Skip to content

Commit

Permalink
Harrison/serialize from llm and tools (#760)
Browse files Browse the repository at this point in the history
  • Loading branch information
hwchase17 authored Jan 27, 2023
1 parent 12dc7f2 commit e2a7fed
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 4 deletions.
10 changes: 9 additions & 1 deletion langchain/agents/conversational/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,22 @@ def from_llm_and_tools(
llm: BaseLLM,
tools: List[Tool],
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
ai_prefix: str = "AI",
human_prefix: str = "Human",
input_variables: Optional[List[str]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
prompt = cls.create_prompt(
tools, ai_prefix=ai_prefix, human_prefix=human_prefix
tools,
ai_prefix=ai_prefix,
human_prefix=human_prefix,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=llm,
Expand Down
4 changes: 3 additions & 1 deletion langchain/agents/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def initialize_agent(
llm, tools, callback_manager=callback_manager
)
elif agent_path is not None:
agent_obj = load_agent(agent_path, callback_manager=callback_manager)
agent_obj = load_agent(
agent_path, llm=llm, tools=tools, callback_manager=callback_manager
)
else:
raise ValueError(
"Somehow both `agent` and `agent_path` are None, "
Expand Down
38 changes: 36 additions & 2 deletions langchain/agents/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import tempfile
from pathlib import Path
from typing import Any, Union
from typing import Any, List, Optional, Union

import requests
import yaml
Expand All @@ -13,7 +13,9 @@
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstoreAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.tools import Tool
from langchain.chains.loading import load_chain, load_chain_from_config
from langchain.llms.base import BaseLLM

AGENT_TO_CLASS = {
"zero-shot-react-description": ZeroShotAgent,
Expand All @@ -25,10 +27,42 @@
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/"


def load_agent_from_config(config: dict, **kwargs: Any) -> Agent:
def _load_agent_from_tools(
config: dict, llm: BaseLLM, tools: List[Tool], **kwargs: Any
) -> Agent:
config_type = config.pop("_type")
if config_type not in AGENT_TO_CLASS:
raise ValueError(f"Loading {config_type} agent not supported")

if config_type not in AGENT_TO_CLASS:
raise ValueError(f"Loading {config_type} agent not supported")
agent_cls = AGENT_TO_CLASS[config_type]
combined_config = {**config, **kwargs}
return agent_cls.from_llm_and_tools(llm, tools, **combined_config)


def load_agent_from_config(
config: dict,
llm: Optional[BaseLLM] = None,
tools: Optional[List[Tool]] = None,
**kwargs: Any,
) -> Agent:
"""Load agent from Config Dict."""
if "_type" not in config:
raise ValueError("Must specify an agent Type in config")
load_from_tools = config.pop("load_from_llm_and_tools", False)
if load_from_tools:
if llm is None:
raise ValueError(
"If `load_from_llm_and_tools` is set to True, "
"then LLM must be provided"
)
if tools is None:
raise ValueError(
"If `load_from_llm_and_tools` is set to True, "
"then tools must be provided"
)
return _load_agent_from_tools(config, llm, tools, **kwargs)
config_type = config.pop("_type")

if config_type not in AGENT_TO_CLASS:
Expand Down
26 changes: 26 additions & 0 deletions langchain/agents/mrkl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from langchain.agents.agent import Agent, AgentExecutor
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.tools import Tool
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate

Expand Down Expand Up @@ -92,6 +94,30 @@ def create_prompt(
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)

@classmethod
def from_llm_and_tools(
cls,
llm: BaseLLM,
tools: List[Tool],
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
input_variables: Optional[List[str]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
prompt = cls.create_prompt(
tools, prefix=prefix, suffix=suffix, input_variables=input_variables
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
return cls(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)

@classmethod
def _validate_tools(cls, tools: List[Tool]) -> None:
for tool in tools:
Expand Down

0 comments on commit e2a7fed

Please sign in to comment.