From c89729620f4b012aac56bcb413505b844a0569de Mon Sep 17 00:00:00 2001 From: liujiangning30 <147385819+liujiangning30@users.noreply.github.com> Date: Tue, 23 Jan 2024 10:57:58 +0800 Subject: [PATCH] update ReAct example for internlm2 (#85) * update ReAct example for internlm2 * update ReAct example for internlm2 * update base_llm * rename file * update readme * update meta_template --- README.md | 20 +++++++++---- examples/hf_react_example.py | 56 +++++++++++++++--------------------- lagent/llms/huggingface.py | 8 ++++-- lagent/llms/meta_template.py | 40 ++++++++++++++++++++++++++ 4 files changed, 83 insertions(+), 41 deletions(-) create mode 100644 lagent/llms/meta_template.py diff --git a/README.md b/README.md index 60833af0..6f679853 100644 --- a/README.md +++ b/README.md @@ -121,24 +121,34 @@ from lagent.agents import ReAct from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter from lagent.llms import HFTransformer -# Initialize the HFTransformer-based Language Model (llm) and provide the model name. -llm = HFTransformer('internlm/internlm-chat-7b-v1_1') +from lagent.llms.meta_template import INTERNLM2_META as META + +# Initialize the HFTransformer-based Language Model (llm) and +# provide the model name. +llm = HFTransformer( + path='internlm/internlm2-chat-7b', + meta_template=META +) # Initialize the Google Search tool and provide your API key. -search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') +search_tool = GoogleSearch( + api_key='Your SERPER_API_KEY') # Initialize the Python Interpreter tool. python_interpreter = PythonInterpreter() # Create a chatbot by configuring the ReAct agent. +# Specify the actions the chatbot can perform. chatbot = ReAct( llm=llm, # Provide the Language Model instance. action_executor=ActionExecutor( - actions=[search_tool, python_interpreter] # Specify the actions the chatbot can perform. + actions=[python_interpreter] ), ) # Ask the chatbot a mathematical question in LaTeX format. -response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$') +response = chatbot.chat( + '若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$' +) # Print the chatbot's response. print(response.response) # Output the response generated by the chatbot. diff --git a/examples/hf_react_example.py b/examples/hf_react_example.py index 268ef363..69954c44 100644 --- a/examples/hf_react_example.py +++ b/examples/hf_react_example.py @@ -1,37 +1,27 @@ -from lagent.actions.action_executor import ActionExecutor -from lagent.actions.python_interpreter import PythonInterpreter -from lagent.agents.react import ReAct -from lagent.llms.huggingface import HFTransformer +# Import necessary modules and classes from the 'lagent' library. +from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter +from lagent.agents import ReAct +from lagent.llms import HFTransformer +from lagent.llms.meta_template import INTERNLM2_META as META -model = HFTransformer( - path='internlm/internlm-chat-7b-v1_1', - meta_template=[ - dict(role='system', begin='<|System|>:', end='<TOKENS_UNUSED_2>\n'), - dict(role='user', begin='<|User|>:', end='<eoh>\n'), - dict(role='assistant', begin='<|Bot|>:', end='<eoa>\n', generate=True) - ], -) - -chatbot = ReAct( - llm=model, - action_executor=ActionExecutor(actions=[PythonInterpreter()]), -) +# Initialize the HFTransformer-based Language Model (llm) and +# provide the model name. +llm = HFTransformer(path='internlm/internlm2-chat-7b', meta_template=META) +# Initialize the Google Search tool and provide your API key. +search_tool = GoogleSearch(api_key='Your SERPER_API_KEY') -def input_prompt(): - print('\ndouble enter to end input >>> ', end='') - sentinel = '' # ends when this string is seen - return '\n'.join(iter(input, sentinel)) +# Initialize the Python Interpreter tool. +python_interpreter = PythonInterpreter() - -while True: - try: - prompt = input_prompt() - except UnicodeDecodeError: - print('UnicodeDecodeError') - continue - if prompt == 'exit': - exit(0) - - agent_return = chatbot.chat(prompt) - print(agent_return.response) +# Create a chatbot by configuring the ReAct agent. +# Specify the actions the chatbot can perform. +chatbot = ReAct( + llm=llm, # Provide the Language Model instance. + action_executor=ActionExecutor(actions=[python_interpreter]), +) +# Ask the chatbot a mathematical question in LaTeX format. +response = chatbot.chat( + '若$z=-1+\\sqrt{3}i$,则$\frac{z}{{z\\overline{z}-1}}=\\left(\\ \\ \right)$') +# Print the chatbot's response. +print(response.response) # Output the response generated by the chatbot. diff --git a/lagent/llms/huggingface.py b/lagent/llms/huggingface.py index beba7430..8f991015 100644 --- a/lagent/llms/huggingface.py +++ b/lagent/llms/huggingface.py @@ -123,9 +123,11 @@ def generate_from_template(self, templates, max_out_len: int, **kwargs): """ inputs = self.parse_template(templates) response = self.generate(inputs, max_out_len=max_out_len, **kwargs) - return response.replace( - self.template_parser.roles['assistant']['end'].strip(), - '').strip() + end_token = self.template_parser.meta_template[0]['end'].strip() + # return response.replace( + # self.template_parser.roles['assistant']['end'].strip(), + # '').strip() + return response.split(end_token.strip())[0] class HFTransformerCasualLM(HFTransformer): diff --git a/lagent/llms/meta_template.py b/lagent/llms/meta_template.py new file mode 100644 index 00000000..9b4ed978 --- /dev/null +++ b/lagent/llms/meta_template.py @@ -0,0 +1,40 @@ +INTERNLM2_META = [ + dict( + role='system', + begin=dict( + with_name='<|im_start|>system name={name}\n', + without_name='<|im_start|>system\n', + name={ + 'interpreter': '<|interpreter|>', + 'plugin': '<|plugin|>', + }), + end='<|im_end|>\n', + ), + dict( + role='user', + begin=dict( + with_name='<|im_start|>user name={name}\n', + without_name='<|im_start|>user\n', + ), + end='<|im_end|>\n'), + dict( + role='assistant', + begin=dict( + with_name='<|im_start|>assistant name={name}\n', + without_name='<|im_start|>assistant\n', + name={ + 'interpreter': '<|interpreter|>', + 'plugin': '<|plugin|>', + }), + end='<|im_end|>\n'), + dict( + role='environment', + begin=dict( + with_name='<|im_start|>environment name={name}\n', + without_name='<|im_start|>environment\n', + name={ + 'interpreter': '<|interpreter|>', + 'plugin': '<|plugin|>', + }), + end='<|im_end|>\n'), +]