From 92f30ae88baec4213270afabecb21e606cbf1f22 Mon Sep 17 00:00:00 2001 From: cpacker Date: Fri, 11 Oct 2024 15:39:54 -0700 Subject: [PATCH] chore: added same pretty print to embeddings --- letta/cli/cli.py | 12 +----------- letta/schemas/embedding_config.py | 7 +++++++ 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/letta/cli/cli.py b/letta/cli/cli.py index 9e95800c84..04dbf359a4 100644 --- a/letta/cli/cli.py +++ b/letta/cli/cli.py @@ -14,7 +14,6 @@ from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL from letta.log import get_logger from letta.metadata import MetadataStore -from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import OptionState from letta.schemas.memory import ChatMemory, Memory from letta.server.server import logger as server_logger @@ -249,17 +248,8 @@ def run( embedding_configs = client.list_embedding_configs() embedding_options = [embedding_config.embedding_model for embedding_config in embedding_configs] - # TODO move into EmbeddingConfig as a class method? - def prettify_embed_config(embedding_config: EmbeddingConfig) -> str: - return ( - f"{embedding_config.embedding_model}" + f" ({embedding_config.embedding_endpoint})" - if embedding_config.embedding_endpoint - else "" - ) - embedding_choices = [ - questionary.Choice(title=prettify_embed_config(embedding_config), value=embedding_config) - for embedding_config in embedding_configs + questionary.Choice(title=embedding_config.pretty_print(), value=embedding_config) for embedding_config in embedding_configs ] # select model diff --git a/letta/schemas/embedding_config.py b/letta/schemas/embedding_config.py index e56b2f8272..31f7ee8da3 100644 --- a/letta/schemas/embedding_config.py +++ b/letta/schemas/embedding_config.py @@ -52,3 +52,10 @@ def default_config(cls, model_name: Optional[str] = None, provider: Optional[str ) else: raise ValueError(f"Model {model_name} not supported.") + + def pretty_print(self) -> str: + return ( + f"{self.embedding_model}" + + (f" [type={self.embedding_endpoint_type}]" if self.embedding_endpoint_type else "") + + (f" [ip={self.embedding_endpoint}]" if self.embedding_endpoint else "") + )