From 04d13a81168d2142fb1f3faefb992f814fb3d040 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 7 Feb 2025 11:01:31 +0800 Subject: [PATCH] feat(credits): Allow to configure model-credit mapping (#13274) Signed-off-by: -LAN- --- .../feature/hosted_service/__init__.py | 35 ++++++++++++++++++- api/core/workflow/nodes/llm/node.py | 6 ++-- .../deduct_quota_when_message_created.py | 6 ++-- 3 files changed, 38 insertions(+), 9 deletions(-) diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py index 7dd47e3658134f..63aec890383f2e 100644 --- a/api/configs/feature/hosted_service/__init__.py +++ b/api/configs/feature/hosted_service/__init__.py @@ -1,9 +1,40 @@ from typing import Optional -from pydantic import Field, NonNegativeInt +from pydantic import Field, NonNegativeInt, computed_field from pydantic_settings import BaseSettings +class HostedCreditConfig(BaseSettings): + HOSTED_MODEL_CREDIT_CONFIG: str = Field( + description="Model credit configuration in format 'model:credits,model:credits', e.g., 'gpt-4:20,gpt-4o:10'", + default="", + ) + + def get_model_credits(self, model_name: str) -> int: + """ + Get credit value for a specific model name. + Returns 1 if model is not found in configuration (default credit). + + :param model_name: The name of the model to search for + :return: The credit value for the model + """ + if not self.HOSTED_MODEL_CREDIT_CONFIG: + return 1 + + try: + credit_map = dict( + item.strip().split(":", 1) for item in self.HOSTED_MODEL_CREDIT_CONFIG.split(",") if ":" in item + ) + + # Search for matching model pattern + for pattern, credit in credit_map.items(): + if pattern.strip() in model_name: + return int(credit) + return 1 # Default quota if no match found + except (ValueError, AttributeError): + return 1 # Return default quota if parsing fails + + class HostedOpenAiConfig(BaseSettings): """ Configuration for hosted OpenAI service @@ -202,5 +233,7 @@ class HostedServiceConfig( HostedZhipuAIConfig, # moderation HostedModerationConfig, + # credit config + HostedCreditConfig, ): pass diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 6a4f8c4e207bb2..7e28aa7a3ffb3d 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -3,6 +3,7 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, cast +from configs import dify_config from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit @@ -732,10 +733,7 @@ def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: if quota_unit == QuotaUnit.TOKENS: used_quota = usage.total_tokens elif quota_unit == QuotaUnit.CREDITS: - used_quota = 1 - - if "gpt-4" in model_instance.model: - used_quota = 20 + used_quota = dify_config.get_model_credits(model_instance.model) else: used_quota = 1 diff --git a/api/events/event_handlers/deduct_quota_when_message_created.py b/api/events/event_handlers/deduct_quota_when_message_created.py index 1ed37efba0b3be..d196a4862013b7 100644 --- a/api/events/event_handlers/deduct_quota_when_message_created.py +++ b/api/events/event_handlers/deduct_quota_when_message_created.py @@ -1,3 +1,4 @@ +from configs import dify_config from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity from core.entities.provider_entities import QuotaUnit from events.message_event import message_was_created @@ -37,10 +38,7 @@ def handle(sender, **kwargs): if quota_unit == QuotaUnit.TOKENS: used_quota = message.message_tokens + message.answer_tokens elif quota_unit == QuotaUnit.CREDITS: - used_quota = 1 - - if "gpt-4" in model_config.model: - used_quota = 20 + used_quota = dify_config.get_model_credits(model_config.model) else: used_quota = 1