Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support demo with hf #179

Merged
merged 1 commit into from
Mar 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
332 changes: 332 additions & 0 deletions examples/internlm2_agent_web_demo_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,332 @@
import copy
import hashlib
import json
import os

import streamlit as st

from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter
from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol
from lagent.llms import HFTransformer
from lagent.llms.meta_template import INTERNLM2_META as META
from lagent.schema import AgentStatusCode

# from streamlit.logger import get_logger


class SessionState:

def init_state(self):
"""Initialize session state variables."""
st.session_state['assistant'] = []
st.session_state['user'] = []

action_list = [
ArxivSearch(),
]
st.session_state['plugin_map'] = {
action.name: action
for action in action_list
}
st.session_state['model_map'] = {}
st.session_state['model_selected'] = None
st.session_state['plugin_actions'] = set()
st.session_state['history'] = []

def clear_state(self):
"""Clear the existing session state."""
st.session_state['assistant'] = []
st.session_state['user'] = []
st.session_state['model_selected'] = None
st.session_state['file'] = set()
if 'chatbot' in st.session_state:
st.session_state['chatbot']._session_history = []


class StreamlitUI:

def __init__(self, session_state: SessionState):
self.init_streamlit()
self.session_state = session_state

def init_streamlit(self):
"""Initialize Streamlit's UI settings."""
st.set_page_config(
layout='wide',
page_title='lagent-web',
page_icon='./docs/imgs/lagent_icon.png')
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
st.sidebar.title('模型控制')
st.session_state['file'] = set()
st.session_state['model_path'] = None

def setup_sidebar(self):
"""Setup the sidebar for model and plugin selection."""
# model_name = st.sidebar.selectbox('模型选择:', options=['internlm'])
model_name = st.sidebar.text_input('模型名称:', value='internlm2-chat-7b')
meta_prompt = st.sidebar.text_area('系统提示词', value=META_CN)
da_prompt = st.sidebar.text_area('数据分析提示词', value=INTERPRETER_CN)
plugin_prompt = st.sidebar.text_area('插件提示词', value=PLUGIN_CN)
model_path = st.sidebar.text_input(
'模型路径:', value='internlm/internlm2-chat-20b')
if model_name != st.session_state['model_selected'] or st.session_state[
'model_path'] != model_path:
st.session_state['model_path'] = model_path
model = self.init_model(model_name, model_path)
self.session_state.clear_state()
st.session_state['model_selected'] = model_name
if 'chatbot' in st.session_state:
del st.session_state['chatbot']
else:
model = st.session_state['model_map'][model_name]

plugin_name = st.sidebar.multiselect(
'插件选择',
options=list(st.session_state['plugin_map'].keys()),
default=[],
)
da_flag = st.sidebar.checkbox(
'数据分析',
value=False,
)
plugin_action = [
st.session_state['plugin_map'][name] for name in plugin_name
]

if 'chatbot' in st.session_state:
if len(plugin_action) > 0:
st.session_state['chatbot']._action_executor = ActionExecutor(
actions=plugin_action)
else:
st.session_state['chatbot']._action_executor = None
if da_flag:
st.session_state[
'chatbot']._interpreter_executor = ActionExecutor(
actions=[IPythonInterpreter()])
else:
st.session_state['chatbot']._interpreter_executor = None
st.session_state['chatbot']._protocol._meta_template = meta_prompt
st.session_state['chatbot']._protocol.plugin_prompt = plugin_prompt
st.session_state[
'chatbot']._protocol.interpreter_prompt = da_prompt
if st.sidebar.button('清空对话', key='clear'):
self.session_state.clear_state()
uploaded_file = st.sidebar.file_uploader('上传文件')

return model_name, model, plugin_action, uploaded_file, model_path

def init_model(self, model_name, path):
"""Initialize the model based on the input model name."""
st.session_state['model_map'][model_name] = HFTransformer(
path=path,
meta_template=META,
max_new_tokens=1024,
top_p=0.8,
top_k=None,
temperature=0.1,
repetition_penalty=1.0,
stop_words=['<|im_end|>'])
return st.session_state['model_map'][model_name]

def initialize_chatbot(self, model, plugin_action):
"""Initialize the chatbot with the given model and plugin actions."""
return Internlm2Agent(
llm=model,
protocol=Internlm2Protocol(
tool=dict(
begin='{start_token}{name}\n',
start_token='<|action_start|>',
name_map=dict(
plugin='<|plugin|>', interpreter='<|interpreter|>'),
belong='assistant',
end='<|action_end|>\n',
), ),
max_turn=7)

def render_user(self, prompt: str):
with st.chat_message('user'):
st.markdown(prompt)

def render_assistant(self, agent_return):
with st.chat_message('assistant'):
for action in agent_return.actions:
if (action) and (action.type != 'FinishAction'):
self.render_action(action)
st.markdown(agent_return.response)

def render_plugin_args(self, action):
action_name = action.type
args = action.args
import json
parameter_dict = dict(name=action_name, parameters=args)
parameter_str = '```json\n' + json.dumps(
parameter_dict, indent=4, ensure_ascii=False) + '\n```'
st.markdown(parameter_str)

def render_interpreter_args(self, action):
st.info(action.type)
st.markdown(action.args['text'])

def render_action(self, action):
st.markdown(action.thought)
if action.type == 'IPythonInterpreter':
self.render_interpreter_args(action)
elif action.type == 'FinishAction':
pass
else:
self.render_plugin_args(action)
self.render_action_results(action)

def render_action_results(self, action):
"""Render the results of action, including text, images, videos, and
audios."""
if (isinstance(action.result, dict)):
if 'text' in action.result:
st.markdown('```\n' + action.result['text'] + '\n```')
if 'image' in action.result:
# image_path = action.result['image']
for image_path in action.result['image']:
image_data = open(image_path, 'rb').read()
st.image(image_data, caption='Generated Image')
if 'video' in action.result:
video_data = action.result['video']
video_data = open(video_data, 'rb').read()
st.video(video_data)
if 'audio' in action.result:
audio_data = action.result['audio']
audio_data = open(audio_data, 'rb').read()
st.audio(audio_data)
elif isinstance(action.result, list):
for item in action.result:
if item['type'] == 'text':
st.markdown('```\n' + item['content'] + '\n```')
elif item['type'] == 'image':
image_data = open(item['content'], 'rb').read()
st.image(image_data, caption='Generated Image')
elif item['type'] == 'video':
video_data = open(item['content'], 'rb').read()
st.video(video_data)
elif item['type'] == 'audio':
audio_data = open(item['content'], 'rb').read()
st.audio(audio_data)
if action.errmsg:
st.error(action.errmsg)


def main():
# logger = get_logger(__name__)
# Initialize Streamlit UI and setup sidebar
if 'ui' not in st.session_state:
session_state = SessionState()
session_state.init_state()
st.session_state['ui'] = StreamlitUI(session_state)

else:
st.set_page_config(
layout='wide',
page_title='lagent-web',
page_icon='./docs/imgs/lagent_icon.png')
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
_, model, plugin_action, uploaded_file, _ = st.session_state[
'ui'].setup_sidebar()

# Initialize chatbot if it is not already initialized
# or if the model has changed
if 'chatbot' not in st.session_state or model != st.session_state[
'chatbot']._llm:
st.session_state['chatbot'] = st.session_state[
'ui'].initialize_chatbot(model, plugin_action)
st.session_state['session_history'] = []

for prompt, agent_return in zip(st.session_state['user'],
st.session_state['assistant']):
st.session_state['ui'].render_user(prompt)
st.session_state['ui'].render_assistant(agent_return)

if user_input := st.chat_input(''):
with st.container():
st.session_state['ui'].render_user(user_input)
st.session_state['user'].append(user_input)
# Add file uploader to sidebar
if (uploaded_file
and uploaded_file.name not in st.session_state['file']):

st.session_state['file'].add(uploaded_file.name)
file_bytes = uploaded_file.read()
file_type = uploaded_file.type
if 'image' in file_type:
st.image(file_bytes, caption='Uploaded Image')
elif 'video' in file_type:
st.video(file_bytes, caption='Uploaded Video')
elif 'audio' in file_type:
st.audio(file_bytes, caption='Uploaded Audio')
# Save the file to a temporary location and get the path

postfix = uploaded_file.name.split('.')[-1]
# prefix = str(uuid.uuid4())
prefix = hashlib.md5(file_bytes).hexdigest()
filename = f'{prefix}.{postfix}'
file_path = os.path.join(root_dir, filename)
with open(file_path, 'wb') as tmpfile:
tmpfile.write(file_bytes)
file_size = os.stat(file_path).st_size / 1024 / 1024
file_size = f'{round(file_size, 2)} MB'
# st.write(f'File saved at: {file_path}')
user_input = [
dict(role='user', content=user_input),
dict(
role='user',
content=json.dumps(dict(path=file_path, size=file_size)),
name='file')
]
if isinstance(user_input, str):
user_input = [dict(role='user', content=user_input)]
st.session_state['last_status'] = AgentStatusCode.SESSION_READY
for agent_return in st.session_state['chatbot'].stream_chat(
st.session_state['session_history'] + user_input):
if agent_return.state == AgentStatusCode.PLUGIN_RETURN:
with st.container():
st.session_state['ui'].render_plugin_args(
agent_return.actions[-1])
st.session_state['ui'].render_action_results(
agent_return.actions[-1])
elif agent_return.state == AgentStatusCode.CODE_RETURN:
with st.container():
st.session_state['ui'].render_action_results(
agent_return.actions[-1])
elif (agent_return.state == AgentStatusCode.STREAM_ING
or agent_return.state == AgentStatusCode.CODING):
# st.markdown(agent_return.response)
# 清除占位符的当前内容,并显示新内容
with st.container():
if agent_return.state != st.session_state['last_status']:
st.session_state['temp'] = ''
placeholder = st.empty()
st.session_state['placeholder'] = placeholder
if isinstance(agent_return.response, dict):
action = f"\n\n {agent_return.response['name']}: \n\n"
action_input = agent_return.response['parameters']
if agent_return.response[
'name'] == 'IPythonInterpreter':
action_input = action_input['command']
response = action + action_input
else:
response = agent_return.response
st.session_state['temp'] = response
st.session_state['placeholder'].markdown(
st.session_state['temp'])
elif agent_return.state == AgentStatusCode.END:
st.session_state['session_history'] += (
user_input + agent_return.inner_steps)
agent_return = copy.deepcopy(agent_return)
agent_return.response = st.session_state['temp']
st.session_state['assistant'].append(
copy.deepcopy(agent_return))
st.session_state['last_status'] = agent_return.state


if __name__ == '__main__':
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
root_dir = os.path.join(root_dir, 'tmp_dir')
os.makedirs(root_dir, exist_ok=True)
main()