From c89729620f4b012aac56bcb413505b844a0569de Mon Sep 17 00:00:00 2001
From: liujiangning30 <147385819+liujiangning30@users.noreply.github.com>
Date: Tue, 23 Jan 2024 10:57:58 +0800
Subject: [PATCH] update ReAct example for internlm2 (#85)

* update ReAct example for internlm2

* update ReAct example for internlm2

* update base_llm

* rename file

* update readme

* update meta_template
---
 README.md                    | 20 +++++++++----
 examples/hf_react_example.py | 56 +++++++++++++++---------------------
 lagent/llms/huggingface.py   |  8 ++++--
 lagent/llms/meta_template.py | 40 ++++++++++++++++++++++++++
 4 files changed, 83 insertions(+), 41 deletions(-)
 create mode 100644 lagent/llms/meta_template.py

diff --git a/README.md b/README.md
index 60833af0..6f679853 100644
--- a/README.md
+++ b/README.md
@@ -121,24 +121,34 @@ from lagent.agents import ReAct
 from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter
 from lagent.llms import HFTransformer
 
-# Initialize the HFTransformer-based Language Model (llm) and provide the model name.
-llm = HFTransformer('internlm/internlm-chat-7b-v1_1')
+from lagent.llms.meta_template import INTERNLM2_META as META
+
+# Initialize the HFTransformer-based Language Model (llm) and
+# provide the model name.
+llm = HFTransformer(
+    path='internlm/internlm2-chat-7b',
+    meta_template=META
+)
 
 # Initialize the Google Search tool and provide your API key.
-search_tool = GoogleSearch(api_key='Your SERPER_API_KEY')
+search_tool = GoogleSearch(
+    api_key='Your SERPER_API_KEY')
 
 # Initialize the Python Interpreter tool.
 python_interpreter = PythonInterpreter()
 
 # Create a chatbot by configuring the ReAct agent.
+# Specify the actions the chatbot can perform.
 chatbot = ReAct(
     llm=llm,  # Provide the Language Model instance.
     action_executor=ActionExecutor(
-        actions=[search_tool, python_interpreter]  # Specify the actions the chatbot can perform.
+        actions=[python_interpreter]
     ),
 )
 # Ask the chatbot a mathematical question in LaTeX format.
-response = chatbot.chat('若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$')
+response = chatbot.chat(
+    '若$z=-1+\sqrt{3}i$,则$\frac{z}{{z\overline{z}-1}}=\left(\ \ \right)$'
+)
 
 # Print the chatbot's response.
 print(response.response)  # Output the response generated by the chatbot.
diff --git a/examples/hf_react_example.py b/examples/hf_react_example.py
index 268ef363..69954c44 100644
--- a/examples/hf_react_example.py
+++ b/examples/hf_react_example.py
@@ -1,37 +1,27 @@
-from lagent.actions.action_executor import ActionExecutor
-from lagent.actions.python_interpreter import PythonInterpreter
-from lagent.agents.react import ReAct
-from lagent.llms.huggingface import HFTransformer
+# Import necessary modules and classes from the 'lagent' library.
+from lagent.actions import ActionExecutor, GoogleSearch, PythonInterpreter
+from lagent.agents import ReAct
+from lagent.llms import HFTransformer
+from lagent.llms.meta_template import INTERNLM2_META as META
 
-model = HFTransformer(
-    path='internlm/internlm-chat-7b-v1_1',
-    meta_template=[
-        dict(role='system', begin='<|System|>:', end='<TOKENS_UNUSED_2>\n'),
-        dict(role='user', begin='<|User|>:', end='<eoh>\n'),
-        dict(role='assistant', begin='<|Bot|>:', end='<eoa>\n', generate=True)
-    ],
-)
-
-chatbot = ReAct(
-    llm=model,
-    action_executor=ActionExecutor(actions=[PythonInterpreter()]),
-)
+# Initialize the HFTransformer-based Language Model (llm) and
+# provide the model name.
+llm = HFTransformer(path='internlm/internlm2-chat-7b', meta_template=META)
 
+# Initialize the Google Search tool and provide your API key.
+search_tool = GoogleSearch(api_key='Your SERPER_API_KEY')
 
-def input_prompt():
-    print('\ndouble enter to end input >>> ', end='')
-    sentinel = ''  # ends when this string is seen
-    return '\n'.join(iter(input, sentinel))
+# Initialize the Python Interpreter tool.
+python_interpreter = PythonInterpreter()
 
-
-while True:
-    try:
-        prompt = input_prompt()
-    except UnicodeDecodeError:
-        print('UnicodeDecodeError')
-        continue
-    if prompt == 'exit':
-        exit(0)
-
-    agent_return = chatbot.chat(prompt)
-    print(agent_return.response)
+# Create a chatbot by configuring the ReAct agent.
+# Specify the actions the chatbot can perform.
+chatbot = ReAct(
+    llm=llm,  # Provide the Language Model instance.
+    action_executor=ActionExecutor(actions=[python_interpreter]),
+)
+# Ask the chatbot a mathematical question in LaTeX format.
+response = chatbot.chat(
+    '若$z=-1+\\sqrt{3}i$,则$\frac{z}{{z\\overline{z}-1}}=\\left(\\ \\ \right)$')
+# Print the chatbot's response.
+print(response.response)  # Output the response generated by the chatbot.
diff --git a/lagent/llms/huggingface.py b/lagent/llms/huggingface.py
index beba7430..8f991015 100644
--- a/lagent/llms/huggingface.py
+++ b/lagent/llms/huggingface.py
@@ -123,9 +123,11 @@ def generate_from_template(self, templates, max_out_len: int, **kwargs):
         """
         inputs = self.parse_template(templates)
         response = self.generate(inputs, max_out_len=max_out_len, **kwargs)
-        return response.replace(
-            self.template_parser.roles['assistant']['end'].strip(),
-            '').strip()
+        end_token = self.template_parser.meta_template[0]['end'].strip()
+        # return response.replace(
+        #     self.template_parser.roles['assistant']['end'].strip(),
+        #     '').strip()
+        return response.split(end_token.strip())[0]
 
 
 class HFTransformerCasualLM(HFTransformer):
diff --git a/lagent/llms/meta_template.py b/lagent/llms/meta_template.py
new file mode 100644
index 00000000..9b4ed978
--- /dev/null
+++ b/lagent/llms/meta_template.py
@@ -0,0 +1,40 @@
+INTERNLM2_META = [
+    dict(
+        role='system',
+        begin=dict(
+            with_name='<|im_start|>system name={name}\n',
+            without_name='<|im_start|>system\n',
+            name={
+                'interpreter': '<|interpreter|>',
+                'plugin': '<|plugin|>',
+            }),
+        end='<|im_end|>\n',
+    ),
+    dict(
+        role='user',
+        begin=dict(
+            with_name='<|im_start|>user name={name}\n',
+            without_name='<|im_start|>user\n',
+        ),
+        end='<|im_end|>\n'),
+    dict(
+        role='assistant',
+        begin=dict(
+            with_name='<|im_start|>assistant name={name}\n',
+            without_name='<|im_start|>assistant\n',
+            name={
+                'interpreter': '<|interpreter|>',
+                'plugin': '<|plugin|>',
+            }),
+        end='<|im_end|>\n'),
+    dict(
+        role='environment',
+        begin=dict(
+            with_name='<|im_start|>environment name={name}\n',
+            without_name='<|im_start|>environment\n',
+            name={
+                'interpreter': '<|interpreter|>',
+                'plugin': '<|plugin|>',
+            }),
+        end='<|im_end|>\n'),
+]