diff --git a/chatgpt_tool_hub/apps/app.py b/chatgpt_tool_hub/apps/app.py index 30cb3e3..2d4b7fb 100644 --- a/chatgpt_tool_hub/apps/app.py +++ b/chatgpt_tool_hub/apps/app.py @@ -2,7 +2,7 @@ from typing import List -from chatgpt_tool_hub.chains.base import Chain +from chatgpt_tool_hub.engine.tool_engine import ToolEngine from chatgpt_tool_hub.common.log import LOG from chatgpt_tool_hub.tools.base_tool import BaseTool @@ -11,7 +11,7 @@ class App: _instance = None # 存储单例实例 init_flag = False # 记录是否执行过初始化动作 - engine: Chain = None + engine: ToolEngine = None # 当前已加载工具 tools: set = set() @@ -76,4 +76,4 @@ def _check_mandatory_tools(self, use_tools: list) -> bool: return True def get_tool_list(self) -> List[str]: - return list(self.tools) + return list(self.engine.get_tool_list()) diff --git a/chatgpt_tool_hub/apps/app_factory.py b/chatgpt_tool_hub/apps/app_factory.py index f556d6e..308175b 100644 --- a/chatgpt_tool_hub/apps/app_factory.py +++ b/chatgpt_tool_hub/apps/app_factory.py @@ -1,4 +1,5 @@ import logging +import os from rich.console import Console @@ -33,10 +34,16 @@ def init_env(self, **kwargs): else: self.default_tools_list = ["python", "terminal", "url-get", "meteo-weather"] + # set proxy + # _proxy = get_from_dict_or_env(kwargs, "proxy", "PROXY", "") + # if not _proxy: + # os.environ["http_proxy"] = str(_proxy) + # os.environ["https_proxy"] = str(_proxy) + # dynamic loading tool dynamic_tool_loader() - def create_app(self, app_type: str = 'victorinox', tools_list: list = None, **kwargs) -> App: + def create_app(self, app_type: str = 'victorinox', tools_list: list = None, console=Console(quiet=True), **kwargs) -> App: tools_list = tools_list if tools_list else [] self.init_env(**kwargs) @@ -58,7 +65,7 @@ def create_app(self, app_type: str = 'victorinox', tools_list: list = None, **kw if "browser" in tools_list: tools_list = list(filter(lambda tool: tool != "url-get", tools_list)) - app = Victorinox(**build_model_params(kwargs)) + app = Victorinox(console, **build_model_params(kwargs)) app.create(tools_list, **kwargs) return app else: diff --git a/chatgpt_tool_hub/bots/chat_bot/base.py b/chatgpt_tool_hub/bots/chat_bot/base.py index e734645..2b715c3 100644 --- a/chatgpt_tool_hub/bots/chat_bot/base.py +++ b/chatgpt_tool_hub/bots/chat_bot/base.py @@ -173,7 +173,7 @@ def _extract_tool_and_input(self, llm_output: str) -> Optional[Tuple[str, str]]: if action.lower() != "answer-user": self.console.print(f"√ 我在用 [bold cyan]{action}[/] 工具...") - # todo + LOG.info(f"执行Tool: {action}中...") return action.strip(), action_input.strip() @@ -223,7 +223,8 @@ def parse_reply_json(self, assistant_reply) -> dict: title=f"{self.ai_prefix.upper()}的内心独白", highlight=True, style='dim')) # it's useful for avoid splitting Panel - self.console.print("\n") + LOG.info(f"{self.ai_prefix.upper()}的内心独白: {thoughts_text}") + return assistant_reply_json except json.decoder.JSONDecodeError as e: call_stack = traceback.format_exc() diff --git a/chatgpt_tool_hub/bots/chat_bot/prompt.py b/chatgpt_tool_hub/bots/chat_bot/prompt.py index 0beebed..81d38a1 100644 --- a/chatgpt_tool_hub/bots/chat_bot/prompt.py +++ b/chatgpt_tool_hub/bots/chat_bot/prompt.py @@ -27,7 +27,7 @@ "speak": "thoughts summary to say to {human_prefix}", }}}}, "tool": {{{{ - "name": "the tool to use, You must use one of the tools from the list: [{tool_names}]", + "name": "the tool to use, You must use one of the tools from the list: [{tool_names}, answer-user]", "input": "the input to the tool" }}}} }}}} diff --git a/chatgpt_tool_hub/chains/api/base.py b/chatgpt_tool_hub/chains/api/base.py index 98804f0..6ac4a3b 100644 --- a/chatgpt_tool_hub/chains/api/base.py +++ b/chatgpt_tool_hub/chains/api/base.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field, root_validator from rich.console import Console +from rich.panel import Panel from chatgpt_tool_hub.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from chatgpt_tool_hub.chains.base import Chain @@ -69,12 +70,21 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: api_url = self.api_request_chain.predict( question=question, api_docs=self.api_docs ) - LOG.debug(f"[API] generate url: {str(api_url)}") + self.console.print(Panel(f"{str(api_url)}", + title=f"[bright_magenta]URL 构造[/]", + highlight=True)) + + LOG.info(f"URL 构造: {str(api_url)}") self.callback_manager.on_text( api_url, color="green", end="\n", verbose=self.verbose ) api_response = self.requests_wrapper.get(api_url) - LOG.debug(f"[API] response: {str(api_response)}") + + self.console.print(Panel(f"{repr(api_response)}", + title=f"[bright_magenta]API 响应[/]", + highlight=True)) + + LOG.info(f"API 响应: {repr(api_response)}") self.callback_manager.on_text( api_response, color="yellow", end="\n", verbose=self.verbose ) diff --git a/chatgpt_tool_hub/common/log.py b/chatgpt_tool_hub/common/log.py index 317bbf5..cbe4aca 100644 --- a/chatgpt_tool_hub/common/log.py +++ b/chatgpt_tool_hub/common/log.py @@ -6,13 +6,12 @@ def _get_logger(level: int = LOGGING_LEVEL): logger = logging.getLogger("tool") + logger.setLevel(level) ch = logging.StreamHandler(sys.stdout) - ch.setLevel(level) ch.setFormatter(logging.Formatter(LOGGING_FMT, datefmt=LOGGING_DATEFMT)) fh = logging.FileHandler(f'{os.getcwd()}/tool.log', encoding='utf-8') - fh.setLevel(logging.DEBUG) fh.setFormatter(logging.Formatter(LOGGING_FMT, datefmt=LOGGING_DATEFMT)) logger.addHandler(ch) diff --git a/chatgpt_tool_hub/database/chat_memory.py b/chatgpt_tool_hub/database/chat_memory.py index a9ad239..a3ece05 100644 --- a/chatgpt_tool_hub/database/chat_memory.py +++ b/chatgpt_tool_hub/database/chat_memory.py @@ -38,8 +38,8 @@ def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: output_key = list(outputs.keys())[0] else: output_key = self.output_key - self.chat_memory.add_user_message(inputs[prompt_input_key]) - self.chat_memory.add_ai_message(outputs[output_key]) + self.chat_memory.add_user_message(repr(inputs[prompt_input_key])) + self.chat_memory.add_ai_message(repr(outputs[output_key])) def clear(self) -> None: """Clear memory contents.""" diff --git a/chatgpt_tool_hub/engine/bot.py b/chatgpt_tool_hub/engine/bot.py index c81847c..603bc82 100644 --- a/chatgpt_tool_hub/engine/bot.py +++ b/chatgpt_tool_hub/engine/bot.py @@ -95,7 +95,7 @@ def _crop_full_input(self, inputs: str) -> str: try: os.remove(file_path) except Exception as e: - LOG.info(f"remove {file_path} failed... error_info: {repr(e)}") + LOG.debug(f"remove {file_path} failed... error_info: {repr(e)}") return _input diff --git a/chatgpt_tool_hub/engine/tool_engine.py b/chatgpt_tool_hub/engine/tool_engine.py index 42cc2fe..764b5a6 100644 --- a/chatgpt_tool_hub/engine/tool_engine.py +++ b/chatgpt_tool_hub/engine/tool_engine.py @@ -130,8 +130,8 @@ def _take_next_step( self.console.print(Panel(f"{output.tool_input}", title=f"我将给工具 [bright_magenta]{output.tool}[/] 发送如下信息", highlight=True)) - self.console.print("\n") - + + LOG.info(f"我将给工具发送如下信息: \n{output.tool_input}") tool = name_to_tool_map[output.tool] return_direct = tool.return_direct # color = color_mapping[output.tool] @@ -146,8 +146,8 @@ def _take_next_step( ) else: self.console.print(f"× 该工具 [bright_magenta]{output.tool}[/] 无效") - self.console.print("\n") - + + LOG.info(f"该工具 {output.tool} 无效") observation = InvalidTool().run( output.tool, verbose=self.verbose, @@ -159,8 +159,8 @@ def _take_next_step( self.console.print(Panel(observation + "\n", title=f"工具 [bright_magenta]{output.tool}[/] 返回内容", highlight=True, style='dim')) - self.console.print("\n") - + + LOG.info(f"工具 {output.tool} 返回内容: {observation}") return output, observation def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: @@ -321,3 +321,9 @@ async def _areturn( if self.return_intermediate_steps: final_output["intermediate_steps"] = intermediate_steps return final_output + + def get_tool_list(self) -> List[str]: + _tool_list = [] + for tool in self.tools: + _tool_list.extend(tool.get_tool_list()) + return _tool_list diff --git a/chatgpt_tool_hub/tools/arxiv_search/api_prompt.py b/chatgpt_tool_hub/tools/arxiv_search/api_prompt.py index a3df72a..f38f338 100644 --- a/chatgpt_tool_hub/tools/arxiv_search/api_prompt.py +++ b/chatgpt_tool_hub/tools/arxiv_search/api_prompt.py @@ -1,5 +1,5 @@ ARXIV_PROMPT = """ -You are a converter from natural language to json to search research papers in arxiv api. +You are a converter from natural language to json to search research papers by arxiv api. You should only respond in JSON format as described below. Response Format: @@ -10,7 +10,7 @@ "sort_order": "descending" }} json note: -search_query: an arXiv query string. I will show you how to generate a search_query +search_query: an arXiv query string. I will teach you how to generate a search_query below. max_results: int, range: 1~20 sort_by: The sort criterion for results: "relevance", "lastUpdatedDate", or "submittedDate" sort_order: The sort order for results: "descending" or "ascending" diff --git a/chatgpt_tool_hub/tools/arxiv_search/tool.py b/chatgpt_tool_hub/tools/arxiv_search/tool.py index 9a313a8..8abc3ba 100644 --- a/chatgpt_tool_hub/tools/arxiv_search/tool.py +++ b/chatgpt_tool_hub/tools/arxiv_search/tool.py @@ -1,9 +1,12 @@ +import os +import tempfile from typing import Any from rich.console import Console from chatgpt_tool_hub.chains import LLMChain from chatgpt_tool_hub.common.log import LOG +from chatgpt_tool_hub.common.utils import get_from_dict_or_env from chatgpt_tool_hub.models import build_model_params from chatgpt_tool_hub.models.model_factory import ModelFactory from chatgpt_tool_hub.prompts import PromptTemplate @@ -11,6 +14,7 @@ from chatgpt_tool_hub.tools.arxiv_search.api_prompt import ARXIV_PROMPT from chatgpt_tool_hub.tools.arxiv_search.wrapper import ArxivAPIWrapper from chatgpt_tool_hub.tools.base_tool import BaseTool +from chatgpt_tool_hub.tools.summary import SummaryTool default_tool_name = "arxiv" @@ -32,6 +36,7 @@ def __init__(self, console: Console = Console(), **tool_kwargs: Any): super().__init__(console=console, return_direct=True) self.api_wrapper = ArxivAPIWrapper(**tool_kwargs) + self.arxiv_summary = get_from_dict_or_env(tool_kwargs, "arxiv_summary", "ARXIV_SUMMARY", True) llm = ModelFactory().create_llm_model(**build_model_params(tool_kwargs)) prompt = PromptTemplate( @@ -46,7 +51,23 @@ def _run(self, query: str) -> str: _llm_response = self.bot.run(query) LOG.info(f"[arxiv]: search_query: {_llm_response}") - return self.api_wrapper.run(_llm_response) + _api_response = self.api_wrapper.run(_llm_response) + + if not self.arxiv_summary: + return _api_response + + temp_file = tempfile.mkstemp() + file_path = temp_file[1] + with open(file_path, "w") as f: + f.write(_api_response + "\n") + + _input = SummaryTool(max_segment_length=2400).run(f"{str(file_path)}, 0") + try: + os.remove(file_path) + except Exception as e: + LOG.debug(f"remove {file_path} failed... error_info: {repr(e)}") + + return _input async def _arun(self, query: str) -> str: """Use the Arxiv tool asynchronously.""" diff --git a/chatgpt_tool_hub/tools/base_tool.py b/chatgpt_tool_hub/tools/base_tool.py index fa98e7c..2d1874e 100644 --- a/chatgpt_tool_hub/tools/base_tool.py +++ b/chatgpt_tool_hub/tools/base_tool.py @@ -122,3 +122,6 @@ async def arun( observation, verbose=verbose, color=color, **kwargs ) return observation + + def get_tool_list(self): + return [self.name] diff --git a/chatgpt_tool_hub/tools/meteo/tool.py b/chatgpt_tool_hub/tools/meteo/tool.py index e43668a..4ff7dcd 100644 --- a/chatgpt_tool_hub/tools/meteo/tool.py +++ b/chatgpt_tool_hub/tools/meteo/tool.py @@ -24,7 +24,7 @@ class MeteoWeatherTool(BaseTool): def __init__(self, console: Console = Console(), **tool_kwargs): super().__init__(console=console, return_direct=False) llm = ModelFactory().create_llm_model(**build_model_params(tool_kwargs)) - self.api_chain = APIChain.from_llm_and_api_docs(llm, OPEN_METEO_DOCS) + self.api_chain = APIChain.from_llm_and_api_docs(llm, OPEN_METEO_DOCS, console=console) def _run(self, query: str) -> str: """Use the tool.""" diff --git a/chatgpt_tool_hub/tools/news/morning_news/tool.py b/chatgpt_tool_hub/tools/news/morning_news/tool.py index e4915b5..d8f9aa2 100644 --- a/chatgpt_tool_hub/tools/news/morning_news/tool.py +++ b/chatgpt_tool_hub/tools/news/morning_news/tool.py @@ -57,7 +57,7 @@ def _run(self, query: str) -> str: _news_content = "\n".join(_return_data.get("news")) _weiyu = _return_data.get('weiyu') _image_url = _return_data.get('image') - return f"\n今日日期:{_date}\n今日早报:{_news_content}\n今日微语:{_weiyu}\nURL: {_image_url}" + return f"\n今日日期:{_date}\n\n今日早报:{_news_content}\n\n今日微语:{_weiyu}\n\nURL: {_image_url}" else: return f"[{default_tool_name}] api error, error_info: {_response_json.get('msg')}" diff --git a/chatgpt_tool_hub/tools/news/news_api/docs_prompts.py b/chatgpt_tool_hub/tools/news/news_api/docs_prompt.py similarity index 89% rename from chatgpt_tool_hub/tools/news/news_api/docs_prompts.py rename to chatgpt_tool_hub/tools/news/news_api/docs_prompt.py index 7eb608e..ae551d0 100644 --- a/chatgpt_tool_hub/tools/news/news_api/docs_prompts.py +++ b/chatgpt_tool_hub/tools/news/news_api/docs_prompt.py @@ -7,7 +7,6 @@ This endpoint is great for retrieving headlines for use with news tickers or similar. Request parameters - country | The 2-letter ISO 3166-1 code of the country you want to get headlines for. Possible options: ae ar at au be bg br ca ch cn co cu cz de eg fr gb gr hk hu id ie il in it jp kr lt lv ma mx my ng nl no nz ph pl pt ro rs ru sa se sg si sk th tr tw ua us ve za. Note: you can't mix this param with the sources param. category | The category you want to get headlines for. Possible options: business entertainment general health science sports technology. Note: you can't mix this param with the sources param. sources | A comma-seperated string of identifiers for the news sources or blogs you want headlines from. Use the /top-headlines/sources endpoint to locate these programmatically or look at the sources index. Note: you can't mix this param with the country or category params. @@ -28,7 +27,11 @@ publishedAt | string | The date and time that the article was published, in UTC (+000) content | string | The unformatted content of the article, where available. This is truncated to 200 chars. -Use page size: 2 -The endpoint and prefix of the document are absolutely correct, and that other URL prefixes should not be fabricated. -the url must not contain apiKey param -""" +Pay attention: +1. page size: 5 +2. You should only use the parameters described in this document +to construct the API URL, and should not make up parameters. +3. The endpoint and top headlines of the document are absolutely correct, +any other URL prefixes should not be fabricated. + +URL: """ diff --git a/chatgpt_tool_hub/tools/news/news_api/tool.py b/chatgpt_tool_hub/tools/news/news_api/tool.py index a5fd93a..4cba1ed 100644 --- a/chatgpt_tool_hub/tools/news/news_api/tool.py +++ b/chatgpt_tool_hub/tools/news/news_api/tool.py @@ -8,7 +8,7 @@ from chatgpt_tool_hub.models.model_factory import ModelFactory from chatgpt_tool_hub.tools.base_tool import BaseTool from chatgpt_tool_hub.tools.news import news_tool_register -from chatgpt_tool_hub.tools.news.news_api.docs_prompts import NEWS_DOCS +from chatgpt_tool_hub.tools.news.news_api.docs_prompt import NEWS_DOCS default_tool_name = "news-api" @@ -29,7 +29,7 @@ def __init__(self, console: Console = Console(), **tool_kwargs: Any): llm = ModelFactory().create_llm_model(**build_model_params(tool_kwargs)) self.api_chain = APIChain.from_llm_and_api_docs( - llm, NEWS_DOCS, headers={"X-Api-Key": news_api_key} + llm, NEWS_DOCS, console=console, headers={"X-Api-Key": news_api_key} ) def _run(self, query: str) -> str: diff --git a/chatgpt_tool_hub/tools/news/tool.py b/chatgpt_tool_hub/tools/news/tool.py index f237b3e..5c02729 100644 --- a/chatgpt_tool_hub/tools/news/tool.py +++ b/chatgpt_tool_hub/tools/news/tool.py @@ -17,9 +17,9 @@ class NewsTool(BaseTool): name: str = default_tool_name description: str = ( - "Useful when you want to get information about current news stories, " - "such as financial news, daily morning reports and any other news. " - "The input should be a description of your needs in natural language." + "当你想要获取实时新闻资讯时可以使用该工具,你能获取任何与新闻有关的信息。" + "该工具包含了金融、早报和news-api三个子工具,访问这些工具前你需要先访问本工具。" + "工具输入:你目前了解到的所有信息的总结 和 用户想获取的新闻内容。" ) engine: ToolEngine = Any @@ -30,6 +30,7 @@ def __init__(self, console: Console = Console(), **tool_kwargs: Any): tools = load_tools(news_tool_register.get_registered_tool_names(), news_tool_register, console, **tool_kwargs) llm = ModelFactory().create_llm_model(**build_model_params(tool_kwargs)) + # subtool 思考深度暂时固定为2 self.engine = init_tool_engine(tools, llm, bot="qa-bot", verbose=True, console=self.console, max_iterations=2, early_stopping_method="force") @@ -47,5 +48,8 @@ async def _arun(self, query: str) -> str: """Use the tool asynchronously.""" raise NotImplementedError("NewsTool does not support async") + def get_tool_list(self): + return news_tool_register.get_registered_tool_names() + main_tool_register.register_tool(default_tool_name, lambda console, kwargs: NewsTool(console, **kwargs), []) diff --git a/chatgpt_tool_hub/tools/summary/tool.py b/chatgpt_tool_hub/tools/summary/tool.py index 5812fb8..c10d28c 100644 --- a/chatgpt_tool_hub/tools/summary/tool.py +++ b/chatgpt_tool_hub/tools/summary/tool.py @@ -151,11 +151,6 @@ def _run(self, tool_input: str) -> str: map_text_list = asyncio.run(self._acall(self.map_bot, _clip_text_list)) map_text = _clipper.seperator.join(map_text_list) - self.console.print(Panel(f"{map_text}", - title=f"=[bright_magenta]Summary tool[/] 总结", - highlight=True)) - self.console.print("\n") - LOG.debug(f"[summary] round:{ctn}, map_list: {map_text}") # reduce _clip_summary_list = _clipper.clip(map_text, self.message_num) @@ -164,6 +159,16 @@ def _run(self, tool_input: str) -> str: reduce_text = _clipper.seperator.join(reduce_text_list) LOG.debug(f"[summary] round:{ctn}, reduce_list: {reduce_text}") _text = reduce_text + + self.console.print(Panel(f"{_text}", + title=f"[bright_magenta]Summary tool[/] 第{ctn}轮总结", + highlight=True)) + + if ctn > 2: + self.console.print(Panel(f"{_text}", + title=f"[bright_magenta]Summary tool[/] 最终总结", + highlight=True)) + return _text async def _arun(self, file_path: str) -> str: diff --git a/chatgpt_tool_hub/tools/web_requests/__init__.py b/chatgpt_tool_hub/tools/web_requests/__init__.py index 21c539b..835647e 100644 --- a/chatgpt_tool_hub/tools/web_requests/__init__.py +++ b/chatgpt_tool_hub/tools/web_requests/__init__.py @@ -37,7 +37,7 @@ def filter_text(html: str) -> str: try: os.remove(file_path) except Exception as e: - LOG.info(f"remove {file_path} failed... error_info: {repr(e)}") + LOG.debug(f"remove {file_path} failed... error_info: {repr(e)}") return _summary.encode('utf-8').decode() diff --git a/chatgpt_tool_hub/tools/web_requests/browser.py b/chatgpt_tool_hub/tools/web_requests/browser.py index c9c05ad..e765ef0 100644 --- a/chatgpt_tool_hub/tools/web_requests/browser.py +++ b/chatgpt_tool_hub/tools/web_requests/browser.py @@ -113,7 +113,7 @@ class BrowserTool(BaseTool): description = ( "A Google Chrome browser. Use this when you need to get specific content from a website. " "Input should be a url (i.e. https://github.com/goldfishh/chatgpt-tool-hub). " - "The output will be the text response of browser. This tool has a higher priority than url-get tool." + "The output will be the text response of browser. " ) browser: ChromeBrowser = None diff --git a/chatgpt_tool_hub/tools/web_requests/get.py b/chatgpt_tool_hub/tools/web_requests/get.py index ff1dca4..48c1728 100644 --- a/chatgpt_tool_hub/tools/web_requests/get.py +++ b/chatgpt_tool_hub/tools/web_requests/get.py @@ -21,10 +21,9 @@ class RequestsGetTool(BaseRequestsTool, BaseTool): "The output will be the text response of the GET request." ) - def __init__(self, console: Console = Console(), - requests_wrapper: RequestsWrapper = RequestsWrapper(), **tool_kwargs: Any): + def __init__(self, console: Console = Console(), **tool_kwargs: Any): # 这个工具直接返回内容 - super().__init__(console=console, requests_wrapper=requests_wrapper, return_direct=False) + super().__init__(console=console, requests_wrapper=RequestsWrapper(**tool_kwargs), return_direct=False) def _run(self, url: str) -> str: """Run the tool.""" @@ -49,7 +48,7 @@ async def _arun(self, url: str) -> str: return _content -main_tool_register.register_tool(default_tool_name, lambda console, kwargs: RequestsGetTool(console, requests_wrapper=RequestsWrapper(**kwargs)), []) +main_tool_register.register_tool(default_tool_name, lambda console, kwargs: RequestsGetTool(console, **kwargs), []) if __name__ == "__main__": diff --git a/chatgpt_tool_hub/tools/wikipedia/wikipedia.py b/chatgpt_tool_hub/tools/wikipedia/wikipedia.py index 8ed4a13..6842c95 100644 --- a/chatgpt_tool_hub/tools/wikipedia/wikipedia.py +++ b/chatgpt_tool_hub/tools/wikipedia/wikipedia.py @@ -1,4 +1,5 @@ """Tool for the Wikipedia API.""" +from typing import Any from rich.console import Console @@ -18,12 +19,12 @@ class WikipediaTool(BaseTool): "people, places, companies, historical events, or other subjects. " "Input should be a search query." ) - api_wrapper: WikipediaAPIWrapper + api_wrapper: WikipediaAPIWrapper = None - def __init__(self, console: Console = Console(), api_wrapper: WikipediaAPIWrapper = WikipediaAPIWrapper()): + def __init__(self, console: Console = Console(), **tool_kwargs: Any): # 这个工具直接返回内容 super().__init__(console=console, return_direct=False) - self.api_wrapper = api_wrapper + self.api_wrapper = WikipediaAPIWrapper(**tool_kwargs) def _run(self, query: str) -> str: """Use the Wikipedia tool.""" @@ -34,4 +35,4 @@ async def _arun(self, query: str) -> str: raise NotImplementedError("WikipediaQueryRun does not support async") -main_tool_register.register_tool(default_tool_name, lambda console, kwargs: WikipediaTool(console, api_wrapper=WikipediaAPIWrapper(**kwargs)), []) +main_tool_register.register_tool(default_tool_name, lambda console, kwargs: WikipediaTool(console, **kwargs), []) diff --git a/chatgpt_tool_hub/tools/wikipedia/wrapper.py b/chatgpt_tool_hub/tools/wikipedia/wrapper.py index f4b5655..52e341f 100644 --- a/chatgpt_tool_hub/tools/wikipedia/wrapper.py +++ b/chatgpt_tool_hub/tools/wikipedia/wrapper.py @@ -1,4 +1,5 @@ """Util that calls Wikipedia.""" +import time from typing import Any, Dict, Optional from pydantic import BaseModel, Extra, root_validator @@ -47,9 +48,16 @@ def run(self, query: str) -> str: search_results = self.wiki_client.search(query) summaries = [] for i in range(min(self.top_k_results, len(search_results))): - summary = self.fetch_formatted_page_summary(search_results[i]) - if summary is not None: - summaries.append(summary) + retry_num = 0 + while retry_num <= 1: + summary = self.fetch_formatted_page_summary(search_results[i]) + if summary is not None: + summaries.append(summary) + break + else: + # wikipedia api 限制 + time.sleep(2) + retry_num += 1 _content = "\n\n".join(summaries) LOG.debug(f"[wikipedia]: {_content}") return _content @@ -58,9 +66,6 @@ def fetch_formatted_page_summary(self, page: str) -> Optional[str]: try: wiki_page = self.wiki_client.page(title=page, auto_suggest=False) return f"Page: {page}\nSummary: {wiki_page.summary}" - except ( - self.wiki_client.exceptions.PageError, - self.wiki_client.exceptions.DisambiguationError, - ) as e: - LOG.error(f"[wikipedia]: {repr(e)}") + except Exception as e: + LOG.info(f"[wikipedia]: {repr(e)}") return None diff --git a/chatgpt_tool_hub/tools/wolfram_alpha/wolfram_alpha.py b/chatgpt_tool_hub/tools/wolfram_alpha/wolfram_alpha.py index 4a6a239..590498a 100644 --- a/chatgpt_tool_hub/tools/wolfram_alpha/wolfram_alpha.py +++ b/chatgpt_tool_hub/tools/wolfram_alpha/wolfram_alpha.py @@ -21,9 +21,9 @@ class WolframAlphaTool(BaseTool): "Science, Technology, Culture, Society and Everyday Life. " "Input should be a search query." ) - api_wrapper: WolframAlphaAPIWrapper + api_wrapper: WolframAlphaAPIWrapper = None - def __init__(self, console: Console = Console(), **tool_kwargs): + def __init__(self, console: Console = Console(), **tool_kwargs: Any): # 这个工具直接返回内容 super().__init__(console=console, return_direct=False) self.api_wrapper = WolframAlphaAPIWrapper(**tool_kwargs) @@ -38,4 +38,4 @@ async def _arun(self, query: str) -> str: main_tool_register.register_tool(default_tool_name, lambda console, kwargs: WolframAlphaTool(console, **kwargs), - tool_input_keys=["wolfram_alpha_appid"]) + tool_input_keys=["wolfram_alpha_appid"]) diff --git a/terminal_io.py b/terminal_io.py index a6043f6..ce05f6b 100644 --- a/terminal_io.py +++ b/terminal_io.py @@ -23,6 +23,7 @@ from chatgpt_tool_hub.common.calculate_token import count_message_tokens from chatgpt_tool_hub.common.constants import LOGGING_FMT, LOGGING_DATEFMT from chatgpt_tool_hub.tools.all_tool_list import main_tool_register +from chatgpt_tool_hub.tools.news import news_tool_register logging.basicConfig(filename=f'{os.getcwd()}/llmos.log', format=LOGGING_FMT, datefmt=LOGGING_DATEFMT, level=logging.INFO) @@ -130,7 +131,7 @@ def __init__(self, timeout: int): self.current_tokens = count_message_tokens(self.messages) def create_app(self): - return AppFactory().create_app(tools_list=config["tools"], **config["kwargs"]) + return AppFactory().create_app(tools_list=config["tools"], console=console, **config["kwargs"]) @property def get_app(self) -> App: @@ -264,7 +265,8 @@ def get_completions(self, document, complete_event): yield Completion(model, start_position=-len(model_prefix)) if text.startswith('/add '): tool_prefix = text[5:] - available_tools = main_tool_register.get_registered_tool_names() + available_tools = main_tool_register.get_registered_tool_names() \ + + news_tool_register.get_registered_tool_names() for tool in available_tools: if tool.startswith(tool_prefix): yield Completion(tool, start_position=-len(tool_prefix)) @@ -336,10 +338,12 @@ def handle_command(command: str, llm_os: LLMOS): tools_kwargs[tool_args] = add_tool_args # todo 目前tool-hub不支持subtool粒度增删tool - if add_tool not in main_tool_register.get_registered_tool_names(): - console.print(f"发现未知工具: {add_tool}") - return + if add_tool not in subtool_parent.keys(): + console.print(f"发现未知工具: {add_tool}") + return + elif subtool_parent[add_tool] not in llm_os.get_app.get_tool_list(): + add_tool = subtool_parent[add_tool] app = llm_os.get_app app.add_tool(add_tool, **tools_kwargs)