Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: avoid to use extra space when finding model by name #13043

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions api/core/model_runtime/model_providers/__base/ai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,12 @@ def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Op
:param credentials: model credentials
:return: model schema
"""
# get predefined models (predefined_models)
models = self.predefined_models()

model_map = {model.model: model for model in models}
if model in model_map:
return model_map[model]
# Try to get model schema from predefined models
for predefined_model in self.predefined_models():
if model == predefined_model.model:
return predefined_model

# Try to get model schema from credentials
if credentials:
model_schema = self.get_customizable_model_schema_from_credentials(model, credentials)
if model_schema:
Expand Down
17 changes: 9 additions & 8 deletions api/core/model_runtime/model_providers/cohere/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,16 +677,17 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
:return: model schema
"""
# get model schema
models = self.predefined_models()
model_map = {model.model: model for model in models}

mode = credentials.get("mode")
base_model_schema = None
for predefined_model in self.predefined_models():
if (
mode == "chat" and predefined_model.model == "command-light-chat"
) or predefined_model.model == "command-light":
base_model_schema = predefined_model
break

if mode == "chat":
base_model_schema = model_map["command-light-chat"]
else:
base_model_schema = model_map["command-light"]
if not base_model_schema:
raise ValueError("Model not found")

base_model_schema = cast(AIModelEntity, base_model_schema)

Expand Down
20 changes: 10 additions & 10 deletions api/core/model_runtime/model_providers/openai/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,6 @@ def remote_models(self, credentials: dict) -> list[AIModelEntity]:
:param credentials: provider credentials
:return:
"""
# get predefined models
predefined_models = self.predefined_models()
predefined_models_map = {model.model: model for model in predefined_models}

# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
Expand All @@ -359,9 +356,10 @@ def remote_models(self, credentials: dict) -> list[AIModelEntity]:
base_model = model.id.split(":")[1]

base_model_schema = None
for predefined_model_name, predefined_model in predefined_models_map.items():
if predefined_model_name in base_model:
for predefined_model in self.predefined_models():
if predefined_model.model in base_model:
base_model_schema = predefined_model
break

if not base_model_schema:
continue
Expand Down Expand Up @@ -1186,12 +1184,14 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
base_model = model.split(":")[1]

# get model schema
models = self.predefined_models()
model_map = {model.model: model for model in models}
if base_model not in model_map:
raise ValueError(f"Base model {base_model} not found")
base_model_schema = None
for predefined_model in self.predefined_models():
if base_model == predefined_model.model:
base_model_schema = predefined_model
break

base_model_schema = model_map[base_model]
if not base_model_schema:
raise ValueError(f"Base model {base_model} not found")

base_model_schema_features = base_model_schema.features or []
base_model_schema_model_properties = base_model_schema.model_properties
Expand Down
Loading