diff --git a/langchain/agents/conversational/base.py b/langchain/agents/conversational/base.py index 1cf24fa4649ee..3cbccf86a308c 100644 --- a/langchain/agents/conversational/base.py +++ b/langchain/agents/conversational/base.py @@ -91,14 +91,22 @@ def from_llm_and_tools( llm: BaseLLM, tools: List[Tool], callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = PREFIX, + suffix: str = SUFFIX, ai_prefix: str = "AI", human_prefix: str = "Human", + input_variables: Optional[List[str]] = None, **kwargs: Any, ) -> Agent: """Construct an agent from an LLM and tools.""" cls._validate_tools(tools) prompt = cls.create_prompt( - tools, ai_prefix=ai_prefix, human_prefix=human_prefix + tools, + ai_prefix=ai_prefix, + human_prefix=human_prefix, + prefix=prefix, + suffix=suffix, + input_variables=input_variables, ) llm_chain = LLMChain( llm=llm, diff --git a/langchain/agents/initialize.py b/langchain/agents/initialize.py index 17d1cf532327e..b0b9a78c80558 100644 --- a/langchain/agents/initialize.py +++ b/langchain/agents/initialize.py @@ -54,7 +54,9 @@ def initialize_agent( llm, tools, callback_manager=callback_manager ) elif agent_path is not None: - agent_obj = load_agent(agent_path, callback_manager=callback_manager) + agent_obj = load_agent( + agent_path, llm=llm, tools=tools, callback_manager=callback_manager + ) else: raise ValueError( "Somehow both `agent` and `agent_path` are None, " diff --git a/langchain/agents/loading.py b/langchain/agents/loading.py index 42731e7fcc686..5b36f908ad7af 100644 --- a/langchain/agents/loading.py +++ b/langchain/agents/loading.py @@ -3,7 +3,7 @@ import os import tempfile from pathlib import Path -from typing import Any, Union +from typing import Any, List, Optional, Union import requests import yaml @@ -13,7 +13,9 @@ from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.react.base import ReActDocstoreAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent +from langchain.agents.tools import Tool from langchain.chains.loading import load_chain, load_chain_from_config +from langchain.llms.base import BaseLLM AGENT_TO_CLASS = { "zero-shot-react-description": ZeroShotAgent, @@ -25,10 +27,42 @@ URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/" -def load_agent_from_config(config: dict, **kwargs: Any) -> Agent: +def _load_agent_from_tools( + config: dict, llm: BaseLLM, tools: List[Tool], **kwargs: Any +) -> Agent: + config_type = config.pop("_type") + if config_type not in AGENT_TO_CLASS: + raise ValueError(f"Loading {config_type} agent not supported") + + if config_type not in AGENT_TO_CLASS: + raise ValueError(f"Loading {config_type} agent not supported") + agent_cls = AGENT_TO_CLASS[config_type] + combined_config = {**config, **kwargs} + return agent_cls.from_llm_and_tools(llm, tools, **combined_config) + + +def load_agent_from_config( + config: dict, + llm: Optional[BaseLLM] = None, + tools: Optional[List[Tool]] = None, + **kwargs: Any, +) -> Agent: """Load agent from Config Dict.""" if "_type" not in config: raise ValueError("Must specify an agent Type in config") + load_from_tools = config.pop("load_from_llm_and_tools", False) + if load_from_tools: + if llm is None: + raise ValueError( + "If `load_from_llm_and_tools` is set to True, " + "then LLM must be provided" + ) + if tools is None: + raise ValueError( + "If `load_from_llm_and_tools` is set to True, " + "then tools must be provided" + ) + return _load_agent_from_tools(config, llm, tools, **kwargs) config_type = config.pop("_type") if config_type not in AGENT_TO_CLASS: diff --git a/langchain/agents/mrkl/base.py b/langchain/agents/mrkl/base.py index 037367ed32b50..fb2acfb4a4c6a 100644 --- a/langchain/agents/mrkl/base.py +++ b/langchain/agents/mrkl/base.py @@ -7,6 +7,8 @@ from langchain.agents.agent import Agent, AgentExecutor from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.agents.tools import Tool +from langchain.callbacks.base import BaseCallbackManager +from langchain.chains import LLMChain from langchain.llms.base import BaseLLM from langchain.prompts import PromptTemplate @@ -92,6 +94,30 @@ def create_prompt( input_variables = ["input", "agent_scratchpad"] return PromptTemplate(template=template, input_variables=input_variables) + @classmethod + def from_llm_and_tools( + cls, + llm: BaseLLM, + tools: List[Tool], + callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = PREFIX, + suffix: str = SUFFIX, + input_variables: Optional[List[str]] = None, + **kwargs: Any, + ) -> Agent: + """Construct an agent from an LLM and tools.""" + cls._validate_tools(tools) + prompt = cls.create_prompt( + tools, prefix=prefix, suffix=suffix, input_variables=input_variables + ) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + return cls(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + @classmethod def _validate_tools(cls, tools: List[Tool]) -> None: for tool in tools: