From d9608943433c0c8f2b738d543f06565cdaea0a82 Mon Sep 17 00:00:00 2001 From: Caren Thomas Date: Mon, 16 Dec 2024 17:20:04 -0800 Subject: [PATCH] add embedding_chunk_size param, more error checks --- letta/constants.py | 1 + letta/schemas/agent.py | 10 +++++----- letta/server/server.py | 32 ++++++++++++++++++++++---------- letta/services/agent_manager.py | 3 +++ 4 files changed, 31 insertions(+), 15 deletions(-) diff --git a/letta/constants.py b/letta/constants.py index 5e9ac9b268..dd32d3b99d 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -23,6 +23,7 @@ # embeddings MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset +DEFAULT_EMBEDDING_CHUNK_SIZE = 300 # tokenizers EMBEDDING_TO_TOKENIZER_MAP = { diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index e94fd29941..840ca58961 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, Field, field_validator +from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE from letta.schemas.block import CreateBlock from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.letta_base import OrmMetadataBase @@ -110,13 +111,13 @@ class CreateAgent(BaseModel, validate_assignment=True): # llm: Optional[str] = Field( None, description="The LLM configuration handle used by the agent, specified in the format " - "provider/model-name, as an alternative to specifying llm_config. This field can also " - "be used to override the context window by optionally appending ':context_window'.", + "provider/model-name, as an alternative to specifying llm_config.", ) embedding: Optional[str] = Field( None, description="The embedding configuration handle used by the agent, specified in the format provider/model-name." ) - context_window: Optional[int] = Field(None, description="The context window specification used by the agent.") + context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.") + embedding_chunk_size: Optional[int] = Field(DEFAULT_EMBEDDING_CHUNK_SIZE, description="The embedding chunk size used by the agent.") @field_validator("name") @classmethod @@ -150,9 +151,8 @@ def validate_llm(cls, llm: Optional[str]) -> Optional[str]: return llm provider_name, model_name = llm.split("/", 1) - model_name, _, _ = model_name.partition(":") if not provider_name or not model_name: - raise ValueError("The llm config handle should be in the format provider/model-name[:context_window]") + raise ValueError("The llm config handle should be in the format provider/model-name") return llm diff --git a/letta/server/server.py b/letta/server/server.py index 3d5854027b..ed4ae05a55 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -780,12 +780,14 @@ def create_agent( if request.llm_config is None: if request.llm is None: raise ValueError("Must specify either llm or llm_config in request") - request.llm_config = self.get_llm_config_from_handle(request.llm, request.context_window) + request.llm_config = self.get_llm_config_from_handle(handle=request.llm, context_window_limit=request.context_window_limit) if request.embedding_config is None: if request.embedding is None: raise ValueError("Must specify either embedding or embedding_config in request") - request.embedding_config = self.get_embedding_config_from_handle(request.embedding) + request.embedding_config = self.get_embedding_config_from_handle( + handle=request.embedding, embedding_chunk_size=request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE + ) """Create a new agent using a config""" # Invoke manager @@ -1373,9 +1375,8 @@ def list_embedding_models(self) -> List[EmbeddingConfig]: warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}") return embedding_models - def get_llm_config_from_handle(self, handle: str, context_window: Optional[int] = None) -> LLMConfig: + def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional[int] = None) -> LLMConfig: provider_name, model_name = handle.split("/", 1) - model_name, _, context_window_override = model_name.partition(":") providers = [provider for provider in self._enabled_providers if provider.name == provider_name] if not providers: raise ValueError(f"Provider {provider_name} is not supported") @@ -1383,15 +1384,19 @@ def get_llm_config_from_handle(self, handle: str, context_window: Optional[int] llm_configs = [llm_config for llm_config in providers[0].list_llm_models() if llm_config.model == model_name] if not llm_configs: raise ValueError(f"LLM model {model_name} is not supported by {provider_name}") + elif len(llm_configs) > 1: + raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}") + else: + llm_config = llm_configs[0] - llm_config = llm_configs[0] - context_window = int(context_window_override) if context_window_override else context_window - if context_window: - llm_config.context_window = context_window + if context_window_limit: + if context_window_limit > llm_config.context_window: + raise ValueError(f"Context window limit ({context_window_limit}) is greater than maxmodel context window ({llm_config.context_window})") + llm_config.context_window = context_window_limit return llm_config - def get_embedding_config_from_handle(self, handle: str) -> EmbeddingConfig: + def get_embedding_config_from_handle(self, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE) -> EmbeddingConfig: provider_name, model_name = handle.split("/", 1) providers = [provider for provider in self._enabled_providers if provider.name == provider_name] if not providers: @@ -1402,8 +1407,15 @@ def get_embedding_config_from_handle(self, handle: str) -> EmbeddingConfig: ] if not embedding_configs: raise ValueError(f"Embedding model {model_name} is not supported by {provider_name}") + elif len(embedding_configs) > 1: + raise ValueError(f"Multiple embedding models with name {model_name} supported by {provider_name}") + else: + embedding_config = embedding_configs[0] + + if embedding_chunk_size: + embedding_config.embedding_chunk_size = embedding_chunk_size - return embedding_configs[0] + return embedding_config def add_llm_model(self, request: LLMConfig) -> LLMConfig: """Add a new LLM model""" diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 093831aab7..aac3b7221b 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -49,6 +49,9 @@ def create_agent( ) -> PydanticAgentState: system = derive_system_message(agent_type=agent_create.agent_type, system=agent_create.system) + if not agent_create.llm_config or not agent_create.embedding_config: + raise ValueError("llm_config and embedding_config are required") + # create blocks (note: cannot be linked into the agent_id is created) block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original for create_block in agent_create.memory_blocks: