Skip to content

Commit

Permalink
feat: rename function to tool in sdk (#2288)
Browse files Browse the repository at this point in the history
Co-authored-by: Caren Thomas <caren@caren-mac.local>
  • Loading branch information
carenthomas and Caren Thomas authored Dec 19, 2024
1 parent d8be870 commit 87820d9
Show file tree
Hide file tree
Showing 16 changed files with 199 additions and 161 deletions.
12 changes: 11 additions & 1 deletion examples/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def nb_print(messages):
return_data = json.loads(msg.function_return)
if "message" in return_data and return_data["message"] == "None":
continue
if msg.message_type == "tool_return_message":
return_data = json.loads(msg.tool_return)
if "message" in return_data and return_data["message"] == "None":
continue

title = msg.message_type.replace("_", " ").upper()
html_output += f"""
Expand All @@ -94,11 +98,17 @@ def get_formatted_content(msg):
elif msg.message_type == "function_call":
args = format_json(msg.function_call.arguments)
return f'<div class="content"><span class="function-name">{html.escape(msg.function_call.name)}</span>({args})</div>'
elif msg.message_type == "tool_call_message":
args = format_json(msg.tool_call.arguments)
return f'<div class="content"><span class="function-name">{html.escape(msg.function_call.name)}</span>({args})</div>'
elif msg.message_type == "function_return":

return_value = format_json(msg.function_return)
# return f'<div class="status-line">Status: {html.escape(msg.status)}</div><div class="content">{return_value}</div>'
return f'<div class="content">{return_value}</div>'
elif msg.message_type == "tool_return_message":
return_value = format_json(msg.tool_return)
# return f'<div class="status-line">Status: {html.escape(msg.status)}</div><div class="content">{return_value}</div>'
return f'<div class="content">{return_value}</div>'
elif msg.message_type == "user_message":
if is_json(msg.message):
return f'<div class="content">{format_json(msg.message)}</div>'
Expand Down
6 changes: 3 additions & 3 deletions examples/tool_rule_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import uuid

from letta import create_client
from letta.schemas.letta_message import FunctionCallMessage
from letta.schemas.letta_message import ToolCallMessage
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from tests.helpers.endpoints_helper import (
assert_invoked_send_message_with_keyword,
Expand Down Expand Up @@ -116,9 +116,9 @@ def main():
# 6. Here, we thoroughly check the correctness of the response
tool_names += ["send_message"] # Add send message because we expect this to be called at the end
for m in response.messages:
if isinstance(m, FunctionCallMessage):
if isinstance(m, ToolCallMessage):
# Check that it's equal to the first one
assert m.function_call.name == tool_names[0]
assert m.tool_call.name == tool_names[0]
# Pop out first one
tool_names = tool_names[1:]

Expand Down
12 changes: 6 additions & 6 deletions letta/client/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from letta.errors import LLMError
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import (
FunctionCallMessage,
FunctionReturn,
ToolCallMessage,
ToolReturnMessage,
InternalMonologue,
)
from letta.schemas.letta_response import LettaStreamingResponse
Expand Down Expand Up @@ -55,10 +55,10 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
chunk_data = json.loads(sse.data)
if "internal_monologue" in chunk_data:
yield InternalMonologue(**chunk_data)
elif "function_call" in chunk_data:
yield FunctionCallMessage(**chunk_data)
elif "function_return" in chunk_data:
yield FunctionReturn(**chunk_data)
elif "tool_call" in chunk_data:
yield ToolCallMessage(**chunk_data)
elif "tool_return" in chunk_data:
yield ToolReturnMessage(**chunk_data)
elif "usage" in chunk_data:
yield LettaUsageStatistics(**chunk_data["usage"])
else:
Expand Down
12 changes: 6 additions & 6 deletions letta/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,16 @@ def construct_error_message(messages: List[Union["Message", "LettaMessage"]], er
return f"{error_msg}\n\n{message_json}"


class MissingFunctionCallError(LettaMessageError):
"""Error raised when a message is missing a function call."""
class MissingToolCallError(LettaMessageError):
"""Error raised when a message is missing a tool call."""

default_error_message = "The message is missing a function call."
default_error_message = "The message is missing a tool call."


class InvalidFunctionCallError(LettaMessageError):
"""Error raised when a message uses an invalid function call."""
class InvalidToolCallError(LettaMessageError):
"""Error raised when a message uses an invalid tool call."""

default_error_message = "The message uses an invalid function call or has improper usage of a function call."
default_error_message = "The message uses an invalid tool call or has improper usage of a tool call."


class MissingInnerMonologueError(LettaMessageError):
Expand Down
94 changes: 58 additions & 36 deletions letta/schemas/letta_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class LettaMessage(BaseModel):
"""
Base class for simplified Letta message response type. This is intended to be used for developers who want the internal monologue, function calls, and function returns in a simplified format that does not include additional information other than the content and timestamp.
Base class for simplified Letta message response type. This is intended to be used for developers who want the internal monologue, tool calls, and tool returns in a simplified format that does not include additional information other than the content and timestamp.
Attributes:
id (str): The ID of the message
Expand Down Expand Up @@ -74,18 +74,18 @@ class InternalMonologue(LettaMessage):
internal_monologue: str


class FunctionCall(BaseModel):
class ToolCall(BaseModel):

name: str
arguments: str
function_call_id: str
tool_call_id: str


class FunctionCallDelta(BaseModel):
class ToolCallDelta(BaseModel):

name: Optional[str]
arguments: Optional[str]
function_call_id: Optional[str]
tool_call_id: Optional[str]

# NOTE: this is a workaround to exclude None values from the JSON dump,
# since the OpenAI style of returning chunks doesn't include keys with null values
Expand All @@ -97,67 +97,67 @@ def json(self, *args, **kwargs):
return json.dumps(self.model_dump(exclude_none=True), *args, **kwargs)


class FunctionCallMessage(LettaMessage):
class ToolCallMessage(LettaMessage):
"""
A message representing a request to call a function (generated by the LLM to trigger function execution).
A message representing a request to call a tool (generated by the LLM to trigger tool execution).
Attributes:
function_call (Union[FunctionCall, FunctionCallDelta]): The function call
tool_call (Union[ToolCall, ToolCallDelta]): The tool call
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
"""

message_type: Literal["function_call"] = "function_call"
function_call: Union[FunctionCall, FunctionCallDelta]
message_type: Literal["tool_call_message"] = "tool_call_message"
tool_call: Union[ToolCall, ToolCallDelta]

# NOTE: this is required for the FunctionCallDelta exclude_none to work correctly
# NOTE: this is required for the ToolCallDelta exclude_none to work correctly
def model_dump(self, *args, **kwargs):
kwargs["exclude_none"] = True
data = super().model_dump(*args, **kwargs)
if isinstance(data["function_call"], dict):
data["function_call"] = {k: v for k, v in data["function_call"].items() if v is not None}
if isinstance(data["tool_call"], dict):
data["tool_call"] = {k: v for k, v in data["tool_call"].items() if v is not None}
return data

class Config:
json_encoders = {
FunctionCallDelta: lambda v: v.model_dump(exclude_none=True),
FunctionCall: lambda v: v.model_dump(exclude_none=True),
ToolCallDelta: lambda v: v.model_dump(exclude_none=True),
ToolCall: lambda v: v.model_dump(exclude_none=True),
}

# NOTE: this is required to cast dicts into FunctionCallMessage objects
# NOTE: this is required to cast dicts into ToolCallMessage objects
# Without this extra validator, Pydantic will throw an error if 'name' or 'arguments' are None
# (instead of properly casting to FunctionCallDelta instead of FunctionCall)
@field_validator("function_call", mode="before")
# (instead of properly casting to ToolCallDelta instead of ToolCall)
@field_validator("tool_call", mode="before")
@classmethod
def validate_function_call(cls, v):
def validate_tool_call(cls, v):
if isinstance(v, dict):
if "name" in v and "arguments" in v and "function_call_id" in v:
return FunctionCall(name=v["name"], arguments=v["arguments"], function_call_id=v["function_call_id"])
elif "name" in v or "arguments" in v or "function_call_id" in v:
return FunctionCallDelta(name=v.get("name"), arguments=v.get("arguments"), function_call_id=v.get("function_call_id"))
if "name" in v and "arguments" in v and "tool_call_id" in v:
return ToolCall(name=v["name"], arguments=v["arguments"], tool_call_id=v["tool_call_id"])
elif "name" in v or "arguments" in v or "tool_call_id" in v:
return ToolCallDelta(name=v.get("name"), arguments=v.get("arguments"), tool_call_id=v.get("tool_call_id"))
else:
raise ValueError("function_call must contain either 'name' or 'arguments'")
raise ValueError("tool_call must contain either 'name' or 'arguments'")
return v


class FunctionReturn(LettaMessage):
class ToolReturnMessage(LettaMessage):
"""
A message representing the return value of a function call (generated by Letta executing the requested function).
A message representing the return value of a tool call (generated by Letta executing the requested tool).
Attributes:
function_return (str): The return value of the function
status (Literal["success", "error"]): The status of the function call
tool_return (str): The return value of the tool
status (Literal["success", "error"]): The status of the tool call
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
function_call_id (str): A unique identifier for the function call that generated this message
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the function invocation
stderr (Optional[List(str)]): Captured stderr from the function invocation
tool_call_id (str): A unique identifier for the tool call that generated this message
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the tool invocation
stderr (Optional[List(str)]): Captured stderr from the tool invocation
"""

message_type: Literal["function_return"] = "function_return"
function_return: str
message_type: Literal["tool_return_message"] = "tool_return_message"
tool_return: str
status: Literal["success", "error"]
function_call_id: str
tool_call_id: str
stdout: Optional[List[str]] = None
stderr: Optional[List[str]] = None

Expand All @@ -174,10 +174,32 @@ class LegacyFunctionCallMessage(LettaMessage):
function_call: str


LegacyLettaMessage = Union[InternalMonologue, AssistantMessage, LegacyFunctionCallMessage, FunctionReturn]
class LegacyFunctionReturn(LettaMessage):
"""
A message representing the return value of a function call (generated by Letta executing the requested function).
Attributes:
function_return (str): The return value of the function
status (Literal["success", "error"]): The status of the function call
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
function_call_id (str): A unique identifier for the function call that generated this message
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the function invocation
stderr (Optional[List(str)]): Captured stderr from the function invocation
"""

message_type: Literal["function_return"] = "function_return"
function_return: str
status: Literal["success", "error"]
function_call_id: str
stdout: Optional[List[str]] = None
stderr: Optional[List[str]] = None


LegacyLettaMessage = Union[InternalMonologue, AssistantMessage, LegacyFunctionCallMessage, LegacyFunctionReturn]


LettaMessageUnion = Annotated[
Union[SystemMessage, UserMessage, InternalMonologue, FunctionCallMessage, FunctionReturn, AssistantMessage],
Union[SystemMessage, UserMessage, InternalMonologue, ToolCallMessage, ToolReturnMessage, AssistantMessage],
Field(discriminator="message_type"),
]
8 changes: 7 additions & 1 deletion letta/schemas/letta_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,17 @@ def get_formatted_content(msg):
elif msg.message_type == "function_call":
args = format_json(msg.function_call.arguments)
return f'<div class="content"><span class="function-name">{html.escape(msg.function_call.name)}</span>({args})</div>'
elif msg.message_type == "tool_call_message":
args = format_json(msg.tool_call.arguments)
return f'<div class="content"><span class="function-name">{html.escape(msg.function_call.name)}</span>({args})</div>'
elif msg.message_type == "function_return":

return_value = format_json(msg.function_return)
# return f'<div class="status-line">Status: {html.escape(msg.status)}</div><div class="content">{return_value}</div>'
return f'<div class="content">{return_value}</div>'
elif msg.message_type == "tool_return_message":
return_value = format_json(msg.tool_return)
# return f'<div class="status-line">Status: {html.escape(msg.status)}</div><div class="content">{return_value}</div>'
return f'<div class="content">{return_value}</div>'
elif msg.message_type == "user_message":
if is_json(msg.message):
return f'<div class="content">{format_json(msg.message)}</div>'
Expand Down
20 changes: 10 additions & 10 deletions letta/schemas/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.letta_message import (
AssistantMessage,
FunctionCall,
FunctionCallMessage,
FunctionReturn,
ToolCall as LettaToolCall,
ToolCallMessage,
ToolReturnMessage,
InternalMonologue,
LettaMessage,
SystemMessage,
Expand Down Expand Up @@ -172,18 +172,18 @@ def to_letta_message(
)
else:
messages.append(
FunctionCallMessage(
ToolCallMessage(
id=self.id,
date=self.created_at,
function_call=FunctionCall(
tool_call=LettaToolCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
function_call_id=tool_call.id,
tool_call_id=tool_call.id,
),
)
)
elif self.role == MessageRole.tool:
# This is type FunctionReturn
# This is type ToolReturnMessage
# Try to interpret the function return, recall that this is how we packaged:
# def package_function_response(was_success, response_string, timestamp=None):
# formatted_time = get_local_time() if timestamp is None else timestamp
Expand All @@ -208,12 +208,12 @@ def to_letta_message(
messages.append(
# TODO make sure this is what the API returns
# function_return may not match exactly...
FunctionReturn(
ToolReturnMessage(
id=self.id,
date=self.created_at,
function_return=self.text,
tool_return=self.text,
status=status_enum,
function_call_id=self.tool_call_id,
tool_call_id=self.tool_call_id,
)
)
elif self.role == MessageRole.user:
Expand Down
Loading

0 comments on commit 87820d9

Please sign in to comment.