From 0b5bebcf4db66d2d421a33baf4afff80ba60b56a Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 20 Dec 2024 14:39:31 -0800 Subject: [PATCH] Fix summarizer tests --- .github/workflows/tests.yml | 1 - letta/services/agent_manager.py | 2 +- tests/integration_test_summarizer.py | 110 +++++++++++++++++++++- tests/test_summarize.py | 133 --------------------------- 4 files changed, 110 insertions(+), 136 deletions(-) delete mode 100644 tests/test_summarize.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 43e66727ab..e4c46c5e85 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -33,7 +33,6 @@ jobs: - "test_memory.py" - "test_utils.py" - "test_stream_buffer_readers.py" - - "test_summarize.py" services: qdrant: image: qdrant/qdrant diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index fd0ac5bf31..711dbcbaa6 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -374,7 +374,7 @@ def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message - return self.set_in_context_messages(agent_id=agent_id, message_ids=[m.id for m in new_messages], actor=actor) + return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor) @enforce_types def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState: diff --git a/tests/integration_test_summarizer.py b/tests/integration_test_summarizer.py index 9131797cc5..b4de0043b4 100644 --- a/tests/integration_test_summarizer.py +++ b/tests/integration_test_summarizer.py @@ -1,13 +1,16 @@ import json import os import uuid +from typing import List import pytest from letta import create_client from letta.agent import Agent +from letta.client.client import LocalClient from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message from letta.streaming_interface import StreamingRefreshCLIInterface from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH from tests.helpers.utils import cleanup @@ -16,6 +19,110 @@ LLM_CONFIG_DIR = "tests/configs/llm_model_configs" SUMMARY_KEY_PHRASE = "The following is a summary" +test_agent_name = f"test_client_{str(uuid.uuid4())}" + +# TODO: these tests should include looping through LLM providers, since behavior may vary across providers +# TODO: these tests should add function calls into the summarized message sequence:W + + +@pytest.fixture(scope="module") +def client(): + client = create_client() + # client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) + client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) + client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) + + yield client + + +@pytest.fixture(scope="module") +def agent_state(client): + # Generate uuid for agent name for this example + agent_state = client.create_agent(name=test_agent_name) + yield agent_state + + client.delete_agent(agent_state.id) + + +def test_summarize_messages_inplace(client, agent_state, mock_e2b_api_key_none): + """Test summarization via sending the summarize CLI command or via a direct call to the agent object""" + # First send a few messages (5) + response = client.user_message( + agent_id=agent_state.id, + message="Hey, how's it going? What do you think about this whole shindig", + ).messages + assert response is not None and len(response) > 0 + print(f"test_summarize: response={response}") + + response = client.user_message( + agent_id=agent_state.id, + message="Any thoughts on the meaning of life?", + ).messages + assert response is not None and len(response) > 0 + print(f"test_summarize: response={response}") + + response = client.user_message(agent_id=agent_state.id, message="Does the number 42 ring a bell?").messages + assert response is not None and len(response) > 0 + print(f"test_summarize: response={response}") + + response = client.user_message( + agent_id=agent_state.id, + message="Would you be surprised to learn that you're actually conversing with an AI right now?", + ).messages + assert response is not None and len(response) > 0 + print(f"test_summarize: response={response}") + + # reload agent object + agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user) + + agent_obj.summarize_messages_inplace() + + +def test_auto_summarize(client, mock_e2b_api_key_none): + """Test that the summarizer triggers by itself""" + small_context_llm_config = LLMConfig.default_config("gpt-4o-mini") + small_context_llm_config.context_window = 4000 + + small_agent_state = client.create_agent( + name="small_context_agent", + llm_config=small_context_llm_config, + ) + + try: + + def summarize_message_exists(messages: List[Message]) -> bool: + for message in messages: + if message.text and "The following is a summary of the previous" in message.text: + print(f"Summarize message found after {message_count} messages: \n {message.text}") + return True + return False + + MAX_ATTEMPTS = 10 + message_count = 0 + while True: + + # send a message + response = client.user_message( + agent_id=small_agent_state.id, + message="What is the meaning of life?", + ) + message_count += 1 + + print(f"Message {message_count}: \n\n{response.messages}" + "--------------------------------") + + # check if the summarize message is inside the messages + assert isinstance(client, LocalClient), "Test only works with LocalClient" + in_context_messages = client.server.agent_manager.get_in_context_messages(agent_id=small_agent_state.id, actor=client.user) + print("SUMMARY", summarize_message_exists(in_context_messages)) + if summarize_message_exists(in_context_messages): + break + + if message_count > MAX_ATTEMPTS: + raise Exception(f"Summarize message not found after {message_count} messages") + + finally: + client.delete_agent(small_agent_state.id) + @pytest.mark.parametrize( "config_filename", @@ -69,4 +176,5 @@ def test_summarizer(config_filename): # Invoke a summarize letta_agent.summarize_messages_inplace(preserve_last_N_messages=False) - assert SUMMARY_KEY_PHRASE in letta_agent.messages[1]["content"], f"Test failed for config: {config_filename}" + in_context_messages = client.get_in_context_messages(agent_state.id) + assert SUMMARY_KEY_PHRASE in in_context_messages[1].text, f"Test failed for config: {config_filename}" diff --git a/tests/test_summarize.py b/tests/test_summarize.py deleted file mode 100644 index d798ff866c..0000000000 --- a/tests/test_summarize.py +++ /dev/null @@ -1,133 +0,0 @@ -import uuid -from typing import List - -from letta import create_client -from letta.client.client import LocalClient -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import Message - -from .utils import wipe_config - -# test_agent_id = "test_agent" -test_agent_name = f"test_client_{str(uuid.uuid4())}" -client = None -agent_obj = None - -# TODO: these tests should include looping through LLM providers, since behavior may vary across providers -# TODO: these tests should add function calls into the summarized message sequence:W - - -def create_test_agent(): - """Create a test agent that we can call functions on""" - wipe_config() - - global client - client = create_client() - - client.set_default_llm_config(LLMConfig.default_config("gpt-4")) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - - agent_state = client.create_agent( - name=test_agent_name, - ) - - global agent_obj - agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user) - - -def test_summarize_messages_inplace(mock_e2b_api_key_none): - """Test summarization via sending the summarize CLI command or via a direct call to the agent object""" - global client - global agent_obj - - if agent_obj is None: - create_test_agent() - - assert agent_obj is not None, "Run create_agent test first" - assert client is not None, "Run create_agent test first" - - # First send a few messages (5) - response = client.user_message( - agent_id=agent_obj.agent_state.id, - message="Hey, how's it going? What do you think about this whole shindig", - ).messages - assert response is not None and len(response) > 0 - print(f"test_summarize: response={response}") - - response = client.user_message( - agent_id=agent_obj.agent_state.id, - message="Any thoughts on the meaning of life?", - ).messages - assert response is not None and len(response) > 0 - print(f"test_summarize: response={response}") - - response = client.user_message(agent_id=agent_obj.agent_state.id, message="Does the number 42 ring a bell?").messages - assert response is not None and len(response) > 0 - print(f"test_summarize: response={response}") - - response = client.user_message( - agent_id=agent_obj.agent_state.id, - message="Would you be surprised to learn that you're actually conversing with an AI right now?", - ).messages - assert response is not None and len(response) > 0 - print(f"test_summarize: response={response}") - - # reload agent object - agent_obj = client.server.load_agent(agent_id=agent_obj.agent_state.id, actor=client.user) - - agent_obj.summarize_messages_inplace() - print(f"Summarization succeeded: messages[1] = \n{agent_obj.messages[1]}") - # response = client.run_command(agent_id=agent_obj.agent_state.id, command="summarize") - - -def test_auto_summarize(mock_e2b_api_key_none): - """Test that the summarizer triggers by itself""" - client = create_client() - client.set_default_llm_config(LLMConfig.default_config("gpt-4")) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - - small_context_llm_config = LLMConfig.default_config("gpt-4") - # default system prompt + funcs lead to ~2300 tokens, after one message it's at 2523 tokens - SMALL_CONTEXT_WINDOW = 4000 - small_context_llm_config.context_window = SMALL_CONTEXT_WINDOW - - agent_state = client.create_agent( - name="small_context_agent", - llm_config=small_context_llm_config, - ) - - try: - - def summarize_message_exists(messages: List[Message]) -> bool: - for message in messages: - if message.text and "The following is a summary of the previous" in message.text: - print(f"Summarize message found after {message_count} messages: \n {message.text}") - return True - return False - - MAX_ATTEMPTS = 5 - message_count = 0 - while True: - - # send a message - response = client.user_message( - agent_id=agent_state.id, - message="What is the meaning of life?", - ) - message_count += 1 - - print(f"Message {message_count}: \n\n{response.messages}" + "--------------------------------") - - # check if the summarize message is inside the messages - assert isinstance(client, LocalClient), "Test only works with LocalClient" - in_context_messages = client.server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=client.user) - print("SUMMARY", summarize_message_exists(in_context_messages)) - if summarize_message_exists(in_context_messages): - break - - if message_count > MAX_ATTEMPTS: - raise Exception(f"Summarize message not found after {message_count} messages") - - finally: - client.delete_agent(agent_state.id)