diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index 88531d8ae00037..4523da438814fd 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -1,5 +1,5 @@ from decimal import Decimal -from enum import Enum +from enum import StrEnum from typing import Optional from pydantic import BaseModel @@ -8,7 +8,7 @@ from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo -class LLMMode(Enum): +class LLMMode(StrEnum): """ Enum class for large language model mode. """ diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 505068104c2c2d..d6fff2a793de04 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field, field_validator from core.model_runtime.entities import ImagePromptMessageContent +from core.model_runtime.entities.llm_entities import LLMMode from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.workflow.entities.variable_entities import VariableSelector from core.workflow.nodes.base import BaseNodeData @@ -12,7 +13,7 @@ class ModelConfig(BaseModel): provider: str name: str - mode: str + mode: LLMMode = LLMMode.COMPLETION completion_params: dict[str, Any] = {}