Skip to content

Commit

Permalink
chore: Add tool rules example (#1998)
Browse files Browse the repository at this point in the history
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
  • Loading branch information
mattzh72 and sarahwooders authored Nov 7, 2024
1 parent e5c194b commit d9d53db
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 39 deletions.
11 changes: 9 additions & 2 deletions examples/docs/agent_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,31 @@

# create a new agent
agent_state = client.create_agent(
# agent's name (unique per-user, autogenerated if not provided)
name="agent_name",
# in-context memory representation with human/persona blocks
memory=ChatMemory(human="Name: Sarah", persona="You are a helpful assistant that loves emojis"),
# LLM model & endpoint configuration
llm_config=LLMConfig(
model="gpt-4",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=8000,
context_window=8000, # set to <= max context window
),
# embedding model & endpoint configuration (cannot be changed)
embedding_config=EmbeddingConfig(
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_model="text-embedding-ada-002",
embedding_dim=1536,
embedding_chunk_size=300,
),
# system instructions for the agent (defaults to `memgpt_chat`)
system=gpt_system.get_system_text("memgpt_chat"),
tools=[],
# whether to include base letta tools (default: True)
include_base_tools=True,
# list of additional tools (by name) to add to the agent
tools=[],
)
print(f"Created agent with name {agent_state.name} and unique ID {agent_state.id}")

Expand Down
15 changes: 2 additions & 13 deletions examples/docs/agent_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,8 @@
client = create_client()

# set automatic defaults for LLM/embedding config
client.set_default_llm_config(
LLMConfig(model="gpt-4o-mini", model_endpoint_type="openai", model_endpoint="https://api.openai.com/v1", context_window=128000)
)
client.set_default_embedding_config(
EmbeddingConfig(
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_model="text-embedding-ada-002",
embedding_dim=1536,
embedding_chunk_size=300,
)
)

client.set_default_llm_config(LLMConfig.default_config(model_name="gpt-4"))
client.set_default_embedding_config(EmbeddingConfig.default_config(model_name="text-embedding-ada-002"))

# create a new agent
agent_state = client.create_agent()
Expand Down
33 changes: 18 additions & 15 deletions examples/docs/tools.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
from letta import EmbeddingConfig, LLMConfig, create_client
from letta.schemas.tool_rule import TerminalToolRule

client = create_client()
# set automatic defaults for LLM/embedding config
client.set_default_llm_config(
LLMConfig(model="gpt-4", model_endpoint_type="openai", model_endpoint="https://api.openai.com/v1", context_window=8000)
)
client.set_default_embedding_config(
EmbeddingConfig(
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_model="text-embedding-ada-002",
embedding_dim=1536,
embedding_chunk_size=300,
)
)
client.set_default_llm_config(LLMConfig.default_config(model_name="gpt-4"))
client.set_default_embedding_config(EmbeddingConfig.default_config(model_name="text-embedding-ada-002"))


# define a function with a docstring
def roll_d20() -> str:
def roll_d20(self) -> str:
"""
Simulate the roll of a 20-sided die (d20).
Expand All @@ -38,10 +29,22 @@ def roll_d20() -> str:
return output_string


tool = client.create_tool(roll_d20, name="roll_dice")
# create a tool from the function
tool = client.create_tool(roll_d20)
print(f"Created tool with name {tool.name}")

# create a new agent
agent_state = client.create_agent(tools=[tool.name])
agent_state = client.create_agent(
# create the agent with an additional tool
tools=[tool.name],
# add tool rules that terminate execution after specific tools
tool_rules=[
# exit after roll_d20 is called
TerminalToolRule(tool_name=tool.name),
# exit after send_message is called (default behavior)
TerminalToolRule(tool_name="send_message"),
],
)
print(f"Created agent with name {agent_state.name} with tools {agent_state.tools}")

# Message an agent
Expand Down
132 changes: 132 additions & 0 deletions examples/tool_rule_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import os
import uuid

from letta import create_client
from letta.schemas.letta_message import FunctionCallMessage
from letta.schemas.tool_rule import InitToolRule, TerminalToolRule, ToolRule
from tests.helpers.endpoints_helper import (
assert_invoked_send_message_with_keyword,
setup_agent,
)
from tests.helpers.utils import cleanup
from tests.test_endpoints import llm_config_dir

"""
This example shows how you can constrain tool calls in your agent.
Please note that this currently only works reliably for models with Structured Outputs (e.g. gpt-4o).
Start by downloading the dependencies.
```
poetry install --all-extras
```
"""

# Tools for this example
# Generate uuid for agent name for this example
namespace = uuid.NAMESPACE_DNS
agent_uuid = str(uuid.uuid5(namespace, "agent_tool_graph"))
config_file = os.path.join(llm_config_dir, "openai-gpt-4o.json")

"""Contrived tools for this test case"""


def first_secret_word(self: "Agent"):
"""
Call this to retrieve the first secret word, which you will need for the second_secret_word function.
"""
return "v0iq020i0g"


def second_secret_word(self: "Agent", prev_secret_word: str):
"""
Call this to retrieve the second secret word, which you will need for the third_secret_word function. If you get the word wrong, this function will error.
Args:
prev_secret_word (str): The secret word retrieved from calling first_secret_word.
"""
if prev_secret_word != "v0iq020i0g":
raise RuntimeError(f"Expected secret {"v0iq020i0g"}, got {prev_secret_word}")

return "4rwp2b4gxq"


def third_secret_word(self: "Agent", prev_secret_word: str):
"""
Call this to retrieve the third secret word, which you will need for the fourth_secret_word function. If you get the word wrong, this function will error.
Args:
prev_secret_word (str): The secret word retrieved from calling second_secret_word.
"""
if prev_secret_word != "4rwp2b4gxq":
raise RuntimeError(f"Expected secret {"4rwp2b4gxq"}, got {prev_secret_word}")

return "hj2hwibbqm"


def fourth_secret_word(self: "Agent", prev_secret_word: str):
"""
Call this to retrieve the last secret word, which you will need to output in a send_message later. If you get the word wrong, this function will error.
Args:
prev_secret_word (str): The secret word retrieved from calling third_secret_word.
"""
if prev_secret_word != "hj2hwibbqm":
raise RuntimeError(f"Expected secret {"hj2hwibbqm"}, got {prev_secret_word}")

return "banana"


def auto_error(self: "Agent"):
"""
If you call this function, it will throw an error automatically.
"""
raise RuntimeError("This should never be called.")


def main():
# 1. Set up the client
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)

# 2. Add all the tools to the client
functions = [first_secret_word, second_secret_word, third_secret_word, fourth_secret_word, auto_error]
tools = []
for func in functions:
tool = client.create_tool(func)
tools.append(tool)
tool_names = [t.name for t in tools[:-1]]

# 3. Create the tool rules. It must be called in this order, or there will be an error thrown.
tool_rules = [
InitToolRule(tool_name="first_secret_word"),
ToolRule(tool_name="first_secret_word", children=["second_secret_word"]),
ToolRule(tool_name="second_secret_word", children=["third_secret_word"]),
ToolRule(tool_name="third_secret_word", children=["fourth_secret_word"]),
ToolRule(tool_name="fourth_secret_word", children=["send_message"]),
TerminalToolRule(tool_name="send_message"),
]

# 4. Create the agent
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tools=[t.name for t in tools], tool_rules=tool_rules)

# 5. Ask for the final secret word
response = client.user_message(agent_id=agent_state.id, message="What is the fourth secret word?")

# 6. Here, we thoroughly check the correctness of the response
tool_names += ["send_message"] # Add send message because we expect this to be called at the end
for m in response.messages:
if isinstance(m, FunctionCallMessage):
# Check that it's equal to the first one
assert m.function_call.name == tool_names[0]
# Pop out first one
tool_names = tool_names[1:]

# Check final send message contains "banana"
assert_invoked_send_message_with_keyword(response.messages, "banana")
print(f"Got successful response from client: \n\n{response}")
cleanup(client=client, agent_uuid=agent_uuid)


if __name__ == "__main__":
main()
11 changes: 8 additions & 3 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,14 @@ def __init__(

if agent_state.tool_rules is None:
agent_state.tool_rules = []
agent_state.tool_rules.append(TerminalToolRule(tool_name="send_message"))
# Define the rule to add
send_message_terminal_rule = TerminalToolRule(tool_name="send_message")
# Check if an equivalent rule is already present
if not any(
isinstance(rule, TerminalToolRule) and rule.tool_name == send_message_terminal_rule.tool_name for rule in agent_state.tool_rules
):
agent_state.tool_rules.append(send_message_terminal_rule)

self.tool_rules_solver = ToolRulesSolver(tool_rules=agent_state.tool_rules)

# gpt-4, gpt-3.5-turbo, ...
Expand Down Expand Up @@ -395,7 +402,6 @@ def link_tools(self, tools: List[Tool]):
exec(tool.module, env)
else:
exec(tool.source_code, env)

self.functions_python[tool.json_schema["name"]] = env[tool.json_schema["name"]]
self.functions.append(tool.json_schema)
except Exception as e:
Expand Down Expand Up @@ -787,7 +793,6 @@ def _handle_ai_response(

# Update ToolRulesSolver state with last called function
self.tool_rules_solver.update_tool_usage(function_name)

# Update heartbeat request according to provided tool rules
if self.tool_rules_solver.has_children_tools(function_name):
heartbeat_request = True
Expand Down
8 changes: 3 additions & 5 deletions letta/services/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def __init__(self):
def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
# Derive json_schema
derived_json_schema = pydantic_tool.json_schema or derive_openai_json_schema(
source_code=pydantic_tool.source_code, name=pydantic_tool.name
)
derived_json_schema = pydantic_tool.json_schema or derive_openai_json_schema(source_code=pydantic_tool.source_code)
derived_name = pydantic_tool.name or derived_json_schema["name"]

try:
Expand Down Expand Up @@ -120,8 +118,8 @@ def update_tool_by_id(self, tool_id: str, tool_update: ToolUpdate, actor: Pydant
if "source_code" in update_data.keys() and "json_schema" not in update_data.keys():
pydantic_tool = tool.to_pydantic()

name = update_data["name"] if "name" in update_data.keys() else None
new_schema = derive_openai_json_schema(source_code=pydantic_tool.source_code, name=name)
update_data["name"] if "name" in update_data.keys() else None
new_schema = derive_openai_json_schema(source_code=pydantic_tool.source_code)

tool.json_schema = new_schema

Expand Down
2 changes: 1 addition & 1 deletion tests/test_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def counter_tool(counter: int):
og_json_schema = tool_fixture["tool_create"].json_schema

source_code = parse_source_code(counter_tool)
name = "test_function_name_explicit"
name = "counter_tool"

# Create a ToolUpdate object to modify the tool's source_code
tool_update = ToolUpdate(name=name, source_code=source_code)
Expand Down

0 comments on commit d9d53db

Please sign in to comment.