Skip to content

Commit

Permalink
feat: Add new types and other changes (#2348)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattzh72 authored Jan 10, 2025
1 parent 008ef65 commit 1cab3a9
Show file tree
Hide file tree
Showing 17 changed files with 127 additions and 179 deletions.
1 change: 0 additions & 1 deletion .github/workflows/integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ jobs:
- "integration_test_tool_execution_sandbox.py"
- "integration_test_offline_memory_agent.py"
- "integration_test_agent_tool_graph.py"
- "integration_test_o1_agent.py"
services:
qdrant:
image: qdrant/qdrant
Expand Down
3 changes: 1 addition & 2 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
MESSAGE_SUMMARY_WARNING_FRAC,
O1_BASE_TOOLS,
REQ_HEARTBEAT_MESSAGE,
)
from letta.errors import ContextWindowExceededError
Expand Down Expand Up @@ -212,7 +211,7 @@ def execute_tool_and_persist_state(self, function_name: str, function_args: dict
# TODO: This is NO BUENO
# TODO: Matching purely by names is extremely problematic, users can create tools with these names and run them in the agent loop
# TODO: We will have probably have to match the function strings exactly for safety
if function_name in BASE_TOOLS or function_name in O1_BASE_TOOLS:
if function_name in BASE_TOOLS:
# base tools are allowed to access the `Agent` object and run on the database
function_args["self"] = self # need to attach self to arg since it's dynamically linked
function_response = callable_func(**function_args)
Expand Down
1 change: 0 additions & 1 deletion letta/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
# Base tools that cannot be edited, as they access agent state directly
# Note that we don't include "conversation_search_date" for now
BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "archival_memory_search"]
O1_BASE_TOOLS = ["send_thinking_message", "send_final_message"]
# Base memory tools CAN be edited, and are added by default by the server
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]

Expand Down
86 changes: 0 additions & 86 deletions letta/o1_agent.py

This file was deleted.

6 changes: 6 additions & 0 deletions letta/orm/enums.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from enum import Enum


class ToolType(str, Enum):
CUSTOM = "custom"
LETTA_CORE = "letta_core"
LETTA_MEMORY_CORE = "letta_memory_core"


class ToolSourceType(str, Enum):
"""Defines what a tool was derived from"""

Expand Down
1 change: 0 additions & 1 deletion letta/schemas/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class AgentType(str, Enum):

memgpt_agent = "memgpt_agent"
split_thread_agent = "split_thread_agent"
o1_agent = "o1_agent"
offline_memory_agent = "offline_memory_agent"
chat_only_agent = "chat_only_agent"

Expand Down
9 changes: 3 additions & 6 deletions letta/schemas/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,19 +206,16 @@ class ToolUpdate(LettaBase):
json_schema: Optional[Dict] = Field(
None, description="The JSON schema of the function (auto-generated from source_code if not provided)"
)
return_char_limit: Optional[int] = Field(None, description="The maximum number of characters in the response.")

class Config:
extra = "ignore" # Allows extra fields without validation errors
# TODO: Remove this, and clean usage of ToolUpdate everywhere else


class ToolRun(LettaBase):
id: str = Field(..., description="The ID of the tool to run.")
args: str = Field(..., description="The arguments to pass to the tool (as stringified JSON).")


class ToolRunFromSource(LettaBase):
source_code: str = Field(..., description="The source code of the function.")
args: str = Field(..., description="The arguments to pass to the tool (as stringified JSON).")
args: Dict[str, str] = Field(..., description="The arguments to pass to the tool.")
env_vars: Dict[str, str] = Field(None, description="The environment variables to pass to the tool.")
name: Optional[str] = Field(None, description="The name of the tool to run.")
source_type: Optional[str] = Field(None, description="The type of the source code.")
3 changes: 2 additions & 1 deletion letta/server/rest_api/routers/v1/sandbox_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,12 @@ def delete_sandbox_config(
def list_sandbox_configs(
limit: int = Query(1000, description="Number of results to return"),
cursor: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
sandbox_type: Optional[SandboxType] = Query(None, description="Filter for this specific sandbox type"),
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.sandbox_config_manager.list_sandbox_configs(actor, limit=limit, cursor=cursor)
return server.sandbox_config_manager.list_sandbox_configs(actor, limit=limit, cursor=cursor, sandbox_type=sandbox_type)


### Sandbox Environment Variable Routes
Expand Down
1 change: 1 addition & 0 deletions letta/server/rest_api/routers/v1/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def run_tool_from_source(
tool_source=request.source_code,
tool_source_type=request.source_type,
tool_args=request.args,
tool_env_vars=request.env_vars,
tool_name=request.name,
actor=actor,
)
Expand Down
21 changes: 7 additions & 14 deletions letta/server/server.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# inspecting tools
import json
import os
import traceback
import warnings
from abc import abstractmethod
from datetime import datetime
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

from composio.client import Composio
from composio.client.collections import ActionModel, AppModel
Expand All @@ -23,7 +22,6 @@
from letta.interface import AgentInterface # abstract
from letta.interface import CLIInterface # for printing to terminal
from letta.log import get_logger
from letta.o1_agent import O1Agent
from letta.offline_memory_agent import OfflineMemoryAgent
from letta.orm import Base
from letta.orm.errors import NoResultFound
Expand Down Expand Up @@ -391,8 +389,6 @@ def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface
interface = interface or self.default_interface_factory()
if agent_state.agent_type == AgentType.memgpt_agent:
agent = Agent(agent_state=agent_state, interface=interface, user=actor)
elif agent_state.agent_type == AgentType.o1_agent:
agent = O1Agent(agent_state=agent_state, interface=interface, user=actor)
elif agent_state.agent_type == AgentType.offline_memory_agent:
agent = OfflineMemoryAgent(agent_state=agent_state, interface=interface, user=actor)
elif agent_state.agent_type == AgentType.chat_only_agent:
Expand Down Expand Up @@ -1117,22 +1113,17 @@ def get_agent_context_window(self, agent_id: str, actor: User) -> ContextWindowO
def run_tool_from_source(
self,
actor: User,
tool_args: str,
tool_args: Dict[str, str],
tool_source: str,
tool_env_vars: Optional[Dict[str, str]] = None,
tool_source_type: Optional[str] = None,
tool_name: Optional[str] = None,
) -> ToolReturnMessage:
"""Run a tool from source code"""

try:
tool_args_dict = json.loads(tool_args)
except json.JSONDecodeError:
raise ValueError("Invalid JSON string for tool_args")

if tool_source_type is not None and tool_source_type != "python":
raise ValueError("Only Python source code is supported at this time")

# NOTE: we're creating a floating Tool object and NOT persiting to DB
# NOTE: we're creating a floating Tool object and NOT persisting to DB
tool = Tool(
name=tool_name,
source_code=tool_source,
Expand All @@ -1144,7 +1135,9 @@ def run_tool_from_source(

# Next, attempt to run the tool with the sandbox
try:
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, actor, tool_object=tool).run(agent_state=agent_state)
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args, actor, tool_object=tool).run(
agent_state=agent_state, additional_env_vars=tool_env_vars
)
return ToolReturnMessage(
id="null",
tool_call_id="null",
Expand Down
2 changes: 0 additions & 2 deletions letta/services/helpers/agent_manager_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ def derive_system_message(agent_type: AgentType, system: Optional[str] = None):
# TODO: don't hardcode
if agent_type == AgentType.memgpt_agent:
system = gpt_system.get_system_text("memgpt_chat")
elif agent_type == AgentType.o1_agent:
system = gpt_system.get_system_text("memgpt_modified_o1")
elif agent_type == AgentType.offline_memory_agent:
system = gpt_system.get_system_text("memgpt_offline_memory")
elif agent_type == AgentType.chat_only_agent:
Expand Down
13 changes: 6 additions & 7 deletions letta/services/sandbox_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,15 @@ def delete_sandbox_config(self, sandbox_config_id: str, actor: PydanticUser) ->

@enforce_types
def list_sandbox_configs(
self, actor: PydanticUser, cursor: Optional[str] = None, limit: Optional[int] = 50
self, actor: PydanticUser, cursor: Optional[str] = None, limit: Optional[int] = 50, sandbox_type: Optional[SandboxType] = None
) -> List[PydanticSandboxConfig]:
"""List all sandbox configurations with optional pagination."""
kwargs = {"organization_id": actor.organization_id}
if sandbox_type:
kwargs.update({"type": sandbox_type})

with self.session_maker() as session:
sandboxes = SandboxConfigModel.list(
db_session=session,
cursor=cursor,
limit=limit,
organization_id=actor.organization_id,
)
sandboxes = SandboxConfigModel.list(db_session=session, cursor=cursor, limit=limit, **kwargs)
return [sandbox.to_pydantic() for sandbox in sandboxes]

@enforce_types
Expand Down
22 changes: 16 additions & 6 deletions letta/services/tool_execution_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,23 @@ def __init__(self, tool_name: str, args: dict, user: User, force_recreate=False,
self.sandbox_config_manager = SandboxConfigManager(tool_settings)
self.force_recreate = force_recreate

def run(self, agent_state: Optional[AgentState] = None) -> SandboxRunResult:
def run(self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None) -> SandboxRunResult:
"""
Run the tool in a sandbox environment.
Args:
agent_state (Optional[AgentState]): The state of the agent invoking the tool
additional_env_vars (Optional[Dict]): Environment variables to inject into the sandbox
Returns:
Tuple[Any, Optional[AgentState]]: Tuple containing (tool_result, agent_state)
"""
if tool_settings.e2b_api_key:
logger.debug(f"Using e2b sandbox to execute {self.tool_name}")
result = self.run_e2b_sandbox(agent_state=agent_state)
result = self.run_e2b_sandbox(agent_state=agent_state, additional_env_vars=additional_env_vars)
else:
logger.debug(f"Using local sandbox to execute {self.tool_name}")
result = self.run_local_dir_sandbox(agent_state=agent_state)
result = self.run_local_dir_sandbox(agent_state=agent_state, additional_env_vars=additional_env_vars)

# Log out any stdout/stderr from the tool run
logger.debug(f"Executed tool '{self.tool_name}', logging output from tool run: \n")
Expand All @@ -98,19 +99,25 @@ def temporary_env_vars(self, env_vars: dict):
os.environ.clear()
os.environ.update(original_env) # Restore original environment variables

def run_local_dir_sandbox(self, agent_state: Optional[AgentState] = None) -> SandboxRunResult:
def run_local_dir_sandbox(
self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None
) -> SandboxRunResult:
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user)
local_configs = sbx_config.get_local_config()

# Get environment variables for the sandbox
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
env = os.environ.copy()
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
env.update(env_vars)

# Get environment variables for this agent specifically
if agent_state:
env.update(agent_state.get_agent_env_vars_as_dict())

# Finally, get any that are passed explicitly into the `run` function call
if additional_env_vars:
env.update(additional_env_vars)

# Safety checks
if not os.path.isdir(local_configs.sandbox_dir):
raise FileNotFoundError(f"Sandbox directory does not exist: {local_configs.sandbox_dir}")
Expand Down Expand Up @@ -277,7 +284,7 @@ def create_venv_for_local_sandbox(self, sandbox_dir_path: str, venv_path: str, e

# e2b sandbox specific functions

def run_e2b_sandbox(self, agent_state: Optional[AgentState] = None) -> SandboxRunResult:
def run_e2b_sandbox(self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None) -> SandboxRunResult:
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user)
sbx = self.get_running_e2b_sandbox_with_same_state(sbx_config)
if not sbx or self.force_recreate:
Expand All @@ -300,6 +307,9 @@ def run_e2b_sandbox(self, agent_state: Optional[AgentState] = None) -> SandboxRu
if agent_state:
env_vars.update(agent_state.get_agent_env_vars_as_dict())

# Finally, get any that are passed explicitly into the `run` function call
if additional_env_vars:
env_vars.update(additional_env_vars)
code = self.generate_execution_script(agent_state=agent_state)
execution = sbx.run_code(code, envs=env_vars)

Expand Down
Loading

0 comments on commit 1cab3a9

Please sign in to comment.