-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update example ipynb && openai_function_agent
- Loading branch information
Showing
6 changed files
with
411 additions
and
151 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
144 changes: 144 additions & 0 deletions
144
src/codeinterpreter/agent/openai_functions_agent/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent | ||
|
||
import json | ||
from json import JSONDecodeError | ||
from typing import Union | ||
|
||
|
||
from langchain.schema import ( | ||
AgentAction, | ||
AgentFinish, | ||
OutputParserException, | ||
) | ||
from langchain.schema.messages import ( | ||
AIMessage, | ||
BaseMessage, | ||
) | ||
|
||
from typing import Any, List, Tuple, Union, Dict | ||
|
||
|
||
from langchain.agents.openai_functions_agent.base import ( | ||
OpenAIFunctionsAgent, | ||
_FunctionsAgentAction, | ||
_format_intermediate_steps, | ||
) | ||
from langchain.schema import ( | ||
AgentAction, | ||
AgentFinish, | ||
AIMessage, | ||
BaseMessage, | ||
) | ||
from typing import Union | ||
from json import JSONDecodeError | ||
|
||
|
||
from langchain.callbacks.manager import Callbacks | ||
|
||
|
||
def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]: | ||
"""Parse an AI message.""" | ||
if not isinstance(message, AIMessage): | ||
raise TypeError(f"Expected an AI message got {type(message)}") | ||
|
||
function_call = message.additional_kwargs.get("function_call", {}) | ||
|
||
if function_call: | ||
function_name = function_call["name"] | ||
try: | ||
_tool_input = json.loads(function_call["arguments"]) | ||
except JSONDecodeError: | ||
_tool_input = function_call["arguments"] | ||
# HACK HACK HACK: | ||
# The code that encodes tool input into Open AI uses a special variable | ||
# name called `__arg1` to handle old style tools that do not expose a | ||
# schema and expect a single string argument as an input. | ||
# We unpack the argument here if it exists. | ||
# Open AI does not support passing in a JSON array as an argument. | ||
if "__arg1" in _tool_input: | ||
tool_input = _tool_input["__arg1"] | ||
else: | ||
tool_input = _tool_input | ||
|
||
content_msg = "responded: {content}\n" if message.content else "\n" | ||
|
||
return _FunctionsAgentAction( | ||
tool=function_name, | ||
tool_input=tool_input, | ||
log=f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n", | ||
message_log=[message], | ||
) | ||
|
||
return AgentFinish(return_values={"output": message.content}, log=message.content) | ||
|
||
|
||
class CustomOpenAIFunctionsAgent(OpenAIFunctionsAgent): | ||
""" | ||
https://github.com/langchain-ai/langchain/issues/6364 | ||
""" | ||
|
||
def plan( | ||
self, | ||
intermediate_steps: List[Tuple[AgentAction, str]], | ||
callbacks: Callbacks = None, | ||
with_functions: bool = True, | ||
**kwargs: Any, | ||
) -> Union[AgentAction, AgentFinish]: | ||
"""Given input, decided what to do. | ||
Args: | ||
intermediate_steps: Steps the LLM has taken to date, along with observations | ||
**kwargs: User inputs. | ||
Returns: | ||
Action specifying what tool to use. | ||
""" | ||
agent_scratchpad = _format_intermediate_steps(intermediate_steps) | ||
selected_inputs = { | ||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" | ||
} | ||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) | ||
prompt = self.prompt.format_prompt(**full_inputs) | ||
messages = prompt.to_messages() | ||
if with_functions: | ||
predicted_message = self.llm.predict_messages( | ||
messages, | ||
functions=self.functions, | ||
callbacks=callbacks, | ||
) | ||
else: | ||
predicted_message = self.llm.predict_messages( | ||
messages, | ||
callbacks=callbacks, | ||
) | ||
agent_decision = _parse_ai_message(predicted_message) | ||
return agent_decision | ||
|
||
async def aplan( | ||
self, | ||
intermediate_steps: List[Tuple[AgentAction, str]], | ||
callbacks: Callbacks = None, | ||
**kwargs: Any, | ||
) -> Union[AgentAction, AgentFinish]: | ||
"""Given input, decided what to do. | ||
Args: | ||
intermediate_steps: Steps the LLM has taken to date, | ||
along with observations | ||
**kwargs: User inputs. | ||
Returns: | ||
Action specifying what tool to use. | ||
""" | ||
agent_scratchpad = _format_intermediate_steps(intermediate_steps) | ||
selected_inputs = { | ||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" | ||
} | ||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) | ||
prompt = self.prompt.format_prompt(**full_inputs) | ||
messages = prompt.to_messages() | ||
predicted_message = await self.llm.apredict_messages( | ||
messages, functions=self.functions, callbacks=callbacks | ||
) | ||
agent_decision = _parse_ai_message(predicted_message) | ||
return agent_decision |
Oops, something went wrong.