Skip to content

Commit

Permalink
feat: Add Multi-Model Provider Support to Agent Component (#4416)
Browse files Browse the repository at this point in the history
* Add Multi-Model Provider Support to Agent Component

- Integrated model provider constants from `model_input_constants.py` into the Agent component to support multiple LLM providers
- Added dynamic field management for different model providers (OpenAI, Azure, Groq, Anthropic, NVIDIA)
- Implemented a dropdown for model provider selection with automatic input field updates

* sorted list

* Update agent.py

making custom separate from sort

* chore: remove unit test

---------

Co-authored-by: italojohnny <italojohnnydosanjos@gmail.com>
  • Loading branch information
edwinjosechittilappilly and italojohnny authored Nov 6, 2024
1 parent 1692495 commit ae7d037
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 131 deletions.
17 changes: 0 additions & 17 deletions src/backend/base/langflow/base/models/model_constants.py

This file was deleted.

71 changes: 71 additions & 0 deletions src/backend/base/langflow/base/models/model_input_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from langflow.base.models.model import LCModelComponent
from langflow.components.models.anthropic import AnthropicModelComponent
from langflow.components.models.azure_openai import AzureChatOpenAIComponent
from langflow.components.models.groq import GroqModel
from langflow.components.models.nvidia import NVIDIAModelComponent
from langflow.components.models.openai import OpenAIModelComponent


def get_filtered_inputs(component_class):
base_input_names = {field.name for field in LCModelComponent._base_inputs}
return [
set_advanced_true(input_) if input_.name == "temperature" else input_
for input_ in component_class().inputs
if input_.name not in base_input_names
]


def set_advanced_true(component_input):
component_input.advanced = True
return component_input


def create_input_fields_dict(inputs, prefix):
return {f"{prefix}_{input_.name}": input_ for input_ in inputs}


OPENAI_INPUTS = get_filtered_inputs(OpenAIModelComponent)
AZURE_INPUTS = get_filtered_inputs(AzureChatOpenAIComponent)
GROQ_INPUTS = get_filtered_inputs(GroqModel)
ANTHROPIC_INPUTS = get_filtered_inputs(AnthropicModelComponent)
NVIDIA_INPUTS = get_filtered_inputs(NVIDIAModelComponent)


OPENAI_FIELDS = {input_.name: input_ for input_ in OPENAI_INPUTS}


AZURE_FIELDS = create_input_fields_dict(AZURE_INPUTS, "azure")
GROQ_FIELDS = create_input_fields_dict(GROQ_INPUTS, "groq")
ANTHROPIC_FIELDS = create_input_fields_dict(ANTHROPIC_INPUTS, "anthropic")
NVIDIA_FIELDS = create_input_fields_dict(NVIDIA_INPUTS, "nvidia")

MODEL_PROVIDERS = ["Azure OpenAI", "OpenAI", "Groq", "Anthropic", "NVIDIA"]

MODEL_PROVIDERS_DICT = {
"Azure OpenAI": {
"fields": AZURE_FIELDS,
"inputs": AZURE_INPUTS,
"prefix": "azure_",
"component_class": AzureChatOpenAIComponent(),
},
"OpenAI": {
"fields": OPENAI_FIELDS,
"inputs": OPENAI_INPUTS,
"prefix": "",
"component_class": OpenAIModelComponent(),
},
"Groq": {"fields": GROQ_FIELDS, "inputs": GROQ_INPUTS, "prefix": "groq_", "component_class": GroqModel()},
"Anthropic": {
"fields": ANTHROPIC_FIELDS,
"inputs": ANTHROPIC_INPUTS,
"prefix": "anthropic_",
"component_class": AnthropicModelComponent(),
},
"NVIDIA": {
"fields": NVIDIA_FIELDS,
"inputs": NVIDIA_INPUTS,
"prefix": "nvidia_",
"component_class": NVIDIAModelComponent(),
},
}
ALL_PROVIDER_FIELDS: list[str] = [field for provider in MODEL_PROVIDERS_DICT.values() for field in provider["fields"]]
31 changes: 0 additions & 31 deletions src/backend/base/langflow/base/models/model_utils.py

This file was deleted.

121 changes: 63 additions & 58 deletions src/backend/base/langflow/components/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from langflow.base.agents.agent import LCToolsAgentComponent
from langflow.base.models.model import LCModelComponent
from langflow.base.models.model_input_constants import ALL_PROVIDER_FIELDS, MODEL_PROVIDERS_DICT
from langflow.components.agents.tool_calling import ToolCallingAgentComponent
from langflow.components.helpers.memory import MemoryComponent
from langflow.components.models.azure_openai import AzureChatOpenAIComponent
from langflow.components.models.openai import OpenAIModelComponent
from langflow.io import (
DropdownInput,
MultilineInput,
Expand All @@ -25,30 +23,19 @@ class AgentComponent(ToolCallingAgentComponent):
beta = True
name = "Agent"

azure_inputs = [
set_advanced_true(component_input) if component_input.name == "temperature" else component_input
for component_input in AzureChatOpenAIComponent().inputs
if component_input.name not in [input_field.name for input_field in LCModelComponent._base_inputs]
]
openai_inputs = [
set_advanced_true(component_input) if component_input.name == "temperature" else component_input
for component_input in OpenAIModelComponent().inputs
if component_input.name not in [input_field.name for input_field in LCModelComponent._base_inputs]
]

memory_inputs = [set_advanced_true(component_input) for component_input in MemoryComponent().inputs]

inputs = [
DropdownInput(
name="agent_llm",
display_name="Model Provider",
options=["Azure OpenAI", "OpenAI", "Custom"],
options=[*sorted(MODEL_PROVIDERS_DICT.keys()), "Custom"],
value="OpenAI",
real_time_refresh=True,
refresh_button=True,
input_types=[],
),
*openai_inputs,
*MODEL_PROVIDERS_DICT["OpenAI"]["inputs"],
MultilineInput(
name="system_prompt",
display_name="Agent Instructions",
Expand Down Expand Up @@ -86,71 +73,89 @@ def get_memory_data(self):
return MemoryComponent().set(**memory_kwargs).retrieve_messages()

def get_llm(self):
try:
if self.agent_llm == "OpenAI":
return self._build_llm_model(OpenAIModelComponent(), self.openai_inputs)
if self.agent_llm == "Azure OpenAI":
return self._build_llm_model(AzureChatOpenAIComponent(), self.azure_inputs, prefix="azure_param_")
except Exception as e:
msg = f"Error building {self.agent_llm} language model"
raise ValueError(msg) from e
if isinstance(self.agent_llm, str):
try:
provider_info = MODEL_PROVIDERS_DICT.get(self.agent_llm)
if provider_info:
component_class = provider_info.get("component_class")
inputs = provider_info.get("inputs")
prefix = provider_info.get("prefix", "")
return self._build_llm_model(component_class, inputs, prefix)
except Exception as e:
msg = f"Error building {self.agent_llm} language model"
raise ValueError(msg) from e
return self.agent_llm

def _build_llm_model(self, component, inputs, prefix=""):
return component.set(
**{component_input.name: getattr(self, f"{prefix}{component_input.name}") for component_input in inputs}
).build_model()
model_kwargs = {input_.name: getattr(self, f"{prefix}{input_.name}") for input_ in inputs}
return component.set(**model_kwargs).build_model()

def delete_fields(self, build_config, fields):
def delete_fields(self, build_config: dotdict, fields: dict | list[str]) -> None:
"""Delete specified fields from build_config."""
for field in fields:
build_config.pop(field, None)

def update_build_config(self, build_config: dotdict, field_value: str, field_name: str | None = None):
def update_input_types(self, build_config: dotdict) -> dotdict:
"""Update input types for all fields in build_config."""
for key, value in build_config.items():
if isinstance(value, dict):
if value.get("input_types") is None:
build_config[key]["input_types"] = []
elif hasattr(value, "input_types") and value.input_types is None:
value.input_types = []
return build_config

def update_build_config(self, build_config: dotdict, field_value: str, field_name: str | None = None) -> dotdict:
if field_name == "agent_llm":
openai_fields = {component_input.name: component_input for component_input in self.openai_inputs}
azure_fields = {
f"azure_param_{component_input.name}": component_input for component_input in self.azure_inputs
# Define provider configurations as (fields_to_add, fields_to_delete)
provider_configs: dict[str, tuple[dict, list[dict]]] = {
provider: (
MODEL_PROVIDERS_DICT[provider]["fields"],
[
MODEL_PROVIDERS_DICT[other_provider]["fields"]
for other_provider in MODEL_PROVIDERS_DICT
if other_provider != provider
],
)
for provider in MODEL_PROVIDERS_DICT
}

if field_value == "OpenAI":
self.delete_fields(build_config, {**azure_fields})
if not any(field in build_config for field in openai_fields):
build_config.update(openai_fields)
build_config["agent_llm"]["input_types"] = []
build_config = self.update_input_types(build_config)
if field_value in provider_configs:
fields_to_add, fields_to_delete = provider_configs[field_value]

# Delete fields from other providers
for fields in fields_to_delete:
self.delete_fields(build_config, fields)

elif field_value == "Azure OpenAI":
self.delete_fields(build_config, {**openai_fields})
build_config.update(azure_fields)
# Add provider-specific fields
if field_value == "OpenAI" and not any(field in build_config for field in fields_to_add):
build_config.update(fields_to_add)
else:
build_config.update(fields_to_add)
# Reset input types for agent_llm
build_config["agent_llm"]["input_types"] = []
build_config = self.update_input_types(build_config)
elif field_value == "Custom":
self.delete_fields(build_config, {**openai_fields})
self.delete_fields(build_config, {**azure_fields})
new_component = DropdownInput(
# Delete all provider fields
self.delete_fields(build_config, ALL_PROVIDER_FIELDS)
# Update with custom component
custom_component = DropdownInput(
name="agent_llm",
display_name="Language Model",
options=["Azure OpenAI", "OpenAI", "Custom"],
options=[*sorted(MODEL_PROVIDERS_DICT.keys()), "Custom"],
value="Custom",
real_time_refresh=True,
input_types=["LanguageModel"],
)
build_config.update({"agent_llm": new_component.to_dict()})
build_config = self.update_input_types(build_config)
build_config.update({"agent_llm": custom_component.to_dict()})

# Update input types for all fields
build_config = self.update_input_types(build_config)

# Validate required keys
default_keys = ["code", "_type", "agent_llm", "tools", "input_value"]
missing_keys = [key for key in default_keys if key not in build_config]
if missing_keys:
msg = f"Missing required keys in build_config: {missing_keys}"
raise ValueError(msg)
return build_config

def update_input_types(self, build_config):
for key, value in build_config.items():
# Check if the value is a dictionary
if isinstance(value, dict):
if value.get("input_types") is None:
build_config[key]["input_types"] = []
# Check if the value has an attribute 'input_types' and it is None
elif hasattr(value, "input_types") and value.input_types is None:
value.input_types = []
return build_config
Empty file.
25 changes: 0 additions & 25 deletions src/backend/tests/unit/base/models/test_model_constants.py

This file was deleted.

0 comments on commit ae7d037

Please sign in to comment.