-
Notifications
You must be signed in to change notification settings - Fork 5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add Multi-Model Provider Support to Agent Component (#4416)
* Add Multi-Model Provider Support to Agent Component - Integrated model provider constants from `model_input_constants.py` into the Agent component to support multiple LLM providers - Added dynamic field management for different model providers (OpenAI, Azure, Groq, Anthropic, NVIDIA) - Implemented a dropdown for model provider selection with automatic input field updates * sorted list * Update agent.py making custom separate from sort * chore: remove unit test --------- Co-authored-by: italojohnny <italojohnnydosanjos@gmail.com>
- Loading branch information
1 parent
1692495
commit ae7d037
Showing
6 changed files
with
134 additions
and
131 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
71 changes: 71 additions & 0 deletions
71
src/backend/base/langflow/base/models/model_input_constants.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from langflow.base.models.model import LCModelComponent | ||
from langflow.components.models.anthropic import AnthropicModelComponent | ||
from langflow.components.models.azure_openai import AzureChatOpenAIComponent | ||
from langflow.components.models.groq import GroqModel | ||
from langflow.components.models.nvidia import NVIDIAModelComponent | ||
from langflow.components.models.openai import OpenAIModelComponent | ||
|
||
|
||
def get_filtered_inputs(component_class): | ||
base_input_names = {field.name for field in LCModelComponent._base_inputs} | ||
return [ | ||
set_advanced_true(input_) if input_.name == "temperature" else input_ | ||
for input_ in component_class().inputs | ||
if input_.name not in base_input_names | ||
] | ||
|
||
|
||
def set_advanced_true(component_input): | ||
component_input.advanced = True | ||
return component_input | ||
|
||
|
||
def create_input_fields_dict(inputs, prefix): | ||
return {f"{prefix}_{input_.name}": input_ for input_ in inputs} | ||
|
||
|
||
OPENAI_INPUTS = get_filtered_inputs(OpenAIModelComponent) | ||
AZURE_INPUTS = get_filtered_inputs(AzureChatOpenAIComponent) | ||
GROQ_INPUTS = get_filtered_inputs(GroqModel) | ||
ANTHROPIC_INPUTS = get_filtered_inputs(AnthropicModelComponent) | ||
NVIDIA_INPUTS = get_filtered_inputs(NVIDIAModelComponent) | ||
|
||
|
||
OPENAI_FIELDS = {input_.name: input_ for input_ in OPENAI_INPUTS} | ||
|
||
|
||
AZURE_FIELDS = create_input_fields_dict(AZURE_INPUTS, "azure") | ||
GROQ_FIELDS = create_input_fields_dict(GROQ_INPUTS, "groq") | ||
ANTHROPIC_FIELDS = create_input_fields_dict(ANTHROPIC_INPUTS, "anthropic") | ||
NVIDIA_FIELDS = create_input_fields_dict(NVIDIA_INPUTS, "nvidia") | ||
|
||
MODEL_PROVIDERS = ["Azure OpenAI", "OpenAI", "Groq", "Anthropic", "NVIDIA"] | ||
|
||
MODEL_PROVIDERS_DICT = { | ||
"Azure OpenAI": { | ||
"fields": AZURE_FIELDS, | ||
"inputs": AZURE_INPUTS, | ||
"prefix": "azure_", | ||
"component_class": AzureChatOpenAIComponent(), | ||
}, | ||
"OpenAI": { | ||
"fields": OPENAI_FIELDS, | ||
"inputs": OPENAI_INPUTS, | ||
"prefix": "", | ||
"component_class": OpenAIModelComponent(), | ||
}, | ||
"Groq": {"fields": GROQ_FIELDS, "inputs": GROQ_INPUTS, "prefix": "groq_", "component_class": GroqModel()}, | ||
"Anthropic": { | ||
"fields": ANTHROPIC_FIELDS, | ||
"inputs": ANTHROPIC_INPUTS, | ||
"prefix": "anthropic_", | ||
"component_class": AnthropicModelComponent(), | ||
}, | ||
"NVIDIA": { | ||
"fields": NVIDIA_FIELDS, | ||
"inputs": NVIDIA_INPUTS, | ||
"prefix": "nvidia_", | ||
"component_class": NVIDIAModelComponent(), | ||
}, | ||
} | ||
ALL_PROVIDER_FIELDS: list[str] = [field for provider in MODEL_PROVIDERS_DICT.values() for field in provider["fields"]] |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
25 changes: 0 additions & 25 deletions
25
src/backend/tests/unit/base/models/test_model_constants.py
This file was deleted.
Oops, something went wrong.