From e32417e988c53069465425ef0d32e96793056bc8 Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 17 Dec 2024 15:31:19 -0800 Subject: [PATCH] feat: Add optional llm and embedding handle args to CreateAgent request (#2260) --- letta/constants.py | 1 + letta/providers.py | 2 ++ letta/schemas/agent.py | 35 ++++++++++++++++++ letta/server/server.py | 63 +++++++++++++++++++++++++++++++++ letta/services/agent_manager.py | 3 ++ tests/test_server.py | 21 ++++++----- 6 files changed, 114 insertions(+), 11 deletions(-) diff --git a/letta/constants.py b/letta/constants.py index 437d956c49..d47f63a2ab 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -23,6 +23,7 @@ # embeddings MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset +DEFAULT_EMBEDDING_CHUNK_SIZE = 300 # tokenizers EMBEDDING_TO_TOKENIZER_MAP = { diff --git a/letta/providers.py b/letta/providers.py index b28baee791..5721db4627 100644 --- a/letta/providers.py +++ b/letta/providers.py @@ -13,6 +13,7 @@ class Provider(BaseModel): + name: str = Field(..., description="The name of the provider") def list_llm_models(self) -> List[LLMConfig]: return [] @@ -465,6 +466,7 @@ def list_embedding_models(self) -> List[EmbeddingConfig]: class GoogleAIProvider(Provider): # gemini + name: str = "google_ai" api_key: str = Field(..., description="API key for the Google AI API.") base_url: str = "https://generativelanguage.googleapis.com" diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index ea3afd28a7..840ca58961 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, Field, field_validator +from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE from letta.schemas.block import CreateBlock from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.letta_base import OrmMetadataBase @@ -107,6 +108,16 @@ class CreateAgent(BaseModel, validate_assignment=True): # include_base_tools: bool = Field(True, description="The LLM configuration used by the agent.") description: Optional[str] = Field(None, description="The description of the agent.") metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_") + llm: Optional[str] = Field( + None, + description="The LLM configuration handle used by the agent, specified in the format " + "provider/model-name, as an alternative to specifying llm_config.", + ) + embedding: Optional[str] = Field( + None, description="The embedding configuration handle used by the agent, specified in the format provider/model-name." + ) + context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.") + embedding_chunk_size: Optional[int] = Field(DEFAULT_EMBEDDING_CHUNK_SIZE, description="The embedding chunk size used by the agent.") @field_validator("name") @classmethod @@ -133,6 +144,30 @@ def validate_name(cls, name: str) -> str: return name + @field_validator("llm") + @classmethod + def validate_llm(cls, llm: Optional[str]) -> Optional[str]: + if not llm: + return llm + + provider_name, model_name = llm.split("/", 1) + if not provider_name or not model_name: + raise ValueError("The llm config handle should be in the format provider/model-name") + + return llm + + @field_validator("embedding") + @classmethod + def validate_embedding(cls, embedding: Optional[str]) -> Optional[str]: + if not embedding: + return embedding + + provider_name, model_name = embedding.split("/", 1) + if not provider_name or not model_name: + raise ValueError("The embedding config handle should be in the format provider/model-name") + + return embedding + class UpdateAgent(BaseModel): name: Optional[str] = Field(None, description="The name of the agent.") diff --git a/letta/server/server.py b/letta/server/server.py index 71b0ac78da..24d70ef351 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -776,6 +776,18 @@ def create_agent( # interface interface: Union[AgentInterface, None] = None, ) -> AgentState: + if request.llm_config is None: + if request.llm is None: + raise ValueError("Must specify either llm or llm_config in request") + request.llm_config = self.get_llm_config_from_handle(handle=request.llm, context_window_limit=request.context_window_limit) + + if request.embedding_config is None: + if request.embedding is None: + raise ValueError("Must specify either embedding or embedding_config in request") + request.embedding_config = self.get_embedding_config_from_handle( + handle=request.embedding, embedding_chunk_size=request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE + ) + """Create a new agent using a config""" # Invoke manager agent_state = self.agent_manager.create_agent( @@ -1283,6 +1295,57 @@ def list_embedding_models(self) -> List[EmbeddingConfig]: warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}") return embedding_models + def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional[int] = None) -> LLMConfig: + provider_name, model_name = handle.split("/", 1) + provider = self.get_provider_from_name(provider_name) + + llm_configs = [config for config in provider.list_llm_models() if config.model == model_name] + if not llm_configs: + raise ValueError(f"LLM model {model_name} is not supported by {provider_name}") + elif len(llm_configs) > 1: + raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}") + else: + llm_config = llm_configs[0] + + if context_window_limit: + if context_window_limit > llm_config.context_window: + raise ValueError( + f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})" + ) + llm_config.context_window = context_window_limit + + return llm_config + + def get_embedding_config_from_handle( + self, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE + ) -> EmbeddingConfig: + provider_name, model_name = handle.split("/", 1) + provider = self.get_provider_from_name(provider_name) + + embedding_configs = [config for config in provider.list_embedding_models() if config.embedding_model == model_name] + if not embedding_configs: + raise ValueError(f"Embedding model {model_name} is not supported by {provider_name}") + elif len(embedding_configs) > 1: + raise ValueError(f"Multiple embedding models with name {model_name} supported by {provider_name}") + else: + embedding_config = embedding_configs[0] + + if embedding_chunk_size: + embedding_config.embedding_chunk_size = embedding_chunk_size + + return embedding_config + + def get_provider_from_name(self, provider_name: str) -> Provider: + providers = [provider for provider in self._enabled_providers if provider.name == provider_name] + if not providers: + raise ValueError(f"Provider {provider_name} is not supported") + elif len(providers) > 1: + raise ValueError(f"Multiple providers with name {provider_name} supported") + else: + provider = providers[0] + + return provider + def add_llm_model(self, request: LLMConfig) -> LLMConfig: """Add a new LLM model""" diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index d1edb3eaba..99dfa3ae47 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -61,6 +61,9 @@ def create_agent( ) -> PydanticAgentState: system = derive_system_message(agent_type=agent_create.agent_type, system=agent_create.system) + if not agent_create.llm_config or not agent_create.embedding_config: + raise ValueError("llm_config and embedding_config are required") + # create blocks (note: cannot be linked into the agent_id is created) block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original for create_block in agent_create.memory_blocks: diff --git a/tests/test_server.py b/tests/test_server.py index 975cde698f..93159aa583 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -24,7 +24,6 @@ from letta.schemas.agent import CreateAgent, UpdateAgent from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.job import Job as PydanticJob -from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message from letta.schemas.source import Source as PydanticSource from letta.server.server import SyncServer @@ -329,8 +328,8 @@ def agent_id(server, user_id, base_tools): name="test_agent", tool_ids=[t.id for t in base_tools], memory_blocks=[], - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), + llm="openai/gpt-4", + embedding="openai/text-embedding-ada-002", ), actor=actor, ) @@ -350,8 +349,8 @@ def other_agent_id(server, user_id, base_tools): name="test_agent_other", tool_ids=[t.id for t in base_tools], memory_blocks=[], - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), + llm="openai/gpt-4", + embedding="openai/text-embedding-ada-002", ), actor=actor, ) @@ -618,8 +617,8 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str): request=CreateAgent( name="nonexistent_tools_agent", memory_blocks=[], - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), + llm="openai/gpt-4", + embedding="openai/text-embedding-ada-002", ), actor=server.user_manager.get_user_or_default(user_id), ) @@ -904,8 +903,8 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools CreateBlock(label="human", value="The human's name is Bob."), CreateBlock(label="persona", value="My name is Alice."), ], - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), + llm="openai/gpt-4", + embedding="openai/text-embedding-ada-002", ), actor=actor, ) @@ -1091,8 +1090,8 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to CreateBlock(label="human", value="The human's name is Bob."), CreateBlock(label="persona", value="My name is Alice."), ], - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), + llm="openai/gpt-4", + embedding="openai/text-embedding-ada-002", include_base_tools=False, ), actor=actor,