From 3527008c8a77434763251f02ce20cc773d2face9 Mon Sep 17 00:00:00 2001 From: Caren Thomas Date: Wed, 18 Dec 2024 13:59:59 -0800 Subject: [PATCH] propagate error on tool failure --- letta/agent.py | 25 ++++++++++++++++++++-- letta/server/rest_api/interface.py | 4 ++-- tests/test_client.py | 34 ++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 485f2112b9..00a0765c88 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -822,7 +822,7 @@ def _handle_ai_response( function_args.pop("self", None) # error_msg = f"Error calling function {function_name} with args {function_args}: {str(e)}" # Less detailed - don't provide full args, idea is that it should be in recent context so no need (just adds noise) - error_msg = f"Error calling function {function_name}: {str(e)}" + error_msg = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)) error_msg_user = f"{error_msg}\n{traceback.format_exc()}" printd(error_msg_user) function_response = package_function_response(False, error_msg) @@ -844,8 +844,29 @@ def _handle_ai_response( self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1]) return messages, False, True # force a heartbeat to allow agent to handle error + # Step 4: check if function response is an error + if function_response_string.startswith("Error"): + function_response = package_function_response(False, function_response_string) + # TODO: truncate error message somehow + messages.append( + Message.dict_to_message( + agent_id=self.agent_state.id, + user_id=self.agent_state.created_by_id, + model=self.model, + openai_message_dict={ + "role": "tool", + "name": function_name, + "content": function_response, + "tool_call_id": tool_call_id, + }, + ) + ) # extend conversation with function response + self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1]) + self.interface.function_message(f"Error: {function_response_string}", msg_obj=messages[-1]) + return messages, False, True # force a heartbeat to allow agent to handle error + # If no failures happened along the way: ... - # Step 4: send the info on the function call and function response to GPT + # Step 5: send the info on the function call and function response to GPT messages.append( Message.dict_to_message( agent_id=self.agent_state.id, diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index 11843250c7..7483120082 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -238,7 +238,7 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_ new_message = {"function_return": msg, "status": "success"} elif msg.startswith("Error: "): - msg = msg.replace("Error: ", "") + msg = msg.replace("Error: ", "", count=1) new_message = {"function_return": msg, "status": "error"} else: @@ -951,7 +951,7 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None): ) elif msg.startswith("Error: "): - msg = msg.replace("Error: ", "") + msg = msg.replace("Error: ", "", count=1) # new_message = {"function_return": msg, "status": "error"} assert msg_obj.tool_call_id is not None new_message = FunctionReturn( diff --git a/tests/test_client.py b/tests/test_client.py index f37fe86267..4bf6c530cf 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,4 +1,5 @@ import asyncio +import json import os import threading import time @@ -382,6 +383,39 @@ def big_return(): client.delete_agent(agent_id=agent.id) +def test_function_always_error(client: Union[LocalClient, RESTClient]): + """Test to see if function that errors works correctly""" + + def always_error(): + """ + Always throw an error. + """ + return 5/0 + + tool = client.create_or_update_tool(func=always_error) + agent = client.create_agent(tool_ids=[tool.id]) + # get function response + response = client.send_message(agent_id=agent.id, message="call the always_error function", role="user") + print(response.messages) + + response_message = None + for message in response.messages: + if isinstance(message, FunctionReturn): + response_message = message + break + + assert response_message, "FunctionReturn message not found in response" + assert response_message.status == "error" + if isinstance(client, RESTClient): + assert response_message.function_return == "Error executing function always_error: ZeroDivisionError: division by zero" + else: + response_json = json.loads(response_message.function_return) + assert response_json['status'] == "Failed" + assert response_json['message'] == "Error executing function always_error: ZeroDivisionError: division by zero" + + client.delete_agent(agent_id=agent.id) + + @pytest.mark.asyncio async def test_send_message_parallel(client: Union[LocalClient, RESTClient], agent: AgentState, request): """