diff --git a/examples/helper.py b/examples/helper.py index 7fc15bf8b0..f19e8ffb13 100644 --- a/examples/helper.py +++ b/examples/helper.py @@ -75,6 +75,10 @@ def nb_print(messages): return_data = json.loads(msg.function_return) if "message" in return_data and return_data["message"] == "None": continue + if msg.message_type == "tool_return_message": + return_data = json.loads(msg.tool_return) + if "message" in return_data and return_data["message"] == "None": + continue title = msg.message_type.replace("_", " ").upper() html_output += f""" @@ -94,11 +98,17 @@ def get_formatted_content(msg): elif msg.message_type == "function_call": args = format_json(msg.function_call.arguments) return f'
{html.escape(msg.function_call.name)}({args})
' + elif msg.message_type == "tool_call_message": + args = format_json(msg.tool_call.arguments) + return f'
{html.escape(msg.function_call.name)}({args})
' elif msg.message_type == "function_return": - return_value = format_json(msg.function_return) # return f'
Status: {html.escape(msg.status)}
{return_value}
' return f'
{return_value}
' + elif msg.message_type == "tool_return_message": + return_value = format_json(msg.tool_return) + # return f'
Status: {html.escape(msg.status)}
{return_value}
' + return f'
{return_value}
' elif msg.message_type == "user_message": if is_json(msg.message): return f'
{format_json(msg.message)}
' diff --git a/examples/tool_rule_usage.py b/examples/tool_rule_usage.py index 45c56ec3eb..7d04df6c5a 100644 --- a/examples/tool_rule_usage.py +++ b/examples/tool_rule_usage.py @@ -2,7 +2,7 @@ import uuid from letta import create_client -from letta.schemas.letta_message import FunctionCallMessage +from letta.schemas.letta_message import ToolCallMessage from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule from tests.helpers.endpoints_helper import ( assert_invoked_send_message_with_keyword, @@ -116,9 +116,9 @@ def main(): # 6. Here, we thoroughly check the correctness of the response tool_names += ["send_message"] # Add send message because we expect this to be called at the end for m in response.messages: - if isinstance(m, FunctionCallMessage): + if isinstance(m, ToolCallMessage): # Check that it's equal to the first one - assert m.function_call.name == tool_names[0] + assert m.tool_call.name == tool_names[0] # Pop out first one tool_names = tool_names[1:] diff --git a/letta/client/streaming.py b/letta/client/streaming.py index 80a8a814e5..9f57352b73 100644 --- a/letta/client/streaming.py +++ b/letta/client/streaming.py @@ -8,8 +8,8 @@ from letta.errors import LLMError from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import ( - FunctionCallMessage, - FunctionReturn, + ToolCallMessage, + ToolReturnMessage, InternalMonologue, ) from letta.schemas.letta_response import LettaStreamingResponse @@ -55,10 +55,10 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe chunk_data = json.loads(sse.data) if "internal_monologue" in chunk_data: yield InternalMonologue(**chunk_data) - elif "function_call" in chunk_data: - yield FunctionCallMessage(**chunk_data) - elif "function_return" in chunk_data: - yield FunctionReturn(**chunk_data) + elif "tool_call" in chunk_data: + yield ToolCallMessage(**chunk_data) + elif "tool_return" in chunk_data: + yield ToolReturnMessage(**chunk_data) elif "usage" in chunk_data: yield LettaUsageStatistics(**chunk_data["usage"]) else: diff --git a/letta/errors.py b/letta/errors.py index 0dc7cc9ec0..4957139bee 100644 --- a/letta/errors.py +++ b/letta/errors.py @@ -131,16 +131,16 @@ def construct_error_message(messages: List[Union["Message", "LettaMessage"]], er return f"{error_msg}\n\n{message_json}" -class MissingFunctionCallError(LettaMessageError): - """Error raised when a message is missing a function call.""" +class MissingToolCallError(LettaMessageError): + """Error raised when a message is missing a tool call.""" - default_error_message = "The message is missing a function call." + default_error_message = "The message is missing a tool call." -class InvalidFunctionCallError(LettaMessageError): - """Error raised when a message uses an invalid function call.""" +class InvalidToolCallError(LettaMessageError): + """Error raised when a message uses an invalid tool call.""" - default_error_message = "The message uses an invalid function call or has improper usage of a function call." + default_error_message = "The message uses an invalid tool call or has improper usage of a tool call." class MissingInnerMonologueError(LettaMessageError): diff --git a/letta/schemas/letta_message.py b/letta/schemas/letta_message.py index 3b2dc73480..424a003844 100644 --- a/letta/schemas/letta_message.py +++ b/letta/schemas/letta_message.py @@ -9,7 +9,7 @@ class LettaMessage(BaseModel): """ - Base class for simplified Letta message response type. This is intended to be used for developers who want the internal monologue, function calls, and function returns in a simplified format that does not include additional information other than the content and timestamp. + Base class for simplified Letta message response type. This is intended to be used for developers who want the internal monologue, tool calls, and tool returns in a simplified format that does not include additional information other than the content and timestamp. Attributes: id (str): The ID of the message @@ -74,18 +74,18 @@ class InternalMonologue(LettaMessage): internal_monologue: str -class FunctionCall(BaseModel): +class ToolCall(BaseModel): name: str arguments: str - function_call_id: str + tool_call_id: str -class FunctionCallDelta(BaseModel): +class ToolCallDelta(BaseModel): name: Optional[str] arguments: Optional[str] - function_call_id: Optional[str] + tool_call_id: Optional[str] # NOTE: this is a workaround to exclude None values from the JSON dump, # since the OpenAI style of returning chunks doesn't include keys with null values @@ -97,67 +97,67 @@ def json(self, *args, **kwargs): return json.dumps(self.model_dump(exclude_none=True), *args, **kwargs) -class FunctionCallMessage(LettaMessage): +class ToolCallMessage(LettaMessage): """ - A message representing a request to call a function (generated by the LLM to trigger function execution). + A message representing a request to call a tool (generated by the LLM to trigger tool execution). Attributes: - function_call (Union[FunctionCall, FunctionCallDelta]): The function call + tool_call (Union[ToolCall, ToolCallDelta]): The tool call id (str): The ID of the message date (datetime): The date the message was created in ISO format """ - message_type: Literal["function_call"] = "function_call" - function_call: Union[FunctionCall, FunctionCallDelta] + message_type: Literal["tool_call_message"] = "tool_call_message" + tool_call: Union[ToolCall, ToolCallDelta] - # NOTE: this is required for the FunctionCallDelta exclude_none to work correctly + # NOTE: this is required for the ToolCallDelta exclude_none to work correctly def model_dump(self, *args, **kwargs): kwargs["exclude_none"] = True data = super().model_dump(*args, **kwargs) - if isinstance(data["function_call"], dict): - data["function_call"] = {k: v for k, v in data["function_call"].items() if v is not None} + if isinstance(data["tool_call"], dict): + data["tool_call"] = {k: v for k, v in data["tool_call"].items() if v is not None} return data class Config: json_encoders = { - FunctionCallDelta: lambda v: v.model_dump(exclude_none=True), - FunctionCall: lambda v: v.model_dump(exclude_none=True), + ToolCallDelta: lambda v: v.model_dump(exclude_none=True), + ToolCall: lambda v: v.model_dump(exclude_none=True), } - # NOTE: this is required to cast dicts into FunctionCallMessage objects + # NOTE: this is required to cast dicts into ToolCallMessage objects # Without this extra validator, Pydantic will throw an error if 'name' or 'arguments' are None - # (instead of properly casting to FunctionCallDelta instead of FunctionCall) - @field_validator("function_call", mode="before") + # (instead of properly casting to ToolCallDelta instead of ToolCall) + @field_validator("tool_call", mode="before") @classmethod - def validate_function_call(cls, v): + def validate_tool_call(cls, v): if isinstance(v, dict): - if "name" in v and "arguments" in v and "function_call_id" in v: - return FunctionCall(name=v["name"], arguments=v["arguments"], function_call_id=v["function_call_id"]) - elif "name" in v or "arguments" in v or "function_call_id" in v: - return FunctionCallDelta(name=v.get("name"), arguments=v.get("arguments"), function_call_id=v.get("function_call_id")) + if "name" in v and "arguments" in v and "tool_call_id" in v: + return ToolCall(name=v["name"], arguments=v["arguments"], tool_call_id=v["tool_call_id"]) + elif "name" in v or "arguments" in v or "tool_call_id" in v: + return ToolCallDelta(name=v.get("name"), arguments=v.get("arguments"), tool_call_id=v.get("tool_call_id")) else: - raise ValueError("function_call must contain either 'name' or 'arguments'") + raise ValueError("tool_call must contain either 'name' or 'arguments'") return v -class FunctionReturn(LettaMessage): +class ToolReturnMessage(LettaMessage): """ - A message representing the return value of a function call (generated by Letta executing the requested function). + A message representing the return value of a tool call (generated by Letta executing the requested tool). Attributes: - function_return (str): The return value of the function - status (Literal["success", "error"]): The status of the function call + tool_return (str): The return value of the tool + status (Literal["success", "error"]): The status of the tool call id (str): The ID of the message date (datetime): The date the message was created in ISO format - function_call_id (str): A unique identifier for the function call that generated this message - stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the function invocation - stderr (Optional[List(str)]): Captured stderr from the function invocation + tool_call_id (str): A unique identifier for the tool call that generated this message + stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the tool invocation + stderr (Optional[List(str)]): Captured stderr from the tool invocation """ - message_type: Literal["function_return"] = "function_return" - function_return: str + message_type: Literal["tool_return_message"] = "tool_return_message" + tool_return: str status: Literal["success", "error"] - function_call_id: str + tool_call_id: str stdout: Optional[List[str]] = None stderr: Optional[List[str]] = None @@ -174,10 +174,32 @@ class LegacyFunctionCallMessage(LettaMessage): function_call: str -LegacyLettaMessage = Union[InternalMonologue, AssistantMessage, LegacyFunctionCallMessage, FunctionReturn] +class LegacyFunctionReturn(LettaMessage): + """ + A message representing the return value of a function call (generated by Letta executing the requested function). + + Attributes: + function_return (str): The return value of the function + status (Literal["success", "error"]): The status of the function call + id (str): The ID of the message + date (datetime): The date the message was created in ISO format + function_call_id (str): A unique identifier for the function call that generated this message + stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the function invocation + stderr (Optional[List(str)]): Captured stderr from the function invocation + """ + + message_type: Literal["function_return"] = "function_return" + function_return: str + status: Literal["success", "error"] + function_call_id: str + stdout: Optional[List[str]] = None + stderr: Optional[List[str]] = None + + +LegacyLettaMessage = Union[InternalMonologue, AssistantMessage, LegacyFunctionCallMessage, LegacyFunctionReturn] LettaMessageUnion = Annotated[ - Union[SystemMessage, UserMessage, InternalMonologue, FunctionCallMessage, FunctionReturn, AssistantMessage], + Union[SystemMessage, UserMessage, InternalMonologue, ToolCallMessage, ToolReturnMessage, AssistantMessage], Field(discriminator="message_type"), ] diff --git a/letta/schemas/letta_response.py b/letta/schemas/letta_response.py index 58dbf42929..990a55cb15 100644 --- a/letta/schemas/letta_response.py +++ b/letta/schemas/letta_response.py @@ -43,11 +43,17 @@ def get_formatted_content(msg): elif msg.message_type == "function_call": args = format_json(msg.function_call.arguments) return f'
{html.escape(msg.function_call.name)}({args})
' + elif msg.message_type == "tool_call_message": + args = format_json(msg.tool_call.arguments) + return f'
{html.escape(msg.function_call.name)}({args})
' elif msg.message_type == "function_return": - return_value = format_json(msg.function_return) # return f'
Status: {html.escape(msg.status)}
{return_value}
' return f'
{return_value}
' + elif msg.message_type == "tool_return_message": + return_value = format_json(msg.tool_return) + # return f'
Status: {html.escape(msg.status)}
{return_value}
' + return f'
{return_value}
' elif msg.message_type == "user_message": if is_json(msg.message): return f'
{format_json(msg.message)}
' diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 0d2f77a715..b548ec256a 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -16,9 +16,9 @@ from letta.schemas.letta_base import OrmMetadataBase from letta.schemas.letta_message import ( AssistantMessage, - FunctionCall, - FunctionCallMessage, - FunctionReturn, + ToolCall as LettaToolCall, + ToolCallMessage, + ToolReturnMessage, InternalMonologue, LettaMessage, SystemMessage, @@ -172,18 +172,18 @@ def to_letta_message( ) else: messages.append( - FunctionCallMessage( + ToolCallMessage( id=self.id, date=self.created_at, - function_call=FunctionCall( + tool_call=LettaToolCall( name=tool_call.function.name, arguments=tool_call.function.arguments, - function_call_id=tool_call.id, + tool_call_id=tool_call.id, ), ) ) elif self.role == MessageRole.tool: - # This is type FunctionReturn + # This is type ToolReturnMessage # Try to interpret the function return, recall that this is how we packaged: # def package_function_response(was_success, response_string, timestamp=None): # formatted_time = get_local_time() if timestamp is None else timestamp @@ -208,12 +208,12 @@ def to_letta_message( messages.append( # TODO make sure this is what the API returns # function_return may not match exactly... - FunctionReturn( + ToolReturnMessage( id=self.id, date=self.created_at, - function_return=self.text, + tool_return=self.text, status=status_enum, - function_call_id=self.tool_call_id, + tool_call_id=self.tool_call_id, ) ) elif self.role == MessageRole.user: diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index 11843250c7..85248ec3bb 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -12,10 +12,10 @@ from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import ( AssistantMessage, - FunctionCall, - FunctionCallDelta, - FunctionCallMessage, - FunctionReturn, + ToolCall, + ToolCallDelta, + ToolCallMessage, + ToolReturnMessage, InternalMonologue, LegacyFunctionCallMessage, LegacyLettaMessage, @@ -411,7 +411,7 @@ def clear(): def _process_chunk_to_letta_style( self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime - ) -> Optional[Union[InternalMonologue, FunctionCallMessage, AssistantMessage]]: + ) -> Optional[Union[InternalMonologue, ToolCallMessage, AssistantMessage]]: """ Example data from non-streaming response looks like: @@ -442,7 +442,7 @@ def _process_chunk_to_letta_style( if self.inner_thoughts_in_kwargs: raise NotImplementedError("inner_thoughts_in_kwargs with use_assistant_message not yet supported") - # If we just received a chunk with the message in it, we either enter "send_message" mode, or we do standard FunctionCallMessage passthrough mode + # If we just received a chunk with the message in it, we either enter "send_message" mode, or we do standard ToolCallMessage passthrough mode # Track the function name while streaming # If we were previously on a 'send_message', we need to 'toggle' into 'content' mode @@ -474,7 +474,7 @@ def _process_chunk_to_letta_style( assistant_message=cleaned_func_args, ) - # otherwise we just do a regular passthrough of a FunctionCallDelta via a FunctionCallMessage + # otherwise we just do a regular passthrough of a ToolCallDelta via a ToolCallMessage else: tool_call_delta = {} if tool_call.id: @@ -485,13 +485,13 @@ def _process_chunk_to_letta_style( if tool_call.function.name: tool_call_delta["name"] = tool_call.function.name - processed_chunk = FunctionCallMessage( + processed_chunk = ToolCallMessage( id=message_id, date=message_date, - function_call=FunctionCallDelta( + tool_call=ToolCallDelta( name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments"), - function_call_id=tool_call_delta.get("id"), + tool_call_id=tool_call_delta.get("id"), ), ) @@ -531,7 +531,7 @@ def _process_chunk_to_letta_style( else: self.function_args_buffer += updates_main_json - # If we have main_json, we should output a FunctionCallMessage + # If we have main_json, we should output a ToolCallMessage elif updates_main_json: # If there's something in the function_name buffer, we should release it first @@ -539,13 +539,13 @@ def _process_chunk_to_letta_style( # however the frontend may expect name first, then args, so to be # safe we'll output name first in a separate chunk if self.function_name_buffer: - processed_chunk = FunctionCallMessage( + processed_chunk = ToolCallMessage( id=message_id, date=message_date, - function_call=FunctionCallDelta( + tool_call=ToolCallDelta( name=self.function_name_buffer, arguments=None, - function_call_id=self.function_id_buffer, + tool_call_id=self.function_id_buffer, ), ) # Clear the buffer @@ -561,20 +561,20 @@ def _process_chunk_to_letta_style( self.function_args_buffer += updates_main_json # If there was nothing in the name buffer, we can proceed to - # output the arguments chunk as a FunctionCallMessage + # output the arguments chunk as a ToolCallMessage else: # There may be a buffer from a previous chunk, for example # if the previous chunk had arguments but we needed to flush name if self.function_args_buffer: # In this case, we should release the buffer + new data at once combined_chunk = self.function_args_buffer + updates_main_json - processed_chunk = FunctionCallMessage( + processed_chunk = ToolCallMessage( id=message_id, date=message_date, - function_call=FunctionCallDelta( + tool_call=ToolCallDelta( name=None, arguments=combined_chunk, - function_call_id=self.function_id_buffer, + tool_call_id=self.function_id_buffer, ), ) # clear buffer @@ -582,13 +582,13 @@ def _process_chunk_to_letta_style( self.function_id_buffer = None else: # If there's no buffer to clear, just output a new chunk with new data - processed_chunk = FunctionCallMessage( + processed_chunk = ToolCallMessage( id=message_id, date=message_date, - function_call=FunctionCallDelta( + tool_call=ToolCallDelta( name=None, arguments=updates_main_json, - function_call_id=self.function_id_buffer, + tool_call_id=self.function_id_buffer, ), ) self.function_id_buffer = None @@ -608,10 +608,10 @@ def _process_chunk_to_letta_style( # # if tool_call.function.name: # # tool_call_delta["name"] = tool_call.function.name - # processed_chunk = FunctionCallMessage( + # processed_chunk = ToolCallMessage( # id=message_id, # date=message_date, - # function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")), + # tool_call=ToolCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")), # ) else: @@ -642,10 +642,10 @@ def _process_chunk_to_letta_style( # if tool_call.function.name: # tool_call_delta["name"] = tool_call.function.name - # processed_chunk = FunctionCallMessage( + # processed_chunk = ToolCallMessage( # id=message_id, # date=message_date, - # function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")), + # tool_call=ToolCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")), # ) # elif False and self.inner_thoughts_in_kwargs and tool_call.function: @@ -682,13 +682,13 @@ def _process_chunk_to_letta_style( # If it does match, start processing the value (stringified-JSON string # And with each new chunk, output it as a chunk of type InternalMonologue - # If the key doesn't match, then flush the buffer as a single FunctionCallMessage chunk + # If the key doesn't match, then flush the buffer as a single ToolCallMessage chunk # If we're reading a value # If we're reading the inner thoughts value, we output chunks of type InternalMonologue - # Otherwise, do simple chunks of FunctionCallMessage + # Otherwise, do simple chunks of ToolCallMessage else: @@ -701,13 +701,13 @@ def _process_chunk_to_letta_style( if tool_call.function.name: tool_call_delta["name"] = tool_call.function.name - processed_chunk = FunctionCallMessage( + processed_chunk = ToolCallMessage( id=message_id, date=message_date, - function_call=FunctionCallDelta( + tool_call=ToolCallDelta( name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments"), - function_call_id=tool_call_delta.get("id"), + tool_call_id=tool_call_delta.get("id"), ), ) @@ -911,13 +911,13 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None): assistant_message=func_args[self.assistant_message_tool_kwarg], ) else: - processed_chunk = FunctionCallMessage( + processed_chunk = ToolCallMessage( id=msg_obj.id, date=msg_obj.created_at, - function_call=FunctionCall( + tool_call=ToolCall( name=function_call.function.name, arguments=function_call.function.arguments, - function_call_id=function_call.id, + tool_call_id=function_call.id, ), ) @@ -942,24 +942,24 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None): msg = msg.replace("Success: ", "") # new_message = {"function_return": msg, "status": "success"} assert msg_obj.tool_call_id is not None - new_message = FunctionReturn( + new_message = ToolReturnMessage( id=msg_obj.id, date=msg_obj.created_at, - function_return=msg, + tool_return=msg, status="success", - function_call_id=msg_obj.tool_call_id, + tool_call_id=msg_obj.tool_call_id, ) elif msg.startswith("Error: "): msg = msg.replace("Error: ", "") # new_message = {"function_return": msg, "status": "error"} assert msg_obj.tool_call_id is not None - new_message = FunctionReturn( + new_message = ToolReturnMessage( id=msg_obj.id, date=msg_obj.created_at, - function_return=msg, + tool_return=msg, status="error", - function_call_id=msg_obj.tool_call_id, + tool_call_id=msg_obj.tool_call_id, ) else: diff --git a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py index 3dc7916a35..deabcaf5c5 100644 --- a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +++ b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Body, Depends, Header, HTTPException from letta.schemas.enums import MessageRole -from letta.schemas.letta_message import FunctionCall, LettaMessage +from letta.schemas.letta_message import ToolCall, LettaMessage from letta.schemas.openai.chat_completion_request import ChatCompletionRequest from letta.schemas.openai.chat_completion_response import ( ChatCompletionResponse, @@ -94,7 +94,7 @@ async def create_chat_completion( created_at = None for letta_msg in response_messages.messages: assert isinstance(letta_msg, LettaMessage) - if isinstance(letta_msg, FunctionCall): + if isinstance(letta_msg, ToolCall): if letta_msg.name and letta_msg.name == "send_message": try: letta_function_call_args = json.loads(letta_msg.arguments) diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 15979346c4..500e6fc86e 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -7,7 +7,7 @@ from letta.errors import LettaToolCreateError from letta.orm.errors import UniqueConstraintViolationError -from letta.schemas.letta_message import FunctionReturn +from letta.schemas.letta_message import ToolReturnMessage from letta.schemas.tool import Tool, ToolCreate, ToolRunFromSource, ToolUpdate from letta.schemas.user import User from letta.server.rest_api.utils import get_letta_server @@ -163,7 +163,7 @@ def upsert_base_tools( return server.tool_manager.upsert_base_tools(actor=actor) -@router.post("/run", response_model=FunctionReturn, operation_id="run_tool_from_source") +@router.post("/run", response_model=ToolReturnMessage, operation_id="run_tool_from_source") def run_tool_from_source( server: SyncServer = Depends(get_letta_server), request: ToolRunFromSource = Body(...), diff --git a/letta/server/server.py b/letta/server/server.py index 1b8ff9cdc7..617283345a 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -47,7 +47,7 @@ # openai schemas from letta.schemas.enums import JobStatus from letta.schemas.job import Job, JobUpdate -from letta.schemas.letta_message import FunctionReturn, LettaMessage +from letta.schemas.letta_message import ToolReturnMessage, LettaMessage from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ( ArchivalMemorySummary, @@ -1350,7 +1350,7 @@ def run_tool_from_source( tool_source: str, tool_source_type: Optional[str] = None, tool_name: Optional[str] = None, - ) -> FunctionReturn: + ) -> ToolReturnMessage: """Run a tool from source code""" try: @@ -1374,24 +1374,24 @@ def run_tool_from_source( # Next, attempt to run the tool with the sandbox try: sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, actor, tool_object=tool).run(agent_state=agent_state) - return FunctionReturn( + return ToolReturnMessage( id="null", - function_call_id="null", + tool_call_id="null", date=get_utc_time(), status=sandbox_run_result.status, - function_return=str(sandbox_run_result.func_return), + tool_return=str(sandbox_run_result.func_return), stdout=sandbox_run_result.stdout, stderr=sandbox_run_result.stderr, ) except Exception as e: func_return = get_friendly_error_msg(function_name=tool.name, exception_name=type(e).__name__, exception_message=str(e)) - return FunctionReturn( + return ToolReturnMessage( id="null", - function_call_id="null", + tool_call_id="null", date=get_utc_time(), status="error", - function_return=func_return, + tool_return=func_return, stdout=[], stderr=[traceback.format_exc()], ) diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 8f1aa99c74..a526f2315b 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -15,9 +15,9 @@ from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA from letta.embeddings import embedding_model from letta.errors import ( - InvalidFunctionCallError, + InvalidToolCallError, InvalidInnerMonologueError, - MissingFunctionCallError, + MissingToolCallError, MissingInnerMonologueError, ) from letta.llm_api.llm_api_tools import create @@ -25,7 +25,7 @@ from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.letta_message import ( - FunctionCallMessage, + ToolCallMessage, InternalMonologue, LettaMessage, ) @@ -377,23 +377,23 @@ def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keywo # Find first instance of send_message target_message = None for message in messages: - if isinstance(message, FunctionCallMessage) and message.function_call.name == "send_message": + if isinstance(message, ToolCallMessage) and message.tool_call.name == "send_message": target_message = message break # No messages found with `send_messages` if target_message is None: - raise MissingFunctionCallError(messages=messages, explanation="Missing `send_message` function call") + raise MissingToolCallError(messages=messages, explanation="Missing `send_message` function call") - send_message_function_call = target_message.function_call + send_message_function_call = target_message.tool_call try: arguments = json.loads(send_message_function_call.arguments) except: - raise InvalidFunctionCallError(messages=[target_message], explanation="Function call arguments could not be loaded into JSON") + raise InvalidToolCallError(messages=[target_message], explanation="Function call arguments could not be loaded into JSON") # Message field not in send_message if "message" not in arguments: - raise InvalidFunctionCallError( + raise InvalidToolCallError( messages=[target_message], explanation=f"send_message function call does not have required field `message`" ) @@ -403,16 +403,16 @@ def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keywo arguments["message"] = arguments["message"].lower() if not keyword in arguments["message"]: - raise InvalidFunctionCallError(messages=[target_message], explanation=f"Message argument did not contain keyword={keyword}") + raise InvalidToolCallError(messages=[target_message], explanation=f"Message argument did not contain keyword={keyword}") def assert_invoked_function_call(messages: List[LettaMessage], function_name: str) -> None: for message in messages: - if isinstance(message, FunctionCallMessage) and message.function_call.name == function_name: + if isinstance(message, ToolCallMessage) and message.tool_call.name == function_name: # Found it, do nothing return - raise MissingFunctionCallError( + raise MissingToolCallError( messages=messages, explanation=f"No messages were found invoking function call with name: {function_name}" ) @@ -446,7 +446,7 @@ def assert_contains_valid_function_call( if (hasattr(message, "function_call") and message.function_call is not None) and ( hasattr(message, "tool_calls") and message.tool_calls is not None ): - raise InvalidFunctionCallError(messages=[message], explanation="Both function_call and tool_calls is present in the message") + raise InvalidToolCallError(messages=[message], explanation="Both function_call and tool_calls is present in the message") elif hasattr(message, "function_call") and message.function_call is not None: function_call = message.function_call elif hasattr(message, "tool_calls") and message.tool_calls is not None: @@ -455,10 +455,10 @@ def assert_contains_valid_function_call( function_call = message.tool_calls[0].function else: # Throw a missing function call error - raise MissingFunctionCallError(messages=[message]) + raise MissingToolCallError(messages=[message]) if function_call_validator and not function_call_validator(function_call): - raise InvalidFunctionCallError(messages=[message], explanation=validation_failure_summary) + raise InvalidToolCallError(messages=[message], explanation=validation_failure_summary) def assert_inner_monologue_is_valid(message: Message) -> None: diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index 336777215d..b013f55cb3 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -3,7 +3,7 @@ import pytest from letta import create_client -from letta.schemas.letta_message import FunctionCallMessage +from letta.schemas.letta_message import ToolCallMessage from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule from tests.helpers.endpoints_helper import ( assert_invoked_function_call, @@ -115,9 +115,9 @@ def test_single_path_agent_tool_call_graph(mock_e2b_api_key_none): tool_names = [t.name for t in [t1, t2, t3, t4]] tool_names += ["send_message"] for m in response.messages: - if isinstance(m, FunctionCallMessage): + if isinstance(m, ToolCallMessage): # Check that it's equal to the first one - assert m.function_call.name == tool_names[0] + assert m.tool_call.name == tool_names[0] # Pop out first one tool_names = tool_names[1:] @@ -220,9 +220,9 @@ def test_claude_initial_tool_rule_enforced(mock_e2b_api_key_none): tool_names = [t.name for t in [t1, t2]] tool_names += ["send_message"] for m in messages: - if isinstance(m, FunctionCallMessage): + if isinstance(m, ToolCallMessage): # Check that it's equal to the first one - assert m.function_call.name == tool_names[0] + assert m.tool_call.name == tool_names[0] # Pop out first one tool_names = tool_names[1:] @@ -273,9 +273,9 @@ def test_agent_no_structured_output_with_one_child_tool(mock_e2b_api_key_none): # Check ordering of tool calls tool_names = [t.name for t in [archival_memory_search, archival_memory_insert, send_message]] for m in response.messages: - if isinstance(m, FunctionCallMessage): + if isinstance(m, ToolCallMessage): # Check that it's equal to the first one - assert m.function_call.name == tool_names[0] + assert m.tool_call.name == tool_names[0] # Pop out first one tool_names = tool_names[1:] diff --git a/tests/test_client.py b/tests/test_client.py index f37fe86267..ec2bb6cb15 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -15,7 +15,7 @@ from letta.schemas.block import CreateBlock from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.job import JobStatus -from letta.schemas.letta_message import FunctionReturn +from letta.schemas.letta_message import ToolReturnMessage from letta.schemas.llm_config import LLMConfig from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxType from letta.utils import create_random_username @@ -365,12 +365,12 @@ def big_return(): response_message = None for message in response.messages: - if isinstance(message, FunctionReturn): + if isinstance(message, ToolReturnMessage): response_message = message break - assert response_message, "FunctionReturn message not found in response" - res = response_message.function_return + assert response_message, "ToolReturnMessage message not found in response" + res = response_message.tool_return assert "function output was truncated " in res # TODO: Re-enable later diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 7c634e5fa3..30b78c911f 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -18,8 +18,8 @@ from letta.schemas.enums import MessageRole, MessageStreamStatus from letta.schemas.letta_message import ( AssistantMessage, - FunctionCallMessage, - FunctionReturn, + ToolCallMessage, + ToolReturnMessage, InternalMonologue, LettaMessage, SystemMessage, @@ -172,8 +172,8 @@ def test_agent_interactions(mock_e2b_api_key_none, client: Union[LocalClient, RE SystemMessage, UserMessage, InternalMonologue, - FunctionCallMessage, - FunctionReturn, + ToolCallMessage, + ToolReturnMessage, AssistantMessage, ], f"Unexpected message type: {type(letta_message)}" @@ -258,7 +258,7 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent if isinstance(chunk, InternalMonologue) and chunk.internal_monologue and chunk.internal_monologue != "": inner_thoughts_exist = True inner_thoughts_count += 1 - if isinstance(chunk, FunctionCallMessage) and chunk.function_call and chunk.function_call.name == "send_message": + if isinstance(chunk, ToolCallMessage) and chunk.tool_call and chunk.tool_call.name == "send_message": send_message_ran = True if isinstance(chunk, MessageStreamStatus): if chunk == MessageStreamStatus.done: @@ -534,7 +534,7 @@ def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentStat message_response = client.send_message(agent_id=agent.id, message="Test message", role="user") print("Messages=", message_response) assert isinstance(message_response, LettaResponse) - assert isinstance(message_response.messages[-1], FunctionReturn) + assert isinstance(message_response.messages[-1], ToolReturnMessage) message = message_response.messages[-1] new_text = "This exact string would never show up in the message???" diff --git a/tests/test_server.py b/tests/test_server.py index 2003c68880..fc39d527a5 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -10,8 +10,8 @@ from letta.schemas.block import CreateBlock from letta.schemas.enums import MessageRole from letta.schemas.letta_message import ( - FunctionCallMessage, - FunctionReturn, + ToolCallMessage, + ToolReturnMessage, InternalMonologue, LettaMessage, SystemMessage, @@ -677,14 +677,14 @@ def _test_get_messages_letta_format( print(f"Assistant Message at {i}: {type(letta_message)}") if reverse: - # Reverse handling: FunctionCallMessages come first + # Reverse handling: ToolCallMessage come first if message.tool_calls: for tool_call in message.tool_calls: try: json.loads(tool_call.function.arguments) except json.JSONDecodeError: warnings.warn(f"Invalid JSON in function arguments: {tool_call.function.arguments}") - assert isinstance(letta_message, FunctionCallMessage) + assert isinstance(letta_message, ToolCallMessage) letta_message_index += 1 if letta_message_index >= len(letta_messages): break @@ -710,9 +710,9 @@ def _test_get_messages_letta_format( json.loads(tool_call.function.arguments) except json.JSONDecodeError: warnings.warn(f"Invalid JSON in function arguments: {tool_call.function.arguments}") - assert isinstance(letta_message, FunctionCallMessage) - assert tool_call.function.name == letta_message.function_call.name - assert tool_call.function.arguments == letta_message.function_call.arguments + assert isinstance(letta_message, ToolCallMessage) + assert tool_call.function.name == letta_message.tool_call.name + assert tool_call.function.arguments == letta_message.tool_call.arguments letta_message_index += 1 if letta_message_index >= len(letta_messages): break @@ -729,8 +729,8 @@ def _test_get_messages_letta_format( letta_message_index += 1 elif message.role == MessageRole.tool: - assert isinstance(letta_message, FunctionReturn) - assert message.text == letta_message.function_return + assert isinstance(letta_message, ToolReturnMessage) + assert message.text == letta_message.tool_return letta_message_index += 1 else: @@ -802,7 +802,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id): ) print(result) assert result.status == "success" - assert result.function_return == "Ingested message Hello, world!", result.function_return + assert result.tool_return == "Ingested message Hello, world!", result.tool_return assert not result.stdout assert not result.stderr @@ -815,7 +815,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id): ) print(result) assert result.status == "success" - assert result.function_return == "Ingested message Well well well", result.function_return + assert result.tool_return == "Ingested message Well well well", result.tool_return assert not result.stdout assert not result.stderr @@ -828,8 +828,8 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id): ) print(result) assert result.status == "error" - assert "Error" in result.function_return, result.function_return - assert "missing 1 required positional argument" in result.function_return, result.function_return + assert "Error" in result.tool_return, result.tool_return + assert "missing 1 required positional argument" in result.tool_return, result.tool_return assert not result.stdout assert result.stderr assert "missing 1 required positional argument" in result.stderr[0] @@ -844,7 +844,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id): ) print(result) assert result.status == "success" - assert result.function_return == "Ingested message Well well well", result.function_return + assert result.tool_return == "Ingested message Well well well", result.tool_return assert result.stdout assert "I'm a distractor" in result.stdout[0] assert not result.stderr @@ -859,7 +859,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id): ) print(result) assert result.status == "success" - assert result.function_return == "Ingested message Well well well", result.function_return + assert result.tool_return == "Ingested message Well well well", result.tool_return assert result.stdout assert "I'm a distractor" in result.stdout[0] assert not result.stderr @@ -874,7 +874,7 @@ def test_tool_run(server, mock_e2b_api_key_none, user, agent_id): ) print(result) assert result.status == "success" - assert result.function_return == str(None), result.function_return + assert result.tool_return == str(None), result.tool_return assert result.stdout assert "I'm a distractor" in result.stdout[0] assert not result.stderr