Skip to content

Commit

Permalink
feat: store handle in configs (#2299)
Browse files Browse the repository at this point in the history
Co-authored-by: Caren Thomas <caren@caren-mac.local>
  • Loading branch information
carenthomas and Caren Thomas authored Dec 21, 2024
1 parent 2fc9b54 commit 160aef5
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
23 changes: 20 additions & 3 deletions letta/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def get_model_context_window(self, model_name: str) -> Optional[int]:
def provider_tag(self) -> str:
"""String representation of the provider for display purposes"""
raise NotImplementedError

def get_handle(self, model_name: str) -> str:
return f"{self.name}/{model_name}"



class LettaProvider(Provider):
Expand All @@ -40,6 +44,7 @@ def list_llm_models(self) -> List[LLMConfig]:
model_endpoint_type="openai",
model_endpoint="https://inference.memgpt.ai",
context_window=16384,
handle=self.get_handle("letta-free")
)
]

Expand All @@ -51,6 +56,7 @@ def list_embedding_models(self):
embedding_endpoint="https://embeddings.memgpt.ai",
embedding_dim=1024,
embedding_chunk_size=300,
handle=self.get_handle("letta-free")
)
]

Expand Down Expand Up @@ -115,7 +121,7 @@ def list_llm_models(self) -> List[LLMConfig]:
# continue

configs.append(
LLMConfig(model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size)
LLMConfig(model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size, handle=self.get_handle(model_name))
)

# for OpenAI, sort in reverse order
Expand All @@ -135,6 +141,7 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
embedding_chunk_size=300,
handle=self.get_handle("text-embedding-ada-002")
)
]

Expand Down Expand Up @@ -163,6 +170,7 @@ def list_llm_models(self) -> List[LLMConfig]:
model_endpoint_type="anthropic",
model_endpoint=self.base_url,
context_window=model["context_window"],
handle=self.get_handle(model["name"])
)
)
return configs
Expand Down Expand Up @@ -195,6 +203,7 @@ def list_llm_models(self) -> List[LLMConfig]:
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=model["max_context_length"],
handle=self.get_handle(model["id"])
)
)

Expand Down Expand Up @@ -250,6 +259,7 @@ def list_llm_models(self) -> List[LLMConfig]:
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=context_window,
handle=self.get_handle(model["name"])
)
)
return configs
Expand Down Expand Up @@ -325,6 +335,7 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:
embedding_endpoint=self.base_url,
embedding_dim=embedding_dim,
embedding_chunk_size=300,
handle=self.get_handle(model["name"])
)
)
return configs
Expand All @@ -345,7 +356,7 @@ def list_llm_models(self) -> List[LLMConfig]:
continue
configs.append(
LLMConfig(
model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"]
model=model["id"], model_endpoint_type="groq", model_endpoint=self.base_url, context_window=model["context_window"], handle=self.get_handle(model["id"])
)
)
return configs
Expand Down Expand Up @@ -413,6 +424,7 @@ def list_llm_models(self) -> List[LLMConfig]:
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=context_window_size,
handle=self.get_handle(model_name)
)
)

Expand Down Expand Up @@ -493,6 +505,7 @@ def list_llm_models(self):
model_endpoint_type="google_ai",
model_endpoint=self.base_url,
context_window=self.get_model_context_window(model),
handle=self.get_handle(model)
)
)
return configs
Expand All @@ -516,6 +529,7 @@ def list_embedding_models(self):
embedding_endpoint=self.base_url,
embedding_dim=768,
embedding_chunk_size=300, # NOTE: max is 2048
handle=self.get_handle(model)
)
)
return configs
Expand Down Expand Up @@ -556,7 +570,7 @@ def list_llm_models(self) -> List[LLMConfig]:
context_window_size = self.get_model_context_window(model_name)
model_endpoint = get_azure_chat_completions_endpoint(self.base_url, model_name, self.api_version)
configs.append(
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size)
LLMConfig(model=model_name, model_endpoint_type="azure", model_endpoint=model_endpoint, context_window=context_window_size), handle=self.get_handle(model_name)
)
return configs

Expand All @@ -577,6 +591,7 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:
embedding_endpoint=model_endpoint,
embedding_dim=768,
embedding_chunk_size=300, # NOTE: max is 2048
handle=self.get_handle(model_name)
)
)
return configs
Expand Down Expand Up @@ -610,6 +625,7 @@ def list_llm_models(self) -> List[LLMConfig]:
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=model["max_model_len"],
handle=self.get_handle(model["id"])
)
)
return configs
Expand Down Expand Up @@ -642,6 +658,7 @@ def list_llm_models(self) -> List[LLMConfig]:
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=model["max_model_len"],
handle=self.get_handle(model["id"])
)
)
return configs
Expand Down
1 change: 1 addition & 0 deletions letta/schemas/embedding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class EmbeddingConfig(BaseModel):
embedding_model: str = Field(..., description="The model for the embedding.")
embedding_dim: int = Field(..., description="The dimension of the embedding.")
embedding_chunk_size: Optional[int] = Field(300, description="The chunk size of the embedding.")
handle: Optional[str] = Field(None, description="The handle for this config, in the format provider/model-name.")

# azure only
azure_endpoint: Optional[str] = Field(None, description="The Azure endpoint for the model.")
Expand Down
1 change: 1 addition & 0 deletions letta/schemas/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class LLMConfig(BaseModel):
True,
description="Puts 'inner_thoughts' as a kwarg in the function call if this is set to True. This helps with function calling performance and also the generation of inner thoughts.",
)
handle: Optional[str] = Field(None, description="The handle for this config, in the format provider/model-name.")

# FIXME hack to silence pydantic protected namespace warning
model_config = ConfigDict(protected_namespaces=())
Expand Down

0 comments on commit 160aef5

Please sign in to comment.