From bbe8deaa2921f482b38747257e811e5862a283fa Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Fri, 20 Dec 2024 16:56:53 -0800 Subject: [PATCH] chore: Clean up `.load_agent` usage (#2298) --- ...0_add_cascading_deletes_for_sources_to_.py | 35 ++++++++ letta/agent.py | 28 ------- letta/client/client.py | 11 ++- letta/orm/source.py | 13 ++- letta/orm/sources_agents.py | 4 +- letta/server/rest_api/routers/v1/sources.py | 11 +-- letta/server/server.py | 79 +------------------ letta/settings.py | 3 - tests/test_model_letta_perfomance.py | 35 +++++++- tests/test_server.py | 77 +++++++++--------- 10 files changed, 133 insertions(+), 163 deletions(-) create mode 100644 alembic/versions/e78b4e82db30_add_cascading_deletes_for_sources_to_.py diff --git a/alembic/versions/e78b4e82db30_add_cascading_deletes_for_sources_to_.py b/alembic/versions/e78b4e82db30_add_cascading_deletes_for_sources_to_.py new file mode 100644 index 0000000000..dd59f2a046 --- /dev/null +++ b/alembic/versions/e78b4e82db30_add_cascading_deletes_for_sources_to_.py @@ -0,0 +1,35 @@ +"""Add cascading deletes for sources to agents + +Revision ID: e78b4e82db30 +Revises: d6632deac81d +Create Date: 2024-12-20 16:30:17.095888 + +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "e78b4e82db30" +down_revision: Union[str, None] = "d6632deac81d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("sources_agents_agent_id_fkey", "sources_agents", type_="foreignkey") + op.drop_constraint("sources_agents_source_id_fkey", "sources_agents", type_="foreignkey") + op.create_foreign_key(None, "sources_agents", "sources", ["source_id"], ["id"], ondelete="CASCADE") + op.create_foreign_key(None, "sources_agents", "agents", ["agent_id"], ["id"], ondelete="CASCADE") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, "sources_agents", type_="foreignkey") + op.drop_constraint(None, "sources_agents", type_="foreignkey") + op.create_foreign_key("sources_agents_source_id_fkey", "sources_agents", "sources", ["source_id"], ["id"]) + op.create_foreign_key("sources_agents_agent_id_fkey", "sources_agents", "agents", ["agent_id"], ["id"]) + # ### end Alembic commands ### diff --git a/letta/agent.py b/letta/agent.py index 072154a22e..a297f5b6ce 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -44,7 +44,6 @@ from letta.schemas.tool import Tool from letta.schemas.tool_rule import TerminalToolRule from letta.schemas.usage import LettaUsageStatistics -from letta.schemas.user import User as PydanticUser from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager from letta.services.helpers.agent_manager_helper import ( @@ -53,7 +52,6 @@ ) from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager -from letta.services.source_manager import SourceManager from letta.services.tool_execution_sandbox import ToolExecutionSandbox from letta.streaming_interface import StreamingRefreshCLIInterface from letta.system import ( @@ -969,32 +967,6 @@ def migrate_embedding(self, embedding_config: EmbeddingConfig): # TODO: recall memory raise NotImplementedError() - def attach_source( - self, - user: PydanticUser, - source_id: str, - source_manager: SourceManager, - agent_manager: AgentManager, - ): - """Attach a source to the agent using the SourcesAgents ORM relationship. - - Args: - user: User performing the action - source_id: ID of the source to attach - source_manager: SourceManager instance to verify source exists - agent_manager: AgentManager instance to manage agent-source relationship - """ - # Verify source exists and user has permission to access it - source = source_manager.get_source_by_id(source_id=source_id, actor=user) - assert source is not None, f"Source {source_id} not found in user's organization ({user.organization_id})" - - # Use the agent_manager to create the relationship - agent_manager.attach_source(agent_id=self.agent_state.id, source_id=source_id, actor=user) - - printd( - f"Attached data source {source.name} to agent {self.agent_state.name}.", - ) - def get_context_window(self) -> ContextWindowOverview: """Get the context window of the agent""" diff --git a/letta/client/client.py b/letta/client/client.py index e575979dc0..bb6d2f0ff8 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -2987,7 +2987,11 @@ def attach_source_to_agent(self, agent_id: str, source_id: Optional[str] = None, source_id (str): ID of the source source_name (str): Name of the source """ - self.server.attach_source_to_agent(source_id=source_id, source_name=source_name, agent_id=agent_id, user_id=self.user_id) + if source_name: + source = self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user) + source_id = source.id + + self.server.agent_manager.attach_source(source_id=source_id, agent_id=agent_id, actor=self.user) def detach_source_from_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None): """ @@ -2999,7 +3003,10 @@ def detach_source_from_agent(self, agent_id: str, source_id: Optional[str] = Non Returns: source (Source): Detached source """ - return self.server.detach_source_from_agent(source_id=source_id, source_name=source_name, agent_id=agent_id, user_id=self.user_id) + if source_name: + source = self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user) + source_id = source.id + return self.server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=self.user) def list_sources(self) -> List[Source]: """ diff --git a/letta/orm/source.py b/letta/orm/source.py index 3ecffda6d9..e7443ea67e 100644 --- a/letta/orm/source.py +++ b/letta/orm/source.py @@ -11,10 +11,10 @@ from letta.schemas.source import Source as PydanticSource if TYPE_CHECKING: - from letta.orm.organization import Organization + from letta.orm.agent import Agent from letta.orm.file import FileMetadata + from letta.orm.organization import Organization from letta.orm.passage import SourcePassage - from letta.orm.agent import Agent class Source(SqlalchemyBase, OrganizationMixin): @@ -32,4 +32,11 @@ class Source(SqlalchemyBase, OrganizationMixin): organization: Mapped["Organization"] = relationship("Organization", back_populates="sources") files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan") passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="source", cascade="all, delete-orphan") - agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources") + agents: Mapped[List["Agent"]] = relationship( + "Agent", + secondary="sources_agents", + back_populates="sources", + lazy="selectin", + cascade="all, delete", # Ensures rows in sources_agents are deleted when the source is deleted + passive_deletes=True, # Allows the database to handle deletion of orphaned rows + ) diff --git a/letta/orm/sources_agents.py b/letta/orm/sources_agents.py index cf502e71f4..ffe8a9d0ea 100644 --- a/letta/orm/sources_agents.py +++ b/letta/orm/sources_agents.py @@ -9,5 +9,5 @@ class SourcesAgents(Base): __tablename__ = "sources_agents" - agent_id: Mapped[String] = mapped_column(String, ForeignKey("agents.id"), primary_key=True) - source_id: Mapped[String] = mapped_column(String, ForeignKey("sources.id"), primary_key=True) + agent_id: Mapped[String] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True) + source_id: Mapped[String] = mapped_column(String, ForeignKey("sources.id", ondelete="CASCADE"), primary_key=True) diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index bcc3203dde..fb48d125ca 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -130,11 +130,8 @@ def attach_source_to_agent( Attach a data source to an existing agent. """ actor = server.user_manager.get_user_or_default(user_id=user_id) - - source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor) - assert source is not None, f"Source with id={source_id} not found." - source = server.attach_source_to_agent(source_id=source.id, agent_id=agent_id, user_id=actor.id) - return source + server.agent_manager.attach_source(source_id=source_id, agent_id=agent_id, actor=actor) + return server.source_manager.get_source_by_id(source_id=source_id, actor=actor) @router.post("/{source_id}/detach", response_model=Source, operation_id="detach_agent_from_source") @@ -148,8 +145,8 @@ def detach_source_from_agent( Detach a data source from an existing agent. """ actor = server.user_manager.get_user_or_default(user_id=user_id) - - return server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=actor.id) + server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor) + return server.source_manager.get_source_by_id(source_id=source_id, actor=actor) @router.post("/{source_id}/upload", response_model=Job, operation_id="upload_file_to_source") diff --git a/letta/server/server.py b/letta/server/server.py index 9d5dc28dbb..85aee52b2c 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -59,7 +59,7 @@ from letta.schemas.organization import Organization from letta.schemas.passage import Passage from letta.schemas.source import Source -from letta.schemas.tool import Tool, ToolCreate +from letta.schemas.tool import Tool from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.services.agent_manager import AgentManager @@ -303,11 +303,6 @@ def __init__( self.block_manager.add_default_blocks(actor=self.default_user) self.tool_manager.upsert_base_tools(actor=self.default_user) - # If there is a default org/user - # This logic may have to change in the future - if settings.load_default_external_tools: - self.add_default_external_tools(actor=self.default_user) - # collect providers (always has Letta as a default) self._enabled_providers: List[Provider] = [LettaProvider()] if model_settings.openai_api_key: @@ -431,9 +426,6 @@ def _step( skip_verify=True, ) - # save agent after step - # save_agent(letta_agent) - except Exception as e: logger.error(f"Error in server._step: {e}") print(traceback.print_exc()) @@ -944,11 +936,10 @@ def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor agent_states = self.source_manager.list_attached_agents(source_id=source_id, actor=actor) for agent_state in agent_states: agent_id = agent_state.id - agent = self.load_agent(agent_id=agent_id, actor=actor) # Attach source to agent curr_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id) - agent.attach_source(user=actor, source_id=source_id, source_manager=self.source_manager, agent_manager=self.agent_manager) + self.agent_manager.attach_source(agent_id=agent_state.id, source_id=source_id, actor=actor) new_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id) assert new_passage_size >= curr_passage_size # in case empty files are added @@ -973,56 +964,6 @@ def load_data( passage_count, document_count = load_data(connector, source, self.passage_manager, self.source_manager, actor=user) return passage_count, document_count - def attach_source_to_agent( - self, - user_id: str, - agent_id: str, - source_id: Optional[str] = None, - source_name: Optional[str] = None, - ) -> Source: - # attach a data source to an agent - # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user - actor = self.user_manager.get_user_or_default(user_id=user_id) - if source_id: - data_source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor) - elif source_name: - data_source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor) - else: - raise ValueError(f"Need to provide at least source_id or source_name to find the source.") - - assert data_source, f"Data source with id={source_id} or name={source_name} does not exist" - - # load agent - agent = self.load_agent(agent_id=agent_id, actor=actor) - - # attach source to agent - agent.attach_source(user=actor, source_id=data_source.id, source_manager=self.source_manager, agent_manager=self.agent_manager) - - return data_source - - def detach_source_from_agent( - self, - user_id: str, - agent_id: str, - source_id: Optional[str] = None, - source_name: Optional[str] = None, - ) -> Source: - # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user - actor = self.user_manager.get_user_or_default(user_id=user_id) - if source_id: - source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor) - elif source_name: - source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor) - source_id = source.id - else: - raise ValueError(f"Need to provide at least source_id or source_name to find the source.") - - # delete agent-source mapping - self.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor) - - # return back source data - return source - def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]: warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning) return [] @@ -1060,22 +1001,6 @@ def list_all_sources(self, actor: User) -> List[Source]: return sources_with_metadata - def add_default_external_tools(self, actor: User) -> bool: - """Add default langchain tools. Return true if successful, false otherwise.""" - success = True - tool_creates = ToolCreate.load_default_langchain_tools() - if tool_settings.composio_api_key: - tool_creates += ToolCreate.load_default_composio_tools() - for tool_create in tool_creates: - try: - self.tool_manager.create_or_update_tool(Tool(**tool_create.model_dump()), actor=actor) - except Exception as e: - warnings.warn(f"An error occurred while creating tool {tool_create}: {e}") - warnings.warn(traceback.format_exc()) - success = False - - return success - def update_agent_message(self, message_id: str, request: MessageUpdate, actor: User) -> Message: """Update the details of a message associated with an agent""" diff --git a/letta/settings.py b/letta/settings.py index d6907b11ee..1b6ba44bbe 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -83,9 +83,6 @@ class Settings(BaseSettings): pg_pool_recycle: int = 1800 # When to recycle connections pg_echo: bool = False # Logging - # tools configuration - load_default_external_tools: Optional[bool] = None - @property def letta_pg_uri(self) -> str: if self.pg_uri: diff --git a/tests/test_model_letta_perfomance.py b/tests/test_model_letta_perfomance.py index e473d5bb45..d45654eaaa 100644 --- a/tests/test_model_letta_perfomance.py +++ b/tests/test_model_letta_perfomance.py @@ -56,10 +56,35 @@ def wrapper(*args, **kwargs): return decorator_retry +def retry_until_success(max_attempts=10, sleep_time_seconds=4): + """ + Decorator to retry a function until it succeeds or the maximum number of attempts is reached. + + :param max_attempts: Maximum number of attempts to retry the function. + :param sleep_time_seconds: Time to wait between attempts, in seconds. + """ + + def decorator_retry(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + for attempt in range(1, max_attempts + 1): + try: + return func(*args, **kwargs) + except Exception as e: + print(f"\033[93mAttempt {attempt} failed with error:\n{e}\033[0m") + if attempt == max_attempts: + raise + time.sleep(sleep_time_seconds) + + return wrapper + + return decorator_retry + + # ====================================================================================================================== # OPENAI TESTS # ====================================================================================================================== -@retry_until_threshold(threshold=0.75, max_attempts=4) +@retry_until_success(max_attempts=5, sleep_time_seconds=2) def test_openai_gpt_4o_returns_valid_first_message(): filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") response = check_first_response_is_valid_for_llm_endpoint(filename) @@ -67,6 +92,7 @@ def test_openai_gpt_4o_returns_valid_first_message(): print(f"Got successful response from client: \n\n{response}") +@retry_until_success(max_attempts=5, sleep_time_seconds=2) def test_openai_gpt_4o_returns_keyword(): keyword = "banana" filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") @@ -75,6 +101,7 @@ def test_openai_gpt_4o_returns_keyword(): print(f"Got successful response from client: \n\n{response}") +@retry_until_success(max_attempts=5, sleep_time_seconds=2) def test_openai_gpt_4o_uses_external_tool(): filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") response = check_agent_uses_external_tool(filename) @@ -82,6 +109,7 @@ def test_openai_gpt_4o_uses_external_tool(): print(f"Got successful response from client: \n\n{response}") +@retry_until_success(max_attempts=5, sleep_time_seconds=2) def test_openai_gpt_4o_recall_chat_memory(): filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") response = check_agent_recall_chat_memory(filename) @@ -89,6 +117,7 @@ def test_openai_gpt_4o_recall_chat_memory(): print(f"Got successful response from client: \n\n{response}") +@retry_until_success(max_attempts=5, sleep_time_seconds=2) def test_openai_gpt_4o_archival_memory_retrieval(): filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") response = check_agent_archival_memory_retrieval(filename) @@ -96,6 +125,7 @@ def test_openai_gpt_4o_archival_memory_retrieval(): print(f"Got successful response from client: \n\n{response}") +@retry_until_success(max_attempts=5, sleep_time_seconds=2) def test_openai_gpt_4o_archival_memory_insert(): filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") response = check_agent_archival_memory_insert(filename) @@ -103,6 +133,7 @@ def test_openai_gpt_4o_archival_memory_insert(): print(f"Got successful response from client: \n\n{response}") +@retry_until_success(max_attempts=5, sleep_time_seconds=2) def test_openai_gpt_4o_edit_core_memory(): filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") response = check_agent_edit_core_memory(filename) @@ -110,12 +141,14 @@ def test_openai_gpt_4o_edit_core_memory(): print(f"Got successful response from client: \n\n{response}") +@retry_until_success(max_attempts=5, sleep_time_seconds=2) def test_openai_gpt_4o_summarize_memory(): filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") response = check_agent_summarize_memory_simple(filename) print(f"Got successful response from client: \n\n{response}") +@retry_until_success(max_attempts=5, sleep_time_seconds=2) def test_embedding_endpoint_openai(): filename = os.path.join(embedding_config_dir, "openai_embed.json") run_embedding_endpoint(filename) diff --git a/tests/test_server.py b/tests/test_server.py index 66e192cd64..4775ed91f5 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -362,10 +362,10 @@ def other_agent_id(server, user_id, base_tools): server.agent_manager.delete_agent(agent_state.id, actor=actor) -def test_error_on_nonexistent_agent(server, user_id, agent_id): +def test_error_on_nonexistent_agent(server, user, agent_id): try: fake_agent_id = str(uuid.uuid4()) - server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?") + server.user_message(user_id=user.id, agent_id=fake_agent_id, message="Hello?") raise Exception("user_message call should have failed") except (KeyError, ValueError) as e: # Error is expected @@ -375,9 +375,9 @@ def test_error_on_nonexistent_agent(server, user_id, agent_id): @pytest.mark.order(1) -def test_user_message_memory(server, user_id, agent_id): +def test_user_message_memory(server, user, agent_id): try: - server.user_message(user_id=user_id, agent_id=agent_id, message="/memory") + server.user_message(user_id=user.id, agent_id=agent_id, message="/memory") raise Exception("user_message call should have failed") except ValueError as e: # Error is expected @@ -385,13 +385,11 @@ def test_user_message_memory(server, user_id, agent_id): except: raise - server.run_command(user_id=user_id, agent_id=agent_id, command="/memory") + server.run_command(user_id=user.id, agent_id=agent_id, command="/memory") @pytest.mark.order(3) -def test_load_data(server, user_id, agent_id): - user = server.user_manager.get_user_or_default(user_id=user_id) - +def test_load_data(server, user, agent_id): # create source passages_before = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000) assert len(passages_before) == 0 @@ -409,10 +407,10 @@ def test_load_data(server, user_id, agent_id): "Shishir loves indian food", ] connector = DummyDataConnector(archival_memories) - server.load_data(user_id, connector, source.name) + server.load_data(user.id, connector, source.name) # attach source - server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source") + server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=user) # check archival memory size passages_after = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000) @@ -425,9 +423,9 @@ def test_save_archival_memory(server, user_id, agent_id): @pytest.mark.order(4) -def test_user_message(server, user_id, agent_id): +def test_user_message(server, user, agent_id): # add data into recall memory - server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?") + server.user_message(user_id=user.id, agent_id=agent_id, message="Hello?") # server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?") # server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?") # server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?") @@ -435,21 +433,20 @@ def test_user_message(server, user_id, agent_id): @pytest.mark.order(5) -def test_get_recall_memory(server, org_id, user_id, agent_id): +def test_get_recall_memory(server, org_id, user, agent_id): # test recall memory cursor pagination - actor = server.user_manager.get_user_or_default(user_id=user_id) - messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2) + actor = user + messages_1 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, limit=2) cursor1 = messages_1[-1].id - messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000) - messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000) + messages_2 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, after=cursor1, limit=1000) + messages_3 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, limit=1000) messages_3[-1].id assert messages_3[-1].created_at >= messages_3[0].created_at assert len(messages_3) == len(messages_1) + len(messages_2) - messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1) + messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, reverse=True, before=cursor1) assert len(messages_4) == 1 # test in-context message ids - # in_context_ids = server.get_in_context_message_ids(agent_id=agent_id) in_context_ids = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids message_ids = [m.id for m in messages_3] @@ -458,13 +455,13 @@ def test_get_recall_memory(server, org_id, user_id, agent_id): @pytest.mark.order(6) -def test_get_archival_memory(server, user_id, agent_id): +def test_get_archival_memory(server, user, agent_id): # test archival memory cursor pagination - user = server.user_manager.get_user_by_id(user_id=user_id) + actor = user # List latest 2 passages passages_1 = server.agent_manager.list_passages( - actor=user, + actor=actor, agent_id=agent_id, ascending=False, limit=2, @@ -474,7 +471,7 @@ def test_get_archival_memory(server, user_id, agent_id): # List next 3 passages (earliest 3) cursor1 = passages_1[-1].id passages_2 = server.agent_manager.list_passages( - actor=user, + actor=actor, agent_id=agent_id, ascending=False, cursor=cursor1, @@ -483,7 +480,7 @@ def test_get_archival_memory(server, user_id, agent_id): # List all 5 cursor2 = passages_1[0].created_at passages_3 = server.agent_manager.list_passages( - actor=user, + actor=actor, agent_id=agent_id, ascending=False, end_date=cursor2, @@ -496,20 +493,20 @@ def test_get_archival_memory(server, user_id, agent_id): earliest = passages_2[-1] # test archival memory - passage_1 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, limit=1, ascending=True) + passage_1 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, limit=1, ascending=True) assert len(passage_1) == 1 assert passage_1[0].text == "alpha" - passage_2 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True) + passage_2 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True) assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test assert all("alpha" not in passage.text for passage in passage_2) # test safe empty return - passage_none = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True) + passage_none = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True) assert len(passage_none) == 0 -def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id: str): +def test_get_context_window_overview(server: SyncServer, user, agent_id): """Test that the context window overview fetch works""" - overview = server.get_agent_context_window(agent_id=agent_id, actor=server.user_manager.get_user_or_default(user_id)) + overview = server.get_agent_context_window(agent_id=agent_id, actor=user) assert overview is not None # Run some basic checks @@ -546,7 +543,7 @@ def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id: ) -def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str): +def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User): agent_state = server.create_agent( request=CreateAgent( name="nonexistent_tools_agent", @@ -554,7 +551,7 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str): llm="openai/gpt-4", embedding="openai/text-embedding-ada-002", ), - actor=server.user_manager.get_user_or_default(user_id), + actor=user, ) # create another user in the same org @@ -566,14 +563,14 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str): def _test_get_messages_letta_format( server, - user_id, + user, agent_id, reverse=False, ): """Test mapping between messages and letta_messages with reverse=False.""" messages = server.get_agent_recall_cursor( - user_id=user_id, + user_id=user.id, agent_id=agent_id, limit=1000, reverse=reverse, @@ -582,7 +579,7 @@ def _test_get_messages_letta_format( assert all(isinstance(m, Message) for m in messages) letta_messages = server.get_agent_recall_cursor( - user_id=user_id, + user_id=user.id, agent_id=agent_id, limit=1000, reverse=reverse, @@ -675,10 +672,10 @@ def _test_get_messages_letta_format( warnings.warn(f"Extra letta_messages found: {len(letta_messages) - letta_message_index}") -def test_get_messages_letta_format(server, user_id, agent_id): +def test_get_messages_letta_format(server, user, agent_id): # for reverse in [False, True]: for reverse in [False]: - _test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse) + _test_get_messages_letta_format(server, user, agent_id, reverse=reverse) EXAMPLE_TOOL_SOURCE = ''' @@ -825,9 +822,9 @@ def test_composio_client_simple(server): assert len(actions) > 0 -def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools, base_memory_tools): +def test_memory_rebuild_count(server, user, mock_e2b_api_key_none, base_tools, base_memory_tools): """Test that the memory rebuild is generating the correct number of role=system messages""" - actor = server.user_manager.get_user_or_default(user_id) + actor = user # create agent agent_state = server.create_agent( request=CreateAgent( @@ -848,7 +845,7 @@ def count_system_messages_in_recall() -> Tuple[int, List[LettaMessage]]: # At this stage, there should only be 1 system message inside of recall storage letta_messages = server.get_agent_recall_cursor( - user_id=user_id, + user_id=user.id, agent_id=agent_state.id, limit=1000, # reverse=reverse, @@ -870,7 +867,7 @@ def count_system_messages_in_recall() -> Tuple[int, List[LettaMessage]]: assert num_system_messages == 1, (num_system_messages, all_messages) # Assuming core memory append actually ran correctly, at this point there should be 2 messages - server.user_message(user_id=user_id, agent_id=agent_state.id, message="Append 'banana' to your core memory") + server.user_message(user_id=user.id, agent_id=agent_state.id, message="Append 'banana' to your core memory") # At this stage, there should be 2 system message inside of recall storage num_system_messages, all_messages = count_system_messages_in_recall()