Skip to content

Commit

Permalink
chore: Clean up .load_agent usage (#2298)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattzh72 authored Dec 21, 2024
1 parent 50c3e00 commit bbe8dea
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 163 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Add cascading deletes for sources to agents
Revision ID: e78b4e82db30
Revises: d6632deac81d
Create Date: 2024-12-20 16:30:17.095888
"""

from typing import Sequence, Union

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "e78b4e82db30"
down_revision: Union[str, None] = "d6632deac81d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("sources_agents_agent_id_fkey", "sources_agents", type_="foreignkey")
op.drop_constraint("sources_agents_source_id_fkey", "sources_agents", type_="foreignkey")
op.create_foreign_key(None, "sources_agents", "sources", ["source_id"], ["id"], ondelete="CASCADE")
op.create_foreign_key(None, "sources_agents", "agents", ["agent_id"], ["id"], ondelete="CASCADE")
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, "sources_agents", type_="foreignkey")
op.drop_constraint(None, "sources_agents", type_="foreignkey")
op.create_foreign_key("sources_agents_source_id_fkey", "sources_agents", "sources", ["source_id"], ["id"])
op.create_foreign_key("sources_agents_agent_id_fkey", "sources_agents", "agents", ["agent_id"], ["id"])
# ### end Alembic commands ###
28 changes: 0 additions & 28 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import TerminalToolRule
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User as PydanticUser
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import (
Expand All @@ -53,7 +52,6 @@
)
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.source_manager import SourceManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.streaming_interface import StreamingRefreshCLIInterface
from letta.system import (
Expand Down Expand Up @@ -969,32 +967,6 @@ def migrate_embedding(self, embedding_config: EmbeddingConfig):
# TODO: recall memory
raise NotImplementedError()

def attach_source(
self,
user: PydanticUser,
source_id: str,
source_manager: SourceManager,
agent_manager: AgentManager,
):
"""Attach a source to the agent using the SourcesAgents ORM relationship.
Args:
user: User performing the action
source_id: ID of the source to attach
source_manager: SourceManager instance to verify source exists
agent_manager: AgentManager instance to manage agent-source relationship
"""
# Verify source exists and user has permission to access it
source = source_manager.get_source_by_id(source_id=source_id, actor=user)
assert source is not None, f"Source {source_id} not found in user's organization ({user.organization_id})"

# Use the agent_manager to create the relationship
agent_manager.attach_source(agent_id=self.agent_state.id, source_id=source_id, actor=user)

printd(
f"Attached data source {source.name} to agent {self.agent_state.name}.",
)

def get_context_window(self) -> ContextWindowOverview:
"""Get the context window of the agent"""

Expand Down
11 changes: 9 additions & 2 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2987,7 +2987,11 @@ def attach_source_to_agent(self, agent_id: str, source_id: Optional[str] = None,
source_id (str): ID of the source
source_name (str): Name of the source
"""
self.server.attach_source_to_agent(source_id=source_id, source_name=source_name, agent_id=agent_id, user_id=self.user_id)
if source_name:
source = self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user)
source_id = source.id

self.server.agent_manager.attach_source(source_id=source_id, agent_id=agent_id, actor=self.user)

def detach_source_from_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None):
"""
Expand All @@ -2999,7 +3003,10 @@ def detach_source_from_agent(self, agent_id: str, source_id: Optional[str] = Non
Returns:
source (Source): Detached source
"""
return self.server.detach_source_from_agent(source_id=source_id, source_name=source_name, agent_id=agent_id, user_id=self.user_id)
if source_name:
source = self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user)
source_id = source.id
return self.server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=self.user)

def list_sources(self) -> List[Source]:
"""
Expand Down
13 changes: 10 additions & 3 deletions letta/orm/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from letta.schemas.source import Source as PydanticSource

if TYPE_CHECKING:
from letta.orm.organization import Organization
from letta.orm.agent import Agent
from letta.orm.file import FileMetadata
from letta.orm.organization import Organization
from letta.orm.passage import SourcePassage
from letta.orm.agent import Agent


class Source(SqlalchemyBase, OrganizationMixin):
Expand All @@ -32,4 +32,11 @@ class Source(SqlalchemyBase, OrganizationMixin):
organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="source", cascade="all, delete-orphan")
agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")
agents: Mapped[List["Agent"]] = relationship(
"Agent",
secondary="sources_agents",
back_populates="sources",
lazy="selectin",
cascade="all, delete", # Ensures rows in sources_agents are deleted when the source is deleted
passive_deletes=True, # Allows the database to handle deletion of orphaned rows
)
4 changes: 2 additions & 2 deletions letta/orm/sources_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ class SourcesAgents(Base):

__tablename__ = "sources_agents"

agent_id: Mapped[String] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
source_id: Mapped[String] = mapped_column(String, ForeignKey("sources.id"), primary_key=True)
agent_id: Mapped[String] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True)
source_id: Mapped[String] = mapped_column(String, ForeignKey("sources.id", ondelete="CASCADE"), primary_key=True)
11 changes: 4 additions & 7 deletions letta/server/rest_api/routers/v1/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,8 @@ def attach_source_to_agent(
Attach a data source to an existing agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)

source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
assert source is not None, f"Source with id={source_id} not found."
source = server.attach_source_to_agent(source_id=source.id, agent_id=agent_id, user_id=actor.id)
return source
server.agent_manager.attach_source(source_id=source_id, agent_id=agent_id, actor=actor)
return server.source_manager.get_source_by_id(source_id=source_id, actor=actor)


@router.post("/{source_id}/detach", response_model=Source, operation_id="detach_agent_from_source")
Expand All @@ -148,8 +145,8 @@ def detach_source_from_agent(
Detach a data source from an existing agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)

return server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=actor.id)
server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
return server.source_manager.get_source_by_id(source_id=source_id, actor=actor)


@router.post("/{source_id}/upload", response_model=Job, operation_id="upload_file_to_source")
Expand Down
79 changes: 2 additions & 77 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage
from letta.schemas.source import Source
from letta.schemas.tool import Tool, ToolCreate
from letta.schemas.tool import Tool
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
Expand Down Expand Up @@ -303,11 +303,6 @@ def __init__(
self.block_manager.add_default_blocks(actor=self.default_user)
self.tool_manager.upsert_base_tools(actor=self.default_user)

# If there is a default org/user
# This logic may have to change in the future
if settings.load_default_external_tools:
self.add_default_external_tools(actor=self.default_user)

# collect providers (always has Letta as a default)
self._enabled_providers: List[Provider] = [LettaProvider()]
if model_settings.openai_api_key:
Expand Down Expand Up @@ -431,9 +426,6 @@ def _step(
skip_verify=True,
)

# save agent after step
# save_agent(letta_agent)

except Exception as e:
logger.error(f"Error in server._step: {e}")
print(traceback.print_exc())
Expand Down Expand Up @@ -944,11 +936,10 @@ def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor
agent_states = self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
for agent_state in agent_states:
agent_id = agent_state.id
agent = self.load_agent(agent_id=agent_id, actor=actor)

# Attach source to agent
curr_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
agent.attach_source(user=actor, source_id=source_id, source_manager=self.source_manager, agent_manager=self.agent_manager)
self.agent_manager.attach_source(agent_id=agent_state.id, source_id=source_id, actor=actor)
new_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
assert new_passage_size >= curr_passage_size # in case empty files are added

Expand All @@ -973,56 +964,6 @@ def load_data(
passage_count, document_count = load_data(connector, source, self.passage_manager, self.source_manager, actor=user)
return passage_count, document_count

def attach_source_to_agent(
self,
user_id: str,
agent_id: str,
source_id: Optional[str] = None,
source_name: Optional[str] = None,
) -> Source:
# attach a data source to an agent
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
if source_id:
data_source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
elif source_name:
data_source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor)
else:
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")

assert data_source, f"Data source with id={source_id} or name={source_name} does not exist"

# load agent
agent = self.load_agent(agent_id=agent_id, actor=actor)

# attach source to agent
agent.attach_source(user=actor, source_id=data_source.id, source_manager=self.source_manager, agent_manager=self.agent_manager)

return data_source

def detach_source_from_agent(
self,
user_id: str,
agent_id: str,
source_id: Optional[str] = None,
source_name: Optional[str] = None,
) -> Source:
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
if source_id:
source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
elif source_name:
source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor)
source_id = source.id
else:
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")

# delete agent-source mapping
self.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)

# return back source data
return source

def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]:
warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning)
return []
Expand Down Expand Up @@ -1060,22 +1001,6 @@ def list_all_sources(self, actor: User) -> List[Source]:

return sources_with_metadata

def add_default_external_tools(self, actor: User) -> bool:
"""Add default langchain tools. Return true if successful, false otherwise."""
success = True
tool_creates = ToolCreate.load_default_langchain_tools()
if tool_settings.composio_api_key:
tool_creates += ToolCreate.load_default_composio_tools()
for tool_create in tool_creates:
try:
self.tool_manager.create_or_update_tool(Tool(**tool_create.model_dump()), actor=actor)
except Exception as e:
warnings.warn(f"An error occurred while creating tool {tool_create}: {e}")
warnings.warn(traceback.format_exc())
success = False

return success

def update_agent_message(self, message_id: str, request: MessageUpdate, actor: User) -> Message:
"""Update the details of a message associated with an agent"""

Expand Down
3 changes: 0 additions & 3 deletions letta/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,6 @@ class Settings(BaseSettings):
pg_pool_recycle: int = 1800 # When to recycle connections
pg_echo: bool = False # Logging

# tools configuration
load_default_external_tools: Optional[bool] = None

@property
def letta_pg_uri(self) -> str:
if self.pg_uri:
Expand Down
Loading

0 comments on commit bbe8dea

Please sign in to comment.