From 1dea00c268e6851637eee0004ffc3a1f919aa8e3 Mon Sep 17 00:00:00 2001 From: Mindy Long Date: Wed, 18 Dec 2024 11:23:09 -0800 Subject: [PATCH 01/11] Add ConditionalToolRule * update ToolRuleSolver to * check for Init->Terminal paths * remove cycle detection * updated tests for conditional rules, valid paths --- letta/helpers/tool_rule_solver.py | 93 +++++++++++++++++--------- letta/schemas/enums.py | 1 + letta/schemas/tool_rule.py | 15 ++++- tests/test_tool_rule_solver.py | 104 ++++++++++++++++++++++++------ 4 files changed, 160 insertions(+), 53 deletions(-) diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index ef4d9a9b37..8d5aec466e 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -1,4 +1,5 @@ -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Union +from collections import deque from pydantic import BaseModel, Field @@ -6,6 +7,7 @@ from letta.schemas.tool_rule import ( BaseToolRule, ChildToolRule, + ConditionalToolRule, InitToolRule, TerminalToolRule, ) @@ -22,7 +24,7 @@ class ToolRulesSolver(BaseModel): init_tool_rules: List[InitToolRule] = Field( default_factory=list, description="Initial tool rules to be used at the start of tool execution." ) - tool_rules: List[ChildToolRule] = Field( + tool_rules: List[Union[ChildToolRule, ConditionalToolRule]] = Field( default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions." ) terminal_tool_rules: List[TerminalToolRule] = Field( @@ -35,15 +37,22 @@ def __init__(self, tool_rules: List[BaseToolRule], **kwargs): # Separate the provided tool rules into init, standard, and terminal categories for rule in tool_rules: if rule.type == ToolRuleType.run_first: + assert isinstance(rule, InitToolRule) self.init_tool_rules.append(rule) elif rule.type == ToolRuleType.constrain_child_tools: + assert isinstance(rule, ChildToolRule) + self.tool_rules.append(rule) + elif rule.type == ToolRuleType.conditional: + assert isinstance(rule, ConditionalToolRule) + self.validate_conditional_tool(rule) self.tool_rules.append(rule) elif rule.type == ToolRuleType.exit_loop: + assert isinstance(rule, TerminalToolRule) self.terminal_tool_rules.append(rule) # Validate the tool rules to ensure they form a DAG if not self.validate_tool_rules(): - raise ToolRuleValidationError("Tool rules contain cycles, which are not allowed in a valid configuration.") + raise ToolRuleValidationError("Tool rules does not have a path from Init to Terminal.") def update_tool_usage(self, tool_name: str): """Update the internal state to track the last tool called.""" @@ -78,38 +87,58 @@ def has_children_tools(self, tool_name): """Check if the tool has children tools""" return any(rule.tool_name == tool_name for rule in self.tool_rules) + def validate_conditional_tool(self, rule: ConditionalToolRule): + if rule.children is None or len(rule.children) == 0: + raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.") + if len(rule.children) != len(rule.child_output_mapping): + raise ToolRuleValidationError("Conditional tool rule must have a child output mapping for each child tool.") + if set(rule.children) != set(rule.child_output_mapping.values()): + raise ToolRuleValidationError("Conditional tool rule must have a child output mapping for each child tool.") + return True + def validate_tool_rules(self) -> bool: """ - Validate that the tool rules define a directed acyclic graph (DAG). - Returns True if valid (no cycles), otherwise False. + Validate that there exists a path from every init tool to a terminal tool. + Returns True if valid (path exists), otherwise False. """ # Build adjacency list for the tool graph adjacency_list: Dict[str, List[str]] = {rule.tool_name: rule.children for rule in self.tool_rules} - # Track visited nodes - visited: Set[str] = set() - path_stack: Set[str] = set() - - # Define DFS helper function - def dfs(tool_name: str) -> bool: - if tool_name in path_stack: - return False # Cycle detected - if tool_name in visited: - return True # Already validated - - # Mark the node as visited in the current path - path_stack.add(tool_name) - for child in adjacency_list.get(tool_name, []): - if not dfs(child): - return False # Cycle detected in DFS - path_stack.remove(tool_name) # Remove from current path - visited.add(tool_name) - return True - - # Run DFS from each tool in `tool_rules` - for rule in self.tool_rules: - if rule.tool_name not in visited: - if not dfs(rule.tool_name): - return False # Cycle found, invalid tool rules - - return True # No cycles, valid DAG + init_tool_names = {rule.tool_name for rule in self.init_tool_rules} + terminal_tool_names = {rule.tool_name for rule in self.terminal_tool_rules} + + # Initial checks + if len(init_tool_names) == 0: + if len(terminal_tool_names) + len(self.tool_rules) > 0: + return False # No init tools defined + else: + return True # No tool rules + if len(terminal_tool_names) == 0: + if len(adjacency_list) > 0: + return False # No terminal tools defined + else: + return True # Only init tools + + # Define BFS helper function to find path to terminal tool + def has_path_to_terminal(start_tool: str) -> bool: + visited = set() + queue = deque([start_tool]) + visited.add(start_tool) + + while queue: + current_tool = queue.popleft() + if current_tool in terminal_tool_names: + return True + + for child in adjacency_list.get(current_tool, []): + if child not in visited: + visited.add(child) + queue.append(child) + return False + + # Check if each init tool has a path to a terminal tool + for init_tool_name in init_tool_names: + if not has_path_to_terminal(init_tool_name): + return False + + return True # All init tools have paths to terminal tools diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index 8b74b83732..6183033f54 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -45,5 +45,6 @@ class ToolRuleType(str, Enum): run_first = "InitToolRule" exit_loop = "TerminalToolRule" # reasoning loop should exit continue_loop = "continue_loop" # reasoning loop should continue + conditional = "conditional" constrain_child_tools = "ToolRule" require_parent_tools = "require_parent_tools" diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py index b320917d25..a7f4f7cc1b 100644 --- a/letta/schemas/tool_rule.py +++ b/letta/schemas/tool_rule.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import Dict, List, Union from pydantic import Field @@ -21,6 +21,17 @@ class ChildToolRule(BaseToolRule): children: List[str] = Field(..., description="The children tools that can be invoked.") +class ConditionalToolRule(BaseToolRule): + """ + A ToolRule that conditionally maps to different child tools based on the output. + """ + type: ToolRuleType = ToolRuleType.conditional + default_child: str = Field(..., description="The default child tool to be called") + child_output_mapping: Dict[Union[bool, str, int], str] = Field(..., description="The output case to check for mapping") + children: List[str] = Field(..., description="The child tool to call when output matches the case") + throw_error: bool = Field(default=False, description="Whether to throw an error when output doesn't match any case") + + class InitToolRule(BaseToolRule): """ Represents the initial tool rule configuration. @@ -37,4 +48,4 @@ class TerminalToolRule(BaseToolRule): type: ToolRuleType = ToolRuleType.exit_loop -ToolRule = Union[ChildToolRule, InitToolRule, TerminalToolRule] +ToolRule = Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule] diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index 9de6a6302b..e1170e4b73 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -2,7 +2,12 @@ from letta.helpers import ToolRulesSolver from letta.helpers.tool_rule_solver import ToolRuleValidationError -from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule +from letta.schemas.tool_rule import ( + ChildToolRule, + ConditionalToolRule, + InitToolRule, + TerminalToolRule +) # Constants for tool names used in the tests START_TOOL = "start_tool" @@ -31,7 +36,9 @@ def test_get_allowed_tool_names_with_subsequent_rule(): # Setup: Tool rule sequence init_rule = InitToolRule(tool_name=START_TOOL) rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL]) - solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1], terminal_tool_rules=[]) + rule_2 = ChildToolRule(tool_name=NEXT_TOOL, children=[END_TOOL]) + terminal_rule = TerminalToolRule(tool_name=END_TOOL) + solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1, rule_2], terminal_tool_rules=[terminal_rule]) # Action: Update usage and get allowed tools solver.update_tool_usage(START_TOOL) @@ -44,21 +51,22 @@ def test_get_allowed_tool_names_with_subsequent_rule(): def test_is_terminal_tool(): # Setup: Terminal tool rule configuration init_rule = InitToolRule(tool_name=START_TOOL) + rule_1 = ChildToolRule(tool_name=START_TOOL, children=[END_TOOL]) terminal_rule = TerminalToolRule(tool_name=END_TOOL) - solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[terminal_rule]) + solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1], terminal_tool_rules=[terminal_rule]) # Action & Assert: Verify terminal and non-terminal tools assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as a terminal tool" assert solver.is_terminal_tool(START_TOOL) is False, "Should not recognize 'start_tool' as a terminal tool" -def test_get_allowed_tool_names_no_matching_rule_warning(): - # Setup: Tool rules with no matching rule for the last tool - init_rule = InitToolRule(tool_name=START_TOOL) - solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[]) +# def test_get_allowed_tool_names_no_matching_rule_warning(): +# # Setup: Tool rules with no matching rule for the last tool +# init_rule = InitToolRule(tool_name=START_TOOL) +# solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[]) - # Action: Set last tool to an unrecognized tool and check warnings - solver.update_tool_usage(UNRECOGNIZED_TOOL) +# # Action: Set last tool to an unrecognized tool and check warnings +# solver.update_tool_usage(UNRECOGNIZED_TOOL) # NOTE: removed for now since this warning is getting triggered on every LLM call # with warnings.catch_warnings(record=True) as w: @@ -104,7 +112,65 @@ def test_update_tool_usage_and_get_allowed_tool_names_combined(): assert solver.is_terminal_tool(FINAL_TOOL) is True, "Should recognize 'final_tool' as terminal" -def test_tool_rules_with_cycle_detection(): +def test_conditional_tool_rule(): + # Setup: Define a conditional tool rule + init_rule = InitToolRule(tool_name=START_TOOL) + terminal_rule = TerminalToolRule(tool_name=END_TOOL) + rule = ConditionalToolRule( + tool_name=START_TOOL, + children=[START_TOOL, END_TOOL], + default_child=END_TOOL, + child_output_mapping={True: END_TOOL, False: START_TOOL} + ) + solver = ToolRulesSolver(tool_rules=[init_rule, rule, terminal_rule]) + + # Action & Assert: Verify the rule properties + # Step 1: Initially allowed tools + assert solver.get_allowed_tool_names() == [START_TOOL], "Initial allowed tool should be 'start_tool'" + + # Step 2: After using 'start_tool' + solver.update_tool_usage(START_TOOL) + assert set(solver.get_allowed_tool_names()) == set({END_TOOL, START_TOOL}), "After 'start_tool', should allow 'start_tool' or 'end_tool'" + + # Step 3: After using 'end_tool' + assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as terminal" + + +def test_invalid_conditional_tool_rule(): + # Setup: Define an invalid conditional tool rule + init_rule = InitToolRule(tool_name=START_TOOL) + terminal_rule = TerminalToolRule(tool_name=END_TOOL) + invalid_rule_1 = ConditionalToolRule( + tool_name=START_TOOL, + children=[START_TOOL], + default_child=END_TOOL, + child_output_mapping={True: END_TOOL, False: START_TOOL} + ) + invalid_rule_2 = ConditionalToolRule( + tool_name=START_TOOL, + children=[START_TOOL, END_TOOL], + default_child=END_TOOL, + child_output_mapping={True: END_TOOL} + ) + invalid_rule_3 = ConditionalToolRule( + tool_name=START_TOOL, + children=[START_TOOL, FINAL_TOOL], + default_child=FINAL_TOOL, + child_output_mapping={True: END_TOOL, False: START_TOOL} + ) + + # Test 1: Missing child output mapping + with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have a child output mapping for each child tool."): + ToolRulesSolver(tool_rules=[init_rule, invalid_rule_1, terminal_rule]) + with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have a child output mapping for each child tool."): + ToolRulesSolver(tool_rules=[init_rule, invalid_rule_2, terminal_rule]) + + # Test 2: Missing child + with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have a child output mapping for each child tool."): + ToolRulesSolver(tool_rules=[init_rule, invalid_rule_3, terminal_rule]) + + +def test_tool_rules_with_invalid_path(): # Setup: Define tool rules with both connected, disconnected nodes and a cycle init_rule = InitToolRule(tool_name=START_TOOL) rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL]) @@ -114,14 +180,14 @@ def test_tool_rules_with_cycle_detection(): terminal_rule = TerminalToolRule(tool_name=END_TOOL) # Action & Assert: Attempt to create the ToolRulesSolver with a cycle should raise ValidationError - with pytest.raises(ToolRuleValidationError, match="Tool rules contain cycles"): + with pytest.raises(ToolRuleValidationError, match="Tool rules does not have a path from Init to Terminal."): ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, terminal_rule]) - # Extra setup: Define tool rules without a cycle but with hanging nodes - rule_5 = ChildToolRule(tool_name=PREP_TOOL, children=[FINAL_TOOL]) # Hanging node with no connection to start_tool - - # Assert that a configuration without cycles does not raise an error - try: - ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_4, rule_5, terminal_rule]) - except ToolRuleValidationError: - pytest.fail("ToolRulesSolver raised ValidationError unexpectedly on a valid DAG with hanging nodes") + # Now: add a path from the start tool to the final tool + rule_5 = ConditionalToolRule( + tool_name=HELPER_TOOL, + children=[START_TOOL, FINAL_TOOL], + default_child=FINAL_TOOL, + child_output_mapping={True: START_TOOL, False: FINAL_TOOL}, + ) + ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, rule_5, terminal_rule]) From b4a1534434f9eeb178b6fbf8fabad8d79c0df5fc Mon Sep 17 00:00:00 2001 From: Mindy Long Date: Wed, 18 Dec 2024 14:45:58 -0800 Subject: [PATCH 02/11] Tool chaining in agent flow * added state to track last function result in agent * added logic in ToolRuleSolver to choose correct next tool * Integrated test cases for conditional tools in agent --- letta/agent.py | 7 +- letta/helpers/tool_rule_solver.py | 71 +++++-- letta/orm/custom_columns.py | 7 +- tests/integration_test_agent_tool_graph.py | 211 ++++++++++++++++++++- 4 files changed, 281 insertions(+), 15 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 485f2112b9..e3bf859908 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -298,6 +298,9 @@ def __init__( self.first_message_verify_mono = first_message_verify_mono + # State needed for conditional tool chaining + self.last_function_response = None + # Controls if the convo memory pressure warning is triggered # When an alert is sent in the message queue, set this to True (to avoid repeat alerts) # When the summarizer is run, set this back to False (to reset) @@ -586,7 +589,7 @@ def _get_ai_reply( ) -> ChatCompletionResponse: """Get response from LLM API with robust retry mechanism.""" - allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names() + allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(last_function_response=self.last_function_response) agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools] allowed_functions = ( @@ -826,6 +829,7 @@ def _handle_ai_response( error_msg_user = f"{error_msg}\n{traceback.format_exc()}" printd(error_msg_user) function_response = package_function_response(False, error_msg) + self.last_function_response = function_response # TODO: truncate error message somehow messages.append( Message.dict_to_message( @@ -861,6 +865,7 @@ def _handle_ai_response( ) # extend conversation with function response self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1]) self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1]) + self.last_function_response = function_response else: # Standard non-function reply diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index 8d5aec466e..b8f2abbd71 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -1,3 +1,4 @@ +import json from typing import Dict, List, Optional, Union from collections import deque @@ -58,7 +59,7 @@ def update_tool_usage(self, tool_name: str): """Update the internal state to track the last tool called.""" self.last_tool_name = tool_name - def get_allowed_tool_names(self, error_on_empty: bool = False) -> List[str]: + def get_allowed_tool_names(self, error_on_empty: bool = False, last_function_response: Optional[str] = None) -> List[str]: """Get a list of tool names allowed based on the last tool called.""" if self.last_tool_name is None: # Use initial tool rules if no tool has been called yet @@ -67,18 +68,20 @@ def get_allowed_tool_names(self, error_on_empty: bool = False) -> List[str]: # Find a matching ToolRule for the last tool used current_rule = next((rule for rule in self.tool_rules if rule.tool_name == self.last_tool_name), None) - # Return children which must exist on ToolRule - if current_rule: - return current_rule.children - - # Default to empty if no rule matches - message = "User provided tool rules and execution state resolved to no more possible tool calls." - if error_on_empty: - raise RuntimeError(message) - else: - # warnings.warn(message) + if current_rule is None: + if error_on_empty: + raise ValueError(f"No tool rule found for {self.last_tool_name}") return [] + # If the current rule is a conditional tool rule, use the LLM response to + # determine which child tool to use + if isinstance(current_rule, ConditionalToolRule): + if not last_function_response: + raise ValueError("Conditional tool rule requires an LLM response to determine which child tool to use") + return [self.evaluate_conditional_tool(current_rule, last_function_response)] + + return current_rule.children if current_rule.children else [] + def is_terminal_tool(self, tool_name: str) -> bool: """Check if the tool is defined as a terminal tool in the terminal tool rules.""" return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules) @@ -88,6 +91,15 @@ def has_children_tools(self, tool_name): return any(rule.tool_name == tool_name for rule in self.tool_rules) def validate_conditional_tool(self, rule: ConditionalToolRule): + ''' + Validate a conditional tool rule + + Args: + rule (ConditionalToolRule): The conditional tool rule to validate + + Raises: + ToolRuleValidationError: If the rule is invalid + ''' if rule.children is None or len(rule.children) == 0: raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.") if len(rule.children) != len(rule.child_output_mapping): @@ -142,3 +154,40 @@ def has_path_to_terminal(start_tool: str) -> bool: return False return True # All init tools have paths to terminal tools + + def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str: + ''' + Parse function response to determine which child tool to use based on the mapping + + Args: + tool (ConditionalToolRule): The conditional tool rule + last_function_response (str): The function response in JSON format + + Returns: + str: The name of the child tool to use next + ''' + json_response = json.loads(last_function_response) + function_output = json_response["message"] + + # Try to match the function output with a mapping key + for key in tool.child_output_mapping: + + # Convert function output to match key type for comparison + if key == "true" or key == "false": + try: + typed_output = function_output.lower() + except AttributeError: + continue + elif isinstance(key, int): + try: + typed_output = int(function_output) + except (ValueError, TypeError): + continue + else: # string + typed_output = str(function_output) + + if typed_output == key: + return tool.child_output_mapping[key] + + # If no match found, use default + return tool.default_child diff --git a/letta/orm/custom_columns.py b/letta/orm/custom_columns.py index 1d8263e332..f53169d93e 100644 --- a/letta/orm/custom_columns.py +++ b/letta/orm/custom_columns.py @@ -9,7 +9,7 @@ from letta.schemas.enums import ToolRuleType from letta.schemas.llm_config import LLMConfig from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction -from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule +from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule class EmbeddingConfigColumn(TypeDecorator): @@ -80,7 +80,7 @@ def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, Init return value @staticmethod - def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]: + def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]: """Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'.""" rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var if rule_type == ToolRuleType.run_first: @@ -90,6 +90,9 @@ def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, Term elif rule_type == ToolRuleType.constrain_child_tools: rule = ChildToolRule(**data) return rule + elif rule_type == ToolRuleType.conditional: + rule = ConditionalToolRule(**data) + return rule else: raise ValueError(f"Unknown tool rule type: {rule_type}") diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index 336777215d..aa59c47c4d 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -4,7 +4,12 @@ import pytest from letta import create_client from letta.schemas.letta_message import FunctionCallMessage -from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule +from letta.schemas.tool_rule import ( + ChildToolRule, + ConditionalToolRule, + InitToolRule, + TerminalToolRule, +) from tests.helpers.endpoints_helper import ( assert_invoked_function_call, assert_invoked_send_message_with_keyword, @@ -68,6 +73,50 @@ def fourth_secret_word(prev_secret_word: str): return "banana" +def flip_coin(): + """ + Call this to retrieve the password to the secret word, which you will need to output in a send_message later. + If it returns an empty string, try flipping again! + + Returns: + str: The password or an empty string + """ + import random + + # Flip a coin with 50% chance + if random.random() < 0.5: + return "" + return "hj2hwibbqm" + + +def flip_coin_hard(): + """ + Call this to retrieve the password to the secret word, which you will need to output in a send_message later. + If it returns an empty string, try flipping again! + + Returns: + str: The password or an empty string + """ + import random + + # Flip a coin with 50% chance + result = random.random() + if result < 0.5: + return "" + if result < 0.75: + return "START_OVER" + return "hj2hwibbqm" + + +def can_play_game(): + """ + Call this to start the tool chain. + """ + import random + + return random.random() < 0.5 + + def auto_error(): """ If you call this function, it will throw an error automatically. @@ -282,3 +331,163 @@ def test_agent_no_structured_output_with_one_child_tool(mock_e2b_api_key_none): print(f"Got successful response from client: \n\n{response}") cleanup(client=client, agent_uuid=agent_uuid) + + +@pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely +def test_agent_conditional_tool_easy(mock_e2b_api_key_none): + """ + Test the agent with a conditional tool that has a child tool. + + Tool Flow: + + ------- + | | + | v + -- flip_coin + | + v + reveal_secret_word + """ + + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + + coin_flip_name = "flip_coin" + secret_word_tool = "fourth_secret_word" + flip_coin_tool = client.create_or_update_tool(flip_coin, name=coin_flip_name) + reveal_secret = client.create_or_update_tool(fourth_secret_word, name=secret_word_tool) + + # Make tool rules + tool_rules = [ + InitToolRule(tool_name=coin_flip_name), + ConditionalToolRule( + tool_name=coin_flip_name, + default_child=coin_flip_name, + children=[secret_word_tool], + child_output_mapping={ + "hj2hwibbqm": secret_word_tool, + } + ), + TerminalToolRule(tool_name=secret_word_tool), + ] + tools = [flip_coin_tool, reveal_secret] + + config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json" + agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) + response = client.user_message(agent_id=agent_state.id, message="flip a coin until you get the secret word") + + # Make checks + assert_sanity_checks(response) + + # Assert the tools were called + assert_invoked_function_call(response.messages, "flip_coin") + assert_invoked_function_call(response.messages, "fourth_secret_word") + + # Check ordering of tool calls + found_secret_word = False + for m in response.messages: + if isinstance(m, FunctionCallMessage): + if m.function_call.name == secret_word_tool: + # Should be the last tool call + found_secret_word = True + else: + # Before finding secret_word, only flip_coin should be called + assert m.function_call.name == coin_flip_name + assert not found_secret_word + + # Ensure we found the secret word exactly once + assert found_secret_word + + print(f"Got successful response from client: \n\n{response}") + cleanup(client=client, agent_uuid=agent_uuid) + + + +@pytest.mark.timeout(90) # Longer timeout since this test has more steps +def test_agent_conditional_tool_hard(mock_e2b_api_key_none): + """ + Test the agent with a complex conditional tool graph + + Tool Flow: + + can_play_game <---+ + | | + v | + flip_coin -----+ + | + v + fourth_secret_word + """ + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + + # Create tools + play_game = "can_play_game" + coin_flip_name = "flip_coin_hard" + final_tool = "fourth_secret_word" + play_game_tool = client.create_or_update_tool(can_play_game, name=play_game) + flip_coin_tool = client.create_or_update_tool(flip_coin_hard, name=coin_flip_name) + reveal_secret = client.create_or_update_tool(fourth_secret_word, name=final_tool) + + # Make tool rules - chain them together with conditional rules + tool_rules = [ + InitToolRule(tool_name=play_game), + ConditionalToolRule( + tool_name=play_game, + default_child=play_game, # Keep trying if we can't play + children=[coin_flip_name], + child_output_mapping={ + True: coin_flip_name # Only allow access when can_play_game returns True + } + ), + ConditionalToolRule( + tool_name=coin_flip_name, + default_child=coin_flip_name, + children=[play_game, final_tool], + child_output_mapping={ + "hj2hwibbqm": final_tool, "START_OVER": play_game + } + ), + TerminalToolRule(tool_name=final_tool), + ] + + # Setup agent with all tools + tools = [play_game_tool, flip_coin_tool, reveal_secret] + config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json" + agent_state = setup_agent( + client, + config_file, + agent_uuid=agent_uuid, + tool_ids=[t.id for t in tools], + tool_rules=tool_rules + ) + + # Ask agent to try to get all secret words + response = client.user_message(agent_id=agent_state.id, message="hi") + + # Make checks + assert_sanity_checks(response) + + # Assert all tools were called + assert_invoked_function_call(response.messages, play_game) + assert_invoked_function_call(response.messages, final_tool) + + # Check ordering of tool calls + found_words = [] + for m in response.messages: + if isinstance(m, FunctionCallMessage): + name = m.function_call.name + if name in [play_game, coin_flip_name]: + # Before finding secret_word, only can_play_game and flip_coin should be called + assert name in [play_game, coin_flip_name] + else: + # Should find secret words in order + expected_word = final_tool + assert name == expected_word, f"Found {name} but expected {expected_word}" + found_words.append(name) + + # Ensure we found all secret words in order + assert found_words == [final_tool] + + print(f"Got successful response from client: \n\n{response}") + cleanup(client=client, agent_uuid=agent_uuid) From 9d28212ce1e374fffedddb15bcdccadaa4762766 Mon Sep 17 00:00:00 2001 From: Mindy Long Date: Wed, 18 Dec 2024 16:58:14 -0800 Subject: [PATCH 03/11] Fixes: less stringent tool chain checks, correct boolean eval logic for conditional tools --- letta/helpers/tool_rule_solver.py | 62 ++----------------- tests/integration_test_agent_tool_graph.py | 1 + .../integration_test_offline_memory_agent.py | 4 +- tests/test_tool_rule_solver.py | 16 +++-- 4 files changed, 16 insertions(+), 67 deletions(-) diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index b8f2abbd71..7fd8410eb9 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -51,9 +51,6 @@ def __init__(self, tool_rules: List[BaseToolRule], **kwargs): assert isinstance(rule, TerminalToolRule) self.terminal_tool_rules.append(rule) - # Validate the tool rules to ensure they form a DAG - if not self.validate_tool_rules(): - raise ToolRuleValidationError("Tool rules does not have a path from Init to Terminal.") def update_tool_usage(self, tool_name: str): """Update the internal state to track the last tool called.""" @@ -108,53 +105,6 @@ def validate_conditional_tool(self, rule: ConditionalToolRule): raise ToolRuleValidationError("Conditional tool rule must have a child output mapping for each child tool.") return True - def validate_tool_rules(self) -> bool: - """ - Validate that there exists a path from every init tool to a terminal tool. - Returns True if valid (path exists), otherwise False. - """ - # Build adjacency list for the tool graph - adjacency_list: Dict[str, List[str]] = {rule.tool_name: rule.children for rule in self.tool_rules} - - init_tool_names = {rule.tool_name for rule in self.init_tool_rules} - terminal_tool_names = {rule.tool_name for rule in self.terminal_tool_rules} - - # Initial checks - if len(init_tool_names) == 0: - if len(terminal_tool_names) + len(self.tool_rules) > 0: - return False # No init tools defined - else: - return True # No tool rules - if len(terminal_tool_names) == 0: - if len(adjacency_list) > 0: - return False # No terminal tools defined - else: - return True # Only init tools - - # Define BFS helper function to find path to terminal tool - def has_path_to_terminal(start_tool: str) -> bool: - visited = set() - queue = deque([start_tool]) - visited.add(start_tool) - - while queue: - current_tool = queue.popleft() - if current_tool in terminal_tool_names: - return True - - for child in adjacency_list.get(current_tool, []): - if child not in visited: - visited.add(child) - queue.append(child) - return False - - # Check if each init tool has a path to a terminal tool - for init_tool_name in init_tool_names: - if not has_path_to_terminal(init_tool_name): - return False - - return True # All init tools have paths to terminal tools - def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str: ''' Parse function response to determine which child tool to use based on the mapping @@ -173,18 +123,18 @@ def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_res for key in tool.child_output_mapping: # Convert function output to match key type for comparison - if key == "true" or key == "false": - try: - typed_output = function_output.lower() - except AttributeError: - continue + if isinstance(key, bool): + typed_output = function_output.lower() == "true" elif isinstance(key, int): try: typed_output = int(function_output) except (ValueError, TypeError): continue else: # string - typed_output = str(function_output) + if function_output == "True" or function_output == "False": + typed_output = function_output.lower() + else: + typed_output = function_output if typed_output == key: return tool.child_output_mapping[key] diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index aa59c47c4d..e5acfd4824 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -250,6 +250,7 @@ def test_claude_initial_tool_rule_enforced(mock_e2b_api_key_none): tool_rules = [ InitToolRule(tool_name=t1_name), ChildToolRule(tool_name=t1_name, children=[t2_name]), + TerminalToolRule(tool_name=t2_name) ] tools = [t1, t2] diff --git a/tests/integration_test_offline_memory_agent.py b/tests/integration_test_offline_memory_agent.py index 07b7c732b2..15d4161d5e 100644 --- a/tests/integration_test_offline_memory_agent.py +++ b/tests/integration_test_offline_memory_agent.py @@ -74,8 +74,8 @@ def test_ripple_edit(client, mock_e2b_api_key_none): assert set(conversation_agent.memory.list_block_labels()) == {"persona", "human", "fact_block", "rethink_memory_block"} - rethink_memory_tool = client.create_tool(rethink_memory) - finish_rethinking_memory_tool = client.create_tool(finish_rethinking_memory) + rethink_memory_tool = client.create_or_update_tool(rethink_memory) + finish_rethinking_memory_tool = client.create_or_update_tool(finish_rethinking_memory) offline_memory_agent = client.create_agent( name="offline_memory_agent", agent_type=AgentType.offline_memory_agent, diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index e1170e4b73..25434ca28d 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -36,9 +36,7 @@ def test_get_allowed_tool_names_with_subsequent_rule(): # Setup: Tool rule sequence init_rule = InitToolRule(tool_name=START_TOOL) rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL]) - rule_2 = ChildToolRule(tool_name=NEXT_TOOL, children=[END_TOOL]) - terminal_rule = TerminalToolRule(tool_name=END_TOOL) - solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1, rule_2], terminal_tool_rules=[terminal_rule]) + solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1], terminal_tool_rules=[]) # Action: Update usage and get allowed tools solver.update_tool_usage(START_TOOL) @@ -51,9 +49,8 @@ def test_get_allowed_tool_names_with_subsequent_rule(): def test_is_terminal_tool(): # Setup: Terminal tool rule configuration init_rule = InitToolRule(tool_name=START_TOOL) - rule_1 = ChildToolRule(tool_name=START_TOOL, children=[END_TOOL]) terminal_rule = TerminalToolRule(tool_name=END_TOOL) - solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1], terminal_tool_rules=[terminal_rule]) + solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[terminal_rule]) # Action & Assert: Verify terminal and non-terminal tools assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as a terminal tool" @@ -83,9 +80,9 @@ def test_get_allowed_tool_names_no_matching_rule_error(): init_rule = InitToolRule(tool_name=START_TOOL) solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[]) - # Action & Assert: Set last tool to an unrecognized tool and expect RuntimeError when error_on_empty=True + # Action & Assert: Set last tool to an unrecognized tool and expect ValueError solver.update_tool_usage(UNRECOGNIZED_TOOL) - with pytest.raises(RuntimeError, match="resolved to no more possible tool calls"): + with pytest.raises(ValueError, match=f"No tool rule found for {UNRECOGNIZED_TOOL}"): solver.get_allowed_tool_names(error_on_empty=True) @@ -119,7 +116,7 @@ def test_conditional_tool_rule(): rule = ConditionalToolRule( tool_name=START_TOOL, children=[START_TOOL, END_TOOL], - default_child=END_TOOL, + default_child=START_TOOL, child_output_mapping={True: END_TOOL, False: START_TOOL} ) solver = ToolRulesSolver(tool_rules=[init_rule, rule, terminal_rule]) @@ -130,7 +127,8 @@ def test_conditional_tool_rule(): # Step 2: After using 'start_tool' solver.update_tool_usage(START_TOOL) - assert set(solver.get_allowed_tool_names()) == set({END_TOOL, START_TOOL}), "After 'start_tool', should allow 'start_tool' or 'end_tool'" + assert solver.get_allowed_tool_names(last_function_response='{"message": "true"}') == [END_TOOL], "After 'start_tool' returns true, should allow 'end_tool'" + assert solver.get_allowed_tool_names(last_function_response='{"message": "false"}') == [START_TOOL], "After 'start_tool' returns false, should allow 'start_tool'" # Step 3: After using 'end_tool' assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as terminal" From 5e6300f61b6ae4fb352f815bf6f3088453dd3611 Mon Sep 17 00:00:00 2001 From: Mindy Long Date: Wed, 18 Dec 2024 17:01:46 -0800 Subject: [PATCH 04/11] fixed a test --- tests/test_tool_rule_solver.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index 25434ca28d..a516b5e70e 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -177,9 +177,7 @@ def test_tool_rules_with_invalid_path(): rule_4 = ChildToolRule(tool_name=FINAL_TOOL, children=[END_TOOL]) # Disconnected rule, no cycle here terminal_rule = TerminalToolRule(tool_name=END_TOOL) - # Action & Assert: Attempt to create the ToolRulesSolver with a cycle should raise ValidationError - with pytest.raises(ToolRuleValidationError, match="Tool rules does not have a path from Init to Terminal."): - ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, terminal_rule]) + ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, terminal_rule]) # Now: add a path from the start tool to the final tool rule_5 = ConditionalToolRule( From aaab950600ba55b87bdc14ceb36645d839a62a90 Mon Sep 17 00:00:00 2001 From: Mindy Long Date: Wed, 18 Dec 2024 17:11:08 -0800 Subject: [PATCH 05/11] removed unnecessary imports --- letta/helpers/tool_rule_solver.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index 7fd8410eb9..0bca7a398a 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -1,6 +1,5 @@ import json -from typing import Dict, List, Optional, Union -from collections import deque +from typing import List, Optional, Union from pydantic import BaseModel, Field From 990fdc0e14136c82e45aee496cfc80bb33eb7505 Mon Sep 17 00:00:00 2001 From: Mindy Long Date: Wed, 18 Dec 2024 17:14:33 -0800 Subject: [PATCH 06/11] added test back in --- tests/test_tool_rule_solver.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index a516b5e70e..e5d8d8b251 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -57,15 +57,15 @@ def test_is_terminal_tool(): assert solver.is_terminal_tool(START_TOOL) is False, "Should not recognize 'start_tool' as a terminal tool" -# def test_get_allowed_tool_names_no_matching_rule_warning(): -# # Setup: Tool rules with no matching rule for the last tool -# init_rule = InitToolRule(tool_name=START_TOOL) -# solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[]) +def test_get_allowed_tool_names_no_matching_rule_warning(): + # Setup: Tool rules with no matching rule for the last tool + init_rule = InitToolRule(tool_name=START_TOOL) + solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[]) -# # Action: Set last tool to an unrecognized tool and check warnings -# solver.update_tool_usage(UNRECOGNIZED_TOOL) + # Action: Set last tool to an unrecognized tool and check warnings + solver.update_tool_usage(UNRECOGNIZED_TOOL) - # NOTE: removed for now since this warning is getting triggered on every LLM call + # # NOTE: removed for now since this warning is getting triggered on every LLM call # with warnings.catch_warnings(record=True) as w: # allowed_tools = solver.get_allowed_tool_names() From 17660f78da5b57936e3ca7ea4b3f59e91be030f9 Mon Sep 17 00:00:00 2001 From: Mindy Long Date: Thu, 19 Dec 2024 10:03:05 -0800 Subject: [PATCH 07/11] Resolved PR comments * removed 'children' field * allowed default_child to be None * updated tests --- letta/agent.py | 1 + letta/helpers/tool_rule_solver.py | 16 +++-- letta/schemas/tool_rule.py | 9 ++- tests/integration_test_agent_tool_graph.py | 82 +++++++++++++++++++++- tests/test_tool_rule_solver.py | 27 +------ 5 files changed, 97 insertions(+), 38 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index e3bf859908..ed7f3a90b9 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -299,6 +299,7 @@ def __init__( self.first_message_verify_mono = first_message_verify_mono # State needed for conditional tool chaining + # TODO: when agent reloads, load this from past messages self.last_function_response = None # Controls if the convo memory pressure warning is triggered diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index 0bca7a398a..02919b2e8a 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -74,7 +74,8 @@ def get_allowed_tool_names(self, error_on_empty: bool = False, last_function_res if isinstance(current_rule, ConditionalToolRule): if not last_function_response: raise ValueError("Conditional tool rule requires an LLM response to determine which child tool to use") - return [self.evaluate_conditional_tool(current_rule, last_function_response)] + next_tool = self.evaluate_conditional_tool(current_rule, last_function_response) + return [next_tool] if next_tool else [] return current_rule.children if current_rule.children else [] @@ -96,12 +97,8 @@ def validate_conditional_tool(self, rule: ConditionalToolRule): Raises: ToolRuleValidationError: If the rule is invalid ''' - if rule.children is None or len(rule.children) == 0: + if len(rule.child_output_mapping) == 0: raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.") - if len(rule.children) != len(rule.child_output_mapping): - raise ToolRuleValidationError("Conditional tool rule must have a child output mapping for each child tool.") - if set(rule.children) != set(rule.child_output_mapping.values()): - raise ToolRuleValidationError("Conditional tool rule must have a child output mapping for each child tool.") return True def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str: @@ -129,9 +126,16 @@ def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_res typed_output = int(function_output) except (ValueError, TypeError): continue + elif isinstance(key, float): + try: + typed_output = float(function_output) + except (ValueError, TypeError): + continue else: # string if function_output == "True" or function_output == "False": typed_output = function_output.lower() + elif function_output == "None": + typed_output = None else: typed_output = function_output diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py index a7f4f7cc1b..259e5452dc 100644 --- a/letta/schemas/tool_rule.py +++ b/letta/schemas/tool_rule.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Union +from typing import Any, Dict, List, Optional, Union from pydantic import Field @@ -26,10 +26,9 @@ class ConditionalToolRule(BaseToolRule): A ToolRule that conditionally maps to different child tools based on the output. """ type: ToolRuleType = ToolRuleType.conditional - default_child: str = Field(..., description="The default child tool to be called") - child_output_mapping: Dict[Union[bool, str, int], str] = Field(..., description="The output case to check for mapping") - children: List[str] = Field(..., description="The child tool to call when output matches the case") - throw_error: bool = Field(default=False, description="Whether to throw an error when output doesn't match any case") + default_child: Optional[str] = Field(None, description="The default child tool to be called. If None, any tool can be called.") + child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping") + require_output_mapping: bool = Field(default=False, description="Whether to throw an error when output doesn't match any case") class InitToolRule(BaseToolRule): diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index e5acfd4824..d433513e21 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -117,6 +117,13 @@ def can_play_game(): return random.random() < 0.5 +def return_none(): + """ + Really simple function + """ + return None + + def auto_error(): """ If you call this function, it will throw an error automatically. @@ -364,7 +371,6 @@ def test_agent_conditional_tool_easy(mock_e2b_api_key_none): ConditionalToolRule( tool_name=coin_flip_name, default_child=coin_flip_name, - children=[secret_word_tool], child_output_mapping={ "hj2hwibbqm": secret_word_tool, } @@ -436,7 +442,6 @@ def test_agent_conditional_tool_hard(mock_e2b_api_key_none): ConditionalToolRule( tool_name=play_game, default_child=play_game, # Keep trying if we can't play - children=[coin_flip_name], child_output_mapping={ True: coin_flip_name # Only allow access when can_play_game returns True } @@ -444,7 +449,6 @@ def test_agent_conditional_tool_hard(mock_e2b_api_key_none): ConditionalToolRule( tool_name=coin_flip_name, default_child=coin_flip_name, - children=[play_game, final_tool], child_output_mapping={ "hj2hwibbqm": final_tool, "START_OVER": play_game } @@ -492,3 +496,75 @@ def test_agent_conditional_tool_hard(mock_e2b_api_key_none): print(f"Got successful response from client: \n\n{response}") cleanup(client=client, agent_uuid=agent_uuid) + + +@pytest.mark.timeout(60) +def test_agent_conditional_tool_without_default_child(mock_e2b_api_key_none): + """ + Test the agent with a conditional tool that allows any child tool to be called if a function returns None. + + Tool Flow: + + return_none + | + v + any tool... <-- When output doesn't match mapping, agent can call any tool + """ + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + + # Create tools - we'll make several available to the agent + tool_name = "return_none" + + tool = client.create_or_update_tool(return_none, name=tool_name) + secret_word = client.create_or_update_tool(first_secret_word, name="first_secret_word") + + # Make tool rules - only map one output, let others be free choice + tool_rules = [ + InitToolRule(tool_name=tool_name), + ConditionalToolRule( + tool_name=tool_name, + default_child=None, # Allow any tool to be called if output doesn't match + child_output_mapping={ + "anything but none": "first_secret_word" + } + ) + ] + tools = [tool, secret_word] + + # Setup agent with all tools + agent_state = setup_agent( + client, + config_file, + agent_uuid=agent_uuid, + tool_ids=[t.id for t in tools], + tool_rules=tool_rules + ) + + # Ask agent to try different tools based on the game output + response = client.user_message( + agent_id=agent_state.id, + message="call a function, any function. then call send_message" + ) + + # Make checks + assert_sanity_checks(response) + + # Assert return_none was called + assert_invoked_function_call(response.messages, tool_name) + + # Assert any base function called afterward + found_any_tool = False + found_return_none = False + for m in response.messages: + if isinstance(m, FunctionCallMessage): + if m.function_call.name == tool_name: + found_return_none = True + elif found_return_none and m.function_call.name: + found_any_tool = True + break + + assert found_any_tool, "Should have called any tool after return_none" + + print(f"Got successful response from client: \n\n{response}") + cleanup(client=client, agent_uuid=agent_uuid) \ No newline at end of file diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index e5d8d8b251..c524d53a34 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -115,8 +115,7 @@ def test_conditional_tool_rule(): terminal_rule = TerminalToolRule(tool_name=END_TOOL) rule = ConditionalToolRule( tool_name=START_TOOL, - children=[START_TOOL, END_TOOL], - default_child=START_TOOL, + default_child=None, child_output_mapping={True: END_TOOL, False: START_TOOL} ) solver = ToolRulesSolver(tool_rules=[init_rule, rule, terminal_rule]) @@ -140,32 +139,13 @@ def test_invalid_conditional_tool_rule(): terminal_rule = TerminalToolRule(tool_name=END_TOOL) invalid_rule_1 = ConditionalToolRule( tool_name=START_TOOL, - children=[START_TOOL], default_child=END_TOOL, - child_output_mapping={True: END_TOOL, False: START_TOOL} - ) - invalid_rule_2 = ConditionalToolRule( - tool_name=START_TOOL, - children=[START_TOOL, END_TOOL], - default_child=END_TOOL, - child_output_mapping={True: END_TOOL} - ) - invalid_rule_3 = ConditionalToolRule( - tool_name=START_TOOL, - children=[START_TOOL, FINAL_TOOL], - default_child=FINAL_TOOL, - child_output_mapping={True: END_TOOL, False: START_TOOL} + child_output_mapping={} ) # Test 1: Missing child output mapping - with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have a child output mapping for each child tool."): + with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have at least one child tool."): ToolRulesSolver(tool_rules=[init_rule, invalid_rule_1, terminal_rule]) - with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have a child output mapping for each child tool."): - ToolRulesSolver(tool_rules=[init_rule, invalid_rule_2, terminal_rule]) - - # Test 2: Missing child - with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have a child output mapping for each child tool."): - ToolRulesSolver(tool_rules=[init_rule, invalid_rule_3, terminal_rule]) def test_tool_rules_with_invalid_path(): @@ -182,7 +162,6 @@ def test_tool_rules_with_invalid_path(): # Now: add a path from the start tool to the final tool rule_5 = ConditionalToolRule( tool_name=HELPER_TOOL, - children=[START_TOOL, FINAL_TOOL], default_child=FINAL_TOOL, child_output_mapping={True: START_TOOL, False: FINAL_TOOL}, ) From 2faf5055bed32bc2846c6c6264ca5746d62062be Mon Sep 17 00:00:00 2001 From: Mindy Long Date: Thu, 19 Dec 2024 12:19:22 -0800 Subject: [PATCH 08/11] agent remembers last_function_response on reload --- letta/agent.py | 22 ++++- tests/integration_test_agent_tool_graph.py | 110 +++++++++++++++++---- 2 files changed, 111 insertions(+), 21 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index ed7f3a90b9..005254d383 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -1,5 +1,6 @@ import datetime import inspect +import json import time import traceback import warnings @@ -32,6 +33,7 @@ from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole +from letta.schemas.letta_message import FunctionReturn from letta.schemas.memory import ContextWindowOverview, Memory from letta.schemas.message import Message, MessageUpdate from letta.schemas.openai.chat_completion_request import ( @@ -298,10 +300,6 @@ def __init__( self.first_message_verify_mono = first_message_verify_mono - # State needed for conditional tool chaining - # TODO: when agent reloads, load this from past messages - self.last_function_response = None - # Controls if the convo memory pressure warning is triggered # When an alert is sent in the message queue, set this to True (to avoid repeat alerts) # When the summarizer is run, set this back to False (to reset) @@ -375,6 +373,9 @@ def __init__( self._append_to_messages(added_messages=init_messages_objs) self._validate_message_buffer_is_utc() + # Load last function response from message history + self.last_function_response = self.load_last_function_response() + # Keep track of the total number of messages throughout all time self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system) self.messages_total_init = len(self._messages) - 1 @@ -393,6 +394,19 @@ def check_tool_rules(self): else: self.supports_structured_output = True + def load_last_function_response(self): + """Load the last function response from message history""" + for i in range(len(self._messages) - 1, -1, -1): + msg = self._messages[i] + if msg.role == MessageRole.tool and msg.text: + try: + response_json = json.loads(msg.text) + if response_json.get("message"): + return response_json["message"] + except (json.JSONDecodeError, KeyError): + raise ValueError(f"Invalid JSON format in message: {msg.text}") + return None + def update_memory_if_change(self, new_memory: Memory) -> bool: """ Update internal memory object and system prompt if there have been modifications. diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index d433513e21..99b345c4db 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -316,28 +316,44 @@ def test_agent_no_structured_output_with_one_child_tool(mock_e2b_api_key_none): ] for config in config_files: - agent_state = setup_agent(client, config, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) - response = client.user_message(agent_id=agent_state.id, message="hi. run archival memory search") + max_retries = 3 + last_error = None - # Make checks - assert_sanity_checks(response) + for attempt in range(max_retries): + try: + agent_state = setup_agent(client, config, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) + response = client.user_message(agent_id=agent_state.id, message="hi. run archival memory search") - # Assert the tools were called - assert_invoked_function_call(response.messages, "archival_memory_search") - assert_invoked_function_call(response.messages, "archival_memory_insert") - assert_invoked_function_call(response.messages, "send_message") + # Make checks + assert_sanity_checks(response) - # Check ordering of tool calls - tool_names = [t.name for t in [archival_memory_search, archival_memory_insert, send_message]] - for m in response.messages: - if isinstance(m, FunctionCallMessage): - # Check that it's equal to the first one - assert m.function_call.name == tool_names[0] + # Assert the tools were called + assert_invoked_function_call(response.messages, "archival_memory_search") + assert_invoked_function_call(response.messages, "archival_memory_insert") + assert_invoked_function_call(response.messages, "send_message") - # Pop out first one - tool_names = tool_names[1:] + # Check ordering of tool calls + tool_names = [t.name for t in [archival_memory_search, archival_memory_insert, send_message]] + for m in response.messages: + if isinstance(m, FunctionCallMessage): + # Check that it's equal to the first one + assert m.function_call.name == tool_names[0] + + # Pop out first one + tool_names = tool_names[1:] + + print(f"Got successful response from client: \n\n{response}") + break # Test passed, exit retry loop + + except AssertionError as e: + last_error = e + print(f"Attempt {attempt + 1} failed, retrying..." if attempt < max_retries - 1 else f"All {max_retries} attempts failed") + cleanup(client=client, agent_uuid=agent_uuid) + continue + + if last_error and attempt == max_retries - 1: + raise last_error # Re-raise the last error if all retries failed - print(f"Got successful response from client: \n\n{response}") cleanup(client=client, agent_uuid=agent_uuid) @@ -566,5 +582,65 @@ def test_agent_conditional_tool_without_default_child(mock_e2b_api_key_none): assert found_any_tool, "Should have called any tool after return_none" + print(f"Got successful response from client: \n\n{response}") + cleanup(client=client, agent_uuid=agent_uuid) + + +@pytest.mark.timeout(60) +def test_agent_reload_remembers_function_response(mock_e2b_api_key_none): + """ + Test that when an agent is reloaded, it remembers the last function response for conditional tool chaining. + + Tool Flow: + + flip_coin + | + v + fourth_secret_word <-- Should remember coin flip result after reload + """ + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + + # Create tools + flip_coin_name = "flip_coin" + secret_word = "fourth_secret_word" + flip_coin_tool = client.create_or_update_tool(flip_coin, name=flip_coin_name) + secret_word_tool = client.create_or_update_tool(fourth_secret_word, name=secret_word) + + # Make tool rules - map coin flip to fourth_secret_word + tool_rules = [ + InitToolRule(tool_name=flip_coin_name), + ConditionalToolRule( + tool_name=flip_coin_name, + default_child=flip_coin_name, # Allow any tool to be called if output doesn't match + child_output_mapping={ + "hj2hwibbqm": secret_word + } + ), + TerminalToolRule(tool_name=secret_word) + ] + tools = [flip_coin_tool, secret_word_tool] + + # Setup initial agent + agent_state = setup_agent( + client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules + ) + + # Call flip_coin first + response = client.user_message(agent_id=agent_state.id, message="flip a coin") + assert_invoked_function_call(response.messages, flip_coin_name) + assert_invoked_function_call(response.messages, secret_word) + found_fourth_secret = False + for m in response.messages: + if isinstance(m, FunctionCallMessage) and m.function_call.name == secret_word: + found_fourth_secret = True + break + + assert found_fourth_secret, "Reloaded agent should remember coin flip result and call fourth_secret_word if True" + + # Reload the agent + reloaded_agent = client.server.load_agent(agent_id=agent_state.id, actor=client.user) + assert reloaded_agent.last_function_response is not None + print(f"Got successful response from client: \n\n{response}") cleanup(client=client, agent_uuid=agent_uuid) \ No newline at end of file From bbf9e5deb3780554330bd35906cd4165b2d11012 Mon Sep 17 00:00:00 2001 From: Mindy Long Date: Thu, 19 Dec 2024 13:16:52 -0800 Subject: [PATCH 09/11] remove unnecessary imports --- letta/agent.py | 1 - tests/integration_test_agent_tool_graph.py | 1 - 2 files changed, 2 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index bff3e233c3..82958acda5 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -33,7 +33,6 @@ from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole -from letta.schemas.letta_message import FunctionReturn from letta.schemas.memory import ContextWindowOverview, Memory from letta.schemas.message import Message, MessageUpdate from letta.schemas.openai.chat_completion_request import ( diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index 6f14740a72..746cc59479 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -3,7 +3,6 @@ import pytest from letta import create_client -from letta.schemas.letta_message import FunctionCallMessage from letta.schemas.tool_rule import ( ChildToolRule, ConditionalToolRule, From d3f97ca6e271a02a9eb7b9b9cd025a1a33baf49a Mon Sep 17 00:00:00 2001 From: Mindy Long Date: Thu, 19 Dec 2024 13:45:38 -0800 Subject: [PATCH 10/11] fixed tests after merge --- tests/helpers/endpoints_helper.py | 6 ++--- tests/integration_test_agent_tool_graph.py | 30 +++++++++------------- 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index a526f2315b..c9cca1a279 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -1,7 +1,7 @@ import json import logging import uuid -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Sequence, Union from letta.llm_api.helpers import unpack_inner_thoughts_from_kwargs from letta.schemas.tool_rule import BaseToolRule @@ -373,7 +373,7 @@ def assert_sanity_checks(response: LettaResponse): assert len(response.messages) > 0, response -def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keyword: str, case_sensitive: bool = False) -> None: +def assert_invoked_send_message_with_keyword(messages: Sequence[LettaMessage], keyword: str, case_sensitive: bool = False) -> None: # Find first instance of send_message target_message = None for message in messages: @@ -406,7 +406,7 @@ def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keywo raise InvalidToolCallError(messages=[target_message], explanation=f"Message argument did not contain keyword={keyword}") -def assert_invoked_function_call(messages: List[LettaMessage], function_name: str) -> None: +def assert_invoked_function_call(messages: Sequence[LettaMessage], function_name: str) -> None: for message in messages: if isinstance(message, ToolCallMessage) and message.tool_call.name == function_name: # Found it, do nothing diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index 746cc59479..44aad0d077 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -3,6 +3,7 @@ import pytest from letta import create_client +from letta.schemas.letta_message import ToolCallMessage from letta.schemas.tool_rule import ( ChildToolRule, ConditionalToolRule, @@ -331,19 +332,12 @@ def test_agent_no_structured_output_with_one_child_tool(mock_e2b_api_key_none): assert_invoked_function_call(response.messages, "archival_memory_insert") assert_invoked_function_call(response.messages, "send_message") - # Check ordering of tool calls - tool_names = [t.name for t in [archival_memory_search, archival_memory_insert, send_message]] - for m in response.messages: - if isinstance(m, ToolCallMessage): - # Check that it's equal to the first one - assert m.tool_call.name == tool_names[0] - # Check ordering of tool calls tool_names = [t.name for t in [archival_memory_search, archival_memory_insert, send_message]] for m in response.messages: - if isinstance(m, FunctionCallMessage): + if isinstance(m, ToolCallMessage): # Check that it's equal to the first one - assert m.function_call.name == tool_names[0] + assert m.tool_call.name == tool_names[0] # Pop out first one tool_names = tool_names[1:] @@ -415,13 +409,13 @@ def test_agent_conditional_tool_easy(mock_e2b_api_key_none): # Check ordering of tool calls found_secret_word = False for m in response.messages: - if isinstance(m, FunctionCallMessage): - if m.function_call.name == secret_word_tool: + if isinstance(m, ToolCallMessage): + if m.tool_call.name == secret_word_tool: # Should be the last tool call found_secret_word = True else: # Before finding secret_word, only flip_coin should be called - assert m.function_call.name == coin_flip_name + assert m.tool_call.name == coin_flip_name assert not found_secret_word # Ensure we found the secret word exactly once @@ -502,8 +496,8 @@ def test_agent_conditional_tool_hard(mock_e2b_api_key_none): # Check ordering of tool calls found_words = [] for m in response.messages: - if isinstance(m, FunctionCallMessage): - name = m.function_call.name + if isinstance(m, ToolCallMessage): + name = m.tool_call.name if name in [play_game, coin_flip_name]: # Before finding secret_word, only can_play_game and flip_coin should be called assert name in [play_game, coin_flip_name] @@ -579,10 +573,10 @@ def test_agent_conditional_tool_without_default_child(mock_e2b_api_key_none): found_any_tool = False found_return_none = False for m in response.messages: - if isinstance(m, FunctionCallMessage): - if m.function_call.name == tool_name: + if isinstance(m, ToolCallMessage): + if m.tool_call.name == tool_name: found_return_none = True - elif found_return_none and m.function_call.name: + elif found_return_none and m.tool_call.name: found_any_tool = True break @@ -638,7 +632,7 @@ def test_agent_reload_remembers_function_response(mock_e2b_api_key_none): assert_invoked_function_call(response.messages, secret_word) found_fourth_secret = False for m in response.messages: - if isinstance(m, FunctionCallMessage) and m.function_call.name == secret_word: + if isinstance(m, ToolCallMessage) and m.tool_call.name == secret_word: found_fourth_secret = True break From 0b4eab6b894cd10e12c1c03a4ca89ba949099390 Mon Sep 17 00:00:00 2001 From: Mindy Long Date: Thu, 19 Dec 2024 13:57:44 -0800 Subject: [PATCH 11/11] fixed client legacy test --- tests/test_client_legacy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 30b78c911f..6d13046cd0 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -529,6 +529,7 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentState): """Test that we can update the details of a message""" + import json # create a message message_response = client.send_message(agent_id=agent.id, message="Test message", role="user") @@ -537,7 +538,7 @@ def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentStat assert isinstance(message_response.messages[-1], ToolReturnMessage) message = message_response.messages[-1] - new_text = "This exact string would never show up in the message???" + new_text = json.dumps({"message": "This exact string would never show up in the message???"}) new_message = client.update_message(message_id=message.id, text=new_text, agent_id=agent.id) assert new_message.text == new_text