Skip to content

Commit

Permalink
add embedding_chunk_size param, more error checks
Browse files Browse the repository at this point in the history
  • Loading branch information
Caren Thomas committed Dec 17, 2024
1 parent f40ecf9 commit d960894
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 15 deletions.
1 change: 1 addition & 0 deletions letta/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

# embeddings
MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset
DEFAULT_EMBEDDING_CHUNK_SIZE = 300

# tokenizers
EMBEDDING_TO_TOKENIZER_MAP = {
Expand Down
10 changes: 5 additions & 5 deletions letta/schemas/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pydantic import BaseModel, Field, field_validator

from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
from letta.schemas.block import CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_base import OrmMetadataBase
Expand Down Expand Up @@ -110,13 +111,13 @@ class CreateAgent(BaseModel, validate_assignment=True): #
llm: Optional[str] = Field(
None,
description="The LLM configuration handle used by the agent, specified in the format "
"provider/model-name, as an alternative to specifying llm_config. This field can also "
"be used to override the context window by optionally appending ':context_window'.",
"provider/model-name, as an alternative to specifying llm_config.",
)
embedding: Optional[str] = Field(
None, description="The embedding configuration handle used by the agent, specified in the format provider/model-name."
)
context_window: Optional[int] = Field(None, description="The context window specification used by the agent.")
context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.")
embedding_chunk_size: Optional[int] = Field(DEFAULT_EMBEDDING_CHUNK_SIZE, description="The embedding chunk size used by the agent.")

@field_validator("name")
@classmethod
Expand Down Expand Up @@ -150,9 +151,8 @@ def validate_llm(cls, llm: Optional[str]) -> Optional[str]:
return llm

provider_name, model_name = llm.split("/", 1)
model_name, _, _ = model_name.partition(":")
if not provider_name or not model_name:
raise ValueError("The llm config handle should be in the format provider/model-name[:context_window]")
raise ValueError("The llm config handle should be in the format provider/model-name")

return llm

Expand Down
32 changes: 22 additions & 10 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,12 +780,14 @@ def create_agent(
if request.llm_config is None:
if request.llm is None:
raise ValueError("Must specify either llm or llm_config in request")
request.llm_config = self.get_llm_config_from_handle(request.llm, request.context_window)
request.llm_config = self.get_llm_config_from_handle(handle=request.llm, context_window_limit=request.context_window_limit)

if request.embedding_config is None:
if request.embedding is None:
raise ValueError("Must specify either embedding or embedding_config in request")
request.embedding_config = self.get_embedding_config_from_handle(request.embedding)
request.embedding_config = self.get_embedding_config_from_handle(
handle=request.embedding, embedding_chunk_size=request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE
)

"""Create a new agent using a config"""
# Invoke manager
Expand Down Expand Up @@ -1373,25 +1375,28 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
return embedding_models

def get_llm_config_from_handle(self, handle: str, context_window: Optional[int] = None) -> LLMConfig:
def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional[int] = None) -> LLMConfig:
provider_name, model_name = handle.split("/", 1)
model_name, _, context_window_override = model_name.partition(":")
providers = [provider for provider in self._enabled_providers if provider.name == provider_name]
if not providers:
raise ValueError(f"Provider {provider_name} is not supported")

llm_configs = [llm_config for llm_config in providers[0].list_llm_models() if llm_config.model == model_name]
if not llm_configs:
raise ValueError(f"LLM model {model_name} is not supported by {provider_name}")
elif len(llm_configs) > 1:
raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}")
else:
llm_config = llm_configs[0]

llm_config = llm_configs[0]
context_window = int(context_window_override) if context_window_override else context_window
if context_window:
llm_config.context_window = context_window
if context_window_limit:
if context_window_limit > llm_config.context_window:
raise ValueError(f"Context window limit ({context_window_limit}) is greater than maxmodel context window ({llm_config.context_window})")
llm_config.context_window = context_window_limit

return llm_config

def get_embedding_config_from_handle(self, handle: str) -> EmbeddingConfig:
def get_embedding_config_from_handle(self, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE) -> EmbeddingConfig:
provider_name, model_name = handle.split("/", 1)
providers = [provider for provider in self._enabled_providers if provider.name == provider_name]
if not providers:
Expand All @@ -1402,8 +1407,15 @@ def get_embedding_config_from_handle(self, handle: str) -> EmbeddingConfig:
]
if not embedding_configs:
raise ValueError(f"Embedding model {model_name} is not supported by {provider_name}")
elif len(embedding_configs) > 1:
raise ValueError(f"Multiple embedding models with name {model_name} supported by {provider_name}")
else:
embedding_config = embedding_configs[0]

if embedding_chunk_size:
embedding_config.embedding_chunk_size = embedding_chunk_size

return embedding_configs[0]
return embedding_config

def add_llm_model(self, request: LLMConfig) -> LLMConfig:
"""Add a new LLM model"""
Expand Down
3 changes: 3 additions & 0 deletions letta/services/agent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def create_agent(
) -> PydanticAgentState:
system = derive_system_message(agent_type=agent_create.agent_type, system=agent_create.system)

if not agent_create.llm_config or not agent_create.embedding_config:
raise ValueError("llm_config and embedding_config are required")

# create blocks (note: cannot be linked into the agent_id is created)
block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original
for create_block in agent_create.memory_blocks:
Expand Down

0 comments on commit d960894

Please sign in to comment.