diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 2cd0b2a7d..fd98db24b 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -56,6 +56,12 @@ jobs: - name: Run Tool run: poetry run -C api bash dev/pytest/pytest_tools.sh + - name: Run mypy + run: | + pushd api + poetry run python -m mypy --install-types --non-interactive . + popd + - name: Set up dotenvs run: | cp docker/.env.example docker/.env diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index b5e63a887..12213380b 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -82,6 +82,33 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' run: yarn run lint + docker-compose-template: + name: Docker Compose Template + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Check changed files + id: changed-files + uses: tj-actions/changed-files@v45 + with: + files: | + docker/generate_docker_compose + docker/.env.example + docker/docker-compose-template.yaml + docker/docker-compose.yaml + + - name: Generate Docker Compose + if: steps.changed-files.outputs.any_changed == 'true' + run: | + cd docker + ./generate_docker_compose + + - name: Check for changes + if: steps.changed-files.outputs.any_changed == 'true' + run: git diff --exit-code superlinter: name: SuperLinter diff --git a/api/.env.example b/api/.env.example index 3e74c1276..4a82a52e9 100644 --- a/api/.env.example +++ b/api/.env.example @@ -23,6 +23,9 @@ FILES_ACCESS_TIMEOUT=300 # Access token expiration time in minutes ACCESS_TOKEN_EXPIRE_MINUTES=60 +# Refresh token expiration time in days +REFRESH_TOKEN_EXPIRE_DAYS=30 + # celery configuration CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 @@ -65,7 +68,7 @@ OPENDAL_FS_ROOT=storage # S3 Storage configuration S3_USE_AWS_MANAGED_IAM=false -S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com +S3_ENDPOINT=https://your-bucket-name.storage.s3.cloudflare.com S3_BUCKET_NAME=your-bucket-name S3_ACCESS_KEY=your-access-key S3_SECRET_KEY=your-secret-key @@ -74,7 +77,7 @@ S3_REGION=your-region # Azure Blob Storage configuration AZURE_BLOB_ACCOUNT_NAME=your-account-name AZURE_BLOB_ACCOUNT_KEY=your-account-key -AZURE_BLOB_CONTAINER_NAME=yout-container-name +AZURE_BLOB_CONTAINER_NAME=your-container-name AZURE_BLOB_ACCOUNT_URL=https://.blob.core.windows.net # Aliyun oss Storage configuration @@ -88,7 +91,7 @@ ALIYUN_OSS_REGION=your-region ALIYUN_OSS_PATH=your-path # Google Storage configuration -GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name +GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string # Tencent COS Storage configuration diff --git a/api/.ruff.toml b/api/.ruff.toml index 26a1b977a..89a2da35d 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -67,7 +67,7 @@ ignore = [ "SIM105", # suppressible-exception "SIM107", # return-in-try-except-finally "SIM108", # if-else-block-instead-of-if-exp - "SIM113", # eumerate-for-loop + "SIM113", # enumerate-for-loop "SIM117", # multiple-with-statements "SIM210", # if-expr-with-true-false ] @@ -85,11 +85,11 @@ ignore = [ ] "tests/*" = [ "F811", # redefined-while-unused - "F401", # unused-import ] [lint.pyflakes] -extend-generics = [ +allowed-unused-imports = [ "_pytest.monkeypatch", "tests.integration_tests", + "tests.unit_tests", ] diff --git a/api/Dockerfile b/api/Dockerfile index c6fc8af77..06309cda3 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -55,7 +55,7 @@ RUN apt-get update \ # && echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \ && apt-get update \ # For Security - && apt-get install -y --no-install-recommends expat=2.6.4-1 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-8 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \ + && apt-get install -y --no-install-recommends expat=2.6.4-1 libldap-2.5-0=2.5.19+dfsg-1 perl=5.40.0-8 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \ # install a chinese font to support the use of tools like matplotlib && apt-get install -y fonts-noto-cjk \ && apt-get autoremove -y \ diff --git a/api/app.py b/api/app.py index c6a082908..9830a8090 100644 --- a/api/app.py +++ b/api/app.py @@ -1,12 +1,8 @@ -from libs import version_utils - -# preparation before creating app -version_utils.check_supported_python_version() +import os +import sys def is_db_command(): - import sys - if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db": return True return False @@ -18,10 +14,25 @@ def is_db_command(): app = create_migrations_app() else: - from app_factory import create_app - from libs import threadings_utils + # It seems that JetBrains Python debugger does not work well with gevent, + # so we need to disable gevent in debug mode. + # If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent. + if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}: + from gevent import monkey # type: ignore + + # gevent + monkey.patch_all() + + from grpc.experimental import gevent as grpc_gevent # type: ignore - threadings_utils.apply_gevent_threading_patch() + # grpc gevent + grpc_gevent.init_gevent() + + import psycogreen.gevent # type: ignore + + psycogreen.gevent.patch_psycopg() + + from app_factory import create_app app = create_app() celery = app.extensions["celery"] diff --git a/api/commands.py b/api/commands.py index bf013cc77..334e7daab 100644 --- a/api/commands.py +++ b/api/commands.py @@ -159,8 +159,7 @@ def migrate_annotation_vector_database(): try: # get apps info apps = ( - db.session.query(App) - .filter(App.status == "normal") + App.query.filter(App.status == "normal") .order_by(App.created_at.desc()) .paginate(page=page, per_page=50) ) @@ -285,8 +284,7 @@ def migrate_knowledge_vector_database(): while True: try: datasets = ( - db.session.query(Dataset) - .filter(Dataset.indexing_technique == "high_quality") + Dataset.query.filter(Dataset.indexing_technique == "high_quality") .order_by(Dataset.created_at.desc()) .paginate(page=page, per_page=50) ) @@ -450,7 +448,8 @@ def convert_to_agent_apps(): if app_id not in proceeded_app_ids: proceeded_app_ids.append(app_id) app = db.session.query(App).filter(App.id == app_id).first() - apps.append(app) + if app is not None: + apps.append(app) if len(apps) == 0: break @@ -562,8 +561,13 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str new_password = secrets.token_urlsafe(16) # register account - account = RegisterService.register(email=email, name=account_name, password=new_password, language=language) - + account = RegisterService.register( + email=email, + name=account_name, + password=new_password, + language=language, + create_workspace_required=False, + ) TenantService.create_owner_tenant_if_not_exist(account, name) click.echo( @@ -583,7 +587,7 @@ def upgrade_db(): click.echo(click.style("Starting database migration.", fg="green")) # run db migration - import flask_migrate + import flask_migrate # type: ignore flask_migrate.upgrade() @@ -621,6 +625,10 @@ def fix_app_site_missing(): try: app = db.session.query(App).filter(App.id == app_id).first() + if not app: + print(f"App {app_id} not found") + continue + tenant = app.tenant if tenant: accounts = tenant.get_accounts() diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 73f8a9598..59309fd25 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -239,7 +239,6 @@ class HttpConfig(BaseSettings): ) @computed_field - @property def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]: return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",") @@ -250,7 +249,6 @@ def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]: ) @computed_field - @property def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]: return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",") @@ -490,6 +488,11 @@ class AuthConfig(BaseSettings): default=60, ) + REFRESH_TOKEN_EXPIRE_DAYS: PositiveFloat = Field( + description="Expiration time for refresh tokens in days", + default=30, + ) + LOGIN_LOCKOUT_DURATION: PositiveInt = Field( description="Time (in seconds) a user must wait before retrying login after exceeding the rate limit.", default=86400, @@ -603,7 +606,7 @@ class RagEtlConfig(BaseSettings): UNSTRUCTURED_API_KEY: Optional[str] = Field( description="API key for Unstructured.io service", - default=None, + default="", ) SCARF_NO_ANALYTICS: Optional[str] = Field( @@ -669,6 +672,11 @@ class IndexingConfig(BaseSettings): default=4000, ) + CHILD_CHUNKS_PREVIEW_NUMBER: PositiveInt = Field( + description="Maximum number of child chunks to preview", + default=50, + ) + class MultiModalTransferConfig(BaseSettings): MULTIMODAL_SEND_FORMAT: Literal["base64", "url"] = Field( @@ -715,27 +723,27 @@ class PositionConfig(BaseSettings): default="", ) - @computed_field + @property def POSITION_PROVIDER_PINS_LIST(self) -> list[str]: return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""] - @computed_field + @property def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""} - @computed_field + @property def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""} - @computed_field + @property def POSITION_TOOL_PINS_LIST(self) -> list[str]: return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""] - @computed_field + @property def POSITION_TOOL_INCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""} - @computed_field + @property def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]: return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} @@ -767,6 +775,13 @@ class LoginConfig(BaseSettings): ) +class AccountConfig(BaseSettings): + ACCOUNT_DELETION_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( + description="Duration in minutes for which a account deletion token remains valid", + default=5, + ) + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, @@ -794,6 +809,7 @@ class FeatureConfig( WorkflowNodeExecutionConfig, WorkspaceConfig, LoginConfig, + AccountConfig, # hosted services config HostedServiceConfig, CeleryBeatConfig, diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 9265a48d9..f6a44eaa4 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -130,7 +130,6 @@ class DatabaseConfig(BaseSettings): ) @computed_field - @property def SQLALCHEMY_DATABASE_URI(self) -> str: db_extras = ( f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS @@ -168,7 +167,6 @@ def SQLALCHEMY_DATABASE_URI(self) -> str: ) @computed_field - @property def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: return { "pool_size": self.SQLALCHEMY_POOL_SIZE, @@ -206,7 +204,6 @@ class CeleryConfig(DatabaseConfig): ) @computed_field - @property def CELERY_RESULT_BACKEND(self) -> str | None: return ( "db+{}".format(self.SQLALCHEMY_DATABASE_URI) @@ -214,7 +211,6 @@ def CELERY_RESULT_BACKEND(self) -> str | None: else self.CELERY_BROKER_URL ) - @computed_field @property def BROKER_USE_SSL(self) -> bool: return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False diff --git a/api/configs/middleware/vdb/milvus_config.py b/api/configs/middleware/vdb/milvus_config.py index 231cbbbe8..ebdf8857b 100644 --- a/api/configs/middleware/vdb/milvus_config.py +++ b/api/configs/middleware/vdb/milvus_config.py @@ -33,3 +33,9 @@ class MilvusConfig(BaseSettings): description="Name of the Milvus database to connect to (default is 'default')", default="default", ) + + MILVUS_ENABLE_HYBRID_SEARCH: bool = Field( + description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with " + "older versions", + default=True, + ) diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index 4a168a3fb..a54c5bf5e 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description="Dify version", - default="0.14.2", + default="0.15.1", ) COMMIT_SHA: str = Field( diff --git a/api/configs/remote_settings_sources/apollo/client.py b/api/configs/remote_settings_sources/apollo/client.py index d1f6781ed..03c64ea00 100644 --- a/api/configs/remote_settings_sources/apollo/client.py +++ b/api/configs/remote_settings_sources/apollo/client.py @@ -4,6 +4,7 @@ import os import threading import time +from collections.abc import Mapping from pathlib import Path from .python_3x import http_request, makedirs_wrapper @@ -255,8 +256,8 @@ def _listener(self): logger.info("stopped, long_poll") # add the need for endorsement to the header - def _sign_headers(self, url): - headers = {} + def _sign_headers(self, url: str) -> Mapping[str, str]: + headers: dict[str, str] = {} if self.secret == "": return headers uri = url[len(self.config_url) : len(url)] diff --git a/api/constants/model_template.py b/api/constants/model_template.py index 7e1a19635..c26d8c018 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -1,8 +1,9 @@ import json +from collections.abc import Mapping from models.model import AppMode -default_app_templates = { +default_app_templates: Mapping[AppMode, Mapping] = { # workflow default mode AppMode.WORKFLOW: { "app": { diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index 79869916e..b1ebc444a 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore parameters__system_parameters = { "image_file_size_limit": fields.Integer, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index ad68993fe..d58d9b403 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -3,6 +3,25 @@ from libs.external_api import ExternalApi from .app.app_import import AppImportApi, AppImportConfirmApi +from .explore.audio import ChatAudioApi, ChatTextApi +from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi +from .explore.conversation import ( + ConversationApi, + ConversationListApi, + ConversationPinApi, + ConversationRenameApi, + ConversationUnPinApi, +) +from .explore.message import ( + MessageFeedbackApi, + MessageListApi, + MessageMoreLikeThisApi, + MessageSuggestedQuestionApi, +) +from .explore.workflow import ( + InstalledAppWorkflowRunApi, + InstalledAppWorkflowTaskStopApi, +) from .files import FileApi, FilePreviewApi, FileSupportTypeApi from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi @@ -69,15 +88,81 @@ # Import explore controllers from .explore import ( - audio, - completion, - conversation, installed_app, - message, parameter, recommended_app, saved_message, - workflow, +) + +# Explore Audio +api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") +api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") + +# Explore Completion +api.add_resource( + CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion" +) +api.add_resource( + CompletionStopApi, + "/installed-apps//completion-messages//stop", + endpoint="installed_app_stop_completion", +) +api.add_resource( + ChatApi, "/installed-apps//chat-messages", endpoint="installed_app_chat_completion" +) +api.add_resource( + ChatStopApi, + "/installed-apps//chat-messages//stop", + endpoint="installed_app_stop_chat_completion", +) + +# Explore Conversation +api.add_resource( + ConversationRenameApi, + "/installed-apps//conversations//name", + endpoint="installed_app_conversation_rename", +) +api.add_resource( + ConversationListApi, "/installed-apps//conversations", endpoint="installed_app_conversations" +) +api.add_resource( + ConversationApi, + "/installed-apps//conversations/", + endpoint="installed_app_conversation", +) +api.add_resource( + ConversationPinApi, + "/installed-apps//conversations//pin", + endpoint="installed_app_conversation_pin", +) +api.add_resource( + ConversationUnPinApi, + "/installed-apps//conversations//unpin", + endpoint="installed_app_conversation_unpin", +) + + +# Explore Message +api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages") +api.add_resource( + MessageFeedbackApi, + "/installed-apps//messages//feedbacks", + endpoint="installed_app_message_feedback", +) +api.add_resource( + MessageMoreLikeThisApi, + "/installed-apps//messages//more-like-this", + endpoint="installed_app_more_like_this", +) +api.add_resource( + MessageSuggestedQuestionApi, + "/installed-apps//messages//suggested-questions", + endpoint="installed_app_suggested_question", +) +# Explore Workflow +api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/run") +api.add_resource( + InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" ) # Import tag controllers diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 8c0bf8710..52e0bb6c5 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -1,7 +1,7 @@ from functools import wraps from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 4b428d6aa..09d535298 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,6 +1,8 @@ -import flask_restful +from typing import Any + +import flask_restful # type: ignore from flask import request # 二开部分 - 密钥额度限制 -from flask_login import current_user +from flask_login import current_user # type: ignore from flask_restful import Resource, fields, marshal_with from sqlalchemy.orm import aliased # 二开部分 - 密钥额度限制 from werkzeug.exceptions import Forbidden @@ -46,14 +48,15 @@ def _get_resource(resource_id, tenant_id, resource_model): class BaseApiKeyListResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type = None - resource_model = None - resource_id_field = None - token_prefix = None + resource_type: str | None = None + resource_model: Any = None + resource_id_field: str | None = None + token_prefix: str | None = None max_keys = 10 @marshal_with(api_key_list) def get(self, resource_id): + assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) # keys = ( @@ -85,6 +88,7 @@ def get(self, resource_id): @marshal_with(api_key_fields) def post(self, resource_id): + assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) if not current_user.is_admin_or_owner: @@ -143,7 +147,7 @@ def post(self, resource_id): # --------------------- 二开部分End - 密钥额度限制 --------------------- return api_token, 201 - + # --------------------- 二开部分Begin - 密钥额度限制 --------------------- @marshal_with(api_key_fields) def put(self, resource_id): @@ -152,7 +156,7 @@ def put(self, resource_id): if not current_user.is_admin_or_owner: raise Forbidden() - + content_type = request.headers.get("Content-Type") if content_type == "application/json": try: @@ -212,11 +216,12 @@ def put(self, resource_id): class BaseApiKeyResource(Resource): method_decorators = [account_initialization_required, login_required, setup_required] - resource_type = None - resource_model = None - resource_id_field = None + resource_type: str | None = None + resource_model: Any = None + resource_id_field: str | None = None def delete(self, resource_id, api_key_id): + assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) api_key_id = str(api_key_id) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index c228743fa..8d0c5b84a 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.wraps import account_initialization_required, setup_required diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index d43341589..920cae0d8 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index fd05cbc19..24f1020c1 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,6 +1,6 @@ from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -110,7 +110,7 @@ def get(self, app_id): page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) - keyword = request.args.get("keyword", default=None, type=str) + keyword = request.args.get("keyword", default="", type=str) app_id = str(app_id) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 6a5d3903a..e0997c844 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,8 +1,8 @@ import uuid from typing import cast -from flask_login import current_user -from flask_restful import Resource, inputs, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, inputs, marshal, marshal_with, reqparse # type: ignore from sqlalchemy import select from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, abort @@ -57,12 +57,13 @@ def uuid_list(value): ) parser.add_argument("name", type=str, location="args", required=False) parser.add_argument("tag_ids", type=uuid_list, location="args", required=False) + parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False) args = parser.parse_args() # get app list app_service = AppService() - app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args) + app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args) if not app_pagination: # ---------------- start app list return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False, "recommended_apps": []} diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 244dcd75d..7e2888d71 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,7 +1,7 @@ from typing import cast -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 695b8890e..12d9157dd 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -1,7 +1,7 @@ import logging from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError import services @@ -22,7 +22,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required -from models.model import AppMode +from models import App, AppMode from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, @@ -79,7 +79,7 @@ class ChatMessageTextApi(Resource): @login_required @account_initialization_required @get_app_model - def post(self, app_model): + def post(self, app_model: App): from werkzeug.exceptions import InternalServerError try: @@ -98,9 +98,13 @@ def post(self, app_model): and app_model.workflow.features_dict ): text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + if text_to_speech is None: + raise ValueError("TTS is not enabled") voice = args.get("voice") or text_to_speech.get("voice") else: try: + if app_model.app_model_config is None: + raise ValueError("AppModelConfig not found") voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") except Exception: voice = None diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 52d254323..630ec9bfe 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -1,7 +1,7 @@ import logging -import flask_login -from flask_restful import Resource, reqparse +import flask_login # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services @@ -21,7 +21,6 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ( - AppInvokeQuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError, @@ -78,7 +77,7 @@ def post(self, app_model): raise ProviderModelCurrentlyNotSupportError() except InvokeError as e: raise CompletionRequestError(e.description) - except (ValueError, AppInvokeQuotaExceededError) as e: + except ValueError as e: raise e except Exception as e: logging.exception("internal server error.") @@ -144,7 +143,7 @@ def post(self, app_model): raise InvokeRateLimitHttpError(ex.description) except InvokeError as e: raise CompletionRequestError(e.description) - except (ValueError, AppInvokeQuotaExceededError) as e: + except ValueError as e: raise e except Exception as e: logging.exception("internal server error.") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index a25004be4..8827f129d 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,9 +1,9 @@ from datetime import UTC, datetime -import pytz -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +import pytz # pip install pytz +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from sqlalchemy import func, or_ from sqlalchemy.orm import joinedload from werkzeug.exceptions import Forbidden, NotFound @@ -77,8 +77,9 @@ def get(self, app_model): query = query.where(Conversation.created_at < end_datetime_utc) + # FIXME, the type ignore in this file if args["annotation_status"] == "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( + query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) elif args["annotation_status"] == "not_annotated": @@ -222,7 +223,7 @@ def get(self, app_model): query = query.where(Conversation.created_at <= end_datetime_utc) if args["annotation_status"] == "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( + query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) elif args["annotation_status"] == "not_annotated": @@ -234,7 +235,7 @@ def get(self, app_model): if args["message_count_gte"] and args["message_count_gte"] >= 1: query = ( - query.options(joinedload(Conversation.messages)) + query.options(joinedload(Conversation.messages)) # type: ignore .join(Message, Message.conversation_id == Conversation.id) .group_by(Conversation.id) .having(func.count(Message.id) >= args["message_count_gte"]) diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index d49f433ba..c0a20b716 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, marshal_with, reqparse +from flask_restful import Resource, marshal_with, reqparse # type: ignore from sqlalchemy import select from sqlalchemy.orm import Session diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 9c3cbe4e3..8518d34a8 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,7 +1,7 @@ import os -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.error import ( diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index b7a4c31a1..b5828b6b4 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,8 +1,8 @@ import logging -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from controllers.console import api diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index a46bc6a8a..8ecc8a9db 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -1,8 +1,9 @@ import json +from typing import cast from flask import request -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model @@ -26,7 +27,9 @@ def post(self, app_model): """Modify app model config""" # validate config model_configuration = AppModelConfigService.validate_configuration( - tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode) + tenant_id=current_user.current_tenant_id, + config=cast(dict, request.json), + app_mode=AppMode.value_of(app_model.mode), ) new_app_model_config = AppModelConfig( @@ -38,9 +41,11 @@ def post(self, app_model): if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: # get original app model config - original_app_model_config: AppModelConfig = ( + original_app_model_config = ( db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() ) + if original_app_model_config is None: + raise ValueError("Original app model config not found") agent_mode = original_app_model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input parameter_map = {} diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index 3f10215e7..dd25af8eb 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import BadRequest from controllers.console import api diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 407f68981..db29b95c4 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,7 +1,7 @@ from datetime import UTC, datetime -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound from constants.languages import supported_language @@ -50,7 +50,7 @@ def post(self, app_model): if not current_user.is_editor: raise Forbidden() - site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404() + site = Site.query.filter(Site.app_id == app_model.id).one_or_404() for attr_name in [ "title", diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 8824448d4..10acbd1ce 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -3,8 +3,8 @@ import pytz from flask import jsonify -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model @@ -302,8 +302,7 @@ def get(self, app_model): messages m ON c.id = m.conversation_id WHERE - c.override_model_configs IS NULL - AND c.app_id = :app_id""" + c.app_id = :app_id""" arg_dict = {"tz": account.timezone, "app_id": app_model.id} # Extend Start: added a new app personal expenses page diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 9aebf0d6b..a8cec9265 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -2,7 +2,7 @@ import logging from flask import abort, request -from flask_restful import Resource, marshal_with, reqparse +from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -15,7 +15,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from factories import variable_factory -from fields.workflow_fields import workflow_fields +from fields.workflow_fields import workflow_fields, workflow_pagination_fields from fields.workflow_run_fields import workflow_run_node_execution_fields from libs import helper from libs.helper import TimestampField, uuid_value @@ -443,6 +443,31 @@ def get(self, app_model: App): } +class PublishedAllWorkflowApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @marshal_with(workflow_pagination_fields) + def get(self, app_model: App): + """ + Get published workflows + """ + if not current_user.is_editor: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") + args = parser.parse_args() + page = args.get("page") + limit = args.get("limit") + workflow_service = WorkflowService() + workflows, has_more = workflow_service.get_all_published_workflow(app_model=app_model, page=page, limit=limit) + + return {"items": workflows, "page": page, "limit": limit, "has_more": has_more} + + api.add_resource(DraftWorkflowApi, "/apps//workflows/draft") api.add_resource(WorkflowConfigApi, "/apps//workflows/draft/config") api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps//advanced-chat/workflows/draft/run") @@ -457,6 +482,7 @@ def get(self, app_model: App): WorkflowDraftRunIterationNodeApi, "/apps//workflows/draft/iteration/nodes//run" ) api.add_resource(PublishedWorkflowApi, "/apps//workflows/publish") +api.add_resource(PublishedAllWorkflowApi, "/apps//workflows") api.add_resource(DefaultBlockConfigsApi, "/apps//workflows/default-workflow-block-configs") api.add_resource( DefaultBlockConfigApi, "/apps//workflows/default-workflow-block-configs/" diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 2940556f8..882c53e4f 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -1,5 +1,5 @@ -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 08ab61bbb..25a99c1e1 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,5 +1,5 @@ -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index d3d36ee09..e79672595 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -3,8 +3,8 @@ import pytz from flask import jsonify -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.app.wraps import get_app_model diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 63edb8307..9ad8c1584 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -8,7 +8,7 @@ from models import App, AppMode -def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): +def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None): def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index d2aa7c903..c56f551d4 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,14 +1,14 @@ import datetime from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from constants.languages import supported_language from controllers.console import api from controllers.console.error import AlreadyActivateError from extensions.ext_database import db from libs.helper import StrLen, email, extract_remote_ip, timezone -from models.account import AccountStatus, Tenant +from models.account import AccountStatus from services.account_service import AccountService, RegisterService @@ -27,7 +27,7 @@ def get(self): invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) if invitation: data = invitation.get("data", {}) - tenant: Tenant = invitation.get("tenant", None) + tenant = invitation.get("tenant", None) workspace_name = tenant.name if tenant else None workspace_id = tenant.id if tenant else None invitee_email = data.get("email") if data else None diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 465c44e9b..ea00c2b8c 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index faca67bb1..e911c9a5e 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -2,8 +2,8 @@ import requests from flask import current_app, redirect, request -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from werkzeug.exceptions import Forbidden from configs import dify_config @@ -17,8 +17,8 @@ def get_oauth_providers(): with current_app.app_context(): notion_oauth = NotionOAuth( - client_id=dify_config.NOTION_CLIENT_ID, - client_secret=dify_config.NOTION_CLIENT_SECRET, + client_id=dify_config.NOTION_CLIENT_ID or "", + client_secret=dify_config.NOTION_CLIENT_SECRET or "", redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion", ) diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py index e6e30c3c0..8ef10c7bb 100644 --- a/api/controllers/console/auth/error.py +++ b/api/controllers/console/auth/error.py @@ -53,3 +53,9 @@ class EmailCodeLoginRateLimitExceededError(BaseHTTPException): error_code = "email_code_login_rate_limit_exceeded" description = "Too many login emails have been sent. Please try again in 5 minutes." code = 429 + + +class EmailCodeAccountDeletionRateLimitExceededError(BaseHTTPException): + error_code = "email_code_account_deletion_rate_limit_exceeded" + description = "Too many account deletion emails have been sent. Please try again in 5 minutes." + code = 429 diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index fb32bb2b6..a9c4300b9 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -2,17 +2,12 @@ import secrets from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from constants.languages import languages from controllers.console import api -from controllers.console.auth.error import ( - EmailCodeError, - InvalidEmailError, - InvalidTokenError, - PasswordMismatchError, -) -from controllers.console.error import AccountNotFound, EmailSendIpLimitError +from controllers.console.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError, PasswordMismatchError +from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError from controllers.console.wraps import setup_required from events.tenant_event import tenant_was_created from extensions.ext_database import db @@ -20,6 +15,7 @@ from libs.password import hash_password, valid_password from models.account import Account from services.account_service import AccountService, TenantService +from services.errors.account import AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError from services.feature_service import FeatureService @@ -122,13 +118,15 @@ def post(self): else: try: account = AccountService.create_account_and_tenant( - email=reset_data.get("email"), - name=reset_data.get("email"), + email=reset_data.get("email", ""), + name=reset_data.get("email", ""), password=password_confirm, interface_language=languages[0], ) except WorkSpaceNotAllowedCreateError: pass + except AccountRegisterError as are: + raise AccountInFreezeError() return {"result": "success"} diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index f4463ce9c..41362e9fa 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,10 +1,11 @@ from typing import cast -import flask_login +import flask_login # type: ignore from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore import services +from configs import dify_config from constants.languages import languages from controllers.console import api from controllers.console.auth.error import ( @@ -16,6 +17,7 @@ ) from controllers.console.error import ( AccountBannedError, + AccountInFreezeError, AccountNotFound, EmailSendIpLimitError, NotAllowedCreateWorkspace, @@ -26,6 +28,8 @@ from libs.password import valid_password from models.account import Account from services.account_service import AccountService, RegisterService, TenantService +from services.billing_service import BillingService +from services.errors.account import AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError from services.feature_service import FeatureService @@ -44,6 +48,9 @@ def post(self): parser.add_argument("language", type=str, required=False, default="en-US", location="json") args = parser.parse_args() + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): + raise AccountInFreezeError() + is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"]) if is_login_error_rate_limit: raise EmailPasswordLoginLimitError() @@ -113,8 +120,10 @@ def post(self): language = "zh-Hans" else: language = "en-US" - - account = AccountService.get_user_through_email(args["email"]) + try: + account = AccountService.get_user_through_email(args["email"]) + except AccountRegisterError as are: + raise AccountInFreezeError() if account is None: if FeatureService.get_system_features().is_allow_register: token = AccountService.send_reset_password_email(email=args["email"], language=language) @@ -142,8 +151,11 @@ def post(self): language = "zh-Hans" else: language = "en-US" + try: + account = AccountService.get_user_through_email(args["email"]) + except AccountRegisterError as are: + raise AccountInFreezeError() - account = AccountService.get_user_through_email(args["email"]) if account is None: if FeatureService.get_system_features().is_allow_register: token = AccountService.send_email_code_login_email(email=args["email"], language=language) @@ -177,7 +189,10 @@ def post(self): raise EmailCodeError() AccountService.revoke_email_code_login_token(args["token"]) - account = AccountService.get_user_through_email(user_email) + try: + account = AccountService.get_user_through_email(user_email) + except AccountRegisterError as are: + raise AccountInFreezeError() if account: tenant = TenantService.get_join_tenants(account) if not tenant: @@ -196,6 +211,8 @@ def post(self): ) except WorkSpaceNotAllowedCreateError: return NotAllowedCreateWorkspace() + except AccountRegisterError as are: + raise AccountInFreezeError() token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) AccountService.reset_login_error_rate_limit(args["email"]) return {"result": "success", "data": token_pair.model_dump()} diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 45e262346..27a802de4 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -4,7 +4,7 @@ import requests from flask import current_app, redirect, request -from flask_restful import Resource +from flask_restful import Resource # type: ignore from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -17,7 +17,7 @@ from models.account import AccountStatus from services.account_service import AccountService, RegisterService, TenantService from services.account_service_extend import TenantExtendService -from services.errors.account import AccountNotFoundError +from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError from services.feature_service import FeatureService @@ -87,7 +87,8 @@ def get(self, provider: str): token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) except requests.exceptions.RequestException as e: - logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") + error_text = e.response.text if e.response else str(e) + logging.exception(f"An error occurred during the OAuth process with {provider}: {error_text}") return {"error": "OAuth process failed"}, 400 if invite_token and RegisterService.is_valid_invite_token(invite_token): @@ -108,6 +109,8 @@ def get(self, provider: str): f"{dify_config.CONSOLE_WEB_URL}/signin" "?message=Workspace not found, please contact system admin to invite you to join in a workspace." ) + except AccountRegisterError as e: + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}") # Check account status if account.status == AccountStatus.BANNED.value: @@ -151,7 +154,7 @@ def get(self, provider: str): def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: - account = Account.get_by_openid(provider, user_info.id) + account: Optional[Account] = Account.get_by_openid(provider, user_info.id) if not account: account = Account.query.filter_by(email=user_info.email).first() diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 4b0c82ae6..fd7b7bd8c 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 278295ca3..3a4a6d75e 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -2,8 +2,8 @@ import json from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from werkzeug.exceptions import NotFound from controllers.console import api @@ -218,7 +218,7 @@ def post(self): args["doc_form"], args["doc_language"], ) - return response, 200 + return response.model_dump(), 200 class DataSourceNotionDatasetSyncApi(Resource): diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 95d4013e3..386e45c58 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,7 +1,7 @@ -import flask_restful +import flask_restful # type: ignore from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore # type: ignore +from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound import services @@ -52,12 +52,12 @@ def get(self): # provider = request.args.get("provider", default="vendor") search = request.args.get("keyword", default=None, type=str) tag_ids = request.args.getlist("tag_ids") - + include_all = request.args.get("include_all", default="false").lower() == "true" if ids: datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) else: datasets, total = DatasetService.get_datasets( - page, limit, current_user.current_tenant_id, current_user, search, tag_ids + page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all ) # check embedding setting @@ -464,7 +464,7 @@ def post(self): except Exception as e: raise IndexingEstimateError(str(e)) - return response, 200 + return response.model_dump(), 200 class DatasetRelatedAppListApi(Resource): @@ -640,6 +640,7 @@ def get(self): | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH + | VectorType.ELASTICSEARCH_JA | VectorType.PGVECTOR | VectorType.TIDB_ON_QDRANT | VectorType.LINDORM @@ -683,6 +684,7 @@ def get(self, vector_type): | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH + | VectorType.ELASTICSEARCH_JA | VectorType.COUCHBASE | VectorType.PGVECTOR | VectorType.LINDORM @@ -733,6 +735,18 @@ def get(self, dataset_id): }, 200 +class DatasetAutoDisableLogApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id): + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200 + + api.add_resource(DatasetListApi, "/datasets") api.add_resource(DatasetApi, "/datasets/") api.add_resource(DatasetUseCheckApi, "/datasets//use-check") @@ -747,3 +761,4 @@ def get(self, dataset_id): api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/") api.add_resource(DatasetPermissionUserListApi, "/datasets//permission-part-users") +api.add_resource(DatasetAutoDisableLogApi, "/datasets//auto-disable-logs") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index ad4768f51..c11beaeee 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,12 +1,13 @@ import logging from argparse import ArgumentTypeError from datetime import UTC, datetime +from typing import cast from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore from sqlalchemy import asc, desc -from transformers.hf_argparser import string_to_bool +from transformers.hf_argparser import string_to_bool # type: ignore from werkzeug.exceptions import Forbidden, NotFound import services @@ -51,6 +52,7 @@ from libs.login import login_required from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from services.dataset_service import DatasetService, DocumentService +from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from tasks.add_document_to_index_task import add_document_to_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task @@ -254,20 +256,23 @@ def post(self, dataset_id): parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") parser.add_argument("original_document_id", type=str, required=False, location="json") parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") + parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") + parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") parser.add_argument( "doc_language", type=str, default="English", required=False, nullable=False, location="json" ) - parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() + knowledge_config = KnowledgeConfig(**args) - if not dataset.indexing_technique and not args["indexing_technique"]: + if not dataset.indexing_technique and not knowledge_config.indexing_technique: raise ValueError("indexing_technique is required.") # validate args - DocumentService.document_create_args_validate(args) + DocumentService.document_create_args_validate(knowledge_config) try: - documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except QuotaExceededError: @@ -277,6 +282,25 @@ def post(self, dataset_id): return {"documents": documents, "batch": batch} + @setup_required + @login_required + @account_initialization_required + def delete(self, dataset_id): + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if dataset is None: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + + try: + document_ids = request.args.getlist("document_id") + DocumentService.delete_documents(dataset, document_ids) + except services.errors.document.DocumentIndexingError: + raise DocumentIndexingError("Cannot delete document during indexing.") + + return {"result": "success"}, 204 + class DatasetInitApi(Resource): @setup_required @@ -312,9 +336,9 @@ def post(self): # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: raise Forbidden() - - if args["indexing_technique"] == "high_quality": - if args["embedding_model"] is None or args["embedding_model_provider"] is None: + knowledge_config = KnowledgeConfig(**args) + if knowledge_config.indexing_technique == "high_quality": + if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") try: model_manager = ModelManager() @@ -333,11 +357,11 @@ def post(self): raise ProviderNotInitializeError(ex.description) # validate args - DocumentService.document_create_args_validate(args) + DocumentService.document_create_args_validate(knowledge_config) try: dataset, documents, batch = DocumentService.save_document_without_dataset_id( - tenant_id=current_user.current_tenant_id, document_data=args, account=current_user + tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -390,7 +414,7 @@ def get(self, dataset_id, document_id): indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate( + estimate_response = indexing_runner.indexing_estimate( current_user.current_tenant_id, [extract_setting], data_process_rule_dict, @@ -398,6 +422,7 @@ def get(self, dataset_id, document_id): "English", dataset_id, ) + return estimate_response.model_dump(), 200 except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " @@ -408,7 +433,7 @@ def get(self, dataset_id, document_id): except Exception as e: raise IndexingEstimateError(str(e)) - return response + return response, 200 class DocumentBatchIndexingEstimateApi(DocumentResource): @@ -419,9 +444,8 @@ def get(self, dataset_id, batch): dataset_id = str(dataset_id) batch = str(batch) documents = self.get_batch_documents(dataset_id, batch) - response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} if not documents: - return response + return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200 data_process_rule = documents[0].dataset_process_rule data_process_rule_dict = data_process_rule.to_dict() info_list = [] @@ -499,6 +523,7 @@ def get(self, dataset_id, batch): "English", dataset_id, ) + return response.model_dump(), 200 except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " @@ -508,7 +533,6 @@ def get(self, dataset_id, batch): raise ProviderNotInitializeError(ex.description) except Exception as e: raise IndexingEstimateError(str(e)) - return response class DocumentBatchIndexingStatusApi(DocumentResource): @@ -581,7 +605,8 @@ def get(self, dataset_id, document_id): if metadata == "only": response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata} elif metadata == "without": - process_rules = DatasetService.get_process_rules(dataset_id) + dataset_process_rules = DatasetService.get_process_rules(dataset_id) + document_process_rules = document.dataset_process_rule.to_dict() data_source_info = document.data_source_detail_dict response = { "id": document.id, @@ -589,7 +614,8 @@ def get(self, dataset_id, document_id): "data_source_type": document.data_source_type, "data_source_info": data_source_info, "dataset_process_rule_id": document.dataset_process_rule_id, - "dataset_process_rule": process_rules, + "dataset_process_rule": dataset_process_rules, + "document_process_rule": document_process_rules, "name": document.name, "created_from": document.created_from, "created_by": document.created_by, @@ -612,7 +638,8 @@ def get(self, dataset_id, document_id): "doc_language": document.doc_language, } else: - process_rules = DatasetService.get_process_rules(dataset_id) + dataset_process_rules = DatasetService.get_process_rules(dataset_id) + document_process_rules = document.dataset_process_rule.to_dict() data_source_info = document.data_source_detail_dict response = { "id": document.id, @@ -620,7 +647,8 @@ def get(self, dataset_id, document_id): "data_source_type": document.data_source_type, "data_source_info": data_source_info, "dataset_process_rule_id": document.dataset_process_rule_id, - "dataset_process_rule": process_rules, + "dataset_process_rule": dataset_process_rules, + "document_process_rule": document_process_rules, "name": document.name, "created_from": document.created_from, "created_by": document.created_by, @@ -733,8 +761,7 @@ def put(self, dataset_id, document_id): if not isinstance(doc_metadata, dict): raise ValueError("doc_metadata must be a dictionary.") - - metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] + metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]) document.doc_metadata = {} if doc_type == "others": @@ -757,9 +784,8 @@ class DocumentStatusApi(DocumentResource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("vector_space") - def patch(self, dataset_id, document_id, action): + def patch(self, dataset_id, action): dataset_id = str(dataset_id) - document_id = str(document_id) dataset = DatasetService.get_dataset(dataset_id) if dataset is None: raise NotFound("Dataset not found.") @@ -774,84 +800,79 @@ def patch(self, dataset_id, document_id, action): # check user's permission DatasetService.check_dataset_permission(dataset, current_user) - document = self.get_document(dataset_id, document_id) + document_ids = request.args.getlist("document_id") + for document_id in document_ids: + document = self.get_document(dataset_id, document_id) - indexing_cache_key = "document_{}_indexing".format(document.id) - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - raise InvalidActionError("Document is being indexed, please try again later") + indexing_cache_key = "document_{}_indexing".format(document.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later") - if action == "enable": - if document.enabled: - raise InvalidActionError("Document already enabled.") + if action == "enable": + if document.enabled: + continue + document.enabled = True + document.disabled_at = None + document.disabled_by = None + document.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() - document.enabled = True - document.disabled_at = None - document.disabled_by = None - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() + # Set cache to prevent indexing the same document multiple times + redis_client.setex(indexing_cache_key, 600, 1) - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) + add_document_to_index_task.delay(document_id) - add_document_to_index_task.delay(document_id) + elif action == "disable": + if not document.completed_at or document.indexing_status != "completed": + raise InvalidActionError(f"Document: {document.name} is not completed.") + if not document.enabled: + continue - return {"result": "success"}, 200 + document.enabled = False + document.disabled_at = datetime.now(UTC).replace(tzinfo=None) + document.disabled_by = current_user.id + document.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() - elif action == "disable": - if not document.completed_at or document.indexing_status != "completed": - raise InvalidActionError("Document is not completed.") - if not document.enabled: - raise InvalidActionError("Document already disabled.") + # Set cache to prevent indexing the same document multiple times + redis_client.setex(indexing_cache_key, 600, 1) - document.enabled = False - document.disabled_at = datetime.now(UTC).replace(tzinfo=None) - document.disabled_by = current_user.id - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() + remove_document_from_index_task.delay(document_id) - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) + elif action == "archive": + if document.archived: + continue - remove_document_from_index_task.delay(document_id) + document.archived = True + document.archived_at = datetime.now(UTC).replace(tzinfo=None) + document.archived_by = current_user.id + document.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() - return {"result": "success"}, 200 + if document.enabled: + # Set cache to prevent indexing the same document multiple times + redis_client.setex(indexing_cache_key, 600, 1) - elif action == "archive": - if document.archived: - raise InvalidActionError("Document already archived.") + remove_document_from_index_task.delay(document_id) - document.archived = True - document.archived_at = datetime.now(UTC).replace(tzinfo=None) - document.archived_by = current_user.id - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() + elif action == "un_archive": + if not document.archived: + continue + document.archived = False + document.archived_at = None + document.archived_by = None + document.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() - if document.enabled: # Set cache to prevent indexing the same document multiple times redis_client.setex(indexing_cache_key, 600, 1) - remove_document_from_index_task.delay(document_id) - - return {"result": "success"}, 200 - elif action == "un_archive": - if not document.archived: - raise InvalidActionError("Document is not archived.") - - document.archived = False - document.archived_at = None - document.archived_by = None - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) + add_document_to_index_task.delay(document_id) - add_document_to_index_task.delay(document_id) - - return {"result": "success"}, 200 - else: - raise InvalidActionError() + else: + raise InvalidActionError() + return {"result": "success"}, 200 class DocumentPauseApi(DocumentResource): @@ -1022,7 +1043,7 @@ def get(self, dataset_id, document_id): ) api.add_resource(DocumentDeleteApi, "/datasets//documents/") api.add_resource(DocumentMetadataApi, "/datasets//documents//metadata") -api.add_resource(DocumentStatusApi, "/datasets//documents//status/") +api.add_resource(DocumentStatusApi, "/datasets//documents/status//batch") api.add_resource(DocumentPauseApi, "/datasets//documents//processing/pause") api.add_resource(DocumentRecoverApi, "/datasets//documents//processing/resume") api.add_resource(DocumentRetryApi, "/datasets//retry") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 6f7ef86d2..d48dbe177 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,16 +1,21 @@ import uuid -from datetime import UTC, datetime import pandas as pd from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound import services from controllers.console import api from controllers.console.app.error import ProviderNotInitializeError -from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError +from controllers.console.datasets.error import ( + ChildChunkDeleteIndexError, + ChildChunkIndexingError, + InvalidActionError, + NoFileUploadedError, + TooManyFilesError, +) from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_knowledge_limit_check, @@ -20,15 +25,15 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from extensions.ext_database import db from extensions.ext_redis import redis_client -from fields.segment_fields import segment_fields +from fields.segment_fields import child_chunk_fields, segment_fields from libs.login import login_required -from models import DocumentSegment +from models.dataset import ChildChunk, DocumentSegment from services.dataset_service import DatasetService, DocumentService, SegmentService +from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs +from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError +from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task -from tasks.disable_segment_from_index_task import disable_segment_from_index_task -from tasks.enable_segment_to_index_task import enable_segment_to_index_task class DatasetDocumentSegmentListApi(Resource): @@ -53,15 +58,16 @@ def get(self, dataset_id, document_id): raise NotFound("Document not found.") parser = reqparse.RequestParser() - parser.add_argument("last_id", type=str, default=None, location="args") parser.add_argument("limit", type=int, default=20, location="args") parser.add_argument("status", type=str, action="append", default=[], location="args") parser.add_argument("hit_count_gte", type=int, default=None, location="args") parser.add_argument("enabled", type=str, default="all", location="args") parser.add_argument("keyword", type=str, default=None, location="args") + parser.add_argument("page", type=int, default=1, location="args") + args = parser.parse_args() - last_id = args["last_id"] + page = args["page"] limit = min(args["limit"], 100) status_list = args["status"] hit_count_gte = args["hit_count_gte"] @@ -69,14 +75,7 @@ def get(self, dataset_id, document_id): query = DocumentSegment.query.filter( DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ) - - if last_id is not None: - last_segment = db.session.get(DocumentSegment, str(last_id)) - if last_segment: - query = query.filter(DocumentSegment.position > last_segment.position) - else: - return {"data": [], "has_more": False, "limit": limit}, 200 + ).order_by(DocumentSegment.position.asc()) if status_list: query = query.filter(DocumentSegment.status.in_(status_list)) @@ -93,21 +92,44 @@ def get(self, dataset_id, document_id): elif args["enabled"].lower() == "false": query = query.filter(DocumentSegment.enabled == False) - total = query.count() - segments = query.order_by(DocumentSegment.position).limit(limit + 1).all() - - has_more = False - if len(segments) > limit: - has_more = True - segments = segments[:-1] + segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) - return { - "data": marshal(segments, segment_fields), - "doc_form": document.doc_form, - "has_more": has_more, + response = { + "data": marshal(segments.items, segment_fields), "limit": limit, - "total": total, - }, 200 + "total": segments.total, + "total_pages": segments.pages, + "page": page, + } + return response, 200 + + @setup_required + @login_required + @account_initialization_required + def delete(self, dataset_id, document_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + segment_ids = request.args.getlist("segment_id") + + # The role of the current user in the ta table must be admin or owner + if not current_user.is_editor: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + SegmentService.delete_segments(segment_ids, document, dataset) + return {"result": "success"}, 200 class DatasetDocumentSegmentApi(Resource): @@ -115,11 +137,15 @@ class DatasetDocumentSegmentApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("vector_space") - def patch(self, dataset_id, segment_id, action): + def patch(self, dataset_id, document_id, action): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # The role of the current user in the ta table must be admin, owner, or editor @@ -147,59 +173,17 @@ def patch(self, dataset_id, segment_id, action): ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) + segment_ids = request.args.getlist("segment_id") - segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).first() - - if not segment: - raise NotFound("Segment not found.") - - if segment.status != "completed": - raise NotFound("Segment is not completed, enable or disable function is not allowed") - - document_indexing_cache_key = "document_{}_indexing".format(segment.document_id) + document_indexing_cache_key = "document_{}_indexing".format(document.id) cache_result = redis_client.get(document_indexing_cache_key) if cache_result is not None: raise InvalidActionError("Document is being indexed, please try again later") - - indexing_cache_key = "segment_{}_indexing".format(segment.id) - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - raise InvalidActionError("Segment is being indexed, please try again later") - - if action == "enable": - if segment.enabled: - raise InvalidActionError("Segment is already enabled.") - - segment.enabled = True - segment.disabled_at = None - segment.disabled_by = None - db.session.commit() - - # Set cache to prevent indexing the same segment multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - enable_segment_to_index_task.delay(segment.id) - - return {"result": "success"}, 200 - elif action == "disable": - if not segment.enabled: - raise InvalidActionError("Segment is already disabled.") - - segment.enabled = False - segment.disabled_at = datetime.now(UTC).replace(tzinfo=None) - segment.disabled_by = current_user.id - db.session.commit() - - # Set cache to prevent indexing the same segment multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - disable_segment_from_index_task.delay(segment.id) - - return {"result": "success"}, 200 - else: - raise InvalidActionError() + try: + SegmentService.update_segments_status(segment_ids, action, dataset, document) + except Exception as e: + raise InvalidActionError(str(e)) + return {"result": "success"}, 200 class DatasetDocumentSegmentAddApi(Resource): @@ -307,9 +291,12 @@ def patch(self, dataset_id, document_id, segment_id): parser.add_argument("content", type=str, required=True, nullable=False, location="json") parser.add_argument("answer", type=str, required=False, nullable=True, location="json") parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") + parser.add_argument( + "regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json" + ) args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) - segment = SegmentService.update_segment(args, segment, document, dataset) + segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset) return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 @setup_required @@ -381,9 +368,9 @@ def post(self, dataset_id, document_id): result = [] for index, row in df.iterrows(): if document.doc_form == "qa_model": - data = {"content": row[0], "answer": row[1]} + data = {"content": row.iloc[0], "answer": row.iloc[1]} else: - data = {"content": row[0]} + data = {"content": row.iloc[0]} result.append(data) if len(result) == 0: raise ValueError("The CSV file is empty.") @@ -412,8 +399,248 @@ def get(self, job_id): return {"job_id": job_id, "job_status": cache_result.decode()}, 200 +class ChildChunkAddApi(Resource): + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_knowledge_limit_check("add_segment") + def post(self, dataset_id, document_id, segment_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + if not current_user.is_editor: + raise Forbidden() + # check embedding model setting + if dataset.indexing_technique == "high_quality": + try: + model_manager = ModelManager() + model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + except LLMBadRequestError: + raise ProviderNotInitializeError( + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + # validate args + parser = reqparse.RequestParser() + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + try: + child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset) + except ChildChunkIndexingServiceError as e: + raise ChildChunkIndexingError(str(e)) + return {"data": marshal(child_chunk, child_chunk_fields)}, 200 + + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, document_id, segment_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + parser = reqparse.RequestParser() + parser.add_argument("limit", type=int, default=20, location="args") + parser.add_argument("keyword", type=str, default=None, location="args") + parser.add_argument("page", type=int, default=1, location="args") + + args = parser.parse_args() + + page = args["page"] + limit = min(args["limit"], 100) + keyword = args["keyword"] + + child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) + return { + "data": marshal(child_chunks.items, child_chunk_fields), + "total": child_chunks.total, + "total_pages": child_chunks.pages, + "page": page, + "limit": limit, + }, 200 + + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check("vector_space") + def patch(self, dataset_id, document_id, segment_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + # validate args + parser = reqparse.RequestParser() + parser.add_argument("chunks", type=list, required=True, nullable=False, location="json") + args = parser.parse_args() + try: + chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")] + child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset) + except ChildChunkIndexingServiceError as e: + raise ChildChunkIndexingError(str(e)) + return {"data": marshal(child_chunks, child_chunk_fields)}, 200 + + +class ChildChunkUpdateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def delete(self, dataset_id, document_id, segment_id, child_chunk_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + # check child chunk + child_chunk_id = str(child_chunk_id) + child_chunk = ChildChunk.query.filter( + ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id + ).first() + if not child_chunk: + raise NotFound("Child chunk not found.") + # The role of the current user in the ta table must be admin or owner + if not current_user.is_editor: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + try: + SegmentService.delete_child_chunk(child_chunk, dataset) + except ChildChunkDeleteIndexServiceError as e: + raise ChildChunkDeleteIndexError(str(e)) + return {"result": "success"}, 200 + + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check("vector_space") + def patch(self, dataset_id, document_id, segment_id, child_chunk_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + # check child chunk + child_chunk_id = str(child_chunk_id) + child_chunk = ChildChunk.query.filter( + ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id + ).first() + if not child_chunk: + raise NotFound("Child chunk not found.") + # The role of the current user in the ta table must be admin or owner + if not current_user.is_editor: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + # validate args + parser = reqparse.RequestParser() + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + try: + child_chunk = SegmentService.update_child_chunk( + args.get("content"), child_chunk, segment, document, dataset + ) + except ChildChunkIndexingServiceError as e: + raise ChildChunkIndexingError(str(e)) + return {"data": marshal(child_chunk, child_chunk_fields)}, 200 + + api.add_resource(DatasetDocumentSegmentListApi, "/datasets//documents//segments") -api.add_resource(DatasetDocumentSegmentApi, "/datasets//segments//") +api.add_resource( + DatasetDocumentSegmentApi, "/datasets//documents//segment/" +) api.add_resource(DatasetDocumentSegmentAddApi, "/datasets//documents//segment") api.add_resource( DatasetDocumentSegmentUpdateApi, @@ -424,3 +651,11 @@ def get(self, job_id): "/datasets//documents//segments/batch_import", "/datasets/batch_import_status/", ) +api.add_resource( + ChildChunkAddApi, + "/datasets//documents//segments//child_chunks", +) +api.add_resource( + ChildChunkUpdateApi, + "/datasets//documents//segments//child_chunks/", +) diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py index 6a7a3971a..2f00a84de 100644 --- a/api/controllers/console/datasets/error.py +++ b/api/controllers/console/datasets/error.py @@ -89,3 +89,15 @@ class IndexingEstimateError(BaseHTTPException): error_code = "indexing_estimate_error" description = "Knowledge indexing estimate failed: {message}" code = 500 + + +class ChildChunkIndexingError(BaseHTTPException): + error_code = "child_chunk_indexing_error" + description = "Create child chunk index failed: {message}" + code = 500 + + +class ChildChunkDeleteIndexError(BaseHTTPException): + error_code = "child_chunk_delete_index_error" + description = "Delete child chunk index failed: {message}" + code = 500 diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index bc6e3687c..48f360dcd 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -1,6 +1,6 @@ from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal, reqparse # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 495f51127..18b746f54 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from controllers.console import api from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 3b4c07686..bd944602c 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,7 @@ import logging -from flask_login import current_user -from flask_restful import marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import marshal, reqparse # type: ignore from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services.dataset_service diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index 9127c8af4..da995537e 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console import api from controllers.console.datasets.error import WebsiteCrawlError diff --git a/api/controllers/console/error.py b/api/controllers/console/error.py index 1b4e6deae..ee87138a4 100644 --- a/api/controllers/console/error.py +++ b/api/controllers/console/error.py @@ -92,3 +92,12 @@ class UnauthorizedAndForceLogout(BaseHTTPException): error_code = "unauthorized_and_force_logout" description = "Unauthorized and force logout." code = 401 + + +class AccountInFreezeError(BaseHTTPException): + error_code = "account_in_freeze" + code = 400 + description = ( + "This email account has been deleted within the past 30 days" + "and is temporarily unavailable for new account registration." + ) diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 9690677f6..c7f9fec32 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -4,7 +4,6 @@ from werkzeug.exceptions import InternalServerError import services -from controllers.console import api from controllers.console.app.error import ( AppUnavailableError, AudioTooLargeError, @@ -67,7 +66,7 @@ def post(self, installed_app): class ChatTextApi(InstalledAppResource): def post(self, installed_app): - from flask_restful import reqparse + from flask_restful import reqparse # type: ignore app_model = installed_app.app try: @@ -118,9 +117,3 @@ def post(self, installed_app): except Exception as e: logging.exception("internal server error.") raise InternalServerError() - - -api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") -api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") -# api.add_resource(ChatTextApiWithMessageId, '/installed-apps//text-to-audio/message-id', -# endpoint='installed_app_text_with_message_id') diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 2e24b9c91..394e029ed 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,12 +1,11 @@ import logging from datetime import UTC, datetime -from flask_login import current_user -from flask_restful import reqparse +from flask_login import current_user # type: ignore +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services -from controllers.console import api from controllers.console.app.error import ( AppUnavailableError, CompletionRequestError, @@ -20,7 +19,11 @@ from controllers.console.money_extend import money_limit from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from libs import helper @@ -163,21 +166,3 @@ def post(self, installed_app, task_id): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {"result": "success"}, 200 - - -api.add_resource( - CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion" -) -api.add_resource( - CompletionStopApi, - "/installed-apps//completion-messages//stop", - endpoint="installed_app_stop_completion", -) -api.add_resource( - ChatApi, "/installed-apps//chat-messages", endpoint="installed_app_chat_completion" -) -api.add_resource( - ChatStopApi, - "/installed-apps//chat-messages//stop", - endpoint="installed_app_stop_chat_completion", -) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 5e7a3da01..600e78e09 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,10 +1,9 @@ -from flask_login import current_user -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound -from controllers.console import api from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom @@ -33,7 +32,7 @@ def get(self, installed_app): pinned = None if "pinned" in args and args["pinned"] is not None: - pinned = True if args["pinned"] == "true" else False + pinned = args["pinned"] == "true" try: with Session(db.engine) as session: @@ -118,28 +117,3 @@ def patch(self, installed_app, c_id): WebConversationService.unpin(app_model, conversation_id, current_user) return {"result": "success"} - - -api.add_resource( - ConversationRenameApi, - "/installed-apps//conversations//name", - endpoint="installed_app_conversation_rename", -) -api.add_resource( - ConversationListApi, "/installed-apps//conversations", endpoint="installed_app_conversations" -) -api.add_resource( - ConversationApi, - "/installed-apps//conversations/", - endpoint="installed_app_conversation", -) -api.add_resource( - ConversationPinApi, - "/installed-apps//conversations//pin", - endpoint="installed_app_conversation_pin", -) -api.add_resource( - ConversationUnPinApi, - "/installed-apps//conversations//unpin", - endpoint="installed_app_conversation_unpin", -) diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 3de179164..86550b2bd 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,8 +1,9 @@ from datetime import UTC, datetime +from typing import Any from flask import request -from flask_login import current_user -from flask_restful import Resource, inputs, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore from sqlalchemy import and_ from werkzeug.exceptions import BadRequest, Forbidden, NotFound @@ -34,7 +35,7 @@ def get(self): installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) - installed_apps = [ + installed_app_list: list[dict[str, Any]] = [ { "id": installed_app.id, "app": installed_app.app, @@ -47,7 +48,7 @@ def get(self): for installed_app in installed_apps if installed_app.app is not None ] - installed_apps.sort( + installed_app_list.sort( key=lambda app: ( -app["is_pinned"], app["last_used_at"] is None, @@ -55,7 +56,7 @@ def get(self): ) ) - return {"installed_apps": installed_apps} + return {"installed_apps": installed_app_list} @login_required @account_initialization_required diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 4e11d8005..405d5ed60 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,12 +1,11 @@ import logging -from flask_login import current_user -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services -from controllers.console import api from controllers.console.app.error import ( AppMoreLikeThisDisabledError, CompletionRequestError, @@ -67,10 +66,17 @@ def post(self, installed_app, message_id): parser = reqparse.RequestParser() parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") + parser.add_argument("content", type=str, location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, current_user, args["rating"], args["content"]) + MessageService.create_feedback( + app_model=app_model, + message_id=message_id, + user=current_user, + rating=args.get("rating"), + content=args.get("content"), + ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -153,21 +159,3 @@ def get(self, installed_app, message_id): raise InternalServerError() return {"data": questions} - - -api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages") -api.add_resource( - MessageFeedbackApi, - "/installed-apps//messages//feedbacks", - endpoint="installed_app_message_feedback", -) -api.add_resource( - MessageMoreLikeThisApi, - "/installed-apps//messages//more-like-this", - endpoint="installed_app_more_like_this", -) -api.add_resource( - MessageSuggestedQuestionApi, - "/installed-apps//messages//suggested-questions", - endpoint="installed_app_suggested_question", -) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index fee52248a..5bc74d16e 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,4 +1,4 @@ -from flask_restful import marshal_with +from flask_restful import marshal_with # type: ignore from controllers.common import fields from controllers.common import helpers as controller_helpers diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index ce85f495a..be6b1f5d2 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from constants.languages import languages from controllers.console import api diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 0fc963747..9f0c49664 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,6 +1,6 @@ -from flask_login import current_user -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_login import current_user # type: ignore +from flask_restful import fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import NotFound from controllers.console import api diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 0428ac13d..0dd3252a4 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,9 +1,8 @@ import logging -from flask_restful import reqparse +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError -from controllers.console import api from controllers.console.app.error import ( CompletionRequestError, ProviderModelCurrentlyNotSupportError, @@ -15,7 +14,11 @@ from controllers.console.money_extend import money_limit from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) from core.model_runtime.errors.invoke import InvokeError from libs import helper from libs.login import current_user @@ -83,9 +86,3 @@ def post(self, installed_app: InstalledApp, task_id: str): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) return {"result": "success"} - - -api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/run") -api.add_resource( - InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" -) diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 49ea81a8a..b7ba81fba 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -1,7 +1,7 @@ from functools import wraps -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from werkzeug.exceptions import NotFound from controllers.console.wraps import account_initialization_required diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 4ac0aa497..ed6cedb22 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from constants import HIDDEN_VALUE from controllers.console import api diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 70ab4ff86..da1171412 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import Resource +from flask_login import current_user # type: ignore +from flask_restful import Resource # type: ignore from libs.login import login_required from services.feature_service import FeatureService diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index ca32d29ef..8cf754bbd 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -1,6 +1,8 @@ +from typing import Literal + from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal_with +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with # type: ignore from werkzeug.exceptions import Forbidden import services @@ -48,7 +50,8 @@ def get(self): @cloud_edition_billing_resource_check("documents") def post(self): file = request.files["file"] - source = request.form.get("source") + source_str = request.form.get("source") + source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None if "file" not in request.files: raise NoFileUploadedError() diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index ae759bb75..d9ae5cf29 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -1,7 +1,7 @@ import os from flask import session -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from configs import dify_config from libs.helper import StrLen diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py index cd28cc946..2a116112a 100644 --- a/api/controllers/console/ping.py +++ b/api/controllers/console/ping.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from controllers.console import api diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index b8cf019e4..30afc930a 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -2,8 +2,8 @@ from typing import cast import httpx -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore import services from controllers.common import helpers diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index e0b728d97..aba6f0aad 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from configs import dify_config from libs.helper import StrLen, email, extract_remote_ip diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index c4577f87b..4d18d59e1 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,6 +1,6 @@ from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -24,7 +24,7 @@ class TagListApi(Resource): @account_initialization_required @marshal_with(tag_fields) def get(self): - tag_type = request.args.get("type", type=str) + tag_type = request.args.get("type", type=str, default="") keyword = request.args.get("keyword", default=None, type=str) tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword) diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 7dea8e554..7773c9994 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -2,7 +2,7 @@ import logging import requests -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from packaging import version from configs import dify_config diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index f704783cf..f1ec0f3d2 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -2,8 +2,8 @@ import pytz from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from configs import dify_config from constants.languages import supported_language @@ -11,6 +11,7 @@ from controllers.console.workspace.error import ( AccountAlreadyInitedError, CurrentPasswordIncorrectError, + InvalidAccountDeletionCodeError, InvalidInvitationCodeError, RepeatPasswordNotMatchError, ) @@ -21,6 +22,7 @@ from libs.login import login_required from models import AccountIntegrate, InvitationCode from services.account_service import AccountService +from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError @@ -242,6 +244,54 @@ def get(self): return {"data": integrate_data} +class AccountDeleteVerifyApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + account = current_user + + token, code = AccountService.generate_account_deletion_verification_code(account) + AccountService.send_account_deletion_verification_email(account, code) + + return {"result": "success", "data": token} + + +class AccountDeleteApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + account = current_user + + parser = reqparse.RequestParser() + parser.add_argument("token", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") + args = parser.parse_args() + + if not AccountService.verify_account_deletion_code(args["token"], args["code"]): + raise InvalidAccountDeletionCodeError() + + AccountService.delete_account(account) + + return {"result": "success"} + + +class AccountDeleteUpdateFeedbackApi(Resource): + @setup_required + def post(self): + account = current_user + + parser = reqparse.RequestParser() + parser.add_argument("email", type=str, required=True, location="json") + parser.add_argument("feedback", type=str, required=True, location="json") + args = parser.parse_args() + + BillingService.update_account_deletion_feedback(args["email"], args["feedback"]) + + return {"result": "success"} + + # Register API resources api.add_resource(AccountInitApi, "/account/init") api.add_resource(AccountProfileApi, "/account/profile") @@ -252,5 +302,8 @@ def get(self): api.add_resource(AccountTimezoneApi, "/account/timezone") api.add_resource(AccountPasswordApi, "/account/password") api.add_resource(AccountIntegrateApi, "/account/integrates") +api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify") +api.add_resource(AccountDeleteApi, "/account/delete") +api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback") # api.add_resource(AccountEmailApi, '/account/email') # api.add_resource(AccountEmailVerifyApi, '/account/email-verify') diff --git a/api/controllers/console/workspace/error.py b/api/controllers/console/workspace/error.py index 9e13c7b92..8b70ca62b 100644 --- a/api/controllers/console/workspace/error.py +++ b/api/controllers/console/workspace/error.py @@ -35,3 +35,9 @@ class AccountNotInitializedError(BaseHTTPException): error_code = "account_not_initialized" description = "The account has not been initialized yet. Please proceed with the initialization process first." code = 400 + + +class InvalidAccountDeletionCodeError(BaseHTTPException): + error_code = "invalid_account_deletion_code" + description = "Invalid account deletion code." + code = 400 diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index d2b2092b7..7009343d9 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -37,7 +37,7 @@ def post(self, provider: str): model_load_balancing_service = ModelLoadBalancingService() result = True - error = None + error = "" try: model_load_balancing_service.validate_load_balancing_credentials( @@ -86,7 +86,7 @@ def post(self, provider: str, config_id: str): model_load_balancing_service = ModelLoadBalancingService() result = True - error = None + error = "" try: model_load_balancing_service.validate_load_balancing_credentials( diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 38ed2316a..a2b41c1d3 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,7 +1,7 @@ from urllib import parse -from flask_login import current_user -from flask_restful import Resource, abort, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, abort, marshal_with, reqparse # type: ignore import services from configs import dify_config @@ -89,19 +89,19 @@ class MemberCancelInviteApi(Resource): @account_initialization_required def delete(self, member_id): member = db.session.query(Account).filter(Account.id == str(member_id)).first() - if not member: + if member is None: abort(404) - - try: - TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) - except services.errors.account.CannotOperateSelfError as e: - return {"code": "cannot-operate-self", "message": str(e)}, 400 - except services.errors.account.NoPermissionError as e: - return {"code": "forbidden", "message": str(e)}, 403 - except services.errors.account.MemberNotInTenantError as e: - return {"code": "member-not-found", "message": str(e)}, 404 - except Exception as e: - raise ValueError(str(e)) + else: + try: + TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) + except services.errors.account.CannotOperateSelfError as e: + return {"code": "cannot-operate-self", "message": str(e)}, 400 + except services.errors.account.NoPermissionError as e: + return {"code": "forbidden", "message": str(e)}, 403 + except services.errors.account.MemberNotInTenantError as e: + return {"code": "member-not-found", "message": str(e)}, 404 + except Exception as e: + raise ValueError(str(e)) return {"result": "success"}, 204 @@ -126,6 +126,7 @@ def put(self, member_id): abort(404) try: + assert member is not None, "Member not found" TenantService.update_member_role(current_user.current_tenant, member, new_role, current_user) except Exception as e: raise ValueError(str(e)) diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 0e5412606..2d11295b0 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -1,8 +1,8 @@ import io from flask import send_file -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -66,7 +66,7 @@ def post(self, provider: str): model_provider_service = ModelProviderService() result = True - error = None + error = "" try: model_provider_service.provider_credentials_validate( @@ -132,7 +132,8 @@ def get(self, provider: str, icon_type: str, lang: str): icon_type=icon_type, lang=lang, ) - + if icon is None: + raise ValueError(f"icon not found for provider {provider}, icon_type {icon_type}, lang {lang}") return send_file(io.BytesIO(icon), mimetype=mimetype) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index efb48783e..94a63830b 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -1,7 +1,7 @@ import logging -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden from controllers.console import api @@ -318,7 +318,7 @@ def post(self, provider: str): model_provider_service = ModelProviderService() result = True - error = None + error = "" try: model_provider_service.model_credentials_validate( diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 9e62a5469..964f38622 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,8 +1,8 @@ import io from flask import send_file -from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, reqparse # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 22daf685b..9ba03d1e3 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -1,8 +1,8 @@ import logging from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse +from flask_login import current_user # type: ignore +from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse # type: ignore from werkzeug.exceptions import Unauthorized import services @@ -86,11 +86,7 @@ def get(self): parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - tenants = ( - db.session.query(Tenant) - .order_by(Tenant.created_at.desc()) - .paginate(page=args["page"], per_page=args["limit"]) - ) + tenants = Tenant.query.order_by(Tenant.created_at.desc()).paginate(page=args["page"], per_page=args["limit"]) has_more = False if len(tenants.items) == args["limit"]: @@ -155,6 +151,8 @@ def post(self): raise AccountNotLinkTenantError("Account not link tenant") new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant + if new_tenant is None: + raise ValueError("Tenant not found") return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} @@ -170,7 +168,7 @@ def post(self): parser.add_argument("replace_webapp_logo", type=str, location="json") args = parser.parse_args() - tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404() + tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404() custom_config_dict = { "remove_webapp_brand": args["remove_webapp_brand"], diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index d0df296c2..111db7ccf 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -3,7 +3,7 @@ from functools import wraps from flask import abort, request -from flask_login import current_user +from flask_login import current_user # type: ignore from configs import dify_config from controllers.console.workspace.error import AccountNotInitializedError @@ -121,8 +121,8 @@ def decorated(*args, **kwargs): utm_info = request.cookies.get("utm_info") if utm_info: - utm_info = json.loads(utm_info) - OperationService.record_utm(current_user.current_tenant_id, utm_info) + utm_info_dict: dict = json.loads(utm_info) + OperationService.record_utm(current_user.current_tenant_id, utm_info_dict) except Exception as e: pass return view(*args, **kwargs) diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index 6b3ac93cd..2357288a5 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -1,5 +1,5 @@ from flask import Response, request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import NotFound import services diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index a298701a2..cfcce8124 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -1,5 +1,5 @@ from flask import Response -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import Forbidden, NotFound from controllers.files import api diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 99d32af59..d7346b13b 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from controllers.console.wraps import setup_required from controllers.inner_api import api diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 51ffe683f..d4587235f 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -45,14 +45,14 @@ def decorated(*args, **kwargs): if " " in user_id: user_id = user_id.split(" ")[1] - inner_api_key = request.headers.get("X-Inner-Api-Key") + inner_api_key = request.headers.get("X-Inner-Api-Key", "") data_to_sign = f"DIFY {user_id}" signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1) - signature = b64encode(signature.digest()).decode("utf-8") + signature_base64 = b64encode(signature.digest()).decode("utf-8") - if signature != token: + if signature_base64 != token: return view(*args, **kwargs) kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first() diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index d6ab96c32..aba9e3ecb 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -7,4 +7,4 @@ from . import index from .app import app, audio, completion, conversation, file, message, workflow -from .dataset import dataset, document, hit_testing, segment +from .dataset import dataset, document, hit_testing, segment, upload_file diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index ecff7d07e..8388e2045 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,4 +1,4 @@ -from flask_restful import Resource, marshal_with +from flask_restful import Resource, marshal_with # type: ignore from controllers.common import fields from controllers.common import helpers as controller_helpers diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 5db416364..e6bcc0bfd 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -1,7 +1,7 @@ import logging from flask import request -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError import services @@ -83,7 +83,7 @@ def post(self, app_model: App, end_user: EndUser): and app_model.workflow and app_model.workflow.features_dict ): - text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {}) voice = args.get("voice") or text_to_speech.get("voice") else: try: diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 5ee2ffb27..5c2c859e5 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -1,6 +1,6 @@ import logging -from flask_restful import Resource, reqparse +from flask_restful import Resource, reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services @@ -18,7 +18,6 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ( - AppInvokeQuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError, @@ -78,7 +77,7 @@ def post(self, app_model: App, end_user: EndUser, api_token: ApiToken): # 二 raise ProviderModelCurrentlyNotSupportError() except InvokeError as e: raise CompletionRequestError(e.description) - except (ValueError, AppInvokeQuotaExceededError) as e: + except ValueError as e: raise e except Exception as e: logging.exception("internal server error.") @@ -144,7 +143,7 @@ def post(self, app_model: App, end_user: EndUser, api_token: ApiToken): # 二 raise ProviderModelCurrentlyNotSupportError() except InvokeError as e: raise CompletionRequestError(e.description) - except (ValueError, AppInvokeQuotaExceededError) as e: + except ValueError as e: raise e except Exception as e: logging.exception("internal server error.") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 32940cbc2..334f2c562 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,5 +1,5 @@ -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index b0fd8e65e..27b21b9f5 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import Resource, marshal_with +from flask_restful import Resource, marshal_with # type: ignore import services from controllers.common.errors import FilenameNotExistsError diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 7cc75f4c0..278a69159 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,7 +1,7 @@ import logging -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import BadRequest, InternalServerError, NotFound import services @@ -108,7 +108,13 @@ def post(self, app_model: App, end_user: EndUser, message_id): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) + MessageService.create_feedback( + app_model=app_model, + message_id=message_id, + user=end_user, + rating=args.get("rating"), + content=args.get("content"), + ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 631383b02..c3fe75e93 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,7 +1,7 @@ import logging -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import InternalServerError from controllers.service_api import api @@ -16,7 +16,6 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ( - AppInvokeQuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError, @@ -98,7 +97,7 @@ def post(self, app_model: App, end_user: EndUser, api_token: ApiToken): # 二 raise ProviderModelCurrentlyNotSupportError() except InvokeError as e: raise CompletionRequestError(e.description) - except (ValueError, AppInvokeQuotaExceededError) as e: + except ValueError as e: raise e except Exception as e: logging.exception("internal server error.") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 799fccc22..49acdd693 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import marshal, reqparse +from flask_restful import marshal, reqparse # type: ignore from werkzeug.exceptions import NotFound import services.dataset_service @@ -31,8 +31,11 @@ def get(self, tenant_id): # provider = request.args.get("provider", default="vendor") search = request.args.get("keyword", default=None, type=str) tag_ids = request.args.getlist("tag_ids") + include_all = request.args.get("include_all", default="false").lower() == "true" - datasets, total = DatasetService.get_datasets(page, limit, tenant_id, current_user, search, tag_ids) + datasets, total = DatasetService.get_datasets( + page, limit, tenant_id, current_user, search, tag_ids, include_all + ) # check embedding setting provider_manager = ProviderManager() configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 5c3fc7b24..2e148dd84 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -1,19 +1,23 @@ import json from flask import request -from flask_restful import marshal, reqparse +from flask_restful import marshal, reqparse # type: ignore from sqlalchemy import desc from werkzeug.exceptions import NotFound import services.dataset_service from controllers.common.errors import FilenameNotExistsError from controllers.service_api import api -from controllers.service_api.app.error import ProviderNotInitializeError +from controllers.service_api.app.error import ( + FileTooLargeError, + NoFileUploadedError, + ProviderNotInitializeError, + TooManyFilesError, + UnsupportedFileTypeError, +) from controllers.service_api.dataset.error import ( ArchivedDocumentImmutableError, DocumentIndexingError, - NoFileUploadedError, - TooManyFilesError, ) from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check from core.errors.error import ProviderTokenNotInitError @@ -22,6 +26,7 @@ from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment from services.dataset_service import DocumentService +from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.file_service import FileService @@ -67,13 +72,14 @@ def post(self, tenant_id, dataset_id): "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, } args["data_source"] = data_source + knowledge_config = KnowledgeConfig(**args) # validate args - DocumentService.document_create_args_validate(args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=current_user, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", @@ -122,12 +128,13 @@ def post(self, tenant_id, dataset_id, document_id): args["data_source"] = data_source # validate args args["original_document_id"] = str(document_id) - DocumentService.document_create_args_validate(args) + knowledge_config = KnowledgeConfig(**args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=current_user, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", @@ -183,15 +190,19 @@ def post(self, tenant_id, dataset_id): user=current_user, source="datasets", ) - data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} + data_source = { + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, + } args["data_source"] = data_source # validate args - DocumentService.document_create_args_validate(args) + knowledge_config = KnowledgeConfig(**args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=dataset.created_by_account, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", @@ -234,23 +245,33 @@ def post(self, tenant_id, dataset_id, document_id): if not file.filename: raise FilenameNotExistsError - upload_file = FileService.upload_file( - filename=file.filename, - content=file.read(), - mimetype=file.mimetype, - user=current_user, - source="datasets", - ) - data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} + try: + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + source="datasets", + ) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + data_source = { + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, + } args["data_source"] = data_source # validate args args["original_document_id"] = str(document_id) - DocumentService.document_create_args_validate(args) + + knowledge_config = KnowledgeConfig(**args) + DocumentService.document_create_args_validate(knowledge_config) try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, - document_data=args, + knowledge_config=knowledge_config, account=dataset.created_by_account, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index e68f6b4dc..1c500f51b 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -1,5 +1,5 @@ -from flask_login import current_user -from flask_restful import marshal, reqparse +from flask_login import current_user # type: ignore +from flask_restful import marshal, reqparse # type: ignore from werkzeug.exceptions import NotFound from controllers.service_api import api @@ -16,6 +16,7 @@ from fields.segment_fields import segment_fields from models.dataset import Dataset, DocumentSegment from services.dataset_service import DatasetService, DocumentService, SegmentService +from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs class SegmentApi(DatasetApiResource): @@ -193,7 +194,7 @@ def post(self, tenant_id, dataset_id, document_id, segment_id): args = parser.parse_args() SegmentService.segment_create_args_validate(args["segment"], document) - segment = SegmentService.update_segment(args["segment"], segment, document, dataset) + segment = SegmentService.update_segment(SegmentUpdateArgs(**args["segment"]), segment, document, dataset) return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 diff --git a/api/controllers/service_api/dataset/upload_file.py b/api/controllers/service_api/dataset/upload_file.py new file mode 100644 index 000000000..6382b63ea --- /dev/null +++ b/api/controllers/service_api/dataset/upload_file.py @@ -0,0 +1,54 @@ +from werkzeug.exceptions import NotFound + +from controllers.service_api import api +from controllers.service_api.wraps import ( + DatasetApiResource, +) +from core.file import helpers as file_helpers +from extensions.ext_database import db +from models.dataset import Dataset +from models.model import UploadFile +from services.dataset_service import DocumentService + + +class UploadFileApi(DatasetApiResource): + def get(self, tenant_id, dataset_id, document_id): + """Get upload file.""" + # check dataset + dataset_id = str(dataset_id) + tenant_id = str(tenant_id) + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + if not dataset: + raise NotFound("Dataset not found.") + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset.id, document_id) + if not document: + raise NotFound("Document not found.") + # check upload file + if document.data_source_type != "upload_file": + raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.") + data_source_info = document.data_source_info_dict + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + if not upload_file: + raise NotFound("UploadFile not found.") + else: + raise ValueError("Upload file id not found in document data source info.") + + url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id) + return { + "id": upload_file.id, + "name": upload_file.name, + "size": upload_file.size, + "extension": upload_file.extension, + "url": url, + "download_url": f"{url}&as_attachment=true", + "mime_type": upload_file.mime_type, + "created_by": upload_file.created_by, + "created_at": upload_file.created_at.timestamp(), + }, 200 + + +api.add_resource(UploadFileApi, "/datasets//documents//upload-file") diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py index d24c4597e..75d9141a6 100644 --- a/api/controllers/service_api/index.py +++ b/api/controllers/service_api/index.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from configs import dify_config from controllers.service_api import api diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 536ec38b1..2efa73624 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,14 +1,16 @@ import logging # ---------------------二开部分 密钥额度限制 --------------------- from collections.abc import Callable -from datetime import UTC, datetime +from datetime import UTC, datetime, timedelta from enum import Enum from functools import wraps from typing import Optional from flask import current_app, request -from flask_login import user_logged_in -from flask_restful import Resource +from flask_login import user_logged_in # type: ignore +from flask_restful import Resource # type: ignore from pydantic import BaseModel +from sqlalchemy import select, update +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, Unauthorized from controllers.service_api.app.error_extend import ( @@ -68,6 +70,8 @@ def decorated_view(*args, **kwargs): raise Forbidden("The app's API service has been disabled.") tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() + if tenant is None: + raise ValueError("Tenant does not exist.") if tenant.status == TenantStatus.ARCHIVE: raise Forbidden("The workspace's status is archived.") @@ -221,8 +225,8 @@ def decorated(*args, **kwargs): # Login admin if account: account.current_tenant = tenant - current_app.login_manager._update_request_context_with_user(account) - user_logged_in.send(current_app._get_current_object(), user=_get_user()) + current_app.login_manager._update_request_context_with_user(account) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore else: raise Unauthorized("Tenant owner account does not exist.") else: @@ -239,7 +243,7 @@ def decorated(*args, **kwargs): return decorator -def validate_and_get_api_token(scope=None): +def validate_and_get_api_token(scope: str | None = None): """ Validate and get API token. """ @@ -253,20 +257,25 @@ def validate_and_get_api_token(scope=None): if auth_scheme != "bearer": raise Unauthorized("Authorization scheme must be 'Bearer'") - api_token = ( - db.session.query(ApiToken) - .filter( - ApiToken.token == auth_token, - ApiToken.type == scope, + current_time = datetime.now(UTC).replace(tzinfo=None) + cutoff_time = current_time - timedelta(minutes=1) + with Session(db.engine, expire_on_commit=False) as session: + update_stmt = ( + update(ApiToken) + .where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope) + .values(last_used_at=current_time) + .returning(ApiToken) ) - .first() - ) + result = session.execute(update_stmt) + api_token = result.scalar_one_or_none() - if not api_token: - raise Unauthorized("Access token is invalid") - - api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() + if not api_token: + stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope) + api_token = session.scalar(stmt) + if not api_token: + raise Unauthorized("Access token is invalid") + else: + session.commit() return api_token @@ -294,7 +303,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] tenant_id=app_model.tenant_id, app_id=app_model.id, type="service_api", - is_anonymous=True if user_id == "DEFAULT-USER" else False, + is_anonymous=user_id == "DEFAULT-USER", session_id=user_id, ) db.session.add(end_user) diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index cc8255ccf..20e071c83 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,4 +1,4 @@ -from flask_restful import marshal_with +from flask_restful import marshal_with # type: ignore from controllers.common import fields from controllers.common import helpers as controller_helpers diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index e8521307a..97d980d07 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -65,7 +65,7 @@ def post(self, app_model: App, end_user): class TextApi(WebApiResource): def post(self, app_model: App, end_user): - from flask_restful import reqparse + from flask_restful import reqparse # type: ignore try: parser = reqparse.RequestParser() @@ -82,7 +82,7 @@ def post(self, app_model: App, end_user): and app_model.workflow and app_model.workflow.features_dict ): - text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {}) voice = args.get("voice") or text_to_speech.get("voice") else: try: diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 0fd3bb4ba..5364da4fb 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,7 +1,7 @@ import logging from flask import request # ----------------- start You must log in to access your account extend --------------- -from flask_restful import reqparse +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services @@ -24,7 +24,11 @@ from controllers.web.wraps import WebApiResource from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from libs import helper diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index fe0d7c74f..419247ea1 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,5 +1,5 @@ -from flask_restful import marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -39,7 +39,7 @@ def get(self, app_model, end_user): pinned = None if "pinned" in args and args["pinned"] is not None: - pinned = True if args["pinned"] == "true" else False + pinned = args["pinned"] == "true" try: with Session(db.engine) as session: diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py index 0563ed223..ce841a881 100644 --- a/api/controllers/web/feature.py +++ b/api/controllers/web/feature.py @@ -1,4 +1,4 @@ -from flask_restful import Resource +from flask_restful import Resource # type: ignore from controllers.web import api from services.feature_service import FeatureService diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index a282fc63a..1d4474015 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -1,5 +1,5 @@ from flask import request -from flask_restful import marshal_with +from flask_restful import marshal_with # type: ignore import services from controllers.common.errors import FilenameNotExistsError @@ -33,7 +33,7 @@ def post(self, app_model, end_user): content=file.read(), mimetype=file.mimetype, user=end_user, - source=source, + source="datasets" if source == "datasets" else None, ) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index febaab532..2afc11f60 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -1,7 +1,7 @@ import logging -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import InternalServerError, NotFound import services @@ -105,10 +105,17 @@ def post(self, app_model, end_user, message_id): parser = reqparse.RequestParser() parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") + parser.add_argument("content", type=str, location="json", default=None) args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) + MessageService.create_feedback( + app_model=app_model, + message_id=message_id, + user=end_user, + rating=args.get("rating"), + content=args.get("content"), + ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index a01ffd861..4625c1f43 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -1,7 +1,7 @@ import uuid from flask import request -from flask_restful import Resource +from flask_restful import Resource # type: ignore from werkzeug.exceptions import NotFound, Unauthorized from controllers.web import api diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index ae68df6bd..d559ab8e0 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,7 +1,7 @@ import urllib.parse import httpx -from flask_restful import marshal_with, reqparse +from flask_restful import marshal_with, reqparse # type: ignore import services from controllers.common import helpers diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index b0492e6b6..6a9b81890 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,5 +1,5 @@ -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range +from flask_restful import fields, marshal_with, reqparse # type: ignore +from flask_restful.inputs import int_range # type: ignore from werkzeug.exceptions import NotFound from controllers.web import api diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 0564b15ea..e68dc7aa4 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,4 +1,4 @@ -from flask_restful import fields, marshal_with +from flask_restful import fields, marshal_with # type: ignore from werkzeug.exceptions import Forbidden from configs import dify_config diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index c3da72d4c..61e666333 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,6 +1,6 @@ import logging -from flask_restful import reqparse +from flask_restful import reqparse # type: ignore from werkzeug.exceptions import InternalServerError from controllers.web import api @@ -19,7 +19,11 @@ from controllers.web.wraps import WebApiResource from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) from core.model_runtime.errors.invoke import InvokeError from libs import helper from models.model import App, AppMode, EndUser diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index c327c3df1..1b4d263be 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,7 +1,7 @@ from functools import wraps from flask import request -from flask_restful import Resource +from flask_restful import Resource # type: ignore from werkzeug.exceptions import BadRequest, NotFound, Unauthorized from controllers.web.error import WebSSOAuthRequiredError diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index ead293200..ae086ba8e 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -1,7 +1,6 @@ import json import logging import uuid -from collections.abc import Mapping, Sequence from datetime import UTC, datetime from typing import Optional, Union, cast @@ -53,6 +52,7 @@ class BaseAgentRunner(AppRunner): def __init__( self, + *, tenant_id: str, application_generate_entity: AgentChatAppGenerateEntity, conversation: Conversation, @@ -66,7 +66,7 @@ def __init__( prompt_messages: Optional[list[PromptMessage]] = None, variables_pool: Optional[ToolRuntimeVariablePool] = None, db_variables: Optional[ToolConversationVariables] = None, - model_instance: ModelInstance | None = None, + model_instance: ModelInstance, ) -> None: self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity @@ -117,7 +117,7 @@ def __init__( features = model_schema.features if model_schema and model_schema.features else [] self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features self.files = application_generate_entity.files if ModelFeature.VISION in features else [] - self.query = None + self.query: Optional[str] = "" self._current_thoughts: list[PromptMessage] = [] def _repack_app_generate_entity( @@ -145,7 +145,7 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P message_tool = PromptMessageTool( name=tool.tool_name, - description=tool_entity.description.llm, + description=tool_entity.description.llm if tool_entity.description else "", parameters={ "type": "object", "properties": {}, @@ -167,7 +167,7 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P continue enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: - enum = [option.value for option in parameter.options] + enum = [option.value for option in parameter.options] if parameter.options else [] message_tool.parameters["properties"][parameter.name] = { "type": parameter_type, @@ -187,8 +187,8 @@ def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRe convert dataset retriever tool to prompt message tool """ prompt_tool = PromptMessageTool( - name=tool.identity.name, - description=tool.description.llm, + name=tool.identity.name if tool.identity else "unknown", + description=tool.description.llm if tool.description else "", parameters={ "type": "object", "properties": {}, @@ -210,14 +210,14 @@ def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRe return prompt_tool - def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]: + def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]: """ Init tools """ tool_instances = {} prompt_messages_tools = [] - for tool in self.app_config.agent.tools if self.app_config.agent else []: + for tool in self.app_config.agent.tools or [] if self.app_config.agent else []: try: prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) except Exception: @@ -234,7 +234,8 @@ def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessage # save prompt tool prompt_messages_tools.append(prompt_tool) # save tool entity - tool_instances[dataset_tool.identity.name] = dataset_tool + if dataset_tool.identity is not None: + tool_instances[dataset_tool.identity.name] = dataset_tool return tool_instances, prompt_messages_tools @@ -258,7 +259,7 @@ def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) continue enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: - enum = [option.value for option in parameter.options] + enum = [option.value for option in parameter.options] if parameter.options else [] prompt_tool.parameters["properties"][parameter.name] = { "type": parameter_type, @@ -322,24 +323,29 @@ def save_agent_thought( tool_name: str, tool_input: Union[str, dict], thought: str, - observation: Union[str, dict], - tool_invoke_meta: Union[str, dict], + observation: Union[str, dict, None], + tool_invoke_meta: Union[str, dict, None], answer: str, messages_ids: list[str], - llm_usage: LLMUsage = None, - ) -> MessageAgentThought: + llm_usage: LLMUsage | None = None, + ): """ Save agent thought """ - agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() + queried_thought = ( + db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() + ) + if not queried_thought: + raise ValueError(f"Agent thought {agent_thought.id} not found") + agent_thought = queried_thought - if thought is not None: + if thought: agent_thought.thought = thought - if tool_name is not None: + if tool_name: agent_thought.tool = tool_name - if tool_input is not None: + if tool_input: if isinstance(tool_input, dict): try: tool_input = json.dumps(tool_input, ensure_ascii=False) @@ -348,7 +354,7 @@ def save_agent_thought( agent_thought.tool_input = tool_input - if observation is not None: + if observation: if isinstance(observation, dict): try: observation = json.dumps(observation, ensure_ascii=False) @@ -357,7 +363,7 @@ def save_agent_thought( agent_thought.observation = observation - if answer is not None: + if answer: agent_thought.answer = answer if messages_ids is not None and len(messages_ids) > 0: @@ -404,7 +410,7 @@ def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variab """ convert tool variables to db variables """ - db_variables = ( + queried_variables = ( db.session.query(ToolConversationVariables) .filter( ToolConversationVariables.conversation_id == self.message.conversation_id, @@ -412,6 +418,11 @@ def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variab .first() ) + if not queried_variables: + return + + db_variables = queried_variables + db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db.session.commit() @@ -421,7 +432,7 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P """ Organize agent history """ - result = [] + result: list[PromptMessage] = [] # check if there is a system message in the beginning of the conversation for prompt_message in prompt_messages: if isinstance(prompt_message, SystemPromptMessage): diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index d98ba5a3f..e936acb60 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -1,7 +1,7 @@ import json from abc import ABC, abstractmethod -from collections.abc import Generator -from typing import Optional, Union +from collections.abc import Generator, Mapping +from typing import Any, Optional from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit @@ -12,6 +12,7 @@ from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageTool, ToolPromptMessage, UserPromptMessage, ) @@ -26,18 +27,18 @@ class CotAgentRunner(BaseAgentRunner, ABC): _is_first_iteration = True _ignore_observation_providers = ["wenxin"] - _historic_prompt_messages: list[PromptMessage] = None - _agent_scratchpad: list[AgentScratchpadUnit] = None - _instruction: str = None - _query: str = None - _prompt_messages_tools: list[PromptMessage] = None + _historic_prompt_messages: list[PromptMessage] | None = None + _agent_scratchpad: list[AgentScratchpadUnit] | None = None + _instruction: str = "" # FIXME this must be str for now + _query: str | None = None + _prompt_messages_tools: list[PromptMessageTool] = [] def run( self, message: Message, query: str, - inputs: dict[str, str], - ) -> Union[Generator, LLMResult]: + inputs: Mapping[str, str], + ) -> Generator: """ Run Cot agent application """ @@ -57,19 +58,19 @@ def run( # init instruction inputs = inputs or {} instruction = app_config.prompt_template.simple_prompt_template - self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) + self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs) iteration_step = 1 - max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 + max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1 # convert tools into ModelRuntime Tool format tool_instances, self._prompt_messages_tools = self._init_prompt_tools() function_call_state = True - llm_usage = {"usage": None} + llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} final_answer = "" - def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): + def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage): if not final_llm_usage_dict["usage"]: final_llm_usage_dict["usage"] = usage else: @@ -90,7 +91,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # the last iteration, remove all tools self._prompt_messages_tools = [] - message_file_ids = [] + message_file_ids: list[str] = [] agent_thought = self.create_agent_thought( message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids @@ -105,7 +106,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): prompt_messages = self._organize_prompt_messages() self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model - chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( + chunks = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=app_generate_entity.model_conf.parameters, tools=[], @@ -115,11 +116,14 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): callbacks=[], ) + if not isinstance(chunks, Generator): + raise ValueError("Expected streaming response from LLM") + # check llm result if not chunks: raise ValueError("failed to invoke llm") - usage_dict = {} + usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None} react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) scratchpad = AgentScratchpadUnit( agent_response="", @@ -139,25 +143,30 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if isinstance(chunk, AgentScratchpadUnit.Action): action = chunk # detect action - scratchpad.agent_response += json.dumps(chunk.model_dump()) + if scratchpad.agent_response is not None: + scratchpad.agent_response += json.dumps(chunk.model_dump()) scratchpad.action_str = json.dumps(chunk.model_dump()) scratchpad.action = action else: - scratchpad.agent_response += chunk - scratchpad.thought += chunk + if scratchpad.agent_response is not None: + scratchpad.agent_response += chunk + if scratchpad.thought is not None: + scratchpad.thought += chunk yield LLMResultChunk( model=self.model_config.model, prompt_messages=prompt_messages, system_fingerprint="", delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), ) - - scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" - self._agent_scratchpad.append(scratchpad) + if scratchpad.thought is not None: + scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" + if self._agent_scratchpad is not None: + self._agent_scratchpad.append(scratchpad) # get llm usage if "usage" in usage_dict: - increase_usage(llm_usage, usage_dict["usage"]) + if usage_dict["usage"] is not None: + increase_usage(llm_usage, usage_dict["usage"]) else: usage_dict["usage"] = LLMUsage.empty_usage() @@ -166,9 +175,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): tool_name=scratchpad.action.action_name if scratchpad.action else "", tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, tool_invoke_meta={}, - thought=scratchpad.thought, + thought=scratchpad.thought or "", observation="", - answer=scratchpad.agent_response, + answer=scratchpad.agent_response or "", messages_ids=[], llm_usage=usage_dict["usage"], ) @@ -209,7 +218,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): agent_thought=agent_thought, tool_name=scratchpad.action.action_name, tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, - thought=scratchpad.thought, + thought=scratchpad.thought or "", observation={scratchpad.action.action_name: tool_invoke_response}, tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()}, answer=scratchpad.agent_response, @@ -247,8 +256,8 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): answer=final_answer, messages_ids=[], ) - - self.update_db_variables(self.variables_pool, self.db_variables_pool) + if self.variables_pool is not None and self.db_variables_pool is not None: + self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event self.queue_manager.publish( QueueMessageEndEvent( @@ -307,8 +316,9 @@ def _handle_invoke_action( # publish files for message_file_id, save_as in message_files: - if save_as: - self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) + if save_as is not None and self.variables_pool: + # FIXME the save_as type is confusing, it should be a string or not + self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as)) # publish message file self.queue_manager.publish( @@ -325,7 +335,7 @@ def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action: """ return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"]) - def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: + def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str: """ fill in inputs from external data tools """ @@ -376,11 +386,13 @@ def _organize_historic_prompt_messages( """ result: list[PromptMessage] = [] scratchpads: list[AgentScratchpadUnit] = [] - current_scratchpad: AgentScratchpadUnit = None + current_scratchpad: AgentScratchpadUnit | None = None for message in self.history_prompt_messages: if isinstance(message, AssistantPromptMessage): if not current_scratchpad: + if not isinstance(message.content, str | None): + raise NotImplementedError("expected str type") current_scratchpad = AgentScratchpadUnit( agent_response=message.content, thought=message.content or "I am thinking about how to help you", @@ -399,8 +411,12 @@ def _organize_historic_prompt_messages( except: pass elif isinstance(message, ToolPromptMessage): - if current_scratchpad: + if not current_scratchpad: + continue + if isinstance(message.content, str): current_scratchpad.observation = message.content + else: + raise NotImplementedError("expected str type") elif isinstance(message, UserPromptMessage): if scratchpads: result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index d8d047fe9..6a96c349b 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -19,7 +19,12 @@ def _organize_system_prompt(self) -> SystemPromptMessage: """ Organize system prompt """ + if not self.app_config.agent: + raise ValueError("Agent configuration is not set") + prompt_entity = self.app_config.agent.prompt + if not prompt_entity: + raise ValueError("Agent prompt configuration is not set") first_prompt = prompt_entity.first_prompt system_prompt = ( @@ -75,6 +80,7 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: assistant_messages = [] else: assistant_message = AssistantPromptMessage(content="") + assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str for unit in agent_scratchpad: if unit.is_final(): assistant_message.content += f"Final Answer: {unit.agent_response}" diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 056309053..3a4d31e04 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -2,7 +2,12 @@ from typing import Optional from core.agent.cot_agent_runner import CotAgentRunner -from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.utils.encoders import jsonable_encoder @@ -11,7 +16,11 @@ def _organize_instruction_prompt(self) -> str: """ Organize instruction prompt """ + if self.app_config.agent is None: + raise ValueError("Agent configuration is not set") prompt_entity = self.app_config.agent.prompt + if prompt_entity is None: + raise ValueError("prompt entity is not set") first_prompt = prompt_entity.first_prompt system_prompt = ( @@ -33,7 +42,13 @@ def _organize_historic_prompt(self, current_session_messages: Optional[list[Prom if isinstance(message, UserPromptMessage): historic_prompt += f"Question: {message.content}\n\n" elif isinstance(message, AssistantPromptMessage): - historic_prompt += message.content + "\n\n" + if isinstance(message.content, str): + historic_prompt += message.content + "\n\n" + elif isinstance(message.content, list): + for content in message.content: + if not isinstance(content, TextPromptMessageContent): + continue + historic_prompt += content.data return historic_prompt @@ -50,7 +65,7 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: # organize current assistant messages agent_scratchpad = self._agent_scratchpad assistant_prompt = "" - for unit in agent_scratchpad: + for unit in agent_scratchpad or []: if unit.is_final(): assistant_prompt += f"Final Answer: {unit.agent_response}" else: diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 119a88fc7..2ae87dca3 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -78,5 +78,5 @@ class Strategy(Enum): model: str strategy: Strategy prompt: Optional[AgentPromptEntity] = None - tools: list[AgentToolEntity] = None + tools: list[AgentToolEntity] | None = None max_iteration: int = 5 diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index cd546dee1..b862c9607 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -40,6 +40,8 @@ def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResul app_generate_entity = self.application_generate_entity app_config = self.app_config + assert app_config is not None, "app_config is required" + assert app_config.agent is not None, "app_config.agent is required" # convert tools into ModelRuntime Tool format tool_instances, prompt_messages_tools = self._init_prompt_tools() @@ -49,7 +51,7 @@ def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResul # continue to run until there is not any tool call function_call_state = True - llm_usage = {"usage": None} + llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()} final_answer = "" # get tracing instance @@ -75,7 +77,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # the last iteration, remove all tools prompt_messages_tools = [] - message_file_ids = [] + message_file_ids: list[str] = [] agent_thought = self.create_agent_thought( message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) @@ -105,7 +107,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): current_llm_usage = None - if self.stream_tool_call: + if self.stream_tool_call and isinstance(chunks, Generator): is_first_chunk = True for chunk in chunks: if is_first_chunk: @@ -116,7 +118,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # check if there is any tool call if self.check_tool_calls(chunk): function_call_state = True - tool_calls.extend(self.extract_tool_calls(chunk)) + tool_calls.extend(self.extract_tool_calls(chunk) or []) tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: tool_call_inputs = json.dumps( @@ -131,19 +133,19 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): for content in chunk.delta.message.content: response += content.data else: - response += chunk.delta.message.content + response += str(chunk.delta.message.content) if chunk.delta.usage: increase_usage(llm_usage, chunk.delta.usage) current_llm_usage = chunk.delta.usage yield chunk - else: - result: LLMResult = chunks + elif not self.stream_tool_call and isinstance(chunks, LLMResult): + result = chunks # check if there is any tool call if self.check_blocking_tool_calls(result): function_call_state = True - tool_calls.extend(self.extract_blocking_tool_calls(result)) + tool_calls.extend(self.extract_blocking_tool_calls(result) or []) tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: tool_call_inputs = json.dumps( @@ -162,7 +164,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): for content in result.message.content: response += content.data else: - response += result.message.content + response += str(result.message.content) if not result.message.content: result.message.content = "" @@ -181,6 +183,8 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): usage=result.usage, ), ) + else: + raise RuntimeError(f"invalid chunks type: {type(chunks)}") assistant_message = AssistantPromptMessage(content="", tool_calls=[]) if tool_calls: @@ -243,7 +247,10 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # publish files for message_file_id, save_as in message_files: if save_as: - self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) + if self.variables_pool: + self.variables_pool.set_file( + tool_name=tool_call_name, value=message_file_id, name=save_as + ) # publish message file self.queue_manager.publish( @@ -263,7 +270,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if tool_response["tool_response"] is not None: self._current_thoughts.append( ToolPromptMessage( - content=tool_response["tool_response"], + content=str(tool_response["tool_response"]), tool_call_id=tool_call_id, name=tool_call_name, ) @@ -273,9 +280,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # save agent thought self.save_agent_thought( agent_thought=agent_thought, - tool_name=None, - tool_input=None, - thought=None, + tool_name="", + tool_input="", + thought="", tool_invoke_meta={ tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses }, @@ -283,7 +290,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): tool_response["tool_call_name"]: tool_response["tool_response"] for tool_response in tool_responses }, - answer=None, + answer="", messages_ids=message_file_ids, ) self.queue_manager.publish( @@ -296,7 +303,8 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): iteration_step += 1 - self.update_db_variables(self.variables_pool, self.db_variables_pool) + if self.variables_pool and self.db_variables_pool: + self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event self.queue_manager.publish( QueueMessageEndEvent( @@ -389,9 +397,9 @@ def _init_system_message( if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) - return prompt_messages + return prompt_messages or [] - def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ Organize user query """ @@ -449,7 +457,7 @@ def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage] def _organize_prompt_messages(self): prompt_template = self.app_config.prompt_template.simple_prompt_template or "" self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) - query_prompt_messages = self._organize_user_query(self.query, []) + query_prompt_messages = self._organize_user_query(self.query or "", []) self.history_prompt_messages = AgentHistoryPromptTransform( model_config=self.model_config, diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 085bac860..61fa774ea 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -38,7 +38,7 @@ def parse_action(json_str): except: return json_str or "" - def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: + def extra_json_from_code_block(code_block) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL) if not code_blocks: return @@ -67,15 +67,15 @@ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, for response in llm_response: if response.delta.usage: usage_dict["usage"] = response.delta.usage - response = response.delta.message.content - if not isinstance(response, str): + response_content = response.delta.message.content + if not isinstance(response_content, str): continue # stream index = 0 - while index < len(response): + while index < len(response_content): steps = 1 - delta = response[index : index + steps] + delta = response_content[index : index + steps] yield_delta = False if delta == "`": diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index b9aae7904..646c4badb 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -66,6 +66,8 @@ def convert(cls, config: dict) -> Optional[DatasetEntity]: dataset_configs = config.get("dataset_configs") else: dataset_configs = {"retrieval_model": "multiple"} + if dataset_configs is None: + return None query_variable = config.get("dataset_query_variable") if dataset_configs["retrieval_model"] == "single": diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 5adcf26f1..642686511 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -94,7 +94,7 @@ def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> config["model"]["completion_params"] ) - return config, ["model"] + return dict(config), ["model"] @classmethod def validate_model_completion_params(cls, cp: dict) -> dict: diff --git a/api/core/app/app_config/features/opening_statement/manager.py b/api/core/app/app_config/features/opening_statement/manager.py index b4dacbc40..92b4185ab 100644 --- a/api/core/app/app_config/features/opening_statement/manager.py +++ b/api/core/app/app_config/features/opening_statement/manager.py @@ -7,10 +7,10 @@ def convert(cls, config: dict) -> tuple[str, list]: :param config: model config args """ # opening statement - opening_statement = config.get("opening_statement") + opening_statement = config.get("opening_statement", "") # suggested questions - suggested_questions_list = config.get("suggested_questions") + suggested_questions_list = config.get("suggested_questions", []) return opening_statement, suggested_questions_list diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index cdf8e47b5..447094d98 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -24,7 +24,7 @@ InvokeFrom, ) from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from extensions.ext_database import db @@ -32,6 +32,7 @@ from models.account import Account from models.model import ApiToken, App, Conversation, EndUser, Message # 二开部分 - 密钥额度限制,新增ApiToken from models.workflow import Workflow +from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -155,7 +156,7 @@ def generate( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, stream=streaming, @@ -323,6 +324,8 @@ def _generate_worker( # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError("Message not exists") # chatbot app runner = AdvancedChatAppRunner( @@ -343,7 +346,7 @@ def _generate_worker( except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - except (ValueError, InvokeError) as e: + except ValueError as e: if dify_config.DEBUG: logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) @@ -390,7 +393,7 @@ def _handle_advanced_chat_response( try: return generate_task_pipeline.process() except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error + if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error raise GenerateTaskStoppedError() else: logger.exception(f"Failed to process generate task pipeline, conversation_id: {conversation.id}") diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py index 29709914b..a50644767 100644 --- a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py +++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py @@ -5,6 +5,7 @@ import re import threading from collections.abc import Iterable +from typing import Optional from core.app.entities.queue_entities import ( MessageQueueMessage, @@ -15,6 +16,7 @@ WorkflowQueueMessage, ) from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.message_entities import TextPromptMessageContent from core.model_runtime.entities.model_entities import ModelType @@ -71,8 +73,9 @@ def __init__(self, tenant_id: str, voice: str): if not voice or voice not in values: self.voice = self.voices[0].get("value") self.MAX_SENTENCE = 2 - self._last_audio_event = None - self._runtime_thread = threading.Thread(target=self._runtime).start() + self._last_audio_event: Optional[AudioTrunk] = None + # FIXME better way to handle this threading.start + threading.Thread(target=self._runtime).start() self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /): @@ -92,10 +95,21 @@ def _runtime(self): future_queue.put(futures_result) break elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent): - self.msg_text += message.event.chunk.delta.message.content + message_content = message.event.chunk.delta.message.content + if not message_content: + continue + if isinstance(message_content, str): + self.msg_text += message_content + elif isinstance(message_content, list): + for content in message_content: + if not isinstance(content, TextPromptMessageContent): + continue + self.msg_text += content.data elif isinstance(message.event, QueueTextChunkEvent): self.msg_text += message.event.text elif isinstance(message.event, QueueNodeSucceededEvent): + if message.event.outputs is None: + continue self.msg_text += message.event.outputs.get("output", "") self.last_message = message sentence_arr, text_tmp = self._extract_sentence(self.msg_text) @@ -121,11 +135,10 @@ def check_and_get_audio(self): if self._last_audio_event and self._last_audio_event.status == "finish": if self.executor: self.executor.shutdown(wait=False) - return self.last_message + return self._last_audio_event audio = self._audio_queue.get_nowait() if audio and audio.status == "finish": self.executor.shutdown(wait=False) - self._runtime_thread = None if audio: self._last_audio_event = audio return audio diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index cf0c9d759..6339d7989 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -109,18 +109,18 @@ def run(self) -> None: ConversationVariable.conversation_id == self.conversation.id, ) with Session(db.engine) as session: - conversation_variables = session.scalars(stmt).all() - if not conversation_variables: + db_conversation_variables = session.scalars(stmt).all() + if not db_conversation_variables: # Create conversation variables if they don't exist. - conversation_variables = [ + db_conversation_variables = [ ConversationVariable.from_variable( app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable ) for variable in workflow.conversation_variables ] - session.add_all(conversation_variables) + session.add_all(db_conversation_variables) # Convert database entities to variables. - conversation_variables = [item.to_variable() for item in conversation_variables] + conversation_variables = [item.to_variable() for item in db_conversation_variables] session.commit() diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index ea5d8ee73..146ba1787 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -2,8 +2,12 @@ import logging import time from collections.abc import Generator, Mapping +from threading import Thread from typing import Any, Optional, Union +from sqlalchemy import select +from sqlalchemy.orm import Session + from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -64,25 +68,17 @@ from models.enums import CreatedByRole from models.workflow import ( Workflow, - WorkflowNodeExecution, WorkflowRunStatus, ) logger = logging.getLogger(__name__) -class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage, MessageCycleManage): +class AdvancedChatAppGenerateTaskPipeline: """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _task_state: WorkflowTaskState - _application_generate_entity: AdvancedChatAppGenerateEntity - _workflow: Workflow - _user: Union[Account, EndUser] - _workflow_system_variables: dict[SystemVariableKey, Any] - _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] - def __init__( self, application_generate_entity: AdvancedChatAppGenerateEntity, @@ -94,61 +90,66 @@ def __init__( stream: bool, dialogue_count: int, ) -> None: - """ - Initialize AdvancedChatAppGenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param workflow: workflow - :param queue_manager: queue manager - :param conversation: conversation - :param message: message - :param user: user - :param stream: stream - :param dialogue_count: dialogue count - """ - super().__init__(application_generate_entity, queue_manager, user, stream) + self._base_task_pipeline = BasedGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) - if isinstance(self._user, EndUser): - user_id = self._user.session_id + if isinstance(user, EndUser): + self._user_id = user.id + user_session_id = user.session_id + self._created_by_role = CreatedByRole.END_USER + elif isinstance(user, Account): + self._user_id = user.id + user_session_id = user.id + self._created_by_role = CreatedByRole.ACCOUNT else: - user_id = self._user.id - - self._workflow = workflow - self._conversation = conversation - self._message = message - self._workflow_system_variables = { - SystemVariableKey.QUERY: message.query, - SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.CONVERSATION_ID: conversation.id, - SystemVariableKey.USER_ID: user_id, - SystemVariableKey.DIALOGUE_COUNT: dialogue_count, - SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, - } + raise NotImplementedError(f"User type not supported: {type(user)}") + + self._workflow_cycle_manager = WorkflowCycleManage( + application_generate_entity=application_generate_entity, + workflow_system_variables={ + SystemVariableKey.QUERY: message.query, + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.CONVERSATION_ID: conversation.id, + SystemVariableKey.USER_ID: user_session_id, + SystemVariableKey.DIALOGUE_COUNT: dialogue_count, + SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, + SystemVariableKey.WORKFLOW_ID: workflow.id, + SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, + }, + ) self._task_state = WorkflowTaskState() - self._wip_workflow_node_executions = {} + self._message_cycle_manager = MessageCycleManage( + application_generate_entity=application_generate_entity, task_state=self._task_state + ) - self._conversation_name_generate_thread = None + self._application_generate_entity = application_generate_entity + self._workflow_id = workflow.id + self._workflow_features_dict = workflow.features_dict + self._conversation_id = conversation.id + self._conversation_mode = conversation.mode + self._message_id = message.id + self._message_created_at = int(message.created_at.timestamp()) + self._conversation_name_generate_thread: Thread | None = None self._recorded_files: list[Mapping[str, Any]] = [] + self._workflow_run_id: str = "" - def process(self): + def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: """ Process generate task pipeline. :return: """ - db.session.refresh(self._workflow) - db.session.refresh(self._user) - db.session.close() - # start generate conversation name thread - self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, self._application_generate_entity.query + self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name( + conversation_id=self._conversation_id, query=self._application_generate_entity.query ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) - if self._stream: + if self._base_task_pipeline._stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) @@ -169,12 +170,12 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] return ChatbotAppBlockingResponse( task_id=stream_response.task_id, data=ChatbotAppBlockingResponse.Data( - id=self._message.id, - mode=self._conversation.mode, - conversation_id=self._conversation.id, - message_id=self._message.id, + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, answer=self._task_state.answer, - created_at=int(self._message.created_at.timestamp()), + created_at=self._message_created_at, **extras, ), ) @@ -192,9 +193,9 @@ def _to_stream_response( """ for stream_response in generator: yield ChatbotAppStreamResponse( - conversation_id=self._conversation.id, - message_id=self._message.id, - created_at=int(self._message.created_at.timestamp()), + conversation_id=self._conversation_id, + message_id=self._message_id, + created_at=self._message_created_at, stream_response=stream_response, ) @@ -212,7 +213,7 @@ def _wrapper_process_stream_response( tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id - features_dict = self._workflow.features_dict + features_dict = self._workflow_features_dict if ( features_dict.get("text_to_speech") @@ -263,255 +264,352 @@ def _process_stream_response( :return: """ # init fake graph runtime state - graph_runtime_state = None - workflow_run = None + graph_runtime_state: Optional[GraphRuntimeState] = None - for queue_message in self._queue_manager.listen(): + for queue_message in self._base_task_pipeline._queue_manager.listen(): event = queue_message.event if isinstance(event, QueuePingEvent): - yield self._ping_stream_response() + yield self._base_task_pipeline._ping_stream_response() elif isinstance(event, QueueErrorEvent): - err = self._handle_error(event, self._message) - yield self._error_to_stream_response(err) + with Session(db.engine, expire_on_commit=False) as session: + err = self._base_task_pipeline._handle_error( + event=event, session=session, message_id=self._message_id + ) + session.commit() + yield self._base_task_pipeline._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): # override graph runtime state graph_runtime_state = event.graph_runtime_state - # init workflow run - workflow_run = self._handle_workflow_run_start() - - self._refetch_message() - # ------------------- 二开部分Begin - 密钥额度限制 ------------------- - app_token_id = self._application_generate_entity.extras.get("app_token_id") - if app_token_id: - ApiTokenMessageJoinsExtend( - app_token_id=app_token_id, record_id=workflow_run.id, app_mode=AppMode.ADVANCED_CHAT.value - ).add_app_token_record_id() - # ------------------- 二开部分End - 密钥额度限制 ------------------- - - self._refetch_message() - self._message.workflow_run_id = workflow_run.id + with Session(db.engine, expire_on_commit=False) as session: + # init workflow run + workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( + session=session, + workflow_id=self._workflow_id, + user_id=self._user_id, + created_by_role=self._created_by_role, + ) - db.session.commit() - db.session.refresh(self._message) - db.session.close() + # ------------------- 二开部分Begin - 密钥额度限制 ------------------- + app_token_id = self._application_generate_entity.extras.get("app_token_id") + if app_token_id: + ApiTokenMessageJoinsExtend( + app_token_id=app_token_id, record_id=workflow_run.id, app_mode=AppMode.ADVANCED_CHAT.value + ).add_app_token_record_id() + # ------------------- 二开部分End - 密钥额度限制 ------------------- + + self._workflow_run_id = workflow_run.id + message = self._get_message(session=session) + if not message: + raise ValueError(f"Message not found: {self._message_id}") + message.workflow_run_id = workflow_run.id + workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() - yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + yield workflow_start_resp elif isinstance( event, QueueNodeRetryEvent, ): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - workflow_node_execution = self._handle_workflow_node_execution_retried( - workflow_run=workflow_run, event=event - ) - response = self._workflow_node_retry_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( + session=session, workflow_run=workflow_run, event=event + ) + node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() - if response: - yield response + if node_retry_resp: + yield node_retry_resp elif isinstance(event, QueueNodeStartedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( + session=session, workflow_run=workflow_run, event=event + ) - response = self._workflow_node_start_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() - if response: - yield response + if node_start_resp: + yield node_start_resp elif isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._handle_workflow_node_execution_success(event) - # Record files if it's an answer node or end node if event.node_type in [NodeType.ANSWER, NodeType.END]: - self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) + self._recorded_files.extend( + self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {}) + ) - response = self._workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( + session=session, event=event + ) - if response: - yield response - elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): - workflow_node_execution = self._handle_workflow_node_execution_failed(event) + node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() - response = self._workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + if node_finish_resp: + yield node_finish_resp + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): + with Session(db.engine, expire_on_commit=False) as session: + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( + session=session, event=event + ) - if response: - yield response + node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() + if node_finish_resp: + yield node_finish_resp elif isinstance(event, QueueParallelBranchRunStartedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_parallel_branch_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + parallel_start_resp = ( + self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + ) + + yield parallel_start_resp elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_parallel_branch_finished_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + parallel_finish_resp = ( + self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + ) + + yield parallel_finish_resp elif isinstance(event, QueueIterationStartEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_iteration_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_start_resp elif isinstance(event, QueueIterationNextEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_iteration_next_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_next_resp elif isinstance(event, QueueIterationCompletedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_iteration_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_finish_resp elif isinstance(event, QueueWorkflowSucceededEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") if not graph_runtime_state: raise ValueError("workflow run not initialized.") - workflow_run = self._handle_workflow_run_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - conversation_id=self._conversation.id, - trace_manager=trace_manager, - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_success( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + conversation_id=self._conversation_id, + trace_manager=trace_manager, + ) - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() - self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + yield workflow_finish_resp + self._base_task_pipeline._queue_manager.publish( + QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE + ) elif isinstance(event, QueueWorkflowPartialSuccessEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_partial_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - exceptions_count=event.exceptions_count, - conversation_id=None, - trace_manager=trace_manager, - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + yield workflow_finish_resp + self._base_task_pipeline._queue_manager.publish( + QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE ) - - self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) elif isinstance(event, QueueWorkflowFailedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_failed( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.FAILED, - error=event.error, - conversation_id=self._conversation.id, - trace_manager=trace_manager, - exceptions_count=event.exceptions_count, - ) - - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) - - err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) - yield self._error_to_stream_response(self._handle_error(err_event, self._message)) - break - elif isinstance(event, QueueStopEvent): - if workflow_run and graph_runtime_state: - workflow_run = self._handle_workflow_run_failed( - workflow_run=workflow_run, + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( + session=session, + workflow_run_id=self._workflow_run_id, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.STOPPED, - error=event.get_stop_reason(), - conversation_id=self._conversation.id, + status=WorkflowRunStatus.FAILED, + error=event.error, + conversation_id=self._conversation_id, trace_manager=trace_manager, + exceptions_count=event.exceptions_count, ) - - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) + err = self._base_task_pipeline._handle_error( + event=err_event, session=session, message_id=self._message_id ) + session.commit() - # Save message - self._save_message(graph_runtime_state=graph_runtime_state) + yield workflow_finish_resp + yield self._base_task_pipeline._error_to_stream_response(err) + break + elif isinstance(event, QueueStopEvent): + if self._workflow_run_id and graph_runtime_state: + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.STOPPED, + error=event.get_stop_reason(), + conversation_id=self._conversation_id, + trace_manager=trace_manager, + ) + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + ) + # Save message + self._save_message(session=session, graph_runtime_state=graph_runtime_state) + session.commit() + + yield workflow_finish_resp yield self._message_end_to_stream_response() break elif isinstance(event, QueueRetrieverResourcesEvent): - self._handle_retriever_resources(event) - - self._refetch_message() + self._message_cycle_manager._handle_retriever_resources(event) - self._message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) - - db.session.commit() - db.session.refresh(self._message) - db.session.close() + with Session(db.engine, expire_on_commit=False) as session: + message = self._get_message(session=session) + message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + session.commit() elif isinstance(event, QueueAnnotationReplyEvent): - self._handle_annotation_reply(event) + self._message_cycle_manager._handle_annotation_reply(event) - self._refetch_message() - - self._message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) - - db.session.commit() - db.session.refresh(self._message) - db.session.close() + with Session(db.engine, expire_on_commit=False) as session: + message = self._get_message(session=session) + message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + session.commit() elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: @@ -527,23 +625,29 @@ def _process_stream_response( tts_publisher.publish(queue_message) self._task_state.answer += delta_text - yield self._message_to_stream_response( - answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector + yield self._message_cycle_manager._message_to_stream_response( + answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector ) elif isinstance(event, QueueMessageReplaceEvent): # published by moderation - yield self._message_replace_to_stream_response(answer=event.text) + yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueueAdvancedChatMessageEndEvent): if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) + output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished( + self._task_state.answer + ) if output_moderation_answer: self._task_state.answer = output_moderation_answer - yield self._message_replace_to_stream_response(answer=output_moderation_answer) + yield self._message_cycle_manager._message_replace_to_stream_response( + answer=output_moderation_answer + ) # Save message - self._save_message(graph_runtime_state=graph_runtime_state) + with Session(db.engine, expire_on_commit=False) as session: + self._save_message(session=session, graph_runtime_state=graph_runtime_state) + session.commit() yield self._message_end_to_stream_response() else: @@ -556,54 +660,46 @@ def _process_stream_response( if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: - self._refetch_message() - - self._message.answer = self._task_state.answer - self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.message_metadata = ( + def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: + message = self._get_message(session=session) + message.answer = self._task_state.answer + message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at + message.message_metadata = ( json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None ) message_files = [ MessageFile( - message_id=self._message.id, + message_id=message.id, type=file["type"], transfer_method=file["transfer_method"], url=file["remote_url"], belongs_to="assistant", upload_file_id=file["related_id"], created_by_role=CreatedByRole.ACCOUNT - if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else CreatedByRole.END_USER, - created_by=self._message.from_account_id or self._message.from_end_user_id or "", + created_by=message.from_account_id or message.from_end_user_id or "", ) for file in self._recorded_files ] - db.session.add_all(message_files) + session.add_all(message_files) if graph_runtime_state and graph_runtime_state.llm_usage: usage = graph_runtime_state.llm_usage - self._message.message_tokens = usage.prompt_tokens - self._message.message_unit_price = usage.prompt_unit_price - self._message.message_price_unit = usage.prompt_price_unit - self._message.answer_tokens = usage.completion_tokens - self._message.answer_unit_price = usage.completion_unit_price - self._message.answer_price_unit = usage.completion_price_unit - self._message.total_price = usage.total_price - self._message.currency = usage.currency - + message.message_tokens = usage.prompt_tokens + message.message_unit_price = usage.prompt_unit_price + message.message_price_unit = usage.prompt_price_unit + message.answer_tokens = usage.completion_tokens + message.answer_unit_price = usage.completion_unit_price + message.answer_price_unit = usage.completion_price_unit + message.total_price = usage.total_price + message.currency = usage.currency self._task_state.metadata["usage"] = jsonable_encoder(usage) else: self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage()) - - db.session.commit() - message_was_created.send( - self._message, + message, application_generate_entity=self._application_generate_entity, - conversation=self._conversation, - is_first_message=self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras, ) def _message_end_to_stream_response(self) -> MessageEndStreamResponse: @@ -619,7 +715,10 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse: del extras["metadata"]["annotation_reply"] return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras + task_id=self._application_generate_entity.task_id, + id=self._message_id, + files=self._recorded_files, + metadata=extras.get("metadata", {}), ) def _handle_output_moderation_chunk(self, text: str) -> bool: @@ -628,28 +727,26 @@ def _handle_output_moderation_chunk(self, text: str) -> bool: :param text: text :return: True if output moderation should direct output, otherwise False """ - if self._output_moderation_handler: - if self._output_moderation_handler.should_direct_output(): + if self._base_task_pipeline._output_moderation_handler: + if self._base_task_pipeline._output_moderation_handler.should_direct_output(): # stop subscribe new token when output moderation should direct output - self._task_state.answer = self._output_moderation_handler.get_final_output() - self._queue_manager.publish( + self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output() + self._base_task_pipeline._queue_manager.publish( QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE ) - self._queue_manager.publish( + self._base_task_pipeline._queue_manager.publish( QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE ) return True else: - self._output_moderation_handler.append_new_token(text) + self._base_task_pipeline._output_moderation_handler.append_new_token(text) return False - def _refetch_message(self) -> None: - """ - Refetch message. - :return: - """ - message = db.session.query(Message).filter(Message.id == self._message.id).first() - if message: - self._message = message + def _get_message(self, *, session: Session): + stmt = select(Message).where(Message.id == self._message_id) + message = session.scalar(stmt) + if not message: + raise ValueError(f"Message not found: {self._message_id}") + return message diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 417d23ecc..55b6ee510 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -61,7 +61,7 @@ def get_app_config( app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: - config_dict = override_config_dict + config_dict = override_config_dict or {} app_mode = AppMode.value_of(app_model.mode) app_config = AgentChatAppConfig( diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 884354579..72eed0c4e 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -18,11 +18,12 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser +from services.errors.message import MessageNotExistsError from models.api_token_money_extend import ApiTokenMessageJoinsExtend # 二开部分End - 密钥额度限制 logger = logging.getLogger(__name__) @@ -98,7 +99,7 @@ def generate( # get conversation conversation = None if args.get("conversation_id"): - conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user) # get app model config app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) @@ -154,7 +155,7 @@ def generate( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, stream=streaming, @@ -189,7 +190,7 @@ def generate( worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "conversation_id": conversation.id, @@ -208,8 +209,8 @@ def generate( user=user, stream=streaming, ) - - return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # FIXME: Type hinting issue here, ignore it for now, will fix it later + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore def _generate_worker( self, @@ -233,6 +234,8 @@ def _generate_worker( # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError("Message not exists") # chatbot app runner = AgentChatAppRunner() @@ -251,7 +254,7 @@ def _generate_worker( except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - except (ValueError, InvokeError) as e: + except ValueError as e: if dify_config.DEBUG: logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 45b1bf009..ac71f02b6 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -173,6 +173,8 @@ def run( return agent_entity = app_config.agent + if not agent_entity: + raise ValueError("Agent entity not found") # load tool variables tool_conversation_variables = self._load_tool_variables( @@ -200,14 +202,21 @@ def run( # change function call strategy based on LLM model llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + if not model_schema or not model_schema.features: + raise ValueError("Model schema not found") if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING - conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() - message = db.session.query(Message).filter(Message.id == message.id).first() + conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() + if conversation_result is None: + raise ValueError("Conversation not found") + message_result = db.session.query(Message).filter(Message.id == message.id).first() + if message_result is None: + raise ValueError("Message not found") db.session.close() + runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner] # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: # check LLM mode @@ -225,12 +234,12 @@ def run( runner = runner_cls( tenant_id=app_config.tenant_id, application_generate_entity=application_generate_entity, - conversation=conversation, + conversation=conversation_result, app_config=app_config, model_config=application_generate_entity.model_conf, config=agent_entity, queue_manager=queue_manager, - message=message, + message=message_result, user_id=application_generate_entity.user_id, memory=memory, prompt_messages=prompt_message, @@ -257,7 +266,7 @@ def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: st """ load tool variables from database """ - tool_variables: ToolConversationVariables = ( + tool_variables: ToolConversationVariables | None = ( db.session.query(ToolConversationVariables) .filter( ToolConversationVariables.conversation_id == conversation_id, diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 629c309c0..ce331d904 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = ChatbotAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -37,7 +37,7 @@ def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingRes return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response @@ -51,8 +51,9 @@ def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingR return response @classmethod - def convert_stream_full_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + def convert_stream_full_response( # type: ignore[override] + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], ) -> Generator[str, None, None]: """ Convert stream full response. @@ -82,8 +83,9 @@ def convert_stream_full_response( yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + def convert_stream_simple_response( # type: ignore[override] + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], ) -> Generator[str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 3725c6e6d..1842fc430 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -50,7 +50,7 @@ def listen(self): # wait for APP_MAX_EXECUTION_TIME seconds to stop listen listen_timeout = dify_config.APP_MAX_EXECUTION_TIME start_time = time.time() - last_ping_time = 0 + last_ping_time: int | float = 0 while True: try: message = self._q.get(timeout=1) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 609fd03f2..07a248d77 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,5 +1,5 @@ import time -from collections.abc import Generator, Mapping +from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity @@ -36,8 +36,8 @@ def get_pre_calculate_rest_tokens( app_record: App, model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["File"], + inputs: Mapping[str, str], + files: Sequence["File"], query: Optional[str] = None, ) -> int: """ @@ -64,7 +64,7 @@ def get_pre_calculate_rest_tokens( ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 if model_context_tokens is None: @@ -85,7 +85,7 @@ def get_pre_calculate_rest_tokens( prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) - rest_tokens = model_context_tokens - max_tokens - prompt_tokens + rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens if rest_tokens < 0: raise InvokeBadRequestError( "Query or prefix prompt is too long, you can reduce the prefix prompt, " @@ -111,7 +111,7 @@ def recalc_llm_max_tokens( ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 if model_context_tokens is None: @@ -136,8 +136,8 @@ def organize_prompt_messages( app_record: App, model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["File"], + inputs: Mapping[str, str], + files: Sequence["File"], query: Optional[str] = None, context: Optional[str] = None, memory: Optional[TokenBufferMemory] = None, @@ -156,6 +156,7 @@ def organize_prompt_messages( """ # get prompt without memory and context if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + prompt_transform: Union[SimplePromptTransform, AdvancedPromptTransform] prompt_transform = SimplePromptTransform() prompt_messages, stop = prompt_transform.get_prompt( app_mode=AppMode.value_of(app_record.mode), @@ -171,8 +172,11 @@ def organize_prompt_messages( memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) model_mode = ModelMode.value_of(model_config.mode) + prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] if model_mode == ModelMode.COMPLETION: advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template + if not advanced_completion_prompt_template: + raise InvokeBadRequestError("Advanced completion prompt template is required.") prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt) if advanced_completion_prompt_template.role_prefix: @@ -181,6 +185,8 @@ def organize_prompt_messages( assistant=advanced_completion_prompt_template.role_prefix.assistant, ) else: + if not prompt_template_entity.advanced_chat_prompt_template: + raise InvokeBadRequestError("Advanced chat prompt template is required.") prompt_template = [] for message in prompt_template_entity.advanced_chat_prompt_template.messages: prompt_template.append(ChatModelMessage(text=message.text, role=message.role)) @@ -246,7 +252,7 @@ def direct_output( def _handle_invoke_result( self, - invoke_result: Union[LLMResult, Generator], + invoke_result: Union[LLMResult, Generator[Any, None, None]], queue_manager: AppQueueManager, stream: bool, agent: bool = False, @@ -259,10 +265,12 @@ def _handle_invoke_result( :param agent: agent :return: """ - if not stream: + if not stream and isinstance(invoke_result, LLMResult): self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) - else: + elif stream and isinstance(invoke_result, Generator): self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) + else: + raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") def _handle_invoke_result_direct( self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool @@ -291,8 +299,8 @@ def _handle_invoke_result_stream( :param agent: agent :return: """ - model = None - prompt_messages = [] + model: str = "" + prompt_messages: list[PromptMessage] = [] text = "" usage = None for result in invoke_result: @@ -328,13 +336,14 @@ def _handle_invoke_result_stream( def moderation_for_inputs( self, + *, app_id: str, tenant_id: str, app_generate_entity: AppGenerateEntity, inputs: Mapping[str, Any], - query: str, + query: str | None = None, message_id: str, - ) -> tuple[bool, dict, str]: + ) -> tuple[bool, Mapping[str, Any], str]: """ Process sensitive_word_avoidance. :param app_id: app id @@ -350,7 +359,7 @@ def moderation_for_inputs( app_id=app_id, tenant_id=tenant_id, app_config=app_generate_entity.app_config, - inputs=inputs, + inputs=dict(inputs), query=query or "", message_id=message_id, trace_manager=app_generate_entity.trace_manager, @@ -390,9 +399,9 @@ def fill_in_inputs_from_external_data_tools( tenant_id: str, app_id: str, external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, + inputs: Mapping[str, Any], query: str, - ) -> dict: + ) -> Mapping[str, Any]: """ Fill in variable inputs from external data tools if exists. diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index edeb6f12b..bc8479603 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -18,13 +18,14 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory from models.account import Account from models.api_token_money_extend import ApiTokenMessageJoinsExtend # 二开部分End - 密钥额度限制 from models.model import App, EndUser +from services.errors.message import MessageNotExistsError logger = logging.getLogger(__name__) @@ -92,7 +93,7 @@ def generate( # get conversation conversation = None if args.get("conversation_id"): - conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id", ""), user) # get app model config app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) @@ -105,7 +106,7 @@ def generate( # validate config override_model_config_dict = ChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, config=args.get("model_config") + tenant_id=app_model.tenant_id, config=args.get("model_config", {}) ) # always enable retriever resource in debugger mode @@ -147,7 +148,7 @@ def generate( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, invoke_from=invoke_from, @@ -181,7 +182,7 @@ def generate( worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "conversation_id": conversation.id, @@ -225,6 +226,8 @@ def _generate_worker( # get conversation and message conversation = self._get_conversation(conversation_id) message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError("Message not exists") # chatbot app runner = ChatAppRunner() @@ -243,7 +246,7 @@ def _generate_worker( except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - except (ValueError, InvokeError) as e: + except ValueError as e: if dify_config.DEBUG: logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 0fa7af0a7..9024c3a98 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = ChatbotAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -37,7 +37,7 @@ def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingRes return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response @@ -52,7 +52,8 @@ def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingR @classmethod def convert_stream_full_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream full response. @@ -83,7 +84,8 @@ def convert_stream_full_response( @classmethod def convert_stream_simple_response( - cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + cls, + stream_response: Generator[ChatbotAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index 1193c4b7a..02e5d4756 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -42,7 +42,7 @@ def get_app_config( app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: - config_dict = override_config_dict + config_dict = override_config_dict or {} app_mode = AppMode.value_of(app_model.mode) app_config = CompletionAppConfig( diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index a882e0971..d4c068d77 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -17,7 +17,7 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory @@ -84,8 +84,6 @@ def generate( query = query.replace("\x00", "") inputs = args["inputs"] - extras = {} - # get conversation conversation = None @@ -100,7 +98,7 @@ def generate( # validate config override_model_config_dict = CompletionAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, config=args.get("model_config") + tenant_id=app_model.tenant_id, config=args.get("model_config", {}) ) # parse files @@ -133,11 +131,11 @@ def generate( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), query=query, - files=file_objs, + files=list(file_objs), user_id=user.id, stream=streaming, invoke_from=invoke_from, - extras=extras, + extras={}, trace_manager=trace_manager, ) @@ -166,7 +164,7 @@ def generate( worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "message_id": message.id, @@ -206,6 +204,8 @@ def _generate_worker( try: # get message message = self._get_message(message_id) + if message is None: + raise MessageNotExistsError() # chatbot app runner = CompletionAppRunner() @@ -223,7 +223,7 @@ def _generate_worker( except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - except (ValueError, InvokeError) as e: + except ValueError as e: if dify_config.DEBUG: logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) @@ -240,7 +240,7 @@ def generate_more_like_this( user: Union[Account, EndUser], invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[str, None, None]]: + ) -> Union[Mapping[str, Any], Generator[str, None, None]]: """ Generate App response. @@ -302,7 +302,7 @@ def generate_more_like_this( model_conf=ModelConfigConverter.convert(app_config), inputs=message.inputs, query=message.query, - files=file_objs, + files=list(file_objs), user_id=user.id, stream=stream, invoke_from=invoke_from, @@ -326,7 +326,7 @@ def generate_more_like_this( worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "message_id": message.id, diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 908d74ff5..41278b75b 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -76,7 +76,7 @@ def run( tenant_id=app_config.tenant_id, app_generate_entity=application_generate_entity, inputs=inputs, - query=query, + query=query or "", message_id=message.id, ) except ModerationError as e: @@ -122,7 +122,7 @@ def run( tenant_id=app_record.tenant_id, model_config=application_generate_entity.model_conf, config=dataset_config, - query=query, + query=query or "", invoke_from=application_generate_entity.invoke_from, show_retrieve_source=app_config.additional_features.show_retrieve_source, hit_callback=hit_callback, diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index 697f0273a..73f38c3d0 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = CompletionAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response @@ -36,7 +36,7 @@ def convert_blocking_full_response(cls, blocking_response: CompletionAppBlocking return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response @@ -51,7 +51,8 @@ def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlocki @classmethod def convert_stream_full_response( - cls, stream_response: Generator[CompletionAppStreamResponse, None, None] + cls, + stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream full response. @@ -81,7 +82,8 @@ def convert_stream_full_response( @classmethod def convert_stream_simple_response( - cls, stream_response: Generator[CompletionAppStreamResponse, None, None] + cls, + stream_response: Generator[CompletionAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 95ae798ec..4e3aa840c 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -2,11 +2,11 @@ import logging from collections.abc import Generator from datetime import UTC, datetime -from typing import Optional, Union +from typing import Optional, Union, cast from sqlalchemy import and_ -from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom +from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError from core.app.entities.app_invoke_entities import ( @@ -42,7 +42,7 @@ def _handle_response( ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity, + AgentChatAppGenerateEntity, ], queue_manager: AppQueueManager, conversation: Conversation, @@ -70,14 +70,13 @@ def _handle_response( queue_manager=queue_manager, conversation=conversation, message=message, - user=user, stream=stream, ) try: return generate_task_pipeline.process() except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error + if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error raise GenerateTaskStoppedError() else: logger.exception(f"Failed to handle response, conversation_id: {conversation.id}") @@ -144,7 +143,7 @@ def _init_generate_records( :conversation conversation :return: """ - app_config = application_generate_entity.app_config + app_config: EasyUIBasedAppConfig = cast(EasyUIBasedAppConfig, application_generate_entity.app_config) # get from source end_user_id = None @@ -267,7 +266,7 @@ def _get_conversation_introduction(self, application_generate_entity: AppGenerat except KeyError: pass - return introduction + return introduction or "" def _get_conversation(self, conversation_id: str): """ @@ -282,7 +281,7 @@ def _get_conversation(self, conversation_id: str): return conversation - def _get_message(self, message_id: str) -> Message: + def _get_message(self, message_id: str) -> Optional[Message]: """ Get message by message id :param message_id: message id diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index a7f068616..4d1cad241 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -20,7 +20,7 @@ from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory @@ -124,7 +124,7 @@ def generate( inputs=self._prepare_user_inputs( user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id ), - files=system_files, + files=list(system_files), user_id=user.id, stream=streaming, invoke_from=invoke_from, @@ -230,6 +230,7 @@ def single_iteration_generate( single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( node_id=node_id, inputs=args["inputs"] ), + workflow_run_id=str(uuid.uuid4()), ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -279,7 +280,7 @@ def _generate_worker( except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - except (ValueError, InvokeError) as e: + except ValueError as e: if dify_config.DEBUG: logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) @@ -318,7 +319,7 @@ def _handle_response( try: return generate_task_pipeline.process() except ValueError as e: - if e.args[0] == "I/O operation on closed file.": # ignore this error + if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error raise GenerateTaskStoppedError() else: logger.exception( diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 08d00ee18..5cdac6ad2 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -17,16 +17,16 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): _blocking_response_type = WorkflowAppBlockingResponse @classmethod - def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: + def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking full response. :param blocking_response: blocking response :return: """ - return blocking_response.to_dict() + return dict(blocking_response.to_dict()) @classmethod - def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: + def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] """ Convert blocking simple response. :param blocking_response: blocking response @@ -36,7 +36,8 @@ def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlocking @classmethod def convert_stream_full_response( - cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] + cls, + stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream full response. @@ -65,7 +66,8 @@ def convert_stream_full_response( @classmethod def convert_stream_simple_response( - cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] + cls, + stream_response: Generator[WorkflowAppStreamResponse, None, None], # type: ignore[override] ) -> Generator[str, None, None]: """ Convert stream simple response. diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 90a17236b..62c495fb0 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -1,7 +1,9 @@ import logging import time from collections.abc import Generator -from typing import Any, Optional, Union +from typing import Optional, Union + +from sqlalchemy.orm import Session from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk @@ -50,13 +52,13 @@ from core.workflow.enums import SystemVariableKey from extensions.ext_database import db from models.account import Account +from models.enums import CreatedByRole from models.api_token_money_extend import ApiTokenMessageJoinsExtend # 二开部分End - 密钥额度限制 from models.model import AppMode, EndUser # 二开部分End - 密钥额度限制,新增AppMode from models.workflow import ( Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, - WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus, ) @@ -64,18 +66,11 @@ logger = logging.getLogger(__name__) -class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleManage): +class WorkflowAppGenerateTaskPipeline: """ WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _workflow: Workflow - _user: Union[Account, EndUser] - _task_state: WorkflowTaskState - _application_generate_entity: WorkflowAppGenerateEntity - _workflow_system_variables: dict[SystemVariableKey, Any] - _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] - def __init__( self, application_generate_entity: WorkflowAppGenerateEntity, @@ -84,44 +79,47 @@ def __init__( user: Union[Account, EndUser], stream: bool, ) -> None: - """ - Initialize GenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param workflow: workflow - :param queue_manager: queue manager - :param user: user - :param stream: is streamed - """ - super().__init__(application_generate_entity, queue_manager, user, stream) + self._base_task_pipeline = BasedGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) - if isinstance(self._user, EndUser): - user_id = self._user.session_id + if isinstance(user, EndUser): + self._user_id = user.id + user_session_id = user.session_id + self._created_by_role = CreatedByRole.END_USER + elif isinstance(user, Account): + self._user_id = user.id + user_session_id = user.id + self._created_by_role = CreatedByRole.ACCOUNT else: - user_id = self._user.id - - self._workflow = workflow - self._workflow_system_variables = { - SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.USER_ID: user_id, - SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, - } + raise ValueError(f"Invalid user type: {type(user)}") + + self._workflow_cycle_manager = WorkflowCycleManage( + application_generate_entity=application_generate_entity, + workflow_system_variables={ + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.USER_ID: user_session_id, + SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, + SystemVariableKey.WORKFLOW_ID: workflow.id, + SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, + }, + ) + self._application_generate_entity = application_generate_entity + self._workflow_id = workflow.id + self._workflow_features_dict = workflow.features_dict self._task_state = WorkflowTaskState() - self._wip_workflow_node_executions = {} + self._workflow_run_id = "" def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ Process generate task pipeline. :return: """ - db.session.refresh(self._workflow) - db.session.refresh(self._user) - db.session.close() - generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) - if self._stream: + if self._base_task_pipeline._stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) @@ -186,7 +184,7 @@ def _wrapper_process_stream_response( tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id - features_dict = self._workflow.features_dict + features_dict = self._workflow_features_dict if ( features_dict.get("text_to_speech") @@ -235,197 +233,297 @@ def _process_stream_response( :return: """ graph_runtime_state = None - workflow_run = None - for queue_message in self._queue_manager.listen(): + for queue_message in self._base_task_pipeline._queue_manager.listen(): event = queue_message.event if isinstance(event, QueuePingEvent): - yield self._ping_stream_response() + yield self._base_task_pipeline._ping_stream_response() elif isinstance(event, QueueErrorEvent): - err = self._handle_error(event) - yield self._error_to_stream_response(err) + err = self._base_task_pipeline._handle_error(event=event) + yield self._base_task_pipeline._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): # override graph runtime state graph_runtime_state = event.graph_runtime_state - # init workflow run - workflow_run = self._handle_workflow_run_start() - - # ------------------- 二开部分Begin - 密钥额度限制 ------------------- - app_token_id = self._application_generate_entity.extras.get("app_token_id") - if app_token_id: - ApiTokenMessageJoinsExtend( - app_token_id=app_token_id, record_id=workflow_run.id, app_mode=AppMode.WORKFLOW.value - ).add_app_token_record_id() - # ------------------- 二开部分End - 密钥额度限制 ------------------- - - yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + with Session(db.engine, expire_on_commit=False) as session: + # init workflow run + workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( + session=session, + workflow_id=self._workflow_id, + user_id=self._user_id, + created_by_role=self._created_by_role, + ) + + # ------------------- 二开部分Begin - 密钥额度限制 ------------------- + app_token_id = self._application_generate_entity.extras.get("app_token_id") + if app_token_id: + ApiTokenMessageJoinsExtend( + app_token_id=app_token_id, record_id=workflow_run.id, app_mode=AppMode.WORKFLOW.value + ).add_app_token_record_id() + # ------------------- 二开部分End - 密钥额度限制 ------------------- + + self._workflow_run_id = workflow_run.id + start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield start_resp elif isinstance( event, QueueNodeRetryEvent, ): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - workflow_node_execution = self._handle_workflow_node_execution_retried( - workflow_run=workflow_run, event=event - ) - - response = self._workflow_node_retry_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( + session=session, workflow_run=workflow_run, event=event + ) + response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() if response: yield response elif isinstance(event, QueueNodeStartedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) - - node_start_response = self._workflow_node_start_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( + session=session, workflow_run=workflow_run, event=event + ) + node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() if node_start_response: yield node_start_response elif isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._handle_workflow_node_execution_success(event) - - node_success_response = self._workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( + session=session, event=event + ) + node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() if node_success_response: yield node_success_response elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): - workflow_node_execution = self._handle_workflow_node_execution_failed(event) + with Session(db.engine, expire_on_commit=False) as session: + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( + session=session, + event=event, + ) + node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + session=session, + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + session.commit() - node_failed_response = self._workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) if node_failed_response: yield node_failed_response elif isinstance(event, QueueParallelBranchRunStartedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_parallel_branch_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + parallel_start_resp = ( + self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + ) + + yield parallel_start_resp + elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_parallel_branch_finished_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + parallel_finish_resp = ( + self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + ) + + yield parallel_finish_resp + elif isinstance(event, QueueIterationStartEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_iteration_start_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_start_resp + elif isinstance(event, QueueIterationNextEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_iteration_next_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_next_resp + elif isinstance(event, QueueIterationCompletedEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - yield self._workflow_iteration_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._get_workflow_run( + session=session, workflow_run_id=self._workflow_run_id + ) + iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + event=event, + ) + + yield iter_finish_resp + elif isinstance(event, QueueWorkflowSucceededEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - conversation_id=None, - trace_manager=trace_manager, - ) - - # save workflow app log - self._save_workflow_app_log(workflow_run) - - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_success( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + conversation_id=None, + trace_manager=trace_manager, + ) + + # save workflow app log + self._save_workflow_app_log(session=session, workflow_run=workflow_run) + + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_run=workflow_run, + ) + session.commit() + + yield workflow_finish_resp elif isinstance(event, QueueWorkflowPartialSuccessEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_partial_success( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - exceptions_count=event.exceptions_count, - conversation_id=None, - trace_manager=trace_manager, - ) - - # save workflow app log - self._save_workflow_app_log(workflow_run) - - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) + + # save workflow app log + self._save_workflow_app_log(session=session, workflow_run=workflow_run) + + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield workflow_finish_resp elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): - if not workflow_run: + if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_failed( - workflow_run=workflow_run, - start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.FAILED - if isinstance(event, QueueWorkflowFailedEvent) - else WorkflowRunStatus.STOPPED, - error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), - conversation_id=None, - trace_manager=trace_manager, - exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, - ) - - # save workflow app log - self._save_workflow_app_log(workflow_run) - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( + session=session, + workflow_run_id=self._workflow_run_id, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED + if isinstance(event, QueueWorkflowFailedEvent) + else WorkflowRunStatus.STOPPED, + error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + conversation_id=None, + trace_manager=trace_manager, + exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, + ) + + # save workflow app log + self._save_workflow_app_log(session=session, workflow_run=workflow_run) + + workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + session.commit() + + yield workflow_finish_resp elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: @@ -445,7 +543,7 @@ def _process_stream_response( if tts_publisher: tts_publisher.publish(None) - def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: + def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None: """ Save workflow app log. :return: @@ -467,12 +565,10 @@ def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: workflow_app_log.workflow_id = workflow_run.workflow_id workflow_app_log.workflow_run_id = workflow_run.id workflow_app_log.created_from = created_from.value - workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user" - workflow_app_log.created_by = self._user.id + workflow_app_log.created_by_role = self._created_by_role + workflow_app_log.created_by = self._user_id - db.session.add(workflow_app_log) - db.session.commit() - db.session.close() + session.add(workflow_app_log) def _text_chunk_to_stream_response( self, text: str, from_variable_selector: Optional[list[str]] = None diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 885283504..63f516bcc 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -24,6 +24,7 @@ QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, @@ -190,16 +191,15 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count)) elif isinstance(event, NodeRunRetryEvent): node_run_result = event.route_node_state.node_run_result + inputs: Mapping[str, Any] | None = {} + process_data: Mapping[str, Any] | None = {} + outputs: Mapping[str, Any] | None = {} + execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {} if node_run_result: inputs = node_run_result.inputs process_data = node_run_result.process_data outputs = node_run_result.outputs execution_metadata = node_run_result.metadata - else: - inputs = {} - process_data = {} - outputs = {} - execution_metadata = {} self._publish_event( QueueNodeRetryEvent( node_execution_id=event.id, @@ -289,7 +289,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) process_data=event.route_node_state.node_run_result.process_data if event.route_node_state.node_run_result else {}, - outputs=event.route_node_state.node_run_result.outputs + outputs=event.route_node_state.node_run_result.outputs or {} if event.route_node_state.node_run_result else {}, error=event.route_node_state.node_run_result.error @@ -349,7 +349,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) process_data=event.route_node_state.node_run_result.process_data if event.route_node_state.node_run_result else {}, - outputs=event.route_node_state.node_run_result.outputs + outputs=event.route_node_state.node_run_result.outputs or {} if event.route_node_state.node_run_result else {}, execution_metadata=event.route_node_state.node_run_result.metadata diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 31c3a996e..7cb4e5903 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from constants import UUID_NIL -from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig +from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle from core.file import File, FileUploadConfig from core.model_runtime.entities.model_entities import AIModelEntity @@ -79,7 +79,7 @@ class AppGenerateEntity(BaseModel): task_id: str # app config - app_config: AppConfig + app_config: Any file_upload_config: Optional[FileUploadConfig] = None inputs: Mapping[str, Any] @@ -195,7 +195,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): # app config app_config: WorkflowUIBasedAppConfig - workflow_run_id: Optional[str] = None + workflow_run_id: str class SingleIterationRunEntity(BaseModel): """ diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index d73c2eb53..a93e533ff 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -308,7 +308,7 @@ class QueueNodeSucceededEvent(AppQueueEvent): inputs: Optional[Mapping[str, Any]] = None process_data: Optional[Mapping[str, Any]] = None outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None error: Optional[str] = None """single iteration duration map""" diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index dd088a897..5e845eba2 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -70,7 +70,7 @@ class StreamResponse(BaseModel): event: StreamEvent task_id: str - def to_dict(self) -> dict: + def to_dict(self): return jsonable_encoder(self) @@ -474,8 +474,8 @@ class Data(BaseModel): title: str created_at: int extras: dict = {} - metadata: dict = {} - inputs: dict = {} + metadata: Mapping = {} + inputs: Mapping = {} parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None @@ -526,15 +526,15 @@ class Data(BaseModel): node_id: str node_type: str title: str - outputs: Optional[dict] = None + outputs: Optional[Mapping] = None created_at: int extras: Optional[dict] = None - inputs: Optional[dict] = None + inputs: Optional[Mapping] = None status: WorkflowNodeExecutionStatus error: Optional[str] = None elapsed_time: float total_tokens: int - execution_metadata: Optional[dict] = None + execution_metadata: Optional[Mapping] = None finished_at: int steps: int parallel_id: Optional[str] = None @@ -628,7 +628,7 @@ class AppBlockingResponse(BaseModel): task_id: str - def to_dict(self) -> dict: + def to_dict(self): return jsonable_encoder(self) diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 77b6bb554..83fd3deba 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -58,7 +58,7 @@ def query( query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]} ) - if documents: + if documents and documents[0].metadata: annotation_id = documents[0].metadata["annotation_id"] score = documents[0].metadata["score"] annotation = AppAnnotationService.get_annotation_by_id(annotation_id) diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 8fe1d96b3..dcc2b4e55 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -17,7 +17,7 @@ class RateLimit: _UNLIMITED_REQUEST_ID = "unlimited_request_id" _REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes - _instance_dict = {} + _instance_dict: dict[str, "RateLimit"] = {} def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 51d610e2c..a2e06d4e1 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -1,6 +1,9 @@ import logging import time -from typing import Optional, Union +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.orm import Session from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import ( @@ -12,14 +15,11 @@ from core.app.entities.task_entities import ( ErrorStreamResponse, PingStreamResponse, - TaskState, ) from core.errors.error import QuotaExceededError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.moderation.output_moderation import ModerationRule, OutputModeration -from extensions.ext_database import db -from models.account import Account -from models.model import EndUser, Message +from models.model import Message logger = logging.getLogger(__name__) @@ -29,39 +29,22 @@ class BasedGenerateTaskPipeline: BasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _task_state: TaskState - _application_generate_entity: AppGenerateEntity - def __init__( self, application_generate_entity: AppGenerateEntity, queue_manager: AppQueueManager, - user: Union[Account, EndUser], stream: bool, ) -> None: - """ - Initialize GenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param user: user - :param stream: stream - """ self._application_generate_entity = application_generate_entity self._queue_manager = queue_manager - self._user = user self._start_at = time.perf_counter() self._output_moderation_handler = self._init_output_moderation() self._stream = stream - def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None): - """ - Handle error event. - :param event: event - :param message: message - :return: - """ + def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""): logger.debug("error: %s", event.error) e = event.error + err: Exception if isinstance(e, InvokeAuthorizationError): err = InvokeAuthorizationError("Incorrect API key provided") @@ -70,16 +53,17 @@ def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = Non else: err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) - if message: - refetch_message = db.session.query(Message).filter(Message.id == message.id).first() - - if refetch_message: - err_desc = self._error_to_desc(err) - refetch_message.status = "error" - refetch_message.error = err_desc + if not message_id or not session: + return err - db.session.commit() + stmt = select(Message).where(Message.id == message_id) + message = session.scalar(stmt) + if not message: + return err + err_desc = self._error_to_desc(err) + message.status = "error" + message.error = err_desc return err def _error_to_desc(self, e: Exception) -> str: @@ -130,6 +114,7 @@ def _init_output_moderation(self) -> Optional[OutputModeration]: rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config), queue_manager=self._queue_manager, ) + return None def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: """ diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 4216cd46c..c84f8ba3e 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -2,8 +2,12 @@ import logging import time from collections.abc import Generator +from threading import Thread from typing import Optional, Union, cast +from sqlalchemy import select +from sqlalchemy.orm import Session + from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -54,8 +58,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.message_event import message_was_created from extensions.ext_database import db -from models.account import Account -from models.model import AppMode, Conversation, EndUser, Message, MessageAgentThought +from models.model import AppMode, Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) @@ -76,23 +79,21 @@ def __init__( queue_manager: AppQueueManager, conversation: Conversation, message: Message, - user: Union[Account, EndUser], stream: bool, ) -> None: - """ - Initialize GenerateTaskPipeline. - :param application_generate_entity: application generate entity - :param queue_manager: queue manager - :param conversation: conversation - :param message: message - :param user: user - :param stream: stream - """ - super().__init__(application_generate_entity, queue_manager, user, stream) + super().__init__( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + stream=stream, + ) self._model_config = application_generate_entity.model_conf self._app_config = application_generate_entity.app_config - self._conversation = conversation - self._message = message + + self._conversation_id = conversation.id + self._conversation_mode = conversation.mode + + self._message_id = message.id + self._message_created_at = int(message.created_at.timestamp()) self._task_state = EasyUITaskState( llm_result=LLMResult( @@ -103,7 +104,7 @@ def __init__( ) ) - self._conversation_name_generate_thread = None + self._conversation_name_generate_thread: Optional[Thread] = None def process( self, @@ -112,18 +113,10 @@ def process( CompletionAppBlockingResponse, Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], ]: - """ - Process generate task pipeline. - :return: - """ - db.session.refresh(self._conversation) - db.session.refresh(self._message) - db.session.close() - if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, self._application_generate_entity.query + conversation_id=self._conversation_id, query=self._application_generate_entity.query or "" ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) @@ -146,16 +139,16 @@ def _to_blocking_response( extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} if self._task_state.metadata: extras["metadata"] = self._task_state.metadata - - if self._conversation.mode == AppMode.COMPLETION.value: + response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] + if self._conversation_mode == AppMode.COMPLETION.value: response = CompletionAppBlockingResponse( task_id=self._application_generate_entity.task_id, data=CompletionAppBlockingResponse.Data( - id=self._message.id, - mode=self._conversation.mode, - message_id=self._message.id, - answer=self._task_state.llm_result.message.content, - created_at=int(self._message.created_at.timestamp()), + id=self._message_id, + mode=self._conversation_mode, + message_id=self._message_id, + answer=cast(str, self._task_state.llm_result.message.content), + created_at=self._message_created_at, **extras, ), ) @@ -163,12 +156,12 @@ def _to_blocking_response( response = ChatbotAppBlockingResponse( task_id=self._application_generate_entity.task_id, data=ChatbotAppBlockingResponse.Data( - id=self._message.id, - mode=self._conversation.mode, - conversation_id=self._conversation.id, - message_id=self._message.id, - answer=self._task_state.llm_result.message.content, - created_at=int(self._message.created_at.timestamp()), + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, + answer=cast(str, self._task_state.llm_result.message.content), + created_at=self._message_created_at, **extras, ), ) @@ -177,7 +170,7 @@ def _to_blocking_response( else: continue - raise Exception("Queue listening stopped unexpectedly.") + raise RuntimeError("queue listening stopped unexpectedly.") def _to_stream_response( self, generator: Generator[StreamResponse, None, None] @@ -189,15 +182,15 @@ def _to_stream_response( for stream_response in generator: if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): yield CompletionAppStreamResponse( - message_id=self._message.id, - created_at=int(self._message.created_at.timestamp()), + message_id=self._message_id, + created_at=self._message_created_at, stream_response=stream_response, ) else: yield ChatbotAppStreamResponse( - conversation_id=self._conversation.id, - message_id=self._message.id, - created_at=int(self._message.created_at.timestamp()), + conversation_id=self._conversation_id, + message_id=self._message_id, + created_at=self._message_created_at, stream_response=stream_response, ) @@ -252,7 +245,7 @@ def _wrapper_process_stream_response( yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( - self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None + self, publisher: Optional[AppGeneratorTTSPublisher], trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -264,27 +257,32 @@ def _process_stream_response( event = message.event if isinstance(event, QueueErrorEvent): - err = self._handle_error(event, self._message) + with Session(db.engine) as session: + err = self._handle_error(event=event, session=session, message_id=self._message_id) + session.commit() yield self._error_to_stream_response(err) break elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent): - self._task_state.llm_result = event.llm_result + if event.llm_result: + self._task_state.llm_result = event.llm_result else: self._handle_stop(event) # handle output moderation output_moderation_answer = self._handle_output_moderation_when_task_finished( - self._task_state.llm_result.message.content + cast(str, self._task_state.llm_result.message.content) ) if output_moderation_answer: self._task_state.llm_result.message.content = output_moderation_answer yield self._message_replace_to_stream_response(answer=output_moderation_answer) - # Save message - self._save_message(trace_manager) - - yield self._message_end_to_stream_response() + with Session(db.engine) as session: + # Save message + self._save_message(session=session, trace_manager=trace_manager) + session.commit() + message_end_resp = self._message_end_to_stream_response() + yield message_end_resp elif isinstance(event, QueueRetrieverResourcesEvent): self._handle_retriever_resources(event) elif isinstance(event, QueueAnnotationReplyEvent): @@ -292,7 +290,9 @@ def _process_stream_response( if annotation: self._task_state.llm_result.message.content = annotation.content elif isinstance(event, QueueAgentThoughtEvent): - yield self._agent_thought_to_stream_response(event) + agent_thought_response = self._agent_thought_to_stream_response(event) + if agent_thought_response is not None: + yield agent_thought_response elif isinstance(event, QueueMessageFileEvent): response = self._message_file_to_stream_response(event) if response: @@ -307,16 +307,24 @@ def _process_stream_response( self._task_state.llm_result.prompt_messages = chunk.prompt_messages # handle output moderation chunk - should_direct_answer = self._handle_output_moderation_chunk(delta_text) + should_direct_answer = self._handle_output_moderation_chunk(cast(str, delta_text)) if should_direct_answer: continue - self._task_state.llm_result.message.content += delta_text + current_content = cast(str, self._task_state.llm_result.message.content) + current_content += cast(str, delta_text) + self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): - yield self._message_to_stream_response(delta_text, self._message.id) + yield self._message_to_stream_response( + answer=cast(str, delta_text), + message_id=self._message_id, + ) else: - yield self._agent_message_to_stream_response(delta_text, self._message.id) + yield self._agent_message_to_stream_response( + answer=cast(str, delta_text), + message_id=self._message_id, + ) elif isinstance(event, QueueMessageReplaceEvent): yield self._message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueuePingEvent): @@ -328,7 +336,7 @@ def _process_stream_response( if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None: + def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None: """ Save message. :return: @@ -336,46 +344,46 @@ def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> No llm_result = self._task_state.llm_result usage = llm_result.usage - self._message = db.session.query(Message).filter(Message.id == self._message.id).first() - self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() + message_stmt = select(Message).where(Message.id == self._message_id) + message = session.scalar(message_stmt) + if not message: + raise ValueError(f"message {self._message_id} not found") + conversation_stmt = select(Conversation).where(Conversation.id == self._conversation_id) + conversation = session.scalar(conversation_stmt) + if not conversation: + raise ValueError(f"Conversation {self._conversation_id} not found") - self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( + message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( self._model_config.mode, self._task_state.llm_result.prompt_messages ) - self._message.message_tokens = usage.prompt_tokens - self._message.message_unit_price = usage.prompt_unit_price - self._message.message_price_unit = usage.prompt_price_unit - self._message.answer = ( - PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) + message.message_tokens = usage.prompt_tokens + message.message_unit_price = usage.prompt_unit_price + message.message_price_unit = usage.prompt_price_unit + message.answer = ( + PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip()) if llm_result.message.content else "" ) - self._message.answer_tokens = usage.completion_tokens - self._message.answer_unit_price = usage.completion_unit_price - self._message.answer_price_unit = usage.completion_price_unit - self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.total_price = usage.total_price - self._message.currency = usage.currency - self._message.message_metadata = ( + message.answer_tokens = usage.completion_tokens + message.answer_unit_price = usage.completion_unit_price + message.answer_price_unit = usage.completion_price_unit + message.provider_response_latency = time.perf_counter() - self._start_at + message.total_price = usage.total_price + message.currency = usage.currency + message.message_metadata = ( json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None ) - db.session.commit() - if trace_manager: trace_manager.add_trace_task( TraceTask( - TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id + TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id ) ) message_was_created.send( - self._message, + message, application_generate_entity=self._application_generate_entity, - conversation=self._conversation, - is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT} - and self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras, ) def _handle_stop(self, event: QueueStopEvent) -> None: @@ -420,7 +428,9 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse: extras["metadata"] = self._task_state.metadata return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, id=self._message.id, **extras + task_id=self._application_generate_entity.task_id, + id=self._message_id, + metadata=extras.get("metadata", {}), ) def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: @@ -440,7 +450,7 @@ def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Op :param event: agent thought event :return: """ - agent_thought: MessageAgentThought = ( + agent_thought: Optional[MessageAgentThought] = ( db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first() ) db.session.refresh(agent_thought) diff --git a/api/core/app/task_pipeline/exc.py b/api/core/app/task_pipeline/exc.py new file mode 100644 index 000000000..e4b4168d0 --- /dev/null +++ b/api/core/app/task_pipeline/exc.py @@ -0,0 +1,17 @@ +class TaskPipilineError(ValueError): + pass + + +class RecordNotFoundError(TaskPipilineError): + def __init__(self, record_name: str, record_id: str): + super().__init__(f"{record_name} with id {record_id} not found") + + +class WorkflowRunNotFoundError(RecordNotFoundError): + def __init__(self, workflow_run_id: str): + super().__init__("WorkflowRun", workflow_run_id) + + +class WorkflowNodeExecutionNotFoundError(RecordNotFoundError): + def __init__(self, workflow_node_execution_id: str): + super().__init__("WorkflowNodeExecution", workflow_node_execution_id) diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index e818a090e..6a4ab259b 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -31,12 +31,21 @@ class MessageCycleManage: - _application_generate_entity: Union[ - ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity - ] - _task_state: Union[EasyUITaskState, WorkflowTaskState] - - def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]: + def __init__( + self, + *, + application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity, + ], + task_state: Union[EasyUITaskState, WorkflowTaskState], + ) -> None: + self._application_generate_entity = application_generate_entity + self._task_state = task_state + + def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: """ Generate conversation name. :param conversation: conversation @@ -56,7 +65,7 @@ def _generate_conversation_name(self, conversation: Conversation, query: str) -> target=self._generate_conversation_name_worker, kwargs={ "flask_app": current_app._get_current_object(), # type: ignore - "conversation_id": conversation.id, + "conversation_id": conversation_id, "query": query, }, ) @@ -128,7 +137,7 @@ def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Opti """ message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first() - if message_file: + if message_file and message_file.url is not None: # get tool file id tool_file_id = message_file.url.split("/")[-1] # trim extension diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 5d26c25b0..d8017c39b 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -5,6 +5,7 @@ from typing import Any, Optional, Union, cast from uuid import uuid4 +from sqlalchemy import func, select from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity @@ -33,7 +34,6 @@ ParallelBranchStartStreamResponse, WorkflowFinishStreamResponse, WorkflowStartStreamResponse, - WorkflowTaskState, ) from core.file import FILE_MODEL_IDENTITY, File from core.model_runtime.utils.encoders import jsonable_encoder @@ -45,7 +45,6 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.workflow_entry import WorkflowEntry -from extensions.ext_database import db from models.account import Account from models.enums import CreatedByRole, WorkflowRunTriggeredFrom from models.model import EndUser @@ -61,30 +60,45 @@ update_account_money_when_workflow_node_execution_created_extend, # 二开部分End - 密钥额度限制 ) +from .exc import WorkflowRunNotFoundError + class WorkflowCycleManage: - _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] - _workflow: Workflow - _user: Union[Account, EndUser] - _task_state: WorkflowTaskState - _workflow_system_variables: dict[SystemVariableKey, Any] - _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] - - def _handle_workflow_run_start(self) -> WorkflowRun: - max_sequence = ( - db.session.query(db.func.max(WorkflowRun.sequence_number)) - .filter(WorkflowRun.tenant_id == self._workflow.tenant_id) - .filter(WorkflowRun.app_id == self._workflow.app_id) - .scalar() - or 0 + def __init__( + self, + *, + application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], + workflow_system_variables: dict[SystemVariableKey, Any], + ) -> None: + self._workflow_run: WorkflowRun | None = None + self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {} + self._application_generate_entity = application_generate_entity + self._workflow_system_variables = workflow_system_variables + + def _handle_workflow_run_start( + self, + *, + session: Session, + workflow_id: str, + user_id: str, + created_by_role: CreatedByRole, + ) -> WorkflowRun: + workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) + workflow = session.scalar(workflow_stmt) + if not workflow: + raise ValueError(f"Workflow not found: {workflow_id}") + + max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where( + WorkflowRun.tenant_id == workflow.tenant_id, + WorkflowRun.app_id == workflow.app_id, ) + max_sequence = session.scalar(max_sequence_stmt) or 0 new_sequence_number = max_sequence + 1 inputs = {**self._application_generate_entity.inputs} for key, value in (self._workflow_system_variables or {}).items(): if key.value == "conversation": continue - inputs[f"sys.{key.value}"] = value triggered_from = ( @@ -94,37 +108,37 @@ def _handle_workflow_run_start(self) -> WorkflowRun: ) # handle special values - inputs = WorkflowEntry.handle_special_values(inputs) + inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) # init workflow run - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = WorkflowRun() - system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID] - workflow_run.id = system_id or str(uuid4()) - workflow_run.tenant_id = self._workflow.tenant_id - workflow_run.app_id = self._workflow.app_id - workflow_run.sequence_number = new_sequence_number - workflow_run.workflow_id = self._workflow.id - workflow_run.type = self._workflow.type - workflow_run.triggered_from = triggered_from.value - workflow_run.version = self._workflow.version - workflow_run.graph = self._workflow.graph - workflow_run.inputs = json.dumps(inputs) - workflow_run.status = WorkflowRunStatus.RUNNING - workflow_run.created_by_role = ( - CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER - ) - workflow_run.created_by = self._user.id - workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) - - session.add(workflow_run) - session.commit() + # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this + workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4()) + + workflow_run = WorkflowRun() + workflow_run.id = workflow_run_id + workflow_run.tenant_id = workflow.tenant_id + workflow_run.app_id = workflow.app_id + workflow_run.sequence_number = new_sequence_number + workflow_run.workflow_id = workflow.id + workflow_run.type = workflow.type + workflow_run.triggered_from = triggered_from.value + workflow_run.version = workflow.version + workflow_run.graph = workflow.graph + workflow_run.inputs = json.dumps(inputs) + workflow_run.status = WorkflowRunStatus.RUNNING + workflow_run.created_by_role = created_by_role + workflow_run.created_by = user_id + workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) + + session.add(workflow_run) return workflow_run def _handle_workflow_run_success( self, - workflow_run: WorkflowRun, + *, + session: Session, + workflow_run_id: str, start_at: float, total_tokens: int, total_steps: int, @@ -142,7 +156,7 @@ def _handle_workflow_run_success( :param conversation_id: conversation id :return: """ - workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) outputs = WorkflowEntry.handle_special_values(outputs) @@ -153,9 +167,6 @@ def _handle_workflow_run_success( workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - db.session.refresh(workflow_run) - if trace_manager: trace_manager.add_trace_task( TraceTask( @@ -166,13 +177,13 @@ def _handle_workflow_run_success( ) ) - db.session.close() - return workflow_run def _handle_workflow_run_partial_success( self, - workflow_run: WorkflowRun, + *, + session: Session, + workflow_run_id: str, start_at: float, total_tokens: int, total_steps: int, @@ -181,19 +192,8 @@ def _handle_workflow_run_partial_success( conversation_id: Optional[str] = None, trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowRun: - """ - Workflow run success - :param workflow_run: workflow run - :param start_at: start time - :param total_tokens: total tokens - :param total_steps: total steps - :param outputs: outputs - :param conversation_id: conversation id - :return: - """ - workflow_run = self._refetch_workflow_run(workflow_run.id) - - outputs = WorkflowEntry.handle_special_values(outputs) + workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) + outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value workflow_run.outputs = json.dumps(outputs or {}) @@ -202,8 +202,6 @@ def _handle_workflow_run_partial_success( workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.exceptions_count = exceptions_count - db.session.commit() - db.session.refresh(workflow_run) if trace_manager: trace_manager.add_trace_task( @@ -215,13 +213,13 @@ def _handle_workflow_run_partial_success( ) ) - db.session.close() - return workflow_run def _handle_workflow_run_failed( self, - workflow_run: WorkflowRun, + *, + session: Session, + workflow_run_id: str, start_at: float, total_tokens: int, total_steps: int, @@ -241,7 +239,7 @@ def _handle_workflow_run_failed( :param error: error message :return: """ - workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) workflow_run.status = status.value workflow_run.error = error @@ -250,35 +248,27 @@ def _handle_workflow_run_failed( workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.exceptions_count = exceptions_count - db.session.commit() - - running_workflow_node_executions = ( - db.session.query(WorkflowNodeExecution) - .filter( - WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, - WorkflowNodeExecution.app_id == workflow_run.app_id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == workflow_run.id, - WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, - ) - .all() + + stmt = select(WorkflowNodeExecution.node_execution_id).where( + WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, + WorkflowNodeExecution.app_id == workflow_run.app_id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == workflow_run.id, + WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, ) + ids = session.scalars(stmt).all() + # Use self._get_workflow_node_execution here to make sure the cache is updated + running_workflow_node_executions = [ + self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id + ] for workflow_node_execution in running_workflow_node_executions: + now = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = error - workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) - workflow_node_execution.elapsed_time = ( - workflow_node_execution.finished_at - workflow_node_execution.created_at - ).total_seconds() - db.session.commit() - - db.session.close() - - # with Session(db.engine, expire_on_commit=False) as session: - # session.add(workflow_run) - # session.refresh(workflow_run) + workflow_node_execution.finished_at = now + workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds() if trace_manager: trace_manager.add_trace_task( @@ -293,49 +283,43 @@ def _handle_workflow_run_failed( return workflow_run def _handle_node_execution_start( - self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent + self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent ) -> WorkflowNodeExecution: - # init workflow node execution - - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = WorkflowNodeExecution() - workflow_node_execution.tenant_id = workflow_run.tenant_id - workflow_node_execution.app_id = workflow_run.app_id - workflow_node_execution.workflow_id = workflow_run.workflow_id - workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value - workflow_node_execution.workflow_run_id = workflow_run.id - workflow_node_execution.predecessor_node_id = event.predecessor_node_id - workflow_node_execution.index = event.node_run_index - workflow_node_execution.node_execution_id = event.node_execution_id - workflow_node_execution.node_id = event.node_id - workflow_node_execution.node_type = event.node_type.value - workflow_node_execution.title = event.node_data.title - workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value - workflow_node_execution.created_by_role = workflow_run.created_by_role - workflow_node_execution.created_by = workflow_run.created_by - workflow_node_execution.execution_metadata = json.dumps( - { - NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, - NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, - } - ) - workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.id = str(uuid4()) + workflow_node_execution.tenant_id = workflow_run.tenant_id + workflow_node_execution.app_id = workflow_run.app_id + workflow_node_execution.workflow_id = workflow_run.workflow_id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + workflow_node_execution.workflow_run_id = workflow_run.id + workflow_node_execution.predecessor_node_id = event.predecessor_node_id + workflow_node_execution.index = event.node_run_index + workflow_node_execution.node_execution_id = event.node_execution_id + workflow_node_execution.node_id = event.node_id + workflow_node_execution.node_type = event.node_type.value + workflow_node_execution.title = event.node_data.title + workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value + workflow_node_execution.created_by_role = workflow_run.created_by_role + workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.execution_metadata = json.dumps( + { + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + } + ) + workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) - session.add(workflow_node_execution) - session.commit() - session.refresh(workflow_node_execution) + session.add(workflow_node_execution) - self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution + self._workflow_node_executions[event.node_execution_id] = workflow_node_execution return workflow_node_execution - def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: - """ - Workflow node execution success - :param event: queue node succeeded event - :return: - """ - workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) - + def _handle_workflow_node_execution_success( + self, *, session: Session, event: QueueNodeSucceededEvent + ) -> WorkflowNodeExecution: + workflow_node_execution = self._get_workflow_node_execution( + session=session, node_execution_id=event.node_execution_id + ) inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) @@ -345,20 +329,6 @@ def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() - db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( - { - WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value, - WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, - WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None, - WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, - WorkflowNodeExecution.execution_metadata: execution_metadata, - WorkflowNodeExecution.finished_at: finished_at, - WorkflowNodeExecution.elapsed_time: elapsed_time, - } - ) - - db.session.commit() - db.session.close() process_data = WorkflowEntry.handle_special_values(event.process_data) workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value @@ -369,7 +339,7 @@ def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent workflow_node_execution.finished_at = finished_at workflow_node_execution.elapsed_time = elapsed_time - self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) + workflow_node_execution = session.merge(workflow_node_execution) # 二开部分Begin - 额度限制 workflow_node_execution_dict = jsonable_encoder(workflow_node_execution) # 转化为json字典 @@ -379,14 +349,19 @@ def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent return workflow_node_execution def _handle_workflow_node_execution_failed( - self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent + self, + *, + session: Session, + event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent, ) -> WorkflowNodeExecution: """ Workflow node execution failed :param event: queue node failed event :return: """ - workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) + workflow_node_execution = self._get_workflow_node_execution( + session=session, node_execution_id=event.node_execution_id + ) inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) @@ -396,25 +371,6 @@ def _handle_workflow_node_execution_failed( execution_metadata = ( json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None ) - db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( - { - WorkflowNodeExecution.status: ( - WorkflowNodeExecutionStatus.FAILED.value - if not isinstance(event, QueueNodeExceptionEvent) - else WorkflowNodeExecutionStatus.EXCEPTION.value - ), - WorkflowNodeExecution.error: event.error, - WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, - WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None, - WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, - WorkflowNodeExecution.finished_at: finished_at, - WorkflowNodeExecution.elapsed_time: elapsed_time, - WorkflowNodeExecution.execution_metadata: execution_metadata, - } - ) - - db.session.commit() - db.session.close() process_data = WorkflowEntry.handle_special_values(event.process_data) workflow_node_execution.status = ( WorkflowNodeExecutionStatus.FAILED.value @@ -429,12 +385,11 @@ def _handle_workflow_node_execution_failed( workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.execution_metadata = execution_metadata - self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) - + workflow_node_execution = session.merge(workflow_node_execution) return workflow_node_execution def _handle_workflow_node_execution_retried( - self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent + self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent ) -> WorkflowNodeExecution: """ Workflow node execution failed @@ -458,6 +413,7 @@ def _handle_workflow_node_execution_retried( execution_metadata = json.dumps(merged_metadata) workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.id = str(uuid4()) workflow_node_execution.tenant_id = workflow_run.tenant_id workflow_node_execution.app_id = workflow_run.app_id workflow_node_execution.workflow_id = workflow_run.workflow_id @@ -480,10 +436,9 @@ def _handle_workflow_node_execution_retried( workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.index = event.node_run_index - db.session.add(workflow_node_execution) - db.session.commit() - db.session.refresh(workflow_node_execution) + session.add(workflow_node_execution) + self._workflow_node_executions[event.node_execution_id] = workflow_node_execution return workflow_node_execution ################################################# @@ -491,14 +446,14 @@ def _handle_workflow_node_execution_retried( ################################################# def _workflow_start_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun + self, + *, + session: Session, + task_id: str, + workflow_run: WorkflowRun, ) -> WorkflowStartStreamResponse: - """ - Workflow start to stream response. - :param task_id: task id - :param workflow_run: workflow run - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return WorkflowStartStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -506,42 +461,38 @@ def _workflow_start_to_stream_response( id=workflow_run.id, workflow_id=workflow_run.workflow_id, sequence_number=workflow_run.sequence_number, - inputs=workflow_run.inputs_dict, + inputs=dict(workflow_run.inputs_dict or {}), created_at=int(workflow_run.created_at.timestamp()), ), ) def _workflow_finish_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun + self, + *, + session: Session, + task_id: str, + workflow_run: WorkflowRun, ) -> WorkflowFinishStreamResponse: - """ - Workflow finish to stream response. - :param task_id: task id - :param workflow_run: workflow run - :return: - """ - # Attach WorkflowRun to an active session so "created_by_role" can be accessed. - workflow_run = db.session.merge(workflow_run) - - # Refresh to ensure any expired attributes are fully loaded - db.session.refresh(workflow_run) - created_by = None - if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value: - created_by_account = workflow_run.created_by_account - if created_by_account: + if workflow_run.created_by_role == CreatedByRole.ACCOUNT: + stmt = select(Account).where(Account.id == workflow_run.created_by) + account = session.scalar(stmt) + if account: created_by = { - "id": created_by_account.id, - "name": created_by_account.name, - "email": created_by_account.email, + "id": account.id, + "name": account.name, + "email": account.email, } - else: - created_by_end_user = workflow_run.created_by_end_user - if created_by_end_user: + elif workflow_run.created_by_role == CreatedByRole.END_USER: + stmt = select(EndUser).where(EndUser.id == workflow_run.created_by) + end_user = session.scalar(stmt) + if end_user: created_by = { - "id": created_by_end_user.id, - "user": created_by_end_user.session_id, + "id": end_user.id, + "user": end_user.session_id, } + else: + raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}") return WorkflowFinishStreamResponse( task_id=task_id, @@ -551,7 +502,7 @@ def _workflow_finish_to_stream_response( workflow_id=workflow_run.workflow_id, sequence_number=workflow_run.sequence_number, status=workflow_run.status, - outputs=workflow_run.outputs_dict, + outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None, error=workflow_run.error, elapsed_time=workflow_run.elapsed_time, total_tokens=workflow_run.total_tokens, @@ -559,23 +510,26 @@ def _workflow_finish_to_stream_response( created_by=created_by, created_at=int(workflow_run.created_at.timestamp()), finished_at=int(workflow_run.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict), + files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)), exceptions_count=workflow_run.exceptions_count, ), ) def _workflow_node_start_to_stream_response( - self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution + self, + *, + session: Session, + event: QueueNodeStartedEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeStartStreamResponse]: - """ - Workflow node start to stream response. - :param event: queue node started event - :param task_id: task id - :param workflow_node_execution: workflow node execution - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None + if not workflow_node_execution.workflow_run_id: + return None response = NodeStartStreamResponse( task_id=task_id, @@ -611,6 +565,8 @@ def _workflow_node_start_to_stream_response( def _workflow_node_finish_to_stream_response( self, + *, + session: Session, event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent @@ -618,15 +574,14 @@ def _workflow_node_finish_to_stream_response( task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeFinishStreamResponse]: - """ - Workflow node finish to stream response. - :param event: queue node succeeded or failed event - :param task_id: task id - :param workflow_node_execution: workflow node execution - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None + if not workflow_node_execution.workflow_run_id: + return None + if not workflow_node_execution.finished_at: + return None return NodeFinishStreamResponse( task_id=task_id, @@ -658,19 +613,20 @@ def _workflow_node_finish_to_stream_response( def _workflow_node_retry_to_stream_response( self, + *, + session: Session, event: QueueNodeRetryEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, - ) -> Optional[NodeFinishStreamResponse]: - """ - Workflow node finish to stream response. - :param event: queue node succeeded or failed event - :param task_id: task id - :param workflow_node_execution: workflow node execution - :return: - """ + ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None + if not workflow_node_execution.workflow_run_id: + return None + if not workflow_node_execution.finished_at: + return None return NodeRetryStreamResponse( task_id=task_id, @@ -702,15 +658,10 @@ def _workflow_node_retry_to_stream_response( ) def _workflow_parallel_branch_start_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent ) -> ParallelBranchStartStreamResponse: - """ - Workflow parallel branch start to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: parallel branch run started event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return ParallelBranchStartStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -726,17 +677,14 @@ def _workflow_parallel_branch_start_to_stream_response( def _workflow_parallel_branch_finished_to_stream_response( self, + *, + session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, ) -> ParallelBranchFinishedStreamResponse: - """ - Workflow parallel branch finished to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: parallel branch run succeeded or failed event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return ParallelBranchFinishedStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -753,15 +701,10 @@ def _workflow_parallel_branch_finished_to_stream_response( ) def _workflow_iteration_start_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent ) -> IterationNodeStartStreamResponse: - """ - Workflow iteration start to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: iteration start event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return IterationNodeStartStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -780,15 +723,10 @@ def _workflow_iteration_start_to_stream_response( ) def _workflow_iteration_next_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent ) -> IterationNodeNextStreamResponse: - """ - Workflow iteration next to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: iteration next event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return IterationNodeNextStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -809,15 +747,10 @@ def _workflow_iteration_next_to_stream_response( ) def _workflow_iteration_completed_to_stream_response( - self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent + self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent ) -> IterationNodeCompletedStreamResponse: - """ - Workflow iteration completed to stream response - :param task_id: task id - :param workflow_run: workflow run - :param event: iteration completed event - :return: - """ + # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this + _ = session return IterationNodeCompletedStreamResponse( task_id=task_id, workflow_run_id=workflow_run.id, @@ -844,7 +777,7 @@ def _workflow_iteration_completed_to_stream_response( ), ) - def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]: + def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]: """ Fetch files from node outputs :param outputs_dict: node outputs dict @@ -857,9 +790,11 @@ def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping # Remove None files = [file for file in files if file] # Flatten list - files = [file for sublist in files for file in sublist] + # Flatten the list of sequences into a single list of mappings + flattened_files = [file for sublist in files if sublist for file in sublist] - return files + # Convert to tuple to match Sequence type + return tuple(flattened_files) def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]: """ @@ -897,28 +832,23 @@ def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any elif isinstance(value, File): return value.to_dict() - def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: - """ - Refetch workflow run - :param workflow_run_id: workflow run id - :return: - """ - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + return None + def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: + if self._workflow_run and self._workflow_run.id == workflow_run_id: + cached_workflow_run = self._workflow_run + cached_workflow_run = session.merge(cached_workflow_run) + return cached_workflow_run + stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) + workflow_run = session.scalar(stmt) if not workflow_run: - raise Exception(f"Workflow run not found: {workflow_run_id}") + raise WorkflowRunNotFoundError(workflow_run_id) + self._workflow_run = workflow_run return workflow_run - def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: - """ - Refetch workflow node execution - :param node_execution_id: workflow node execution id - :return: - """ - workflow_node_execution = self._wip_workflow_node_executions.get(node_execution_id) - - if not workflow_node_execution: - raise Exception(f"Workflow node execution not found: {node_execution_id}") - - return workflow_node_execution + def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: + if node_execution_id not in self._workflow_node_executions: + raise ValueError(f"Workflow node execution not found: {node_execution_id}") + cached_workflow_node_execution = self._workflow_node_executions[node_execution_id] + return cached_workflow_node_execution diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index d826edf6a..effc7eff9 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -57,7 +57,7 @@ def on_tool_end( self, tool_name: str, tool_inputs: Mapping[str, Any], - tool_outputs: Sequence[ToolInvokeMessage], + tool_outputs: Sequence[ToolInvokeMessage] | str, message_id: Optional[str] = None, timer: Optional[Any] = None, trace_manager: Optional[TraceQueueManager] = None, diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 148157863..8f8aaa93d 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -40,17 +40,18 @@ def on_query(self, query: str, dataset_id: str) -> None: def on_tool_end(self, documents: list[Document]) -> None: """Handle tool end.""" for document in documents: - query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) + if document.metadata is not None: + query = db.session.query(DocumentSegment).filter( + DocumentSegment.index_node_id == document.metadata["doc_id"] + ) - if "dataset_id" in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) - # add hit count to document segment - query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + # add hit count to document segment + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) - db.session.commit() + db.session.commit() def return_retriever_resource_info(self, resource: list): """Handle return_retriever_resource_info.""" diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py new file mode 100644 index 000000000..90c987973 --- /dev/null +++ b/api/core/entities/knowledge_entities.py @@ -0,0 +1,19 @@ +from typing import Optional + +from pydantic import BaseModel + + +class PreviewDetail(BaseModel): + content: str + child_chunks: Optional[list[str]] = None + + +class QAPreviewDetail(BaseModel): + question: str + answer: str + + +class IndexingEstimate(BaseModel): + total_segments: int + preview: list[PreviewDetail] + qa_preview: Optional[list[QAPreviewDetail]] = None diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 9ed5528e4..501783556 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from enum import Enum from typing import Optional @@ -72,7 +73,7 @@ class DefaultModelProviderEntity(BaseModel): label: I18nObject icon_small: Optional[I18nObject] = None icon_large: Optional[I18nObject] = None - supported_model_types: list[ModelType] + supported_model_types: Sequence[ModelType] = [] class DefaultModelEntity(BaseModel): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index c68678e68..fa1875041 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -40,7 +40,7 @@ logger = logging.getLogger(__name__) -original_provider_configurate_methods = {} +original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {} class ProviderConfiguration(BaseModel): @@ -99,7 +99,8 @@ def get_current_credentials(self, model_type: ModelType, model: str) -> Optional continue restrict_models = quota_configuration.restrict_models - + if self.system_configuration.credentials is None: + return None copy_credentials = self.system_configuration.credentials.copy() if restrict_models: for restrict_model in restrict_models: @@ -124,7 +125,7 @@ def get_current_credentials(self, model_type: ModelType, model: str) -> Optional return credentials - def get_system_configuration_status(self) -> SystemConfigurationStatus: + def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]: """ Get system configuration status. :return: @@ -136,6 +137,8 @@ def get_system_configuration_status(self) -> SystemConfigurationStatus: current_quota_configuration = next( (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None ) + if current_quota_configuration is None: + return None return ( SystemConfigurationStatus.ACTIVE @@ -150,7 +153,7 @@ def is_custom_configuration_available(self) -> bool: """ return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 - def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: + def get_custom_credentials(self, obfuscated: bool = False): """ Get custom credentials. @@ -172,7 +175,7 @@ def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: else [], ) - def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: + def custom_credentials_validate(self, credentials: dict) -> tuple[Optional[Provider], dict]: """ Validate custom credentials. :param credentials: provider credentials @@ -417,7 +420,7 @@ def get_custom_model_credentials( def custom_model_credentials_validate( self, model_type: ModelType, model: str, credentials: dict - ) -> tuple[ProviderModel, dict]: + ) -> tuple[Optional[ProviderModel], dict]: """ Validate custom model credentials. @@ -937,10 +940,10 @@ def get_provider_models( if model_type: model_types.append(model_type) else: - model_types = provider_instance.get_provider_schema().supported_model_types + model_types = list(provider_instance.get_provider_schema().supported_model_types) # Group model settings by model type and model - model_setting_map = defaultdict(dict) + model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict) for model_setting in self.model_settings: model_setting_map[model_setting.model_type][model_setting.model] = model_setting @@ -1019,54 +1022,57 @@ def _get_system_provider_models( ]: # only customizable model for restrict_model in restrict_models: - copy_credentials = self.system_configuration.credentials.copy() - if restrict_model.base_model_name: - copy_credentials["base_model_name"] = restrict_model.base_model_name - - try: - custom_model_schema = provider_instance.get_model_instance( - restrict_model.model_type - ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) - except Exception as ex: - logger.warning(f"get custom model schema failed, {ex}") - continue - - if not custom_model_schema: - continue - - if custom_model_schema.model_type not in model_types: - continue - - status = ModelStatus.ACTIVE - if ( - custom_model_schema.model_type in model_setting_map - and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] - ): - model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] - if model_setting.enabled is False: - status = ModelStatus.DISABLED - - provider_models.append( - ModelWithProviderEntity( - model=custom_model_schema.model, - label=custom_model_schema.label, - model_type=custom_model_schema.model_type, - features=custom_model_schema.features, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties=custom_model_schema.model_properties, - deprecated=custom_model_schema.deprecated, - provider=SimpleModelProviderEntity(self.provider), - status=status, + if self.system_configuration.credentials is not None: + copy_credentials = self.system_configuration.credentials.copy() + if restrict_model.base_model_name: + copy_credentials["base_model_name"] = restrict_model.base_model_name + + try: + custom_model_schema = provider_instance.get_model_instance( + restrict_model.model_type + ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) + except Exception as ex: + logger.warning(f"get custom model schema failed, {ex}") + continue + + if not custom_model_schema: + continue + + if custom_model_schema.model_type not in model_types: + continue + + status = ModelStatus.ACTIVE + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): + model_setting = model_setting_map[custom_model_schema.model_type][ + custom_model_schema.model + ] + if model_setting.enabled is False: + status = ModelStatus.DISABLED + + provider_models.append( + ModelWithProviderEntity( + model=custom_model_schema.model, + label=custom_model_schema.label, + model_type=custom_model_schema.model_type, + features=custom_model_schema.features, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties=custom_model_schema.model_properties, + deprecated=custom_model_schema.deprecated, + provider=SimpleModelProviderEntity(self.provider), + status=status, + ) ) - ) # if llm name not in restricted llm list, remove it restrict_model_names = [rm.model for rm in restrict_models] - for m in provider_models: - if m.model_type == ModelType.LLM and m.model not in restrict_model_names: - m.status = ModelStatus.NO_PERMISSION + for model in provider_models: + if model.model_type == ModelType.LLM and model.model not in restrict_model_names: + model.status = ModelStatus.NO_PERMISSION elif not quota_configuration.is_valid: - m.status = ModelStatus.QUOTA_EXCEEDED + model.status = ModelStatus.QUOTA_EXCEEDED return provider_models @@ -1240,7 +1246,7 @@ def __iter__(self): return iter(self.configurations) def values(self) -> Iterator[ProviderConfiguration]: - return self.configurations.values() + return iter(self.configurations.values()) def get(self, key, default=None): return self.configurations.get(key, default) diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index 38cebb6b6..3f4e20ec2 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -1,3 +1,5 @@ +from typing import cast + import requests from configs import dify_config @@ -5,7 +7,7 @@ class APIBasedExtensionRequestor: - timeout: (int, int) = (5, 60) + timeout: tuple[int, int] = (5, 60) """timeout for request connect and read""" def __init__(self, api_endpoint: str, api_key: str) -> None: @@ -51,4 +53,4 @@ def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: "request error, status_code: {}, content: {}".format(response.status_code, response.text[:100]) ) - return response.json() + return cast(dict, response.json()) diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 97dbaf202..231743bf2 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -38,8 +38,8 @@ def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: @classmethod def scan_extensions(cls): - extensions: list[ModuleExtension] = [] - position_map = {} + extensions = [] + position_map: dict[str, int] = {} # get the path of the current class current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py") @@ -58,7 +58,8 @@ def scan_extensions(cls): # is builtin extension, builtin extension # in the front-end page and business logic, there are special treatments. builtin = False - position = None + # default position is 0 can not be None for sort_to_dict_by_position_map + position = 0 if "__builtin__" in file_names: builtin = True @@ -89,7 +90,7 @@ def scan_extensions(cls): logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.") continue - json_data = {} + json_data: dict[str, Any] = {} if not builtin: if "schema.json" not in file_names: logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py index 3da170455..9eb9e0306 100644 --- a/api/core/extension/extension.py +++ b/api/core/extension/extension.py @@ -1,4 +1,6 @@ -from core.extension.extensible import ExtensionModule, ModuleExtension +from typing import cast + +from core.extension.extensible import Extensible, ExtensionModule, ModuleExtension from core.external_data_tool.base import ExternalDataTool from core.moderation.base import Moderation @@ -10,7 +12,8 @@ class Extension: def init(self): for module, module_class in self.module_classes.items(): - self.__module_extensions[module.value] = module_class.scan_extensions() + m = cast(Extensible, module_class) + self.__module_extensions[module.value] = m.scan_extensions() def module_extensions(self, module: str) -> list[ModuleExtension]: module_extensions = self.__module_extensions.get(module) @@ -35,7 +38,8 @@ def module_extension(self, module: ExtensionModule, extension_name: str) -> Modu def extension_class(self, module: ExtensionModule, extension_name: str) -> type: module_extension = self.module_extension(module, extension_name) - return module_extension.extension_class + t: type = module_extension.extension_class + return t def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None: module_extension = self.module_extension(module, extension_name) diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 54ec97a49..9989c8a09 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -48,7 +48,10 @@ def query(self, inputs: dict, query: Optional[str] = None) -> str: :return: the tool query result """ # get params from config + if not self.config: + raise ValueError("config is required, config: {}".format(self.config)) api_based_extension_id = self.config.get("api_based_extension_id") + assert api_based_extension_id is not None, "api_based_extension_id is required" # get api_based_extension api_based_extension = ( diff --git a/api/core/external_data_tool/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py index 84b94e117..6a9703a56 100644 --- a/api/core/external_data_tool/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -1,7 +1,7 @@ -import concurrent import logging -from concurrent.futures import ThreadPoolExecutor -from typing import Optional +from collections.abc import Mapping +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from typing import Any, Optional from flask import Flask, current_app @@ -17,9 +17,9 @@ def fetch( tenant_id: str, app_id: str, external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, + inputs: Mapping[str, Any], query: str, - ) -> dict: + ) -> Mapping[str, Any]: """ Fill in variable inputs from external data tools if exists. @@ -30,13 +30,14 @@ def fetch( :param query: the query :return: the filled inputs """ - results = {} + results: dict[str, Any] = {} + inputs = dict(inputs) with ThreadPoolExecutor() as executor: futures = {} for tool in external_data_tools: - future = executor.submit( + future: Future[tuple[str | None, str | None]] = executor.submit( self._query_external_data_tool, - current_app._get_current_object(), + current_app._get_current_object(), # type: ignore tenant_id, app_id, tool, @@ -46,9 +47,10 @@ def fetch( futures[future] = tool - for future in concurrent.futures.as_completed(futures): + for future in as_completed(futures): tool_variable, result = future.result() - results[tool_variable] = result + if tool_variable is not None: + results[tool_variable] = result inputs.update(results) return inputs @@ -59,7 +61,7 @@ def _query_external_data_tool( tenant_id: str, app_id: str, external_data_tool: ExternalDataVariableEntity, - inputs: dict, + inputs: Mapping[str, Any], query: str, ) -> tuple[Optional[str], Optional[str]]: """ diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index 287210985..245507e17 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -1,4 +1,5 @@ -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional, cast from core.extension.extensible import ExtensionModule from extensions.ext_code_based_extension import code_based_extension @@ -23,9 +24,10 @@ def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: """ code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config) extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) - extension_class.validate_config(tenant_id, config) + # FIXME mypy issue here, figure out how to fix it + extension_class.validate_config(tenant_id, config) # type: ignore - def query(self, inputs: dict, query: Optional[str] = None) -> str: + def query(self, inputs: Mapping[str, Any], query: Optional[str] = None) -> str: """ Query the external data tool. @@ -33,4 +35,4 @@ def query(self, inputs: dict, query: Optional[str] = None) -> str: :param query: the query of chat app :return: the tool query result """ - return self.__extension_instance.query(inputs, query) + return cast(str, self.__extension_instance.query(inputs, query)) diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index 15eb351a7..4a50fb85c 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -1,4 +1,5 @@ import base64 +from collections.abc import Mapping from configs import dify_config from core.helper import ssrf_proxy @@ -55,7 +56,7 @@ def to_prompt_message_content( if f.type == FileType.IMAGE: params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - prompt_class_map = { + prompt_class_map: Mapping[FileType, type[MultiModalPromptMessageContent]] = { FileType.IMAGE: ImagePromptMessageContent, FileType.AUDIO: AudioPromptMessageContent, FileType.VIDEO: VideoPromptMessageContent, @@ -63,7 +64,7 @@ def to_prompt_message_content( } try: - return prompt_class_map[f.type](**params) + return prompt_class_map[f.type].model_validate(params) except KeyError: raise ValueError(f"file type {f.type} is not supported") diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index a17b7be36..6fa101cf3 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast if TYPE_CHECKING: from core.tools.tool_file_manager import ToolFileManager @@ -9,4 +9,4 @@ class ToolFileParser: @staticmethod def get_tool_file_manager() -> "ToolFileManager": - return tool_file_manager["manager"] + return cast("ToolFileManager", tool_file_manager["manager"]) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 584e3e969..15b501780 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -38,7 +38,7 @@ class CodeLanguage(StrEnum): class CodeExecutor: - dependencies_cache = {} + dependencies_cache: dict[str, str] = {} dependencies_cache_lock = Lock() code_template_transformers: dict[CodeLanguage, type[TemplateTransformer]] = { @@ -103,19 +103,19 @@ def execute_code(cls, language: CodeLanguage, preload: str, code: str) -> str: ) try: - response = response.json() + response_data = response.json() except: raise CodeExecutionError("Failed to parse response") - if (code := response.get("code")) != 0: - raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}") + if (code := response_data.get("code")) != 0: + raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}") - response = CodeExecutionResponse(**response) + response_code = CodeExecutionResponse(**response_data) - if response.data.error: - raise CodeExecutionError(response.data.error) + if response_code.data.error: + raise CodeExecutionError(response_code.data.error) - return response.data.stdout or "" + return response_code.data.stdout or "" @classmethod def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]): diff --git a/api/core/helper/code_executor/jinja2/jinja2_formatter.py b/api/core/helper/code_executor/jinja2/jinja2_formatter.py index db2eb5ebb..264947b56 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_formatter.py +++ b/api/core/helper/code_executor/jinja2/jinja2_formatter.py @@ -1,9 +1,11 @@ +from collections.abc import Mapping + from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage class Jinja2Formatter: @classmethod - def format(cls, template: str, inputs: dict) -> str: + def format(cls, template: str, inputs: Mapping[str, str]) -> str: """ Format template :param template: template @@ -11,5 +13,4 @@ def format(cls, template: str, inputs: dict) -> str: :return: """ result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs) - - return result["result"] + return str(result.get("result", "")) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index 605719747..baa792b5b 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -29,8 +29,7 @@ def extract_result_str_from_response(cls, response: str): result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL) if not result: raise ValueError("Failed to parse result") - result = result.group(1) - return result + return result.group(1) @classmethod def transform_response(cls, response: str) -> Mapping[str, Any]: diff --git a/api/core/helper/lru_cache.py b/api/core/helper/lru_cache.py index 518962c16..81501d2e4 100644 --- a/api/core/helper/lru_cache.py +++ b/api/core/helper/lru_cache.py @@ -4,7 +4,7 @@ class LRUCache: def __init__(self, capacity: int): - self.cache = OrderedDict() + self.cache: OrderedDict[Any, Any] = OrderedDict() self.capacity = capacity def get(self, key: Any) -> Any: diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 5e274f891..35349210b 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -30,7 +30,7 @@ def get(self) -> Optional[dict]: except JSONDecodeError: return None - return cached_provider_credentials + return dict(cached_provider_credentials) else: return None diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index da0fd0031..543444463 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -22,6 +22,7 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) provider_name = model_config.provider if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers: hosting_openai_config = hosting_configuration.provider_map["openai"] + assert hosting_openai_config is not None # 2000 text per chunk length = 2000 @@ -34,8 +35,9 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) try: model_type_instance = OpenAIModerationModel() + # FIXME, for type hint using assert or raise ValueError is better here? moderation_result = model_type_instance.invoke( - model="text-moderation-stable", credentials=hosting_openai_config.credentials, text=text_chunk + model="text-moderation-stable", credentials=hosting_openai_config.credentials or {}, text=text_chunk ) if moderation_result is True: diff --git a/api/core/helper/module_import_helper.py b/api/core/helper/module_import_helper.py index 1e2fefce8..9a041667e 100644 --- a/api/core/helper/module_import_helper.py +++ b/api/core/helper/module_import_helper.py @@ -14,12 +14,13 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz if existed_spec: spec = existed_spec if not spec.loader: - raise Exception(f"Failed to load module {module_name} from {py_file_path}") + raise Exception(f"Failed to load module {module_name} from {py_file_path!r}") else: # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly - spec = importlib.util.spec_from_file_location(module_name, py_file_path) + # FIXME: mypy does not support the type of spec.loader + spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore if not spec or not spec.loader: - raise Exception(f"Failed to load module {module_name} from {py_file_path}") + raise Exception(f"Failed to load module {module_name} from {py_file_path!r}") if use_lazy_loader: # Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports spec.loader = importlib.util.LazyLoader(spec.loader) @@ -29,7 +30,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz spec.loader.exec_module(module) return module except Exception as e: - logging.exception(f"Failed to load module {module_name} from script file '{py_file_path}'") + logging.exception(f"Failed to load module {module_name} from script file '{py_file_path!r}'") raise e @@ -57,6 +58,6 @@ def load_single_subclass_from_source( case 1: return subclasses[0] case 0: - raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path}") + raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path!r}") case _: - raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path}") + raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path!r}") diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index e848b46c5..3b67b3f84 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -33,7 +33,7 @@ def get(self) -> Optional[dict]: except JSONDecodeError: return None - return cached_tool_parameter + return dict(cached_tool_parameter) else: return None diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py index 94b02cf98..6de5e704a 100644 --- a/api/core/helper/tool_provider_cache.py +++ b/api/core/helper/tool_provider_cache.py @@ -28,7 +28,7 @@ def get(self) -> Optional[dict]: except JSONDecodeError: return None - return cached_provider_credentials + return dict(cached_provider_credentials) else: return None diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index b47ba67f2..f9fb7275f 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -42,7 +42,7 @@ class HostedModerationConfig(BaseModel): class HostingConfiguration: provider_map: dict[str, HostingProvider] = {} - moderation_config: HostedModerationConfig = None + moderation_config: Optional[HostedModerationConfig] = None def init_app(self, app: Flask) -> None: if dify_config.EDITION != "CLOUD": @@ -67,7 +67,7 @@ def init_azure_openai() -> HostingProvider: "base_model_name": "gpt-35-turbo", } - quotas = [] + quotas: list[HostingQuota] = [] hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT trial_quota = TrialHostingQuota( quota_limit=hosted_quota_limit, @@ -123,7 +123,7 @@ def init_azure_openai() -> HostingProvider: def init_openai(self) -> HostingProvider: quota_unit = QuotaUnit.CREDITS - quotas = [] + quotas: list[HostingQuota] = [] if dify_config.HOSTED_OPENAI_TRIAL_ENABLED: hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT @@ -157,7 +157,7 @@ def init_openai(self) -> HostingProvider: @staticmethod def init_anthropic() -> HostingProvider: quota_unit = QuotaUnit.TOKENS - quotas = [] + quotas: list[HostingQuota] = [] if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED: hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT @@ -187,7 +187,7 @@ def init_anthropic() -> HostingProvider: def init_minimax() -> HostingProvider: quota_unit = QuotaUnit.TOKENS if dify_config.HOSTED_MINIMAX_ENABLED: - quotas = [FreeHostingQuota()] + quotas: list[HostingQuota] = [FreeHostingQuota()] return HostingProvider( enabled=True, @@ -205,7 +205,7 @@ def init_minimax() -> HostingProvider: def init_spark() -> HostingProvider: quota_unit = QuotaUnit.TOKENS if dify_config.HOSTED_SPARK_ENABLED: - quotas = [FreeHostingQuota()] + quotas: list[HostingQuota] = [FreeHostingQuota()] return HostingProvider( enabled=True, @@ -223,7 +223,7 @@ def init_spark() -> HostingProvider: def init_zhipuai() -> HostingProvider: quota_unit = QuotaUnit.TOKENS if dify_config.HOSTED_ZHIPUAI_ENABLED: - quotas = [FreeHostingQuota()] + quotas: list[HostingQuota] = [FreeHostingQuota()] return HostingProvider( enabled=True, diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 29e161cb7..1bc4baf9c 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -6,36 +6,36 @@ import threading import time import uuid -from typing import Optional, cast +from typing import Any, Optional, cast -from flask import Flask, current_app -from flask_login import current_user +from flask import current_app +from flask_login import current_user # type: ignore from sqlalchemy.orm.exc import ObjectDeletedError from configs import dify_config +from core.entities.knowledge_entities import IndexingEstimate, PreviewDetail, QAPreviewDetail from core.errors.error import ProviderTokenNotInitError -from core.llm_generator.llm_generator import LLMGenerator from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import ChildDocument, Document from core.rag.splitter.fixed_text_splitter import ( EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter, ) from core.rag.splitter.text_splitter import TextSplitter -from core.tools.utils.text_processing_utils import remove_leading_symbols from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper -from models.dataset import Dataset, DatasetProcessRule, DocumentSegment +from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import UploadFile from services.feature_service import FeatureService @@ -62,6 +62,8 @@ def run(self, dataset_documents: list[DatasetDocument]): .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) + if not processing_rule: + raise ValueError("no process rule found") index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() # extract @@ -113,6 +115,9 @@ def run_in_splitting_status(self, dataset_document: DatasetDocument): for document_segment in document_segments: db.session.delete(document_segment) + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + # delete child chunks + db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete() db.session.commit() # get the process rule processing_rule = ( @@ -120,6 +125,8 @@ def run_in_splitting_status(self, dataset_document: DatasetDocument): .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) + if not processing_rule: + raise ValueError("no process rule found") index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -179,7 +186,22 @@ def run_in_indexing_status(self, dataset_document: DatasetDocument): "dataset_id": document_segment.dataset_id, }, ) - + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = document_segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": document_segment.document_id, + "dataset_id": document_segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents documents.append(document) # build index @@ -218,7 +240,7 @@ def indexing_estimate( doc_language: str = "English", dataset_id: Optional[str] = None, indexing_technique: str = "economy", - ) -> dict: + ) -> IndexingEstimate: """ Estimate the indexing for the document. """ @@ -254,38 +276,46 @@ def indexing_estimate( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) - preview_texts = [] + preview_texts = [] # type: ignore + total_segments = 0 index_type = doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - all_text_docs = [] for extract_setting in extract_settings: # extract - text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) - all_text_docs.extend(text_docs) processing_rule = DatasetProcessRule( mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) - - # get splitter - splitter = self._get_splitter(processing_rule, embedding_model_instance) - - # split to documents - documents = self._split_to_documents_for_estimate( - text_docs=text_docs, splitter=splitter, processing_rule=processing_rule + text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) + documents = index_processor.transform( + text_docs, + embedding_model_instance=embedding_model_instance, + process_rule=processing_rule.to_dict(), + tenant_id=current_user.current_tenant_id, + doc_language=doc_language, + preview=True, ) - total_segments += len(documents) for document in documents: - if len(preview_texts) < 5: - preview_texts.append(document.page_content) + if len(preview_texts) < 10: + if doc_form and doc_form == "qa_model": + preview_detail = QAPreviewDetail( + question=document.page_content, answer=document.metadata.get("answer") or "" + ) + preview_texts.append(preview_detail) + else: + preview_detail = PreviewDetail(content=document.page_content) # type: ignore + if document.children: + preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore + preview_texts.append(preview_detail) # delete image files and related db records image_upload_file_ids = get_image_upload_file_ids(document.page_content) for upload_file_id in image_upload_file_ids: image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() try: - storage.delete(image_file.key) + if image_file: + storage.delete(image_file.key) except Exception: logging.exception( "Delete image_files failed while indexing_estimate, \ @@ -294,15 +324,8 @@ def indexing_estimate( db.session.delete(image_file) if doc_form and doc_form == "qa_model": - if len(preview_texts) > 0: - # qa model document - response = LLMGenerator.generate_qa_document( - current_user.current_tenant_id, preview_texts[0], doc_language - ) - document_qa_list = self.format_split_text(response) - - return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts} - return {"total_segments": total_segments, "preview": preview_texts} + return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[]) + return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore def _extract( self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict @@ -379,8 +402,9 @@ def _extract( # replace doc id to document model id text_docs = cast(list[Document], text_docs) for text_doc in text_docs: - text_doc.metadata["document_id"] = dataset_document.id - text_doc.metadata["dataset_id"] = dataset_document.dataset_id + if text_doc.metadata is not None: + text_doc.metadata["document_id"] = dataset_document.id + text_doc.metadata["dataset_id"] = dataset_document.dataset_id return text_docs @@ -395,30 +419,26 @@ def filter_string(text): @staticmethod def _get_splitter( - processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance] + processing_rule_mode: str, + max_tokens: int, + chunk_overlap: int, + separator: str, + embedding_model_instance: Optional[ModelInstance], ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ - if processing_rule.mode == "custom": + if processing_rule_mode in ["custom", "hierarchical"]: # The user-defined segmentation rule - rules = json.loads(processing_rule.rules) - segmentation = rules["segmentation"] max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH - if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: + if max_tokens < 50 or max_tokens > max_segmentation_tokens_length: raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") - separator = segmentation["separator"] if separator: separator = separator.replace("\\n", "\n") - if segmentation.get("chunk_overlap"): - chunk_overlap = segmentation["chunk_overlap"] - else: - chunk_overlap = 0 - character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( - chunk_size=segmentation["max_tokens"], + chunk_size=max_tokens, chunk_overlap=chunk_overlap, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], @@ -426,148 +446,15 @@ def _get_splitter( ) else: # Automatic segmentation + automatic_rules: dict[str, Any] = dict(DatasetProcessRule.AUTOMATIC_RULES["segmentation"]) character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( - chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"], - chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"], + chunk_size=automatic_rules["max_tokens"], + chunk_overlap=automatic_rules["chunk_overlap"], separators=["\n\n", "。", ". ", " ", ""], embedding_model_instance=embedding_model_instance, ) - return character_splitter - - def _step_split( - self, - text_docs: list[Document], - splitter: TextSplitter, - dataset: Dataset, - dataset_document: DatasetDocument, - processing_rule: DatasetProcessRule, - ) -> list[Document]: - """ - Split the text documents into documents and save them to the document segment. - """ - documents = self._split_to_documents( - text_docs=text_docs, - splitter=splitter, - processing_rule=processing_rule, - tenant_id=dataset.tenant_id, - document_form=dataset_document.doc_form, - document_language=dataset_document.doc_language, - ) - - # save node to document segment - doc_store = DatasetDocumentStore( - dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id - ) - - # add document segments - doc_store.add_documents(documents) - - # update document status to indexing - cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - self._update_document_index_status( - document_id=dataset_document.id, - after_indexing_status="indexing", - extra_update_params={ - DatasetDocument.cleaning_completed_at: cur_time, - DatasetDocument.splitting_completed_at: cur_time, - }, - ) - - # update segment status to indexing - self._update_segments_by_document( - dataset_document_id=dataset_document.id, - update_params={ - DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - }, - ) - - return documents - - def _split_to_documents( - self, - text_docs: list[Document], - splitter: TextSplitter, - processing_rule: DatasetProcessRule, - tenant_id: str, - document_form: str, - document_language: str, - ) -> list[Document]: - """ - Split the text documents into nodes. - """ - all_documents = [] - all_qa_documents = [] - for text_doc in text_docs: - # document clean - document_text = self._document_clean(text_doc.page_content, processing_rule) - text_doc.page_content = document_text - - # parse document to nodes - documents = splitter.split_documents([text_doc]) - split_documents = [] - for document_node in documents: - if document_node.page_content.strip(): - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata["doc_id"] = doc_id - document_node.metadata["doc_hash"] = hash - # delete Splitter character - page_content = document_node.page_content - document_node.page_content = remove_leading_symbols(page_content) - - if document_node.page_content: - split_documents.append(document_node) - all_documents.extend(split_documents) - # processing qa document - if document_form == "qa_model": - for i in range(0, len(all_documents), 10): - threads = [] - sub_documents = all_documents[i : i + 10] - for doc in sub_documents: - document_format_thread = threading.Thread( - target=self.format_qa_document, - kwargs={ - "flask_app": current_app._get_current_object(), - "tenant_id": tenant_id, - "document_node": doc, - "all_qa_documents": all_qa_documents, - "document_language": document_language, - }, - ) - threads.append(document_format_thread) - document_format_thread.start() - for thread in threads: - thread.join() - return all_qa_documents - return all_documents - - def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): - format_documents = [] - if document_node.page_content is None or not document_node.page_content.strip(): - return - with flask_app.app_context(): - try: - # qa model document - response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language) - document_qa_list = self.format_split_text(response) - qa_documents = [] - for result in document_qa_list: - qa_document = Document( - page_content=result["question"], metadata=document_node.metadata.model_copy() - ) - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result["question"]) - qa_document.metadata["answer"] = result["answer"] - qa_document.metadata["doc_id"] = doc_id - qa_document.metadata["doc_hash"] = hash - qa_documents.append(qa_document) - format_documents.extend(qa_documents) - except Exception as e: - logging.exception("Failed to format qa document") - - all_qa_documents.extend(format_documents) + return character_splitter # type: ignore def _split_to_documents_for_estimate( self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule @@ -575,7 +462,7 @@ def _split_to_documents_for_estimate( """ Split the text documents into nodes. """ - all_documents = [] + all_documents: list[Document] = [] for text_doc in text_docs: # document clean document_text = self._document_clean(text_doc.page_content, processing_rule) @@ -588,11 +475,11 @@ def _split_to_documents_for_estimate( for document in documents: if document.page_content is None or not document.page_content.strip(): continue - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(document.page_content) - - document.metadata["doc_id"] = doc_id - document.metadata["doc_hash"] = hash + if document.metadata is not None: + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document.page_content) + document.metadata["doc_id"] = doc_id + document.metadata["doc_hash"] = hash split_documents.append(document) @@ -614,11 +501,11 @@ def _document_clean(text: str, processing_rule: DatasetProcessRule) -> str: return document_text @staticmethod - def format_split_text(text): + def format_split_text(text: str) -> list[QAPreviewDetail]: regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) - return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a] + return [QAPreviewDetail(question=q, answer=re.sub(r"\n\s*", "\n", a.strip())) for q, a in matches if q and a] def _load( self, @@ -643,23 +530,34 @@ def _load( # chunk nodes by chunk size indexing_start_at = time.perf_counter() tokens = 0 - chunk_size = 10 + if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: + # create keyword index + create_keyword_thread = threading.Thread( + target=self._process_keyword_index, + args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore + ) + create_keyword_thread.start() - # create keyword index - create_keyword_thread = threading.Thread( - target=self._process_keyword_index, - args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), - ) - create_keyword_thread.start() + max_workers = 10 if dataset.indexing_technique == "high_quality": - with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] - for i in range(0, len(documents), chunk_size): - chunk_documents = documents[i : i + chunk_size] + + # Distribute documents into multiple groups based on the hash values of page_content + # This is done to prevent multiple threads from processing the same document, + # Thereby avoiding potential database insertion deadlocks + document_groups: list[list[Document]] = [[] for _ in range(max_workers)] + for document in documents: + hash = helper.generate_text_hash(document.page_content) + group_index = int(hash, 16) % max_workers + document_groups[group_index].append(document) + for chunk_documents in document_groups: + if len(chunk_documents) == 0: + continue futures.append( executor.submit( self._process_chunk, - current_app._get_current_object(), + current_app._get_current_object(), # type: ignore index_processor, chunk_documents, dataset, @@ -670,8 +568,8 @@ def _load( for future in futures: tokens += future.result() - - create_keyword_thread.join() + if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: + create_keyword_thread.join() indexing_end_at = time.perf_counter() # update document status to completed @@ -783,28 +681,6 @@ def _update_segments_by_document(dataset_document_id: str, update_params: dict) DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) db.session.commit() - @staticmethod - def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset): - """ - Batch add segments index processing - """ - documents = [] - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - documents.append(document) - # save vector index - index_type = dataset.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.load(dataset, documents) - def _transform( self, index_processor: BaseIndexProcessor, @@ -846,7 +722,7 @@ def _load_segments(self, dataset, dataset_document, documents): ) # add document segments - doc_store.add_documents(documents) + doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX) # update document status to indexing cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 3a92c8d9d..9fe3f68f2 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -1,7 +1,7 @@ import json import logging import re -from typing import Optional +from typing import Optional, cast from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser @@ -13,6 +13,7 @@ WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, ) from core.model_manager import ModelManager +from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError @@ -44,10 +45,13 @@ def generate_conversation_name( prompts = [UserPromptMessage(content=prompt)] with measure_time() as timer: - response = model_instance.invoke_llm( - prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False + ), ) - answer = response.message.content + answer = cast(str, response.message.content) cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL) if cleaned_answer is None: return "" @@ -94,11 +98,16 @@ def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: st prompt_messages = [UserPromptMessage(content=prompt)] try: - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters={"max_tokens": 256, "temperature": 0}, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters={"max_tokens": 256, "temperature": 0}, + stream=False, + ), ) - questions = output_parser.parse(response.message.content) + questions = output_parser.parse(cast(str, response.message.content)) except InvokeError: questions = [] except Exception as e: @@ -138,11 +147,14 @@ def generate_rule_config( ) try: - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ), ) - rule_config["prompt"] = response.message.content + rule_config["prompt"] = cast(str, response.message.content) except InvokeError as e: error = str(e) @@ -178,15 +190,18 @@ def generate_rule_config( model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider") if model_config else None, - model=model_config.get("name") if model_config else None, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), ) try: try: # the first step to generate the task prompt - prompt_content = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + prompt_content = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ), ) except InvokeError as e: error = str(e) @@ -195,8 +210,10 @@ def generate_rule_config( return rule_config - rule_config["prompt"] = prompt_content.message.content + rule_config["prompt"] = cast(str, prompt_content.message.content) + if not isinstance(prompt_content.message.content, str): + raise NotImplementedError("prompt content is not a string") parameter_generate_prompt = parameter_template.format( inputs={ "INPUT_TEXT": prompt_content.message.content, @@ -216,19 +233,25 @@ def generate_rule_config( statement_messages = [UserPromptMessage(content=statement_generate_prompt)] try: - parameter_content = model_instance.invoke_llm( - prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False + parameter_content = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False + ), ) - rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.content) + rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)) except InvokeError as e: error = str(e) error_step = "generate variables" try: - statement_content = model_instance.invoke_llm( - prompt_messages=statement_messages, model_parameters=model_parameters, stream=False + statement_content = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=statement_messages, model_parameters=model_parameters, stream=False + ), ) - rule_config["opening_statement"] = statement_content.message.content + rule_config["opening_statement"] = cast(str, statement_content.message.content) except InvokeError as e: error = str(e) error_step = "generate conversation opener" @@ -267,19 +290,22 @@ def generate_code( model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, - provider=model_config.get("provider") if model_config else None, - model=model_config.get("name") if model_config else None, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), ) prompt_messages = [UserPromptMessage(content=prompt)] model_parameters = {"max_tokens": max_tokens, "temperature": 0.01} try: - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ), ) - generated_code = response.message.content + generated_code = cast(str, response.message.content) return {"code": generated_code, "language": code_language, "error": ""} except InvokeError as e: @@ -303,9 +329,14 @@ def generate_qa_document(cls, tenant_id: str, query, document_language: str): prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] - response = model_instance.invoke_llm( - prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters={"temperature": 0.01, "max_tokens": 2000}, + stream=False, + ), ) - answer = response.message.content + answer = cast(str, response.message.content) return answer.strip() diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 81d08dc88..003a0c85b 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -68,7 +68,7 @@ def get_history_prompt_messages( messages = list(reversed(thread_messages)) - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] for message in messages: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 198668855..d1e71148c 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -124,17 +124,20 @@ def invoke_llm( raise Exception("Model type instance is not LargeLanguageModel") self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - callbacks=callbacks, + return cast( + Union[LLMResult, Generator], + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + callbacks=callbacks, + ), ) def get_llm_num_tokens( @@ -151,12 +154,15 @@ def get_llm_num_tokens( raise Exception("Model type instance is not LargeLanguageModel") self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, - model=self.model, - credentials=self.credentials, - prompt_messages=prompt_messages, - tools=tools, + return cast( + int, + self._round_robin_invoke( + function=self.model_type_instance.get_num_tokens, + model=self.model, + credentials=self.credentials, + prompt_messages=prompt_messages, + tools=tools, + ), ) def invoke_text_embedding( @@ -174,13 +180,16 @@ def invoke_text_embedding( raise Exception("Model type instance is not TextEmbeddingModel") self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - texts=texts, - user=user, - input_type=input_type, + return cast( + TextEmbeddingResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + texts=texts, + user=user, + input_type=input_type, + ), ) def get_text_embedding_num_tokens(self, texts: list[str]) -> int: @@ -194,11 +203,14 @@ def get_text_embedding_num_tokens(self, texts: list[str]) -> int: raise Exception("Model type instance is not TextEmbeddingModel") self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, - model=self.model, - credentials=self.credentials, - texts=texts, + return cast( + int, + self._round_robin_invoke( + function=self.model_type_instance.get_num_tokens, + model=self.model, + credentials=self.credentials, + texts=texts, + ), ) def invoke_rerank( @@ -223,15 +235,18 @@ def invoke_rerank( raise Exception("Model type instance is not RerankModel") self.model_type_instance = cast(RerankModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - user=user, + return cast( + RerankResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + user=user, + ), ) def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool: @@ -246,12 +261,15 @@ def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool: raise Exception("Model type instance is not ModerationModel") self.model_type_instance = cast(ModerationModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - text=text, - user=user, + return cast( + bool, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + text=text, + user=user, + ), ) def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str: @@ -266,12 +284,15 @@ def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str raise Exception("Model type instance is not Speech2TextModel") self.model_type_instance = cast(Speech2TextModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - file=file, - user=user, + return cast( + str, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + file=file, + user=user, + ), ) def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]: @@ -288,17 +309,20 @@ def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Option raise Exception("Model type instance is not TTSModel") self.model_type_instance = cast(TTSModel, self.model_type_instance) - return self._round_robin_invoke( - function=self.model_type_instance.invoke, - model=self.model, - credentials=self.credentials, - content_text=content_text, - user=user, - tenant_id=tenant_id, - voice=voice, + return cast( + Iterable[bytes], + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + content_text=content_text, + user=user, + tenant_id=tenant_id, + voice=voice, + ), ) - def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): + def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs) -> Any: """ Round-robin invoke :param function: function to invoke @@ -309,7 +333,7 @@ def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): if not self.load_balancing_manager: return function(*args, **kwargs) - last_exception = None + last_exception: Union[InvokeRateLimitError, InvokeAuthorizationError, InvokeConnectionError, None] = None while True: lb_config = self.load_balancing_manager.fetch_next() if not lb_config: @@ -463,7 +487,7 @@ def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]: if real_index > max_index: real_index = 0 - config = self._load_balancing_configs[real_index] + config: ModelLoadBalancingConfiguration = self._load_balancing_configs[real_index] if self.in_cooldown(config): cooldown_load_balancing_configs.append(config) @@ -507,8 +531,7 @@ def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: self._tenant_id, self._provider, self._model_type.value, self._model, config.id ) - res = redis_client.exists(cooldown_cache_key) - res = cast(bool, res) + res: bool = redis_client.exists(cooldown_cache_key) return res @staticmethod diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 3b6b82524..1f21a2d37 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -1,7 +1,8 @@ import json import logging import sys -from typing import Optional +from collections.abc import Sequence +from typing import Optional, cast from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -20,7 +21,7 @@ def on_before_invoke( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -76,7 +77,7 @@ def on_new_chunk( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ): @@ -94,7 +95,7 @@ def on_new_chunk( :param stream: is stream response :param user: unique user id """ - sys.stdout.write(chunk.delta.message.content) + sys.stdout.write(cast(str, chunk.delta.message.content)) sys.stdout.flush() def on_after_invoke( @@ -106,7 +107,7 @@ def on_after_invoke( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -147,7 +148,7 @@ def on_invoke_error( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 0efe46f87..2f682ceef 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -3,7 +3,7 @@ from enum import Enum, StrEnum from typing import Optional -from pydantic import BaseModel, Field, computed_field, field_validator +from pydantic import BaseModel, Field, field_validator class PromptMessageRole(Enum): @@ -89,7 +89,6 @@ class MultiModalPromptMessageContent(PromptMessageContent): url: str = Field(default="", description="the url of multi-modal file") mime_type: str = Field(default=..., description="the mime type of multi-modal file") - @computed_field(return_type=str) @property def data(self): return self.url or f"data:{self.mime_type};base64,{self.base64_data}" diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 79a1d28eb..e2b956033 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -1,7 +1,6 @@ import decimal import os from abc import ABC, abstractmethod -from collections.abc import Mapping from typing import Optional from pydantic import ConfigDict @@ -36,7 +35,7 @@ class AIModel(ABC): model_config = ConfigDict(protected_namespaces=()) @abstractmethod - def validate_credentials(self, model: str, credentials: Mapping) -> None: + def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials @@ -214,7 +213,7 @@ def predefined_models(self) -> list[AIModelEntity]: return model_schemas - def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]: + def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]: """ Get model schema by model name and credentials @@ -236,9 +235,7 @@ def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> return None - def get_customizable_model_schema_from_credentials( - self, model: str, credentials: Mapping - ) -> Optional[AIModelEntity]: + def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ Get customizable model schema from credentials @@ -248,7 +245,7 @@ def get_customizable_model_schema_from_credentials( """ return self._get_customizable_model_schema(model, credentials) - def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: + def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ Get customizable model schema and fill in the template """ @@ -301,7 +298,7 @@ def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Op return schema - def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ Get customizable model schema diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 8faeffa87..402a30376 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -2,7 +2,7 @@ import re import time from abc import abstractmethod -from collections.abc import Generator, Mapping, Sequence +from collections.abc import Generator, Sequence from typing import Optional, Union from pydantic import ConfigDict @@ -48,7 +48,7 @@ def invoke( prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[Sequence[str]] = None, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -291,12 +291,12 @@ def _code_block_mode_stream_processor( content = piece.delta.message.content piece.delta.message.content = "" yield piece - piece = content + content_piece = content else: yield piece continue new_piece: str = "" - for char in piece: + for char in content_piece: char = str(char) if state == "normal": if char == "`": @@ -350,7 +350,7 @@ def _code_block_mode_stream_processor_with_backtick( piece.delta.message.content = "" # Yield a piece with cleared content before processing it to maintain the generator structure yield piece - piece = content + content_piece = content else: # Yield pieces without content directly yield piece @@ -360,7 +360,7 @@ def _code_block_mode_stream_processor_with_backtick( continue new_piece: str = "" - for char in piece: + for char in content_piece: if state == "search_start": if char == "`": backtick_count += 1 @@ -535,7 +535,7 @@ def get_parameter_rules(self, model: str, credentials: dict) -> list[ParameterRu return [] - def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode: + def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode: """ Get model mode diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index 4374093de..36e3e7bd5 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -104,9 +104,10 @@ def get_model_instance(self, model_type: ModelType) -> AIModel: mod = import_module_from_source( module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path ) + # FIXME "type" has no attribute "__abstractmethods__" ignore it for now fix it later model_class = next( filter( - lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, + lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, # type: ignore get_subclasses_from_module(mod, AIModel), ), None, diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index 2d38fba95..331351290 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -89,7 +89,8 @@ def _get_context_size(self, model: str, credentials: dict) -> int: model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE] + content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE] + return content_size return 1000 @@ -104,6 +105,7 @@ def _get_max_chunks(self, model: str, credentials: dict) -> int: model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + return max_chunks return 1 diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py index 5fe6dda6a..2f6f4fbbe 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py @@ -1,10 +1,10 @@ -from os.path import abspath, dirname, join +import logging from threading import Lock from typing import Any -from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer +logger = logging.getLogger(__name__) -_tokenizer = None +_tokenizer: Any = None _lock = Lock() @@ -15,11 +15,16 @@ def _get_num_tokens_by_gpt2(text: str) -> int: use gpt2 tokenizer to get num tokens """ _tokenizer = GPT2Tokenizer.get_encoder() - tokens = _tokenizer.encode(text, verbose=False) + tokens = _tokenizer.encode(text) return len(tokens) @staticmethod def get_num_tokens(text: str) -> int: + # Because this process needs more cpu resource, we turn this back before we find a better way to handle it. + # + # future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text) + # result = future.result() + # return cast(int, result) return GPT2Tokenizer._get_num_tokens_by_gpt2(text) @staticmethod @@ -27,8 +32,20 @@ def get_encoder() -> Any: global _tokenizer, _lock with _lock: if _tokenizer is None: - base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), "gpt2") - _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) + # Try to use tiktoken to get the tokenizer because it is faster + # + try: + import tiktoken + + _tokenizer = tiktoken.get_encoding("gpt2") + except Exception: + from os.path import abspath, dirname, join + + from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore + + base_path = abspath(__file__) + gpt2_tokenizer_path = join(dirname(base_path), "gpt2") + _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) + logger.info("Fallback to Transformers' GPT-2 tokenizer from tiktoken") return _tokenizer diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index b394ea4e9..6ce316b13 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -127,7 +127,8 @@ def _get_model_audio_type(self, model: str, credentials: dict) -> str: if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties: raise ValueError("this model does not support audio type") - return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE] + audio_type: str = model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE] + return audio_type def _get_model_word_limit(self, model: str, credentials: dict) -> int: """ @@ -138,8 +139,9 @@ def _get_model_word_limit(self, model: str, credentials: dict) -> int: if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties: raise ValueError("this model does not support word limit") + world_limit: int = model_schema.model_properties[ModelPropertyKey.WORD_LIMIT] - return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT] + return world_limit def _get_model_workers_limit(self, model: str, credentials: dict) -> int: """ @@ -150,8 +152,9 @@ def _get_model_workers_limit(self, model: str, credentials: dict) -> int: if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties: raise ValueError("this model does not support max workers") + workers_limit: int = model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] - return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] + return workers_limit @staticmethod def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"): diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index c5d7a83a4..03818741f 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -113,7 +113,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) - if "o1" in model: + if model.startswith("o1"): client.chat.completions.create( messages=[{"role": "user", "content": "ping"}], model=model, @@ -311,7 +311,10 @@ def _chat_generate( prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) block_as_stream = False - if "o1" in model: + if model.startswith("o1"): + if "max_tokens" in model_parameters: + model_parameters["max_completion_tokens"] = model_parameters["max_tokens"] + del model_parameters["max_tokens"] if stream: block_as_stream = True stream = False @@ -404,7 +407,7 @@ def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[Promp ] ) - if "o1" in model: + if model.startswith("o1"): system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)]) if system_message_count > 0: new_prompt_messages = [] diff --git a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py index a2b14cf3d..4aa09e61f 100644 --- a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py @@ -64,10 +64,12 @@ def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) + if not ai_model_entity: + return None return ai_model_entity.entity @staticmethod - def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: + def _get_ai_model_entity(base_model_name: str, model: str) -> Optional[AzureBaseModel]: for ai_model_entity in SPEECH2TEXT_BASE_MODELS: if ai_model_entity.base_model_name == base_model_name: ai_model_entity_copy = copy.deepcopy(ai_model_entity) diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index 173b9d250..6d50ba916 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -114,6 +114,8 @@ def _process_sentence(self, sentence: str, model: str, voice, credentials: dict) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) + if not ai_model_entity: + return None return ai_model_entity.entity @staticmethod diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 75ed7ad62..29bd673d5 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -6,9 +6,9 @@ from typing import Optional, Union, cast # 3rd import -import boto3 -from botocore.config import Config -from botocore.exceptions import ( +import boto3 # type: ignore +from botocore.config import Config # type: ignore +from botocore.exceptions import ( # type: ignore ClientError, EndpointConnectionError, NoRegionError, diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py index aba8fedbc..3a0a241f7 100644 --- a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py @@ -44,7 +44,7 @@ def _invoke( :return: rerank result """ if len(docs) == 0: - return RerankResult(model=model, docs=docs) + return RerankResult(model=model, docs=[]) # initialize client client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) @@ -62,7 +62,7 @@ def _invoke( # format document rerank_document = RerankDocument( index=result.index, - text=result.document.text, + text=result.document.text if result.document else "", score=result.relevance_score, ) diff --git a/api/core/model_runtime/model_providers/fireworks/_common.py b/api/core/model_runtime/model_providers/fireworks/_common.py index 378ced3a4..38d0a9dfb 100644 --- a/api/core/model_runtime/model_providers/fireworks/_common.py +++ b/api/core/model_runtime/model_providers/fireworks/_common.py @@ -1,5 +1,3 @@ -from collections.abc import Mapping - import openai from core.model_runtime.errors.invoke import ( @@ -13,7 +11,7 @@ class _CommonFireworks: - def _to_credential_kwargs(self, credentials: Mapping) -> dict: + def _to_credential_kwargs(self, credentials: dict) -> dict: """ Transform credentials to kwargs for model instance diff --git a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py index c745a7e97..4c0362838 100644 --- a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py @@ -1,5 +1,4 @@ import time -from collections.abc import Mapping from typing import Optional, Union import numpy as np @@ -93,7 +92,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int """ return sum(self._get_num_tokens_by_gpt2(text) for text in texts) - def validate_credentials(self, model: str, credentials: Mapping) -> None: + def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials diff --git a/api/core/model_runtime/model_providers/gitee_ai/_common.py b/api/core/model_runtime/model_providers/gitee_ai/_common.py index 0750f3b75..ad6600faf 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/_common.py +++ b/api/core/model_runtime/model_providers/gitee_ai/_common.py @@ -1,4 +1,4 @@ -from dashscope.common.error import ( +from dashscope.common.error import ( # type: ignore AuthenticationError, InvalidParameter, RequestFailure, diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py index 832ba9274..737d3d5c9 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional import httpx @@ -51,7 +51,7 @@ def _invoke( base_url = base_url.removesuffix("/") try: - body = {"model": model, "query": query, "documents": docs} + body: dict[str, Any] = {"model": model, "query": query, "documents": docs} if top_n is not None: body["top_n"] = top_n response = httpx.post( diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py index b833c5652..97d5ecc24 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py @@ -25,7 +25,4 @@ def validate_credentials(self, model: str, credentials: dict) -> None: @staticmethod def _add_custom_parameters(credentials: dict, model: str) -> None: - if model is None: - model = "bge-m3" - - credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model}/v1/" + credentials["endpoint_url"] = "https://ai.gitee.com/v1" diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py b/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py index 36dcea405..dc91257da 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional import requests @@ -13,9 +13,10 @@ class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel): Model class for OpenAI text2speech model. """ + # FIXME this Any return will be better type def _invoke( self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None - ) -> any: + ) -> Any: """ _invoke text2speech model @@ -47,7 +48,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None: except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: + # FIXME this Any return will be better type + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> Any: """ _tts_invoke_streaming text2speech model :param model: model name diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 7d19ccbb7..98273f60a 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -7,7 +7,7 @@ from typing import Optional, Union import google.ai.generativelanguage as glm -import google.generativeai as genai +import google.generativeai as genai # type: ignore import requests from google.api_core import exceptions from google.generativeai.types import ContentType, File, GenerateContentResponse diff --git a/api/core/model_runtime/model_providers/gpustack/gpustack.yaml b/api/core/model_runtime/model_providers/gpustack/gpustack.yaml index ee4a3c159..6a1673fa2 100644 --- a/api/core/model_runtime/model_providers/gpustack/gpustack.yaml +++ b/api/core/model_runtime/model_providers/gpustack/gpustack.yaml @@ -9,6 +9,8 @@ supported_model_types: - llm - text-embedding - rerank + - speech2text + - tts configurate_methods: - customizable-model model_credential_schema: @@ -118,3 +120,19 @@ model_credential_schema: label: en_US: Not Support zh_Hans: 不支持 + - variable: voices + show_on: + - variable: __model_type + value: tts + label: + en_US: Available Voices (comma-separated) + zh_Hans: 可用声音(用英文逗号分隔) + type: text-input + required: false + default: "Chinese Female" + placeholder: + en_US: "Chinese Female, Chinese Male, Japanese Male, Cantonese Female, English Female, English Male, Korean Female" + zh_Hans: "Chinese Female, Chinese Male, Japanese Male, Cantonese Female, English Female, English Male, Korean Female" + help: + en_US: "List voice names separated by commas. First voice will be used as default." + zh_Hans: "用英文逗号分隔的声音列表。第一个声音将作为默认值。" diff --git a/api/core/model_runtime/model_providers/gpustack/llm/llm.py b/api/core/model_runtime/model_providers/gpustack/llm/llm.py index ce6780b6a..429c76183 100644 --- a/api/core/model_runtime/model_providers/gpustack/llm/llm.py +++ b/api/core/model_runtime/model_providers/gpustack/llm/llm.py @@ -1,7 +1,5 @@ from collections.abc import Generator -from yarl import URL - from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import ( PromptMessage, @@ -24,9 +22,10 @@ def _invoke( stream: bool = True, user: str | None = None, ) -> LLMResult | Generator: + compatible_credentials = self._get_compatible_credentials(credentials) return super()._invoke( model, - credentials, + compatible_credentials, prompt_messages, model_parameters, tools, @@ -36,10 +35,15 @@ def _invoke( ) def validate_credentials(self, model: str, credentials: dict) -> None: - self._add_custom_parameters(credentials) - super().validate_credentials(model, credentials) + compatible_credentials = self._get_compatible_credentials(credentials) + super().validate_credentials(model, compatible_credentials) + + def _get_compatible_credentials(self, credentials: dict) -> dict: + credentials = credentials.copy() + base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai") + credentials["endpoint_url"] = f"{base_url}/v1-openai" + return credentials @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai") credentials["mode"] = "chat" diff --git a/api/core/app/task_pipeline/workflow_cycle_state_manager.py b/api/core/model_runtime/model_providers/gpustack/speech2text/__init__.py similarity index 100% rename from api/core/app/task_pipeline/workflow_cycle_state_manager.py rename to api/core/model_runtime/model_providers/gpustack/speech2text/__init__.py diff --git a/api/core/model_runtime/model_providers/gpustack/speech2text/speech2text.py b/api/core/model_runtime/model_providers/gpustack/speech2text/speech2text.py new file mode 100644 index 000000000..e8ee90db6 --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/speech2text/speech2text.py @@ -0,0 +1,43 @@ +from typing import IO, Optional + +from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import OAICompatSpeech2TextModel + + +class GPUStackSpeech2TextModel(OAICompatSpeech2TextModel): + """ + Model class for GPUStack Speech to text model. + """ + + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: + """ + Invoke speech2text model + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + compatible_credentials = self._get_compatible_credentials(credentials) + return super()._invoke(model, compatible_credentials, file) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + """ + compatible_credentials = self._get_compatible_credentials(credentials) + super().validate_credentials(model, compatible_credentials) + + def _get_compatible_credentials(self, credentials: dict) -> dict: + """ + Get compatible credentials + + :param credentials: model credentials + :return: compatible credentials + """ + compatible_credentials = credentials.copy() + base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai") + compatible_credentials["endpoint_url"] = f"{base_url}/v1-openai" + return compatible_credentials diff --git a/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py index eb324491a..35b499e51 100644 --- a/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py @@ -1,7 +1,5 @@ from typing import Optional -from yarl import URL - from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.text_embedding_entities import ( TextEmbeddingResult, @@ -24,12 +22,15 @@ def _invoke( user: Optional[str] = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, ) -> TextEmbeddingResult: - return super()._invoke(model, credentials, texts, user, input_type) + compatible_credentials = self._get_compatible_credentials(credentials) + return super()._invoke(model, compatible_credentials, texts, user, input_type) def validate_credentials(self, model: str, credentials: dict) -> None: - self._add_custom_parameters(credentials) - super().validate_credentials(model, credentials) + compatible_credentials = self._get_compatible_credentials(credentials) + super().validate_credentials(model, compatible_credentials) - @staticmethod - def _add_custom_parameters(credentials: dict) -> None: - credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai") + def _get_compatible_credentials(self, credentials: dict) -> dict: + credentials = credentials.copy() + base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai") + credentials["endpoint_url"] = f"{base_url}/v1-openai" + return credentials diff --git a/api/core/model_runtime/model_providers/gpustack/tts/__init__.py b/api/core/model_runtime/model_providers/gpustack/tts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/model_runtime/model_providers/gpustack/tts/tts.py b/api/core/model_runtime/model_providers/gpustack/tts/tts.py new file mode 100644 index 000000000..f144ddff4 --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/tts/tts.py @@ -0,0 +1,57 @@ +from typing import Any, Optional + +from core.model_runtime.model_providers.openai_api_compatible.tts.tts import OAICompatText2SpeechModel + + +class GPUStackText2SpeechModel(OAICompatText2SpeechModel): + """ + Model class for GPUStack Text to Speech model. + """ + + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> Any: + """ + Invoke text2speech model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :param user: unique user id + :return: text translated to audio file + """ + compatible_credentials = self._get_compatible_credentials(credentials) + return super()._invoke( + model=model, + tenant_id=tenant_id, + credentials=compatible_credentials, + content_text=content_text, + voice=voice, + user=user, + ) + + def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :param user: unique user id + """ + compatible_credentials = self._get_compatible_credentials(credentials) + super().validate_credentials(model, compatible_credentials) + + def _get_compatible_credentials(self, credentials: dict) -> dict: + """ + Get compatible credentials + + :param credentials: model credentials + :return: compatible credentials + """ + compatible_credentials = credentials.copy() + base_url = credentials["endpoint_url"].rstrip("/").removesuffix("/v1-openai") + compatible_credentials["endpoint_url"] = f"{base_url}/v1-openai" + + return compatible_credentials diff --git a/api/core/model_runtime/model_providers/groq/llm/gemma-7b-it.yaml b/api/core/model_runtime/model_providers/groq/llm/gemma-7b-it.yaml index 02f84e95f..157baaf31 100644 --- a/api/core/model_runtime/model_providers/groq/llm/gemma-7b-it.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/gemma-7b-it.yaml @@ -18,6 +18,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.05' output: '0.1' diff --git a/api/core/model_runtime/model_providers/groq/llm/gemma2-9b-it.yaml b/api/core/model_runtime/model_providers/groq/llm/gemma2-9b-it.yaml index dad496f66..d0294ac6a 100644 --- a/api/core/model_runtime/model_providers/groq/llm/gemma2-9b-it.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/gemma2-9b-it.yaml @@ -18,6 +18,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.05' output: '0.1' diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.1-405b-reasoning.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.1-405b-reasoning.yaml index 217785cea..3cbce0c05 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama-3.1-405b-reasoning.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.1-405b-reasoning.yaml @@ -18,6 +18,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.05' output: '0.1' diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.1-70b-versatile.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.1-70b-versatile.yaml index 01323a1b8..07a0187e4 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama-3.1-70b-versatile.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.1-70b-versatile.yaml @@ -6,6 +6,7 @@ label: model_type: llm features: - agent-thought + - multi-tool-call model_properties: mode: chat context_size: 131072 @@ -19,6 +20,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.05' output: '0.1' diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.1-8b-instant.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.1-8b-instant.yaml index a82e64532..04eae49b9 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama-3.1-8b-instant.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.1-8b-instant.yaml @@ -5,6 +5,7 @@ label: model_type: llm features: - agent-thought + - multi-tool-call model_properties: mode: chat context_size: 131072 @@ -18,6 +19,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.05' output: '0.1' diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-text-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-text-preview.yaml index 3f30d81ae..e6eadeb07 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-text-preview.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-text-preview.yaml @@ -19,6 +19,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.05' output: '0.1' diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-vision-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-vision-preview.yaml index 563221879..241a7bed1 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-vision-preview.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-vision-preview.yaml @@ -19,6 +19,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.05' output: '0.1' diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-1b-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-1b-preview.yaml index a44e4ff50..a6087d344 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-1b-preview.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-1b-preview.yaml @@ -18,6 +18,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.05' output: '0.1' diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-3b-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-3b-preview.yaml index f2fdd0a05..93a8127ec 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-3b-preview.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-3b-preview.yaml @@ -18,6 +18,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.05' output: '0.1' diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-text-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-text-preview.yaml index 0391a7c89..f9361bff6 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-text-preview.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-text-preview.yaml @@ -19,6 +19,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.05' output: '0.1' diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-vision-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-vision-preview.yaml index e7b93101e..145b45792 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-vision-preview.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-vision-preview.yaml @@ -19,6 +19,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.05' output: '0.1' diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.3-70b-specdec.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.3-70b-specdec.yaml index bda9ec530..916dfee39 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama-3.3-70b-specdec.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.3-70b-specdec.yaml @@ -5,6 +5,7 @@ label: model_type: llm features: - agent-thought + - multi-tool-call model_properties: mode: chat context_size: 131072 @@ -18,6 +19,18 @@ parameter_rules: default: 1024 min: 1 max: 32768 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: "0.05" output: "0.1" diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.3-70b-versatile.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.3-70b-versatile.yaml index eb609f4db..a5de4e752 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama-3.3-70b-versatile.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.3-70b-versatile.yaml @@ -5,6 +5,7 @@ label: model_type: llm features: - agent-thought + - multi-tool-call model_properties: mode: chat context_size: 131072 @@ -18,6 +19,18 @@ parameter_rules: default: 1024 min: 1 max: 32768 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: "0.05" output: "0.1" diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-guard-3-8b.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-guard-3-8b.yaml index 03779ccc6..bd8e5d2a3 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama-guard-3-8b.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama-guard-3-8b.yaml @@ -18,6 +18,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.20' output: '0.20' diff --git a/api/core/model_runtime/model_providers/groq/llm/llama2-70b-4096.yaml b/api/core/model_runtime/model_providers/groq/llm/llama2-70b-4096.yaml index 384912b0d..6e7ffd7a9 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama2-70b-4096.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama2-70b-4096.yaml @@ -18,6 +18,18 @@ parameter_rules: default: 512 min: 1 max: 4096 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.7' output: '0.8' diff --git a/api/core/model_runtime/model_providers/groq/llm/llama3-70b-8192.yaml b/api/core/model_runtime/model_providers/groq/llm/llama3-70b-8192.yaml index 91d0e3076..2c25bb743 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama3-70b-8192.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama3-70b-8192.yaml @@ -18,6 +18,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.59' output: '0.79' diff --git a/api/core/model_runtime/model_providers/groq/llm/llama3-8b-8192.yaml b/api/core/model_runtime/model_providers/groq/llm/llama3-8b-8192.yaml index b6154f761..d8a708eaf 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama3-8b-8192.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama3-8b-8192.yaml @@ -5,6 +5,7 @@ label: model_type: llm features: - agent-thought + - multi-tool-call model_properties: mode: chat context_size: 8192 @@ -18,6 +19,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.05' output: '0.08' diff --git a/api/core/model_runtime/model_providers/groq/llm/llama3-groq-70b-8192-tool-use-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama3-groq-70b-8192-tool-use-preview.yaml index 32ccbf1f4..61c83c980 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llama3-groq-70b-8192-tool-use-preview.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/llama3-groq-70b-8192-tool-use-preview.yaml @@ -5,6 +5,7 @@ label: model_type: llm features: - agent-thought + - multi-tool-call model_properties: mode: chat context_size: 8192 @@ -18,6 +19,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.05' output: '0.08' diff --git a/api/core/model_runtime/model_providers/huggingface_hub/_common.py b/api/core/model_runtime/model_providers/huggingface_hub/_common.py index 3c4020b6e..d8a09265e 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/_common.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/_common.py @@ -1,4 +1,4 @@ -from huggingface_hub.utils import BadRequestError, HfHubHTTPError +from huggingface_hub.utils import BadRequestError, HfHubHTTPError # type: ignore from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index 9d29237fd..cdb4103cd 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -1,9 +1,9 @@ from collections.abc import Generator from typing import Optional, Union -from huggingface_hub import InferenceClient -from huggingface_hub.hf_api import HfApi -from huggingface_hub.utils import BadRequestError +from huggingface_hub import InferenceClient # type: ignore +from huggingface_hub.hf_api import HfApi # type: ignore +from huggingface_hub.utils import BadRequestError # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE diff --git a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py index 8278d1e64..4ca537940 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ import numpy as np import requests -from huggingface_hub import HfApi, InferenceClient +from huggingface_hub import HfApi, InferenceClient # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py index 284429b74..a8a13313d 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py @@ -157,7 +157,6 @@ def validate_credentials(self, model: str, credentials: dict) -> None: headers["Authorization"] = f"Bearer {api_key}" extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers) - print(extra_args) if extra_args.model_type != "embedding": raise CredentialsValidateFailedError("Current model is not a embedding model") diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py index 2014de851..9b7686ded 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py @@ -3,11 +3,11 @@ from collections.abc import Generator from typing import cast -from tencentcloud.common import credential -from tencentcloud.common.exception import TencentCloudSDKException -from tencentcloud.common.profile.client_profile import ClientProfile -from tencentcloud.common.profile.http_profile import HttpProfile -from tencentcloud.hunyuan.v20230901 import hunyuan_client, models +from tencentcloud.common import credential # type: ignore +from tencentcloud.common.exception import TencentCloudSDKException # type: ignore +from tencentcloud.common.profile.client_profile import ClientProfile # type: ignore +from tencentcloud.common.profile.http_profile import HttpProfile # type: ignore +from tencentcloud.hunyuan.v20230901 import hunyuan_client, models # type: ignore from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -54,6 +54,7 @@ def _invoke( "Model": model, "Messages": messages_dict, "Stream": stream, + "Stop": stop, **custom_parameters, } # add Tools and ToolChoice @@ -305,7 +306,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: elif isinstance(message, ToolPromptMessage): message_text = f"{tool_prompt} {content}" elif isinstance(message, SystemPromptMessage): - message_text = content + message_text = content if isinstance(content, str) else "" else: raise ValueError(f"Got unknown type {message}") diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py index b6d857cb3..856cda90d 100644 --- a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py @@ -3,11 +3,11 @@ import time from typing import Optional -from tencentcloud.common import credential -from tencentcloud.common.exception import TencentCloudSDKException -from tencentcloud.common.profile.client_profile import ClientProfile -from tencentcloud.common.profile.http_profile import HttpProfile -from tencentcloud.hunyuan.v20230901 import hunyuan_client, models +from tencentcloud.common import credential # type: ignore +from tencentcloud.common.exception import TencentCloudSDKException # type: ignore +from tencentcloud.common.profile.client_profile import ClientProfile # type: ignore +from tencentcloud.common.profile.http_profile import HttpProfile # type: ignore +from tencentcloud.hunyuan.v20230901 import hunyuan_client, models # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py index d80cbfa83..1fc0f8c02 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py @@ -1,11 +1,11 @@ from os.path import abspath, dirname, join from threading import Lock -from transformers import AutoTokenizer +from transformers import AutoTokenizer # type: ignore class JinaTokenizer: - _tokenizer = None + _tokenizer: AutoTokenizer | None = None _lock = Lock() @classmethod diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index 88cc0e8e0..357631b2d 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -40,7 +40,7 @@ def generate( url = f"https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}" - extra_kwargs = {} + extra_kwargs: dict[str, Any] = {} if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] @@ -117,19 +117,19 @@ def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ handle chat generate response """ - response = response.json() - if "base_resp" in response and response["base_resp"]["status_code"] != 0: - code = response["base_resp"]["status_code"] - msg = response["base_resp"]["status_msg"] + response_data = response.json() + if "base_resp" in response_data and response_data["base_resp"]["status_code"] != 0: + code = response_data["base_resp"]["status_code"] + msg = response_data["base_resp"]["status_msg"] self._handle_error(code, msg) - message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value) + message = MinimaxMessage(content=response_data["reply"], role=MinimaxMessage.Role.ASSISTANT.value) message.usage = { "prompt_tokens": 0, - "completion_tokens": response["usage"]["total_tokens"], - "total_tokens": response["usage"]["total_tokens"], + "completion_tokens": response_data["usage"]["total_tokens"], + "total_tokens": response_data["usage"]["total_tokens"], } - message.stop_reason = response["choices"][0]["finish_reason"] + message.stop_reason = response_data["choices"][0]["finish_reason"] return message def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: @@ -139,10 +139,10 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator for line in response.iter_lines(): if not line: continue - line: str = line.decode("utf-8") - if line.startswith("data: "): - line = line[6:].strip() - data = loads(line) + line_str: str = line.decode("utf-8") + if line_str.startswith("data: "): + line_str = line_str[6:].strip() + data = loads(line_str) if "base_resp" in data and data["base_resp"]["status_code"] != 0: code = data["base_resp"]["status_code"] @@ -162,5 +162,5 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator continue for choice in choices: - message = choice["delta"] - yield MinimaxMessage(content=message, role=MinimaxMessage.Role.ASSISTANT.value) + message_choice = choice["delta"] + yield MinimaxMessage(content=message_choice, role=MinimaxMessage.Role.ASSISTANT.value) diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 8b8fdbb6b..284b61829 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -41,7 +41,7 @@ def generate( url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}" - extra_kwargs = {} + extra_kwargs: dict[str, Any] = {} if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] @@ -122,19 +122,19 @@ def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ handle chat generate response """ - response = response.json() - if "base_resp" in response and response["base_resp"]["status_code"] != 0: - code = response["base_resp"]["status_code"] - msg = response["base_resp"]["status_msg"] + response_data = response.json() + if "base_resp" in response_data and response_data["base_resp"]["status_code"] != 0: + code = response_data["base_resp"]["status_code"] + msg = response_data["base_resp"]["status_msg"] self._handle_error(code, msg) - message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value) + message = MinimaxMessage(content=response_data["reply"], role=MinimaxMessage.Role.ASSISTANT.value) message.usage = { "prompt_tokens": 0, - "completion_tokens": response["usage"]["total_tokens"], - "total_tokens": response["usage"]["total_tokens"], + "completion_tokens": response_data["usage"]["total_tokens"], + "total_tokens": response_data["usage"]["total_tokens"], } - message.stop_reason = response["choices"][0]["finish_reason"] + message.stop_reason = response_data["choices"][0]["finish_reason"] return message def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: @@ -144,10 +144,10 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator for line in response.iter_lines(): if not line: continue - line: str = line.decode("utf-8") - if line.startswith("data: "): - line = line[6:].strip() - data = loads(line) + line_str: str = line.decode("utf-8") + if line_str.startswith("data: "): + line_str = line_str[6:].strip() + data = loads(line_str) if "base_resp" in data and data["base_resp"]["status_code"] != 0: code = data["base_resp"]["status_code"] diff --git a/api/core/model_runtime/model_providers/minimax/llm/types.py b/api/core/model_runtime/model_providers/minimax/llm/types.py index 88ebe5e2e..c248db374 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/types.py +++ b/api/core/model_runtime/model_providers/minimax/llm/types.py @@ -11,9 +11,9 @@ class Role(Enum): role: str = Role.USER.value content: str - usage: dict[str, int] = None + usage: dict[str, int] | None = None stop_reason: str = "" - function_call: dict[str, Any] = None + function_call: dict[str, Any] | None = None def to_dict(self) -> dict[str, Any]: if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value: diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py index 90d015942..cfee0b91e 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -252,7 +252,7 @@ def get_tool_call(tool_name: str): # ignore sse comments if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().removeprefix("data: ") + decoded_chunk = chunk.strip().removeprefix("data:").lstrip() chunk_json = None try: chunk_json = json.loads(decoded_chunk) diff --git a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py index 56a707333..8a4c19d4d 100644 --- a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py @@ -2,8 +2,8 @@ from functools import wraps from typing import Optional -from nomic import embed -from nomic import login as nomic_login +from nomic import embed # type: ignore +from nomic import login as nomic_login # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType diff --git a/api/core/model_runtime/model_providers/oci/llm/llm.py b/api/core/model_runtime/model_providers/oci/llm/llm.py index 1e1fc5b3e..9f676573f 100644 --- a/api/core/model_runtime/model_providers/oci/llm/llm.py +++ b/api/core/model_runtime/model_providers/oci/llm/llm.py @@ -5,8 +5,8 @@ from collections.abc import Generator from typing import Optional, Union -import oci -from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse +import oci # type: ignore +from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse # type: ignore from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py index 50fa63768..5a428c9fe 100644 --- a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional import numpy as np -import oci +import oci # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index 83c4facc8..3543fe58b 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -61,6 +61,7 @@ def _invoke( headers = {"Content-Type": "application/json"} endpoint_url = credentials.get("base_url") + assert endpoint_url is not None, "Base URL is required for Ollama API" if not endpoint_url.endswith("/"): endpoint_url += "/" diff --git a/api/core/model_runtime/model_providers/openai/_common.py b/api/core/model_runtime/model_providers/openai/_common.py index 2181bb4f0..ac2b3e688 100644 --- a/api/core/model_runtime/model_providers/openai/_common.py +++ b/api/core/model_runtime/model_providers/openai/_common.py @@ -1,5 +1,3 @@ -from collections.abc import Mapping - import openai from httpx import Timeout @@ -14,7 +12,7 @@ class _CommonOpenAI: - def _to_credential_kwargs(self, credentials: Mapping) -> dict: + def _to_credential_kwargs(self, credentials: dict) -> dict: """ Transform credentials to kwargs for model instance diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml index a4681fe18..d6be36ad7 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml @@ -37,6 +37,9 @@ parameter_rules: options: - text - json_object + - json_schema + - name: json_schema + use_template: json_schema pricing: input: '2.50' output: '10.00' diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 73cd7e3c3..86042de6f 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -739,6 +739,12 @@ def _handle_chat_generate_stream_response( delta = chunk.choices[0] has_finish_reason = delta.finish_reason is not None + # to fix issue #12215 yi model has special case for ligthing + # FIXME drop the case when yi model is updated + if model.startswith("yi-"): + if isinstance(delta.finish_reason, str): + # doc: https://platform.lingyiwanwu.com/docs/api-reference + has_finish_reason = delta.finish_reason.startswith(("length", "stop", "content_filter")) if ( not has_finish_reason diff --git a/api/core/model_runtime/model_providers/openai/moderation/moderation.py b/api/core/model_runtime/model_providers/openai/moderation/moderation.py index 619044d80..227e4b0c1 100644 --- a/api/core/model_runtime/model_providers/openai/moderation/moderation.py +++ b/api/core/model_runtime/model_providers/openai/moderation/moderation.py @@ -93,7 +93,8 @@ def _get_max_characters_per_chunk(self, model: str, credentials: dict) -> int: model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK] + max_characters_per_chunk: int = model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK] + return max_characters_per_chunk return 2000 @@ -108,6 +109,7 @@ def _get_max_chunks(self, model: str, credentials: dict) -> int: model_schema = self.get_model_schema(model, credentials) if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + return max_chunks return 1 diff --git a/api/core/model_runtime/model_providers/openai/openai.py b/api/core/model_runtime/model_providers/openai/openai.py index aa6f38ce9..c546441af 100644 --- a/api/core/model_runtime/model_providers/openai/openai.py +++ b/api/core/model_runtime/model_providers/openai/openai.py @@ -1,5 +1,4 @@ import logging -from collections.abc import Mapping from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -9,7 +8,7 @@ class OpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: Mapping) -> None: + def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials if validate failed, raise exception diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 8e07d56f4..734cf28b1 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -332,6 +332,23 @@ def _generate( if not endpoint_url.endswith("/"): endpoint_url += "/" + response_format = model_parameters.get("response_format") + if response_format: + if response_format == "json_schema": + json_schema = model_parameters.get("json_schema") + if not json_schema: + raise ValueError("Must define JSON Schema when the response format is json_schema") + try: + schema = json.loads(json_schema) + except: + raise ValueError(f"not correct json_schema format: {json_schema}") + model_parameters.pop("json_schema") + model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema} + else: + model_parameters["response_format"] = {"type": response_format} + elif "json_schema" in model_parameters: + del model_parameters["json_schema"] + data = {"model": model, "stream": stream, **model_parameters} completion_type = LLMMode.value_of(credentials["mode"]) @@ -462,7 +479,7 @@ def get_tool_call(tool_call_id: str): # ignore sse comments if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().removeprefix("data: ") + decoded_chunk = chunk.strip().removeprefix("data:").lstrip() if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]" continue diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py index a490537e5..74229a089 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py @@ -33,6 +33,7 @@ def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional headers["Authorization"] = f"Bearer {api_key}" endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" endpoint_url = urljoin(endpoint_url, "audio/transcriptions") diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index 9da8f55d0..b4d6c6c6c 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -55,6 +55,7 @@ def _invoke( headers["Authorization"] = f"Bearer {api_key}" endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py b/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py index 8239c625f..53e895b0e 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py @@ -44,6 +44,7 @@ def _invoke( # Construct endpoint URL endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" endpoint_url = urljoin(endpoint_url, "audio/speech") diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 2789a9250..e9509b544 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -1,7 +1,7 @@ from collections.abc import Generator from enum import Enum from json import dumps, loads -from typing import Any, Union +from typing import Any, Optional, Union from requests import Response, post from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema @@ -20,7 +20,7 @@ class Role(Enum): role: str = Role.USER.value content: str - usage: dict[str, int] = None + usage: Optional[dict[str, int]] = None stop_reason: str = "" def to_dict(self) -> dict[str, Any]: @@ -165,17 +165,17 @@ def _handle_chat_stream_generate_response( if not line: continue - line: str = line.decode("utf-8") - if line.startswith("data: "): - line = line[6:].strip() + line_str: str = line.decode("utf-8") + if line_str.startswith("data: "): + line_str = line_str[6:].strip() - if line == "[DONE]": + if line_str == "[DONE]": return try: - data = loads(line) + data = loads(line_str) except Exception as e: - raise InternalServerError(f"Failed to convert response to json: {e} with text: {line}") + raise InternalServerError(f"Failed to convert response to json: {e} with text: {line_str}") output = data["outputs"] diff --git a/api/core/model_runtime/model_providers/openrouter/llm/claude-3-5-sonnet.yaml b/api/core/model_runtime/model_providers/openrouter/llm/claude-3-5-sonnet.yaml index e829048e5..caed5a901 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/claude-3-5-sonnet.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/claude-3-5-sonnet.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 200000 diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index 7bbd31e87..40ea4dc01 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -53,14 +53,16 @@ def _invoke( api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - + endpoint_url: Optional[str] if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": endpoint_url = "https://cloud.perfxlab.cn/v1/" else: endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" + assert isinstance(endpoint_url, str) endpoint_url = urljoin(endpoint_url, "embeddings") extra_model_kwargs = {} @@ -142,13 +144,16 @@ def validate_credentials(self, model: str, credentials: dict) -> None: if api_key: headers["Authorization"] = f"Bearer {api_key}" + endpoint_url: Optional[str] if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": endpoint_url = "https://cloud.perfxlab.cn/v1/" else: endpoint_url = credentials.get("endpoint_url") + assert endpoint_url is not None, "endpoint_url is required in credentials" if not endpoint_url.endswith("/"): endpoint_url += "/" + assert isinstance(endpoint_url, str) endpoint_url = urljoin(endpoint_url, "embeddings") payload = {"input": "ping", "model": model} diff --git a/api/core/model_runtime/model_providers/replicate/_common.py b/api/core/model_runtime/model_providers/replicate/_common.py index 915f6e0ee..3e2cf2adb 100644 --- a/api/core/model_runtime/model_providers/replicate/_common.py +++ b/api/core/model_runtime/model_providers/replicate/_common.py @@ -1,4 +1,4 @@ -from replicate.exceptions import ModelError, ReplicateError +from replicate.exceptions import ModelError, ReplicateError # type: ignore from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index 3641b35dc..1e7858100 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -1,9 +1,9 @@ from collections.abc import Generator from typing import Optional, Union -from replicate import Client as ReplicateClient -from replicate.exceptions import ReplicateError -from replicate.prediction import Prediction +from replicate import Client as ReplicateClient # type: ignore +from replicate.exceptions import ReplicateError # type: ignore +from replicate.prediction import Prediction # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index 41759fe07..aaf825388 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -2,11 +2,11 @@ import time from typing import Optional -from replicate import Client as ReplicateClient +from replicate import Client as ReplicateClient # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel @@ -86,7 +86,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> Option label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, - model_properties={"context_size": 4096, "max_chunks": 1}, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 4096, ModelPropertyKey.MAX_CHUNKS: 1}, ) return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py index 5ff00f008..b8c979b1f 100644 --- a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -4,7 +4,7 @@ from collections.abc import Generator, Iterator from typing import Any, Optional, Union, cast -import boto3 +import boto3 # type: ignore from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -83,7 +83,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): sagemaker_session: Any = None predictor: Any = None - sagemaker_endpoint: str = None + sagemaker_endpoint: str | None = None def _handle_chat_generate_response( self, @@ -209,8 +209,8 @@ def _invoke( :param user: unique user id :return: full response or stream response chunk generator result """ - from sagemaker import Predictor, serializers - from sagemaker.session import Session + from sagemaker import Predictor, serializers # type: ignore + from sagemaker.session import Session # type: ignore if not self.sagemaker_session: access_key = credentials.get("aws_access_key_id") diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py index df797bae2..7daab6d86 100644 --- a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py @@ -3,7 +3,7 @@ import operator from typing import Any, Optional -import boto3 +import boto3 # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType @@ -114,6 +114,7 @@ def _invoke( except Exception as e: logger.exception(f"Failed to invoke rerank model, model: {model}") + raise InvokeError(f"Failed to invoke rerank model, model: {model}, error: {str(e)}") def validate_credentials(self, model: str, credentials: dict) -> None: """ diff --git a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py index 2d50e9c7b..a6aca1304 100644 --- a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py @@ -2,7 +2,7 @@ import logging from typing import IO, Any, Optional -import boto3 +import boto3 # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType @@ -67,6 +67,7 @@ def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional s3_prefix = "dify/speech2text/" sagemaker_endpoint = credentials.get("sagemaker_endpoint") bucket = credentials.get("audio_s3_cache_bucket") + assert bucket is not None, "audio_s3_cache_bucket is required in credentials" s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix) payload = {"audio_s3_presign_uri": s3_presign_url} diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py index ef4ddcd6a..e7eccd997 100644 --- a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ import time from typing import Any, Optional -import boto3 +import boto3 # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject @@ -118,6 +118,7 @@ def _invoke( except Exception as e: logger.exception(f"Failed to invoke text embedding model, model: {model}, line: {line}") + raise InvokeError(str(e)) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ diff --git a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py index 6a5946453..62231c518 100644 --- a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py +++ b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Any, Optional -import boto3 +import boto3 # type: ignore import requests from core.model_runtime.entities.common_entities import I18nObject diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml index 8703a97ed..8361be91b 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml @@ -1,4 +1,3 @@ -- Tencent/Hunyuan-A52B-Instruct - Qwen/QwQ-32B-Preview - Qwen/Qwen2.5-72B-Instruct - Qwen/Qwen2.5-32B-Instruct @@ -6,11 +5,11 @@ - Qwen/Qwen2.5-7B-Instruct - Qwen/Qwen2.5-Coder-32B-Instruct - Qwen/Qwen2.5-Coder-7B-Instruct -- Qwen/Qwen2.5-Math-72B-Instruct - Qwen/Qwen2-VL-72B-Instruct - Qwen/Qwen2-1.5B-Instruct +- Qwen/Qwen2.5-72B-Instruct-128K +- Vendor-A/Qwen/Qwen2.5-72B-Instruct - Pro/Qwen/Qwen2-VL-7B-Instruct -- OpenGVLab/InternVL2-Llama3-76B - OpenGVLab/InternVL2-26B - Pro/OpenGVLab/InternVL2-8B - deepseek-ai/DeepSeek-V2.5 diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/hunyuan-a52b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/hunyuan-a52b-instruct.yaml index c5489554a..51d6c024f 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/hunyuan-a52b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/hunyuan-a52b-instruct.yaml @@ -82,3 +82,4 @@ pricing: output: '21' unit: '0.000001' currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/internvl2-llama3-76b.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/internvl2-llama3-76b.yaml index 65386d317..b5443df18 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/internvl2-llama3-76b.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/internvl2-llama3-76b.yaml @@ -82,3 +82,4 @@ pricing: output: '21' unit: '0.000001' currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py index e3a323a49..f61e8b82e 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py +++ b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py @@ -43,7 +43,7 @@ def _add_custom_parameters(cls, credentials: dict) -> None: credentials["mode"] = "chat" credentials["endpoint_url"] = "https://api.siliconflow.cn/v1" - def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: return AIModelEntity( model=model, label=I18nObject(en_US=model, zh_Hans=model), diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen-qvq-72B-preview.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen-qvq-72B-preview.yaml new file mode 100644 index 000000000..dada6bb80 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen-qvq-72B-preview.yaml @@ -0,0 +1,54 @@ +model: Qwen/QVQ-72B-Preview +label: + en_US: Qwen/QVQ-72B-Preview +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 16384 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: frequency_penalty + use_template: frequency_penalty + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '9.90' + output: '9.90' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen-qwq-32B-preview.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen-qwq-32B-preview.yaml index c949de4d7..e73c5d203 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen-qwq-32B-preview.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen-qwq-32B-preview.yaml @@ -15,9 +15,9 @@ parameter_rules: - name: max_tokens use_template: max_tokens type: int - default: 512 + default: 4096 min: 1 - max: 4096 + max: 8192 help: zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-vl-72b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-vl-72b-instruct.yaml index 1866a684b..f5180b41f 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-vl-72b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-vl-72b-instruct.yaml @@ -78,7 +78,7 @@ parameter_rules: - text - json_object pricing: - input: '21' - output: '21' + input: '4.13' + output: '4.13' unit: '0.000001' currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-vl-7b-Instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-vl-7b-Instruct.yaml index a50834468..0ffbaee38 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-vl-7b-Instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-vl-7b-Instruct.yaml @@ -78,7 +78,7 @@ parameter_rules: - text - json_object pricing: - input: '21' - output: '21' + input: '0.35' + output: '0.35' unit: '0.000001' currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct-128k.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct-128k.yaml new file mode 100644 index 000000000..79f94da37 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct-128k.yaml @@ -0,0 +1,51 @@ +model: Qwen/Qwen2.5-72B-Instruct-128K +label: + en_US: Qwen/Qwen2.5-72B-Instruct-128K +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: frequency_penalty + use_template: frequency_penalty + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '4.13' + output: '4.13' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct-vendorA.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct-vendorA.yaml new file mode 100644 index 000000000..fdbe38ff2 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct-vendorA.yaml @@ -0,0 +1,51 @@ +model: Vendor-A/Qwen/Qwen2.5-72B-Instruct +label: + en_US: Vendor-A/Qwen/Qwen2.5-72B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: frequency_penalty + use_template: frequency_penalty + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '1.00' + output: '1.00' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct.yaml index c80cd45dd..de9d9d97b 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct.yaml @@ -15,7 +15,7 @@ parameter_rules: type: int default: 512 min: 1 - max: 8192 + max: 4096 help: zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-math-72b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-math-72b-instruct.yaml index 1b6f2603f..40c9ab48c 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-math-72b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-math-72b-instruct.yaml @@ -82,3 +82,4 @@ pricing: output: '4.13' unit: '0.000001' currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/siliconflow/tts/fish-speech-1.5.yaml b/api/core/model_runtime/model_providers/siliconflow/tts/fish-speech-1.5.yaml new file mode 100644 index 000000000..e2cc86d06 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/tts/fish-speech-1.5.yaml @@ -0,0 +1,37 @@ +model: fishaudio/fish-speech-1.5 +model_type: tts +model_properties: + default_voice: 'fishaudio/fish-speech-1.5:alex' + voices: + - mode: "fishaudio/fish-speech-1.5:alex" + name: "Alex(男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.5:benjamin" + name: "Benjamin(男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.5:charles" + name: "Charles(男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.5:david" + name: "David(男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.5:anna" + name: "Anna(女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.5:bella" + name: "Bella(女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.5:claire" + name: "Claire(女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.5:diana" + name: "Diana(女声)" + language: [ "zh-Hans", "en-US" ] + audio_type: 'mp3' + max_workers: 5 + # stream: false +pricing: + input: '0.015' + output: '0' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/spark/llm/llm.py b/api/core/model_runtime/model_providers/spark/llm/llm.py index 1181ba699..cb6f28b6c 100644 --- a/api/core/model_runtime/model_providers/spark/llm/llm.py +++ b/api/core/model_runtime/model_providers/spark/llm/llm.py @@ -1,6 +1,6 @@ import threading from collections.abc import Generator -from typing import Optional, Union +from typing import Optional, Union, cast from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -270,7 +270,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" elif isinstance(message, SystemPromptMessage): - message_text = content + message_text = cast(str, content) else: raise ValueError(f"Got unknown type {message}") diff --git a/api/core/model_runtime/model_providers/stepfun/llm/llm.py b/api/core/model_runtime/model_providers/stepfun/llm/llm.py index 686809ff2..b14d5ec2e 100644 --- a/api/core/model_runtime/model_providers/stepfun/llm/llm.py +++ b/api/core/model_runtime/model_providers/stepfun/llm/llm.py @@ -250,7 +250,7 @@ def get_tool_call(tool_name: str): # ignore sse comments if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().removeprefix("data: ") + decoded_chunk = chunk.strip().removeprefix("data:").lstrip() chunk_json = None try: chunk_json = json.loads(decoded_chunk) diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py index b96d43979..03eac1942 100644 --- a/api/core/model_runtime/model_providers/togetherai/llm/llm.py +++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -12,6 +12,7 @@ AIModelEntity, DefaultParameterName, FetchFrom, + ModelFeature, ModelPropertyKey, ModelType, ParameterRule, @@ -67,7 +68,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode cred_with_endpoint = self._update_endpoint_url(credentials=credentials) REPETITION_PENALTY = "repetition_penalty" TOP_K = "top_k" - features = [] + features: list[ModelFeature] = [] entity = AIModelEntity( model=model, diff --git a/api/core/model_runtime/model_providers/tongyi/_common.py b/api/core/model_runtime/model_providers/tongyi/_common.py index 8a50c7aa0..bb6831955 100644 --- a/api/core/model_runtime/model_providers/tongyi/_common.py +++ b/api/core/model_runtime/model_providers/tongyi/_common.py @@ -1,4 +1,4 @@ -from dashscope.common.error import ( +from dashscope.common.error import ( # type: ignore AuthenticationError, InvalidParameter, RequestFailure, diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 0c1f65188..61ebd45ed 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -7,9 +7,9 @@ from pathlib import Path from typing import Optional, Union, cast -from dashscope import Generation, MultiModalConversation, get_tokenizer -from dashscope.api_entities.dashscope_response import GenerationResponse -from dashscope.common.error import ( +from dashscope import Generation, MultiModalConversation, get_tokenizer # type: ignore +from dashscope.api_entities.dashscope_response import GenerationResponse # type: ignore +from dashscope.common.error import ( # type: ignore AuthenticationError, InvalidParameter, RequestFailure, diff --git a/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py b/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py index a5ce9ead6..ed682cb0f 100644 --- a/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py @@ -1,7 +1,7 @@ from typing import Optional -import dashscope -from dashscope.common.error import ( +import dashscope # type: ignore +from dashscope.common.error import ( # type: ignore AuthenticationError, InvalidParameter, RequestFailure, @@ -51,7 +51,7 @@ def _invoke( :return: rerank result """ if len(docs) == 0: - return RerankResult(model=model, docs=docs) + return RerankResult(model=model, docs=[]) # initialize client dashscope.api_key = credentials["dashscope_api_key"] @@ -64,7 +64,7 @@ def _invoke( return_documents=True, ) - rerank_documents = [] + rerank_documents: list[RerankDocument] = [] if not response.output: return RerankResult(model=model, docs=rerank_documents) for _, result in enumerate(response.output.results): diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py index 2ef7f3f57..8c53be413 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -1,7 +1,7 @@ import time from typing import Optional -import dashscope +import dashscope # type: ignore import numpy as np from core.entities.embedding_type import EmbeddingInputType diff --git a/api/core/model_runtime/model_providers/tongyi/tts/tts.py b/api/core/model_runtime/model_providers/tongyi/tts/tts.py index ca3b9fbc1..a654e2d76 100644 --- a/api/core/model_runtime/model_providers/tongyi/tts/tts.py +++ b/api/core/model_runtime/model_providers/tongyi/tts/tts.py @@ -2,10 +2,10 @@ from queue import Queue from typing import Any, Optional -import dashscope -from dashscope import SpeechSynthesizer -from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse -from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult +import dashscope # type: ignore +from dashscope import SpeechSynthesizer # type: ignore +from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse # type: ignore +from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult # type: ignore from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/upstage/_common.py b/api/core/model_runtime/model_providers/upstage/_common.py index 47ebaccd8..f6609bba7 100644 --- a/api/core/model_runtime/model_providers/upstage/_common.py +++ b/api/core/model_runtime/model_providers/upstage/_common.py @@ -1,5 +1,3 @@ -from collections.abc import Mapping - import openai from httpx import Timeout @@ -14,7 +12,7 @@ class _CommonUpstage: - def _to_credential_kwargs(self, credentials: Mapping) -> dict: + def _to_credential_kwargs(self, credentials: dict) -> dict: """ Transform credentials to kwargs for model instance diff --git a/api/core/model_runtime/model_providers/upstage/llm/llm.py b/api/core/model_runtime/model_providers/upstage/llm/llm.py index a18ee9062..2bf6796ca 100644 --- a/api/core/model_runtime/model_providers/upstage/llm/llm.py +++ b/api/core/model_runtime/model_providers/upstage/llm/llm.py @@ -6,7 +6,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_message import FunctionCall -from tokenizers import Tokenizer +from tokenizers import Tokenizer # type: ignore from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta diff --git a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py index 5b340e53b..87693eca7 100644 --- a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py @@ -1,11 +1,10 @@ import base64 import time -from collections.abc import Mapping from typing import Union import numpy as np from openai import OpenAI -from tokenizers import Tokenizer +from tokenizers import Tokenizer # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType @@ -132,7 +131,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int return total_num_tokens - def validate_credentials(self, model: str, credentials: Mapping) -> None: + def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate model credentials diff --git a/api/core/model_runtime/model_providers/vertex_ai/_common.py b/api/core/model_runtime/model_providers/vertex_ai/_common.py index 8f7c859e3..4e3df7574 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/_common.py +++ b/api/core/model_runtime/model_providers/vertex_ai/_common.py @@ -12,4 +12,4 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ - pass + raise NotImplementedError diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index c50e0f794..85be34f3f 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -6,7 +6,7 @@ from collections.abc import Generator from typing import TYPE_CHECKING, Optional, Union, cast -import google.auth.transport.requests +import google.auth.transport.requests # type: ignore import requests from anthropic import AnthropicVertex, Stream from anthropic.types import ( diff --git a/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py index 034c066ab..782e4fd62 100644 --- a/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py @@ -17,14 +17,12 @@ class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel): def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: - features = [] - entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - features=features, + features=[], model_properties={ ModelPropertyKey.MODE: credentials.get("mode"), }, diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py index 1cffd902c..a8a015167 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py @@ -1,8 +1,8 @@ from collections.abc import Generator from typing import Optional, cast -from volcenginesdkarkruntime import Ark -from volcenginesdkarkruntime.types.chat import ( +from volcenginesdkarkruntime import Ark # type: ignore +from volcenginesdkarkruntime.types.chat import ( # type: ignore ChatCompletion, ChatCompletionAssistantMessageParam, ChatCompletionChunk, @@ -15,10 +15,10 @@ ChatCompletionToolParam, ChatCompletionUserMessageParam, ) -from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL -from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function -from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse -from volcenginesdkarkruntime.types.shared_params import FunctionDefinition +from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL # type: ignore +from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function # type: ignore +from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse # type: ignore +from volcenginesdkarkruntime.types.shared_params import FunctionDefinition # type: ignore from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py index 91dbe21a6..aa837b831 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py @@ -152,5 +152,6 @@ class ServiceNotOpenError(MaasError): def wrap_error(e: MaasError) -> Exception: if ErrorCodeMap.get(e.code): - return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) + # FIXME: mypy type error, try to fix it instead of using type: ignore + return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) # type: ignore return e diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index 9e19b7ded..f0b2b101b 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -2,7 +2,7 @@ from collections.abc import Generator from typing import Optional -from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk +from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py index cf3cf23cf..7c3736808 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel from core.model_runtime.entities.llm_entities import LLMMode @@ -102,7 +104,7 @@ def get_model_config(credentials: dict) -> ModelConfig: def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None): - req_params = {} + req_params: dict[str, Any] = {} # predefined properties model_configs = get_model_config(credentials) if model_configs: @@ -130,7 +132,7 @@ def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str] def get_v3_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None): - req_params = {} + req_params: dict[str, Any] = {} # predefined properties model_configs = get_model_config(credentials) if model_configs: diff --git a/api/core/model_runtime/model_providers/wenxin/_common.py b/api/core/model_runtime/model_providers/wenxin/_common.py index c77a49998..1247a11fe 100644 --- a/api/core/model_runtime/model_providers/wenxin/_common.py +++ b/api/core/model_runtime/model_providers/wenxin/_common.py @@ -122,6 +122,7 @@ class _CommonWenxin: "bge-large-zh": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh", "tao-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k", "bce-reranker-base_v1": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/reranker/bce_reranker_base", + "ernie-lite-pro-128k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-pro-128k", } function_calling_supports = [ diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-pro-128k.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-pro-128k.yaml new file mode 100644 index 000000000..4f5832c85 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-lite-pro-128k.yaml @@ -0,0 +1,42 @@ +model: ernie-lite-pro-128k +label: + en_US: Ernie-Lite-Pro-128K +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + min: 0.1 + max: 1.0 + default: 0.8 + - name: top_p + use_template: top_p + - name: min_output_tokens + label: + en_US: "Min Output Tokens" + zh_Hans: "最小输出Token数" + use_template: max_tokens + min: 2 + max: 2048 + help: + zh_Hans: 指定模型最小输出token数 + en_US: Specifies the lower limit on the length of generated results. + - name: max_output_tokens + label: + en_US: "Max Output Tokens" + zh_Hans: "最大输出Token数" + use_template: max_tokens + min: 2 + max: 2048 + default: 2048 + help: + zh_Hans: 指定模型最大输出token数 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index 07b970f81..d28997956 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -1,7 +1,7 @@ from collections.abc import Generator from enum import Enum from json import dumps, loads -from typing import Any, Union +from typing import Any, Optional, Union from requests import Response, post @@ -22,7 +22,7 @@ class Role(Enum): role: str = Role.USER.value content: str - usage: dict[str, int] = None + usage: Optional[dict[str, int]] = None stop_reason: str = "" def to_dict(self) -> dict[str, Any]: @@ -135,6 +135,7 @@ def _build_function_calling_request_body( """ TODO: implement function calling """ + raise NotImplementedError("Function calling is not supported yet.") def _build_chat_request_body( self, diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py index 19135deb2..816b3b98c 100644 --- a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py @@ -1,6 +1,5 @@ import time from abc import abstractmethod -from collections.abc import Mapping from json import dumps from typing import Any, Optional @@ -23,12 +22,12 @@ class TextEmbedding: @abstractmethod - def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + def embed_documents(self, model: str, texts: list[str], user: str) -> tuple[list[list[float]], int, int]: raise NotImplementedError class WenxinTextEmbedding(_CommonWenxin, TextEmbedding): - def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + def embed_documents(self, model: str, texts: list[str], user: str) -> tuple[list[list[float]], int, int]: access_token = self._get_access_token() url = f"{self.api_bases[model]}?access_token={access_token}" body = self._build_embed_request_body(model, texts, user) @@ -50,7 +49,7 @@ def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> } return body - def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int): + def _handle_embed_response(self, model: str, response: Response) -> tuple[list[list[float]], int, int]: data = response.json() if "error_code" in data: code = data["error_code"] @@ -147,7 +146,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int return total_num_tokens - def validate_credentials(self, model: str, credentials: Mapping) -> None: + def validate_credentials(self, model: str, credentials: dict) -> None: api_key = credentials["api_key"] secret_key = credentials["secret_key"] try: diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 8d86d6937..7db120364 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -17,7 +17,7 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_message import FunctionCall from openai.types.completion import Completion -from xinference_client.client.restful.restful_client import ( +from xinference_client.client.restful.restful_client import ( # type: ignore Client, RESTfulChatModelHandle, RESTfulGenerateModelHandle, diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index efaf11485..078ec0537 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -1,6 +1,6 @@ from typing import Optional -from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle +from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index 3d7aefeb6..5f330ece1 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -1,6 +1,6 @@ from typing import IO, Optional -from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle +from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index e51e6a941..9054aabab 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -1,7 +1,7 @@ import time from typing import Optional -from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle +from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject @@ -134,7 +134,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: handle = client.get_model(model_uid=model_uid) except RuntimeError as e: - raise InvokeAuthorizationError(e) + raise InvokeAuthorizationError(str(e)) if not isinstance(handle, RESTfulEmbeddingModelHandle): raise InvokeBadRequestError( diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index ad7b64efb..8aa39d4de 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -1,7 +1,7 @@ import concurrent.futures from typing import Any, Optional -from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle +from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle # type: ignore from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType @@ -74,11 +74,14 @@ def validate_credentials(self, model: str, credentials: dict) -> None: raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") credentials["server_url"] = credentials["server_url"].removesuffix("/") + api_key = credentials.get("api_key") + if api_key is None: + raise CredentialsValidateFailedError("api_key is required") extra_param = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials["server_url"], model_uid=credentials["model_uid"], - api_key=credentials.get("api_key"), + api_key=api_key, ) if "text-to-audio" not in extra_param.model_ability: diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index baa3ccbe8..b51423f4e 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -1,6 +1,6 @@ from threading import Lock from time import time -from typing import Optional +from typing import Any, Optional from requests.adapters import HTTPAdapter from requests.exceptions import ConnectionError, MissingSchema, Timeout @@ -39,13 +39,15 @@ def __init__( self.model_family = model_family -cache = {} +cache: dict[str, dict[str, Any]] = {} cache_lock = Lock() class XinferenceHelper: @staticmethod - def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: + def get_xinference_extra_parameter( + server_url: str, model_uid: str, api_key: str | None + ) -> XinferenceModelExtraParameter: XinferenceHelper._clean_cache() with cache_lock: if model_uid not in cache: @@ -66,7 +68,9 @@ def _clean_cache() -> None: pass @staticmethod - def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: + def _get_xinference_extra_parameter( + server_url: str, model_uid: str, api_key: str | None + ) -> XinferenceModelExtraParameter: """ get xinference model extra parameter like model_format and model_handle_type """ diff --git a/api/core/model_runtime/model_providers/yi/llm/llm.py b/api/core/model_runtime/model_providers/yi/llm/llm.py index 0642e72ed..f5b61e207 100644 --- a/api/core/model_runtime/model_providers/yi/llm/llm.py +++ b/api/core/model_runtime/model_providers/yi/llm/llm.py @@ -136,7 +136,7 @@ def _add_custom_parameters(credentials: dict) -> None: parsed_url = urlparse(credentials["endpoint_url"]) credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: return AIModelEntity( model=model, label=I18nObject(en_US=model, zh_Hans=model), diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml index 035d9881e..bf1dad81f 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml @@ -47,6 +47,18 @@ parameter_rules: help: zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.1' output: '0.1' diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml index c3ee76141..a88779ed1 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml @@ -47,6 +47,18 @@ parameter_rules: help: zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.001' output: '0.001' diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml index 1926db7ac..868dab924 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml @@ -47,6 +47,18 @@ parameter_rules: help: zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.01' output: '0.01' diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml index e54b5de4a..5ee67aade 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml @@ -47,6 +47,18 @@ parameter_rules: help: zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0' output: '0' diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flashx.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flashx.yaml index 724fe4890..abd95a3a9 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flashx.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flashx.yaml @@ -47,6 +47,18 @@ parameter_rules: help: zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0' output: '0' diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml index e1eb13df3..127f81c5a 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml @@ -47,6 +47,18 @@ parameter_rules: help: zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.1' output: '0.1' diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml index c0c4e04d3..ebcb7b4f7 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml @@ -49,6 +49,18 @@ parameter_rules: help: zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.001' output: '0.001' diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml index c4f26f8ba..7dd113c06 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml @@ -47,6 +47,18 @@ parameter_rules: help: zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.05' output: '0.05' diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml index 0d99f89cb..ceeeacced 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml @@ -45,6 +45,18 @@ parameter_rules: help: zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.05' output: '0.05' diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_flash.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_flash.yaml index c2047b2cd..9e3a7f203 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_flash.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_flash.yaml @@ -45,6 +45,18 @@ parameter_rules: help: zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.00' output: '0.00' diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml index 5cd0e16b0..3aa981def 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml @@ -46,6 +46,18 @@ parameter_rules: help: zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.01' output: '0.01' diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index 59861507e..6199fbb45 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -1,9 +1,10 @@ +import json from collections.abc import Generator from typing import Optional, Union -from zhipuai import ZhipuAI -from zhipuai.types.chat.chat_completion import Completion -from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk +from zhipuai import ZhipuAI # type: ignore +from zhipuai.types.chat.chat_completion import Completion # type: ignore +from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk # type: ignore from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -188,6 +189,23 @@ def _generate( else: model_parameters["tools"] = [web_search_params] + response_format = model_parameters.get("response_format") + if response_format: + if response_format == "json_schema": + json_schema = model_parameters.get("json_schema") + if not json_schema: + raise ValueError("Must define JSON Schema when the response format is json_schema") + try: + schema = json.loads(json_schema) + except: + raise ValueError(f"not correct json_schema format: {json_schema}") + model_parameters.pop("json_schema") + model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema} + else: + model_parameters["response_format"] = {"type": response_format} + elif "json_schema" in model_parameters: + del model_parameters["json_schema"] + if model.startswith("glm-4v"): params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) else: diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 2428284ba..a700304db 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -1,7 +1,7 @@ import time from typing import Optional -from zhipuai import ZhipuAI +from zhipuai import ZhipuAI # type: ignore from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py index 029ec1a58..810a7c4c4 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/core/model_runtime/schema_validators/common_validator.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Union, cast from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType @@ -38,7 +38,7 @@ def _validate_and_filter_credential_form_schemas( def _validate_credential_form_schema( self, credential_form_schema: CredentialFormSchema, credentials: dict - ) -> Optional[str]: + ) -> Union[str, bool, None]: """ Validate credential form schema @@ -47,6 +47,7 @@ def _validate_credential_form_schema( :return: validated credential form schema value """ # If the variable does not exist in credentials + value: Union[str, bool, None] = None if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: # If required is True, an exception is thrown if credential_form_schema.required: @@ -61,7 +62,7 @@ def _validate_credential_form_schema( return None # Get the value corresponding to the variable from credentials - value = credentials[credential_form_schema.variable] + value = cast(str, credentials[credential_form_schema.variable]) # If max_length=0, no validation is performed if credential_form_schema.max_length: @@ -86,6 +87,6 @@ def _validate_credential_form_schema( if value.lower() not in {"true", "false"}: raise ValueError(f"Variable {credential_form_schema.variable} should be true or false") - value = True if value.lower() == "true" else False + value = value.lower() == "true" return value diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index ec1bad569..03e350627 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -129,7 +129,8 @@ def jsonable_encoder( sqlalchemy_safe=sqlalchemy_safe, ) if dataclasses.is_dataclass(obj): - obj_dict = dataclasses.asdict(obj) + # FIXME: mypy error, try to fix it instead of using type: ignore + obj_dict = dataclasses.asdict(obj) # type: ignore return jsonable_encoder( obj_dict, by_alias=by_alias, diff --git a/api/core/model_runtime/utils/helper.py b/api/core/model_runtime/utils/helper.py index 2067092d8..5e8a723ec 100644 --- a/api/core/model_runtime/utils/helper.py +++ b/api/core/model_runtime/utils/helper.py @@ -4,6 +4,7 @@ def dump_model(model: BaseModel) -> dict: if hasattr(pydantic, "model_dump"): - return pydantic.model_dump(model) + # FIXME mypy error, try to fix it instead of using type: ignore + return pydantic.model_dump(model) # type: ignore else: return model.model_dump() diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 094ad7863..c65a3885f 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,3 +1,5 @@ +from typing import Optional + from pydantic import BaseModel from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor @@ -43,6 +45,8 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["inputs_config"]["enabled"]: params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query) @@ -57,6 +61,8 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["outputs_config"]["enabled"]: params = ModerationOutputParams(app_id=self.app_id, text=text) @@ -69,14 +75,18 @@ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: ) def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict: - extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id")) + if self.config is None: + raise ValueError("The config is not set.") + extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", "")) + if not extension: + raise ValueError("API-based Extension not found. Please check it again.") requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key)) result = requestor.request(extension_point, params) return result @staticmethod - def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: extension = ( db.session.query(APIBasedExtension) .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 60898d554..d8c392d09 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -100,14 +100,14 @@ def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_re if not inputs_config.get("preset_response"): raise ValueError("inputs_config.preset_response is required") - if len(inputs_config.get("preset_response")) > 100: + if len(inputs_config.get("preset_response", 0)) > 100: raise ValueError("inputs_config.preset_response must be less than 100 characters") if outputs_config_enabled: if not outputs_config.get("preset_response"): raise ValueError("outputs_config.preset_response is required") - if len(outputs_config.get("preset_response")) > 100: + if len(outputs_config.get("preset_response", 0)) > 100: raise ValueError("outputs_config.preset_response must be less than 100 characters") diff --git a/api/core/moderation/factory.py b/api/core/moderation/factory.py index 96bf2ab54..0ad4438c1 100644 --- a/api/core/moderation/factory.py +++ b/api/core/moderation/factory.py @@ -22,7 +22,8 @@ def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: """ code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config) extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) - extension_class.validate_config(tenant_id, config) + # FIXME: mypy error, try to fix it instead of using type: ignore + extension_class.validate_config(tenant_id, config) # type: ignore def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: """ diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 46d3963bd..3ac33966c 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -1,5 +1,6 @@ import logging -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional from core.app.app_config.entities import AppConfig from core.moderation.base import ModerationAction, ModerationError @@ -17,11 +18,11 @@ def check( app_id: str, tenant_id: str, app_config: AppConfig, - inputs: dict, + inputs: Mapping[str, Any], query: str, message_id: str, trace_manager: Optional[TraceQueueManager] = None, - ) -> tuple[bool, dict, str]: + ) -> tuple[bool, Mapping[str, Any], str]: """ Process sensitive_word_avoidance. :param app_id: app id @@ -33,6 +34,7 @@ def check( :param trace_manager: trace manager :return: """ + inputs = dict(inputs) if not app_config.sensitive_word_avoidance: return False, inputs, query diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 00b3c56c0..9dd2665c3 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -21,7 +21,7 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: if not config.get("keywords"): raise ValueError("keywords is required") - if len(config.get("keywords")) > 10000: + if len(config.get("keywords", [])) > 10000: raise ValueError("keywords length must be less than 10000") keywords_row_len = config["keywords"].split("\n") @@ -31,6 +31,8 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["inputs_config"]["enabled"]: preset_response = self.config["inputs_config"]["preset_response"] @@ -50,6 +52,8 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["outputs_config"]["enabled"]: # Filter out empty values diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 6465de23b..d64f17b38 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -20,6 +20,8 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["inputs_config"]["enabled"]: preset_response = self.config["inputs_config"]["preset_response"] @@ -35,6 +37,8 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" + if self.config is None: + raise ValueError("The config is not set.") if self.config["outputs_config"]["enabled"]: flagged = self._is_violated({"text": text}) diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 4635bd9c2..e595be126 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -70,7 +70,7 @@ def start_thread(self) -> threading.Thread: thread = threading.Thread( target=self.worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE, }, ) diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index ef0f9c708..b484242b6 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -6,6 +6,7 @@ class TracingProviderEnum(Enum): LANGFUSE = "langfuse" LANGSMITH = "langsmith" + OPIK = "opik" class BaseTracingConfig(BaseModel): @@ -56,5 +57,36 @@ def set_value(cls, v, info: ValidationInfo): return v +class OpikConfig(BaseTracingConfig): + """ + Model class for Opik tracing config. + """ + + api_key: str | None = None + project: str | None = None + workspace: str | None = None + url: str = "https://www.comet.com/opik/api/" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + if v is None or v == "": + v = "Default Project" + + return v + + @field_validator("url") + @classmethod + def url_validator(cls, v, info: ValidationInfo): + if v is None or v == "": + v = "https://www.comet.com/opik/api/" + if not v.startswith(("https://", "http://")): + raise ValueError("url must start with https:// or http://") + if not v.endswith("/api/"): + raise ValueError("url should ends with /api/") + + return v + + OPS_FILE_PATH = "ops_trace/" OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE" diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 71ff03b6e..f0e34c0cd 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from datetime import datetime from enum import StrEnum from typing import Any, Optional, Union @@ -38,8 +39,8 @@ class WorkflowTraceInfo(BaseTraceInfo): workflow_run_id: str workflow_run_elapsed_time: Union[int, float] workflow_run_status: str - workflow_run_inputs: dict[str, Any] - workflow_run_outputs: dict[str, Any] + workflow_run_inputs: Mapping[str, Any] + workflow_run_outputs: Mapping[str, Any] workflow_run_version: str error: Optional[str] = None total_tokens: int diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 29fdebd8f..b9ba068b1 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -77,8 +77,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): id=trace_id, user_id=user_id, name=name, - input=trace_info.workflow_run_inputs, - output=trace_info.workflow_run_outputs, + input=dict(trace_info.workflow_run_inputs), + output=dict(trace_info.workflow_run_outputs), metadata=metadata, session_id=trace_info.conversation_id, tags=["message", "workflow"], @@ -87,8 +87,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): workflow_span_data = LangfuseSpan( id=trace_info.workflow_run_id, name=TraceTaskName.WORKFLOW_TRACE.value, - input=trace_info.workflow_run_inputs, - output=trace_info.workflow_run_outputs, + input=dict(trace_info.workflow_run_inputs), + output=dict(trace_info.workflow_run_outputs), trace_id=trace_id, start_time=trace_info.start_time, end_time=trace_info.end_time, @@ -102,8 +102,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): id=trace_id, user_id=user_id, name=TraceTaskName.WORKFLOW_TRACE.value, - input=trace_info.workflow_run_inputs, - output=trace_info.workflow_run_outputs, + input=dict(trace_info.workflow_run_inputs), + output=dict(trace_info.workflow_run_outputs), metadata=metadata, session_id=trace_info.conversation_id, tags=["workflow"], diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index 99221d669..348b7ba50 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -49,7 +49,6 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run") input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") - dotted_order: Optional[str] = Field(None, description="Dotted order of the run") @field_validator("inputs", "outputs") @classmethod diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 672843e5a..4ffd888bd 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -3,6 +3,7 @@ import os import uuid from datetime import datetime, timedelta +from typing import Optional, cast from langsmith import Client from langsmith.schemas import RunBase @@ -63,6 +64,8 @@ def trace(self, trace_info: BaseTraceInfo): def workflow_trace(self, trace_info: WorkflowTraceInfo): trace_id = trace_info.message_id or trace_info.workflow_run_id + if trace_info.start_time is None: + trace_info.start_time = datetime.now() message_dotted_order = ( generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None ) @@ -78,8 +81,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): message_run = LangSmithRunModel( id=trace_info.message_id, name=TraceTaskName.MESSAGE_TRACE.value, - inputs=trace_info.workflow_run_inputs, - outputs=trace_info.workflow_run_outputs, + inputs=dict(trace_info.workflow_run_inputs), + outputs=dict(trace_info.workflow_run_outputs), run_type=LangSmithRunType.chain, start_time=trace_info.start_time, end_time=trace_info.end_time, @@ -90,6 +93,15 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): error=trace_info.error, trace_id=trace_id, dotted_order=message_dotted_order, + file_list=[], + serialized=None, + parent_run_id=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, ) self.add_run(message_run) @@ -98,11 +110,11 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): total_tokens=trace_info.total_tokens, id=trace_info.workflow_run_id, name=TraceTaskName.WORKFLOW_TRACE.value, - inputs=trace_info.workflow_run_inputs, + inputs=dict(trace_info.workflow_run_inputs), run_type=LangSmithRunType.tool, start_time=trace_info.workflow_data.created_at, end_time=trace_info.workflow_data.finished_at, - outputs=trace_info.workflow_run_outputs, + outputs=dict(trace_info.workflow_run_outputs), extra={ "metadata": metadata, }, @@ -111,6 +123,13 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): parent_run_id=trace_info.message_id or None, trace_id=trace_id, dotted_order=workflow_dotted_order, + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, ) self.add_run(langsmith_run) @@ -211,25 +230,35 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): id=node_execution_id, trace_id=trace_id, dotted_order=node_dotted_order, + error="", + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, ) self.add_run(langsmith_run) def message_trace(self, trace_info: MessageTraceInfo): # get message file data - file_list = trace_info.file_list - message_file_data: MessageFile = trace_info.message_file_data + file_list = cast(list[str], trace_info.file_list) or [] + message_file_data: Optional[MessageFile] = trace_info.message_file_data file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" file_list.append(file_url) metadata = trace_info.metadata message_data = trace_info.message_data + if message_data is None: + return message_id = message_data.id user_id = message_data.from_account_id metadata["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: EndUser = ( + end_user_data: Optional[EndUser] = ( db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() ) if end_user_data is not None: @@ -247,12 +276,20 @@ def message_trace(self, trace_info: MessageTraceInfo): start_time=trace_info.start_time, end_time=trace_info.end_time, outputs=message_data.answer, - extra={ - "metadata": metadata, - }, + extra={"metadata": metadata}, tags=["message", str(trace_info.conversation_mode)], error=trace_info.error, file_list=file_list, + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + parent_run_id=None, ) self.add_run(message_run) @@ -267,17 +304,27 @@ def message_trace(self, trace_info: MessageTraceInfo): start_time=trace_info.start_time, end_time=trace_info.end_time, outputs=message_data.answer, - extra={ - "metadata": metadata, - }, + extra={"metadata": metadata}, parent_run_id=message_id, tags=["llm", str(trace_info.conversation_mode)], error=trace_info.error, file_list=file_list, + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + id=str(uuid.uuid4()), ) self.add_run(llm_run) def moderation_trace(self, trace_info: ModerationTraceInfo): + if trace_info.message_data is None: + return langsmith_run = LangSmithRunModel( name=TraceTaskName.MODERATION_TRACE.value, inputs=trace_info.inputs, @@ -288,48 +335,82 @@ def moderation_trace(self, trace_info: ModerationTraceInfo): "inputs": trace_info.inputs, }, run_type=LangSmithRunType.tool, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["moderation"], parent_run_id=trace_info.message_id, start_time=trace_info.start_time or trace_info.message_data.created_at, end_time=trace_info.end_time or trace_info.message_data.updated_at, + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], ) self.add_run(langsmith_run) def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): message_data = trace_info.message_data + if message_data is None: + return suggested_question_run = LangSmithRunModel( name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, inputs=trace_info.inputs, outputs=trace_info.suggested_question, run_type=LangSmithRunType.tool, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["suggested_question"], parent_run_id=trace_info.message_id, start_time=trace_info.start_time or message_data.created_at, end_time=trace_info.end_time or message_data.updated_at, + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], ) self.add_run(suggested_question_run) def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): + if trace_info.message_data is None: + return dataset_retrieval_run = LangSmithRunModel( name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, inputs=trace_info.inputs, outputs={"documents": trace_info.documents}, run_type=LangSmithRunType.retriever, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["dataset_retrieval"], parent_run_id=trace_info.message_id, start_time=trace_info.start_time or trace_info.message_data.created_at, end_time=trace_info.end_time or trace_info.message_data.updated_at, + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], ) self.add_run(dataset_retrieval_run) @@ -347,7 +428,18 @@ def tool_trace(self, trace_info: ToolTraceInfo): parent_run_id=trace_info.message_id, start_time=trace_info.start_time, end_time=trace_info.end_time, - file_list=[trace_info.file_url], + file_list=[cast(str, trace_info.file_url)], + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error=trace_info.error or "", ) self.add_run(tool_run) @@ -358,12 +450,23 @@ def generate_name_trace(self, trace_info: GenerateNameTraceInfo): inputs=trace_info.inputs, outputs=trace_info.outputs, run_type=LangSmithRunType.tool, - extra={ - "metadata": trace_info.metadata, - }, + extra={"metadata": trace_info.metadata}, tags=["generate_name"], start_time=trace_info.start_time or datetime.now(), end_time=trace_info.end_time or datetime.now(), + id=str(uuid.uuid4()), + serialized=None, + events=[], + session_id=None, + session_name=None, + reference_example_id=None, + input_attachments={}, + output_attachments={}, + trace_id=None, + dotted_order=None, + error="", + file_list=[], + parent_run_id=None, ) self.add_run(name_run) diff --git a/api/core/ops/opik_trace/__init__.py b/api/core/ops/opik_trace/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py new file mode 100644 index 000000000..fabf38fbd --- /dev/null +++ b/api/core/ops/opik_trace/opik_trace.py @@ -0,0 +1,469 @@ +import json +import logging +import os +import uuid +from datetime import datetime, timedelta +from typing import Optional, cast + +from opik import Opik, Trace +from opik.id_helpers import uuid4_to_uuid7 + +from core.ops.base_trace_instance import BaseTraceInstance +from core.ops.entities.config_entity import OpikConfig +from core.ops.entities.trace_entity import ( + BaseTraceInfo, + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + TraceTaskName, + WorkflowTraceInfo, +) +from extensions.ext_database import db +from models.model import EndUser, MessageFile +from models.workflow import WorkflowNodeExecution + +logger = logging.getLogger(__name__) + + +def wrap_dict(key_name, data): + """Make sure that the input data is a dict""" + if not isinstance(data, dict): + return {key_name: data} + + return data + + +def wrap_metadata(metadata, **kwargs): + """Add common metatada to all Traces and Spans""" + metadata["created_from"] = "dify" + + metadata.update(kwargs) + + return metadata + + +def prepare_opik_uuid(user_datetime: Optional[datetime], user_uuid: Optional[str]): + """Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most + messages and objects. The type-hints of BaseTraceInfo indicates that + objects start_time and message_id could be null which means we cannot map + it to a UUIDv7. Given that we have no way to identify that object + uniquely, generate a new random one UUIDv7 in that case. + """ + + if user_datetime is None: + user_datetime = datetime.now() + + if user_uuid is None: + user_uuid = str(uuid.uuid4()) + + return uuid4_to_uuid7(user_datetime, user_uuid) + + +class OpikDataTrace(BaseTraceInstance): + def __init__( + self, + opik_config: OpikConfig, + ): + super().__init__(opik_config) + self.opik_client = Opik( + project_name=opik_config.project, + workspace=opik_config.workspace, + host=opik_config.url, + api_key=opik_config.api_key, + ) + self.project = opik_config.project + self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") + + def trace(self, trace_info: BaseTraceInfo): + if isinstance(trace_info, WorkflowTraceInfo): + self.workflow_trace(trace_info) + if isinstance(trace_info, MessageTraceInfo): + self.message_trace(trace_info) + if isinstance(trace_info, ModerationTraceInfo): + self.moderation_trace(trace_info) + if isinstance(trace_info, SuggestedQuestionTraceInfo): + self.suggested_question_trace(trace_info) + if isinstance(trace_info, DatasetRetrievalTraceInfo): + self.dataset_retrieval_trace(trace_info) + if isinstance(trace_info, ToolTraceInfo): + self.tool_trace(trace_info) + if isinstance(trace_info, GenerateNameTraceInfo): + self.generate_name_trace(trace_info) + + def workflow_trace(self, trace_info: WorkflowTraceInfo): + dify_trace_id = trace_info.workflow_run_id + opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) + workflow_metadata = wrap_metadata( + trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id + ) + root_span_id = None + + if trace_info.message_id: + dify_trace_id = trace_info.message_id + opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) + + trace_data = { + "id": opik_trace_id, + "name": TraceTaskName.MESSAGE_TRACE.value, + "start_time": trace_info.start_time, + "end_time": trace_info.end_time, + "metadata": workflow_metadata, + "input": wrap_dict("input", trace_info.workflow_run_inputs), + "output": wrap_dict("output", trace_info.workflow_run_outputs), + "tags": ["message", "workflow"], + "project_name": self.project, + } + self.add_trace(trace_data) + + root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id) + span_data = { + "id": root_span_id, + "parent_span_id": None, + "trace_id": opik_trace_id, + "name": TraceTaskName.WORKFLOW_TRACE.value, + "input": wrap_dict("input", trace_info.workflow_run_inputs), + "output": wrap_dict("output", trace_info.workflow_run_outputs), + "start_time": trace_info.start_time, + "end_time": trace_info.end_time, + "metadata": workflow_metadata, + "tags": ["workflow"], + "project_name": self.project, + } + self.add_span(span_data) + else: + trace_data = { + "id": opik_trace_id, + "name": TraceTaskName.MESSAGE_TRACE.value, + "start_time": trace_info.start_time, + "end_time": trace_info.end_time, + "metadata": workflow_metadata, + "input": wrap_dict("input", trace_info.workflow_run_inputs), + "output": wrap_dict("output", trace_info.workflow_run_outputs), + "tags": ["workflow"], + "project_name": self.project, + } + self.add_trace(trace_data) + + # through workflow_run_id get all_nodes_execution + workflow_nodes_execution_id_records = ( + db.session.query(WorkflowNodeExecution.id) + .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) + .all() + ) + + for node_execution_id_record in workflow_nodes_execution_id_records: + node_execution = ( + db.session.query( + WorkflowNodeExecution.id, + WorkflowNodeExecution.tenant_id, + WorkflowNodeExecution.app_id, + WorkflowNodeExecution.title, + WorkflowNodeExecution.node_type, + WorkflowNodeExecution.status, + WorkflowNodeExecution.inputs, + WorkflowNodeExecution.outputs, + WorkflowNodeExecution.created_at, + WorkflowNodeExecution.elapsed_time, + WorkflowNodeExecution.process_data, + WorkflowNodeExecution.execution_metadata, + ) + .filter(WorkflowNodeExecution.id == node_execution_id_record.id) + .first() + ) + + if not node_execution: + continue + + node_execution_id = node_execution.id + tenant_id = node_execution.tenant_id + app_id = node_execution.app_id + node_name = node_execution.title + node_type = node_execution.node_type + status = node_execution.status + if node_type == "llm": + inputs = ( + json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} + ) + else: + inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} + outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} + created_at = node_execution.created_at or datetime.now() + elapsed_time = node_execution.elapsed_time + finished_at = created_at + timedelta(seconds=elapsed_time) + + execution_metadata = ( + json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} + ) + metadata = execution_metadata.copy() + metadata.update( + { + "workflow_run_id": trace_info.workflow_run_id, + "node_execution_id": node_execution_id, + "tenant_id": tenant_id, + "app_id": app_id, + "app_name": node_name, + "node_type": node_type, + "status": status, + } + ) + + process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} + + provider = None + model = None + total_tokens = 0 + completion_tokens = 0 + prompt_tokens = 0 + + if process_data and process_data.get("model_mode") == "chat": + run_type = "llm" + provider = process_data.get("model_provider", None) + model = process_data.get("model_name", "") + metadata.update( + { + "ls_provider": provider, + "ls_model_name": model, + } + ) + + try: + if outputs.get("usage"): + total_tokens = outputs["usage"].get("total_tokens", 0) + prompt_tokens = outputs["usage"].get("prompt_tokens", 0) + completion_tokens = outputs["usage"].get("completion_tokens", 0) + except Exception: + logger.error("Failed to extract usage", exc_info=True) + + else: + run_type = "tool" + + parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id + + if not total_tokens: + total_tokens = execution_metadata.get("total_tokens", 0) + + span_data = { + "trace_id": opik_trace_id, + "id": prepare_opik_uuid(created_at, node_execution_id), + "parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id), + "name": node_type, + "type": run_type, + "start_time": created_at, + "end_time": finished_at, + "metadata": wrap_metadata(metadata), + "input": wrap_dict("input", inputs), + "output": wrap_dict("output", outputs), + "tags": ["node_execution"], + "project_name": self.project, + "usage": { + "total_tokens": total_tokens, + "completion_tokens": completion_tokens, + "prompt_tokens": prompt_tokens, + }, + "model": model, + "provider": provider, + } + + self.add_span(span_data) + + def message_trace(self, trace_info: MessageTraceInfo): + # get message file data + file_list = cast(list[str], trace_info.file_list) or [] + message_file_data: Optional[MessageFile] = trace_info.message_file_data + + if message_file_data is not None: + file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" + file_list.append(file_url) + + message_data = trace_info.message_data + if message_data is None: + return + + metadata = trace_info.metadata + message_id = trace_info.message_id + + user_id = message_data.from_account_id + metadata["user_id"] = user_id + metadata["file_list"] = file_list + + if message_data.from_end_user_id: + end_user_data: Optional[EndUser] = ( + db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() + ) + if end_user_data is not None: + end_user_id = end_user_data.session_id + metadata["end_user_id"] = end_user_id + + trace_data = { + "id": prepare_opik_uuid(trace_info.start_time, message_id), + "name": TraceTaskName.MESSAGE_TRACE.value, + "start_time": trace_info.start_time, + "end_time": trace_info.end_time, + "metadata": wrap_metadata(metadata), + "input": trace_info.inputs, + "output": message_data.answer, + "tags": ["message", str(trace_info.conversation_mode)], + "project_name": self.project, + } + trace = self.add_trace(trace_data) + + span_data = { + "trace_id": trace.id, + "name": "llm", + "type": "llm", + "start_time": trace_info.start_time, + "end_time": trace_info.end_time, + "metadata": wrap_metadata(metadata), + "input": {"input": trace_info.inputs}, + "output": {"output": message_data.answer}, + "tags": ["llm", str(trace_info.conversation_mode)], + "usage": { + "completion_tokens": trace_info.answer_tokens, + "prompt_tokens": trace_info.message_tokens, + "total_tokens": trace_info.total_tokens, + }, + "project_name": self.project, + } + self.add_span(span_data) + + def moderation_trace(self, trace_info: ModerationTraceInfo): + if trace_info.message_data is None: + return + + start_time = trace_info.start_time or trace_info.message_data.created_at + + span_data = { + "trace_id": prepare_opik_uuid(start_time, trace_info.message_id), + "name": TraceTaskName.MODERATION_TRACE.value, + "type": "tool", + "start_time": start_time, + "end_time": trace_info.end_time or trace_info.message_data.updated_at, + "metadata": wrap_metadata(trace_info.metadata), + "input": wrap_dict("input", trace_info.inputs), + "output": { + "action": trace_info.action, + "flagged": trace_info.flagged, + "preset_response": trace_info.preset_response, + "inputs": trace_info.inputs, + }, + "tags": ["moderation"], + } + + self.add_span(span_data) + + def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): + message_data = trace_info.message_data + if message_data is None: + return + + start_time = trace_info.start_time or message_data.created_at + + span_data = { + "trace_id": prepare_opik_uuid(start_time, trace_info.message_id), + "name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value, + "type": "tool", + "start_time": start_time, + "end_time": trace_info.end_time or message_data.updated_at, + "metadata": wrap_metadata(trace_info.metadata), + "input": wrap_dict("input", trace_info.inputs), + "output": wrap_dict("output", trace_info.suggested_question), + "tags": ["suggested_question"], + } + + self.add_span(span_data) + + def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): + if trace_info.message_data is None: + return + + start_time = trace_info.start_time or trace_info.message_data.created_at + + span_data = { + "trace_id": prepare_opik_uuid(start_time, trace_info.message_id), + "name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value, + "type": "tool", + "start_time": start_time, + "end_time": trace_info.end_time or trace_info.message_data.updated_at, + "metadata": wrap_metadata(trace_info.metadata), + "input": wrap_dict("input", trace_info.inputs), + "output": {"documents": trace_info.documents}, + "tags": ["dataset_retrieval"], + } + + self.add_span(span_data) + + def tool_trace(self, trace_info: ToolTraceInfo): + span_data = { + "trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id), + "name": trace_info.tool_name, + "type": "tool", + "start_time": trace_info.start_time, + "end_time": trace_info.end_time, + "metadata": wrap_metadata(trace_info.metadata), + "input": wrap_dict("input", trace_info.tool_inputs), + "output": wrap_dict("output", trace_info.tool_outputs), + "tags": ["tool", trace_info.tool_name], + } + + self.add_span(span_data) + + def generate_name_trace(self, trace_info: GenerateNameTraceInfo): + trace_data = { + "id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id), + "name": TraceTaskName.GENERATE_NAME_TRACE.value, + "start_time": trace_info.start_time, + "end_time": trace_info.end_time, + "metadata": wrap_metadata(trace_info.metadata), + "input": trace_info.inputs, + "output": trace_info.outputs, + "tags": ["generate_name"], + "project_name": self.project, + } + + trace = self.add_trace(trace_data) + + span_data = { + "trace_id": trace.id, + "name": TraceTaskName.GENERATE_NAME_TRACE.value, + "start_time": trace_info.start_time, + "end_time": trace_info.end_time, + "metadata": wrap_metadata(trace_info.metadata), + "input": wrap_dict("input", trace_info.inputs), + "output": wrap_dict("output", trace_info.outputs), + "tags": ["generate_name"], + } + + self.add_span(span_data) + + def add_trace(self, opik_trace_data: dict) -> Trace: + try: + trace = self.opik_client.trace(**opik_trace_data) + logger.debug("Opik Trace created successfully") + return trace + except Exception as e: + raise ValueError(f"Opik Failed to create trace: {str(e)}") + + def add_span(self, opik_span_data: dict): + try: + self.opik_client.span(**opik_span_data) + logger.debug("Opik Span created successfully") + except Exception as e: + raise ValueError(f"Opik Failed to create span: {str(e)}") + + def api_check(self): + try: + self.opik_client.auth_check() + return True + except Exception as e: + logger.info(f"Opik API check failed: {str(e)}", exc_info=True) + raise ValueError(f"Opik API check failed: {str(e)}") + + def get_project_url(self): + try: + return self.opik_client.get_project_url(project_name=self.project) + except Exception as e: + logger.info(f"Opik get run url failed: {str(e)}", exc_info=True) + raise ValueError(f"Opik get run url failed: {str(e)}") diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index a04fc6ee7..c153e3f9d 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -9,12 +9,15 @@ from uuid import UUID, uuid4 from flask import current_app +from sqlalchemy import select +from sqlalchemy.orm import Session from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token from core.ops.entities.config_entity import ( OPS_FILE_PATH, LangfuseConfig, LangSmithConfig, + OpikConfig, TracingProviderEnum, ) from core.ops.entities.trace_entity import ( @@ -30,14 +33,15 @@ ) from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace +from core.ops.opik_trace.opik_trace import OpikDataTrace from core.ops.utils import get_message_data from extensions.ext_database import db from extensions.ext_storage import storage -from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig +from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig from models.workflow import WorkflowAppLog, WorkflowRun from tasks.ops_trace_task import process_trace_tasks -provider_config_map = { +provider_config_map: dict[str, dict[str, Any]] = { TracingProviderEnum.LANGFUSE.value: { "config_class": LangfuseConfig, "secret_keys": ["public_key", "secret_key"], @@ -50,6 +54,12 @@ "other_keys": ["project", "endpoint"], "trace_instance": LangSmithDataTrace, }, + TracingProviderEnum.OPIK.value: { + "config_class": OpikConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "url", "workspace"], + "trace_instance": OpikDataTrace, + }, } @@ -145,7 +155,7 @@ def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str): :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = ( + trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -155,7 +165,11 @@ def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str): return None # decrypt_token - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + app = db.session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError("App not found") + + tenant_id = app.tenant_id decrypt_tracing_config = cls.decrypt_tracing_config( tenant_id, tracing_provider, trace_config_data.tracing_config ) @@ -178,7 +192,7 @@ def get_ops_trace_instance( if app_id is None: return None - app: App = db.session.query(App).filter(App.id == app_id).first() + app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() if app is None: return None @@ -209,8 +223,12 @@ def get_ops_trace_instance( def get_app_config_through_message_id(cls, message_id: str): app_model_config = None message_data = db.session.query(Message).filter(Message.id == message_id).first() + if not message_data: + return None conversation_id = message_data.conversation_id conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() + if not conversation_data: + return None if conversation_data.app_model_config_id: app_model_config = ( @@ -236,7 +254,9 @@ def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: if tracing_provider not in provider_config_map and tracing_provider is not None: raise ValueError(f"Invalid tracing provider: {tracing_provider}") - app_config: App = db.session.query(App).filter(App.id == app_id).first() + app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + if not app_config: + raise ValueError("App not found") app_config.tracing = json.dumps( { "enabled": enabled, @@ -252,7 +272,9 @@ def get_app_tracing_config(cls, app_id: str): :param app_id: app id :return: """ - app: App = db.session.query(App).filter(App.id == app_id).first() + app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError("App not found") if not app.tracing: return {"enabled": False, "tracing_provider": None} app_trace_config = json.loads(app.tracing) @@ -317,15 +339,15 @@ def __init__( ): self.trace_type = trace_type self.message_id = message_id - self.workflow_run = workflow_run + self.workflow_run_id = workflow_run.id if workflow_run else None self.conversation_id = conversation_id self.user_id = user_id self.timer = timer - self.kwargs = kwargs self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") - self.app_id = None + self.kwargs = kwargs + def execute(self): return self.preprocess() @@ -333,19 +355,23 @@ def preprocess(self): preprocess_map = { TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs), TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace( - self.workflow_run, self.conversation_id, self.user_id + workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id + ), + TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id), + TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( + message_id=self.message_id, timer=self.timer, **self.kwargs ), - TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id), - TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs), TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace( - self.message_id, self.timer, **self.kwargs + message_id=self.message_id, timer=self.timer, **self.kwargs ), TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace( - self.message_id, self.timer, **self.kwargs + message_id=self.message_id, timer=self.timer, **self.kwargs + ), + TraceTaskName.TOOL_TRACE: lambda: self.tool_trace( + message_id=self.message_id, timer=self.timer, **self.kwargs ), - TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(self.message_id, self.timer, **self.kwargs), TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace( - self.conversation_id, self.timer, **self.kwargs + conversation_id=self.conversation_id, timer=self.timer, **self.kwargs ), } @@ -355,86 +381,100 @@ def preprocess(self): def conversation_trace(self, **kwargs): return kwargs - def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id): - if not workflow_run: - raise ValueError("Workflow run not found") - - db.session.merge(workflow_run) - db.sessoin.refresh(workflow_run) - - workflow_id = workflow_run.workflow_id - tenant_id = workflow_run.tenant_id - workflow_run_id = workflow_run.id - workflow_run_elapsed_time = workflow_run.elapsed_time - workflow_run_status = workflow_run.status - workflow_run_inputs = workflow_run.inputs_dict - workflow_run_outputs = workflow_run.outputs_dict - workflow_run_version = workflow_run.version - error = workflow_run.error or "" - - total_tokens = workflow_run.total_tokens - - file_list = workflow_run_inputs.get("sys.file") or [] - query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" - - # get workflow_app_log_id - workflow_app_log_data = ( - db.session.query(WorkflowAppLog) - .filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id) - .first() - ) - workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None - # get message_id - message_data = ( - db.session.query(Message.id) - .filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id) - .first() - ) - message_id = str(message_data.id) if message_data else None - - metadata = { - "workflow_id": workflow_id, - "conversation_id": conversation_id, - "workflow_run_id": workflow_run_id, - "tenant_id": tenant_id, - "elapsed_time": workflow_run_elapsed_time, - "status": workflow_run_status, - "version": workflow_run_version, - "total_tokens": total_tokens, - "file_list": file_list, - "triggered_form": workflow_run.triggered_from, - "user_id": user_id, - } + def workflow_trace( + self, + *, + workflow_run_id: str | None, + conversation_id: str | None, + user_id: str | None, + ): + if not workflow_run_id: + return {} - workflow_trace_info = WorkflowTraceInfo( - workflow_data=workflow_run.to_dict(), - conversation_id=conversation_id, - workflow_id=workflow_id, - tenant_id=tenant_id, - workflow_run_id=workflow_run_id, - workflow_run_elapsed_time=workflow_run_elapsed_time, - workflow_run_status=workflow_run_status, - workflow_run_inputs=workflow_run_inputs, - workflow_run_outputs=workflow_run_outputs, - workflow_run_version=workflow_run_version, - error=error, - total_tokens=total_tokens, - file_list=file_list, - query=query, - metadata=metadata, - workflow_app_log_id=workflow_app_log_id, - message_id=message_id, - start_time=workflow_run.created_at, - end_time=workflow_run.finished_at, - ) + with Session(db.engine) as session: + workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) + workflow_run = session.scalars(workflow_run_stmt).first() + if not workflow_run: + raise ValueError("Workflow run not found") + + workflow_id = workflow_run.workflow_id + tenant_id = workflow_run.tenant_id + workflow_run_id = workflow_run.id + workflow_run_elapsed_time = workflow_run.elapsed_time + workflow_run_status = workflow_run.status + workflow_run_inputs = workflow_run.inputs_dict + workflow_run_outputs = workflow_run.outputs_dict + workflow_run_version = workflow_run.version + error = workflow_run.error or "" + + total_tokens = workflow_run.total_tokens + + file_list = workflow_run_inputs.get("sys.file") or [] + query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" + + # get workflow_app_log_id + workflow_app_log_data_stmt = select(WorkflowAppLog.id).where( + WorkflowAppLog.tenant_id == tenant_id, + WorkflowAppLog.app_id == workflow_run.app_id, + WorkflowAppLog.workflow_run_id == workflow_run.id, + ) + workflow_app_log_id = session.scalar(workflow_app_log_data_stmt) + # get message_id + message_id = None + if conversation_id: + message_data_stmt = select(Message.id).where( + Message.conversation_id == conversation_id, + Message.workflow_run_id == workflow_run_id, + ) + message_id = session.scalar(message_data_stmt) + + metadata = { + "workflow_id": workflow_id, + "conversation_id": conversation_id, + "workflow_run_id": workflow_run_id, + "tenant_id": tenant_id, + "elapsed_time": workflow_run_elapsed_time, + "status": workflow_run_status, + "version": workflow_run_version, + "total_tokens": total_tokens, + "file_list": file_list, + "triggered_form": workflow_run.triggered_from, + "user_id": user_id, + } + workflow_trace_info = WorkflowTraceInfo( + workflow_data=workflow_run.to_dict(), + conversation_id=conversation_id, + workflow_id=workflow_id, + tenant_id=tenant_id, + workflow_run_id=workflow_run_id, + workflow_run_elapsed_time=workflow_run_elapsed_time, + workflow_run_status=workflow_run_status, + workflow_run_inputs=workflow_run_inputs, + workflow_run_outputs=workflow_run_outputs, + workflow_run_version=workflow_run_version, + error=error, + total_tokens=total_tokens, + file_list=file_list, + query=query, + metadata=metadata, + workflow_app_log_id=workflow_app_log_id, + message_id=message_id, + start_time=workflow_run.created_at, + end_time=workflow_run.finished_at, + ) return workflow_trace_info - def message_trace(self, message_id): + def message_trace(self, message_id: str | None): + if not message_id: + return {} message_data = get_message_data(message_id) if not message_data: return {} - conversation_mode = db.session.query(Conversation.mode).filter_by(id=message_data.conversation_id).first() + conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id) + conversation_mode = db.session.scalars(conversation_mode_stmt).all() + if not conversation_mode or len(conversation_mode) == 0: + return {} conversation_mode = conversation_mode[0] created_at = message_data.created_at inputs = message_data.message @@ -483,6 +523,8 @@ def message_trace(self, message_id): def moderation_trace(self, message_id, timer, **kwargs): moderation_result = kwargs.get("moderation_result") + if not moderation_result: + return {} inputs = kwargs.get("inputs") message_data = get_message_data(message_id) if not message_data: @@ -518,7 +560,7 @@ def moderation_trace(self, message_id, timer, **kwargs): return moderation_trace_info def suggested_question_trace(self, message_id, timer, **kwargs): - suggested_question = kwargs.get("suggested_question") + suggested_question = kwargs.get("suggested_question", []) message_data = get_message_data(message_id) if not message_data: return {} @@ -586,7 +628,7 @@ def dataset_retrieval_trace(self, message_id, timer, **kwargs): dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( message_id=message_id, inputs=message_data.query or message_data.inputs, - documents=[doc.model_dump() for doc in documents], + documents=[doc.model_dump() for doc in documents] if documents else [], start_time=timer.get("start"), end_time=timer.get("end"), metadata=metadata, @@ -596,9 +638,9 @@ def dataset_retrieval_trace(self, message_id, timer, **kwargs): return dataset_retrieval_trace_info def tool_trace(self, message_id, timer, **kwargs): - tool_name = kwargs.get("tool_name") - tool_inputs = kwargs.get("tool_inputs") - tool_outputs = kwargs.get("tool_outputs") + tool_name = kwargs.get("tool_name", "") + tool_inputs = kwargs.get("tool_inputs", {}) + tool_outputs = kwargs.get("tool_outputs", {}) message_data = get_message_data(message_id) if not message_data: return {} @@ -608,7 +650,7 @@ def tool_trace(self, message_id, timer, **kwargs): tool_parameters = {} created_time = message_data.created_at end_time = message_data.updated_at - agent_thoughts: list[MessageAgentThought] = message_data.agent_thoughts + agent_thoughts = message_data.agent_thoughts for agent_thought in agent_thoughts: if tool_name in agent_thought.tools: created_time = agent_thought.created_at @@ -672,6 +714,8 @@ def generate_name_trace(self, conversation_id, timer, **kwargs): generate_conversation_name = kwargs.get("generate_conversation_name") inputs = kwargs.get("inputs") tenant_id = kwargs.get("tenant_id") + if not tenant_id: + return {} start_time = timer.get("start") end_time = timer.get("end") @@ -693,8 +737,8 @@ def generate_name_trace(self, conversation_id, timer, **kwargs): return generate_name_trace_info -trace_manager_timer = None -trace_manager_queue = queue.Queue() +trace_manager_timer: Optional[threading.Timer] = None +trace_manager_queue: queue.Queue = queue.Queue() trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5)) trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100)) @@ -706,7 +750,7 @@ def __init__(self, app_id=None, user_id=None): self.app_id = app_id self.user_id = user_id self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) - self.flask_app = current_app._get_current_object() + self.flask_app = current_app._get_current_object() # type: ignore if trace_manager_timer is None: self.start_timer() @@ -723,7 +767,7 @@ def add_trace_task(self, trace_task: TraceTask): def collect_tasks(self): global trace_manager_queue - tasks = [] + tasks: list[TraceTask] = [] while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty(): task = trace_manager_queue.get_nowait() tasks.append(task) @@ -749,6 +793,8 @@ def start_timer(self): def send_to_celery(self, tasks: list[TraceTask]): with self.flask_app.app_context(): for task in tasks: + if task.app_id is None: + continue file_id = uuid4().hex trace_info = task.execute() task_data = TaskData( diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 998eba9ea..8b06df193 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -18,7 +18,7 @@ def filter_none_values(data: dict): return new_data -def get_message_data(message_id): +def get_message_data(message_id: str): return db.session.query(Message).filter(Message.id == message_id).first() diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 0f3f82496..87c7a79fb 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,5 +1,5 @@ -from collections.abc import Sequence -from typing import Optional +from collections.abc import Mapping, Sequence +from typing import Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.file import file_manager @@ -39,7 +39,7 @@ def get_prompt( self, *, prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate, - inputs: dict[str, str], + inputs: Mapping[str, str], query: str, files: Sequence[File], context: Optional[str], @@ -77,7 +77,7 @@ def get_prompt( def _get_completion_model_prompt_messages( self, prompt_template: CompletionModelPromptTemplate, - inputs: dict, + inputs: Mapping[str, str], query: Optional[str], files: Sequence[File], context: Optional[str], @@ -90,15 +90,15 @@ def _get_completion_model_prompt_messages( """ raw_prompt = prompt_template.text - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] if prompt_template.edition_type == "basic" or not prompt_template.edition_type: parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} + prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs} prompt_inputs = self._set_context_variable(context, parser, prompt_inputs) - if memory and memory_config: + if memory and memory_config and memory_config.role_prefix: role_prefix = memory_config.role_prefix prompt_inputs = self._set_histories_variable( memory=memory, @@ -135,7 +135,7 @@ def _get_completion_model_prompt_messages( def _get_chat_model_prompt_messages( self, prompt_template: list[ChatModelMessage], - inputs: dict, + inputs: Mapping[str, str], query: Optional[str], files: Sequence[File], context: Optional[str], @@ -146,7 +146,7 @@ def _get_chat_model_prompt_messages( """ Get chat model prompt messages. """ - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] for prompt_item in prompt_template: raw_prompt = prompt_item.text @@ -160,7 +160,7 @@ def _get_chat_model_prompt_messages( prompt = vp.convert_template(raw_prompt).text else: parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} + prompt_inputs: Mapping[str, str] = {k: inputs[k] for k in parser.variable_keys if k in inputs} prompt_inputs = self._set_context_variable( context=context, parser=parser, prompt_inputs=prompt_inputs ) @@ -207,7 +207,7 @@ def _get_chat_model_prompt_messages( last_message = prompt_messages[-1] if prompt_messages else None if last_message and last_message.role == PromptMessageRole.USER: # get last user message content and add files - prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] + prompt_message_contents = [TextPromptMessageContent(data=cast(str, last_message.content))] for file in files: prompt_message_contents.append(file_manager.to_prompt_message_content(file)) @@ -229,7 +229,10 @@ def _get_chat_model_prompt_messages( return prompt_messages - def _set_context_variable(self, context: str | None, parser: PromptTemplateParser, prompt_inputs: dict) -> dict: + def _set_context_variable( + self, context: str | None, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str] + ) -> Mapping[str, str]: + prompt_inputs = dict(prompt_inputs) if "#context#" in parser.variable_keys: if context: prompt_inputs["#context#"] = context @@ -238,7 +241,10 @@ def _set_context_variable(self, context: str | None, parser: PromptTemplateParse return prompt_inputs - def _set_query_variable(self, query: str, parser: PromptTemplateParser, prompt_inputs: dict) -> dict: + def _set_query_variable( + self, query: str, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str] + ) -> Mapping[str, str]: + prompt_inputs = dict(prompt_inputs) if "#query#" in parser.variable_keys: if query: prompt_inputs["#query#"] = query @@ -254,9 +260,10 @@ def _set_histories_variable( raw_prompt: str, role_prefix: MemoryConfig.RolePrefix, parser: PromptTemplateParser, - prompt_inputs: dict, + prompt_inputs: Mapping[str, str], model_config: ModelConfigWithCredentialsEntity, - ) -> dict: + ) -> Mapping[str, str]: + prompt_inputs = dict(prompt_inputs) if "#histories#" in parser.variable_keys: if memory: inputs = {"#histories#": "", **prompt_inputs} diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index caa1793ea..09f017a7d 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -31,7 +31,7 @@ def __init__( self.memory = memory def get_prompt(self) -> list[PromptMessage]: - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] num_system = 0 for prompt_message in self.history_messages: if isinstance(prompt_message, SystemPromptMessage): diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 87acdb3c4..1f040599b 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -42,7 +42,7 @@ def _calculate_rest_token( ): max_tokens = ( model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template) + or model_config.parameters.get(parameter_rule.use_template or "") ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens @@ -59,7 +59,7 @@ def _get_history_messages_from_memory( ai_prefix: Optional[str] = None, ) -> str: """Get memory messages.""" - kwargs = {"max_token_limit": max_token_limit} + kwargs: dict[str, Any] = {"max_token_limit": max_token_limit} if human_prefix: kwargs["human_prefix"] = human_prefix @@ -76,11 +76,15 @@ def _get_history_messages_list_from_memory( self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int ) -> list[PromptMessage]: """Get memory messages.""" - return memory.get_history_prompt_messages( - max_token_limit=max_token_limit, - message_limit=memory_config.window.size - if ( - memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0 + return list( + memory.get_history_prompt_messages( + max_token_limit=max_token_limit, + message_limit=memory_config.window.size + if ( + memory_config.window.enabled + and memory_config.window.size is not None + and memory_config.window.size > 0 + ) + else None, ) - else None, ) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 93dd92f18..e75877de9 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -1,7 +1,8 @@ import enum import json import os -from typing import TYPE_CHECKING, Optional +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any, Optional, cast from core.app.app_config.entities import PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -41,7 +42,7 @@ def value_of(cls, value: str) -> "ModelMode": raise ValueError(f"invalid mode value {value}") -prompt_file_contents = {} +prompt_file_contents: dict[str, Any] = {} class SimplePromptTransform(PromptTransform): @@ -53,9 +54,9 @@ def get_prompt( self, app_mode: AppMode, prompt_template_entity: PromptTemplateEntity, - inputs: dict, + inputs: Mapping[str, str], query: str, - files: list["File"], + files: Sequence["File"], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, @@ -66,7 +67,7 @@ def get_prompt( if model_mode == ModelMode.CHAT: prompt_messages, stops = self._get_chat_model_prompt_messages( app_mode=app_mode, - pre_prompt=prompt_template_entity.simple_prompt_template, + pre_prompt=prompt_template_entity.simple_prompt_template or "", inputs=inputs, query=query, files=files, @@ -77,7 +78,7 @@ def get_prompt( else: prompt_messages, stops = self._get_completion_model_prompt_messages( app_mode=app_mode, - pre_prompt=prompt_template_entity.simple_prompt_template, + pre_prompt=prompt_template_entity.simple_prompt_template or "", inputs=inputs, query=query, files=files, @@ -171,11 +172,11 @@ def _get_chat_model_prompt_messages( inputs: dict, query: str, context: Optional[str], - files: list["File"], + files: Sequence["File"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, ) -> tuple[list[PromptMessage], Optional[list[str]]]: - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] # get prompt prompt, _ = self.get_prompt_str_and_rules( @@ -216,7 +217,7 @@ def _get_completion_model_prompt_messages( inputs: dict, query: str, context: Optional[str], - files: list["File"], + files: Sequence["File"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity, ) -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -263,7 +264,7 @@ def _get_completion_model_prompt_messages( return [self.get_last_user_message(prompt, files)], stops - def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage: + def get_last_user_message(self, prompt: str, files: Sequence["File"]) -> UserPromptMessage: if files: prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=prompt)) @@ -288,7 +289,7 @@ def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict # Check if the prompt file is already loaded if prompt_file_name in prompt_file_contents: - return prompt_file_contents[prompt_file_name] + return cast(dict, prompt_file_contents[prompt_file_name]) # Get the absolute path of the subdirectory prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates") @@ -301,7 +302,7 @@ def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict # Store the content of the prompt file prompt_file_contents[prompt_file_name] = content - return content + return cast(dict, content) def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: # baichuan diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index aa175153b..2f4e65146 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import cast +from typing import Any, cast from core.model_runtime.entities import ( AssistantPromptMessage, @@ -72,7 +72,7 @@ def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Seque } ) else: - text = prompt_message.content + text = cast(str, prompt_message.content) prompt = {"role": role, "text": text, "files": files} @@ -99,9 +99,9 @@ def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Seque } ) else: - text = prompt_message.content + text = cast(str, prompt_message.content) - params = { + params: dict[str, Any] = { "role": "user", "text": text, } diff --git a/api/core/prompt/utils/prompt_template_parser.py b/api/core/prompt/utils/prompt_template_parser.py index 0fd08c5d3..8e40674bc 100644 --- a/api/core/prompt/utils/prompt_template_parser.py +++ b/api/core/prompt/utils/prompt_template_parser.py @@ -1,4 +1,5 @@ import re +from collections.abc import Mapping REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{0,29}|#histories#|#query#|#context#)\}\}") WITH_VARIABLE_TMPL_REGEX = re.compile( @@ -28,7 +29,7 @@ def extract(self) -> list: # Regular expression to match the template rules return re.findall(self.regex, self.template) - def format(self, inputs: dict, remove_template_variables: bool = True) -> str: + def format(self, inputs: Mapping[str, str], remove_template_variables: bool = True) -> str: def replacer(match): key = match.group(1) value = inputs.get(key, match.group(0)) # return original matched string if key not found diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 3a1fe300d..010abd12d 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1,7 +1,7 @@ import json from collections import defaultdict from json import JSONDecodeError -from typing import Optional +from typing import Optional, cast from sqlalchemy.exc import IntegrityError @@ -15,6 +15,7 @@ ModelLoadBalancingConfiguration, ModelSettings, QuotaConfiguration, + QuotaUnit, SystemConfiguration, ) from core.helper import encrypter @@ -116,8 +117,8 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations: for provider_entity in provider_entities: # handle include, exclude if is_filtered( - include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, - exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, + include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET), + exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET), data=provider_entity, name_func=lambda x: x.provider, ): @@ -490,12 +491,13 @@ def _init_trial_provider_records( # Init trial provider records if not exists if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: try: + # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic provider_record = Provider( tenant_id=tenant_id, provider_name=provider_name, provider_type=ProviderType.SYSTEM.value, quota_type=ProviderQuotaType.TRIAL.value, - quota_limit=quota.quota_limit, + quota_limit=quota.quota_limit, # type: ignore quota_used=0, is_valid=True, ) @@ -589,7 +591,9 @@ def _to_custom_configuration( if variable in provider_credentials: try: provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa + provider_credentials.get(variable) or "", # type: ignore + self.decoding_rsa_key, + self.decoding_cipher_rsa, ) except ValueError: pass @@ -671,13 +675,9 @@ def _to_system_configuration( # Get hosting configuration hosting_configuration = ext_hosting_provider.hosting_configuration - if ( - provider_entity.provider not in hosting_configuration.provider_map - or not hosting_configuration.provider_map.get(provider_entity.provider).enabled - ): - return SystemConfiguration(enabled=False) - provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider) + if provider_hosting_configuration is None or not provider_hosting_configuration.enabled: + return SystemConfiguration(enabled=False) # Convert provider_records to dict quota_type_to_provider_records_dict = {} @@ -688,14 +688,13 @@ def _to_system_configuration( quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( provider_record ) - quota_configurations = [] for provider_quota in provider_hosting_configuration.quotas: if provider_quota.quota_type not in quota_type_to_provider_records_dict: if provider_quota.quota_type == ProviderQuotaType.FREE: quota_configuration = QuotaConfiguration( quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, quota_used=0, quota_limit=0, is_valid=False, @@ -708,7 +707,7 @@ def _to_system_configuration( quota_configuration = QuotaConfiguration( quota_type=provider_quota.quota_type, - quota_unit=provider_hosting_configuration.quota_unit, + quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, quota_used=provider_record.quota_used, quota_limit=provider_record.quota_limit, is_valid=provider_record.quota_limit > provider_record.quota_used @@ -725,12 +724,12 @@ def _to_system_configuration( current_using_credentials = provider_hosting_configuration.credentials if current_quota_type == ProviderQuotaType.FREE: - provider_record = quota_type_to_provider_records_dict.get(current_quota_type) + provider_record_quota_free = quota_type_to_provider_records_dict.get(current_quota_type) - if provider_record: + if provider_record_quota_free: provider_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, - identity_id=provider_record.id, + identity_id=provider_record_quota_free.id, cache_type=ProviderCredentialsCacheType.PROVIDER, ) @@ -763,7 +762,7 @@ def _to_system_configuration( except ValueError: pass - current_using_credentials = provider_credentials + current_using_credentials = provider_credentials or {} # cache provider credentials provider_credentials_cache.set(credentials=current_using_credentials) @@ -842,7 +841,7 @@ def _to_model_settings( else [] ) - model_settings = [] + model_settings: list[ModelSettings] = [] if not provider_model_settings: return model_settings diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index a0153c1e5..95a2316f1 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -32,8 +32,11 @@ def create(self, texts: list[Document], **kwargs) -> BaseKeyword: keywords = keyword_table_handler.extract_keywords( text.page_content, self._config.max_keywords_per_chunk ) - self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) + if text.metadata is not None: + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table( + keyword_table or {}, text.metadata["doc_id"], list(keywords) + ) self._save_dataset_keyword_table(keyword_table) @@ -58,20 +61,26 @@ def add_texts(self, texts: list[Document], **kwargs): keywords = keyword_table_handler.extract_keywords( text.page_content, self._config.max_keywords_per_chunk ) - self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) + if text.metadata is not None: + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table( + keyword_table or {}, text.metadata["doc_id"], list(keywords) + ) self._save_dataset_keyword_table(keyword_table) def text_exists(self, id: str) -> bool: keyword_table = self._get_dataset_keyword_table() + if keyword_table is None: + return False return id in set.union(*keyword_table.values()) def delete_by_ids(self, ids: list[str]) -> None: lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table = self._get_dataset_keyword_table() - keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) + if keyword_table is not None: + keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) self._save_dataset_keyword_table(keyword_table) @@ -80,7 +89,7 @@ def search(self, query: str, **kwargs: Any) -> list[Document]: k = kwargs.get("top_k", 4) - sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) + sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k) documents = [] for chunk_index in sorted_chunk_indices: @@ -137,7 +146,7 @@ def _get_dataset_keyword_table(self) -> Optional[dict]: if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict if keyword_table_dict: - return keyword_table_dict["__data__"]["table"] + return dict(keyword_table_dict["__data__"]["table"]) else: keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE dataset_keyword_table = DatasetKeywordTable( @@ -188,8 +197,8 @@ def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4): # go through text chunks in order of most matching keywords chunk_indices_count: dict[str, int] = defaultdict(int) - keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())] - for keyword in keywords: + keywords_list = [keyword for keyword in keywords if keyword in set(keyword_table.keys())] + for keyword in keywords_list: for node_id in keyword_table[keyword]: chunk_indices_count[node_id] += 1 @@ -215,7 +224,7 @@ def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list def create_segment_keywords(self, node_id: str, keywords: list[str]): keyword_table = self._get_dataset_keyword_table() self._update_segment_keywords(self.dataset.id, node_id, keywords) - keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) + keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords) self._save_dataset_keyword_table(keyword_table) def multi_create_segment_keywords(self, pre_segment_data_list: list): @@ -226,17 +235,19 @@ def multi_create_segment_keywords(self, pre_segment_data_list: list): if pre_segment_data["keywords"]: segment.keywords = pre_segment_data["keywords"] keyword_table = self._add_text_to_keyword_table( - keyword_table, segment.index_node_id, pre_segment_data["keywords"] + keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"] ) else: keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk) segment.keywords = list(keywords) - keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords)) + keyword_table = self._add_text_to_keyword_table( + keyword_table or {}, segment.index_node_id, list(keywords) + ) self._save_dataset_keyword_table(keyword_table) def update_segment_keywords_index(self, node_id: str, keywords: list[str]): keyword_table = self._get_dataset_keyword_table() - keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) + keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords) self._save_dataset_keyword_table(keyword_table) diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index ec809cf32..a6214d955 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -1,25 +1,27 @@ import re -from typing import Optional +from typing import Optional, cast class JiebaKeywordTableHandler: def __init__(self): - import jieba.analyse + import jieba.analyse # type: ignore from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS - jieba.analyse.default_tfidf.stop_words = STOPWORDS + jieba.analyse.default_tfidf.stop_words = STOPWORDS # type: ignore def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]: """Extract keywords with JIEBA tfidf.""" - import jieba + import jieba.analyse # type: ignore keywords = jieba.analyse.extract_tags( sentence=text, topK=max_keywords_per_chunk, ) + # jieba.analyse.extract_tags returns list[Any] when withFlag is False by default. + keywords = cast(list[str], keywords) - return set(self._expand_tokens_with_subtokens(keywords)) + return set(self._expand_tokens_with_subtokens(set(keywords))) def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]: """Get subtokens from a list of tokens., filtering for stopwords.""" diff --git a/api/core/rag/datasource/keyword/keyword_base.py b/api/core/rag/datasource/keyword/keyword_base.py index be00687ab..b261b40b7 100644 --- a/api/core/rag/datasource/keyword/keyword_base.py +++ b/api/core/rag/datasource/keyword/keyword_base.py @@ -37,6 +37,8 @@ def search(self, query: str, **kwargs: Any) -> list[Document]: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts.copy(): + if text.metadata is None: + continue doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: @@ -45,4 +47,4 @@ def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata["doc_id"] for text in texts] + return [text.metadata["doc_id"] for text in texts if text.metadata] diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 18f8d4e83..3a8200bc7 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -6,10 +6,14 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.embedding.retrieval import RetrievalSegments +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db -from models.dataset import Dataset +from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { @@ -31,7 +35,7 @@ def retrieve( top_k: int, score_threshold: Optional[float] = 0.0, reranking_model: Optional[dict] = None, - reranking_mode: Optional[str] = "reranking_model", + reranking_mode: str = "reranking_model", weights: Optional[dict] = None, ): if not query: @@ -42,15 +46,15 @@ def retrieve( if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: return [] - all_documents = [] - threads = [] - exceptions = [] + all_documents: list[Document] = [] + threads: list[threading.Thread] = [] + exceptions: list[str] = [] # retrieval_model source with keyword if retrieval_method == "keyword_search": keyword_thread = threading.Thread( target=RetrievalService.keyword_search, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "top_k": top_k, @@ -65,7 +69,7 @@ def retrieve( embedding_thread = threading.Thread( target=RetrievalService.embedding_search, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "top_k": top_k, @@ -84,7 +88,7 @@ def retrieve( full_text_index_thread = threading.Thread( target=RetrievalService.full_text_index_search, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "retrieval_method": retrieval_method, @@ -124,7 +128,7 @@ def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model if not dataset: return [] all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( - dataset.tenant_id, dataset_id, query, external_retrieval_model + dataset.tenant_id, dataset_id, query, external_retrieval_model or {} ) return all_documents @@ -135,6 +139,8 @@ def keyword_search( with flask_app.app_context(): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("dataset not found") keyword = Keyword(dataset=dataset) @@ -159,6 +165,8 @@ def embedding_search( with flask_app.app_context(): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("dataset not found") vector = Vector(dataset=dataset) @@ -209,6 +217,8 @@ def full_text_index_search( with flask_app.app_context(): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("dataset not found") vector_processor = Vector( dataset=dataset, @@ -241,3 +251,89 @@ def full_text_index_search( @staticmethod def escape_query_for_search(query: str) -> str: return query.replace('"', '\\"') + + @staticmethod + def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]: + records = [] + include_segment_ids = [] + segment_child_map = {} + for document in documents: + document_id = document.metadata.get("document_id") + dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + if dataset_document: + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_index_node_id = document.metadata.get("doc_id") + result = ( + db.session.query(ChildChunk, DocumentSegment) + .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) + .filter( + ChildChunk.index_node_id == child_index_node_id, + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + ) + .first() + ) + if result: + child_chunk, segment = result + if not segment: + continue + if segment.id not in include_segment_ids: + include_segment_ids.append(segment.id) + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + map_detail = { + "max_score": document.metadata.get("score", 0.0), + "child_chunks": [child_chunk_detail], + } + segment_child_map[segment.id] = map_detail + record = { + "segment": segment, + } + records.append(record) + else: + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) + segment_child_map[segment.id]["max_score"] = max( + segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) + ) + else: + continue + else: + index_node_id = document.metadata["doc_id"] + + segment = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, + ) + .first() + ) + + if not segment: + continue + include_segment_ids.append(segment.id) + record = { + "segment": segment, + "score": document.metadata.get("score", None), + } + + records.append(record) + for record in records: + if record["segment"].id in segment_child_map: + record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None) + record["score"] = segment_child_map[record["segment"].id]["max_score"] + + return [RetrievalSegments(**record) for record in records] diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index 09104ae42..603d3fdbc 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -17,12 +17,19 @@ class AnalyticdbVector(BaseVector): def __init__( - self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig + self, + collection_name: str, + api_config: AnalyticdbVectorOpenAPIConfig | None, + sql_config: AnalyticdbVectorBySqlConfig | None, ): super().__init__(collection_name) if api_config is not None: - self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config) + self.analyticdb_vector: AnalyticdbVectorOpenAPI | AnalyticdbVectorBySql = AnalyticdbVectorOpenAPI( + collection_name, api_config + ) else: + if sql_config is None: + raise ValueError("Either api_config or sql_config must be provided") self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config) def get_type(self) -> str: @@ -33,8 +40,8 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) self.analyticdb_vector._create_collection_if_not_exists(dimension) self.analyticdb_vector.add_texts(texts, embeddings) - def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - self.analyticdb_vector.add_texts(texts, embeddings) + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + self.analyticdb_vector.add_texts(documents, embeddings) def text_exists(self, id: str) -> bool: return self.analyticdb_vector.text_exists(id) @@ -68,13 +75,13 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings if dify_config.ANALYTICDB_HOST is None: # implemented through OpenAPI apiConfig = AnalyticdbVectorOpenAPIConfig( - access_key_id=dify_config.ANALYTICDB_KEY_ID, - access_key_secret=dify_config.ANALYTICDB_KEY_SECRET, - region_id=dify_config.ANALYTICDB_REGION_ID, - instance_id=dify_config.ANALYTICDB_INSTANCE_ID, - account=dify_config.ANALYTICDB_ACCOUNT, - account_password=dify_config.ANALYTICDB_PASSWORD, - namespace=dify_config.ANALYTICDB_NAMESPACE, + access_key_id=dify_config.ANALYTICDB_KEY_ID or "", + access_key_secret=dify_config.ANALYTICDB_KEY_SECRET or "", + region_id=dify_config.ANALYTICDB_REGION_ID or "", + instance_id=dify_config.ANALYTICDB_INSTANCE_ID or "", + account=dify_config.ANALYTICDB_ACCOUNT or "", + account_password=dify_config.ANALYTICDB_PASSWORD or "", + namespace=dify_config.ANALYTICDB_NAMESPACE or "", namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD, ) sqlConfig = None @@ -83,11 +90,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings sqlConfig = AnalyticdbVectorBySqlConfig( host=dify_config.ANALYTICDB_HOST, port=dify_config.ANALYTICDB_PORT, - account=dify_config.ANALYTICDB_ACCOUNT, - account_password=dify_config.ANALYTICDB_PASSWORD, + account=dify_config.ANALYTICDB_ACCOUNT or "", + account_password=dify_config.ANALYTICDB_PASSWORD or "", min_connection=dify_config.ANALYTICDB_MIN_CONNECTION, max_connection=dify_config.ANALYTICDB_MAX_CONNECTION, - namespace=dify_config.ANALYTICDB_NAMESPACE, + namespace=dify_config.ANALYTICDB_NAMESPACE or "", ) apiConfig = None return AnalyticdbVector( diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index 05e0ebc54..095752ea8 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -1,5 +1,5 @@ import json -from typing import Any +from typing import Any, Optional from pydantic import BaseModel, model_validator @@ -20,7 +20,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel): account: str account_password: str namespace: str = "dify" - namespace_password: str = (None,) + namespace_password: Optional[str] = None metrics: str = "cosine" read_timeout: int = 60000 @@ -55,8 +55,8 @@ def to_analyticdb_client_params(self): class AnalyticdbVectorOpenAPI: def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig): try: - from alibabacloud_gpdb20160503.client import Client - from alibabacloud_tea_openapi import models as open_api_models + from alibabacloud_gpdb20160503.client import Client # type: ignore + from alibabacloud_tea_openapi import models as open_api_models # type: ignore except: raise ImportError(_import_err_msg) self._collection_name = collection_name.lower() @@ -77,7 +77,7 @@ def _initialize(self) -> None: redis_client.set(database_exist_cache_key, 1, ex=3600) def _initialize_vector_database(self) -> None: - from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + from alibabacloud_gpdb20160503 import models as gpdb_20160503_models # type: ignore request = gpdb_20160503_models.InitVectorDatabaseRequest( dbinstance_id=self.config.instance_id, @@ -89,7 +89,7 @@ def _initialize_vector_database(self) -> None: def _create_namespace_if_not_exists(self) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - from Tea.exceptions import TeaException + from Tea.exceptions import TeaException # type: ignore try: request = gpdb_20160503_models.DescribeNamespaceRequest( @@ -159,17 +159,18 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = [] for doc, embedding in zip(documents, embeddings, strict=True): - metadata = { - "ref_doc_id": doc.metadata["doc_id"], - "page_content": doc.page_content, - "metadata_": json.dumps(doc.metadata), - } - rows.append( - gpdb_20160503_models.UpsertCollectionDataRequestRows( - vector=embedding, - metadata=metadata, + if doc.metadata is not None: + metadata = { + "ref_doc_id": doc.metadata["doc_id"], + "page_content": doc.page_content, + "metadata_": json.dumps(doc.metadata), + } + rows.append( + gpdb_20160503_models.UpsertCollectionDataRequestRows( + vector=embedding, + metadata=metadata, + ) ) - ) request = gpdb_20160503_models.UpsertCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -258,7 +259,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc metadata=metadata, ) documents.append(doc) - documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) + documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -290,7 +291,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: metadata=metadata, ) documents.append(doc) - documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) + documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return documents def delete(self) -> None: diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py index e474db5cb..4d8f79294 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py @@ -3,8 +3,8 @@ from contextlib import contextmanager from typing import Any -import psycopg2.extras -import psycopg2.pool +import psycopg2.extras # type: ignore +import psycopg2.pool # type: ignore from pydantic import BaseModel, model_validator from core.rag.models.document import Document @@ -75,6 +75,7 @@ def _create_connection_pool(self): @contextmanager def _get_cursor(self): + assert self.pool is not None, "Connection pool is not initialized" conn = self.pool.getconn() cur = conn.cursor() try: @@ -156,16 +157,17 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s)); """ for i, doc in enumerate(documents): - values.append( - ( - id_prefix + str(i), - doc.metadata.get("doc_id", str(uuid.uuid4())), - embeddings[i], - doc.page_content, - json.dumps(doc.metadata), - doc.page_content, + if doc.metadata is not None: + values.append( + ( + id_prefix + str(i), + doc.metadata.get("doc_id", str(uuid.uuid4())), + embeddings[i], + doc.page_content, + json.dumps(doc.metadata), + doc.page_content, + ) ) - ) with self._get_cursor() as cur: psycopg2.extras.execute_batch(cur, sql, values) diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index eb78e8aa6..a658495af 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -5,13 +5,13 @@ import numpy as np from pydantic import BaseModel, model_validator -from pymochow import MochowClient -from pymochow.auth.bce_credentials import BceCredentials -from pymochow.configuration import Configuration -from pymochow.exception import ServerError -from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState -from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex -from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row +from pymochow import MochowClient # type: ignore +from pymochow.auth.bce_credentials import BceCredentials # type: ignore +from pymochow.configuration import Configuration # type: ignore +from pymochow.exception import ServerError # type: ignore +from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore +from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore +from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore from configs import dify_config from core.rag.datasource.vdb.vector_base import BaseVector @@ -75,7 +75,7 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): texts = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents] + metadatas = [doc.metadata for doc in documents if doc.metadata is not None] total_count = len(documents) batch_size = 1000 @@ -84,6 +84,8 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** for start in range(0, total_count, batch_size): end = min(start + batch_size, total_count) rows = [] + assert len(metadatas) == total_count, "metadatas length should be equal to total_count" + # FIXME do you need this assert? for i in range(start, end, 1): row = Row( id=metadatas[i].get("doc_id", str(uuid.uuid4())), @@ -111,6 +113,8 @@ def text_exists(self, id: str) -> bool: return False def delete_by_ids(self, ids: list[str]) -> None: + if not ids: + return quoted_ids = [f"'{id}'" for id in ids] self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})") @@ -136,7 +140,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: # baidu vector database doesn't support bm25 search on current version return [] - def _get_search_res(self, res, score_threshold): + def _get_search_res(self, res, score_threshold) -> list[Document]: docs = [] for row in res.rows: row_data = row.get("row", {}) @@ -276,11 +280,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return BaiduVector( collection_name=collection_name, config=BaiduConfig( - endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT, + endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT or "", connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS, - account=dify_config.BAIDU_VECTOR_DB_ACCOUNT, - api_key=dify_config.BAIDU_VECTOR_DB_API_KEY, - database=dify_config.BAIDU_VECTOR_DB_DATABASE, + account=dify_config.BAIDU_VECTOR_DB_ACCOUNT or "", + api_key=dify_config.BAIDU_VECTOR_DB_API_KEY or "", + database=dify_config.BAIDU_VECTOR_DB_DATABASE or "", shard=dify_config.BAIDU_VECTOR_DB_SHARD, replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS, ), diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index a9e1486ed..907c4d228 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -71,16 +71,20 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** metadatas = [d.metadata for d in documents] collection = self._client.get_or_create_collection(self._collection_name) - collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) + # FIXME: chromadb using numpy array, fix the type error later + collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore def delete_by_metadata_field(self, key: str, value: str): collection = self._client.get_or_create_collection(self._collection_name) - collection.delete(where={key: {"$eq": value}}) + # FIXME: fix the type error later + collection.delete(where={key: {"$eq": value}}) # type: ignore def delete(self): self._client.delete_collection(self._collection_name) def delete_by_ids(self, ids: list[str]) -> None: + if not ids: + return collection = self._client.get_or_create_collection(self._collection_name) collection.delete(ids=ids) @@ -94,15 +98,19 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) score_threshold = float(kwargs.get("score_threshold") or 0.0) - ids: list[str] = results["ids"][0] - documents: list[str] = results["documents"][0] - metadatas: dict[str, Any] = results["metadatas"][0] - distances: list[float] = results["distances"][0] + # Check if results contain data + if not results["ids"] or not results["documents"] or not results["metadatas"] or not results["distances"]: + return [] + + ids = results["ids"][0] + documents = results["documents"][0] + metadatas = results["metadatas"][0] + distances = results["distances"][0] docs = [] for index in range(len(ids)): distance = distances[index] - metadata = metadatas[index] + metadata = dict(metadatas[index]) if distance >= score_threshold: metadata["score"] = distance doc = Document( @@ -111,7 +119,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -133,7 +141,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return ChromaVector( collection_name=collection_name, config=ChromaConfig( - host=dify_config.CHROMA_HOST, + host=dify_config.CHROMA_HOST or "", port=dify_config.CHROMA_PORT, tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, diff --git a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py index d26726e86..68a995278 100644 --- a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py +++ b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py @@ -5,14 +5,14 @@ from datetime import timedelta from typing import Any -from couchbase import search -from couchbase.auth import PasswordAuthenticator -from couchbase.cluster import Cluster -from couchbase.management.search import SearchIndex +from couchbase import search # type: ignore +from couchbase.auth import PasswordAuthenticator # type: ignore +from couchbase.cluster import Cluster # type: ignore +from couchbase.management.search import SearchIndex # type: ignore # needed for options -- cluster, timeout, SQL++ (N1QL) query, etc. -from couchbase.options import ClusterOptions, SearchOptions -from couchbase.vector_search import VectorQuery, VectorSearch +from couchbase.options import ClusterOptions, SearchOptions # type: ignore +from couchbase.vector_search import VectorQuery, VectorSearch # type: ignore from flask import current_app from pydantic import BaseModel, model_validator @@ -231,7 +231,7 @@ def text_exists(self, id: str) -> bool: # Pass the id as a parameter to the query result = self._cluster.query(query, named_parameters={"doc_id": id}).execute() for row in result: - return row["count"] > 0 + return bool(row["count"] > 0) return False # Return False if no rows are returned def delete_by_ids(self, ids: list[str]) -> None: @@ -369,10 +369,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return CouchbaseVector( collection_name=collection_name, config=CouchbaseConfig( - connection_string=config.get("COUCHBASE_CONNECTION_STRING"), - user=config.get("COUCHBASE_USER"), - password=config.get("COUCHBASE_PASSWORD"), - bucket_name=config.get("COUCHBASE_BUCKET_NAME"), - scope_name=config.get("COUCHBASE_SCOPE_NAME"), + connection_string=config.get("COUCHBASE_CONNECTION_STRING", ""), + user=config.get("COUCHBASE_USER", ""), + password=config.get("COUCHBASE_PASSWORD", ""), + bucket_name=config.get("COUCHBASE_BUCKET_NAME", ""), + scope_name=config.get("COUCHBASE_SCOPE_NAME", ""), ), ) diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py new file mode 100644 index 000000000..27575197f --- /dev/null +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py @@ -0,0 +1,104 @@ +import json +import logging +from typing import Any, Optional + +from flask import current_app + +from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ( + ElasticSearchConfig, + ElasticSearchVector, + ElasticSearchVectorFactory, +) +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from extensions.ext_redis import redis_client +from models.dataset import Dataset + +logger = logging.getLogger(__name__) + + +class ElasticSearchJaVector(ElasticSearchVector): + def create_collection( + self, + embeddings: list[list[float]], + metadatas: Optional[list[dict[Any, Any]]] = None, + index_params: Optional[dict] = None, + ): + lock_name = f"vector_indexing_lock_{self._collection_name}" + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" + if redis_client.get(collection_exist_cache_key): + logger.info(f"Collection {self._collection_name} already exists.") + return + + if not self._client.indices.exists(index=self._collection_name): + dim = len(embeddings[0]) + settings = { + "analysis": { + "analyzer": { + "ja_analyzer": { + "type": "custom", + "char_filter": [ + "icu_normalizer", + "kuromoji_iteration_mark", + ], + "tokenizer": "kuromoji_tokenizer", + "filter": [ + "kuromoji_baseform", + "kuromoji_part_of_speech", + "ja_stop", + "kuromoji_number", + "kuromoji_stemmer", + ], + } + } + } + } + mappings = { + "properties": { + Field.CONTENT_KEY.value: { + "type": "text", + "analyzer": "ja_analyzer", + "search_analyzer": "ja_analyzer", + }, + Field.VECTOR.value: { # Make sure the dimension is correct here + "type": "dense_vector", + "dims": dim, + "index": True, + "similarity": "cosine", + }, + Field.METADATA_KEY.value: { + "type": "object", + "properties": { + "doc_id": {"type": "keyword"} # Map doc_id to keyword type + }, + }, + } + } + self._client.indices.create(index=self._collection_name, settings=settings, mappings=mappings) + + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + +class ElasticSearchJaVectorFactory(ElasticSearchVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchJaVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name)) + + config = current_app.config + return ElasticSearchJaVector( + index_name=collection_name, + config=ElasticSearchConfig( + host=config.get("ELASTICSEARCH_HOST", "localhost"), + port=config.get("ELASTICSEARCH_PORT", 9200), + username=config.get("ELASTICSEARCH_USERNAME", ""), + password=config.get("ELASTICSEARCH_PASSWORD", ""), + ), + attributes=[], + ) diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index b08811a02..cca696bae 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -1,7 +1,7 @@ import json import logging import math -from typing import Any, Optional +from typing import Any, Optional, cast from urllib.parse import urlparse import requests @@ -70,7 +70,7 @@ def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: def _get_version(self) -> str: info = self._client.info() - return info["version"]["number"] + return cast(str, info["version"]["number"]) def _check_version(self): if self._version < "8.0.0": @@ -98,6 +98,8 @@ def text_exists(self, id: str) -> bool: return bool(self._client.exists(index=self._collection_name, id=id)) def delete_by_ids(self, ids: list[str]) -> None: + if not ids: + return for id in ids: self._client.delete(index=self._collection_name, id=id) @@ -135,7 +137,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc for doc, score in docs_and_scores: score_threshold = float(kwargs.get("score_threshold") or 0.0) if score > score_threshold: - doc.metadata["score"] = score + if doc.metadata is not None: + doc.metadata["score"] = score docs.append(doc) return docs @@ -156,12 +159,15 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return docs def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - metadatas = [d.metadata for d in texts] + metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas) self.add_texts(texts, embeddings, **kwargs) def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, + embeddings: list[list[float]], + metadatas: Optional[list[dict[Any, Any]]] = None, + index_params: Optional[dict] = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): @@ -208,10 +214,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return ElasticSearchVector( index_name=collection_name, config=ElasticSearchConfig( - host=config.get("ELASTICSEARCH_HOST"), - port=config.get("ELASTICSEARCH_PORT"), - username=config.get("ELASTICSEARCH_USERNAME"), - password=config.get("ELASTICSEARCH_PASSWORD"), + host=config.get("ELASTICSEARCH_HOST", "localhost"), + port=config.get("ELASTICSEARCH_PORT", 9200), + username=config.get("ELASTICSEARCH_USERNAME", ""), + password=config.get("ELASTICSEARCH_PASSWORD", ""), ), attributes=[], ) diff --git a/api/core/rag/datasource/vdb/field.py b/api/core/rag/datasource/vdb/field.py index 1c16e4d9c..a64407bce 100644 --- a/api/core/rag/datasource/vdb/field.py +++ b/api/core/rag/datasource/vdb/field.py @@ -6,6 +6,8 @@ class Field(Enum): METADATA_KEY = "metadata" GROUP_KEY = "group_id" VECTOR = "vector" + # Sparse Vector aims to support full text search + SPARSE_VECTOR = "sparse_vector" TEXT_KEY = "text" PRIMARY_KEY = "id" DOC_ID = "metadata.doc_id" diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index 8646e52cf..66fba763e 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -42,7 +42,7 @@ def validate_config(cls, values: dict) -> dict: return values def to_opensearch_params(self) -> dict[str, Any]: - params = {"hosts": self.hosts} + params: dict[str, Any] = {"hosts": self.hosts} if self.username and self.password: params["http_auth"] = (self.username, self.password) return params @@ -53,7 +53,7 @@ def __init__(self, collection_name: str, config: LindormVectorStoreConfig, using self._routing = None self._routing_field = None if using_ugc: - routing_value: str = kwargs.get("routing_value") + routing_value: str | None = kwargs.get("routing_value") if routing_value is None: raise ValueError("UGC index should init vector with valid 'routing_value' parameter value") self._routing = routing_value.lower() @@ -87,14 +87,15 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** "_id": uuids[i], } } - action_values = { + action_values: dict[str, Any] = { Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i], # Make sure you pass an array here Field.METADATA_KEY.value: documents[i].metadata, } if self._using_ugc: action_header["index"]["routing"] = self._routing - action_values[self._routing_field] = self._routing + if self._routing_field is not None: + action_values[self._routing_field] = self._routing actions.append(action_header) actions.append(action_values) response = self._client.bulk(actions) @@ -105,7 +106,9 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** self.refresh() def get_ids_by_metadata_field(self, key: str, value: str): - query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}} + query: dict[str, Any] = { + "query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}} + } if self._using_ugc: query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}}) response = self._client.search(index=self._collection_name, body=query) @@ -191,7 +194,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc for doc, score in docs_and_scores: score_threshold = kwargs.get("score_threshold", 0.0) or 0.0 if score > score_threshold: - doc.metadata["score"] = score + if doc.metadata is not None: + doc.metadata["score"] = score docs.append(doc) return docs @@ -254,7 +258,7 @@ def create_collection(self, dimension: int, **kwargs): hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500) ivfpq_m = kwargs.pop("ivfpq_m", dimension) nlist = kwargs.pop("nlist", 1000) - centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False) + centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", nlist >= 5000) centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24) centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500) centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100) @@ -301,7 +305,7 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic if method_name == "ivfpq": ivfpq_m = kwargs["ivfpq_m"] nlist = kwargs["nlist"] - centroids_use_hnsw = True if nlist > 10000 else False + centroids_use_hnsw = nlist > 10000 centroids_hnsw_m = 24 centroids_hnsw_ef_construct = 500 centroids_hnsw_ef_search = 100 @@ -366,6 +370,7 @@ def default_text_search_query( routing_field: Optional[str] = None, **kwargs, ) -> dict: + query_clause: dict[str, Any] = {} if routing is not None: query_clause = { "bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]} @@ -386,7 +391,7 @@ def default_text_search_query( else: must = [query_clause] - boolean_query = {"must": must} + boolean_query: dict[str, Any] = {"must": must} if must_not: if not isinstance(must_not, list): @@ -426,7 +431,7 @@ def default_vector_search_query( filter_type = "post_filter" if filter_type is None else filter_type if not isinstance(filters, list): raise RuntimeError(f"unexpected filter with {type(filters)}") - final_ext = {"lvector": {}} + final_ext: dict[str, Any] = {"lvector": {}} if min_score != "0.0": final_ext["lvector"]["min_score"] = min_score if ef_search: @@ -438,7 +443,7 @@ def default_vector_search_query( if client_refactor: final_ext["lvector"]["client_refactor"] = client_refactor - search_query = { + search_query: dict[str, Any] = { "size": k, "_source": True, # force return '_source' "query": {"knn": {vector_field: {"vector": query_vector, "k": k}}}, @@ -446,8 +451,8 @@ def default_vector_search_query( if filters is not None: # when using filter, transform filter from List[Dict] to Dict as valid format - filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0] - search_query["query"]["knn"][vector_field]["filter"] = filters # filter should be Dict + filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0] + search_query["query"]["knn"][vector_field]["filter"] = filter_dict # filter should be Dict if filter_type: final_ext["lvector"]["filter_type"] = filter_type @@ -459,17 +464,19 @@ def default_vector_search_query( class LindormVectorStoreFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore: lindorm_config = LindormVectorStoreConfig( - hosts=dify_config.LINDORM_URL, + hosts=dify_config.LINDORM_URL or "", username=dify_config.LINDORM_USERNAME, password=dify_config.LINDORM_PASSWORD, using_ugc=dify_config.USING_UGC_INDEX, ) using_ugc = dify_config.USING_UGC_INDEX + if using_ugc is None: + raise ValueError("USING_UGC_INDEX is not set") routing_value = None if dataset.index_struct: # if an existed record's index_struct_dict doesn't contain using_ugc field, # it actually stores in the normal index format - stored_in_ugc = dataset.index_struct_dict.get("using_ugc", False) + stored_in_ugc: bool = dataset.index_struct_dict.get("using_ugc", False) using_ugc = stored_in_ugc if stored_in_ugc: dimension = dataset.index_struct_dict["dimension"] diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 5a263d6e7..9a184f7dd 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -2,9 +2,10 @@ import logging from typing import Any, Optional +from packaging import version from pydantic import BaseModel, model_validator -from pymilvus import MilvusClient, MilvusException -from pymilvus.milvus_client import IndexParams +from pymilvus import MilvusClient, MilvusException # type: ignore +from pymilvus.milvus_client import IndexParams # type: ignore from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -20,16 +21,25 @@ class MilvusConfig(BaseModel): - uri: str - token: Optional[str] = None - user: str - password: str - batch_size: int = 100 - database: str = "default" + """ + Configuration class for Milvus connection. + """ + + uri: str # Milvus server URI + token: Optional[str] = None # Optional token for authentication + user: str # Username for authentication + password: str # Password for authentication + batch_size: int = 100 # Batch size for operations + database: str = "default" # Database name + enable_hybrid_search: bool = False # Flag to enable hybrid search @model_validator(mode="before") @classmethod def validate_config(cls, values: dict) -> dict: + """ + Validate the configuration values. + Raises ValueError if required fields are missing. + """ if not values.get("uri"): raise ValueError("config MILVUS_URI is required") if not values.get("user"): @@ -39,6 +49,9 @@ def validate_config(cls, values: dict) -> dict: return values def to_milvus_params(self): + """ + Convert the configuration to a dictionary of Milvus connection parameters. + """ return { "uri": self.uri, "token": self.token, @@ -49,26 +62,57 @@ def to_milvus_params(self): class MilvusVector(BaseVector): + """ + Milvus vector storage implementation. + """ + def __init__(self, collection_name: str, config: MilvusConfig): super().__init__(collection_name) self._client_config = config self._client = self._init_client(config) - self._consistency_level = "Session" - self._fields = [] + self._consistency_level = "Session" # Consistency level for Milvus operations + self._fields: list[str] = [] # List of fields in the collection + self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported + + def _check_hybrid_search_support(self) -> bool: + """ + Check if the current Milvus version supports hybrid search. + Returns True if the version is >= 2.5.0, otherwise False. + """ + if not self._client_config.enable_hybrid_search: + return False + + try: + milvus_version = self._client.get_server_version() + return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version + except Exception as e: + logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.") + return False def get_type(self) -> str: + """ + Get the type of vector storage (Milvus). + """ return VectorType.MILVUS def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + """ + Create a collection and add texts with embeddings. + """ index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}} - metadatas = [d.metadata for d in texts] + metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas, index_params) self.add_texts(texts, embeddings) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + """ + Add texts and their embeddings to the collection. + """ insert_dict_list = [] for i in range(len(documents)): insert_dict = { + # Do not need to insert the sparse_vector field separately, as the text_bm25_emb + # function will automatically convert the native text into a sparse vector for us. Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i], Field.METADATA_KEY.value: documents[i].metadata, @@ -76,12 +120,11 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** insert_dict_list.append(insert_dict) # Total insert count total_count = len(insert_dict_list) - pks: list[str] = [] for i in range(0, total_count, 1000): - batch_insert_list = insert_dict_list[i : i + 1000] # Insert into the collection. + batch_insert_list = insert_dict_list[i : i + 1000] try: ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list) pks.extend(ids) @@ -91,6 +134,9 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** return pks def get_ids_by_metadata_field(self, key: str, value: str): + """ + Get document IDs by metadata field key and value. + """ result = self._client.query( collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"] ) @@ -100,12 +146,18 @@ def get_ids_by_metadata_field(self, key: str, value: str): return None def delete_by_metadata_field(self, key: str, value: str): + """ + Delete documents by metadata field key and value. + """ if self._client.has_collection(self._collection_name): ids = self.get_ids_by_metadata_field(key, value) if ids: self._client.delete(collection_name=self._collection_name, pks=ids) def delete_by_ids(self, ids: list[str]) -> None: + """ + Delete documents by their IDs. + """ if self._client.has_collection(self._collection_name): result = self._client.query( collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"] @@ -115,10 +167,16 @@ def delete_by_ids(self, ids: list[str]) -> None: self._client.delete(collection_name=self._collection_name, pks=ids) def delete(self) -> None: + """ + Delete the entire collection. + """ if self._client.has_collection(self._collection_name): self._client.drop_collection(self._collection_name, None) def text_exists(self, id: str) -> bool: + """ + Check if a text with the given ID exists in the collection. + """ if not self._client.has_collection(self._collection_name): return False @@ -128,32 +186,80 @@ def text_exists(self, id: str) -> bool: return len(result) > 0 + def field_exists(self, field: str) -> bool: + """ + Check if a field exists in the collection. + """ + return field in self._fields + + def _process_search_results( + self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0 + ) -> list[Document]: + """ + Common method to process search results + + :param results: Search results + :param output_fields: Fields to be output + :param score_threshold: Score threshold for filtering + :return: List of documents + """ + docs = [] + for result in results[0]: + metadata = result["entity"].get(output_fields[1], {}) + metadata["score"] = result["distance"] + + if result["distance"] > score_threshold: + doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata) + docs.append(doc) + + return docs + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - # Set search parameters. + """ + Search for documents by vector similarity. + """ results = self._client.search( collection_name=self._collection_name, data=[query_vector], + anns_field=Field.VECTOR.value, limit=kwargs.get("top_k", 4), output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], ) - # Organize results. - docs = [] - for result in results[0]: - metadata = result["entity"].get(Field.METADATA_KEY.value) - metadata["score"] = result["distance"] - score_threshold = float(kwargs.get("score_threshold") or 0.0) - if result["distance"] > score_threshold: - doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata) - docs.append(doc) - return docs + + return self._process_search_results( + results, + output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + score_threshold=float(kwargs.get("score_threshold") or 0.0), + ) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - # milvus/zilliz doesn't support bm25 search - return [] + """ + Search for documents by full-text search (if hybrid search is enabled). + """ + if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value): + logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)") + return [] + + results = self._client.search( + collection_name=self._collection_name, + data=[query], + anns_field=Field.SPARSE_VECTOR.value, + limit=kwargs.get("top_k", 4), + output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + ) + + return self._process_search_results( + results, + output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + score_threshold=float(kwargs.get("score_threshold") or 0.0), + ) def create_collection( self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None ): + """ + Create a new collection in Milvus with the specified schema and index parameters. + """ lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) @@ -161,8 +267,8 @@ def create_collection( return # Grab the existing collection if it exists if not self._client.has_collection(self._collection_name): - from pymilvus import CollectionSchema, DataType, FieldSchema - from pymilvus.orm.types import infer_dtype_bydata + from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore + from pymilvus.orm.types import infer_dtype_bydata # type: ignore # Determine embedding dim dim = len(embeddings[0]) @@ -170,16 +276,36 @@ def create_collection( if metadatas: fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) - # Create the text field - fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)) + # Create the text field, enable_analyzer will be set True to support milvus automatically + # transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md + fields.append( + FieldSchema( + Field.CONTENT_KEY.value, + DataType.VARCHAR, + max_length=65_535, + enable_analyzer=self._hybrid_search_enabled, + ) + ) # Create the primary key field fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True)) # Create the vector field, supports binary or float vectors fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)) + # Create Sparse Vector Index for the collection + if self._hybrid_search_enabled: + fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR)) - # Create the schema for the collection schema = CollectionSchema(fields) + # Create custom function to support text to sparse vector by BM25 + if self._hybrid_search_enabled: + bm25_function = Function( + name="text_bm25_emb", + input_field_names=[Field.CONTENT_KEY.value], + output_field_names=[Field.SPARSE_VECTOR.value], + function_type=FunctionType.BM25, + ) + schema.add_function(bm25_function) + for x in schema.fields: self._fields.append(x.name) # Since primary field is auto-id, no need to track it @@ -189,10 +315,15 @@ def create_collection( index_params_obj = IndexParams() index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params) + # Create Sparse Vector Index for the collection + if self._hybrid_search_enabled: + index_params_obj.add_index( + field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25" + ) + # Create the collection - collection_name = self._collection_name self._client.create_collection( - collection_name=collection_name, + collection_name=self._collection_name, schema=schema, index_params=index_params_obj, consistency_level=self._consistency_level, @@ -200,12 +331,22 @@ def create_collection( redis_client.set(collection_exist_cache_key, 1, ex=3600) def _init_client(self, config) -> MilvusClient: + """ + Initialize and return a Milvus client. + """ client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database) return client class MilvusVectorFactory(AbstractVectorFactory): + """ + Factory class for creating MilvusVector instances. + """ + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector: + """ + Initialize a MilvusVector instance for the given dataset. + """ if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix @@ -217,10 +358,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return MilvusVector( collection_name=collection_name, config=MilvusConfig( - uri=dify_config.MILVUS_URI, - token=dify_config.MILVUS_TOKEN, - user=dify_config.MILVUS_USER, - password=dify_config.MILVUS_PASSWORD, - database=dify_config.MILVUS_DATABASE, + uri=dify_config.MILVUS_URI or "", + token=dify_config.MILVUS_TOKEN or "", + user=dify_config.MILVUS_USER or "", + password=dify_config.MILVUS_PASSWORD or "", + database=dify_config.MILVUS_DATABASE or "", + enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False, ), ) diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index b7b6b803a..556b952ec 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -74,15 +74,16 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** columns = ["id", "text", "vector", "metadata"] values = [] for i, doc in enumerate(documents): - doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) - row = ( - doc_id, - self.escape_str(doc.page_content), - embeddings[i], - json.dumps(doc.metadata) if doc.metadata else {}, - ) - values.append(str(row)) - ids.append(doc_id) + if doc.metadata is not None: + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + row = ( + doc_id, + self.escape_str(doc.page_content), + embeddings[i], + json.dumps(doc.metadata) if doc.metadata else {}, + ) + values.append(str(row)) + ids.append(doc_id) sql = f""" INSERT INTO {self._config.database}.{self._collection_name} ({",".join(columns)}) VALUES {",".join(values)} @@ -99,6 +100,8 @@ def text_exists(self, id: str) -> bool: return results.row_count > 0 def delete_by_ids(self, ids: list[str]) -> None: + if not ids: + return self._client.command( f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}" ) diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index c44338d42..3c2d53ce7 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -4,7 +4,7 @@ from typing import Any from pydantic import BaseModel, model_validator -from pyobvector import VECTOR, ObVecClient +from pyobvector import VECTOR, ObVecClient # type: ignore from sqlalchemy import JSON, Column, String, func from sqlalchemy.dialects.mysql import LONGTEXT @@ -131,9 +131,11 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** def text_exists(self, id: str) -> bool: cur = self._client.get(table_name=self._collection_name, id=id) - return cur.rowcount != 0 + return bool(cur.rowcount != 0) def delete_by_ids(self, ids: list[str]) -> None: + if not ids: + return self._client.delete(table_name=self._collection_name, ids=ids) def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]: diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 7a976d7c3..72a150220 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -66,7 +66,7 @@ def get_type(self) -> str: return VectorType.OPENSEARCH def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - metadatas = [d.metadata for d in texts] + metadatas = [d.metadata if d.metadata is not None else {} for d in texts] self.create_collection(embeddings, metadatas) self.add_texts(texts, embeddings) @@ -244,7 +244,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) open_search_config = OpenSearchConfig( - host=dify_config.OPENSEARCH_HOST, + host=dify_config.OPENSEARCH_HOST or "localhost", port=dify_config.OPENSEARCH_PORT, user=dify_config.OPENSEARCH_USER, password=dify_config.OPENSEARCH_PASSWORD, diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 74608f1e1..a58df7eb9 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -5,7 +5,7 @@ from contextlib import contextmanager from typing import Any -import jieba.posseg as pseg +import jieba.posseg as pseg # type: ignore import numpy import oracledb from pydantic import BaseModel, model_validator @@ -88,12 +88,11 @@ def input_type_handler(self, cursor, value, arraysize): def numpy_converter_out(self, value): if value.typecode == "b": - dtype = numpy.int8 + return numpy.array(value, copy=False, dtype=numpy.int8) elif value.typecode == "f": - dtype = numpy.float32 + return numpy.array(value, copy=False, dtype=numpy.float32) else: - dtype = numpy.float64 - return numpy.array(value, copy=False, dtype=dtype) + return numpy.array(value, copy=False, dtype=numpy.float64) def output_type_handler(self, cursor, metadata): if metadata.type_code is oracledb.DB_TYPE_VECTOR: @@ -135,17 +134,18 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** values = [] pks = [] for i, doc in enumerate(documents): - doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) - pks.append(doc_id) - values.append( - ( - doc_id, - doc.page_content, - json.dumps(doc.metadata), - # array.array("f", embeddings[i]), - numpy.array(embeddings[i]), + if doc.metadata is not None: + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + pks.append(doc_id) + values.append( + ( + doc_id, + doc.page_content, + json.dumps(doc.metadata), + # array.array("f", embeddings[i]), + numpy.array(embeddings[i]), + ) ) - ) # print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") with self._get_cursor() as cur: cur.executemany( @@ -167,6 +167,8 @@ def get_by_ids(self, ids: list[str]) -> list[Document]: return docs def delete_by_ids(self, ids: list[str]) -> None: + if not ids: + return with self._get_cursor() as cur: cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) @@ -201,8 +203,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: # lazy import - import nltk - from nltk.corpus import stopwords + import nltk # type: ignore + from nltk.corpus import stopwords # type: ignore top_k = kwargs.get("top_k", 5) # just not implement fetch by score_threshold now, may be later @@ -285,10 +287,10 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return OracleVector( collection_name=collection_name, config=OracleVectorConfig( - host=dify_config.ORACLE_HOST, + host=dify_config.ORACLE_HOST or "localhost", port=dify_config.ORACLE_PORT, - user=dify_config.ORACLE_USER, - password=dify_config.ORACLE_PASSWORD, - database=dify_config.ORACLE_DATABASE, + user=dify_config.ORACLE_USER or "system", + password=dify_config.ORACLE_PASSWORD or "oracle", + database=dify_config.ORACLE_DATABASE or "orcl", ), ) diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index 7cbbdcc81..221bc68d6 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -4,7 +4,7 @@ from uuid import UUID, uuid4 from numpy import ndarray -from pgvecto_rs.sqlalchemy import VECTOR +from pgvecto_rs.sqlalchemy import VECTOR # type: ignore from pydantic import BaseModel, model_validator from sqlalchemy import Float, String, create_engine, insert, select, text from sqlalchemy import text as sql_text @@ -58,7 +58,7 @@ def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int): with Session(self._client) as session: session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors")) session.commit() - self._fields = [] + self._fields: list[str] = [] class _Table(CollectionORM): __tablename__ = collection_name @@ -222,11 +222,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return PGVectoRS( collection_name=collection_name, config=PgvectoRSConfig( - host=dify_config.PGVECTO_RS_HOST, - port=dify_config.PGVECTO_RS_PORT, - user=dify_config.PGVECTO_RS_USER, - password=dify_config.PGVECTO_RS_PASSWORD, - database=dify_config.PGVECTO_RS_DATABASE, + host=dify_config.PGVECTO_RS_HOST or "localhost", + port=dify_config.PGVECTO_RS_PORT or 5432, + user=dify_config.PGVECTO_RS_USER or "postgres", + password=dify_config.PGVECTO_RS_PASSWORD or "", + database=dify_config.PGVECTO_RS_DATABASE or "postgres", ), dim=dim, ) diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 40a9cdd13..de443ba58 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -3,8 +3,8 @@ from contextlib import contextmanager from typing import Any -import psycopg2.extras -import psycopg2.pool +import psycopg2.extras # type: ignore +import psycopg2.pool # type: ignore from pydantic import BaseModel, model_validator from configs import dify_config @@ -98,16 +98,17 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** values = [] pks = [] for i, doc in enumerate(documents): - doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) - pks.append(doc_id) - values.append( - ( - doc_id, - doc.page_content, - json.dumps(doc.metadata), - embeddings[i], + if doc.metadata is not None: + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + pks.append(doc_id) + values.append( + ( + doc_id, + doc.page_content, + json.dumps(doc.metadata), + embeddings[i], + ) ) - ) with self._get_cursor() as cur: psycopg2.extras.execute_values( cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values @@ -128,6 +129,11 @@ def get_by_ids(self, ids: list[str]) -> list[Document]: return docs def delete_by_ids(self, ids: list[str]) -> None: + # Avoiding crashes caused by performing delete operations on empty lists in certain scenarios + # Scenario 1: extract a document fails, resulting in a table not being created. + # Then clicking the retry button triggers a delete operation on an empty list. + if not ids: + return with self._get_cursor() as cur: cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) @@ -216,11 +222,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return PGVector( collection_name=collection_name, config=PGVectorConfig( - host=dify_config.PGVECTOR_HOST, + host=dify_config.PGVECTOR_HOST or "localhost", port=dify_config.PGVECTOR_PORT, - user=dify_config.PGVECTOR_USER, - password=dify_config.PGVECTOR_PASSWORD, - database=dify_config.PGVECTOR_DATABASE, + user=dify_config.PGVECTOR_USER or "postgres", + password=dify_config.PGVECTOR_PASSWORD or "", + database=dify_config.PGVECTOR_DATABASE or "postgres", min_connection=dify_config.PGVECTOR_MIN_CONNECTION, max_connection=dify_config.PGVECTOR_MAX_CONNECTION, ), diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 3811458e0..6e94cb69d 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -51,6 +51,8 @@ def to_qdrant_params(self): if self.endpoint and self.endpoint.startswith("path:"): path = self.endpoint.replace("path:", "") if not os.path.isabs(path): + if not self.root_path: + raise ValueError("Root path is not set") path = os.path.join(self.root_path, path) return {"path": path} @@ -149,9 +151,12 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** uuids = self._get_uuids(documents) texts = [d.page_content for d in documents] metadatas = [d.metadata for d in documents] - added_ids = [] - for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): + # Filter out None values from metadatas list to match expected type + filtered_metadatas = [m for m in metadatas if m is not None] + for batch_ids, points in self._generate_rest_batches( + texts, embeddings, filtered_metadatas, uuids, 64, self._group_id + ): self._client.upsert(collection_name=self._collection_name, points=points) added_ids.extend(batch_ids) @@ -194,7 +199,7 @@ def _generate_rest_batches( batch_metadatas, Field.CONTENT_KEY.value, Field.METADATA_KEY.value, - group_id, + group_id or "", # Ensure group_id is never None Field.GROUP_KEY.value, ), ) @@ -337,18 +342,20 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc ) docs = [] for result in results: + if result.payload is None: + continue metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold score_threshold = float(kwargs.get("score_threshold") or 0.0) if result.score > score_threshold: metadata["score"] = result.score doc = Document( - page_content=result.payload.get(Field.CONTENT_KEY.value), + page_content=result.payload.get(Field.CONTENT_KEY.value, ""), metadata=metadata, ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -432,9 +439,9 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings collection_name=collection_name, group_id=dataset.id, config=QdrantConfig( - endpoint=dify_config.QDRANT_URL, + endpoint=dify_config.QDRANT_URL or "", api_key=dify_config.QDRANT_API_KEY, - root_path=current_app.config.root_path, + root_path=str(current_app.config.root_path), timeout=dify_config.QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.QDRANT_GRPC_PORT, prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index f373dcfea..a3a20448f 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -3,7 +3,7 @@ from typing import Any, Optional from pydantic import BaseModel, model_validator -from sqlalchemy import Column, Sequence, String, Table, create_engine, insert +from sqlalchemy import Column, String, Table, create_engine, insert from sqlalchemy import text as sql_text from sqlalchemy.dialects.postgresql import JSON, TEXT from sqlalchemy.orm import Session @@ -58,14 +58,14 @@ def __init__(self, collection_name: str, config: RelytConfig, group_id: str): f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" ) self.client = create_engine(self._url) - self._fields = [] + self._fields: list[str] = [] self._group_id = group_id def get_type(self) -> str: return VectorType.RELYT - def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - index_params = {} + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> None: + index_params: dict[str, Any] = {} metadatas = [d.metadata for d in texts] self.create_collection(len(embeddings[0])) self.embedding_dimension = len(embeddings[0]) @@ -107,10 +107,10 @@ def create_collection(self, dimension: int): redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): - from pgvecto_rs.sqlalchemy import VECTOR + from pgvecto_rs.sqlalchemy import VECTOR # type: ignore ids = [str(uuid.uuid1()) for _ in documents] - metadatas = [d.metadata for d in documents] + metadatas = [d.metadata for d in documents if d.metadata is not None] for metadata in metadatas: metadata["group_id"] = self._group_id texts = [d.page_content for d in documents] @@ -242,10 +242,6 @@ def similarity_search_with_score_by_vector( filter: Optional[dict] = None, ) -> list[tuple[Document, float]]: # Add the filter if provided - try: - from sqlalchemy.engine import Row - except ImportError: - raise ImportError("Could not import Row from sqlalchemy.engine. Please 'pip install sqlalchemy>=1.4'.") filter_condition = "" if filter is not None: @@ -275,7 +271,7 @@ def similarity_search_with_score_by_vector( # Execute the query and fetch the results with self.client.connect() as conn: - results: Sequence[Row] = conn.execute(sql_text(sql_query), params).fetchall() + results = conn.execute(sql_text(sql_query), params).fetchall() documents_with_scores = [ ( @@ -307,11 +303,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return RelytVector( collection_name=collection_name, config=RelytConfig( - host=dify_config.RELYT_HOST, + host=dify_config.RELYT_HOST or "localhost", port=dify_config.RELYT_PORT, - user=dify_config.RELYT_USER, - password=dify_config.RELYT_PASSWORD, - database=dify_config.RELYT_DATABASE, + user=dify_config.RELYT_USER or "", + password=dify_config.RELYT_PASSWORD or "", + database=dify_config.RELYT_DATABASE or "default", ), group_id=dataset.id, ) diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index f971a9c5e..1a4fa7b87 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -2,10 +2,10 @@ from typing import Any, Optional from pydantic import BaseModel -from tcvectordb import VectorDBClient -from tcvectordb.model import document, enum -from tcvectordb.model import index as vdb_index -from tcvectordb.model.document import Filter +from tcvectordb import VectorDBClient # type: ignore +from tcvectordb.model import document, enum # type: ignore +from tcvectordb.model import index as vdb_index # type: ignore +from tcvectordb.model.document import Filter # type: ignore from configs import dify_config from core.rag.datasource.vdb.vector_base import BaseVector @@ -25,8 +25,8 @@ class TencentConfig(BaseModel): database: Optional[str] index_type: str = "HNSW" metric_type: str = "L2" - shard: int = (1,) - replicas: int = (2,) + shard: int = 1 + replicas: int = 2 def to_tencent_params(self): return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} @@ -120,15 +120,15 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** metadatas = [doc.metadata for doc in documents] total_count = len(embeddings) docs = [] - for id in range(0, total_count): + for i in range(0, total_count): if metadatas is None: continue - metadata = json.dumps(metadatas[id]) + metadata = metadatas[i] or {} doc = document.Document( - id=metadatas[id]["doc_id"], - vector=embeddings[id], - text=texts[id], - metadata=metadata, + id=metadata.get("doc_id"), + vector=embeddings[i], + text=texts[i], + metadata=json.dumps(metadata), ) docs.append(doc) self._db.collection(self._collection_name).upsert(docs, self._client_config.timeout) @@ -140,6 +140,8 @@ def text_exists(self, id: str) -> bool: return False def delete_by_ids(self, ids: list[str]) -> None: + if not ids: + return self._db.collection(self._collection_name).delete(document_ids=ids) def delete_by_metadata_field(self, key: str, value: str) -> None: @@ -159,8 +161,8 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return [] - def _get_search_res(self, res, score_threshold): - docs = [] + def _get_search_res(self, res: list | None, score_threshold: float) -> list[Document]: + docs: list[Document] = [] if res is None or len(res) == 0: return docs @@ -193,7 +195,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return TencentVector( collection_name=collection_name, config=TencentConfig( - url=dify_config.TENCENT_VECTOR_DB_URL, + url=dify_config.TENCENT_VECTOR_DB_URL or "", api_key=dify_config.TENCENT_VECTOR_DB_API_KEY, timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT, username=dify_config.TENCENT_VECTOR_DB_USERNAME, diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index cfd47aac5..549f0175e 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -54,7 +54,10 @@ def to_qdrant_params(self): if self.endpoint and self.endpoint.startswith("path:"): path = self.endpoint.replace("path:", "") if not os.path.isabs(path): - path = os.path.join(self.root_path, path) + if self.root_path: + path = os.path.join(self.root_path, path) + else: + raise ValueError("root_path is required") return {"path": path} else: @@ -157,7 +160,7 @@ def create_collection(self, collection_name: str, vector_size: int): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) texts = [d.page_content for d in documents] - metadatas = [d.metadata for d in documents] + metadatas = [d.metadata for d in documents if d.metadata is not None] added_ids = [] for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): @@ -203,7 +206,7 @@ def _generate_rest_batches( batch_metadatas, Field.CONTENT_KEY.value, Field.METADATA_KEY.value, - group_id, + group_id or "", Field.GROUP_KEY.value, ), ) @@ -334,18 +337,20 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc ) docs = [] for result in results: + if result.payload is None: + continue metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold score_threshold = kwargs.get("score_threshold") or 0.0 if result.score > score_threshold: metadata["score"] = result.score doc = Document( - page_content=result.payload.get(Field.CONTENT_KEY.value), + page_content=result.payload.get(Field.CONTENT_KEY.value, ""), metadata=metadata, ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -404,35 +409,35 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() ) if not tidb_auth_binding: - idle_tidb_auth_binding = ( - db.session.query(TidbAuthBinding) - .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") - .limit(1) - .one_or_none() - ) - if idle_tidb_auth_binding: - idle_tidb_auth_binding.active = True - idle_tidb_auth_binding.tenant_id = dataset.tenant_id - db.session.commit() - TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}" - else: - with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): - tidb_auth_binding = ( + with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): + tidb_auth_binding = ( + db.session.query(TidbAuthBinding) + .filter(TidbAuthBinding.tenant_id == dataset.tenant_id) + .one_or_none() + ) + if tidb_auth_binding: + TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" + + else: + idle_tidb_auth_binding = ( db.session.query(TidbAuthBinding) - .filter(TidbAuthBinding.tenant_id == dataset.tenant_id) + .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") + .limit(1) .one_or_none() ) - if tidb_auth_binding: - TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" - + if idle_tidb_auth_binding: + idle_tidb_auth_binding.active = True + idle_tidb_auth_binding.tenant_id = dataset.tenant_id + db.session.commit() + TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}" else: new_cluster = TidbService.create_tidb_serverless_cluster( - dify_config.TIDB_PROJECT_ID, - dify_config.TIDB_API_URL, - dify_config.TIDB_IAM_API_URL, - dify_config.TIDB_PUBLIC_KEY, - dify_config.TIDB_PRIVATE_KEY, - dify_config.TIDB_REGION, + dify_config.TIDB_PROJECT_ID or "", + dify_config.TIDB_API_URL or "", + dify_config.TIDB_IAM_API_URL or "", + dify_config.TIDB_PUBLIC_KEY or "", + dify_config.TIDB_PRIVATE_KEY or "", + dify_config.TIDB_REGION or "", ) new_tidb_auth_binding = TidbAuthBinding( cluster_id=new_cluster["cluster_id"], @@ -446,7 +451,6 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings db.session.add(new_tidb_auth_binding) db.session.commit() TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}" - else: TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" @@ -464,9 +468,9 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings collection_name=collection_name, group_id=dataset.id, config=TidbOnQdrantConfig( - endpoint=dify_config.TIDB_ON_QDRANT_URL, + endpoint=dify_config.TIDB_ON_QDRANT_URL or "", api_key=TIDB_ON_QDRANT_API_KEY, - root_path=config.root_path, + root_path=str(config.root_path), timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT, prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED, diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 8dd5922ad..0a48c7951 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -146,7 +146,7 @@ def batch_update_tidb_serverless_cluster_status( iam_url: str, public_key: str, private_key: str, - ) -> list[dict]: + ): """ Update the status of a new TiDB Serverless cluster. :param project_id: The project ID of the TiDB Cloud project (required). @@ -159,7 +159,6 @@ def batch_update_tidb_serverless_cluster_status( :return: The response from the API. """ - clusters = [] tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list} cluster_ids = [item.cluster_id for item in tidb_serverless_list] params = {"clusterIds": cluster_ids, "view": "BASIC"} @@ -169,7 +168,6 @@ def batch_update_tidb_serverless_cluster_status( if response.status_code == 200: response_data = response.json() - cluster_infos = [] for item in response_data["clusters"]: state = item["state"] userPrefix = item["userPrefix"] @@ -236,16 +234,17 @@ def batch_create_tidb_serverless_cluster( cluster_infos = [] for item in response_data["clusters"]: cache_key = f"tidb_serverless_cluster_password:{item['displayName']}" - password = redis_client.get(cache_key) - if not password: + cached_password = redis_client.get(cache_key) + if not cached_password: continue cluster_info = { "cluster_id": item["clusterId"], "cluster_name": item["displayName"], "account": "root", - "password": password.decode("utf-8"), + "password": cached_password.decode("utf-8"), } cluster_infos.append(cluster_info) return cluster_infos else: response.raise_for_status() + return [] # FIXME for mypy, This line will not be reached as raise_for_status() will raise an exception diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 39ab6ea71..be3a41739 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -49,7 +49,7 @@ def get_type(self) -> str: return VectorType.TIDB_VECTOR def _table(self, dim: int) -> Table: - from tidb_vector.sqlalchemy import VectorType + from tidb_vector.sqlalchemy import VectorType # type: ignore return Table( self._collection_name, @@ -241,11 +241,11 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return TiDBVector( collection_name=collection_name, config=TiDBVectorConfig( - host=dify_config.TIDB_VECTOR_HOST, - port=dify_config.TIDB_VECTOR_PORT, - user=dify_config.TIDB_VECTOR_USER, - password=dify_config.TIDB_VECTOR_PASSWORD, - database=dify_config.TIDB_VECTOR_DATABASE, + host=dify_config.TIDB_VECTOR_HOST or "", + port=dify_config.TIDB_VECTOR_PORT or 0, + user=dify_config.TIDB_VECTOR_USER or "", + password=dify_config.TIDB_VECTOR_PASSWORD or "", + database=dify_config.TIDB_VECTOR_DATABASE or "", program_name=dify_config.APPLICATION_NAME, ), ) diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index 22e191340..edfce2edd 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -51,15 +51,16 @@ def delete(self) -> None: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts.copy(): - doc_id = text.metadata["doc_id"] - exists_duplicate_node = self.text_exists(doc_id) - if exists_duplicate_node: - texts.remove(text) + if text.metadata and "doc_id" in text.metadata: + doc_id = text.metadata["doc_id"] + exists_duplicate_node = self.text_exists(doc_id) + if exists_duplicate_node: + texts.remove(text) return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata["doc_id"] for text in texts] + return [text.metadata["doc_id"] for text in texts if text.metadata and "doc_id" in text.metadata] @property def collection_name(self): diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 6d2e04fc0..bdc40e29c 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -90,6 +90,12 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]: from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory return ElasticSearchVectorFactory + case VectorType.ELASTICSEARCH_JA: + from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import ( + ElasticSearchJaVectorFactory, + ) + + return ElasticSearchJaVectorFactory case VectorType.TIDB_VECTOR: from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory @@ -193,10 +199,13 @@ def _get_embeddings(self) -> Embeddings: def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts.copy(): + if text.metadata is None: + continue doc_id = text.metadata["doc_id"] - exists_duplicate_node = self.text_exists(doc_id) - if exists_duplicate_node: - texts.remove(text) + if doc_id: + exists_duplicate_node = self.text_exists(doc_id) + if exists_duplicate_node: + texts.remove(text) return texts diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 05183c037..e73411aa0 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -16,6 +16,7 @@ class VectorType(StrEnum): TENCENT = "tencent" ORACLE = "oracle" ELASTICSEARCH = "elasticsearch" + ELASTICSEARCH_JA = "elasticsearch-ja" LINDORM = "lindorm" COUCHBASE = "couchbase" BAIDU = "baidu" diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index 4f927f289..9de8761a9 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -2,7 +2,7 @@ from typing import Any from pydantic import BaseModel -from volcengine.viking_db import ( +from volcengine.viking_db import ( # type: ignore Data, DistanceType, Field, @@ -121,11 +121,12 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** for i, page_content in enumerate(page_contents): metadata = {} if metadatas is not None: - for key, val in metadatas[i].items(): + for key, val in (metadatas[i] or {}).items(): metadata[key] = val + # FIXME: fix the type of metadata later doc = Data( { - vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], + vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], # type: ignore vdb_Field.VECTOR.value: embeddings[i] if embeddings else None, vdb_Field.CONTENT_KEY.value: page_content, vdb_Field.METADATA_KEY.value: json.dumps(metadata), @@ -178,7 +179,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._get_search_res(results, score_threshold) - def _get_search_res(self, results, score_threshold): + def _get_search_res(self, results, score_threshold) -> list[Document]: if len(results) == 0: return [] @@ -191,7 +192,7 @@ def _get_search_res(self, results, score_threshold): metadata["score"] = result.score doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 649cfbfea..68d043a19 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -3,7 +3,7 @@ from typing import Any, Optional import requests -import weaviate +import weaviate # type: ignore from pydantic import BaseModel, model_validator from configs import dify_config @@ -107,7 +107,8 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** for i, text in enumerate(texts): data_properties = {Field.TEXT_KEY.value: text} if metadatas is not None: - for key, val in metadatas[i].items(): + # metadata maybe None + for key, val in (metadatas[i] or {}).items(): data_properties[key] = self._json_serializable(val) batch.add_data_object( @@ -208,10 +209,11 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc score_threshold = float(kwargs.get("score_threshold") or 0.0) # check score threshold if score > score_threshold: - doc.metadata["score"] = score - docs.append(doc) + if doc.metadata is not None: + doc.metadata["score"] = score + docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -275,7 +277,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings return WeaviateVector( collection_name=collection_name, config=WeaviateConfig( - endpoint=dify_config.WEAVIATE_ENDPOINT, + endpoint=dify_config.WEAVIATE_ENDPOINT or "", api_key=dify_config.WEAVIATE_API_KEY, batch_size=dify_config.WEAVIATE_BATCH_SIZE, ), diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 319a2612c..8b95d81cc 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -7,7 +7,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.rag.models.document import Document from extensions.ext_database import db -from models.dataset import Dataset, DocumentSegment +from models.dataset import ChildChunk, Dataset, DocumentSegment class DatasetDocumentStore: @@ -60,7 +60,7 @@ def docs(self) -> dict[str, Document]: return output - def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None: + def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None: max_position = ( db.session.query(func.max(DocumentSegment.position)) .filter(DocumentSegment.document_id == self._document_id) @@ -83,6 +83,9 @@ def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> if not isinstance(doc, Document): raise ValueError("doc must be a Document") + if doc.metadata is None: + raise ValueError("doc.metadata must be a dict") + segment_document = self.get_document_segment(doc_id=doc.metadata["doc_id"]) # NOTE: doc could already exist in the store, but we overwrite it @@ -117,13 +120,55 @@ def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> segment_document.answer = doc.metadata.pop("answer", "") db.session.add(segment_document) + db.session.flush() + if save_child: + if doc.children: + for postion, child in enumerate(doc.children, start=1): + child_segment = ChildChunk( + tenant_id=self._dataset.tenant_id, + dataset_id=self._dataset.id, + document_id=self._document_id, + segment_id=segment_document.id, + position=postion, + index_node_id=child.metadata.get("doc_id"), + index_node_hash=child.metadata.get("doc_hash"), + content=child.page_content, + word_count=len(child.page_content), + type="automatic", + created_by=self._user_id, + ) + db.session.add(child_segment) else: segment_document.content = doc.page_content if doc.metadata.get("answer"): segment_document.answer = doc.metadata.pop("answer", "") - segment_document.index_node_hash = doc.metadata["doc_hash"] + segment_document.index_node_hash = doc.metadata.get("doc_hash") segment_document.word_count = len(doc.page_content) segment_document.tokens = tokens + if save_child and doc.children: + # delete the existing child chunks + db.session.query(ChildChunk).filter( + ChildChunk.tenant_id == self._dataset.tenant_id, + ChildChunk.dataset_id == self._dataset.id, + ChildChunk.document_id == self._document_id, + ChildChunk.segment_id == segment_document.id, + ).delete() + # add new child chunks + for position, child in enumerate(doc.children, start=1): + child_segment = ChildChunk( + tenant_id=self._dataset.tenant_id, + dataset_id=self._dataset.id, + document_id=self._document_id, + segment_id=segment_document.id, + position=position, + index_node_id=child.metadata.get("doc_id"), + index_node_hash=child.metadata.get("doc_hash"), + content=child.page_content, + word_count=len(child.page_content), + type="automatic", + created_by=self._user_id, + ) + db.session.add(child_segment) db.session.commit() @@ -179,10 +224,10 @@ def get_document_hash(self, doc_id: str) -> Optional[str]: if document_segment is None: return None + data: Optional[str] = document_segment.index_node_hash + return data - return document_segment.index_node_hash - - def get_document_segment(self, doc_id: str) -> DocumentSegment: + def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: document_segment = ( db.session.query(DocumentSegment) .filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 8ddda7e98..a2c8737da 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -1,6 +1,6 @@ import base64 import logging -from typing import Optional, cast +from typing import Any, Optional, cast import numpy as np from sqlalchemy.exc import IntegrityError @@ -27,7 +27,7 @@ def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs in batches of 10.""" # use doc embedding cache or store if not exists - text_embeddings = [None for _ in range(len(texts))] + text_embeddings: list[Any] = [None for _ in range(len(texts))] embedding_queue_indices = [] for i, text in enumerate(texts): hash = helper.generate_text_hash(text) @@ -64,7 +64,8 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: for vector in embedding_result.embeddings: try: - normalized_embedding = (vector / np.linalg.norm(vector)).tolist() + # FIXME: type ignore for numpy here + normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore # stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan if np.isnan(normalized_embedding).any(): # for issue #11827 float values are not json compliant @@ -77,8 +78,8 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: logging.exception("Failed transform embedding") cache_embeddings = [] try: - for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): - text_embeddings[i] = embedding + for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings): + text_embeddings[i] = n_embedding hash = helper.generate_text_hash(texts[i]) if hash not in cache_embeddings: embedding_cache = Embedding( @@ -86,7 +87,7 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: hash=hash, provider_name=self._model_instance.provider, ) - embedding_cache.set_embedding(embedding) + embedding_cache.set_embedding(n_embedding) db.session.add(embedding_cache) cache_embeddings.append(hash) db.session.commit() @@ -115,7 +116,8 @@ def embed_query(self, text: str) -> list[float]: ) embedding_results = embedding_result.embeddings[0] - embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() + # FIXME: type ignore for numpy here + embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore if np.isnan(embedding_results).any(): raise ValueError("Normalized embedding is nan please try again") except Exception as ex: diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py new file mode 100644 index 000000000..800422d88 --- /dev/null +++ b/api/core/rag/embedding/retrieval.py @@ -0,0 +1,23 @@ +from typing import Optional + +from pydantic import BaseModel + +from models.dataset import DocumentSegment + + +class RetrievalChildChunk(BaseModel): + """Retrieval segments.""" + + id: str + content: str + score: float + position: int + + +class RetrievalSegments(BaseModel): + """Retrieval segments.""" + + model_config = {"arbitrary_types_allowed": True} + segment: DocumentSegment + child_chunks: Optional[list[RetrievalChildChunk]] = None + score: Optional[float] = None diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index 3692b5d19..7c00c668d 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -14,7 +14,7 @@ class NotionInfo(BaseModel): notion_workspace_id: str notion_obj_id: str notion_page_type: str - document: Document = None + document: Optional[Document] = None tenant_id: str model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index fc3316571..a3b35458d 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -1,10 +1,10 @@ """Abstract interface for document loader implementations.""" import os -from typing import Optional +from typing import Optional, cast import pandas as pd -from openpyxl import load_workbook +from openpyxl import load_workbook # type: ignore from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -47,7 +47,7 @@ def extract(self) -> list[Document]: for col_index, (k, v) in enumerate(row.items()): if pd.notna(v): cell = sheet.cell( - row=index + 2, column=col_index + 1 + row=cast(int, index) + 2, column=col_index + 1 ) # +2 to account for header and 1-based index if cell.hyperlink: value = f"[{v}]({cell.hyperlink.target})" @@ -60,8 +60,8 @@ def extract(self) -> list[Document]: elif file_extension == ".xls": excel_file = pd.ExcelFile(self._file_path, engine="xlrd") - for sheet_name in excel_file.sheet_names: - df = excel_file.parse(sheet_name=sheet_name) + for excel_sheet_name in excel_file.sheet_names: + df = excel_file.parse(sheet_name=excel_sheet_name) df.dropna(how="all", inplace=True) for _, row in df.iterrows(): diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 69659e310..f9fd7f92a 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -10,6 +10,7 @@ from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.excel_extractor import ExcelExtractor +from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor from core.rag.extractor.html_extractor import HtmlExtractor from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor @@ -23,7 +24,6 @@ from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor -from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor from core.rag.extractor.word_extractor import WordExtractor from core.rag.models.document import Document @@ -66,9 +66,13 @@ def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Docume filename_match = re.search(r'filename="([^"]+)"', content_disposition) if filename_match: filename = unquote(filename_match.group(1)) - suffix = "." + re.search(r"\.(\w+)$", filename).group(1) - - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" + match = re.search(r"\.(\w+)$", filename) + if match: + suffix = "." + match.group(1) + else: + suffix = "" + # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore Path(file_path).write_bytes(response.content) extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model") if return_text: @@ -89,16 +93,20 @@ def extract( if extract_setting.datasource_type == DatasourceType.FILE.value: with tempfile.TemporaryDirectory() as temp_dir: if not file_path: + assert extract_setting.upload_file is not None, "upload_file is required" upload_file: UploadFile = extract_setting.upload_file suffix = Path(upload_file.key).suffix - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" + # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore storage.download(upload_file.key, file_path) input_file = Path(file_path) file_extension = input_file.suffix.lower() etl_type = dify_config.ETL_TYPE - unstructured_api_url = dify_config.UNSTRUCTURED_API_URL - unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY + extractor: Optional[BaseExtractor] = None if etl_type == "Unstructured": + unstructured_api_url = dify_config.UNSTRUCTURED_API_URL + unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY or "" + if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": @@ -131,11 +139,7 @@ def extract( extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key) else: # txt - extractor = ( - UnstructuredTextExtractor(file_path, unstructured_api_url) - if is_automatic - else TextExtractor(file_path, autodetect_encoding=True) - ) + extractor = TextExtractor(file_path, autodetect_encoding=True) else: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) @@ -156,6 +160,7 @@ def extract( extractor = TextExtractor(file_path, autodetect_encoding=True) return extractor.extract() elif extract_setting.datasource_type == DatasourceType.NOTION.value: + assert extract_setting.notion_info is not None, "notion_info is required" extractor = NotionExtractor( notion_workspace_id=extract_setting.notion_info.notion_workspace_id, notion_obj_id=extract_setting.notion_info.notion_obj_id, @@ -165,6 +170,7 @@ def extract( ) return extractor.extract() elif extract_setting.datasource_type == DatasourceType.WEBSITE.value: + assert extract_setting.website_info is not None, "website_info is required" if extract_setting.website_info.provider == "firecrawl": extractor = FirecrawlWebExtractor( url=extract_setting.website_info.url, diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 17c2087a0..8ae4579c7 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -1,5 +1,6 @@ import json import time +from typing import cast import requests @@ -20,9 +21,9 @@ def scrape_url(self, url, params=None) -> dict: json_data.update(params) response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data) if response.status_code == 200: - response = response.json() - if response["success"] == True: - data = response["data"] + response_data = response.json() + if response_data["success"] == True: + data = response_data["data"] return { "title": data.get("metadata").get("title"), "description": data.get("metadata").get("description"), @@ -30,7 +31,7 @@ def scrape_url(self, url, params=None) -> dict: "markdown": data.get("markdown"), } else: - raise Exception(f'Failed to scrape URL. Error: {response["error"]}') + raise Exception(f'Failed to scrape URL. Error: {response_data["error"]}') elif response.status_code in {402, 409, 500}: error_message = response.json().get("error", "Unknown error occurred") @@ -46,9 +47,11 @@ def crawl_url(self, url, params=None) -> str: response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers) if response.status_code == 200: job_id = response.json().get("jobId") - return job_id + return cast(str, job_id) else: self._handle_error(response, "start crawl job") + # FIXME: unreachable code for mypy + return "" # unreachable def check_crawl_status(self, job_id) -> dict: headers = self._prepare_headers() @@ -64,9 +67,9 @@ def check_crawl_status(self, job_id) -> dict: for item in data: if isinstance(item, dict) and "metadata" in item and "markdown" in item: url_data = { - "title": item.get("metadata").get("title"), - "description": item.get("metadata").get("description"), - "source_url": item.get("metadata").get("sourceURL"), + "title": item.get("metadata", {}).get("title"), + "description": item.get("metadata", {}).get("description"), + "source_url": item.get("metadata", {}).get("sourceURL"), "markdown": item.get("markdown"), } url_data_list.append(url_data) @@ -92,6 +95,8 @@ def check_crawl_status(self, job_id) -> dict: else: self._handle_error(response, "check crawl status") + # FIXME: unreachable code for mypy + return {} # unreachable def _prepare_headers(self): return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} diff --git a/api/core/rag/extractor/html_extractor.py b/api/core/rag/extractor/html_extractor.py index 560c2d1d8..350b52234 100644 --- a/api/core/rag/extractor/html_extractor.py +++ b/api/core/rag/extractor/html_extractor.py @@ -1,6 +1,6 @@ """Abstract interface for document loader implementations.""" -from bs4 import BeautifulSoup +from bs4 import BeautifulSoup # type: ignore from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -23,6 +23,7 @@ def extract(self) -> list[Document]: return [Document(page_content=self._load_as_text())] def _load_as_text(self) -> str: + text: str = "" with open(self._file_path, "rb") as fp: soup = BeautifulSoup(fp, "html.parser") text = soup.get_text() diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 87a4ce08b..41355d3fa 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Optional +from typing import Any, Optional, cast import requests @@ -78,6 +78,7 @@ def _load_data_as_documents(self, notion_obj_id: str, notion_page_type: str) -> def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]: """Get all the pages from a Notion database.""" + assert self._notion_access_token is not None, "Notion access token is required" res = requests.post( DATABASE_URL_TMPL.format(database_id=database_id), headers={ @@ -96,6 +97,7 @@ def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] for result in data["results"]: properties = result["properties"] data = {} + value: Any for property_name, property_value in properties.items(): type = property_value["type"] if type == "multi_select": @@ -130,22 +132,30 @@ def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] return [Document(page_content="\n".join(database_content))] def _get_notion_block_data(self, page_id: str) -> list[str]: + assert self._notion_access_token is not None, "Notion access token is required" result_lines_arr = [] start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id) while True: query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} - res = requests.request( - "GET", - block_url, - headers={ - "Authorization": "Bearer " + self._notion_access_token, - "Content-Type": "application/json", - "Notion-Version": "2022-06-28", - }, - params=query_dict, - ) - data = res.json() + try: + res = requests.request( + "GET", + block_url, + headers={ + "Authorization": "Bearer " + self._notion_access_token, + "Content-Type": "application/json", + "Notion-Version": "2022-06-28", + }, + params=query_dict, + ) + if res.status_code != 200: + raise ValueError(f"Error fetching Notion block data: {res.text}") + data = res.json() + except requests.RequestException as e: + raise ValueError("Error fetching Notion block data") from e + if "results" not in data or not isinstance(data["results"], list): + raise ValueError("Error fetching Notion block data") for result in data["results"]: result_type = result["type"] result_obj = result[result_type] @@ -184,6 +194,7 @@ def _get_notion_block_data(self, page_id: str) -> list[str]: def _read_block(self, block_id: str, num_tabs: int = 0) -> str: """Read a block.""" + assert self._notion_access_token is not None, "Notion access token is required" result_lines_arr = [] start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id) @@ -242,6 +253,7 @@ def _read_block(self, block_id: str, num_tabs: int = 0) -> str: def _read_table_rows(self, block_id: str) -> str: """Read table rows.""" + assert self._notion_access_token is not None, "Notion access token is required" done = False result_lines_arr = [] start_cursor = None @@ -296,7 +308,7 @@ def _read_table_rows(self, block_id: str) -> str: result_lines = "\n".join(result_lines_arr) return result_lines - def update_last_edited_time(self, document_model: DocumentModel): + def update_last_edited_time(self, document_model: Optional[DocumentModel]): if not document_model: return @@ -309,6 +321,7 @@ def update_last_edited_time(self, document_model: DocumentModel): db.session.commit() def get_notion_last_edited_time(self) -> str: + assert self._notion_access_token is not None, "Notion access token is required" obj_id = self._notion_obj_id page_type = self._notion_page_type if page_type == "database": @@ -330,7 +343,7 @@ def get_notion_last_edited_time(self) -> str: ) data = res.json() - return data["last_edited_time"] + return cast(str, data["last_edited_time"]) @classmethod def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: @@ -349,4 +362,4 @@ def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: f"and notion workspace {notion_workspace_id}" ) - return data_source_binding.access_token + return cast(str, data_source_binding.access_token) diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 57cb9610b..04033dec3 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -1,7 +1,7 @@ """Abstract interface for document loader implementations.""" from collections.abc import Iterator -from typing import Optional +from typing import Optional, cast from core.rag.extractor.blob.blob import Blob from core.rag.extractor.extractor_base import BaseExtractor @@ -23,11 +23,10 @@ def __init__(self, file_path: str, file_cache_key: Optional[str] = None): self._file_cache_key = file_cache_key def extract(self) -> list[Document]: - plaintext_file_key = "" plaintext_file_exists = False if self._file_cache_key: try: - text = storage.load(self._file_cache_key).decode("utf-8") + text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8") plaintext_file_exists = True return [Document(page_content=text)] except FileNotFoundError: @@ -39,8 +38,8 @@ def extract(self) -> list[Document]: text = "\n\n".join(text_list) # save plaintext file for caching - if not plaintext_file_exists and plaintext_file_key: - storage.save(plaintext_file_key, text.encode("utf-8")) + if not plaintext_file_exists and self._file_cache_key: + storage.save(self._file_cache_key, text.encode("utf-8")) return documents @@ -53,7 +52,7 @@ def load( def parse(self, blob: Blob) -> Iterator[Document]: """Lazily parse the blob.""" - import pypdfium2 + import pypdfium2 # type: ignore with blob.as_bytes_io() as file_path: pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True) diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index bd669bbad..f1fa5dde5 100644 --- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -1,7 +1,8 @@ import base64 import logging +from typing import Optional -from bs4 import BeautifulSoup +from bs4 import BeautifulSoup # type: ignore from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -15,7 +16,7 @@ class UnstructuredEmailExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: str, api_key: str): + def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py index 35220b558..35ca686f6 100644 --- a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py @@ -19,7 +19,7 @@ def __init__( self, file_path: str, api_url: Optional[str] = None, - api_key: Optional[str] = None, + api_key: str = "", ): """Initialize with file path.""" self._file_path = file_path diff --git a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py index 4173d4d12..d5418e612 100644 --- a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -24,7 +25,7 @@ class UnstructuredMarkdownExtractor(BaseExtractor): if the specified encoding fails. """ - def __init__(self, file_path: str, api_url: str, api_key: str): + def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py index 57affb8d3..d363449c2 100644 --- a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -14,7 +15,7 @@ class UnstructuredMsgExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: str, api_key: str): + def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py index 0fdcd58b2..ecc272a2f 100644 --- a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -14,7 +15,7 @@ class UnstructuredPPTExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: str, api_key: str): + def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url @@ -27,9 +28,11 @@ def extract(self) -> list[Document]: elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) else: raise NotImplementedError("Unstructured API Url is not configured") - text_by_page = {} + text_by_page: dict[int, str] = {} for element in elements: page = element.metadata.page_number + if page is None: + continue text = element.text if page in text_by_page: text_by_page[page] += "\n" + text diff --git a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py index ab41290fb..e7bf6fd2e 100644 --- a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -14,7 +15,7 @@ class UnstructuredPPTXExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: str, api_key: str): + def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url @@ -29,14 +30,15 @@ def extract(self) -> list[Document]: from unstructured.partition.pptx import partition_pptx elements = partition_pptx(filename=self._file_path) - text_by_page = {} + text_by_page: dict[int, str] = {} for element in elements: page = element.metadata.page_number text = element.text - if page in text_by_page: - text_by_page[page] += "\n" + text - else: - text_by_page[page] = text + if page is not None: + if page in text_by_page: + text_by_page[page] += "\n" + text + else: + text_by_page[page] = text combined_texts = list(text_by_page.values()) documents = [] diff --git a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py index ef46ab0e7..916cdc3f2 100644 --- a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -14,7 +15,7 @@ class UnstructuredXmlExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__(self, file_path: str, api_url: str, api_key: str): + def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 0c38a9c07..d93de5fef 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -89,6 +89,8 @@ def _extract_images_from_docx(self, doc, image_folder): response = ssrf_proxy.get(url) if response.status_code == 200: image_ext = mimetypes.guess_extension(response.headers["Content-Type"]) + if image_ext is None: + continue file_uuid = str(uuid.uuid4()) file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext mime_type, _ = mimetypes.guess_type(file_key) @@ -97,6 +99,8 @@ def _extract_images_from_docx(self, doc, image_folder): continue else: image_ext = rel.target_ref.split(".")[-1] + if image_ext is None: + continue # user uuid as file name file_uuid = str(uuid.uuid4()) file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext @@ -226,6 +230,8 @@ def parse_docx(self, docx_path, image_folder): if x_child is None: continue if x.tag.endswith("instrText"): + if x.text is None: + continue for i in url_pattern.findall(x.text): hyperlinks_url = str(i) except Exception as e: @@ -261,8 +267,10 @@ def parse_paragraph(paragraph): if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph para = paragraphs.pop(0) parsed_paragraph = parse_paragraph(para) - if parsed_paragraph: + if parsed_paragraph.strip(): content.append(parsed_paragraph) + else: + content.append("\n") elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table table = tables.pop(0) content.append(self._table_to_markdown(table, image_map)) diff --git a/api/core/rag/index_processor/constant/index_type.py b/api/core/rag/index_processor/constant/index_type.py index e42cc44c6..0845b58e2 100644 --- a/api/core/rag/index_processor/constant/index_type.py +++ b/api/core/rag/index_processor/constant/index_type.py @@ -1,8 +1,7 @@ from enum import Enum -class IndexType(Enum): +class IndexType(str, Enum): PARAGRAPH_INDEX = "text_model" QA_INDEX = "qa_model" - PARENT_CHILD_INDEX = "parent_child_index" - SUMMARY_INDEX = "summary_index" + PARENT_CHILD_INDEX = "hierarchical_model" diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index be857bd12..2bcd1c79b 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -27,10 +27,10 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: raise NotImplementedError @abstractmethod - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): raise NotImplementedError - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): raise NotImplementedError @abstractmethod @@ -45,25 +45,29 @@ def retrieve( ) -> list[Document]: raise NotImplementedError - def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: + def _get_splitter( + self, + processing_rule_mode: str, + max_tokens: int, + chunk_overlap: int, + separator: str, + embedding_model_instance: Optional[ModelInstance], + ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ - if processing_rule["mode"] == "custom": + if processing_rule_mode in ["custom", "hierarchical"]: # The user-defined segmentation rule - rules = processing_rule["rules"] - segmentation = rules["segmentation"] max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH - if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: + if max_tokens < 50 or max_tokens > max_segmentation_tokens_length: raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") - separator = segmentation["separator"] if separator: separator = separator.replace("\\n", "\n") character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( - chunk_size=segmentation["max_tokens"], - chunk_overlap=segmentation.get("chunk_overlap", 0) or 0, + chunk_size=max_tokens, + chunk_overlap=chunk_overlap, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], embedding_model_instance=embedding_model_instance, @@ -77,4 +81,4 @@ def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optiona embedding_model_instance=embedding_model_instance, ) - return character_splitter + return character_splitter # type: ignore diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py index 9b855ece2..c987edf34 100644 --- a/api/core/rag/index_processor/index_processor_factory.py +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -3,13 +3,14 @@ from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor +from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor class IndexProcessorFactory: """IndexProcessorInit.""" - def __init__(self, index_type: str): + def __init__(self, index_type: str | None): self._index_type = index_type def init_index_processor(self) -> BaseIndexProcessor: @@ -18,9 +19,11 @@ def init_index_processor(self) -> BaseIndexProcessor: if not self._index_type: raise ValueError("Index type must be specified.") - if self._index_type == IndexType.PARAGRAPH_INDEX.value: + if self._index_type == IndexType.PARAGRAPH_INDEX: return ParagraphIndexProcessor() - elif self._index_type == IndexType.QA_INDEX.value: + elif self._index_type == IndexType.QA_INDEX: return QAIndexProcessor() + elif self._index_type == IndexType.PARENT_CHILD_INDEX: + return ParentChildIndexProcessor() else: raise ValueError(f"Index type {self._index_type} is not supported.") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index a631f953c..dca84b904 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -13,26 +13,46 @@ from core.rag.models.document import Document from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper -from models.dataset import Dataset +from models.dataset import Dataset, DatasetProcessRule +from services.entities.knowledge_entities.knowledge_entities import Rule class ParagraphIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( - extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" + extract_setting=extract_setting, + is_automatic=( + kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical" + ), ) return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]: + process_rule = kwargs.get("process_rule") + if not process_rule: + raise ValueError("No process rule found.") + if process_rule.get("mode") == "automatic": + automatic_rule = DatasetProcessRule.AUTOMATIC_RULES + rules = Rule(**automatic_rule) + else: + if not process_rule.get("rules"): + raise ValueError("No rules found in process rule.") + rules = Rule(**process_rule.get("rules")) # Split the text documents into nodes. + if not rules.segmentation: + raise ValueError("No segmentation found in rules.") splitter = self._get_splitter( - processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + processing_rule_mode=process_rule.get("mode"), + max_tokens=rules.segmentation.max_tokens, + chunk_overlap=rules.segmentation.chunk_overlap, + separator=rules.segmentation.separator, + embedding_model_instance=kwargs.get("embedding_model_instance"), ) all_documents = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule", {})) document.page_content = document_text # parse document to nodes document_nodes = splitter.split_documents([document]) @@ -41,8 +61,9 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata["doc_id"] = doc_id - document_node.metadata["doc_hash"] = hash + if document_node.metadata is not None: + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = remove_leading_symbols(document_node.page_content).strip() if len(page_content) > 0: @@ -51,15 +72,19 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: all_documents.extend(split_documents) return all_documents - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) if with_keywords: + keywords_list = kwargs.get("keywords_list") keyword = Keyword(dataset) - keyword.create(documents) + if keywords_list and len(keywords_list) > 0: + keyword.add_texts(documents, keywords_list=keywords_list) + else: + keyword.add_texts(documents) - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) if node_ids: diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py new file mode 100644 index 000000000..314012208 --- /dev/null +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -0,0 +1,200 @@ +"""Paragraph index processor.""" + +import uuid +from typing import Optional + +from configs import dify_config +from core.model_manager import ModelInstance +from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.extract_processor import ExtractProcessor +from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.models.document import ChildDocument, Document +from extensions.ext_database import db +from libs import helper +from models.dataset import ChildChunk, Dataset, DocumentSegment +from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule + + +class ParentChildIndexProcessor(BaseIndexProcessor): + def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: + text_docs = ExtractProcessor.extract( + extract_setting=extract_setting, + is_automatic=( + kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical" + ), + ) + + return text_docs + + def transform(self, documents: list[Document], **kwargs) -> list[Document]: + process_rule = kwargs.get("process_rule") + if not process_rule: + raise ValueError("No process rule found.") + if not process_rule.get("rules"): + raise ValueError("No rules found in process rule.") + rules = Rule(**process_rule.get("rules")) + all_documents = [] # type: ignore + if rules.parent_mode == ParentMode.PARAGRAPH: + # Split the text documents into nodes. + splitter = self._get_splitter( + processing_rule_mode=process_rule.get("mode"), + max_tokens=rules.segmentation.max_tokens, + chunk_overlap=rules.segmentation.chunk_overlap, + separator=rules.segmentation.separator, + embedding_model_instance=kwargs.get("embedding_model_instance"), + ) + for document in documents: + # document clean + document_text = CleanProcessor.clean(document.page_content, process_rule) + document.page_content = document_text + # parse document to nodes + document_nodes = splitter.split_documents([document]) + split_documents = [] + for document_node in document_nodes: + if document_node.page_content.strip(): + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document_node.page_content) + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash + # delete Splitter character + page_content = document_node.page_content + if page_content.startswith(".") or page_content.startswith("。"): + page_content = page_content[1:].strip() + else: + page_content = page_content + if len(page_content) > 0: + document_node.page_content = page_content + # parse document to child nodes + child_nodes = self._split_child_nodes( + document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") + ) + document_node.children = child_nodes + split_documents.append(document_node) + all_documents.extend(split_documents) + elif rules.parent_mode == ParentMode.FULL_DOC: + page_content = "\n".join([document.page_content for document in documents]) + document = Document(page_content=page_content, metadata=documents[0].metadata) + # parse document to child nodes + child_nodes = self._split_child_nodes( + document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") + ) + if kwargs.get("preview"): + if len(child_nodes) > dify_config.CHILD_CHUNKS_PREVIEW_NUMBER: + child_nodes = child_nodes[: dify_config.CHILD_CHUNKS_PREVIEW_NUMBER] + + document.children = child_nodes + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document.page_content) + document.metadata["doc_id"] = doc_id + document.metadata["doc_hash"] = hash + all_documents.append(document) + + return all_documents + + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + if dataset.indexing_technique == "high_quality": + vector = Vector(dataset) + for document in documents: + child_documents = document.children + if child_documents: + formatted_child_documents = [ + Document(**child_document.model_dump()) for child_document in child_documents + ] + vector.create(formatted_child_documents) + + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): + # node_ids is segment's node_ids + if dataset.indexing_technique == "high_quality": + delete_child_chunks = kwargs.get("delete_child_chunks") or False + vector = Vector(dataset) + if node_ids: + child_node_ids = ( + db.session.query(ChildChunk.index_node_id) + .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ChildChunk.dataset_id == dataset.id, + ) + .all() + ) + child_node_ids = [child_node_id[0] for child_node_id in child_node_ids] + vector.delete_by_ids(child_node_ids) + if delete_child_chunks: + db.session.query(ChildChunk).filter( + ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) + ).delete() + db.session.commit() + else: + vector.delete() + + if delete_child_chunks: + db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete() + db.session.commit() + + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ) -> list[Document]: + # Set search parameters. + results = RetrievalService.retrieve( + retrieval_method=retrieval_method, + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + ) + # Organize results. + docs = [] + for result in results: + metadata = result.metadata + metadata["score"] = result.score + if result.score > score_threshold: + doc = Document(page_content=result.page_content, metadata=metadata) + docs.append(doc) + return docs + + def _split_child_nodes( + self, + document_node: Document, + rules: Rule, + process_rule_mode: str, + embedding_model_instance: Optional[ModelInstance], + ) -> list[ChildDocument]: + if not rules.subchunk_segmentation: + raise ValueError("No subchunk segmentation found in rules.") + child_splitter = self._get_splitter( + processing_rule_mode=process_rule_mode, + max_tokens=rules.subchunk_segmentation.max_tokens, + chunk_overlap=rules.subchunk_segmentation.chunk_overlap, + separator=rules.subchunk_segmentation.separator, + embedding_model_instance=embedding_model_instance, + ) + # parse document to child nodes + child_nodes = [] + child_documents = child_splitter.split_documents([document_node]) + for child_document_node in child_documents: + if child_document_node.page_content.strip(): + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(child_document_node.page_content) + child_document = ChildDocument( + page_content=child_document_node.page_content, metadata=document_node.metadata + ) + child_document.metadata["doc_id"] = doc_id + child_document.metadata["doc_hash"] = hash + child_page_content = child_document.page_content + if child_page_content.startswith(".") or child_page_content.startswith("。"): + child_page_content = child_page_content[1:].strip() + if len(child_page_content) > 0: + child_document.page_content = child_page_content + child_nodes.append(child_document) + return child_nodes diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 320f0157a..0055625e1 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -21,26 +21,41 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper from models.dataset import Dataset +from services.entities.knowledge_entities.knowledge_entities import Rule class QAIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( - extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" + extract_setting=extract_setting, + is_automatic=( + kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical" + ), ) return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]: + preview = kwargs.get("preview") + process_rule = kwargs.get("process_rule") + if not process_rule: + raise ValueError("No process rule found.") + if not process_rule.get("rules"): + raise ValueError("No rules found in process rule.") + rules = Rule(**process_rule.get("rules")) splitter = self._get_splitter( - processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + processing_rule_mode=process_rule.get("mode"), + max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0, + chunk_overlap=rules.segmentation.chunk_overlap if rules.segmentation else 0, + separator=rules.segmentation.separator if rules.segmentation else "", + embedding_model_instance=kwargs.get("embedding_model_instance"), ) # Split the text documents into nodes. - all_documents = [] - all_qa_documents = [] + all_documents: list[Document] = [] + all_qa_documents: list[Document] = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule") or {}) document.page_content = document_text # parse document to nodes @@ -50,31 +65,41 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata["doc_id"] = doc_id - document_node.metadata["doc_hash"] = hash + if document_node.metadata is not None: + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash # delete Splitter character page_content = document_node.page_content document_node.page_content = remove_leading_symbols(page_content) split_documents.append(document_node) all_documents.extend(split_documents) - for i in range(0, len(all_documents), 10): - threads = [] - sub_documents = all_documents[i : i + 10] - for doc in sub_documents: - document_format_thread = threading.Thread( - target=self._format_qa_document, - kwargs={ - "flask_app": current_app._get_current_object(), - "tenant_id": kwargs.get("tenant_id"), - "document_node": doc, - "all_qa_documents": all_qa_documents, - "document_language": kwargs.get("doc_language", "English"), - }, - ) - threads.append(document_format_thread) - document_format_thread.start() - for thread in threads: - thread.join() + if preview: + self._format_qa_document( + current_app._get_current_object(), # type: ignore + kwargs.get("tenant_id"), # type: ignore + all_documents[0], + all_qa_documents, + kwargs.get("doc_language", "English"), + ) + else: + for i in range(0, len(all_documents), 10): + threads = [] + sub_documents = all_documents[i : i + 10] + for doc in sub_documents: + document_format_thread = threading.Thread( + target=self._format_qa_document, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "tenant_id": kwargs.get("tenant_id"), # type: ignore + "document_node": doc, + "all_qa_documents": all_qa_documents, + "document_language": kwargs.get("doc_language", "English"), + }, + ) + threads.append(document_format_thread) + document_format_thread.start() + for thread in threads: + thread.join() return all_qa_documents def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: @@ -87,7 +112,7 @@ def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: df = pd.read_csv(file) text_docs = [] for index, row in df.iterrows(): - data = Document(page_content=row[0], metadata={"answer": row[1]}) + data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]}) text_docs.append(data) if len(text_docs) == 0: raise ValueError("The CSV file is empty.") @@ -96,12 +121,12 @@ def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: raise ValueError(str(e)) return text_docs - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) @@ -148,11 +173,12 @@ def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, a qa_documents = [] for result in document_qa_list: qa_document = Document(page_content=result["question"], metadata=document_node.metadata.copy()) - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result["question"]) - qa_document.metadata["answer"] = result["answer"] - qa_document.metadata["doc_id"] = doc_id - qa_document.metadata["doc_hash"] = hash + if qa_document.metadata is not None: + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 1e9aaa24f..421cdc05d 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -2,7 +2,20 @@ from collections.abc import Sequence from typing import Any, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel + + +class ChildDocument(BaseModel): + """Class for storing a piece of text and associated metadata.""" + + page_content: str + + vector: Optional[list[float]] = None + + """Arbitrary metadata about the page content (e.g., source, relationships to other + documents, etc.). + """ + metadata: dict = {} class Document(BaseModel): @@ -15,10 +28,12 @@ class Document(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: Optional[dict] = Field(default_factory=dict) + metadata: dict = {} provider: Optional[str] = "dify" + children: Optional[list[ChildDocument]] = None + class BaseDocumentTransformer(ABC): """Abstract base class for document transformation systems. diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 6ae432a52..ac7a3f8bb 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -30,7 +30,11 @@ def run( doc_ids = set() unique_documents = [] for document in documents: - if document.provider == "dify" and document.metadata["doc_id"] not in doc_ids: + if ( + document.provider == "dify" + and document.metadata is not None + and document.metadata["doc_id"] not in doc_ids + ): doc_ids.add(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) @@ -54,7 +58,8 @@ def run( metadata=documents[result.index].metadata, provider=documents[result.index].provider, ) - rerank_document.metadata["score"] = result.score - rerank_documents.append(rerank_document) + if rerank_document.metadata is not None: + rerank_document.metadata["score"] = result.score + rerank_documents.append(rerank_document) return rerank_documents diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 4719be012..cbc96037b 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -39,7 +39,7 @@ def run( unique_documents = [] doc_ids = set() for document in documents: - if document.metadata["doc_id"] not in doc_ids: + if document.metadata is not None and document.metadata["doc_id"] not in doc_ids: doc_ids.add(document.metadata["doc_id"]) unique_documents.append(document) @@ -56,10 +56,11 @@ def run( ) if score_threshold and score < score_threshold: continue - document.metadata["score"] = score - rerank_documents.append(document) + if document.metadata is not None: + document.metadata["score"] = score + rerank_documents.append(document) - rerank_documents.sort(key=lambda x: x.metadata["score"], reverse=True) + rerank_documents.sort(key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) return rerank_documents[:top_n] if top_n else rerank_documents def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]: @@ -76,8 +77,9 @@ def _calculate_keyword_score(self, query: str, documents: list[Document]) -> lis for document in documents: # get the document keywords document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata["keywords"] = document_keywords - documents_keywords.append(document_keywords) + if document.metadata is not None: + document.metadata["keywords"] = document_keywords + documents_keywords.append(document_keywords) # Counter query keywords(TF) query_keyword_counts = Counter(query_keywords) @@ -162,7 +164,7 @@ def _calculate_cosine( query_vector = cache_embedding.embed_query(query) for document in documents: # calculate cosine similarity - if "score" in document.metadata: + if document.metadata and "score" in document.metadata: query_vector_scores.append(document.metadata["score"]) else: # transform to NumPy diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 7a5bf39fa..e1d36aad1 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1,7 +1,7 @@ import math import threading from collections import Counter -from typing import Optional, cast +from typing import Any, Optional, cast from flask import Flask, current_app @@ -34,7 +34,7 @@ from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService -default_retrieval_model = { +default_retrieval_model: dict[str, Any] = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -140,12 +140,12 @@ def retrieve( user_from, available_datasets, query, - retrieve_config.top_k, - retrieve_config.score_threshold, - retrieve_config.rerank_mode, + retrieve_config.top_k or 0, + retrieve_config.score_threshold or 0, + retrieve_config.rerank_mode or "reranking_model", retrieve_config.reranking_model, retrieve_config.weights, - retrieve_config.reranking_enabled, + retrieve_config.reranking_enabled or True, message_id, ) @@ -166,43 +166,29 @@ def retrieve( "content": item.page_content, } retrieval_resource_list.append(source) - document_score_list = {} # deal with dify documents if dify_documents: - for item in dify_documents: - if item.metadata.get("score"): - document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - - index_node_ids = [document.metadata["doc_id"] for document in dify_documents] - segments = DocumentSegment.query.filter( - DocumentSegment.dataset_id.in_(dataset_ids), - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids), - ).all() - - if segments: - index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted( - segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) - ) - for segment in sorted_segments: + records = RetrievalService.format_retrieval_documents(dify_documents) + if records: + for record in records: + segment = record.segment if segment.answer: document_context_list.append( DocumentContext( content=f"question:{segment.get_sign_content()} answer:{segment.answer}", - score=document_score_list.get(segment.index_node_id, None), + score=record.score, ) ) else: document_context_list.append( DocumentContext( content=segment.get_sign_content(), - score=document_score_list.get(segment.index_node_id, None), + score=record.score, ) ) if show_retrieve_source: - for segment in sorted_segments: + for record in records: + segment = record.segment dataset = Dataset.query.filter_by(id=segment.dataset_id).first() document = DatasetDocument.query.filter( DatasetDocument.id == segment.document_id, @@ -218,7 +204,7 @@ def retrieve( "data_source_type": document.data_source_type, "segment_id": segment.id, "retriever_from": invoke_from.to_source(), - "score": document_score_list.get(segment.index_node_id, 0.0), + "score": record.score or 0.0, } if invoke_from.to_source() == "dev": @@ -300,10 +286,11 @@ def single_retrieve( metadata=external_document.get("metadata"), provider="external", ) - document.metadata["score"] = external_document.get("score") - document.metadata["title"] = external_document.get("title") - document.metadata["dataset_id"] = dataset_id - document.metadata["dataset_name"] = dataset.name + if document.metadata is not None: + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset_id + document.metadata["dataset_name"] = dataset.name results.append(document) else: retrieval_model_config = dataset.retrieval_model or default_retrieval_model @@ -325,7 +312,7 @@ def single_retrieve( score_threshold = 0.0 score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") if score_threshold_enabled: - score_threshold = retrieval_model_config.get("score_threshold") + score_threshold = retrieval_model_config.get("score_threshold", 0.0) with measure_time() as timer: results = RetrievalService.retrieve( @@ -358,14 +345,14 @@ def multiple_retrieve( score_threshold: float, reranking_mode: str, reranking_model: Optional[dict] = None, - weights: Optional[dict] = None, + weights: Optional[dict[str, Any]] = None, reranking_enable: bool = True, message_id: Optional[str] = None, ): if not available_datasets: return [] threads = [] - all_documents = [] + all_documents: list[Document] = [] dataset_ids = [dataset.id for dataset in available_datasets] index_type_check = all( item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets @@ -392,15 +379,18 @@ def multiple_retrieve( "The configured knowledge base list have different embedding model, please set reranking model." ) if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE: - weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider - weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model + if weights is not None: + weights["vector_setting"]["embedding_provider_name"] = available_datasets[ + 0 + ].embedding_model_provider + weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model for dataset in available_datasets: index_type = dataset.indexing_technique retrieval_thread = threading.Thread( target=self._retriever, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset.id, "query": query, "top_k": top_k, @@ -439,21 +429,22 @@ def _on_retrieval_end( """Handle retrieval end.""" dify_documents = [document for document in documents if document.provider == "dify"] for document in dify_documents: - query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) + if document.metadata is not None: + query = db.session.query(DocumentSegment).filter( + DocumentSegment.index_node_id == document.metadata["doc_id"] + ) - # if 'dataset_id' in document.metadata: - if "dataset_id" in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + # if 'dataset_id' in document.metadata: + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) - # add hit count to document segment - query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + # add hit count to document segment + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) - db.session.commit() + db.session.commit() # get tracing instance - trace_manager: TraceQueueManager = ( + trace_manager: Optional[TraceQueueManager] = ( self.application_generate_entity.trace_manager if self.application_generate_entity else None ) if trace_manager: @@ -504,10 +495,11 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, metadata=external_document.get("metadata"), provider="external", ) - document.metadata["score"] = external_document.get("score") - document.metadata["title"] = external_document.get("title") - document.metadata["dataset_id"] = dataset_id - document.metadata["dataset_name"] = dataset.name + if document.metadata is not None: + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset_id + document.metadata["dataset_name"] = dataset.name all_documents.append(document) else: # get retrieval model , if the model is not setting , using default @@ -607,19 +599,20 @@ def to_dataset_retriever_tool( tools.append(tool) elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: - tool = DatasetMultiRetrieverTool.from_dataset( - dataset_ids=[dataset.id for dataset in available_datasets], - tenant_id=tenant_id, - top_k=retrieve_config.top_k or 2, - score_threshold=retrieve_config.score_threshold, - hit_callbacks=[hit_callback], - return_resource=return_resource, - retriever_from=invoke_from.to_source(), - reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"), - reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"), - ) + if retrieve_config.reranking_model is not None: + tool = DatasetMultiRetrieverTool.from_dataset( + dataset_ids=[dataset.id for dataset in available_datasets], + tenant_id=tenant_id, + top_k=retrieve_config.top_k or 2, + score_threshold=retrieve_config.score_threshold, + hit_callbacks=[hit_callback], + return_resource=return_resource, + retriever_from=invoke_from.to_source(), + reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"), + reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"), + ) - tools.append(tool) + tools.append(tool) return tools @@ -635,10 +628,11 @@ def calculate_keyword_score(self, query: str, documents: list[Document], top_k: query_keywords = keyword_table_handler.extract_keywords(query, None) documents_keywords = [] for document in documents: - # get the document keywords - document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata["keywords"] = document_keywords - documents_keywords.append(document_keywords) + if document.metadata is not None: + # get the document keywords + document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) + document.metadata["keywords"] = document_keywords + documents_keywords.append(document_keywords) # Counter query keywords(TF) query_keyword_counts = Counter(query_keywords) @@ -696,8 +690,9 @@ def cosine_similarity(vec1, vec2): for document, score in zip(documents, similarities): # format document - document.metadata["score"] = score - documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) + if document.metadata is not None: + document.metadata["score"] = score + documents = sorted(documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True) return documents[:top_k] if top_k else documents def calculate_vector_score( @@ -705,10 +700,12 @@ def calculate_vector_score( ) -> list[Document]: filter_documents = [] for document in all_documents: - if score_threshold is None or document.metadata["score"] >= score_threshold: + if score_threshold is None or (document.metadata and document.metadata.get("score", 0) >= score_threshold): filter_documents.append(document) if not filter_documents: return [] - filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True) + filter_documents = sorted( + filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True + ) return filter_documents[:top_k] if top_k else filter_documents diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index 06147fe7b..b008d0df9 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -1,7 +1,8 @@ -from typing import Union +from typing import Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage @@ -27,11 +28,14 @@ def invoke( SystemPromptMessage(content="You are a helpful AI assistant."), UserPromptMessage(content=query), ] - result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - tools=dataset_tools, - stream=False, - model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, + result = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + tools=dataset_tools, + stream=False, + model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, + ), ) if result.message.tool_calls: # get retrieval model config diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 68fab0c12..05e8d043d 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -1,9 +1,9 @@ from collections.abc import Generator, Sequence -from typing import Union +from typing import Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate @@ -92,6 +92,7 @@ def _react_invoke( suffix: str = SUFFIX, format_instructions: str = FORMAT_INSTRUCTIONS, ) -> Union[str, None]: + prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate] if model_config.mode == "chat": prompt = self.create_chat_prompt( query=query, @@ -149,12 +150,15 @@ def _invoke_llm( :param stop: stop :return: """ - invoke_result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=completion_param, - stop=stop, - stream=True, - user=user_id, + invoke_result = cast( + Generator[LLMResult, None, None], + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=completion_param, + stop=stop, + stream=True, + user=user_id, + ), ) # handle invoke result @@ -172,7 +176,7 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage :return: """ model = None - prompt_messages = [] + prompt_messages: list[PromptMessage] = [] full_text = "" usage = None for result in invoke_result: diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 53032b34d..3376bd7f7 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -26,8 +26,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): def from_encoder( cls: type[TS], embedding_model_instance: Optional[ModelInstance], - allowed_special: Union[Literal[all], Set[str]] = set(), - disallowed_special: Union[Literal[all], Collection[str]] = "all", + allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037 + disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037 **kwargs: Any, ): def _token_encoder(text: str) -> int: diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 7dd62f8de..4bfa541fd 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -92,7 +92,7 @@ def split_documents(self, documents: Iterable[Document]) -> list[Document]: texts, metadatas = [], [] for doc in documents: texts.append(doc.page_content) - metadatas.append(doc.metadata) + metadatas.append(doc.metadata or {}) return self.create_documents(texts, metadatas=metadatas) def _join_docs(self, docs: list[str], separator: str) -> Optional[str]: @@ -143,7 +143,7 @@ def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: """Text splitter that uses HuggingFace tokenizer to count length.""" try: - from transformers import PreTrainedTokenizerBase + from transformers import PreTrainedTokenizerBase # type: ignore if not isinstance(tokenizer, PreTrainedTokenizerBase): raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase") diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index ddb148127..975c374ca 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -14,7 +14,7 @@ class UserTool(BaseModel): label: I18nObject # label description: I18nObject parameters: Optional[list[ToolParameter]] = None - labels: list[str] = None + labels: list[str] | None = None UserToolProviderTypeLiteral = Optional[Literal["builtin", "api", "workflow"]] diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index 0c15b2a37..7c365dc69 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -18,7 +18,7 @@ class ApiToolBundle(BaseModel): # summary summary: Optional[str] = None # operation_id - operation_id: str = None + operation_id: str | None = None # parameters parameters: Optional[list[ToolParameter]] = None # author diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 4fc383f91..c87a90c03 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -243,19 +243,22 @@ def get_simple_instance( :param options: the options of the parameter """ # convert options to ToolParameterOption + # FIXME fix the type error if options: options = [ - ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) # type: ignore + for option in options # type: ignore ] return cls( name=name, label=I18nObject(en_US="", zh_Hans=""), human_description=I18nObject(en_US="", zh_Hans=""), + placeholder=None, type=type, form=cls.ToolParameterForm.LLM, llm_description=llm_description, required=required, - options=options, + options=options, # type: ignore ) @@ -331,7 +334,7 @@ def to_dict(self) -> dict: "default": self.default, "options": self.options, "help": self.help.to_dict() if self.help else None, - "label": self.label.to_dict(), + "label": self.label.to_dict() if self.label else None, "url": self.url, "placeholder": self.placeholder.to_dict() if self.placeholder else None, } @@ -374,7 +377,10 @@ def __init__(self, **data: Any): pool[index] = ToolRuntimeImageVariable(**variable) super().__init__(**data) - def dict(self) -> dict: + def dict(self) -> dict: # type: ignore + """ + FIXME: just ignore the type check for now + """ return { "conversation_id": self.conversation_id, "user_id": self.user_id, diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index 6febf137b..c5f9ca477 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -31,3 +31,7 @@ class ToolApiSchemaError(ValueError): class ToolEngineInvokeError(Exception): meta: ToolInvokeMeta + + def __init__(self, meta, **kwargs): + self.meta = meta + super().__init__(**kwargs) diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py index d99314e33..f451edbf2 100644 --- a/api/core/tools/provider/api_tool_provider.py +++ b/api/core/tools/provider/api_tool_provider.py @@ -1,9 +1,14 @@ +from typing import Optional + from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, ToolCredentialsOption, + ToolDescription, + ToolIdentity, ToolProviderCredentials, + ToolProviderIdentity, ToolProviderType, ) from core.tools.provider.tool_provider import ToolProviderController @@ -64,21 +69,18 @@ def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "Ap pass else: raise ValueError(f"invalid auth type {auth_type}") - - user_name = db_provider.user.name if db_provider.user_id else "" - + user_name = db_provider.user.name if db_provider.user_id and db_provider.user is not None else "" return ApiToolProviderController( - **{ - "identity": { - "author": user_name, - "name": db_provider.name, - "label": {"en_US": db_provider.name, "zh_Hans": db_provider.name}, - "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description}, - "icon": db_provider.icon, - }, - "credentials_schema": credentials_schema, - "provider_id": db_provider.id or "", - } + identity=ToolProviderIdentity( + author=user_name, + name=db_provider.name, + label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name), + description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description), + icon=db_provider.icon, + ), + credentials_schema=credentials_schema, + provider_id=db_provider.id or "", + tools=None, ) @property @@ -93,24 +95,22 @@ def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool: :return: the tool """ return ApiTool( - **{ - "api_bundle": tool_bundle, - "identity": { - "author": tool_bundle.author, - "name": tool_bundle.operation_id, - "label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id}, - "icon": self.identity.icon, - "provider": self.provider_id, - }, - "description": { - "human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""}, - "llm": tool_bundle.summary or "", - }, - "parameters": tool_bundle.parameters or [], - } + api_bundle=tool_bundle, + identity=ToolIdentity( + author=tool_bundle.author, + name=tool_bundle.operation_id or "", + label=I18nObject(en_US=tool_bundle.operation_id, zh_Hans=tool_bundle.operation_id), + icon=self.identity.icon if self.identity else None, + provider=self.provider_id, + ), + description=ToolDescription( + human=I18nObject(en_US=tool_bundle.summary or "", zh_Hans=tool_bundle.summary or ""), + llm=tool_bundle.summary or "", + ), + parameters=tool_bundle.parameters or [], ) - def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]: + def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[Tool]: """ load bundled tools @@ -121,7 +121,7 @@ def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]: return self.tools - def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]: """ fetch tools from database @@ -131,6 +131,8 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: """ if self.tools is not None: return self.tools + if self.identity is None: + return None tools: list[Tool] = [] @@ -151,7 +153,7 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: self.tools = tools return tools - def get_tool(self, tool_name: str) -> ApiTool: + def get_tool(self, tool_name: str) -> Tool: """ get tool by name @@ -161,7 +163,9 @@ def get_tool(self, tool_name: str) -> ApiTool: if self.tools is None: self.get_tools() - for tool in self.tools: + for tool in self.tools or []: + if tool.identity is None: + continue if tool.identity.name == tool_name: return tool diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py index 582ad636b..fc29920ac 100644 --- a/api/core/tools/provider/app_tool_provider.py +++ b/api/core/tools/provider/app_tool_provider.py @@ -1,9 +1,10 @@ import logging -from typing import Any +from typing import Any, Optional from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool.api_tool import ApiTool from core.tools.tool.tool import Tool from extensions.ext_database import db from models.model import App, AppModelConfig @@ -20,10 +21,10 @@ def provider_type(self) -> ToolProviderType: def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None: pass - def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None: + def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: pass - def get_tools(self, user_id: str) -> list[Tool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> list[Tool]: db_tools: list[PublishedAppTool] = ( db.session.query(PublishedAppTool) .filter( @@ -38,7 +39,7 @@ def get_tools(self, user_id: str) -> list[Tool]: tools: list[Tool] = [] for db_tool in db_tools: - tool = { + tool: dict[str, Any] = { "identity": { "author": db_tool.author, "name": db_tool.tool_name, @@ -52,7 +53,7 @@ def get_tools(self, user_id: str) -> list[Tool]: "parameters": [], } # get app from db - app: App = db_tool.app + app: Optional[App] = db_tool.app if not app: logger.error(f"app {db_tool.app_id} not found") @@ -79,6 +80,7 @@ def get_tools(self, user_id: str) -> list[Tool]: type=ToolParameter.ToolParameterType.STRING, required=required, default=default, + placeholder=I18nObject(en_US="", zh_Hans=""), ) ) elif form_type == "select": @@ -92,6 +94,7 @@ def get_tools(self, user_id: str) -> list[Tool]: type=ToolParameter.ToolParameterType.SELECT, required=required, default=default, + placeholder=I18nObject(en_US="", zh_Hans=""), options=[ ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options @@ -99,5 +102,5 @@ def get_tools(self, user_id: str) -> list[Tool]: ) ) - tools.append(Tool(**tool)) + tools.append(ApiTool(**tool)) return tools diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index 5c10f72fd..99a062f8c 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -5,7 +5,7 @@ class BuiltinToolProviderSort: - _position = {} + _position: dict[str, int] = {} @classmethod def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py index 38123f125..cf10f5d25 100644 --- a/api/core/tools/provider/builtin/aippt/tools/aippt.py +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -4,7 +4,7 @@ from json import loads as json_loads from threading import Lock from time import sleep, time -from typing import Any +from typing import Any, Union from httpx import get, post from requests import get as requests_get @@ -21,23 +21,25 @@ class AIPPTGenerateToolAdapter: """ _api_base_url = URL("https://co.aippt.cn/api") - _api_token_cache = {} - _style_cache = {} + _api_token_cache: dict[str, dict[str, Union[str, float]]] = {} + _style_cache: dict[str, dict[str, Union[list[dict[str, Any]], float]]] = {} - _api_token_cache_lock = Lock() - _style_cache_lock = Lock() + _api_token_cache_lock: Lock = Lock() + _style_cache_lock: Lock = Lock() - _task = {} + _task: dict[str, Any] = {} _task_type_map = { "auto": 1, "markdown": 7, } - _tool: BuiltinTool + _tool: BuiltinTool | None - def __init__(self, tool: BuiltinTool = None): + def __init__(self, tool: BuiltinTool | None = None): self._tool = tool - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invokes the AIPPT generate tool with the given user ID and tool parameters. @@ -68,8 +70,8 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe ) # get suit - color: str = tool_parameters.get("color") - style: str = tool_parameters.get("style") + color: str = tool_parameters.get("color", "") + style: str = tool_parameters.get("style", "") if color == "__default__": color_id = "" @@ -226,7 +228,7 @@ def _generate_content(self, task_id: str, model: str, user_id: str) -> str: return "" - def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]: + def _generate_ppt(self, task_id: str, suit_id: int, user_id: str) -> tuple[str, str]: """ Generate a ppt @@ -362,7 +364,9 @@ def _calculate_sign(access_key: str, secret_key: str, timestamp: int) -> str: ).decode("utf-8") @classmethod - def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]: + def _get_styles( + cls, credentials: dict[str, str], user_id: str + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """ Get styles """ @@ -415,7 +419,7 @@ def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[di return colors, styles - def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]: + def get_styles(self, user_id: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """ Get styles @@ -507,7 +511,9 @@ class AIPPTGenerateTool(BuiltinTool): def __init__(self, **kwargs: Any): super().__init__(**kwargs) - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters) def get_runtime_parameters(self) -> list[ToolParameter]: diff --git a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py index 2d65ba2d6..8bd16050e 100644 --- a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py +++ b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py @@ -1,7 +1,7 @@ import logging from typing import Any, Optional -import arxiv +import arxiv # type: ignore from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/audio/tools/tts.py b/api/core/tools/provider/builtin/audio/tools/tts.py index f83a64d04..8a33ac405 100644 --- a/api/core/tools/provider/builtin/audio/tools/tts.py +++ b/api/core/tools/provider/builtin/audio/tools/tts.py @@ -11,19 +11,21 @@ class TTSTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: - provider, model = tool_parameters.get("model").split("#") - voice = tool_parameters.get(f"voice#{provider}#{model}") + provider, model = tool_parameters.get("model", "").split("#") + voice = tool_parameters.get(f"voice#{provider}#{model}", "") model_manager = ModelManager() + if not self.runtime: + raise ValueError("Runtime is required") model_instance = model_manager.get_model_instance( - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", provider=provider, model_type=ModelType.TTS, model=model, ) tts = model_instance.invoke_tts( - content_text=tool_parameters.get("text"), + content_text=tool_parameters.get("text", ""), user=user_id, - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", voice=voice, ) buffer = io.BytesIO() @@ -41,8 +43,11 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInv ] def get_available_models(self) -> list[tuple[str, str, list[Any]]]: + if not self.runtime: + raise ValueError("Runtime is required") model_provider_service = ModelProviderService() - models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts") + tid: str = self.runtime.tenant_id or "" + models = model_provider_service.get_models_by_model_type(tenant_id=tid, model_type="tts") items = [] for provider_model in models: provider = provider_model.provider @@ -62,6 +67,8 @@ def get_runtime_parameters(self) -> list[ToolParameter]: ToolParameter( name=f"voice#{provider}#{model}", label=I18nObject(en_US=f"Voice of {model}({provider})"), + human_description=I18nObject(en_US=f"Select a voice for {model} model"), + placeholder=I18nObject(en_US="Select a voice"), type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, options=[ @@ -83,6 +90,7 @@ def get_runtime_parameters(self) -> list[ToolParameter]: type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, required=True, + placeholder=I18nObject(en_US="Select a model", zh_Hans="选择模型"), options=options, ), ) diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py index a04f5c0fe..b224ff525 100644 --- a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py @@ -2,8 +2,8 @@ import logging from typing import Any, Union -import boto3 -from botocore.exceptions import BotoCoreError +import boto3 # type: ignore +from botocore.exceptions import BotoCoreError # type: ignore from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py index 050b468b7..2e6a9740c 100644 --- a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py @@ -14,12 +14,36 @@ class BedrockRetrieveTool(BuiltinTool): topk: int = None def _bedrock_retrieve( - self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None + self, + query_input: str, + knowledge_base_id: str, + num_results: int, + search_type: str, + rerank_model_id: str, + metadata_filter: Optional[dict] = None, ): try: retrieval_query = {"text": query_input} - retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}} + if search_type not in ["HYBRID", "SEMANTIC"]: + raise RuntimeException("search_type should be HYBRID or SEMANTIC") + + retrieval_configuration = { + "vectorSearchConfiguration": {"numberOfResults": num_results, "overrideSearchType": search_type} + } + + if rerank_model_id != "default": + model_for_rerank_arn = f"arn:aws:bedrock:us-west-2::foundation-model/{rerank_model_id}" + rerankingConfiguration = { + "bedrockRerankingConfiguration": { + "numberOfRerankedResults": num_results, + "modelConfiguration": {"modelArn": model_for_rerank_arn}, + }, + "type": "BEDROCK_RERANKING_MODEL", + } + + retrieval_configuration["vectorSearchConfiguration"]["rerankingConfiguration"] = rerankingConfiguration + retrieval_configuration["vectorSearchConfiguration"]["numberOfResults"] = num_results * 5 # 如果有元数据过滤条件,则添加到检索配置中 if metadata_filter: @@ -81,12 +105,17 @@ def _invoke( metadata_filter_str = tool_parameters.get("metadata_filter") metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None + search_type = tool_parameters.get("search_type") + rerank_model_id = tool_parameters.get("rerank_model_id") + line = 4 retrieved_docs = self._bedrock_retrieve( query_input=query, knowledge_base_id=self.knowledge_base_id, num_results=self.topk, - metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法 + search_type=search_type, + rerank_model_id=rerank_model_id, + metadata_filter=metadata_filter, ) line = 5 diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml index 9e51d52de..f8d1d1d49 100644 --- a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml @@ -59,6 +59,57 @@ parameters: max: 10 default: 5 + - name: search_type + type: select + required: false + label: + en_US: search type + zh_Hans: 搜索类型 + pt_BR: search type + human_description: + en_US: search type + zh_Hans: 搜索类型 + pt_BR: search type + llm_description: search type + default: SEMANTIC + options: + - value: SEMANTIC + label: + en_US: SEMANTIC + zh_Hans: 语义搜索 + - value: HYBRID + label: + en_US: HYBRID + zh_Hans: 混合搜索 + form: form + + - name: rerank_model_id + type: select + required: false + label: + en_US: rerank model id + zh_Hans: 重拍模型ID + pt_BR: rerank model id + human_description: + en_US: rerank model id + zh_Hans: 重拍模型ID + pt_BR: rerank model id + llm_description: rerank model id + options: + - value: default + label: + en_US: default + zh_Hans: 默认 + - value: cohere.rerank-v3-5:0 + label: + en_US: cohere.rerank-v3-5:0 + zh_Hans: cohere.rerank-v3-5:0 + - value: amazon.rerank-v1:0 + label: + en_US: amazon.rerank-v1:0 + zh_Hans: amazon.rerank-v1:0 + form: form + - name: aws_region type: string required: false @@ -73,9 +124,9 @@ parameters: llm_description: AWS region where the Bedrock Knowledge Base is located form: form - - name: metadata_filter - type: string - required: false + - name: metadata_filter # Additional parameter for metadata filtering + type: string # String type, expects JSON-formatted filter conditions + required: false # Optional field - can be omitted label: en_US: Metadata Filter zh_Hans: 元数据过滤器 diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py index 989608122..b6d16d275 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -import boto3 +import boto3 # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py index f43f3b6fe..01bc59634 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py @@ -2,7 +2,7 @@ import logging from typing import Any, Union -import boto3 +import boto3 # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py index e05e2d9bf..3d88f28db 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py @@ -6,8 +6,8 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -# 定义标签映射 -LABEL_MAPPING = {"LABEL_0": "SAFE", "LABEL_1": "NO_SAFE"} +# Define label mappings +LABEL_MAPPING = {0: "SAFE", 1: "NO_SAFE"} class ContentModerationTool(BuiltinTool): @@ -28,12 +28,12 @@ def _invoke_sagemaker(self, payload: dict, endpoint: str): # Handle nested JSON if present if isinstance(json_obj, dict) and "body" in json_obj: body_content = json.loads(json_obj["body"]) - raw_label = body_content.get("label") + prediction_result = body_content.get("prediction") else: - raw_label = json_obj.get("label") + prediction_result = json_obj.get("prediction") - # 映射标签并返回 - result = LABEL_MAPPING.get(raw_label, "NO_SAFE") # 如果映射中没有找到,默认返回NO_SAFE + # Map labels and return + result = LABEL_MAPPING.get(prediction_result, "NO_SAFE") # If not found in mapping, default to NO_SAFE return result def _invoke( diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py index bffcd058b..8320bd84e 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py @@ -2,7 +2,7 @@ import operator from typing import Any, Union -import boto3 +import boto3 # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -11,7 +11,6 @@ class SageMakerReRankTool(BuiltinTool): sagemaker_client: Any = None sagemaker_endpoint: str = None - topk: int = None def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str): inputs = [query_input] * len(docs) @@ -47,8 +46,7 @@ def _invoke( self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint") line = 2 - if not self.topk: - self.topk = tool_parameters.get("topk", 5) + topk = tool_parameters.get("topk", 5) line = 3 query = tool_parameters.get("query", "") @@ -75,7 +73,7 @@ def _invoke( sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True) line = 9 - return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]] + return [self.create_json_message(res) for res in sorted_candidate_docs[:topk]] except Exception as e: return self.create_text_message(f"Exception {str(e)}, line : {line}") diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py index 1fafe09b4..55cff8979 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Any, Optional, Union -import boto3 +import boto3 # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -17,7 +17,7 @@ class TTSModelType(Enum): class SageMakerTTSTool(BuiltinTool): sagemaker_client: Any = None - sagemaker_endpoint: str = None + sagemaker_endpoint: str | None = None s3_client: Any = None comprehend_client: Any = None diff --git a/api/core/tools/provider/builtin/cogview/tools/cogvideo.py b/api/core/tools/provider/builtin/cogview/tools/cogvideo.py index 7f69e833c..a60062ca6 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogvideo.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogvideo.py @@ -1,6 +1,6 @@ from typing import Any, Union -from zhipuai import ZhipuAI +from zhipuai import ZhipuAI # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py b/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py index a521f1c28..3e24b74d2 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py @@ -1,7 +1,7 @@ from typing import Any, Union import httpx -from zhipuai import ZhipuAI +from zhipuai import ZhipuAI # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.py b/api/core/tools/provider/builtin/cogview/tools/cogview3.py index 12b4173fa..9aa781709 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.py @@ -1,7 +1,7 @@ import random from typing import Any, Union -from zhipuai import ZhipuAI +from zhipuai import ZhipuAI # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py index f994cdbf6..2bf10ce8f 100644 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py @@ -125,7 +125,7 @@ def generate_image_by_prompt(self, prompt: dict) -> list[bytes]: for output in history["outputs"].values(): for img in output.get("images", []): image_data = self.get_image(img["filename"], img["subfolder"], img["type"]) - images.append(image_data) + images.append((image_data, img["filename"])) return images finally: ws.close() diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py index 878373627..eb085f221 100644 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py @@ -1,4 +1,5 @@ import json +import mimetypes from typing import Any from core.file import FileType @@ -75,10 +76,12 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe images = comfyui.generate_image_by_prompt(prompt) result = [] - for img in images: + for image_data, filename in images: result.append( self.create_blob_message( - blob=img, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + blob=image_data, + meta={"mime_type": mimetypes.guess_type(filename)[0]}, + save_as=self.VariableKey.IMAGE.value, ) ) return result diff --git a/api/core/tools/provider/builtin/feishu_base/tools/search_records.py b/api/core/tools/provider/builtin/feishu_base/tools/search_records.py index c95949673..d58b42b82 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/search_records.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/search_records.py @@ -7,18 +7,22 @@ class SearchRecordsTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - app_token = tool_parameters.get("app_token") - table_id = tool_parameters.get("table_id") - table_name = tool_parameters.get("table_name") - view_id = tool_parameters.get("view_id") - field_names = tool_parameters.get("field_names") - sort = tool_parameters.get("sort") - filters = tool_parameters.get("filter") - page_token = tool_parameters.get("page_token") + app_token = tool_parameters.get("app_token", "") + table_id = tool_parameters.get("table_id", "") + table_name = tool_parameters.get("table_name", "") + view_id = tool_parameters.get("view_id", "") + field_names = tool_parameters.get("field_names", "") + sort = tool_parameters.get("sort", "") + filters = tool_parameters.get("filter", "") + page_token = tool_parameters.get("page_token", "") automatic_fields = tool_parameters.get("automatic_fields", False) user_id_type = tool_parameters.get("user_id_type", "open_id") page_size = tool_parameters.get("page_size", 20) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_records.py b/api/core/tools/provider/builtin/feishu_base/tools/update_records.py index a7b036387..31cf8e18d 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/update_records.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/update_records.py @@ -7,14 +7,18 @@ class UpdateRecordsTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - app_token = tool_parameters.get("app_token") - table_id = tool_parameters.get("table_id") - table_name = tool_parameters.get("table_name") - records = tool_parameters.get("records") + app_token = tool_parameters.get("app_token", "") + table_id = tool_parameters.get("table_id", "") + table_name = tool_parameters.get("table_name", "") + records = tool_parameters.get("records", "") user_id_type = tool_parameters.get("user_id_type", "open_id") res = client.update_records(app_token, table_id, table_name, records, user_id_type) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py b/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py index 8f83aea5a..80287feca 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py @@ -7,12 +7,16 @@ class AddEventAttendeesTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - event_id = tool_parameters.get("event_id") - attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email") + event_id = tool_parameters.get("event_id", "") + attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email", "") need_notification = tool_parameters.get("need_notification", True) res = client.add_event_attendees(event_id, attendee_phone_or_email, need_notification) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py b/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py index 144889692..02e9b4452 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py @@ -7,11 +7,15 @@ class DeleteEventTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - event_id = tool_parameters.get("event_id") + event_id = tool_parameters.get("event_id", "") need_notification = tool_parameters.get("need_notification", True) res = client.delete_event(event_id, need_notification) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py b/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py index a2cd5a8b1..4dafe4b3b 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py @@ -7,8 +7,12 @@ class GetPrimaryCalendarTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) user_id_type = tool_parameters.get("user_id_type", "open_id") diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py b/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py index 8815b4c9c..2e8ca968b 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py @@ -7,14 +7,18 @@ class ListEventsTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - start_time = tool_parameters.get("start_time") - end_time = tool_parameters.get("end_time") - page_token = tool_parameters.get("page_token") - page_size = tool_parameters.get("page_size") + start_time = tool_parameters.get("start_time", "") + end_time = tool_parameters.get("end_time", "") + page_token = tool_parameters.get("page_token", "") + page_size = tool_parameters.get("page_size", 50) res = client.list_events(start_time, end_time, page_token, page_size) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py b/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py index 85bcb1d3f..b20eb6c31 100644 --- a/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py @@ -7,16 +7,20 @@ class UpdateEventTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - event_id = tool_parameters.get("event_id") - summary = tool_parameters.get("summary") - description = tool_parameters.get("description") + event_id = tool_parameters.get("event_id", "") + summary = tool_parameters.get("summary", "") + description = tool_parameters.get("description", "") need_notification = tool_parameters.get("need_notification", True) - start_time = tool_parameters.get("start_time") - end_time = tool_parameters.get("end_time") + start_time = tool_parameters.get("start_time", "") + end_time = tool_parameters.get("end_time", "") auto_record = tool_parameters.get("auto_record", False) res = client.update_event(event_id, summary, description, need_notification, start_time, end_time, auto_record) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py index 090a0828e..1533f5941 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py @@ -7,13 +7,17 @@ class CreateDocumentTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - title = tool_parameters.get("title") - content = tool_parameters.get("content") - folder_token = tool_parameters.get("folder_token") + title = tool_parameters.get("title", "") + content = tool_parameters.get("content", "") + folder_token = tool_parameters.get("folder_token", "") res = client.create_document(title, content, folder_token) return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py index dd57c6870..8ea68a2ed 100644 --- a/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py +++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py @@ -7,11 +7,15 @@ class ListDocumentBlockTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + if not self.runtime or not self.runtime.credentials: + raise ValueError("Runtime is not set") app_id = self.runtime.credentials.get("app_id") app_secret = self.runtime.credentials.get("app_secret") + if not app_id or not app_secret: + raise ValueError("app_id and app_secret are required") client = FeishuRequest(app_id, app_secret) - document_id = tool_parameters.get("document_id") + document_id = tool_parameters.get("document_id", "") page_token = tool_parameters.get("page_token", "") user_id_type = tool_parameters.get("user_id_type", "open_id") page_size = tool_parameters.get("page_size", 500) diff --git a/api/core/tools/provider/builtin/json_process/tools/delete.py b/api/core/tools/provider/builtin/json_process/tools/delete.py index fcab3d71a..06f6cacd5 100644 --- a/api/core/tools/provider/builtin/json_process/tools/delete.py +++ b/api/core/tools/provider/builtin/json_process/tools/delete.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -from jsonpath_ng import parse +from jsonpath_ng import parse # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/json_process/tools/insert.py b/api/core/tools/provider/builtin/json_process/tools/insert.py index 793c74e5f..e825329a6 100644 --- a/api/core/tools/provider/builtin/json_process/tools/insert.py +++ b/api/core/tools/provider/builtin/json_process/tools/insert.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -from jsonpath_ng import parse +from jsonpath_ng import parse # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/json_process/tools/parse.py b/api/core/tools/provider/builtin/json_process/tools/parse.py index f91432ee7..193017ba9 100644 --- a/api/core/tools/provider/builtin/json_process/tools/parse.py +++ b/api/core/tools/provider/builtin/json_process/tools/parse.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -from jsonpath_ng import parse +from jsonpath_ng import parse # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/json_process/tools/replace.py b/api/core/tools/provider/builtin/json_process/tools/replace.py index 383825c2d..feca0d8a7 100644 --- a/api/core/tools/provider/builtin/json_process/tools/replace.py +++ b/api/core/tools/provider/builtin/json_process/tools/replace.py @@ -1,7 +1,7 @@ import json from typing import Any, Union -from jsonpath_ng import parse +from jsonpath_ng import parse # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/maths/tools/eval_expression.py b/api/core/tools/provider/builtin/maths/tools/eval_expression.py index 0c5b5e41c..d3a497d1c 100644 --- a/api/core/tools/provider/builtin/maths/tools/eval_expression.py +++ b/api/core/tools/provider/builtin/maths/tools/eval_expression.py @@ -1,7 +1,7 @@ import logging from typing import Any, Union -import numexpr as ne +import numexpr as ne # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py index db4adfd4a..6473c509e 100644 --- a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py +++ b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py @@ -1,4 +1,4 @@ -from novita_client import ( +from novita_client import ( # type: ignore Txt2ImgV3Embedding, Txt2ImgV3HiresFix, Txt2ImgV3LoRA, diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py index 0b4f2edff..097b234bd 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py @@ -2,7 +2,7 @@ from copy import deepcopy from typing import Any, Union -from novita_client import ( +from novita_client import ( # type: ignore NovitaClient, ) diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py index 9c61eab9f..297a27abb 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py @@ -2,7 +2,7 @@ from copy import deepcopy from typing import Any, Union -from novita_client import ( +from novita_client import ( # type: ignore NovitaClient, ) diff --git a/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py index 165e93956..704e0015d 100644 --- a/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py +++ b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py @@ -13,7 +13,7 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") - from pydub import AudioSegment + from pydub import AudioSegment # type: ignore class PodcastAudioGeneratorTool(BuiltinTool): diff --git a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py index d8ca20bde..4a47c4211 100644 --- a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py +++ b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py @@ -2,10 +2,10 @@ import logging from typing import Any, Union -from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q -from qrcode.image.base import BaseImage -from qrcode.image.pure import PyPNGImage -from qrcode.main import QRCode +from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q # type: ignore +from qrcode.image.base import BaseImage # type: ignore +from qrcode.image.pure import PyPNGImage # type: ignore +from qrcode.main import QRCode # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/transcript/tools/transcript.py b/api/core/tools/provider/builtin/transcript/tools/transcript.py index 27f700efb..ac7565d9e 100644 --- a/api/core/tools/provider/builtin/transcript/tools/transcript.py +++ b/api/core/tools/provider/builtin/transcript/tools/transcript.py @@ -1,7 +1,7 @@ from typing import Any, Union from urllib.parse import parse_qs, urlparse -from youtube_transcript_api import YouTubeTranscriptApi +from youtube_transcript_api import YouTubeTranscriptApi # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/twilio/tools/send_message.py b/api/core/tools/provider/builtin/twilio/tools/send_message.py index 5ee839baa..98a108f4e 100644 --- a/api/core/tools/provider/builtin/twilio/tools/send_message.py +++ b/api/core/tools/provider/builtin/twilio/tools/send_message.py @@ -37,7 +37,7 @@ class TwilioAPIWrapper(BaseModel): def set_validator(cls, values: dict) -> dict: """Validate that api key and python package exists in environment.""" try: - from twilio.rest import Client + from twilio.rest import Client # type: ignore except ImportError: raise ImportError("Could not import twilio python package. Please install it with `pip install twilio`.") account_sid = values.get("account_sid") diff --git a/api/core/tools/provider/builtin/twilio/twilio.py b/api/core/tools/provider/builtin/twilio/twilio.py index b1d100aad..649e03d18 100644 --- a/api/core/tools/provider/builtin/twilio/twilio.py +++ b/api/core/tools/provider/builtin/twilio/twilio.py @@ -1,7 +1,7 @@ from typing import Any -from twilio.base.exceptions import TwilioRestException -from twilio.rest import Client +from twilio.base.exceptions import TwilioRestException # type: ignore +from twilio.rest import Client # type: ignore from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.py b/api/core/tools/provider/builtin/vanna/tools/vanna.py index 1c7cb39c9..a6afd2ddd 100644 --- a/api/core/tools/provider/builtin/vanna/tools/vanna.py +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.py @@ -1,6 +1,6 @@ from typing import Any, Union -from vanna.remote import VannaDefault +from vanna.remote import VannaDefault # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.errors import ToolProviderCredentialValidationError @@ -14,6 +14,9 @@ def _invoke( """ invoke tools """ + # Ensure runtime and credentials + if not self.runtime or not self.runtime.credentials: + raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing") api_key = self.runtime.credentials.get("api_key", None) if not api_key: raise ToolProviderCredentialValidationError("Please input api key") diff --git a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py index cb88e9519..edb96e722 100644 --- a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py +++ b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py @@ -1,6 +1,6 @@ from typing import Any, Optional, Union -import wikipedia +import wikipedia # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/yahoo/tools/analytics.py b/api/core/tools/provider/builtin/yahoo/tools/analytics.py index f044fbe54..95a65ba22 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/analytics.py +++ b/api/core/tools/provider/builtin/yahoo/tools/analytics.py @@ -3,7 +3,7 @@ import pandas as pd from requests.exceptions import HTTPError, ReadTimeout -from yfinance import download +from yfinance import download # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/yahoo/tools/news.py b/api/core/tools/provider/builtin/yahoo/tools/news.py index ff820430f..c9ae0c4ca 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/news.py +++ b/api/core/tools/provider/builtin/yahoo/tools/news.py @@ -1,6 +1,6 @@ from typing import Any, Union -import yfinance +import yfinance # type: ignore from requests.exceptions import HTTPError, ReadTimeout from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/yahoo/tools/ticker.py b/api/core/tools/provider/builtin/yahoo/tools/ticker.py index dfc7e4604..74d0d25ad 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/ticker.py +++ b/api/core/tools/provider/builtin/yahoo/tools/ticker.py @@ -1,7 +1,7 @@ from typing import Any, Union from requests.exceptions import HTTPError, ReadTimeout -from yfinance import Ticker +from yfinance import Ticker # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/youtube/tools/videos.py b/api/core/tools/provider/builtin/youtube/tools/videos.py index 95dec2eac..a24fe8967 100644 --- a/api/core/tools/provider/builtin/youtube/tools/videos.py +++ b/api/core/tools/provider/builtin/youtube/tools/videos.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Any, Union -from googleapiclient.discovery import build +from googleapiclient.discovery import build # type: ignore from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 955a0add3..61de75ac5 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -1,6 +1,6 @@ from abc import abstractmethod from os import listdir, path -from typing import Any +from typing import Any, Optional from core.helper.module_import_helper import load_single_subclass_from_source from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType @@ -50,6 +50,8 @@ def _get_builtin_tools(self) -> list[Tool]: """ if self.tools: return self.tools + if not self.identity: + return [] provider = self.identity.name tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools") @@ -86,7 +88,7 @@ def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: return self.credentials_schema.copy() - def get_tools(self) -> list[Tool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]: """ returns a list of tools that the provider can provide @@ -94,11 +96,14 @@ def get_tools(self) -> list[Tool]: """ return self._get_builtin_tools() - def get_tool(self, tool_name: str) -> Tool: + def get_tool(self, tool_name: str) -> Optional[Tool]: """ returns the tool that the provider can provide """ - return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) + tools = self.get_tools() + if tools is None: + raise ValueError("tools not found") + return next((t for t in tools if t.identity and t.identity.name == tool_name), None) def get_parameters(self, tool_name: str) -> list[ToolParameter]: """ @@ -107,10 +112,13 @@ def get_parameters(self, tool_name: str) -> list[ToolParameter]: :param tool_name: the name of the tool, defined in `get_tools` :return: list of parameters """ - tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) + tools = self.get_tools() + if tools is None: + raise ToolNotFoundError(f"tool {tool_name} not found") + tool = next((t for t in tools if t.identity and t.identity.name == tool_name), None) if tool is None: raise ToolNotFoundError(f"tool {tool_name} not found") - return tool.parameters + return tool.parameters or [] @property def need_credentials(self) -> bool: @@ -144,6 +152,8 @@ def _get_tool_labels(self) -> list[ToolLabelEnum]: """ returns the labels of the provider """ + if self.identity is None: + return [] return self.identity.tags or [] def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: @@ -159,56 +169,56 @@ def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dic for parameter in tool_parameters_schema: tool_parameters_need_to_validate[parameter.name] = parameter - for parameter in tool_parameters: - if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}") + for parameter_name in tool_parameters: + if parameter_name not in tool_parameters_need_to_validate: + raise ToolParameterValidationError(f"parameter {parameter_name} not found in tool {tool_name}") # check type - parameter_schema = tool_parameters_need_to_validate[parameter] + parameter_schema = tool_parameters_need_to_validate[parameter_name] if parameter_schema.type == ToolParameter.ToolParameterType.STRING: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") + if not isinstance(tool_parameters[parameter_name], str): + raise ToolParameterValidationError(f"parameter {parameter_name} should be string") elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: - if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f"parameter {parameter} should be number") + if not isinstance(tool_parameters[parameter_name], int | float): + raise ToolParameterValidationError(f"parameter {parameter_name} should be number") - if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: + if parameter_schema.min is not None and tool_parameters[parameter_name] < parameter_schema.min: raise ToolParameterValidationError( - f"parameter {parameter} should be greater than {parameter_schema.min}" + f"parameter {parameter_name} should be greater than {parameter_schema.min}" ) - if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: + if parameter_schema.max is not None and tool_parameters[parameter_name] > parameter_schema.max: raise ToolParameterValidationError( - f"parameter {parameter} should be less than {parameter_schema.max}" + f"parameter {parameter_name} should be less than {parameter_schema.max}" ) elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: - if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f"parameter {parameter} should be boolean") + if not isinstance(tool_parameters[parameter_name], bool): + raise ToolParameterValidationError(f"parameter {parameter_name} should be boolean") elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") + if not isinstance(tool_parameters[parameter_name], str): + raise ToolParameterValidationError(f"parameter {parameter_name} should be string") options = parameter_schema.options if not isinstance(options, list): - raise ToolParameterValidationError(f"parameter {parameter} options should be list") + raise ToolParameterValidationError(f"parameter {parameter_name} options should be list") - if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + if tool_parameters[parameter_name] not in [x.value for x in options]: + raise ToolParameterValidationError(f"parameter {parameter_name} should be one of {options}") - tool_parameters_need_to_validate.pop(parameter) + tool_parameters_need_to_validate.pop(parameter_name) - for parameter in tool_parameters_need_to_validate: - parameter_schema = tool_parameters_need_to_validate[parameter] + for parameter_name in tool_parameters_need_to_validate: + parameter_schema = tool_parameters_need_to_validate[parameter_name] if parameter_schema.required: - raise ToolParameterValidationError(f"parameter {parameter} is required") + raise ToolParameterValidationError(f"parameter {parameter_name} is required") # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: default_value = parameter_schema.type.cast_value(parameter_schema.default) - tool_parameters[parameter] = default_value + tool_parameters[parameter_name] = default_value def validate_credentials(self, credentials: dict[str, Any]) -> None: """ diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index bc05a1156..e35207e4f 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -24,10 +24,12 @@ def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: :return: the credentials schema """ + if self.credentials_schema is None: + return {} return self.credentials_schema.copy() @abstractmethod - def get_tools(self) -> list[Tool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]: """ returns a list of tools that the provider can provide @@ -36,7 +38,7 @@ def get_tools(self) -> list[Tool]: pass @abstractmethod - def get_tool(self, tool_name: str) -> Tool: + def get_tool(self, tool_name: str) -> Optional[Tool]: """ returns a tool that the provider can provide @@ -51,10 +53,13 @@ def get_parameters(self, tool_name: str) -> list[ToolParameter]: :param tool_name: the name of the tool, defined in `get_tools` :return: list of parameters """ - tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) + tools = self.get_tools() + if tools is None: + raise ToolNotFoundError(f"tool {tool_name} not found") + tool = next((t for t in tools if t.identity and t.identity.name == tool_name), None) if tool is None: raise ToolNotFoundError(f"tool {tool_name} not found") - return tool.parameters + return tool.parameters or [] @property def provider_type(self) -> ToolProviderType: @@ -78,55 +83,55 @@ def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dic for parameter in tool_parameters_schema: tool_parameters_need_to_validate[parameter.name] = parameter - for parameter in tool_parameters: - if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}") + for tool_parameter in tool_parameters: + if tool_parameter not in tool_parameters_need_to_validate: + raise ToolParameterValidationError(f"parameter {tool_parameter} not found in tool {tool_name}") # check type - parameter_schema = tool_parameters_need_to_validate[parameter] + parameter_schema = tool_parameters_need_to_validate[tool_parameter] if parameter_schema.type == ToolParameter.ToolParameterType.STRING: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") + if not isinstance(tool_parameters[tool_parameter], str): + raise ToolParameterValidationError(f"parameter {tool_parameter} should be string") elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: - if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f"parameter {parameter} should be number") + if not isinstance(tool_parameters[tool_parameter], int | float): + raise ToolParameterValidationError(f"parameter {tool_parameter} should be number") - if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: + if parameter_schema.min is not None and tool_parameters[tool_parameter] < parameter_schema.min: raise ToolParameterValidationError( - f"parameter {parameter} should be greater than {parameter_schema.min}" + f"parameter {tool_parameter} should be greater than {parameter_schema.min}" ) - if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: + if parameter_schema.max is not None and tool_parameters[tool_parameter] > parameter_schema.max: raise ToolParameterValidationError( - f"parameter {parameter} should be less than {parameter_schema.max}" + f"parameter {tool_parameter} should be less than {parameter_schema.max}" ) elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: - if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f"parameter {parameter} should be boolean") + if not isinstance(tool_parameters[tool_parameter], bool): + raise ToolParameterValidationError(f"parameter {tool_parameter} should be boolean") elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: - if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f"parameter {parameter} should be string") + if not isinstance(tool_parameters[tool_parameter], str): + raise ToolParameterValidationError(f"parameter {tool_parameter} should be string") options = parameter_schema.options if not isinstance(options, list): - raise ToolParameterValidationError(f"parameter {parameter} options should be list") + raise ToolParameterValidationError(f"parameter {tool_parameter} options should be list") - if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + if tool_parameters[tool_parameter] not in [x.value for x in options]: + raise ToolParameterValidationError(f"parameter {tool_parameter} should be one of {options}") - tool_parameters_need_to_validate.pop(parameter) + tool_parameters_need_to_validate.pop(tool_parameter) - for parameter in tool_parameters_need_to_validate: - parameter_schema = tool_parameters_need_to_validate[parameter] + for tool_parameter_validate in tool_parameters_need_to_validate: + parameter_schema = tool_parameters_need_to_validate[tool_parameter_validate] if parameter_schema.required: - raise ToolParameterValidationError(f"parameter {parameter} is required") + raise ToolParameterValidationError(f"parameter {tool_parameter_validate} is required") # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: - tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default) + tool_parameters[tool_parameter_validate] = parameter_schema.type.cast_value(parameter_schema.default) def validate_credentials_format(self, credentials: dict[str, Any]) -> None: """ @@ -144,6 +149,8 @@ def validate_credentials_format(self, credentials: dict[str, Any]) -> None: for credential_name in credentials: if credential_name not in credentials_need_to_validate: + if self.identity is None: + raise ValueError("identity is not set") raise ToolProviderCredentialValidationError( f"credential {credential_name} not found in provider {self.identity.name}" ) diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py index 5656dd09a..17fe2e20c 100644 --- a/api/core/tools/provider/workflow_tool_provider.py +++ b/api/core/tools/provider/workflow_tool_provider.py @@ -11,6 +11,7 @@ ToolProviderType, ) from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool.tool import Tool from core.tools.tool.workflow_tool import WorkflowTool from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from extensions.ext_database import db @@ -116,6 +117,7 @@ def fetch_workflow_variable(variable_name: str): llm_description=parameter.description, required=variable.required, options=options, + placeholder=I18nObject(en_US="", zh_Hans=""), ) ) elif features.file_upload: @@ -128,6 +130,7 @@ def fetch_workflow_variable(variable_name: str): llm_description=parameter.description, required=False, form=parameter.form, + placeholder=I18nObject(en_US="", zh_Hans=""), ) ) else: @@ -157,7 +160,7 @@ def fetch_workflow_variable(variable_name: str): label=db_provider.label, ) - def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: + def get_tools(self, user_id: str = "", tenant_id: str = "") -> Optional[list[Tool]]: """ fetch tools from database @@ -168,7 +171,7 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: if self.tools is not None: return self.tools - db_providers: WorkflowToolProvider = ( + db_providers: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter( WorkflowToolProvider.tenant_id == tenant_id, @@ -179,12 +182,14 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: if not db_providers: return [] + if not db_providers.app: + raise ValueError("app not found") self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)] return self.tools - def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: + def get_tool(self, tool_name: str) -> Optional[Tool]: """ get tool by name @@ -195,6 +200,8 @@ def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: return None for tool in self.tools: + if tool.identity is None: + continue if tool.identity.name == tool_name: return tool diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 48aac75db..7d27c4fcf 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -32,11 +32,13 @@ def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": :param meta: the meta data of a tool call processing, tenant_id is required :return: the new tool """ + if self.api_bundle is None: + raise ValueError("api_bundle is required") return self.__class__( identity=self.identity.model_copy() if self.identity else None, parameters=self.parameters.copy() if self.parameters else None, description=self.description.model_copy() if self.description else None, - api_bundle=self.api_bundle.model_copy() if self.api_bundle else None, + api_bundle=self.api_bundle.model_copy(), runtime=Tool.Runtime(**runtime), ) @@ -61,6 +63,8 @@ def tool_provider_type(self) -> ToolProviderType: def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: headers = {} + if self.runtime is None: + raise ValueError("runtime is required") credentials = self.runtime.credentials or {} if "auth_type" not in credentials: @@ -88,7 +92,7 @@ def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: headers[api_key_header] = credentials["api_key_value"] - needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required] + needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required] for parameter in needed_parameters: if parameter.required and parameter.name not in parameters: raise ToolParameterValidationError(f"Missing required parameter {parameter.name}") @@ -137,7 +141,8 @@ def do_http_request( params = {} path_params = {} - body = {} + # FIXME: body should be a dict[str, Any] but it changed a lot in this function + body: Any = {} cookies = {} files = [] @@ -197,8 +202,23 @@ def do_http_request( else: body = body - if method in {"get", "head", "post", "put", "delete", "patch"}: - response = getattr(ssrf_proxy, method)( + if method in { + "get", + "head", + "post", + "put", + "delete", + "patch", + "options", + "GET", + "POST", + "PUT", + "PATCH", + "DELETE", + "HEAD", + "OPTIONS", + }: + response: httpx.Response = getattr(ssrf_proxy, method.lower())( url, params=params, headers=headers, @@ -288,6 +308,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe """ invoke http request """ + response: httpx.Response | str = "" # assemble request headers = self.assembling_request(tool_parameters) diff --git a/api/core/tools/tool/builtin_tool.py b/api/core/tools/tool/builtin_tool.py index e2a81ed0a..adda4297f 100644 --- a/api/core/tools/tool/builtin_tool.py +++ b/api/core/tools/tool/builtin_tool.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, cast from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage @@ -32,9 +32,12 @@ def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: :return: the model result """ # invoke model + if self.runtime is None or self.identity is None: + raise ValueError("runtime and identity are required") + return ModelInvocationUtils.invoke( user_id=user_id, - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", tool_type="builtin", tool_name=self.identity.name, prompt_messages=prompt_messages, @@ -50,8 +53,11 @@ def get_max_tokens(self) -> int: :param model_config: the model config :return: the max tokens """ + if self.runtime is None: + raise ValueError("runtime is required") + return ModelInvocationUtils.get_max_llm_context_tokens( - tenant_id=self.runtime.tenant_id, + tenant_id=self.runtime.tenant_id or "", ) def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: @@ -61,7 +67,12 @@ def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: :param prompt_messages: the prompt messages :return: the tokens """ - return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages) + if self.runtime is None: + raise ValueError("runtime is required") + + return ModelInvocationUtils.calculate_tokens( + tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages + ) def summary(self, user_id: str, content: str) -> str: max_tokens = self.get_max_tokens() @@ -81,7 +92,7 @@ def summarize(content: str) -> str: stop=[], ) - return summary.message.content + return cast(str, summary.message.content) lines = content.split("\n") new_lines = [] @@ -102,16 +113,16 @@ def summarize(content: str) -> str: # merge lines into messages with max tokens messages: list[str] = [] - for i in new_lines: + for j in new_lines: if len(messages) == 0: - messages.append(i) + messages.append(j) else: - if len(messages[-1]) + len(i) < max_tokens * 0.5: - messages[-1] += i - if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: - messages.append(i) + if len(messages[-1]) + len(j) < max_tokens * 0.5: + messages[-1] += j + if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7: + messages.append(j) else: - messages[-1] += i + messages[-1] += j summaries = [] for i in range(len(messages)): diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index ab7b40a25..a4afea4b9 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,4 +1,5 @@ import threading +from typing import Any from flask import Flask, current_app from pydantic import BaseModel, Field @@ -7,13 +8,14 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -default_retrieval_model = { +default_retrieval_model: dict[str, Any] = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -44,12 +46,12 @@ def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): def _run(self, query: str) -> str: threads = [] - all_documents = [] + all_documents: list[RagDocument] = [] for dataset_id in self.dataset_ids: retrieval_thread = threading.Thread( target=self._retriever, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore "dataset_id": dataset_id, "query": query, "all_documents": all_documents, @@ -77,11 +79,11 @@ def _run(self, query: str) -> str: document_score_list = {} for item in all_documents: - if item.metadata.get("score"): + if item.metadata and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] - index_node_ids = [document.metadata["doc_id"] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(self.dataset_ids), DocumentSegment.completed_at.isnot(None), @@ -139,6 +141,7 @@ def _run(self, query: str) -> str: hit_callback.return_retriever_resource_info(context_list) return str("\n".join(document_context_list)) + return "" def _retriever( self, diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py index dad8c7735..a4d2de3b1 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py @@ -1,7 +1,7 @@ from abc import abstractmethod from typing import Any, Optional -from msal_extensions.persistence import ABC +from msal_extensions.persistence import ABC # type: ignore from pydantic import BaseModel, ConfigDict from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index 987f94a35..b38201647 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel, Field from core.rag.datasource.retrieval_service import RetrievalService @@ -69,25 +71,27 @@ def _run(self, query: str) -> str: metadata=external_document.get("metadata"), provider="external", ) - document.metadata["score"] = external_document.get("score") - document.metadata["title"] = external_document.get("title") - document.metadata["dataset_id"] = dataset.id - document.metadata["dataset_name"] = dataset.name - results.append(document) + if document.metadata is not None: + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset.id + document.metadata["dataset_name"] = dataset.name + results.append(document) # deal with external documents context_list = [] for position, item in enumerate(results, start=1): - source = { - "position": position, - "dataset_id": item.metadata.get("dataset_id"), - "dataset_name": item.metadata.get("dataset_name"), - "document_name": item.metadata.get("title"), - "data_source_type": "external", - "retriever_from": self.retriever_from, - "score": item.metadata.get("score"), - "title": item.metadata.get("title"), - "content": item.page_content, - } + if item.metadata is not None: + source = { + "position": position, + "dataset_id": item.metadata.get("dataset_id"), + "dataset_name": item.metadata.get("dataset_name"), + "document_name": item.metadata.get("title"), + "data_source_type": "external", + "retriever_from": self.retriever_from, + "score": item.metadata.get("score"), + "title": item.metadata.get("title"), + "content": item.page_content, + } context_list.append(source) for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(context_list) @@ -95,7 +99,7 @@ def _run(self, query: str) -> str: return str("\n".join([item.page_content for item in results])) else: # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model or default_retrieval_model + retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query documents = RetrievalService.retrieve( @@ -113,11 +117,11 @@ def _run(self, query: str) -> str: score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] else 0.0, - reranking_model=retrieval_model.get("reranking_model", None) + reranking_model=retrieval_model.get("reranking_model") if retrieval_model["reranking_enable"] else None, reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", - weights=retrieval_model.get("weights", None), + weights=retrieval_model.get("weights"), ) else: documents = [] @@ -127,7 +131,7 @@ def _run(self, query: str) -> str: document_score_list = {} if dataset.indexing_technique != "economy": for item in documents: - if item.metadata.get("score"): + if item.metadata is not None and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] index_node_ids = [document.metadata["doc_id"] for document in documents] @@ -155,20 +159,21 @@ def _run(self, query: str) -> str: context_list = [] resource_number = 1 for segment in sorted_segments: - context = {} - document = Document.query.filter( + document_segment = Document.query.filter( Document.id == segment.document_id, Document.enabled == True, Document.archived == False, ).first() - if dataset and document: + if not document_segment: + continue + if dataset and document_segment: source = { "position": resource_number, "dataset_id": dataset.id, "dataset_name": dataset.name, - "document_id": document.id, - "document_name": document.name, - "data_source_type": document.data_source_type, + "document_id": document_segment.id, + "document_name": document_segment.name, + "data_source_type": document_segment.data_source_type, "segment_id": segment.id, "retriever_from": self.retriever_from, "score": document_score_list.get(segment.index_node_id, None), diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 3c9295c49..2d7e193e1 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -23,7 +23,7 @@ class DatasetRetrieverTool(Tool): def get_dataset_tools( tenant_id: str, dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity, + retrieve_config: Optional[DatasetRetrieveConfigEntity], return_resource: bool, invoke_from: InvokeFrom, hit_callback: DatasetIndexToolCallbackHandler, @@ -51,6 +51,8 @@ def get_dataset_tools( invoke_from=invoke_from, hit_callback=hit_callback, ) + if retrieval_tools is None: + return [] # restore retrieve strategy retrieve_config.retrieve_strategy = original_retriever_mode @@ -83,6 +85,7 @@ def get_runtime_parameters(self) -> list[ToolParameter]: llm_description="Query for the dataset to be used to retrieve the dataset.", required=True, default="", + placeholder=I18nObject(en_US="", zh_Hans=""), ), ] @@ -102,7 +105,9 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe return self.create_text_message(text=result) - def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: + def validate_credentials( + self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False + ) -> str | None: """ validate the credentials for dataset retriever tool """ diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 8d4045038..4094207be 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -91,7 +91,7 @@ def tool_provider_type(self) -> ToolProviderType: :return: the tool provider type """ - def load_variables(self, variables: ToolRuntimeVariablePool): + def load_variables(self, variables: ToolRuntimeVariablePool | None) -> None: """ load variables from database @@ -105,6 +105,8 @@ def set_image_variable(self, variable_name: str, image_key: str) -> None: """ if not self.variables: return + if self.identity is None: + return self.variables.set_file(self.identity.name, variable_name, image_key) @@ -114,6 +116,8 @@ def set_text_variable(self, variable_name: str, text: str) -> None: """ if not self.variables: return + if self.identity is None: + return self.variables.set_text(self.identity.name, variable_name, text) @@ -200,7 +204,11 @@ def list_default_image_variables(self) -> list[ToolRuntimeVariable]: def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> list[ToolInvokeMessage]: # update tool_parameters # TODO: Fix type error. + if self.runtime is None: + return [] if self.runtime.runtime_parameters: + # Convert Mapping to dict before updating + tool_parameters = dict(tool_parameters) tool_parameters.update(self.runtime.runtime_parameters) # try parse tool parameters into the correct type @@ -214,6 +222,12 @@ def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> list[ToolI if not isinstance(result, list): result = [result] + if not all(isinstance(message, ToolInvokeMessage) for message in result): + raise ValueError( + f"Invalid return type from {self.__class__.__name__}._invoke method. " + "Expected ToolInvokeMessage or list of ToolInvokeMessage." + ) + return result def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> dict[str, Any]: @@ -221,7 +235,7 @@ def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> Transform tool parameters type """ # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials - result = deepcopy(tool_parameters) + result: dict[str, Any] = deepcopy(dict(tool_parameters)) for parameter in self.parameters or []: if parameter.name in tool_parameters: result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name]) @@ -234,12 +248,15 @@ def _invoke( ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: pass - def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: + def validate_credentials( + self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False + ) -> str | None: """ validate the credentials :param credentials: the credentials :param parameters: the parameters + :param format_only: only return the formatted """ pass diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index 33b4ad021..eea66ee4e 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -1,12 +1,13 @@ import json import logging from copy import deepcopy -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType from core.tools.tool.tool import Tool from extensions.ext_database import db +from factories.file_factory import build_from_mapping from models.account import Account from models.model import App, EndUser from models.workflow import Workflow @@ -68,20 +69,20 @@ def _invoke( if data.get("error"): raise Exception(data.get("error")) - result = [] + r = [] outputs = data.get("outputs") if outputs == None: outputs = {} else: - outputs, files = self._extract_files(outputs) - for file in files: - result.append(self.create_file_message(file)) + outputs, extracted_files = self._extract_files(outputs) + for f in extracted_files: + r.append(self.create_file_message(f)) - result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False))) - result.append(self.create_json_message(outputs)) + r.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False))) + r.append(self.create_json_message(outputs)) - return result + return r def _get_user(self, user_id: str) -> Union[EndUser, Account]: """ @@ -194,10 +195,18 @@ def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]: if isinstance(value, list): for item in value: if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY: - file = File.model_validate(item) + item["tool_file_id"] = item.get("related_id") + file = build_from_mapping( + mapping=item, + tenant_id=str(cast(Tool.Runtime, self.runtime).tenant_id), + ) files.append(file) elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: - file = File.model_validate(value) + value["tool_file_id"] = value.get("related_id") + file = build_from_mapping( + mapping=value, + tenant_id=str(cast(Tool.Runtime, self.runtime).tenant_id), + ) files.append(file) result[key] = value diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index f92b43608..f7a8ed63f 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -3,7 +3,7 @@ from copy import deepcopy from datetime import UTC, datetime from mimetypes import guess_type -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast from yarl import URL @@ -46,7 +46,7 @@ def agent_invoke( invoke_from: InvokeFrom, agent_tool_callback: DifyAgentCallbackHandler, trace_manager: Optional[TraceQueueManager] = None, - ) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]: + ) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]: """ Agent invokes the tool with the given arguments. """ @@ -69,6 +69,8 @@ def agent_invoke( raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") # invoke the tool + if tool.identity is None: + raise ValueError("tool identity is not set") try: # hit the callback handler agent_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) @@ -111,7 +113,7 @@ def agent_invoke( error_response = f"tool invoke error: {e}" agent_tool_callback.on_tool_error(e) except ToolEngineInvokeError as e: - meta = e.args[0] + meta = e.meta error_response = f"tool invoke error: {meta.error}" agent_tool_callback.on_tool_error(e) return error_response, [], meta @@ -163,6 +165,8 @@ def _invoke(tool: Tool, tool_parameters: dict, user_id: str) -> tuple[ToolInvoke """ Invoke the tool with the given arguments. """ + if tool.identity is None: + raise ValueError("tool identity is not set") started_at = datetime.now(UTC) meta = ToolInvokeMeta( time_cost=0.0, @@ -171,7 +175,7 @@ def _invoke(tool: Tool, tool_parameters: dict, user_id: str) -> tuple[ToolInvoke "tool_name": tool.identity.name, "tool_provider": tool.identity.provider, "tool_provider_type": tool.tool_provider_type().value, - "tool_parameters": deepcopy(tool.runtime.runtime_parameters), + "tool_parameters": deepcopy(tool.runtime.runtime_parameters) if tool.runtime else {}, "tool_icon": tool.identity.icon, }, ) @@ -194,9 +198,9 @@ def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str result = "" for response in tool_response: if response.type == ToolInvokeMessage.MessageType.TEXT: - result += response.message + result += str(response.message) if response.message is not None else "" elif response.type == ToolInvokeMessage.MessageType.LINK: - result += f"result link: {response.message}. please tell user to check it." + result += f"result link: {response.message!r}. please tell user to check it." elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: result += ( "image has been created and sent to user already, you do not need to create it," @@ -205,7 +209,7 @@ def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str elif response.type == ToolInvokeMessage.MessageType.JSON: result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}." else: - result += f"tool response: {response.message}." + result += f"tool response: {response.message!r}." return result @@ -223,7 +227,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis mimetype = response.meta.get("mime_type") else: try: - url = URL(response.message) + url = URL(cast(str, response.message)) extension = url.suffix guess_type_result, _ = guess_type(f"a{extension}") if guess_type_result: @@ -237,7 +241,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis result.append( ToolInvokeMessageBinary( mimetype=response.meta.get("mime_type", "image/jpeg"), - url=response.message, + url=cast(str, response.message), save_as=response.save_as, ) ) @@ -245,7 +249,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis result.append( ToolInvokeMessageBinary( mimetype=response.meta.get("mime_type", "octet/stream"), - url=response.message, + url=cast(str, response.message), save_as=response.save_as, ) ) @@ -257,7 +261,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis mimetype=response.meta.get("mime_type", "octet/stream") if response.meta else "octet/stream", - url=response.message, + url=cast(str, response.message), save_as=response.save_as, ) ) diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 2a5a2944e..e53985951 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -84,13 +84,17 @@ def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[ if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): raise ValueError("Unsupported tool type") - provider_ids = [controller.provider_id for controller in tool_providers] + provider_ids = [ + controller.provider_id + for controller in tool_providers + if isinstance(controller, (ApiToolProviderController, WorkflowToolProviderController)) + ] labels: list[ToolLabelBinding] = ( db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() ) - tool_labels = {label.tool_id: [] for label in labels} + tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} for label in labels: tool_labels[label.tool_id].append(label.label_name) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index ac333162b..5b2173a4d 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -4,7 +4,7 @@ from collections.abc import Generator from os import listdir, path from threading import Lock, Thread -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast from configs import dify_config from core.agent.entities import AgentToolEntity @@ -15,15 +15,18 @@ from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter -from core.tools.errors import ToolProviderNotFoundError +from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.provider.builtin._positions import BuiltinToolProviderSort from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.provider.tool_provider import ToolProviderController +from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController from core.tools.tool.api_tool import ApiTool from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager +from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -33,9 +36,9 @@ class ToolManager: _builtin_provider_lock = Lock() - _builtin_providers = {} + _builtin_providers: dict[str, BuiltinToolProviderController] = {} _builtin_providers_loaded = False - _builtin_tools_labels = {} + _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} @classmethod def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController: @@ -55,7 +58,7 @@ def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController: return cls._builtin_providers[provider] @classmethod - def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool: + def get_builtin_tool(cls, provider: str, tool_name: str) -> Union[BuiltinTool, Tool]: """ get the builtin tool @@ -66,13 +69,15 @@ def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool: """ provider_controller = cls.get_builtin_provider(provider) tool = provider_controller.get_tool(tool_name) + if tool is None: + raise ToolNotFoundError(f"tool {tool_name} not found") return tool @classmethod def get_tool( cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: Optional[str] = None - ) -> Union[BuiltinTool, ApiTool]: + ) -> Union[BuiltinTool, ApiTool, Tool]: """ get the tool @@ -103,7 +108,7 @@ def get_tool_runtime( tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, - ) -> Union[BuiltinTool, ApiTool]: + ) -> Union[BuiltinTool, ApiTool, Tool]: """ get the tool runtime @@ -113,6 +118,7 @@ def get_tool_runtime( :return: the tool """ + controller: Union[BuiltinToolProviderController, ApiToolProviderController, WorkflowToolProviderController] if provider_type == "builtin": builtin_tool = cls.get_builtin_tool(provider_id, tool_name) @@ -129,7 +135,7 @@ def get_tool_runtime( ) # get credentials - builtin_provider: BuiltinToolProvider = ( + builtin_provider: Optional[BuiltinToolProvider] = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, @@ -177,7 +183,7 @@ def get_tool_runtime( } ) elif provider_type == "workflow": - workflow_provider = ( + workflow_provider: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) .first() @@ -187,8 +193,13 @@ def get_tool_runtime( raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) + controller_tools: Optional[list[Tool]] = controller.get_tools( + user_id="", tenant_id=workflow_provider.tenant_id + ) + if controller_tools is None or len(controller_tools) == 0: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( + return controller_tools[0].fork_tool_runtime( runtime={ "tenant_id": tenant_id, "credentials": {}, @@ -215,7 +226,7 @@ def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict if parameter_rule.type == ToolParameter.ToolParameterType.SELECT: # check if tool_parameter_config in options - options = [x.value for x in parameter_rule.options] + options = [x.value for x in parameter_rule.options or []] if parameter_value is not None and parameter_value not in options: raise ValueError( f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}" @@ -267,6 +278,8 @@ def get_agent_tool_runtime( identity_id=f"AGENT.{app_id}", ) runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None: + raise ValueError("runtime not found or runtime parameters not found") tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity @@ -312,6 +325,9 @@ def get_workflow_tool_runtime( if runtime_parameters: runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None: + raise ValueError("runtime not found or runtime parameters not found") + tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity @@ -326,6 +342,8 @@ def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]: """ # get provider provider_controller = cls.get_builtin_provider(provider) + if provider_controller.identity is None: + raise ToolProviderNotFoundError(f"builtin provider {provider} not found") absolute_path = path.join( path.dirname(path.realpath(__file__)), @@ -381,11 +399,15 @@ def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, Non ), parent_type=BuiltinToolProviderController, ) - provider: BuiltinToolProviderController = provider_class() - cls._builtin_providers[provider.identity.name] = provider - for tool in provider.get_tools(): + provider_controller: BuiltinToolProviderController = provider_class() + if provider_controller.identity is None: + continue + cls._builtin_providers[provider_controller.identity.name] = provider_controller + for tool in provider_controller.get_tools() or []: + if tool.identity is None: + continue cls._builtin_tools_labels[tool.identity.name] = tool.identity.label - yield provider + yield provider_controller except Exception as e: logger.exception(f"load builtin provider {provider}") @@ -449,9 +471,11 @@ def user_list_providers( # append builtin providers for provider in builtin_providers: # handle include, exclude + if provider.identity is None: + continue if is_filtered( - include_set=dify_config.POSITION_TOOL_INCLUDES_SET, - exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET), + exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET), data=provider, name_func=lambda x: x.identity.name, ): @@ -472,7 +496,7 @@ def user_list_providers( db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() ) - api_provider_controllers = [ + api_provider_controllers: list[dict[str, Any]] = [ {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} for provider in db_api_providers ] @@ -495,7 +519,7 @@ def user_list_providers( db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() ) - workflow_provider_controllers = [] + workflow_provider_controllers: list[WorkflowToolProviderController] = [] for provider in workflow_providers: try: workflow_provider_controllers.append( @@ -505,7 +529,9 @@ def user_list_providers( # app has been deleted pass - labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers) + labels = ToolLabelManager.get_tools_labels( + [cast(ToolProviderController, controller) for controller in workflow_provider_controllers] + ) for provider_controller in workflow_provider_controllers: user_provider = ToolTransformService.workflow_provider_to_user_provider( @@ -527,7 +553,7 @@ def get_api_provider_controller( :return: the provider controller, the credentials """ - provider: ApiToolProvider = ( + provider: Optional[ApiToolProvider] = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.id == provider_id, @@ -556,7 +582,7 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: get tool provider """ provider_name = provider - provider: ApiToolProvider = ( + provider_tool: Optional[ApiToolProvider] = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, @@ -565,17 +591,18 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: .first() ) - if provider is None: + if provider_tool is None: raise ValueError(f"you have not added provider {provider_name}") try: - credentials = json.loads(provider.credentials_str) or {} + credentials = json.loads(provider_tool.credentials_str) or {} except: credentials = {} # package tool provider controller controller = ApiToolProviderController.from_db( - provider, ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE + provider_tool, + ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ) # init tool configuration tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) @@ -584,25 +611,28 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) try: - icon = json.loads(provider.icon) + icon = json.loads(provider_tool.icon) except: icon = {"background": "#252525", "content": "\ud83d\ude01"} # add tool labels labels = ToolLabelManager.get_tool_labels(controller) - return jsonable_encoder( - { - "schema_type": provider.schema_type, - "schema": provider.schema, - "tools": provider.tools, - "icon": icon, - "description": provider.description, - "credentials": masked_credentials, - "privacy_policy": provider.privacy_policy, - "custom_disclaimer": provider.custom_disclaimer, - "labels": labels, - } + return cast( + dict, + jsonable_encoder( + { + "schema_type": provider_tool.schema_type, + "schema": provider_tool.schema, + "tools": provider_tool.tools, + "icon": icon, + "description": provider_tool.description, + "credentials": masked_credentials, + "privacy_policy": provider_tool.privacy_policy, + "custom_disclaimer": provider_tool.custom_disclaimer, + "labels": labels, + } + ), ) @classmethod @@ -617,6 +647,7 @@ def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> """ provider_type = provider_type provider_id = provider_id + provider: Optional[Union[BuiltinToolProvider, ApiToolProvider, WorkflowToolProvider]] = None if provider_type == "builtin": return ( dify_config.CONSOLE_API_URL @@ -626,16 +657,21 @@ def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> ) elif provider_type == "api": try: - provider: ApiToolProvider = ( + provider = ( db.session.query(ApiToolProvider) .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) .first() ) - return json.loads(provider.icon) + if provider is None: + raise ToolProviderNotFoundError(f"api provider {provider_id} not found") + icon = json.loads(provider.icon) + if isinstance(icon, (str, dict)): + return icon + return {"background": "#252525", "content": "\ud83d\ude01"} except: return {"background": "#252525", "content": "\ud83d\ude01"} elif provider_type == "workflow": - provider: WorkflowToolProvider = ( + provider = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) .first() @@ -643,7 +679,13 @@ def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> if provider is None: raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - return json.loads(provider.icon) + try: + icon = json.loads(provider.icon) + if isinstance(icon, (str, dict)): + return icon + return {"background": "#252525", "content": "\ud83d\ude01"} + except: + return {"background": "#252525", "content": "\ud83d\ude01"} else: raise ValueError(f"provider type {provider_type} not found") diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 8b5e27f53..d77209286 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -72,9 +72,13 @@ def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str return a deep copy of credentials with decrypted values """ + identity_id = "" + if self.provider_controller.identity: + identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}" + cache = ToolProviderCredentialsCache( tenant_id=self.tenant_id, - identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}", + identity_id=identity_id, cache_type=ToolProviderCredentialsCacheType.PROVIDER, ) cached_credentials = cache.get() @@ -95,9 +99,13 @@ def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str return credentials def delete_tool_credentials_cache(self): + identity_id = "" + if self.provider_controller.identity: + identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}" + cache = ToolProviderCredentialsCache( tenant_id=self.tenant_id, - identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}", + identity_id=identity_id, cache_type=ToolProviderCredentialsCacheType.PROVIDER, ) cache.delete() @@ -199,6 +207,9 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: return a deep copy of parameters with decrypted values """ + if self.tool_runtime is None or self.tool_runtime.identity is None: + raise ValueError("tool_runtime is required") + cache = ToolParameterCache( tenant_id=self.tenant_id, provider=f"{self.provider_type}.{self.provider_name}", @@ -232,6 +243,9 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: return parameters def delete_tool_parameters_cache(self): + if self.tool_runtime is None or self.tool_runtime.identity is None: + raise ValueError("tool_runtime is required") + cache = ToolParameterCache( tenant_id=self.tenant_id, provider=f"{self.provider_type}.{self.provider_name}", diff --git a/api/core/tools/utils/feishu_api_utils.py b/api/core/tools/utils/feishu_api_utils.py index ea28037df..ecf60045a 100644 --- a/api/core/tools/utils/feishu_api_utils.py +++ b/api/core/tools/utils/feishu_api_utils.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import Any, Optional, cast import httpx @@ -101,7 +101,7 @@ def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict: """ url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token" payload = {"app_id": app_id, "app_secret": app_secret} - res = self._send_request(url, require_token=False, payload=payload) + res: dict = self._send_request(url, require_token=False, payload=payload) return res def create_document(self, title: str, content: str, folder_token: str) -> dict: @@ -126,15 +126,16 @@ def create_document(self, title: str, content: str, folder_token: str) -> dict: "content": content, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def write_document(self, document_id: str, content: str, position: str = "end") -> dict: url = f"{self.API_BASE_URL}/document/write_document" payload = {"document_id": document_id, "content": content, "position": position} - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) return res def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str: @@ -155,9 +156,9 @@ def get_document_content(self, document_id: str, mode: str = "markdown", lang: s "lang": lang, } url = f"{self.API_BASE_URL}/document/get_document_content" - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data").get("content") + return cast(str, res.get("data", {}).get("content")) return "" def list_document_blocks( @@ -173,9 +174,10 @@ def list_document_blocks( "page_token": page_token, } url = f"{self.API_BASE_URL}/document/list_document_blocks" - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict: @@ -191,9 +193,10 @@ def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, "msg_type": msg_type, "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict: @@ -203,7 +206,7 @@ def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dic "msg_type": msg_type, "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), } - res = self._send_request(url, require_token=False, payload=payload) + res: dict = self._send_request(url, require_token=False, payload=payload) return res def get_chat_messages( @@ -227,9 +230,10 @@ def get_chat_messages( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_thread_messages( @@ -245,9 +249,10 @@ def get_thread_messages( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict: @@ -260,9 +265,10 @@ def create_task(self, summary: str, start_time: str, end_time: str, completed_ti "completed_at": completed_time, "description": description, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_task( @@ -278,9 +284,10 @@ def update_task( "completed_time": completed_time, "description": description, } - res = self._send_request(url, method="PATCH", payload=payload) + res: dict = self._send_request(url, method="PATCH", payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_task(self, task_guid: str) -> dict: @@ -289,7 +296,7 @@ def delete_task(self, task_guid: str) -> dict: payload = { "task_guid": task_guid, } - res = self._send_request(url, method="DELETE", payload=payload) + res: dict = self._send_request(url, method="DELETE", payload=payload) return res def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict: @@ -300,7 +307,7 @@ def add_members(self, task_guid: str, member_phone_or_email: str, member_role: s "member_phone_or_email": member_phone_or_email, "member_role": member_role, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) return res def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict: @@ -312,9 +319,10 @@ def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: @@ -322,9 +330,10 @@ def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: params = { "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_event( @@ -347,9 +356,10 @@ def create_event( "auto_record": auto_record, "attendee_ability": attendee_ability, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_event( @@ -363,7 +373,7 @@ def update_event( auto_record: bool, ) -> dict: url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}" - payload = {} + payload: dict[str, Any] = {} if summary: payload["summary"] = summary if description: @@ -376,7 +386,7 @@ def update_event( payload["need_notification"] = need_notification if auto_record: payload["auto_record"] = auto_record - res = self._send_request(url, method="PATCH", payload=payload) + res: dict = self._send_request(url, method="PATCH", payload=payload) return res def delete_event(self, event_id: str, need_notification: bool = True) -> dict: @@ -384,7 +394,7 @@ def delete_event(self, event_id: str, need_notification: bool = True) -> dict: params = { "need_notification": need_notification, } - res = self._send_request(url, method="DELETE", params=params) + res: dict = self._send_request(url, method="DELETE", params=params) return res def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict: @@ -395,9 +405,10 @@ def list_events(self, start_time: str, end_time: str, page_token: str, page_size "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def search_events( @@ -418,9 +429,10 @@ def search_events( "user_id_type": user_id_type, "page_size": page_size, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict: @@ -431,9 +443,10 @@ def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_ "attendee_phone_or_email": attendee_phone_or_email, "need_notification": need_notification, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_spreadsheet( @@ -447,9 +460,10 @@ def create_spreadsheet( "title": title, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_spreadsheet( @@ -463,9 +477,10 @@ def get_spreadsheet( "spreadsheet_token": spreadsheet_token, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def list_spreadsheet_sheets( @@ -477,9 +492,10 @@ def list_spreadsheet_sheets( params = { "spreadsheet_token": spreadsheet_token, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_rows( @@ -499,9 +515,10 @@ def add_rows( "length": length, "values": values, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_cols( @@ -521,9 +538,10 @@ def add_cols( "length": length, "values": values, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_rows( @@ -545,9 +563,10 @@ def read_rows( "num_rows": num_rows, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_cols( @@ -569,9 +588,10 @@ def read_cols( "num_cols": num_cols, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_table( @@ -593,9 +613,10 @@ def read_table( "query": query, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_base( @@ -609,9 +630,10 @@ def create_base( "name": name, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_records( @@ -633,9 +655,10 @@ def add_records( payload = { "records": convert_add_records(records), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_records( @@ -657,9 +680,10 @@ def update_records( payload = { "records": convert_update_records(records), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_records( @@ -686,9 +710,10 @@ def delete_records( payload = { "records": record_id_list, } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def search_record( @@ -740,7 +765,7 @@ def search_record( except json.JSONDecodeError: raise ValueError("The input string is not valid JSON") - payload = {} + payload: dict[str, Any] = {} if view_id: payload["view_id"] = view_id @@ -752,10 +777,11 @@ def search_record( payload["filter"] = filter_dict if automatic_fields: payload["automatic_fields"] = automatic_fields - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_base_info( @@ -767,9 +793,10 @@ def get_base_info( params = { "app_token": app_token, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_table( @@ -797,9 +824,10 @@ def create_table( } if default_view_name: payload["default_view_name"] = default_view_name - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_tables( @@ -834,9 +862,10 @@ def delete_tables( "table_names": table_name_list, } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def list_tables( @@ -852,9 +881,10 @@ def list_tables( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_records( @@ -882,7 +912,8 @@ def read_records( "record_ids": record_id_list, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params, payload=payload) + res: dict = self._send_request(url, method="GET", params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res diff --git a/api/core/tools/utils/lark_api_utils.py b/api/core/tools/utils/lark_api_utils.py index 30cb0cb14..de394a39b 100644 --- a/api/core/tools/utils/lark_api_utils.py +++ b/api/core/tools/utils/lark_api_utils.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import Any, Optional, cast import httpx @@ -62,12 +62,10 @@ def convert_update_records(self, json_str): def tenant_access_token(self) -> str: feishu_tenant_access_token = f"tools:{self.app_id}:feishu_tenant_access_token" if redis_client.exists(feishu_tenant_access_token): - return redis_client.get(feishu_tenant_access_token).decode() - res = self.get_tenant_access_token(self.app_id, self.app_secret) + return str(redis_client.get(feishu_tenant_access_token).decode()) + res: dict[str, str] = self.get_tenant_access_token(self.app_id, self.app_secret) redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token")) - if "tenant_access_token" in res: - return res.get("tenant_access_token") - return "" + return res.get("tenant_access_token", "") def _send_request( self, @@ -91,7 +89,7 @@ def _send_request( def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict: url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token" payload = {"app_id": app_id, "app_secret": app_secret} - res = self._send_request(url, require_token=False, payload=payload) + res: dict = self._send_request(url, require_token=False, payload=payload) return res def create_document(self, title: str, content: str, folder_token: str) -> dict: @@ -101,15 +99,16 @@ def create_document(self, title: str, content: str, folder_token: str) -> dict: "content": content, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def write_document(self, document_id: str, content: str, position: str = "end") -> dict: url = f"{self.API_BASE_URL}/document/write_document" payload = {"document_id": document_id, "content": content, "position": position} - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) return res def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str | dict: @@ -119,9 +118,9 @@ def get_document_content(self, document_id: str, mode: str = "markdown", lang: s "lang": lang, } url = f"{self.API_BASE_URL}/document/get_document_content" - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data").get("content") + return cast(dict, res.get("data", {}).get("content")) return "" def list_document_blocks( @@ -134,9 +133,10 @@ def list_document_blocks( "page_token": page_token, } url = f"{self.API_BASE_URL}/document/list_document_blocks" - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict: @@ -149,9 +149,10 @@ def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, "msg_type": msg_type, "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict: @@ -161,7 +162,7 @@ def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dic "msg_type": msg_type, "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), } - res = self._send_request(url, require_token=False, payload=payload) + res: dict = self._send_request(url, require_token=False, payload=payload) return res def get_chat_messages( @@ -182,9 +183,10 @@ def get_chat_messages( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_thread_messages( @@ -197,9 +199,10 @@ def get_thread_messages( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict: @@ -211,9 +214,10 @@ def create_task(self, summary: str, start_time: str, end_time: str, completed_ti "completed_at": completed_time, "description": description, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_task( @@ -228,9 +232,10 @@ def update_task( "completed_time": completed_time, "description": description, } - res = self._send_request(url, method="PATCH", payload=payload) + res: dict = self._send_request(url, method="PATCH", payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_task(self, task_guid: str) -> dict: @@ -238,9 +243,10 @@ def delete_task(self, task_guid: str) -> dict: payload = { "task_guid": task_guid, } - res = self._send_request(url, method="DELETE", payload=payload) + res: dict = self._send_request(url, method="DELETE", payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict: @@ -250,9 +256,10 @@ def add_members(self, task_guid: str, member_phone_or_email: str, member_role: s "member_phone_or_email": member_phone_or_email, "member_role": member_role, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict: @@ -263,9 +270,10 @@ def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: @@ -273,9 +281,10 @@ def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: params = { "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_event( @@ -298,9 +307,10 @@ def create_event( "auto_record": auto_record, "attendee_ability": attendee_ability, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_event( @@ -314,7 +324,7 @@ def update_event( auto_record: bool, ) -> dict: url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}" - payload = {} + payload: dict[str, Any] = {} if summary: payload["summary"] = summary if description: @@ -327,7 +337,7 @@ def update_event( payload["need_notification"] = need_notification if auto_record: payload["auto_record"] = auto_record - res = self._send_request(url, method="PATCH", payload=payload) + res: dict = self._send_request(url, method="PATCH", payload=payload) return res def delete_event(self, event_id: str, need_notification: bool = True) -> dict: @@ -335,7 +345,7 @@ def delete_event(self, event_id: str, need_notification: bool = True) -> dict: params = { "need_notification": need_notification, } - res = self._send_request(url, method="DELETE", params=params) + res: dict = self._send_request(url, method="DELETE", params=params) return res def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict: @@ -346,9 +356,10 @@ def list_events(self, start_time: str, end_time: str, page_token: str, page_size "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def search_events( @@ -369,9 +380,10 @@ def search_events( "user_id_type": user_id_type, "page_size": page_size, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict: @@ -381,9 +393,10 @@ def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_ "attendee_phone_or_email": attendee_phone_or_email, "need_notification": need_notification, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_spreadsheet( @@ -396,9 +409,10 @@ def create_spreadsheet( "title": title, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_spreadsheet( @@ -411,9 +425,10 @@ def get_spreadsheet( "spreadsheet_token": spreadsheet_token, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def list_spreadsheet_sheets( @@ -424,9 +439,10 @@ def list_spreadsheet_sheets( params = { "spreadsheet_token": spreadsheet_token, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_rows( @@ -445,9 +461,10 @@ def add_rows( "length": length, "values": values, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_cols( @@ -466,9 +483,10 @@ def add_cols( "length": length, "values": values, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_rows( @@ -489,9 +507,10 @@ def read_rows( "num_rows": num_rows, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_cols( @@ -512,9 +531,10 @@ def read_cols( "num_cols": num_cols, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_table( @@ -535,9 +555,10 @@ def read_table( "query": query, "user_id_type": user_id_type, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_base( @@ -550,9 +571,10 @@ def create_base( "name": name, "folder_token": folder_token, } - res = self._send_request(url, payload=payload) + res: dict = self._send_request(url, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def add_records( @@ -573,9 +595,10 @@ def add_records( payload = { "records": self.convert_add_records(records), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def update_records( @@ -596,9 +619,10 @@ def update_records( payload = { "records": self.convert_update_records(records), } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_records( @@ -624,9 +648,10 @@ def delete_records( payload = { "records": record_id_list, } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def search_record( @@ -678,7 +703,7 @@ def search_record( except json.JSONDecodeError: raise ValueError("The input string is not valid JSON") - payload = {} + payload: dict[str, Any] = {} if view_id: payload["view_id"] = view_id @@ -690,9 +715,10 @@ def search_record( payload["filter"] = filter_dict if automatic_fields: payload["automatic_fields"] = automatic_fields - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def get_base_info( @@ -703,9 +729,10 @@ def get_base_info( params = { "app_token": app_token, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def create_table( @@ -732,9 +759,10 @@ def create_table( } if default_view_name: payload["default_view_name"] = default_view_name - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def delete_tables( @@ -767,9 +795,10 @@ def delete_tables( "table_ids": table_id_list, "table_names": table_name_list, } - res = self._send_request(url, params=params, payload=payload) + res: dict = self._send_request(url, params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def list_tables( @@ -784,9 +813,10 @@ def list_tables( "page_token": page_token, "page_size": page_size, } - res = self._send_request(url, method="GET", params=params) + res: dict = self._send_request(url, method="GET", params=params) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res def read_records( @@ -814,7 +844,8 @@ def read_records( "record_ids": record_id_list, "user_id_type": user_id_type, } - res = self._send_request(url, method="POST", params=params, payload=payload) + res: dict = self._send_request(url, method="POST", params=params, payload=payload) if "data" in res: - return res.get("data") + data: dict = res.get("data", {}) + return data return res diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index e30c903a4..3509f1e6e 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -90,12 +90,12 @@ def transform_tool_invoke_messages( ) elif message.type == ToolInvokeMessage.MessageType.FILE: assert message.meta is not None - file = message.meta.get("file") - if isinstance(file, File): - if file.transfer_method == FileTransferMethod.TOOL_FILE: - assert file.related_id is not None - url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) - if file.type == FileType.IMAGE: + file_mata = message.meta.get("file") + if isinstance(file_mata, File): + if file_mata.transfer_method == FileTransferMethod.TOOL_FILE: + assert file_mata.related_id is not None + url = cls.get_tool_file_url(tool_file_id=file_mata.related_id, extension=file_mata.extension) + if file_mata.type == FileType.IMAGE: result.append( ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 4e226810d..3689dcc9e 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -5,7 +5,7 @@ """ import json -from typing import cast +from typing import Optional, cast from core.model_manager import ModelManager from core.model_runtime.entities.llm_entities import LLMResult @@ -51,7 +51,7 @@ def get_max_llm_context_tokens( if not schema: raise InvokeModelError("No model schema found") - max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) + max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) if max_tokens is None: return 2048 @@ -133,14 +133,17 @@ def invoke( db.session.commit() try: - response: LLMResult = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=[], - stop=[], - stream=False, - user=user_id, - callbacks=[], + response: LLMResult = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=[], + stop=[], + stream=False, + user=user_id, + callbacks=[], + ), ) except InvokeRateLimitError as e: raise InvokeModelError(f"Invoke rate limit error: {e}") diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index ae44b1b99..9d88d6d6e 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -5,8 +5,9 @@ from json.decoder import JSONDecodeError from typing import Optional +from flask import request from requests import get -from yaml import YAMLError, safe_load +from yaml import YAMLError, safe_load # type: ignore from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle @@ -29,6 +30,10 @@ def parse_openapi_to_tool_bundle( raise ToolProviderNotFoundError("No server found in the openapi yaml.") server_url = openapi["servers"][0]["url"] + request_env = request.headers.get("X-Request-Env") + if request_env: + matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env] + server_url = matched_servers[0] if matched_servers else server_url # list all interfaces interfaces = [] @@ -64,6 +69,9 @@ def parse_openapi_to_tool_bundle( default=parameter["schema"]["default"] if "schema" in parameter and "default" in parameter["schema"] else None, + placeholder=I18nObject( + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") + ), ) # check if there is a type @@ -108,6 +116,9 @@ def parse_openapi_to_tool_bundle( form=ToolParameter.ToolParameterForm.LLM, llm_description=property.get("description", ""), default=property.get("default", None), + placeholder=I18nObject( + en_US=property.get("description", ""), zh_Hans=property.get("description", "") + ), ) # check if there is a type @@ -158,9 +169,9 @@ def parse_openapi_to_tool_bundle( return bundles @staticmethod - def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType: + def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]: parameter = parameter or {} - typ = None + typ: Optional[str] = None if parameter.get("format") == "binary": return ToolParameter.ToolParameterType.FILE @@ -175,6 +186,8 @@ def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType return ToolParameter.ToolParameterType.BOOLEAN elif typ == "string": return ToolParameter.ToolParameterType.STRING + else: + return None @staticmethod def parse_openapi_yaml_to_tool_bundle( @@ -236,7 +249,8 @@ def parse_swagger_to_openapi(swagger: dict, extra_info: Optional[dict], warning: if ("summary" not in operation or len(operation["summary"]) == 0) and ( "description" not in operation or len(operation["description"]) == 0 ): - warning["missing_summary"] = f"No summary or description found in operation {method} {path}." + if warning is not None: + warning["missing_summary"] = f"No summary or description found in operation {method} {path}." openapi["paths"][path][method] = { "operationId": operation["operationId"], diff --git a/api/core/tools/utils/text_processing_utils.py b/api/core/tools/utils/text_processing_utils.py index 6db9dfd0d..105823f89 100644 --- a/api/core/tools/utils/text_processing_utils.py +++ b/api/core/tools/utils/text_processing_utils.py @@ -12,5 +12,6 @@ def remove_leading_symbols(text: str) -> str: str: The text with leading punctuation or symbols removed. """ # Match Unicode ranges for punctuation and symbols - pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,\-./:;<=>?@\[\]^_`{|}~]+" + # FIXME this pattern is confused quick fix for #11868 maybe refactor it later + pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+" return re.sub(pattern, "", text) diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 3aae31e93..d42fd99fc 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -9,13 +9,13 @@ import unicodedata from contextlib import contextmanager from pathlib import Path -from typing import Optional +from typing import Any, Literal, Optional, cast from urllib.parse import unquote import chardet -import cloudscraper -from bs4 import BeautifulSoup, CData, Comment, NavigableString -from regex import regex +import cloudscraper # type: ignore +from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore +from regex import regex # type: ignore from core.helper import ssrf_proxy from core.rag.extractor import extract_processor @@ -68,7 +68,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str: return "Unsupported content-type [{}] of URL.".format(main_content_type) if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: - return ExtractProcessor.load_from_url(url, return_text=True) + return cast(str, ExtractProcessor.load_from_url(url, return_text=True)) response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) elif response.status_code == 403: @@ -125,7 +125,7 @@ def extract_using_readabilipy(html): os.unlink(article_json_path) os.unlink(html_path) - article_json = { + article_json: dict[str, Any] = { "title": None, "byline": None, "date": None, @@ -300,7 +300,7 @@ def strip_control_characters(text): def normalize_unicode(text): """Normalize unicode such that things that are visually equivalent map to the same unicode string where possible.""" - normal_form = "NFKC" + normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC" text = unicodedata.normalize(normal_form, text) return text @@ -332,6 +332,7 @@ def add_content_digest(element): def content_digest(element): + digest: Any if is_text(element): # Hash trimmed_string = element.string.strip() diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index d92bfb9b9..08a112cfd 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -7,7 +7,7 @@ class WorkflowToolConfigurationUtils: @classmethod - def check_parameter_configurations(cls, configurations: Mapping[str, Any]): + def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]): for configuration in configurations: WorkflowToolParameterConfiguration.model_validate(configuration) @@ -27,7 +27,7 @@ def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[Vari @classmethod def check_is_synced( cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] - ) -> None: + ) -> bool: """ check is synced diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index 42c7f85bc..ee7ca11e0 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Any -import yaml +import yaml # type: ignore from yaml import YAMLError logger = logging.getLogger(__name__) diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index f9268b52e..c32815b24 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from typing import cast from uuid import uuid4 from pydantic import Field @@ -78,7 +79,7 @@ class SecretVariable(StringVariable): @property def log(self) -> str: - return encrypter.obfuscated_token(self.value) + return cast(str, encrypter.obfuscated_token(self.value)) class NoneVariable(NoneSegment, Variable): @@ -90,5 +91,5 @@ class FileVariable(FileSegment, Variable): pass -class ArrayFileVariable(ArrayFileSegment, Variable): +class ArrayFileVariable(ArrayFileSegment, ArrayVariable): pass diff --git a/api/core/workflow/callbacks/workflow_logging_callback.py b/api/core/workflow/callbacks/workflow_logging_callback.py index ed737e731..b9c6b35ad 100644 --- a/api/core/workflow/callbacks/workflow_logging_callback.py +++ b/api/core/workflow/callbacks/workflow_logging_callback.py @@ -33,7 +33,7 @@ class WorkflowLoggingCallback(WorkflowCallback): def __init__(self) -> None: - self.current_node_id = None + self.current_node_id: Optional[str] = None def on_event(self, event: GraphEngineEvent) -> None: if isinstance(event, GraphRunStartedEvent): diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index ca01dcd7d..ae5f117bf 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -36,9 +36,9 @@ class NodeRunResult(BaseModel): status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING inputs: Optional[Mapping[str, Any]] = None # node inputs - process_data: Optional[dict[str, Any]] = None # process data + process_data: Optional[Mapping[str, Any]] = None # process data outputs: Optional[Mapping[str, Any]] = None # node outputs - metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata + metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata llm_usage: Optional[LLMUsage] = None # llm usage edge_source_handle: Optional[str] = None # source handle id of node with multiple branches diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py index bc3a15bd0..b8470aecb 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -5,7 +5,7 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler): - def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState): """ Check if the condition can be executed diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 800dd136a..5c672c985 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -1,4 +1,5 @@ import uuid +from collections import defaultdict from collections.abc import Mapping from typing import Any, Optional, cast @@ -310,26 +311,17 @@ def _recursively_add_parallels( parallel = None if len(target_node_edges) > 1: # fetch all node ids in current parallels - parallel_branch_node_ids = {} - condition_edge_mappings = {} + parallel_branch_node_ids = defaultdict(list) + condition_edge_mappings = defaultdict(list) for graph_edge in target_node_edges: if graph_edge.run_condition is None: - if "default" not in parallel_branch_node_ids: - parallel_branch_node_ids["default"] = [] - parallel_branch_node_ids["default"].append(graph_edge.target_node_id) else: condition_hash = graph_edge.run_condition.hash - if condition_hash not in condition_edge_mappings: - condition_edge_mappings[condition_hash] = [] - condition_edge_mappings[condition_hash].append(graph_edge) for condition_hash, graph_edges in condition_edge_mappings.items(): if len(graph_edges) > 1: - if condition_hash not in parallel_branch_node_ids: - parallel_branch_node_ids[condition_hash] = [] - for graph_edge in graph_edges: parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id) @@ -418,7 +410,7 @@ def _recursively_add_parallels( if condition_edge_mappings: for condition_hash, graph_edges in condition_edge_mappings.items(): for graph_edge in graph_edges: - current_parallel: GraphParallel | None = cls._get_current_parallel( + current_parallel = cls._get_current_parallel( parallel_mapping=parallel_mapping, graph_edge=graph_edge, parallel=condition_parallels.get(condition_hash), @@ -621,10 +613,10 @@ def _fetch_all_node_ids_in_parallels( for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items(): # check which node is after if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping): - if node_id in merge_branch_node_ids: + if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids: del merge_branch_node_ids[node_id2] elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping): - if node_id2 in merge_branch_node_ids: + if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids: del merge_branch_node_ids[node_id] branches_merge_node_ids: dict[str, str] = {} diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index d7d33c65f..db1e01f14 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -40,6 +40,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.nodes import NodeType from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor +from core.workflow.nodes.answer.base_stream_processor import StreamProcessor from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor @@ -66,7 +67,7 @@ def __init__( self.max_submit_count = max_submit_count self.submit_count = 0 - def submit(self, fn, *args, **kwargs): + def submit(self, fn, /, *args, **kwargs): self.submit_count += 1 self.check_is_full() @@ -140,7 +141,8 @@ def __init__( def run(self) -> Generator[GraphEngineEvent, None, None]: # trigger graph run start event yield GraphRunStartedEvent() - handle_exceptions = [] + handle_exceptions: list[str] = [] + stream_processor: StreamProcessor try: if self.init_params.workflow_type == WorkflowType.CHAT: @@ -168,7 +170,7 @@ def run(self) -> Generator[GraphEngineEvent, None, None]: elif isinstance(item, NodeRunSucceededEvent): if item.node_type == NodeType.END: self.graph_runtime_state.outputs = ( - item.route_node_state.node_run_result.outputs + dict(item.route_node_state.node_run_result.outputs) if item.route_node_state.node_run_result and item.route_node_state.node_run_result.outputs else {} @@ -350,7 +352,7 @@ def _run( if any(edge.run_condition for edge in edge_mappings): # if nodes has run conditions, get node id which branch to take based on the run condition results - condition_edge_mappings = {} + condition_edge_mappings: dict[str, list[GraphEdge]] = {} for edge in edge_mappings: if edge.run_condition: run_condition_hash = edge.run_condition.hash @@ -364,6 +366,9 @@ def _run( continue edge = cast(GraphEdge, sub_edge_mappings[0]) + if edge.run_condition is None: + logger.warning(f"Edge {edge.target_node_id} run condition is None") + continue result = ConditionManager.get_condition_handler( init_params=self.init_params, @@ -387,11 +392,11 @@ def _run( handle_exceptions=handle_exceptions, ) - for item in parallel_generator: - if isinstance(item, str): - final_node_id = item + for parallel_result in parallel_generator: + if isinstance(parallel_result, str): + final_node_id = parallel_result else: - yield item + yield parallel_result break @@ -413,11 +418,11 @@ def _run( handle_exceptions=handle_exceptions, ) - for item in parallel_generator: - if isinstance(item, str): - final_node_id = item + for generated_item in parallel_generator: + if isinstance(generated_item, str): + final_node_id = generated_item else: - yield item + yield generated_item if not final_node_id: break @@ -612,8 +617,8 @@ def _run_node( max_retries = node_instance.node_data.retry_config.max_retries retry_interval = node_instance.node_data.retry_config.retry_interval_seconds retries = 0 - shoudl_continue_retry = True - while shoudl_continue_retry and retries <= max_retries: + should_continue_retry = True + while should_continue_retry and retries <= max_retries: try: # run node retry_start_at = datetime.now(UTC).replace(tzinfo=None) @@ -653,7 +658,7 @@ def _run_node( parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - error=run_result.error, + error=run_result.error or "Unknown error", retry_index=retries, start_at=retry_start_at, ) @@ -692,7 +697,7 @@ def _run_node( parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, ) - shoudl_continue_retry = False + should_continue_retry = False else: yield NodeRunFailedEvent( error=route_node_state.failed_reason or "Unknown error.", @@ -706,7 +711,7 @@ def _run_node( parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, ) - shoudl_continue_retry = False + should_continue_retry = False elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: if node_instance.should_continue_on_error and self.graph.edge_mapping.get( node_instance.node_id @@ -732,20 +737,20 @@ def _run_node( variable_value=variable_value, ) - # add parallel info to run result metadata - if parallel_id and parallel_start_node_id: - if not run_result.metadata: - run_result.metadata = {} + # When setting metadata, convert to dict first + if not run_result.metadata: + run_result.metadata = {} - run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id - run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = ( - parallel_start_node_id - ) + if parallel_id and parallel_start_node_id: + metadata_dict = dict(run_result.metadata) + metadata_dict[NodeRunMetadataKey.PARALLEL_ID] = parallel_id + metadata_dict[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id if parent_parallel_id and parent_parallel_start_node_id: - run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id - run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( + metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id + metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( parent_parallel_start_node_id ) + run_result.metadata = metadata_dict yield NodeRunSucceededEvent( id=node_instance.id, @@ -758,7 +763,7 @@ def _run_node( parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, ) - shoudl_continue_retry = False + should_continue_retry = False break elif isinstance(item, RunStreamChunkEvent): @@ -869,8 +874,8 @@ def _handle_continue_on_error( variable_pool.add([node_instance.node_id, "error_message"], error_result.error) variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type) # add error message to handle_exceptions - handle_exceptions.append(error_result.error) - node_error_args = { + handle_exceptions.append(error_result.error or "") + node_error_args: dict[str, Any] = { "status": WorkflowNodeExecutionStatus.EXCEPTION, "error": error_result.error, "inputs": error_result.inputs, diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index ed033e7f2..40213bd15 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -63,7 +63,7 @@ def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generat self._remove_unreachable_nodes(event) # generate stream outputs - yield from self._generate_stream_outputs_when_node_finished(event) + yield from self._generate_stream_outputs_when_node_finished(cast(NodeRunSucceededEvent, event)) else: yield event @@ -130,7 +130,7 @@ def _generate_stream_outputs_when_node_finished( node_type=event.node_type, node_data=event.node_data, chunk_content=text, - from_variable_selector=value_selector, + from_variable_selector=list(value_selector), route_node_state=event.route_node_state, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index d785397e1..4759356ae 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -1,9 +1,10 @@ import logging from abc import ABC, abstractmethod from collections.abc import Generator +from typing import Optional from core.workflow.entities.variable_pool import VariablePool -from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent +from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent from core.workflow.graph_engine.entities.graph import Graph logger = logging.getLogger(__name__) @@ -19,7 +20,7 @@ def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: raise NotImplementedError - def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: + def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None: finished_node_id = event.route_node_state.node_id if finished_node_id not in self.rest_node_ids: return @@ -32,8 +33,8 @@ def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: return if run_result.edge_source_handle: - reachable_node_ids = [] - unreachable_first_node_ids = [] + reachable_node_ids: list[str] = [] + unreachable_first_node_ids: list[str] = [] if finished_node_id not in self.graph.edge_mapping: logger.warning(f"node {finished_node_id} has no edge mapping") return @@ -48,23 +49,35 @@ def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: # we remove the node maybe shortcut the answer node, so comment this code for now # there is not effect on the answer node and the workflow, when we have a better solution # we can open this code. Issues: #11542 #9560 #10638 #10564 - - # reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) - continue + # ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id) + # if "answer" in ids: + # continue + # else: + # reachable_node_ids.extend(ids) + + # The branch_identify parameter is added to ensure that + # only nodes in the correct logical branch are included. + ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle) + reachable_node_ids.extend(ids) else: unreachable_first_node_ids.append(edge.target_node_id) for node_id in unreachable_first_node_ids: self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) - def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]: + def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]: node_ids = [] for edge in self.graph.edge_mapping.get(node_id, []): if edge.target_node_id == self.graph.root_node_id: continue + # Only follow edges that match the branch_identify or have no run_condition + if edge.run_condition and edge.run_condition.branch_identify: + if not branch_identify or edge.run_condition.branch_identify != branch_identify: + continue + node_ids.append(edge.target_node_id) - node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) + node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify)) return node_ids def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None: diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 529fd7be7..6bf8899f5 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -38,7 +38,8 @@ def _parse_json(value: str) -> Any: @staticmethod def _validate_array(value: Any, element_type: DefaultValueType) -> bool: """Unified array type validation""" - return isinstance(value, list) and all(isinstance(x, element_type) for x in value) + # FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it + return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore @staticmethod def _convert_number(value: str) -> float: @@ -84,7 +85,7 @@ def validate_value_type(self) -> "DefaultValue": }, } - validator = type_validators.get(self.type) + validator: dict[str, Any] = type_validators.get(self.type, {}) if not validator: if self.type == DefaultValueType.ARRAY_FILES: # Handle files type diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 4e371ca43..2f82bf8c3 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -125,7 +125,7 @@ def _transform_result( if depth > dify_config.CODE_MAX_DEPTH: raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") - transformed_result = {} + transformed_result: dict[str, Any] = {} if output_schema is None: # validate output thought instance type for output_name, output_value in result.items(): diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index e78183baf..a45403588 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -14,7 +14,7 @@ class CodeNodeData(BaseNodeData): class Output(BaseModel): type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] - children: Optional[dict[str, "Output"]] = None + children: Optional[dict[str, "CodeNodeData.Output"]] = None class Dependency(BaseModel): name: str diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 6d82dbe6d..c0d8c6409 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -2,13 +2,18 @@ import io import json import logging +import operator import os import tempfile +from collections.abc import Mapping, Sequence +from typing import Any, cast import docx import pandas as pd import pypdfium2 # type: ignore import yaml # type: ignore +from docx.table import Table +from docx.text.paragraph import Paragraph from configs import dify_config from core.file import File, FileTransferMethod, file_manager @@ -77,6 +82,23 @@ def _run(self): process_data=process_data, ) + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: DocumentExtractorNodeData, + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ + return {node_id + ".files": node_data.variable_selector} + def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: """Extract text from a file based on its MIME type.""" @@ -159,7 +181,7 @@ def _extract_text_from_yaml(file_content: bytes) -> str: """Extract the content from yaml file""" try: yaml_data = yaml.safe_load_all(file_content.decode("utf-8", "ignore")) - return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) + return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)) except (UnicodeDecodeError, yaml.YAMLError) as e: raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e @@ -188,35 +210,56 @@ def _extract_text_from_doc(file_content: bytes) -> str: doc_file = io.BytesIO(file_content) doc = docx.Document(doc_file) text = [] - # Process paragraphs - for paragraph in doc.paragraphs: - if paragraph.text.strip(): - text.append(paragraph.text) - # Process tables - for table in doc.tables: - # Table header - try: - # table maybe cause errors so ignore it. - if len(table.rows) > 0 and table.rows[0].cells is not None: + # Keep track of paragraph and table positions + content_items: list[tuple[int, str, Table | Paragraph]] = [] + + # Process paragraphs and tables + for i, paragraph in enumerate(doc.paragraphs): + if paragraph.text.strip(): + content_items.append((i, "paragraph", paragraph)) + + for i, table in enumerate(doc.tables): + content_items.append((i, "table", table)) + + # Sort content items based on their original position + content_items.sort(key=operator.itemgetter(0)) + + # Process sorted content + for _, item_type, item in content_items: + if item_type == "paragraph": + if isinstance(item, Table): + continue + text.append(item.text) + elif item_type == "table": + # Process tables + if not isinstance(item, Table): + continue + try: # Check if any cell in the table has text has_content = False - for row in table.rows: + for row in item.rows: if any(cell.text.strip() for cell in row.cells): has_content = True break if has_content: - markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n" - markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n" - for row in table.rows[1:]: - markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n" + cell_texts = [cell.text.replace("\n", "
") for cell in item.rows[0].cells] + markdown_table = f"| {' | '.join(cell_texts)} |\n" + markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n" + + for row in item.rows[1:]: + # Replace newlines with
in each cell + row_cells = [cell.text.replace("\n", "
") for cell in row.cells] + markdown_table += "| " + " | ".join(row_cells) + " |\n" + text.append(markdown_table) - except Exception as e: - logger.warning(f"Failed to extract table from DOC/DOCX: {e}") - continue + except Exception as e: + logger.warning(f"Failed to extract table from DOC/DOCX: {e}") + continue return "\n".join(text) + except Exception as e: raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e @@ -229,9 +272,9 @@ def _download_file_content(file: File) -> bytes: raise FileDownloadError("Missing URL for remote file") response = ssrf_proxy.get(file.remote_url) response.raise_for_status() - return response.content + return cast(bytes, response.content) else: - return file_manager.download(file) + return cast(bytes, file_manager.download(file)) except Exception as e: raise FileDownloadError(f"Error downloading file: {str(e)}") from e diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index 0db1ba9f0..b3678a82b 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -67,7 +67,7 @@ def extract_stream_variable_selector_from_node_data( and node_type == NodeType.LLM.value and variable_selector.value_selector[1] == "text" ): - value_selectors.append(variable_selector.value_selector) + value_selectors.append(list(variable_selector.value_selector)) return value_selectors @@ -119,8 +119,7 @@ def _recursive_fetch_end_dependencies( current_node_id: str, end_node_id: str, node_id_config_mapping: dict[str, dict], - reverse_edge_mapping: dict[str, list["GraphEdge"]], - # type: ignore[name-defined] + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] end_dependencies: dict[str, list[str]], ) -> None: """ diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py index 1aecf863a..a770eb951 100644 --- a/api/core/workflow/nodes/end/end_stream_processor.py +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -23,7 +23,7 @@ def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: self.route_position[end_node_id] = 0 self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} self.has_output = False - self.output_node_ids = set() + self.output_node_ids: set[str] = set() def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: for event in generator: diff --git a/api/core/workflow/nodes/event/event.py b/api/core/workflow/nodes/event/event.py index 137b47655..9fea3fbda 100644 --- a/api/core/workflow/nodes/event/event.py +++ b/api/core/workflow/nodes/event/event.py @@ -42,6 +42,6 @@ class RunRetryEvent(BaseModel): class SingleStepRetryEvent(NodeRunResult): """Single step retry event""" - status: str = WorkflowNodeExecutionStatus.RETRY.value + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RETRY elapsed_time: float = Field(..., description="elapsed time") diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 5e39ef79d..5764ce725 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -68,7 +68,22 @@ class HttpRequestNodeData(BaseNodeData): Code Node Data. """ - method: Literal["get", "post", "put", "patch", "delete", "head"] + method: Literal[ + "get", + "post", + "put", + "patch", + "delete", + "head", + "options", + "GET", + "POST", + "PUT", + "PATCH", + "DELETE", + "HEAD", + "OPTIONS", + ] url: str authorization: HttpRequestNodeAuthorization headers: str diff --git a/api/core/workflow/nodes/http_request/exc.py b/api/core/workflow/nodes/http_request/exc.py index a815f277b..46613c9e8 100644 --- a/api/core/workflow/nodes/http_request/exc.py +++ b/api/core/workflow/nodes/http_request/exc.py @@ -20,3 +20,7 @@ class ResponseSizeError(HttpRequestNodeError): class RequestBodyError(HttpRequestNodeError): """Raised when the request body is invalid.""" + + +class InvalidURLError(HttpRequestNodeError): + """Raised when the URL is invalid.""" diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 575db15d3..87b71394e 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -23,6 +23,7 @@ FileFetchError, HttpRequestNodeError, InvalidHttpMethodError, + InvalidURLError, RequestBodyError, ResponseSizeError, ) @@ -36,7 +37,22 @@ class Executor: - method: Literal["get", "head", "post", "put", "delete", "patch"] + method: Literal[ + "get", + "head", + "post", + "put", + "delete", + "patch", + "options", + "GET", + "POST", + "PUT", + "PATCH", + "DELETE", + "HEAD", + "OPTIONS", + ] url: str params: list[tuple[str, str]] | None content: str | bytes | None @@ -92,6 +108,12 @@ def _initialize(self): def _init_url(self): self.url = self.variable_pool.convert_template(self.node_data.url).text + # check if url is a valid URL + if not self.url: + raise InvalidURLError("url is required") + if not self.url.startswith(("http://", "https://")): + raise InvalidURLError("url should start with http:// or https://") + def _init_params(self): """ Almost same as _init_headers(), difference: @@ -107,9 +129,9 @@ def _init_params(self): if not (key := key.strip()): continue - value = value[0].strip() if value else "" + value_str = value[0].strip() if value else "" result.append( - (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value).text) + (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text) ) self.params = result @@ -151,7 +173,10 @@ def _init_body(self): if len(data) != 1: raise RequestBodyError("json body type should have exactly one item") json_string = self.variable_pool.convert_template(data[0].value).text - json_object = json.loads(json_string, strict=False) + try: + json_object = json.loads(json_string, strict=False) + except json.JSONDecodeError as e: + raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e self.json = json_object # self.json = self._parse_object_contains_variables(json_object) case "binary": @@ -182,9 +207,10 @@ def _init_body(self): self.variable_pool.convert_template(item.key).text: item.file for item in filter(lambda item: item.type == "file", data) } + files: dict[str, Any] = {} files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()} files = {k: v for k, v in files.items() if v is not None} - files = {k: variable.value for k, variable in files.items()} + files = {k: variable.value for k, variable in files.items() if variable is not None} files = { k: (v.filename, file_manager.download(v), v.mime_type or "application/octet-stream") for k, v in files.items() @@ -238,7 +264,22 @@ def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: """ do http request depending on api bundle """ - if self.method not in {"get", "head", "post", "put", "delete", "patch"}: + if self.method not in { + "get", + "head", + "post", + "put", + "delete", + "patch", + "options", + "GET", + "POST", + "PUT", + "PATCH", + "DELETE", + "HEAD", + "OPTIONS", + }: raise InvalidHttpMethodError(f"Invalid http method {self.method}") request_args = { @@ -255,10 +296,11 @@ def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: } # request_args = {k: v for k, v in request_args.items() if v is not None} try: - response = getattr(ssrf_proxy, self.method)(**request_args) + response = getattr(ssrf_proxy, self.method.lower())(**request_args) except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e: raise HttpRequestNodeError(str(e)) - return response + # FIXME: fix type ignore, this maybe httpx type issue + return response # type: ignore def invoke(self) -> Response: # assemble headers @@ -300,37 +342,37 @@ def to_log(self): continue raw += f"{k}: {v}\r\n" - body = "" + body_string = "" if self.files: for k, v in self.files.items(): - body += f"--{boundary}\r\n" - body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n' - body += f"{v[1]}\r\n" - body += f"--{boundary}--\r\n" + body_string += f"--{boundary}\r\n" + body_string += f'Content-Disposition: form-data; name="{k}"\r\n\r\n' + body_string += f"{v[1]}\r\n" + body_string += f"--{boundary}--\r\n" elif self.node_data.body: if self.content: if isinstance(self.content, str): - body = self.content + body_string = self.content elif isinstance(self.content, bytes): - body = self.content.decode("utf-8", errors="replace") + body_string = self.content.decode("utf-8", errors="replace") elif self.data and self.node_data.body.type == "x-www-form-urlencoded": - body = urlencode(self.data) + body_string = urlencode(self.data) elif self.data and self.node_data.body.type == "form-data": for key, value in self.data.items(): - body += f"--{boundary}\r\n" - body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - body += f"{value}\r\n" - body += f"--{boundary}--\r\n" + body_string += f"--{boundary}\r\n" + body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' + body_string += f"{value}\r\n" + body_string += f"--{boundary}--\r\n" elif self.json: - body = json.dumps(self.json) + body_string = json.dumps(self.json) elif self.node_data.body.type == "raw-text": if len(self.node_data.body.data) != 1: raise RequestBodyError("raw-text body type should have exactly one item") - body = self.node_data.body.data[0].value - if body: - raw += f"Content-Length: {len(body)}\r\n" + body_string = self.node_data.body.data[0].value + if body_string: + raw += f"Content-Length: {len(body_string)}\r\n" raw += "\r\n" # Empty line between headers and body - raw += body + raw += body_string return raw diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index ebed690f6..861119f26 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -1,7 +1,7 @@ import logging import mimetypes from collections.abc import Mapping, Sequence -from typing import Any +from typing import Any, Optional from configs import dify_config from core.file import File, FileTransferMethod @@ -36,7 +36,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): _node_type = NodeType.HTTP_REQUEST @classmethod - def get_default_config(cls, filters: dict | None = None) -> dict: + def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict: return { "type": "http-request", "config": { @@ -160,8 +160,8 @@ def _extract_variable_selector_to_variable_mapping( ) mapping = {} - for selector in selectors: - mapping[node_id + "." + selector.variable] = selector.value_selector + for selector_iter in selectors: + mapping[node_id + "." + selector_iter.variable] = selector_iter.value_selector return mapping diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 6a89cbfad..f1289558f 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -361,13 +361,16 @@ def _handle_event_metadata( metadata = event.route_node_state.node_run_result.metadata if not metadata: metadata = {} - if NodeRunMetadataKey.ITERATION_ID not in metadata: - metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id - if self.node_data.is_parallel: - metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id - else: - metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index + metadata = { + **metadata, + NodeRunMetadataKey.ITERATION_ID: self.node_id, + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID + if self.node_data.is_parallel + else NodeRunMetadataKey.ITERATION_INDEX: parallel_mode_run_id + if self.node_data.is_parallel + else iter_run_index, + } event.route_node_state.node_run_result.metadata = metadata return event diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 4f9e415f4..0f239af51 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -11,6 +11,7 @@ from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.rag.datasource.retrieval_service import RetrievalService from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables import StringSegment @@ -18,7 +19,7 @@ from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.dataset import Dataset, Document, DocumentSegment +from models.dataset import Dataset, Document from models.workflow import WorkflowNodeExecutionStatus from .entities import KnowledgeRetrievalNodeData @@ -147,6 +148,8 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: planning_strategy=planning_strategy, ) elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: + if node_data.multiple_retrieval_config is None: + raise ValueError("multiple_retrieval_config is required") if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": if node_data.multiple_retrieval_config.reranking_model: reranking_model = { @@ -157,6 +160,8 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: reranking_model = None weights = None elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": + if node_data.multiple_retrieval_config.weights is None: + raise ValueError("weights is required") reranking_model = None vector_setting = node_data.multiple_retrieval_config.weights.vector_setting weights = { @@ -180,7 +185,9 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: available_datasets=available_datasets, query=query, top_k=node_data.multiple_retrieval_config.top_k, - score_threshold=node_data.multiple_retrieval_config.score_threshold, + score_threshold=node_data.multiple_retrieval_config.score_threshold + if node_data.multiple_retrieval_config.score_threshold is not None + else 0.0, reranking_mode=node_data.multiple_retrieval_config.reranking_mode, reranking_model=reranking_model, weights=weights, @@ -205,29 +212,12 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: "content": item.page_content, } retrieval_resource_list.append(source) - document_score_list = {} # deal with dify documents if dify_documents: - document_score_list = {} - for item in dify_documents: - if item.metadata.get("score"): - document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - - index_node_ids = [document.metadata["doc_id"] for document in dify_documents] - segments = DocumentSegment.query.filter( - DocumentSegment.dataset_id.in_(dataset_ids), - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids), - ).all() - if segments: - index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted( - segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) - ) - - for segment in sorted_segments: + records = RetrievalService.format_retrieval_documents(dify_documents) + if records: + for record in records: + segment = record.segment dataset = Dataset.query.filter_by(id=segment.dataset_id).first() document = Document.query.filter( Document.id == segment.document_id, @@ -245,7 +235,7 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: "document_data_source_type": document.data_source_type, "segment_id": segment.id, "retriever_from": "workflow", - "score": document_score_list.get(segment.index_node_id, None), + "score": record.score or 0.0, "segment_hit_count": segment.hit_count, "segment_word_count": segment.word_count, "segment_position": segment.position, @@ -260,12 +250,12 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: retrieval_resource_list.append(source) if retrieval_resource_list: retrieval_resource_list = sorted( - retrieval_resource_list, key=lambda x: x.get("metadata").get("score") or 0.0, reverse=True + retrieval_resource_list, + key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0, + reverse=True, ) - position = 1 - for item in retrieval_resource_list: + for position, item in enumerate(retrieval_resource_list, start=1): item["metadata"]["position"] = position - position += 1 return retrieval_resource_list @classmethod @@ -295,6 +285,8 @@ def _fetch_model_config( :param node_data: node data :return: """ + if node_data.single_retrieval_config is None: + raise ValueError("single_retrieval_config is required") model_name = node_data.single_retrieval_config.model.name provider_name = node_data.single_retrieval_config.model.provider diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 79066cece..432c57294 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Sequence -from typing import Literal, Union +from typing import Any, Literal, Union from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment @@ -17,9 +17,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): _node_type = NodeType.LIST_OPERATOR def _run(self): - inputs = {} - process_data = {} - outputs = {} + inputs: dict[str, list] = {} + process_data: dict[str, list] = {} + outputs: dict[str, Any] = {} variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) if variable is None: @@ -93,6 +93,8 @@ def _run(self): def _apply_filter( self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + filter_func: Callable[[Any], bool] + result: list[Any] = [] for condition in self.node_data.filter_by.conditions: if isinstance(variable, ArrayStringSegment): if not isinstance(condition.value, str): @@ -236,6 +238,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: + extract_func: Callable[[File], Any] if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str): extract_func = _get_file_extract_string_func(key=key) return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x)) @@ -249,47 +252,47 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str raise InvalidKeyError(f"Invalid key: {key}") -def _contains(value: str): +def _contains(value: str) -> Callable[[str], bool]: return lambda x: value in x -def _startswith(value: str): +def _startswith(value: str) -> Callable[[str], bool]: return lambda x: x.startswith(value) -def _endswith(value: str): +def _endswith(value: str) -> Callable[[str], bool]: return lambda x: x.endswith(value) -def _is(value: str): +def _is(value: str) -> Callable[[str], bool]: return lambda x: x is value -def _in(value: str | Sequence[str]): +def _in(value: str | Sequence[str]) -> Callable[[str], bool]: return lambda x: x in value -def _eq(value: int | float): +def _eq(value: int | float) -> Callable[[int | float], bool]: return lambda x: x == value -def _ne(value: int | float): +def _ne(value: int | float) -> Callable[[int | float], bool]: return lambda x: x != value -def _lt(value: int | float): +def _lt(value: int | float) -> Callable[[int | float], bool]: return lambda x: x < value -def _le(value: int | float): +def _le(value: int | float) -> Callable[[int | float], bool]: return lambda x: x <= value -def _gt(value: int | float): +def _gt(value: int | float) -> Callable[[int | float], bool]: return lambda x: x > value -def _ge(value: int | float): +def _ge(value: int | float) -> Callable[[int | float], bool]: return lambda x: x >= value @@ -302,6 +305,7 @@ def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]): def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]): + extract_func: Callable[[File], Any] if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}: extract_func = _get_file_extract_string_func(key=order_by) return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 55fac4557..6909b30c9 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -88,8 +88,8 @@ class LLMNode(BaseNode[LLMNodeData]): _node_data_cls = LLMNodeData _node_type = NodeType.LLM - def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]: - node_inputs = None + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: + node_inputs: Optional[dict[str, Any]] = None process_data = None try: @@ -196,7 +196,6 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] error_type=type(e).__name__, ) ) - return except Exception as e: yield RunCompletedEvent( run_result=NodeRunResult( @@ -206,7 +205,6 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] process_data=process_data, ) ) - return outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} @@ -302,7 +300,7 @@ def _transform_chat_messages( return messages def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: - variables = {} + variables: dict[str, Any] = {} if not node_data.prompt_config: return variables @@ -319,7 +317,7 @@ def parse_dict(input_dict: Mapping[str, Any]) -> str: """ # check if it's a context structure if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict: - return input_dict["content"] + return str(input_dict["content"]) # else, parse the dict try: @@ -557,7 +555,8 @@ def _fetch_prompt_messages( variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: - prompt_messages = [] + # FIXME: fix the type error cause prompt_messages is type quick a few times + prompt_messages: list[Any] = [] if isinstance(prompt_template, list): # For chat model @@ -783,7 +782,7 @@ def _extract_variable_selector_to_variable_mapping( else: raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") - variable_mapping = {} + variable_mapping: dict[str, Any] = {} for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector @@ -981,7 +980,7 @@ def _handle_memory_chat_mode( memory_config: MemoryConfig | None, model_config: ModelConfigWithCredentialsEntity, ) -> Sequence[PromptMessage]: - memory_messages = [] + memory_messages: Sequence[PromptMessage] = [] # Get messages from memory for chat model if memory and memory_config: rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 6fdff9660..a366c287c 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -14,8 +14,8 @@ class LoopNode(BaseNode[LoopNodeData]): _node_data_cls = LoopNodeData _node_type = NodeType.LOOP - def _run(self) -> LoopState: - return super()._run() + def _run(self) -> LoopState: # type: ignore + return super()._run() # type: ignore @classmethod def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: @@ -28,7 +28,7 @@ def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: # TODO waiting for implementation return [ - Condition( + Condition( # type: ignore variable_selector=[node_id, "index"], comparison_operator="≤", value_type="value_selector", diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index a001b44dc..369eb13b0 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -25,7 +25,7 @@ def validate_name(cls, value) -> str: raise ValueError("Parameter name is required") if value in {"__reason", "__is_success"}: raise ValueError("Invalid parameter name, __reason and __is_success are reserved") - return value + return str(value) class ParameterExtractorNodeData(BaseNodeData): @@ -52,7 +52,7 @@ def get_parameter_json_schema(self) -> dict: :return: parameter json schema """ - parameters = {"type": "object", "properties": {}, "required": []} + parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []} for parameter in self.parameters: parameter_schema: dict[str, Any] = {"description": parameter.description} diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index c8c854a43..9c88047f2 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -63,7 +63,8 @@ class ParameterExtractorNode(LLMNode): Parameter Extractor Node. """ - _node_data_cls = ParameterExtractorNodeData + # FIXME: figure out why here is different from super class + _node_data_cls = ParameterExtractorNodeData # type: ignore _node_type = NodeType.PARAMETER_EXTRACTOR _model_instance: Optional[ModelInstance] = None @@ -253,6 +254,9 @@ def _invoke( # deduct quota self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) + if text is None: + text = "" + return text, usage, tool_call def _generate_function_call_prompt( @@ -605,9 +609,10 @@ def extract_json(text): json_str = extract_json(result[idx:]) if json_str: try: - return json.loads(json_str) + return cast(dict, json.loads(json_str)) except Exception: pass + return None def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]: """ @@ -616,13 +621,13 @@ def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCal if not tool_call or not tool_call.function.arguments: return None - return json.loads(tool_call.function.arguments) + return cast(dict, json.loads(tool_call.function.arguments)) def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict: """ Generate default result. """ - result = {} + result: dict[str, Any] = {} for parameter in data.parameters: if parameter.type == "number": result[parameter.name] = 0 @@ -772,7 +777,7 @@ def _extract_variable_selector_to_variable_mapping( *, graph_config: Mapping[str, Any], node_id: str, - node_data: ParameterExtractorNodeData, + node_data: ParameterExtractorNodeData, # type: ignore ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -781,6 +786,7 @@ def _extract_variable_selector_to_variable_mapping( :param node_data: node data :return: """ + # FIXME: fix the type error later variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} if node_data.instruction: diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py index e603add17..6c3155ac9 100644 --- a/api/core/workflow/nodes/parameter_extractor/prompts.py +++ b/api/core/workflow/nodes/parameter_extractor/prompts.py @@ -1,3 +1,5 @@ +from typing import Any + FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters" FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy. @@ -35,7 +37,7 @@ """ # noqa: E501 -FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [ +FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [ { "user": { "query": "What is the weather today in SF?", diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 31f8368d5..0ec44eefa 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -34,12 +34,9 @@ QUESTION_CLASSIFIER_USER_PROMPT_3, ) -if TYPE_CHECKING: - from core.file import File - class QuestionClassifierNode(LLMNode): - _node_data_cls = QuestionClassifierNodeData + _node_data_cls = QuestionClassifierNodeData # type: ignore _node_type = NodeType.QUESTION_CLASSIFIER def _run(self): @@ -61,7 +58,7 @@ def _run(self): node_data.instruction = node_data.instruction or "" node_data.instruction = variable_pool.convert_template(node_data.instruction).text - files: Sequence[File] = ( + files = ( self._fetch_files( selector=node_data.vision.configs.variable_selector, ) @@ -168,7 +165,7 @@ def _extract_variable_selector_to_variable_mapping( *, graph_config: Mapping[str, Any], node_id: str, - node_data: QuestionClassifierNodeData, + node_data: Any, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -177,6 +174,7 @@ def _extract_variable_selector_to_variable_mapping( :param node_data: node data :return: """ + node_data = cast(QuestionClassifierNodeData, node_data) variable_mapping = {"query": node_data.query_variable_selector} variable_selectors = [] if node_data.instruction: diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 983fa7e62..01d07e494 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -9,7 +9,6 @@ from core.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.tool_engine import ToolEngine -from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_pool import VariablePool @@ -46,6 +45,8 @@ def _run(self) -> NodeRunResult: # get tool runtime try: + from core.tools.tool_manager import ToolManager + tool_runtime = ToolManager.get_workflow_tool_runtime( self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from ) @@ -142,7 +143,7 @@ def _generate_parameters( """ tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} - result = {} + result: dict[str, Any] = {} for parameter_name in node_data.tool_parameters: parameter = tool_parameters_dictionary.get(parameter_name) if not parameter: @@ -264,9 +265,9 @@ def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> """ return "\n".join( [ - f"{message.message}" + str(message.message) if message.type == ToolInvokeMessage.MessageType.TEXT - else f"Link: {message.message}" + else f"Link: {str(message.message)}" for message in tool_response if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK} ] diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 8eb4bd5c2..9acc76f32 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -36,6 +36,8 @@ def _run(self) -> NodeRunResult: case WriteMode.CLEAR: income_value = get_zero_value(original_variable.value_type) + if income_value is None: + raise VariableOperatorNodeError("income value not found") updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) case _: diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index d73c74420..afa5656f4 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,5 +1,6 @@ import json -from typing import Any +from collections.abc import Sequence +from typing import Any, cast from core.variables import SegmentType, Variable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID @@ -29,9 +30,9 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): def _run(self) -> NodeRunResult: inputs = self.node_data.model_dump() - process_data = {} + process_data: dict[str, Any] = {} # NOTE: This node has no outputs - updated_variables: list[Variable] = [] + updated_variable_selectors: list[Sequence[str]] = [] try: for item in self.node_data.items: @@ -98,7 +99,8 @@ def _run(self) -> NodeRunResult: value=item.value, ) variable = variable.model_copy(update={"value": updated_value}) - updated_variables.append(variable) + self.graph_runtime_state.variable_pool.add(variable.selector, variable) + updated_variable_selectors.append(variable.selector) except VariableOperatorNodeError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -107,9 +109,15 @@ def _run(self) -> NodeRunResult: error=str(e), ) + # The `updated_variable_selectors` is a list contains list[str] which not hashable, + # remove the duplicated items first. + updated_variable_selectors = list(set(map(tuple, updated_variable_selectors))) + # Update variables - for variable in updated_variables: - self.graph_runtime_state.variable_pool.add(variable.selector, variable) + for selector in updated_variable_selectors: + variable = self.graph_runtime_state.variable_pool.get(selector) + if not isinstance(variable, Variable): + raise VariableNotFoundError(variable_selector=selector) process_data[variable.name] = variable.value if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID: @@ -119,7 +127,7 @@ def _run(self) -> NodeRunResult: else: conversation_id = conversation_id.value common_helpers.update_conversation_variable( - conversation_id=conversation_id, + conversation_id=cast(str, conversation_id), variable=variable, ) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 811e40c11..f622d0b2d 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -129,11 +129,11 @@ def single_step_run( :return: """ # fetch node info from workflow graph - graph = workflow.graph_dict - if not graph: + workflow_graph = workflow.graph_dict + if not workflow_graph: raise ValueError("workflow graph not found") - nodes = graph.get("nodes") + nodes = workflow_graph.get("nodes") if not nodes: raise ValueError("nodes not found in workflow graph") @@ -196,7 +196,8 @@ def single_step_run( @staticmethod def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: - return WorkflowEntry._handle_special_values(value) + result = WorkflowEntry._handle_special_values(value) + return result if isinstance(result, Mapping) or result is None else dict(result) @staticmethod def _handle_special_values(value: Any) -> Any: @@ -208,10 +209,10 @@ def _handle_special_values(value: Any) -> Any: res[k] = WorkflowEntry._handle_special_values(v) return res if isinstance(value, list): - res = [] + res_list = [] for item in value: - res.append(WorkflowEntry._handle_special_values(item)) - return res + res_list.append(WorkflowEntry._handle_special_values(item)) + return res_list if isinstance(value, File): return value.to_dict() return value @@ -238,6 +239,10 @@ def mapping_user_inputs_to_variable_pool( ): raise ValueError(f"Variable key {node_variable} not found in user inputs.") + # environment variable already exist in variable pool, not from user inputs + if variable_pool.get(variable_selector): + continue + # fetch variable node id from variable selector variable_node_id = variable_selector[0] variable_key_list = variable_selector[1:] diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 5b07659ca..23f42c504 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -33,6 +33,7 @@ else --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \ --workers ${SERVER_WORKER_AMOUNT:-1} \ --worker-class ${SERVER_WORKER_CLASS:-gevent} \ + --worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \ --timeout ${GUNICORN_TIMEOUT:-200} \ app:app fi diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 24fa01369..8a677f6b6 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -14,7 +14,7 @@ @document_index_created.connect def handle(sender, **kwargs): dataset_id = sender - document_ids = kwargs.get("document_ids") + document_ids = kwargs.get("document_ids", []) documents = [] start_at = time.perf_counter() for document_id in document_ids: diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index 1515661b2..5e7caf8cb 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -8,18 +8,19 @@ def handle(sender, **kwargs): """Create site record when an app is created.""" app = sender account = kwargs.get("account") - site = Site( - app_id=app.id, - title=app.name, - icon_type=app.icon_type, - icon=app.icon, - icon_background=app.icon_background, - default_language=account.interface_language, - customize_token_strategy="not_allow", - code=Site.generate_code(16), - created_by=app.created_by, - updated_by=app.updated_by, - ) + if account is not None: + site = Site( + app_id=app.id, + title=app.name, + icon_type=app.icon_type, + icon=app.icon, + icon_background=app.icon_background, + default_language=account.interface_language, + customize_token_strategy="not_allow", + code=Site.generate_code(16), + created_by=app.created_by, + updated_by=app.updated_by, + ) - db.session.add(site) - db.session.commit() + db.session.add(site) + db.session.commit() diff --git a/api/events/event_handlers/deduct_quota_when_message_created.py b/api/events/event_handlers/deduct_quota_when_message_created.py index 843a23209..1ed37efba 100644 --- a/api/events/event_handlers/deduct_quota_when_message_created.py +++ b/api/events/event_handlers/deduct_quota_when_message_created.py @@ -44,7 +44,7 @@ def handle(sender, **kwargs): else: used_quota = 1 - if used_quota is not None: + if used_quota is not None and system_configuration.current_quota_type is not None: db.session.query(Provider).filter( Provider.tenant_id == application_generate_entity.app_config.tenant_id, Provider.provider_name == model_config.provider, diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 9c5955c8c..f89fae24a 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -8,7 +8,10 @@ @app_draft_workflow_was_synced.connect def handle(sender, **kwargs): app = sender - for node_data in kwargs.get("synced_draft_workflow").graph_dict.get("nodes", []): + synced_draft_workflow = kwargs.get("synced_draft_workflow") + if synced_draft_workflow is None: + return + for node_data in synced_draft_workflow.graph_dict.get("nodes", []): if node_data.get("data", {}).get("type") == NodeType.TOOL.value: try: tool_entity = ToolEntity(**node_data["data"]) diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index de7c0f4df..14396e992 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -8,16 +8,18 @@ def handle(sender, **kwargs): app = sender app_model_config = kwargs.get("app_model_config") + if app_model_config is None: + return dataset_ids = get_dataset_ids_from_model_config(app_model_config) app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() - removed_dataset_ids = [] + removed_dataset_ids: set[str] = set() if not app_dataset_joins: added_dataset_ids = dataset_ids else: - old_dataset_ids = set() + old_dataset_ids: set[str] = set() old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) added_dataset_ids = dataset_ids - old_dataset_ids @@ -37,8 +39,8 @@ def handle(sender, **kwargs): db.session.commit() -def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set: - dataset_ids = set() +def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[str]: + dataset_ids: set[str] = set() if not app_model_config: return dataset_ids diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 453395e8d..dd2efed94 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -17,11 +17,11 @@ def handle(sender, **kwargs): dataset_ids = get_dataset_ids_from_workflow(published_workflow) app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() - removed_dataset_ids = [] + removed_dataset_ids: set[str] = set() if not app_dataset_joins: added_dataset_ids = dataset_ids else: - old_dataset_ids = set() + old_dataset_ids: set[str] = set() old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) added_dataset_ids = dataset_ids - old_dataset_ids @@ -41,8 +41,8 @@ def handle(sender, **kwargs): db.session.commit() -def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set: - dataset_ids = set() +def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]: + dataset_ids: set[str] = set() graph = published_workflow.graph_dict if not graph: return dataset_ids @@ -60,7 +60,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set: for node in knowledge_retrieval_nodes: try: node_data = KnowledgeRetrievalNodeData(**node.get("data", {})) - dataset_ids.update(node_data.dataset_ids) + dataset_ids.update(dataset_id for dataset_id in node_data.dataset_ids) except Exception as e: continue diff --git a/api/extensions/__init__.py b/api/extensions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/extensions/ext_app_metrics.py b/api/extensions/ext_app_metrics.py index de1cdfeb9..b7d412d68 100644 --- a/api/extensions/ext_app_metrics.py +++ b/api/extensions/ext_app_metrics.py @@ -54,12 +54,14 @@ def pool_stat(): from extensions.ext_database import db engine = db.engine + # TODO: Fix the type error + # FIXME maybe its sqlalchemy issue return { "pid": os.getpid(), - "pool_size": engine.pool.size(), - "checked_in_connections": engine.pool.checkedin(), - "checked_out_connections": engine.pool.checkedout(), - "overflow_connections": engine.pool.overflow(), - "connection_timeout": engine.pool.timeout(), - "recycle_time": db.engine.pool._recycle, + "pool_size": engine.pool.size(), # type: ignore + "checked_in_connections": engine.pool.checkedin(), # type: ignore + "checked_out_connections": engine.pool.checkedout(), # type: ignore + "overflow_connections": engine.pool.overflow(), # type: ignore + "connection_timeout": engine.pool.timeout(), # type: ignore + "recycle_time": db.engine.pool._recycle, # type: ignore } diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index fcd1547a2..316be12f5 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -5,7 +5,7 @@ def init_app(app: DifyApp): # register blueprint routers - from flask_cors import CORS + from flask_cors import CORS # type: ignore from controllers.console import bp as console_app_bp from controllers.files import bp as files_bp diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 8dbadddb1..dc473c64d 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,8 +1,8 @@ from datetime import timedelta import pytz -from celery import Celery, Task -from celery.schedules import crontab +from celery import Celery, Task # type: ignore +from celery.schedules import crontab # type: ignore from configs import dify_config from dify_app import DifyApp @@ -47,7 +47,7 @@ def __call__(self, *args: object, **kwargs: object) -> object: worker_log_format=dify_config.LOG_FORMAT, worker_task_log_format=dify_config.LOG_FORMAT, worker_hijack_root_logger=False, - timezone=pytz.timezone(dify_config.LOG_TZ), + timezone=pytz.timezone(dify_config.LOG_TZ or "UTC"), ) if dify_config.BROKER_USE_SSL: @@ -69,6 +69,7 @@ def __call__(self, *args: object, **kwargs: object) -> object: "schedule.create_tidb_serverless_task", "schedule.update_tidb_serverless_status_task", "schedule.clean_messages", + "schedule.mail_clean_document_notify_task", "schedule.update_account_used_quota_extend", # 二开部分 每月重置账号额度 "schedule.update_api_token_daily_used_quota_task_extend", # 二开部分 重置密钥日额度 "schedule.update_api_token_monthly_used_quota_task_extend", # 二开部分 重置密钥月额度 @@ -96,6 +97,11 @@ def __call__(self, *args: object, **kwargs: object) -> object: "task": "schedule.clean_messages.clean_messages", "schedule": timedelta(days=day), }, + # every Monday + "mail_clean_document_notify_task": { + "task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task", + "schedule": crontab(minute="0", hour="10", day_of_week="1"), + }, # ---------------------------- 二开部分 Begin ---------------------------- # 每月1号00:00,重置账号额度 "update_account_used_quota": { diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py index 9c3a663af..26ff6427b 100644 --- a/api/extensions/ext_compress.py +++ b/api/extensions/ext_compress.py @@ -7,7 +7,7 @@ def is_enabled() -> bool: def init_app(app: DifyApp): - from flask_compress import Compress + from flask_compress import Compress # type: ignore compress = Compress() compress.init_app(app) diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py index 9fc29b4eb..1b9e78828 100644 --- a/api/extensions/ext_logging.py +++ b/api/extensions/ext_logging.py @@ -11,7 +11,7 @@ def init_app(app: DifyApp): - log_handlers = [] + log_handlers: list[logging.Handler] = [] log_file = dify_config.LOG_FILE if log_file: log_dir = os.path.dirname(log_file) @@ -46,10 +46,11 @@ def init_app(app: DifyApp): timezone = pytz.timezone(log_tz) def time_converter(seconds): - return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() + return datetime.fromtimestamp(seconds, tz=timezone).timetuple() for handler in logging.root.handlers: - handler.formatter.converter = time_converter + if handler.formatter: + handler.formatter.converter = time_converter def get_request_id(): diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index b29553071..10fb89eb7 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -1,6 +1,6 @@ import json -import flask_login +import flask_login # type: ignore from flask import Response, request from flask_login import user_loaded_from_request, user_logged_in from werkzeug.exceptions import Unauthorized diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index 468aedd47..9240ebe7f 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -26,7 +26,7 @@ def init_app(self, app: Flask): match mail_type: case "resend": - import resend + import resend # type: ignore api_key = dify_config.RESEND_API_KEY if not api_key: @@ -48,9 +48,9 @@ def init_app(self, app: Flask): self._client = SMTPClient( server=dify_config.SMTP_SERVER, port=dify_config.SMTP_PORT, - username=dify_config.SMTP_USERNAME, - password=dify_config.SMTP_PASSWORD, - _from=dify_config.MAIL_DEFAULT_SEND_FROM, + username=dify_config.SMTP_USERNAME or "", + password=dify_config.SMTP_PASSWORD or "", + _from=dify_config.MAIL_DEFAULT_SEND_FROM or "", use_tls=dify_config.SMTP_USE_TLS, opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS, ) diff --git a/api/extensions/ext_migrate.py b/api/extensions/ext_migrate.py index 6d8f35c30..5f862181f 100644 --- a/api/extensions/ext_migrate.py +++ b/api/extensions/ext_migrate.py @@ -2,7 +2,7 @@ def init_app(app: DifyApp): - import flask_migrate + import flask_migrate # type: ignore from extensions.ext_database import db diff --git a/api/extensions/ext_proxy_fix.py b/api/extensions/ext_proxy_fix.py index 3b895ac95..514e06582 100644 --- a/api/extensions/ext_proxy_fix.py +++ b/api/extensions/ext_proxy_fix.py @@ -6,4 +6,4 @@ def init_app(app: DifyApp): if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED: from werkzeug.middleware.proxy_fix import ProxyFix - app.wsgi_app = ProxyFix(app.wsgi_app) + app.wsgi_app = ProxyFix(app.wsgi_app) # type: ignore diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 8016356a3..3a74aace6 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -6,7 +6,7 @@ def init_app(app: DifyApp): if dify_config.SENTRY_DSN: import openai import sentry_sdk - from langfuse import parse_error + from langfuse import parse_error # type: ignore from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException @@ -27,6 +27,7 @@ def before_send(event, hint): ignore_errors=[ HTTPException, ValueError, + FileNotFoundError, openai.APIStatusError, InvokeRateLimitError, parse_error.defaultErrorResponse, diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index 42422263c..588bdb2d2 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -1,6 +1,6 @@ import logging from collections.abc import Callable, Generator -from typing import Union +from typing import Literal, Union, overload from flask import Flask @@ -79,6 +79,12 @@ def save(self, filename, data): logger.exception(f"Failed to save file {filename}") raise e + @overload + def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ... + + @overload + def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ... + def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]: try: if stream: diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py index 58c917dbd..00bf5d4f9 100644 --- a/api/extensions/storage/aliyun_oss_storage.py +++ b/api/extensions/storage/aliyun_oss_storage.py @@ -1,7 +1,7 @@ import posixpath from collections.abc import Generator -import oss2 as aliyun_s3 +import oss2 as aliyun_s3 # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -33,7 +33,7 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: obj = self.client.get_object(self.__wrapper_folder_filename(filename)) - data = obj.read() + data: bytes = obj.read() return data def load_stream(self, filename: str) -> Generator: @@ -41,14 +41,14 @@ def load_stream(self, filename: str) -> Generator: while chunk := obj.read(4096): yield chunk - def download(self, filename, target_filepath): + def download(self, filename: str, target_filepath): self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath) - def exists(self, filename): + def exists(self, filename: str): return self.client.object_exists(self.__wrapper_folder_filename(filename)) - def delete(self, filename): + def delete(self, filename: str): self.client.delete_object(self.__wrapper_folder_filename(filename)) - def __wrapper_folder_filename(self, filename) -> str: + def __wrapper_folder_filename(self, filename: str) -> str: return posixpath.join(self.folder, filename) if self.folder else filename diff --git a/api/extensions/storage/aws_s3_storage.py b/api/extensions/storage/aws_s3_storage.py index ce36c2e7d..7b6b2eedd 100644 --- a/api/extensions/storage/aws_s3_storage.py +++ b/api/extensions/storage/aws_s3_storage.py @@ -1,9 +1,9 @@ import logging from collections.abc import Generator -import boto3 -from botocore.client import Config -from botocore.exceptions import ClientError +import boto3 # type: ignore +from botocore.client import Config # type: ignore +from botocore.exceptions import ClientError # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -53,7 +53,7 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: try: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() + data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index b26caa867..2f8532f4f 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -27,7 +27,7 @@ def load_once(self, filename: str) -> bytes: client = self._sync_client() blob = client.get_container_client(container=self.bucket_name) blob = blob.get_blob_client(blob=filename) - data = blob.download_blob().readall() + data: bytes = blob.download_blob().readall() return data def load_stream(self, filename: str) -> Generator: @@ -63,11 +63,11 @@ def _sync_client(self): sas_token = cache_result.decode("utf-8") else: sas_token = generate_account_sas( - account_name=self.account_name, - account_key=self.account_key, + account_name=self.account_name or "", + account_key=self.account_key or "", resource_types=ResourceTypes(service=True, container=True, object=True), permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1), ) redis_client.set(cache_key, sas_token, ex=3000) - return BlobServiceClient(account_url=self.account_url, credential=sas_token) + return BlobServiceClient(account_url=self.account_url or "", credential=sas_token) diff --git a/api/extensions/storage/baidu_obs_storage.py b/api/extensions/storage/baidu_obs_storage.py index e0d2140e9..b94efa08b 100644 --- a/api/extensions/storage/baidu_obs_storage.py +++ b/api/extensions/storage/baidu_obs_storage.py @@ -2,9 +2,9 @@ import hashlib from collections.abc import Generator -from baidubce.auth.bce_credentials import BceCredentials -from baidubce.bce_client_configuration import BceClientConfiguration -from baidubce.services.bos.bos_client import BosClient +from baidubce.auth.bce_credentials import BceCredentials # type: ignore +from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore +from baidubce.services.bos.bos_client import BosClient # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -36,7 +36,8 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: response = self.client.get_object(bucket_name=self.bucket_name, key=filename) - return response.data.read() + data: bytes = response.data.read() + return data def load_stream(self, filename: str) -> Generator: response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py index 26b662d2f..705639f42 100644 --- a/api/extensions/storage/google_cloud_storage.py +++ b/api/extensions/storage/google_cloud_storage.py @@ -3,7 +3,7 @@ import json from collections.abc import Generator -from google.cloud import storage as google_cloud_storage +from google.cloud import storage as google_cloud_storage # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -35,7 +35,7 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) - data = blob.download_as_bytes() + data: bytes = blob.download_as_bytes() return data def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py index 20be70ef8..07f1d1997 100644 --- a/api/extensions/storage/huawei_obs_storage.py +++ b/api/extensions/storage/huawei_obs_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -from obs import ObsClient +from obs import ObsClient # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -23,7 +23,7 @@ def save(self, filename, data): self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data) def load_once(self, filename: str) -> bytes: - data = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read() + data: bytes = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read() return data def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index e671eff05..b78fc94da 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -3,7 +3,7 @@ from collections.abc import Generator from pathlib import Path -import opendal +import opendal # type: ignore[import] from dotenv import dotenv_values from extensions.storage.base_storage import BaseStorage @@ -18,7 +18,7 @@ def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str if key.startswith(config_prefix): kwargs[key[len(config_prefix) :].lower()] = value - file_env_vars = dotenv_values(env_file_path) + file_env_vars: dict = dotenv_values(env_file_path) or {} for key, value in file_env_vars.items(): if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value: kwargs[key[len(config_prefix) :].lower()] = value @@ -48,7 +48,7 @@ def load_once(self, filename: str) -> bytes: if not self.exists(filename): raise FileNotFoundError("File not found") - content = self.op.read(path=filename) + content: bytes = self.op.read(path=filename) logger.debug(f"file {filename} loaded") return content @@ -75,7 +75,7 @@ def exists(self, filename: str) -> bool: # error handler here when opendal python-binding has a exists method, we should use it # more https://github.com/apache/opendal/blob/main/bindings/python/src/operator.rs try: - res = self.op.stat(path=filename).mode.is_file() + res: bool = self.op.stat(path=filename).mode.is_file() logger.debug(f"file {filename} checked") return res except Exception: diff --git a/api/extensions/storage/oracle_oci_storage.py b/api/extensions/storage/oracle_oci_storage.py index b59f83b8d..82829f7fd 100644 --- a/api/extensions/storage/oracle_oci_storage.py +++ b/api/extensions/storage/oracle_oci_storage.py @@ -1,7 +1,7 @@ from collections.abc import Generator -import boto3 -from botocore.exceptions import ClientError +import boto3 # type: ignore +from botocore.exceptions import ClientError # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -27,7 +27,7 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: try: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() + data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") diff --git a/api/extensions/storage/supabase_storage.py b/api/extensions/storage/supabase_storage.py index 9f7c69a9a..711c3f721 100644 --- a/api/extensions/storage/supabase_storage.py +++ b/api/extensions/storage/supabase_storage.py @@ -32,7 +32,7 @@ def save(self, filename, data): self.client.storage.from_(self.bucket_name).upload(filename, data) def load_once(self, filename: str) -> bytes: - content = self.client.storage.from_(self.bucket_name).download(filename) + content: bytes = self.client.storage.from_(self.bucket_name).download(filename) return content def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py index 13a6c9239..9cdd3e67f 100644 --- a/api/extensions/storage/tencent_cos_storage.py +++ b/api/extensions/storage/tencent_cos_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -from qcloud_cos import CosConfig, CosS3Client +from qcloud_cos import CosConfig, CosS3Client # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -25,7 +25,7 @@ def save(self, filename, data): self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename) def load_once(self, filename: str) -> bytes: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read() + data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read() return data def load_stream(self, filename: str) -> Generator: diff --git a/api/extensions/storage/volcengine_tos_storage.py b/api/extensions/storage/volcengine_tos_storage.py index de82be04e..55fe6545e 100644 --- a/api/extensions/storage/volcengine_tos_storage.py +++ b/api/extensions/storage/volcengine_tos_storage.py @@ -1,6 +1,6 @@ from collections.abc import Generator -import tos +import tos # type: ignore from configs import dify_config from extensions.storage.base_storage import BaseStorage @@ -24,6 +24,8 @@ def save(self, filename, data): def load_once(self, filename: str) -> bytes: data = self.client.get_object(bucket=self.bucket_name, key=filename).read() + if not isinstance(data, bytes): + raise TypeError("Expected bytes, got {}".format(type(data).__name__)) return data def load_stream(self, filename: str) -> Generator: diff --git a/api/factories/__init__.py b/api/factories/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 13034f5cf..c6dc748e9 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -1,4 +1,5 @@ import mimetypes +import uuid from collections.abc import Callable, Mapping, Sequence from typing import Any, cast @@ -64,7 +65,7 @@ def build_from_mapping( if not build_func: raise ValueError(f"Invalid file transfer method: {transfer_method}") - file = build_func( + file: File = build_func( mapping=mapping, tenant_id=tenant_id, transfer_method=transfer_method, @@ -72,7 +73,7 @@ def build_from_mapping( if config and not _is_file_valid_with_config( input_file_type=mapping.get("type", FileType.CUSTOM), - file_extension=file.extension, + file_extension=file.extension or "", file_transfer_method=file.transfer_method, config=config, ): @@ -119,6 +120,11 @@ def _build_from_local_file( upload_file_id = mapping.get("upload_file_id") if not upload_file_id: raise ValueError("Invalid upload file id") + # check if upload_file_id is a valid uuid + try: + uuid.UUID(upload_file_id) + except ValueError: + raise ValueError("Invalid upload file id format") stmt = select(UploadFile).where( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, @@ -152,7 +158,7 @@ def _build_from_remote_url( tenant_id: str, transfer_method: FileTransferMethod, ) -> File: - url = mapping.get("url") + url = mapping.get("url") or mapping.get("remote_url") if not url: raise ValueError("Invalid file url") @@ -281,6 +287,7 @@ def _get_file_type_by_extension(extension: str) -> FileType | None: return FileType.AUDIO elif extension in DOCUMENT_EXTENSIONS: return FileType.DOCUMENT + return None def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 16a578728..bbca8448e 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any +from typing import Any, cast from uuid import uuid4 from configs import dify_config @@ -84,6 +84,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen raise VariableError("missing value type") if (value := mapping.get("value")) is None: raise VariableError("missing value") + # FIXME: using Any here, fix it later + result: Any match value_type: case SegmentType.STRING: result = StringVariable.model_validate(mapping) @@ -109,7 +111,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") if not result.selector: result = result.model_copy(update={"selector": selector}) - return result + return cast(Variable, result) def build_segment(value: Any, /) -> Segment: @@ -164,10 +166,13 @@ def segment_to_variable( raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] - return variable_class( - id=id, - name=name, - description=description, - value=segment.value, - selector=selector, + return cast( + Variable, + variable_class( + id=id, + name=name, + description=description, + value=segment.value, + selector=selector, + ), ) diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py index 379dcc6d1..1c58b3a25 100644 --- a/api/fields/annotation_fields.py +++ b/api/fields/annotation_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py index a85d4a34d..d40407bfc 100644 --- a/api/fields/api_based_extension_fields.py +++ b/api/fields/api_based_extension_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 8682c81ee..d10611219 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.workflow_fields import workflow_partial_fields from libs.helper import AppIconUrlField, TimestampField diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 6a9e347b1..c54554a6d 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.member_fields import simple_account_fields from libs.helper import TimestampField diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index 983e50e73..c6385efb5 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py index 071071376..608672121 100644 --- a/api/fields/data_source_fields.py +++ b/api/fields/data_source_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 533e3a083..bedab5750 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField @@ -73,6 +73,7 @@ "embedding_available": fields.Boolean, "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), "tags": fields.List(fields.Nested(tag_fields)), + "doc_form": fields.String, "external_knowledge_info": fields.Nested(external_knowledge_info_fields), "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), } diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index a83ec7bc9..f2250d964 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.dataset_fields import dataset_fields from libs.helper import TimestampField @@ -34,6 +34,7 @@ "data_source_info": fields.Raw(attribute="data_source_info_dict"), "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), "dataset_process_rule_id": fields.String, + "process_rule_dict": fields.Raw(attribute="process_rule_dict"), "name": fields.String, "created_from": fields.String, "created_by": fields.String, diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py index 99e529f9d..aefa0b275 100644 --- a/api/fields/end_user_fields.py +++ b/api/fields/end_user_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore simple_end_user_fields = { "id": fields.String, diff --git a/api/fields/external_dataset_fields.py b/api/fields/external_dataset_fields.py index 2281460fe..9cc4e14a0 100644 --- a/api/fields/external_dataset_fields.py +++ b/api/fields/external_dataset_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index afaacc056..f896c15f0 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index f36e80f8d..b9f7e78c1 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField @@ -34,8 +34,16 @@ "document": fields.Nested(document_fields), } +child_chunk_fields = { + "id": fields.String, + "content": fields.String, + "position": fields.Integer, + "score": fields.Float, +} + hit_testing_record_fields = { "segment": fields.Nested(segment_fields), + "child_chunks": fields.List(fields.Nested(child_chunk_fields)), "score": fields.Float, "tsne_position": fields.Raw, } diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index e0b3e340f..16f265b9b 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import AppIconUrlField, TimestampField diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 1cf8e408d..0c854c640 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 5f6e7884a..0571faab0 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.conversation_fields import message_file_fields from libs.helper import TimestampField diff --git a/api/fields/raws.py b/api/fields/raws.py index 15ec16ab1..493d4b6cc 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from core.file import File diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 2dd4cb45b..52f89859c 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -1,7 +1,18 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from libs.helper import TimestampField +child_chunk_fields = { + "id": fields.String, + "segment_id": fields.String, + "content": fields.String, + "position": fields.Integer, + "word_count": fields.Integer, + "type": fields.String, + "created_at": TimestampField, + "updated_at": TimestampField, +} + segment_fields = { "id": fields.String, "position": fields.Integer, @@ -20,10 +31,13 @@ "status": fields.String, "created_by": fields.String, "created_at": TimestampField, + "updated_at": TimestampField, + "updated_by": fields.String, "indexing_at": TimestampField, "completed_at": TimestampField, "error": fields.String, "stopped_at": TimestampField, + "child_chunks": fields.List(fields.Nested(child_chunk_fields)), } segment_list_response = { diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py index 9af4fc57d..986cd725f 100644 --- a/api/fields/tag_fields.py +++ b/api/fields/tag_fields.py @@ -1,3 +1,3 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String} diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index a53b54624..c45b33597 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 0d860d6f4..32f979a5f 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from core.helper import encrypter from core.variables import SecretVariable, SegmentType, Variable @@ -45,6 +45,7 @@ def format(self, value): "graph": fields.Raw(attribute="graph_dict"), "features": fields.Raw(attribute="features_dict"), "hash": fields.String(attribute="unique_hash"), + "version": fields.String(attribute="version"), "created_by": fields.Nested(simple_account_fields, attribute="created_by_account"), "created_at": TimestampField, "updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True), @@ -61,3 +62,10 @@ def format(self, value): "updated_by": fields.String, "updated_at": TimestampField, } + +workflow_pagination_fields = { + "items": fields.List(fields.Nested(workflow_fields), attribute="items"), + "page": fields.Integer, + "limit": fields.Integer(attribute="limit"), + "has_more": fields.Boolean(attribute="has_more"), +} diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 74fdf8bd9..ef59c57ec 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -1,4 +1,4 @@ -from flask_restful import fields +from flask_restful import fields # type: ignore from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields diff --git a/api/libs/external_api.py b/api/libs/external_api.py index 179617ac0..922d2d9cd 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -1,8 +1,9 @@ import re import sys +from typing import Any from flask import current_app, got_request_exception -from flask_restful import Api, http_status_message +from flask_restful import Api, http_status_message # type: ignore from werkzeug.datastructures import Headers from werkzeug.exceptions import HTTPException @@ -84,7 +85,7 @@ def handle_error(self, e): # record the exception in the logs when we have a server error of status code: 500 if status_code and status_code >= 500: - exc_info = sys.exc_info() + exc_info: Any = sys.exc_info() if exc_info[1] is None: exc_info = None current_app.log_exception(exc_info) @@ -100,7 +101,7 @@ def handle_error(self, e): resp = self.make_response(data, status_code, headers, fallback_mediatype=fallback_mediatype) elif status_code == 400: if isinstance(data.get("message"), dict): - param_key, param_value = list(data.get("message").items())[0] + param_key, param_value = list(data.get("message", {}).items())[0] data = {"code": "invalid_param", "message": param_value, "params": param_key} else: if "code" not in data: diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index 83f9c74e3..2dae87e17 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -23,7 +23,7 @@ import Crypto.Hash.SHA1 import Crypto.Util.number -import gmpy2 +import gmpy2 # type: ignore from Crypto import Random from Crypto.Signature.pss import MGF1 from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes @@ -191,12 +191,12 @@ def decrypt(self, ciphertext): # Step 3g one_pos = hLen + db[hLen:].find(b"\x01") lHash1 = db[:hLen] - invalid = bord(y) | int(one_pos < hLen) + invalid = bord(y) | int(one_pos < hLen) # type: ignore hash_compare = strxor(lHash1, lHash) for x in hash_compare: - invalid |= bord(x) + invalid |= bord(x) # type: ignore for x in db[hLen:one_pos]: - invalid |= bord(x) + invalid |= bord(x) # type: ignore if invalid != 0: raise ValueError("Incorrect decryption.") # Step 4 diff --git a/api/libs/helper.py b/api/libs/helper.py index 91b1d1fe1..eaa4efdb7 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -13,7 +13,7 @@ from zoneinfo import available_timezones from flask import Response, stream_with_context -from flask_restful import fields +from flask_restful import fields # type: ignore from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator @@ -248,13 +248,13 @@ def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]] if token_data_json is None: logging.warning(f"{token_type} token {token} not found with key {key}") return None - token_data = json.loads(token_data_json) + token_data: Optional[dict[str, Any]] = json.loads(token_data_json) return token_data @classmethod def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]: key = cls._get_account_token_key(account_id, token_type) - current_token = redis_client.get(key) + current_token: Optional[str] = redis_client.get(key) return current_token @classmethod diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 267af611f..9ab53b629 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -10,6 +10,7 @@ def parse_json_markdown(json_string: str) -> dict: ends = ["```", "``", "`", "}"] end_index = -1 start_index = 0 + parsed: dict = {} for s in starts: start_index = json_string.find(s) if start_index != -1: diff --git a/api/libs/login.py b/api/libs/login.py index 0ea191a18..5395534a6 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -1,8 +1,9 @@ from functools import wraps +from typing import Any from flask import current_app, g, has_request_context, request -from flask_login import user_logged_in -from flask_login.config import EXEMPT_METHODS +from flask_login import user_logged_in # type: ignore +from flask_login.config import EXEMPT_METHODS # type: ignore from werkzeug.exceptions import Unauthorized from werkzeug.local import LocalProxy @@ -12,7 +13,7 @@ #: A proxy for the current user. If no user is logged in, this will be an #: anonymous user -current_user = LocalProxy(lambda: _get_user()) +current_user: Any = LocalProxy(lambda: _get_user()) def login_required(func): @@ -79,12 +80,12 @@ def decorated_view(*args, **kwargs): # Login admin if account: account.current_tenant = tenant - current_app.login_manager._update_request_context_with_user(account) - user_logged_in.send(current_app._get_current_object(), user=_get_user()) + current_app.login_manager._update_request_context_with_user(account) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: pass elif not current_user.is_authenticated: - return current_app.login_manager.unauthorized() + return current_app.login_manager.unauthorized() # type: ignore # flask 1.x compatibility # current_app.ensure_sync is only available in Flask >= 2.0 @@ -98,7 +99,7 @@ def decorated_view(*args, **kwargs): def _get_user(): if has_request_context(): if "_login_user" not in g: - current_app.login_manager._load_user() + current_app.login_manager._load_user() # type: ignore return g._login_user diff --git a/api/libs/oauth.py b/api/libs/oauth.py index f2154ad84..521f8d622 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -79,9 +79,9 @@ def get_raw_user_info(self, token: str): email_response = requests.get(self._EMAIL_INFO_URL, headers=headers) email_info = email_response.json() - primary_email = next((email for email in email_info if email["primary"] == True), None) + primary_email: dict = next((email for email in email_info if email["primary"] == True), {}) - return {**user_info, "email": primary_email["email"]} + return {**user_info, "email": primary_email.get("email", "")} def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: email = raw_info.get("email") @@ -132,7 +132,7 @@ def get_raw_user_info(self, token: str): return response.json() def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: - return OAuthUserInfo(id=str(raw_info["sub"]), name=None, email=raw_info["email"]) + return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"]) class OaOAuth(OAuth): diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 1d39abd8f..a5ba08d35 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -1,8 +1,9 @@ import datetime import urllib.parse +from typing import Any import requests -from flask_login import current_user +from flask_login import current_user # type: ignore from extensions.ext_database import db from models.source import DataSourceOauthBinding @@ -226,7 +227,7 @@ def notion_page_search(self, access_token: str): has_more = True while has_more: - data = { + data: dict[str, Any] = { "filter": {"value": "page", "property": "object"}, **({"start_cursor": next_cursor} if next_cursor else {}), } @@ -254,7 +255,8 @@ def notion_block_parent_page_id(self, access_token: str, block_id: str): response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) response_json = response.json() if response.status_code != 200: - raise ValueError(f"Error fetching block parent page ID: {response_json.message}") + message = response_json.get("message", "unknown error") + raise ValueError(f"Error fetching block parent page ID: {message}") parent = response_json["parent"] parent_type = parent["type"] if parent_type == "block_id": @@ -281,7 +283,7 @@ def notion_database_search(self, access_token: str): has_more = True while has_more: - data = { + data: dict[str, Any] = { "filter": {"value": "database", "property": "object"}, **({"start_cursor": next_cursor} if next_cursor else {}), } diff --git a/api/libs/threadings_utils.py b/api/libs/threadings_utils.py deleted file mode 100644 index d356def41..000000000 --- a/api/libs/threadings_utils.py +++ /dev/null @@ -1,19 +0,0 @@ -from configs import dify_config - - -def apply_gevent_threading_patch(): - """ - Run threading patch by gevent - to make standard library threading compatible. - Patching should be done as early as possible in the lifecycle of the program. - :return: - """ - if not dify_config.DEBUG: - from gevent import monkey - from grpc.experimental import gevent as grpc_gevent - - # gevent - monkey.patch_all() - - # grpc gevent - grpc_gevent.init_gevent() diff --git a/api/libs/version_utils.py b/api/libs/version_utils.py deleted file mode 100644 index 10edf8a05..000000000 --- a/api/libs/version_utils.py +++ /dev/null @@ -1,12 +0,0 @@ -import sys - - -def check_supported_python_version(): - python_version = sys.version_info - if not ((3, 11) <= python_version < (3, 13)): - print( - "Aborted to launch the service " - f" with unsupported Python version {python_version.major}.{python_version.minor}." - " Please ensure Python 3.11 or 3.12." - ) - raise SystemExit(1) diff --git a/api/migrations/README b/api/migrations/README index 220678df7..0e0484415 100644 --- a/api/migrations/README +++ b/api/migrations/README @@ -1,2 +1 @@ Single-database configuration for Flask. - diff --git a/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py b/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py new file mode 100644 index 000000000..9238e5a0a --- /dev/null +++ b/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py @@ -0,0 +1,55 @@ +"""parent-child-index + +Revision ID: e19037032219 +Revises: 01d6889832f7 +Create Date: 2024-11-22 07:01:17.550037 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'e19037032219' +down_revision = 'd7999dfa4aae' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('child_chunks', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('segment_id', models.types.StringUUID(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('word_count', sa.Integer(), nullable=False), + sa.Column('index_node_id', sa.String(length=255), nullable=True), + sa.Column('index_node_hash', sa.String(length=255), nullable=True), + sa.Column('type', sa.String(length=255), server_default=sa.text("'automatic'::character varying"), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_by', models.types.StringUUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('indexing_at', sa.DateTime(), nullable=True), + sa.Column('completed_at', sa.DateTime(), nullable=True), + sa.Column('error', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id', name='child_chunk_pkey') + ) + with op.batch_alter_table('child_chunks', schema=None) as batch_op: + batch_op.create_index('child_chunk_dataset_id_idx', ['tenant_id', 'dataset_id', 'document_id', 'segment_id', 'index_node_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('child_chunks', schema=None) as batch_op: + batch_op.drop_index('child_chunk_dataset_id_idx') + + op.drop_table('child_chunks') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py b/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py new file mode 100644 index 000000000..6dadd4e4a --- /dev/null +++ b/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py @@ -0,0 +1,47 @@ +"""add_auto_disabled_dataset_logs + +Revision ID: 923752d42eb6 +Revises: e19037032219 +Create Date: 2024-12-25 11:37:55.467101 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '923752d42eb6' +down_revision = 'e19037032219' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('dataset_auto_disable_logs', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('notified', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='dataset_auto_disable_log_pkey') + ) + with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op: + batch_op.create_index('dataset_auto_disable_log_created_atx', ['created_at'], unique=False) + batch_op.create_index('dataset_auto_disable_log_dataset_idx', ['dataset_id'], unique=False) + batch_op.create_index('dataset_auto_disable_log_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op: + batch_op.drop_index('dataset_auto_disable_log_tenant_idx') + batch_op.drop_index('dataset_auto_disable_log_dataset_idx') + batch_op.drop_index('dataset_auto_disable_log_created_atx') + + op.drop_table('dataset_auto_disable_logs') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_01_01_2000-a91b476a53de_change_workflow_runs_total_tokens_to_.py b/api/migrations/versions/2025_01_01_2000-a91b476a53de_change_workflow_runs_total_tokens_to_.py new file mode 100644 index 000000000..798c89586 --- /dev/null +++ b/api/migrations/versions/2025_01_01_2000-a91b476a53de_change_workflow_runs_total_tokens_to_.py @@ -0,0 +1,41 @@ +"""change workflow_runs.total_tokens to bigint + +Revision ID: a91b476a53de +Revises: 923752d42eb6 +Create Date: 2025-01-01 20:00:01.207369 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'a91b476a53de' +down_revision = '923752d42eb6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.alter_column('total_tokens', + existing_type=sa.INTEGER(), + type_=sa.BigInteger(), + existing_nullable=False, + existing_server_default=sa.text('0')) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.alter_column('total_tokens', + existing_type=sa.BigInteger(), + type_=sa.INTEGER(), + existing_nullable=False, + existing_server_default=sa.text('0')) + + # ### end Alembic commands ### diff --git a/api/models/account.py b/api/models/account.py index a8602d10a..35a28df75 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -1,8 +1,9 @@ import enum import json -from flask_login import UserMixin +from flask_login import UserMixin # type: ignore from sqlalchemy import func +from sqlalchemy.orm import Mapped, mapped_column from .engine import db from .types import StringUUID @@ -16,11 +17,11 @@ class AccountStatus(enum.StrEnum): CLOSED = "closed" -class Account(UserMixin, db.Model): +class Account(UserMixin, db.Model): # type: ignore[name-defined] __tablename__ = "accounts" __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) email = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=True) @@ -43,7 +44,8 @@ def is_password_set(self): @property def current_tenant(self): - return self._current_tenant + # FIXME: fix the type error later, because the type is important maybe cause some bugs + return self._current_tenant # type: ignore @current_tenant.setter def current_tenant(self, value: "Tenant"): @@ -52,7 +54,8 @@ def current_tenant(self, value: "Tenant"): if ta: tenant.current_role = ta.role else: - tenant = None + # FIXME: fix the type error later, because the type is important maybe cause some bugs + tenant = None # type: ignore self._current_tenant = tenant @property @@ -89,7 +92,7 @@ def get_status(self) -> AccountStatus: return AccountStatus(status_str) @classmethod - def get_by_openid(cls, provider: str, open_id: str) -> db.Model: + def get_by_openid(cls, provider: str, open_id: str): account_integrate = ( db.session.query(AccountIntegrate) .filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) @@ -134,7 +137,7 @@ class TenantAccountRole(enum.StrEnum): @staticmethod def is_valid_role(role: str) -> bool: - return role and role in { + return role in { TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, @@ -144,15 +147,15 @@ def is_valid_role(role: str) -> bool: @staticmethod def is_privileged_role(role: str) -> bool: - return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} + return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} @staticmethod def is_admin_role(role: str) -> bool: - return role and role == TenantAccountRole.ADMIN + return role == TenantAccountRole.ADMIN @staticmethod def is_non_owner_role(role: str) -> bool: - return role and role in { + return role in { TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, TenantAccountRole.NORMAL, @@ -161,11 +164,11 @@ def is_non_owner_role(role: str) -> bool: @staticmethod def is_editing_role(role: str) -> bool: - return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} + return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} @staticmethod def is_dataset_edit_role(role: str) -> bool: - return role and role in { + return role in { TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, @@ -173,7 +176,7 @@ def is_dataset_edit_role(role: str) -> bool: } -class Tenant(db.Model): +class Tenant(db.Model): # type: ignore[name-defined] __tablename__ = "tenants" __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) @@ -209,7 +212,7 @@ class TenantAccountJoinRole(enum.Enum): DATASET_OPERATOR = "dataset_operator" -class TenantAccountJoin(db.Model): +class TenantAccountJoin(db.Model): # type: ignore[name-defined] __tablename__ = "tenant_account_joins" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), @@ -228,7 +231,7 @@ class TenantAccountJoin(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class AccountIntegrate(db.Model): +class AccountIntegrate(db.Model): # type: ignore[name-defined] __tablename__ = "account_integrates" __table_args__ = ( db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), @@ -245,7 +248,7 @@ class AccountIntegrate(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class InvitationCode(db.Model): +class InvitationCode(db.Model): # type: ignore[name-defined] __tablename__ = "invitation_codes" __table_args__ = ( db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index fbffe7a3b..6b6d80871 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -13,7 +13,7 @@ class APIBasedExtensionPoint(enum.Enum): APP_MODERATION_OUTPUT = "app.moderation.output" -class APIBasedExtension(db.Model): +class APIBasedExtension(db.Model): # type: ignore[name-defined] __tablename__ = "api_based_extensions" __table_args__ = ( db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), diff --git a/api/models/dataset.py b/api/models/dataset.py index 7279e8d5b..567f7db43 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -9,6 +9,7 @@ import re import time from json import JSONDecodeError +from typing import Any, cast from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB @@ -16,6 +17,7 @@ from configs import dify_config from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_storage import storage +from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule from .account import Account from .engine import db @@ -29,7 +31,7 @@ class DatasetPermissionEnum(enum.StrEnum): PARTIAL_TEAM = "partial_members" -class Dataset(db.Model): +class Dataset(db.Model): # type: ignore[name-defined] __tablename__ = "datasets" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_pkey"), @@ -200,7 +202,7 @@ def gen_collection_name_by_id(dataset_id: str) -> str: return f"Vector_index_{normalized_dataset_id}_Node" -class DatasetProcessRule(db.Model): +class DatasetProcessRule(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_process_rules" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), @@ -214,9 +216,9 @@ class DatasetProcessRule(db.Model): created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - MODES = ["automatic", "custom"] + MODES = ["automatic", "custom", "hierarchical"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] - AUTOMATIC_RULES = { + AUTOMATIC_RULES: dict[str, Any] = { "pre_processing_rules": [ {"id": "remove_extra_spaces", "enabled": True}, {"id": "remove_urls_emails", "enabled": False}, @@ -230,8 +232,6 @@ def to_dict(self): "dataset_id": self.dataset_id, "mode": self.mode, "rules": self.rules_dict, - "created_by": self.created_by, - "created_at": self.created_at, } @property @@ -242,7 +242,7 @@ def rules_dict(self): return None -class Document(db.Model): +class Document(db.Model): # type: ignore[name-defined] __tablename__ = "documents" __table_args__ = ( db.PrimaryKeyConstraint("id", name="document_pkey"), @@ -395,6 +395,12 @@ def hit_count(self): .scalar() ) + @property + def process_rule_dict(self): + if self.dataset_process_rule_id: + return self.dataset_process_rule.to_dict() + return None + def to_dict(self): return { "id": self.id, @@ -492,7 +498,7 @@ def from_dict(cls, data: dict): ) -class DocumentSegment(db.Model): +class DocumentSegment(db.Model): # type: ignore[name-defined] __tablename__ = "document_segments" __table_args__ = ( db.PrimaryKeyConstraint("id", name="document_segment_pkey"), @@ -559,6 +565,24 @@ def next_segment(self): .first() ) + @property + def child_chunks(self): + process_rule = self.document.dataset_process_rule + if process_rule.mode == "hierarchical": + rules = Rule(**process_rule.rules_dict) + if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: + child_chunks = ( + db.session.query(ChildChunk) + .filter(ChildChunk.segment_id == self.id) + .order_by(ChildChunk.position.asc()) + .all() + ) + return child_chunks or [] + else: + return [] + else: + return [] + def get_sign_content(self): signed_urls = [] text = self.content @@ -604,7 +628,48 @@ def get_sign_content(self): return text -class AppDatasetJoin(db.Model): +class ChildChunk(db.Model): # type: ignore[name-defined] + __tablename__ = "child_chunks" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="child_chunk_pkey"), + db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"), + ) + + # initial fields + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + document_id = db.Column(StringUUID, nullable=False) + segment_id = db.Column(StringUUID, nullable=False) + position = db.Column(db.Integer, nullable=False) + content = db.Column(db.Text, nullable=False) + word_count = db.Column(db.Integer, nullable=False) + # indexing fields + index_node_id = db.Column(db.String(255), nullable=True) + index_node_hash = db.Column(db.String(255), nullable=True) + type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + indexing_at = db.Column(db.DateTime, nullable=True) + completed_at = db.Column(db.DateTime, nullable=True) + error = db.Column(db.Text, nullable=True) + + @property + def dataset(self): + return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() + + @property + def document(self): + return db.session.query(Document).filter(Document.id == self.document_id).first() + + @property + def segment(self): + return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first() + + +class AppDatasetJoin(db.Model): # type: ignore[name-defined] __tablename__ = "app_dataset_joins" __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), @@ -621,7 +686,7 @@ def app(self): return db.session.get(App, self.app_id) -class DatasetQuery(db.Model): +class DatasetQuery(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_queries" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), @@ -638,7 +703,7 @@ class DatasetQuery(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) -class DatasetKeywordTable(db.Model): +class DatasetKeywordTable(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_keyword_tables" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), @@ -683,7 +748,7 @@ def object_hook(self, dct): return None -class Embedding(db.Model): +class Embedding(db.Model): # type: ignore[name-defined] __tablename__ = "embeddings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="embedding_pkey"), @@ -704,10 +769,10 @@ def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) def get_embedding(self) -> list[float]: - return pickle.loads(self.embedding) + return cast(list[float], pickle.loads(self.embedding)) -class DatasetCollectionBinding(db.Model): +class DatasetCollectionBinding(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_collection_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), @@ -722,7 +787,7 @@ class DatasetCollectionBinding(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TidbAuthBinding(db.Model): +class TidbAuthBinding(db.Model): # type: ignore[name-defined] __tablename__ = "tidb_auth_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), @@ -742,7 +807,7 @@ class TidbAuthBinding(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class Whitelist(db.Model): +class Whitelist(db.Model): # type: ignore[name-defined] __tablename__ = "whitelists" __table_args__ = ( db.PrimaryKeyConstraint("id", name="whitelists_pkey"), @@ -754,7 +819,7 @@ class Whitelist(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class DatasetPermission(db.Model): +class DatasetPermission(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_permissions" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), @@ -771,7 +836,7 @@ class DatasetPermission(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class ExternalKnowledgeApis(db.Model): +class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined] __tablename__ = "external_knowledge_apis" __table_args__ = ( db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), @@ -824,7 +889,7 @@ def dataset_bindings(self): return dataset_bindings -class ExternalKnowledgeBindings(db.Model): +class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined] __tablename__ = "external_knowledge_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), @@ -843,3 +908,20 @@ class ExternalKnowledgeBindings(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined] + __tablename__ = "dataset_auto_disable_logs" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"), + db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"), + db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"), + db.Index("dataset_auto_disable_log_created_atx", "created_at"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + document_id = db.Column(StringUUID, nullable=False) + notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/model.py b/api/models/model.py index de478622d..3067d1145 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -4,11 +4,11 @@ from collections.abc import Mapping from datetime import datetime from enum import Enum, StrEnum -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional, cast import sqlalchemy as sa from flask import request -from flask_login import UserMixin +from flask_login import UserMixin # type: ignore from sqlalchemy import Float, func, text from sqlalchemy.orm import Mapped, mapped_column @@ -28,7 +28,7 @@ from .workflow import Workflow -class DifySetup(db.Model): +class DifySetup(db.Model): # type: ignore[name-defined] __tablename__ = "dify_setups" __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) @@ -63,7 +63,7 @@ class IconType(Enum): EMOJI = "emoji" -class App(db.Model): +class App(db.Model): # type: ignore[name-defined] __tablename__ = "apps" __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) @@ -86,7 +86,7 @@ class App(db.Model): is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) tracing = db.Column(db.Text, nullable=True) - max_active_requests = db.Column(db.Integer, nullable=True) + max_active_requests: Mapped[Optional[int]] = mapped_column(nullable=True) created_by = db.Column(StringUUID, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) @@ -154,7 +154,7 @@ def mode_compatible_with_agent(self) -> str: if self.mode == AppMode.CHAT.value and self.is_agent: return AppMode.AGENT_CHAT.value - return self.mode + return str(self.mode) @property def deleted_tools(self) -> list: @@ -231,7 +231,7 @@ class AppStatisticsExtend(db.Model): number = db.Column(db.Integer, nullable=False, default=0) -class AppModelConfig(db.Model): +class AppModelConfig(db.Model): # type: ignore[name-defined] __tablename__ = "app_model_configs" __table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id")) @@ -334,7 +334,7 @@ def external_data_tools_list(self) -> list[dict]: return json.loads(self.external_data_tools) if self.external_data_tools else [] @property - def user_input_form_list(self) -> dict: + def user_input_form_list(self) -> list[dict]: return json.loads(self.user_input_form) if self.user_input_form else [] @property @@ -356,7 +356,7 @@ def completion_prompt_config_dict(self) -> dict: @property def dataset_configs_dict(self) -> dict: if self.dataset_configs: - dataset_configs = json.loads(self.dataset_configs) + dataset_configs: dict = json.loads(self.dataset_configs) if "retrieval_model" not in dataset_configs: return {"retrieval_model": "single"} else: @@ -504,7 +504,7 @@ class RecommendedAppsCategoryJoinExtend(db.Model): category_id = db.Column(StringUUID, nullable=False) -class RecommendedApp(db.Model): +class RecommendedApp(db.Model): # type: ignore[name-defined] __tablename__ = "recommended_apps" __table_args__ = ( db.PrimaryKeyConstraint("id", name="recommended_app_pkey"), @@ -532,7 +532,7 @@ def app(self): return app -class InstalledApp(db.Model): +class InstalledApp(db.Model): # type: ignore[name-defined] __tablename__ = "installed_apps" __table_args__ = ( db.PrimaryKeyConstraint("id", name="installed_app_pkey"), @@ -561,20 +561,20 @@ def tenant(self): return tenant -class Conversation(db.Model): +class Conversation(db.Model): # type: ignore[name-defined] __tablename__ = "conversations" __table_args__ = ( db.PrimaryKeyConstraint("id", name="conversation_pkey"), db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) app_model_config_id = db.Column(StringUUID, nullable=True) model_provider = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) model_id = db.Column(db.String(255), nullable=True) - mode = db.Column(db.String(255), nullable=False) + mode: Mapped[str] = mapped_column(db.String(255)) name = db.Column(db.String(255), nullable=False) summary = db.Column(db.Text) _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) @@ -640,6 +640,8 @@ def inputs(self, value: Mapping[str, Any]): @property def model_config(self): model_config = {} + app_model_config: Optional[AppModelConfig] = None + if self.mode == AppMode.ADVANCED_CHAT.value: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) @@ -651,6 +653,7 @@ def model_config(self): if "model" in override_model_configs: app_model_config = AppModelConfig() app_model_config = app_model_config.from_model_config_dict(override_model_configs) + assert app_model_config is not None, "app model config not found" model_config = app_model_config.to_dict() else: model_config["configs"] = override_model_configs @@ -807,7 +810,7 @@ def in_debug_mode(self): return self.override_model_configs is not None -class Message(db.Model): +class Message(db.Model): # type: ignore[name-defined] __tablename__ = "messages" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_pkey"), @@ -819,7 +822,7 @@ class Message(db.Model): db.Index("message_created_at_idx", "created_at"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) model_provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) @@ -846,7 +849,7 @@ class Message(db.Model): from_source = db.Column(db.String(255), nullable=False) from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id = db.Column(StringUUID) @@ -1047,7 +1050,7 @@ def message_files(self): if not current_app: raise ValueError(f"App {self.app_id} not found") - files: list[File] = [] + files = [] for message_file in message_files: if message_file.transfer_method == "local_file": if message_file.upload_file_id is None: @@ -1154,7 +1157,7 @@ def from_dict(cls, data: dict): ) -class MessageFeedback(db.Model): +class MessageFeedback(db.Model): # type: ignore[name-defined] __tablename__ = "message_feedbacks" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_feedback_pkey"), @@ -1181,7 +1184,7 @@ def from_account(self): return account -class MessageFile(db.Model): +class MessageFile(db.Model): # type: ignore[name-defined] __tablename__ = "message_files" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_file_pkey"), @@ -1222,7 +1225,7 @@ def __init__( created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class MessageAnnotation(db.Model): +class MessageAnnotation(db.Model): # type: ignore[name-defined] __tablename__ = "message_annotations" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_annotation_pkey"), @@ -1253,7 +1256,7 @@ def annotation_create_account(self): return account -class AppAnnotationHitHistory(db.Model): +class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined] __tablename__ = "app_annotation_hit_histories" __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), @@ -1291,7 +1294,7 @@ def annotation_create_account(self): return account -class AppAnnotationSetting(db.Model): +class AppAnnotationSetting(db.Model): # type: ignore[name-defined] __tablename__ = "app_annotation_settings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), @@ -1339,7 +1342,7 @@ def collection_binding_detail(self): return collection_binding_detail -class OperationLog(db.Model): +class OperationLog(db.Model): # type: ignore[name-defined] __tablename__ = "operation_logs" __table_args__ = ( db.PrimaryKeyConstraint("id", name="operation_log_pkey"), @@ -1356,7 +1359,7 @@ class OperationLog(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class EndUser(UserMixin, db.Model): +class EndUser(UserMixin, db.Model): # type: ignore[name-defined] __tablename__ = "end_users" __table_args__ = ( db.PrimaryKeyConstraint("id", name="end_user_pkey"), @@ -1371,12 +1374,12 @@ class EndUser(UserMixin, db.Model): external_user_id = db.Column(db.String(255), nullable=True) name = db.Column(db.String(255)) is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - session_id = db.Column(db.String(255), nullable=False) + session_id: Mapped[str] = mapped_column() created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class Site(db.Model): +class Site(db.Model): # type: ignore[name-defined] __tablename__ = "sites" __table_args__ = ( db.PrimaryKeyConstraint("id", name="site_pkey"), @@ -1433,7 +1436,7 @@ def app_base_url(self): return dify_config.APP_WEB_URL or request.url_root.rstrip("/") -class ApiToken(db.Model): +class ApiToken(db.Model): # type: ignore[name-defined] __tablename__ = "api_tokens" __table_args__ = ( db.PrimaryKeyConstraint("id", name="api_token_pkey"), @@ -1454,13 +1457,12 @@ class ApiToken(db.Model): def generate_api_key(prefix, n): while True: result = prefix + generate_string(n) - while db.session.query(ApiToken).filter(ApiToken.token == result).count() > 0: - result = prefix + generate_string(n) - + if db.session.query(ApiToken).filter(ApiToken.token == result).count() > 0: + continue return result -class UploadFile(db.Model): +class UploadFile(db.Model): # type: ignore[name-defined] __tablename__ = "upload_files" __table_args__ = ( db.PrimaryKeyConstraint("id", name="upload_file_pkey"), @@ -1522,7 +1524,7 @@ def __init__( self.source_url = source_url -class ApiRequest(db.Model): +class ApiRequest(db.Model): # type: ignore[name-defined] __tablename__ = "api_requests" __table_args__ = ( db.PrimaryKeyConstraint("id", name="api_request_pkey"), @@ -1539,7 +1541,7 @@ class ApiRequest(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class MessageChain(db.Model): +class MessageChain(db.Model): # type: ignore[name-defined] __tablename__ = "message_chains" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_chain_pkey"), @@ -1554,7 +1556,7 @@ class MessageChain(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) -class MessageAgentThought(db.Model): +class MessageAgentThought(db.Model): # type: ignore[name-defined] __tablename__ = "message_agent_thoughts" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), @@ -1594,7 +1596,7 @@ class MessageAgentThought(db.Model): @property def files(self) -> list: if self.message_files: - return json.loads(self.message_files) + return cast(list[Any], json.loads(self.message_files)) else: return [] @@ -1606,7 +1608,7 @@ def tools(self) -> list[str]: def tool_labels(self) -> dict: try: if self.tool_labels_str: - return json.loads(self.tool_labels_str) + return cast(dict, json.loads(self.tool_labels_str)) else: return {} except Exception as e: @@ -1616,7 +1618,7 @@ def tool_labels(self) -> dict: def tool_meta(self) -> dict: try: if self.tool_meta_str: - return json.loads(self.tool_meta_str) + return cast(dict, json.loads(self.tool_meta_str)) else: return {} except Exception as e: @@ -1664,9 +1666,11 @@ def tool_outputs_dict(self) -> dict: except Exception as e: if self.observation: return dict.fromkeys(tools, self.observation) + else: + return {} -class DatasetRetrieverResource(db.Model): +class DatasetRetrieverResource(db.Model): # type: ignore[name-defined] __tablename__ = "dataset_retriever_resources" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), @@ -1693,7 +1697,7 @@ class DatasetRetrieverResource(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) -class Tag(db.Model): +class Tag(db.Model): # type: ignore[name-defined] __tablename__ = "tags" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tag_pkey"), @@ -1711,7 +1715,7 @@ class Tag(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TagBinding(db.Model): +class TagBinding(db.Model): # type: ignore[name-defined] __tablename__ = "tag_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tag_binding_pkey"), @@ -1727,7 +1731,7 @@ class TagBinding(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TraceAppConfig(db.Model): +class TraceAppConfig(db.Model): # type: ignore[name-defined] __tablename__ = "trace_app_config" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), diff --git a/api/models/provider.py b/api/models/provider.py index fdd3e802d..abe673975 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -36,7 +36,7 @@ def value_of(value): raise ValueError(f"No matching enum found for value '{value}'") -class Provider(db.Model): +class Provider(db.Model): # type: ignore[name-defined] """ Provider model representing the API providers and their configurations. """ @@ -89,7 +89,7 @@ def is_enabled(self): return self.is_valid and self.token_is_set -class ProviderModel(db.Model): +class ProviderModel(db.Model): # type: ignore[name-defined] """ Provider model representing the API provider_models and their configurations. """ @@ -114,7 +114,7 @@ class ProviderModel(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TenantDefaultModel(db.Model): +class TenantDefaultModel(db.Model): # type: ignore[name-defined] __tablename__ = "tenant_default_models" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), @@ -130,7 +130,7 @@ class TenantDefaultModel(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TenantPreferredModelProvider(db.Model): +class TenantPreferredModelProvider(db.Model): # type: ignore[name-defined] __tablename__ = "tenant_preferred_model_providers" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), @@ -145,7 +145,7 @@ class TenantPreferredModelProvider(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class ProviderOrder(db.Model): +class ProviderOrder(db.Model): # type: ignore[name-defined] __tablename__ = "provider_orders" __table_args__ = ( db.PrimaryKeyConstraint("id", name="provider_order_pkey"), @@ -170,7 +170,7 @@ class ProviderOrder(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class ProviderModelSetting(db.Model): +class ProviderModelSetting(db.Model): # type: ignore[name-defined] """ Provider model settings for record the model enabled status and load balancing status. """ @@ -192,7 +192,7 @@ class ProviderModelSetting(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class LoadBalancingModelConfig(db.Model): +class LoadBalancingModelConfig(db.Model): # type: ignore[name-defined] """ Configurations for load balancing models. """ diff --git a/api/models/source.py b/api/models/source.py index 114db8e11..881cfaac7 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -7,7 +7,7 @@ from .types import StringUUID -class DataSourceOauthBinding(db.Model): +class DataSourceOauthBinding(db.Model): # type: ignore[name-defined] __tablename__ = "data_source_oauth_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="source_binding_pkey"), @@ -25,7 +25,7 @@ class DataSourceOauthBinding(db.Model): disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) -class DataSourceApiKeyAuthBinding(db.Model): +class DataSourceApiKeyAuthBinding(db.Model): # type: ignore[name-defined] __tablename__ = "data_source_api_key_auth_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"), diff --git a/api/models/task.py b/api/models/task.py index 27571e247..0db1c6322 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,11 +1,11 @@ from datetime import UTC, datetime -from celery import states +from celery import states # type: ignore from .engine import db -class CeleryTask(db.Model): +class CeleryTask(db.Model): # type: ignore[name-defined] """Task result/status.""" __tablename__ = "celery_taskmeta" @@ -29,7 +29,7 @@ class CeleryTask(db.Model): queue = db.Column(db.String(155), nullable=True) -class CeleryTaskSet(db.Model): +class CeleryTaskSet(db.Model): # type: ignore[name-defined] """TaskSet result.""" __tablename__ = "celery_tasksetmeta" diff --git a/api/models/tools.py b/api/models/tools.py index e90ab669c..13a112ee8 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import Any, Optional import sqlalchemy as sa from sqlalchemy import ForeignKey, func @@ -14,7 +14,7 @@ from .types import StringUUID -class BuiltinToolProvider(db.Model): +class BuiltinToolProvider(db.Model): # type: ignore[name-defined] """ This table stores the tool provider information for built-in tools for each tenant. """ @@ -41,10 +41,10 @@ class BuiltinToolProvider(db.Model): @property def credentials(self) -> dict: - return json.loads(self.encrypted_credentials) + return dict(json.loads(self.encrypted_credentials)) -class PublishedAppTool(db.Model): +class PublishedAppTool(db.Model): # type: ignore[name-defined] """ The table stores the apps published as a tool for each person. """ @@ -86,7 +86,7 @@ def app(self): return db.session.query(App).filter(App.id == self.app_id).first() -class ApiToolProvider(db.Model): +class ApiToolProvider(db.Model): # type: ignore[name-defined] """ The table stores the api providers. """ @@ -133,7 +133,7 @@ def tools(self) -> list[ApiToolBundle]: @property def credentials(self) -> dict: - return json.loads(self.credentials_str) + return dict(json.loads(self.credentials_str)) @property def user(self) -> Account | None: @@ -144,7 +144,7 @@ def tenant(self) -> Tenant | None: return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() -class ToolLabelBinding(db.Model): +class ToolLabelBinding(db.Model): # type: ignore[name-defined] """ The table stores the labels for tools. """ @@ -164,7 +164,7 @@ class ToolLabelBinding(db.Model): label_name = db.Column(db.String(40), nullable=False) -class WorkflowToolProvider(db.Model): +class WorkflowToolProvider(db.Model): # type: ignore[name-defined] """ The table stores the workflow providers. """ @@ -218,7 +218,7 @@ def app(self) -> App | None: return db.session.query(App).filter(App.id == self.app_id).first() -class ToolModelInvoke(db.Model): +class ToolModelInvoke(db.Model): # type: ignore[name-defined] """ store the invoke logs from tool invoke """ @@ -255,7 +255,7 @@ class ToolModelInvoke(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class ToolConversationVariables(db.Model): +class ToolConversationVariables(db.Model): # type: ignore[name-defined] """ store the conversation variables from tool invoke """ @@ -282,11 +282,11 @@ class ToolConversationVariables(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property - def variables(self) -> dict: + def variables(self) -> Any: return json.loads(self.variables_str) -class ToolFile(db.Model): +class ToolFile(db.Model): # type: ignore[name-defined] __tablename__ = "tool_files" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tool_file_pkey"), diff --git a/api/models/web.py b/api/models/web.py index 028a76851..864428fe0 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -6,7 +6,7 @@ from .types import StringUUID -class SavedMessage(db.Model): +class SavedMessage(db.Model): # type: ignore[name-defined] __tablename__ = "saved_messages" __table_args__ = ( db.PrimaryKeyConstraint("id", name="saved_message_pkey"), @@ -25,7 +25,7 @@ def message(self): return db.session.query(Message).filter(Message.id == self.message_id).first() -class PinnedConversation(db.Model): +class PinnedConversation(db.Model): # type: ignore[name-defined] __tablename__ = "pinned_conversations" __table_args__ = ( db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), diff --git a/api/models/workflow.py b/api/models/workflow.py index a6b881e4d..4c7e273c5 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence from datetime import UTC, datetime from enum import Enum, StrEnum -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import sqlalchemy as sa from sqlalchemy import func @@ -20,6 +20,9 @@ from .engine import db from .types import StringUUID +if TYPE_CHECKING: + from models.model import AppMode, Message + class WorkflowType(Enum): """ @@ -56,7 +59,7 @@ def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT -class Workflow(db.Model): +class Workflow(db.Model): # type: ignore[name-defined] """ Workflow, for `Workflow App` and `Chat App workflow mode`. @@ -182,7 +185,7 @@ def features(self, value: str) -> None: self._features = value @property - def features_dict(self) -> Mapping[str, Any]: + def features_dict(self) -> dict[str, Any]: return json.loads(self.features) if self.features else {} def user_input_form(self, to_old_structure: bool = False) -> list: @@ -199,7 +202,7 @@ def user_input_form(self, to_old_structure: bool = False) -> list: return [] # get user_input_form from start node - variables = start_node.get("data", {}).get("variables", []) + variables: list[Any] = start_node.get("data", {}).get("variables", []) if to_old_structure: old_structure_variables = [] @@ -344,7 +347,7 @@ def value_of(cls, value: str) -> "WorkflowRunStatus": raise ValueError(f"invalid workflow run status value {value}") -class WorkflowRun(db.Model): +class WorkflowRun(db.Model): # type: ignore[name-defined] """ Workflow Run @@ -389,23 +392,23 @@ class WorkflowRun(db.Model): db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) - sequence_number = db.Column(db.Integer, nullable=False) - workflow_id = db.Column(StringUUID, nullable=False) - type = db.Column(db.String(255), nullable=False) - triggered_from = db.Column(db.String(255), nullable=False) - version = db.Column(db.String(255), nullable=False) - graph = db.Column(db.Text) - inputs = db.Column(db.Text) - status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + app_id: Mapped[str] = mapped_column(StringUUID) + sequence_number: Mapped[int] = mapped_column() + workflow_id: Mapped[str] = mapped_column(StringUUID) + type: Mapped[str] = mapped_column(db.String(255)) + triggered_from: Mapped[str] = mapped_column(db.String(255)) + version: Mapped[str] = mapped_column(db.String(255)) + graph: Mapped[Optional[str]] = mapped_column(db.Text) + inputs: Mapped[Optional[str]] = mapped_column(db.Text) + status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") - error = db.Column(db.Text) - elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) - total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + error: Mapped[Optional[str]] = mapped_column(db.Text) + elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0")) + total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) total_steps = db.Column(db.Integer, server_default=db.text("0")) - created_by_role = db.Column(db.String(255), nullable=False) # account, end_user + created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) finished_at = db.Column(db.DateTime) @@ -546,7 +549,7 @@ def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": raise ValueError(f"invalid workflow node execution status value {value}") -class WorkflowNodeExecution(db.Model): +class WorkflowNodeExecution(db.Model): # type: ignore[name-defined] """ Workflow Node Execution @@ -618,29 +621,29 @@ class WorkflowNodeExecution(db.Model): ), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) - workflow_id = db.Column(StringUUID, nullable=False) - triggered_from = db.Column(db.String(255), nullable=False) - workflow_run_id = db.Column(StringUUID) - index = db.Column(db.Integer, nullable=False) - predecessor_node_id = db.Column(db.String(255)) - node_execution_id = db.Column(db.String(255), nullable=True) - node_id = db.Column(db.String(255), nullable=False) - node_type = db.Column(db.String(255), nullable=False) - title = db.Column(db.String(255), nullable=False) - inputs = db.Column(db.Text) - process_data = db.Column(db.Text) - outputs = db.Column(db.Text) - status = db.Column(db.String(255), nullable=False) - error = db.Column(db.Text) - elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) - execution_metadata = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - created_by_role = db.Column(db.String(255), nullable=False) - created_by = db.Column(StringUUID, nullable=False) - finished_at = db.Column(db.DateTime) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + app_id: Mapped[str] = mapped_column(StringUUID) + workflow_id: Mapped[str] = mapped_column(StringUUID) + triggered_from: Mapped[str] = mapped_column(db.String(255)) + workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) + index: Mapped[int] = mapped_column(db.Integer) + predecessor_node_id: Mapped[Optional[str]] = mapped_column(db.String(255)) + node_execution_id: Mapped[Optional[str]] = mapped_column(db.String(255)) + node_id: Mapped[str] = mapped_column(db.String(255)) + node_type: Mapped[str] = mapped_column(db.String(255)) + title: Mapped[str] = mapped_column(db.String(255)) + inputs: Mapped[Optional[str]] = mapped_column(db.Text) + process_data: Mapped[Optional[str]] = mapped_column(db.Text) + outputs: Mapped[Optional[str]] = mapped_column(db.Text) + status: Mapped[str] = mapped_column(db.String(255)) + error: Mapped[Optional[str]] = mapped_column(db.Text) + elapsed_time: Mapped[float] = mapped_column(db.Float, server_default=db.text("0")) + execution_metadata: Mapped[Optional[str]] = mapped_column(db.Text) + created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) + created_by_role: Mapped[str] = mapped_column(db.String(255)) + created_by: Mapped[str] = mapped_column(StringUUID) + finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) @property def created_by_account(self): @@ -712,7 +715,7 @@ def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": raise ValueError(f"invalid workflow app log created from value {value}") -class WorkflowAppLog(db.Model): +class WorkflowAppLog(db.Model): # type: ignore[name-defined] """ Workflow App execution log, excluding workflow debugging records. @@ -747,11 +750,11 @@ class WorkflowAppLog(db.Model): db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID) + app_id: Mapped[str] = mapped_column(StringUUID) workflow_id = db.Column(StringUUID, nullable=False) - workflow_run_id = db.Column(StringUUID, nullable=False) + workflow_run_id: Mapped[str] = mapped_column(StringUUID) created_from = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) @@ -815,7 +818,7 @@ def created_by_end_user(self): # Extend: stop /apps//workflow-app-logs update end_user to account -class ConversationVariable(db.Model): +class ConversationVariable(db.Model): # type: ignore[name-defined] __tablename__ = "workflow_conversation_variables" id: Mapped[str] = db.Column(StringUUID, primary_key=True) diff --git a/api/mypy.ini b/api/mypy.ini new file mode 100644 index 000000000..2c754f9fc --- /dev/null +++ b/api/mypy.ini @@ -0,0 +1,10 @@ +[mypy] +warn_return_any = True +warn_unused_configs = True +check_untyped_defs = True +exclude = (?x)( + core/tools/provider/builtin/ + | core/model_runtime/model_providers/ + | tests/ + | migrations/ + ) \ No newline at end of file diff --git a/api/poetry.lock b/api/poetry.lock index bdff35b89..72b88a58b 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -34,102 +34,87 @@ reference = "aliyun" [[package]] name = "aiohttp" -version = "3.10.5" +version = "3.11.11" description = "Async http client/server framework (asyncio)" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "aiohttp-3.10.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:18a01eba2574fb9edd5f6e5fb25f66e6ce061da5dab5db75e13fe1558142e0a3"}, - {file = "aiohttp-3.10.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:94fac7c6e77ccb1ca91e9eb4cb0ac0270b9fb9b289738654120ba8cebb1189c6"}, - {file = "aiohttp-3.10.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2f1f1c75c395991ce9c94d3e4aa96e5c59c8356a15b1c9231e783865e2772699"}, - {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f7acae3cf1a2a2361ec4c8e787eaaa86a94171d2417aae53c0cca6ca3118ff6"}, - {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:94c4381ffba9cc508b37d2e536b418d5ea9cfdc2848b9a7fea6aebad4ec6aac1"}, - {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c31ad0c0c507894e3eaa843415841995bf8de4d6b2d24c6e33099f4bc9fc0d4f"}, - {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0912b8a8fadeb32ff67a3ed44249448c20148397c1ed905d5dac185b4ca547bb"}, - {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d93400c18596b7dc4794d48a63fb361b01a0d8eb39f28800dc900c8fbdaca91"}, - {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d00f3c5e0d764a5c9aa5a62d99728c56d455310bcc288a79cab10157b3af426f"}, - {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:d742c36ed44f2798c8d3f4bc511f479b9ceef2b93f348671184139e7d708042c"}, - {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:814375093edae5f1cb31e3407997cf3eacefb9010f96df10d64829362ae2df69"}, - {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8224f98be68a84b19f48e0bdc14224b5a71339aff3a27df69989fa47d01296f3"}, - {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d9a487ef090aea982d748b1b0d74fe7c3950b109df967630a20584f9a99c0683"}, - {file = "aiohttp-3.10.5-cp310-cp310-win32.whl", hash = "sha256:d9ef084e3dc690ad50137cc05831c52b6ca428096e6deb3c43e95827f531d5ef"}, - {file = "aiohttp-3.10.5-cp310-cp310-win_amd64.whl", hash = "sha256:66bf9234e08fe561dccd62083bf67400bdbf1c67ba9efdc3dac03650e97c6088"}, - {file = "aiohttp-3.10.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8c6a4e5e40156d72a40241a25cc226051c0a8d816610097a8e8f517aeacd59a2"}, - {file = "aiohttp-3.10.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2c634a3207a5445be65536d38c13791904fda0748b9eabf908d3fe86a52941cf"}, - {file = "aiohttp-3.10.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4aff049b5e629ef9b3e9e617fa6e2dfeda1bf87e01bcfecaf3949af9e210105e"}, - {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1942244f00baaacaa8155eca94dbd9e8cc7017deb69b75ef67c78e89fdad3c77"}, - {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e04a1f2a65ad2f93aa20f9ff9f1b672bf912413e5547f60749fa2ef8a644e061"}, - {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7f2bfc0032a00405d4af2ba27f3c429e851d04fad1e5ceee4080a1c570476697"}, - {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:424ae21498790e12eb759040bbb504e5e280cab64693d14775c54269fd1d2bb7"}, - {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:975218eee0e6d24eb336d0328c768ebc5d617609affaca5dbbd6dd1984f16ed0"}, - {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4120d7fefa1e2d8fb6f650b11489710091788de554e2b6f8347c7a20ceb003f5"}, - {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b90078989ef3fc45cf9221d3859acd1108af7560c52397ff4ace8ad7052a132e"}, - {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ba5a8b74c2a8af7d862399cdedce1533642fa727def0b8c3e3e02fcb52dca1b1"}, - {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:02594361128f780eecc2a29939d9dfc870e17b45178a867bf61a11b2a4367277"}, - {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:8fb4fc029e135859f533025bc82047334e24b0d489e75513144f25408ecaf058"}, - {file = "aiohttp-3.10.5-cp311-cp311-win32.whl", hash = "sha256:e1ca1ef5ba129718a8fc827b0867f6aa4e893c56eb00003b7367f8a733a9b072"}, - {file = "aiohttp-3.10.5-cp311-cp311-win_amd64.whl", hash = "sha256:349ef8a73a7c5665cca65c88ab24abe75447e28aa3bc4c93ea5093474dfdf0ff"}, - {file = "aiohttp-3.10.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:305be5ff2081fa1d283a76113b8df7a14c10d75602a38d9f012935df20731487"}, - {file = "aiohttp-3.10.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3a1c32a19ee6bbde02f1cb189e13a71b321256cc1d431196a9f824050b160d5a"}, - {file = "aiohttp-3.10.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:61645818edd40cc6f455b851277a21bf420ce347baa0b86eaa41d51ef58ba23d"}, - {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c225286f2b13bab5987425558baa5cbdb2bc925b2998038fa028245ef421e75"}, - {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ba01ebc6175e1e6b7275c907a3a36be48a2d487549b656aa90c8a910d9f3178"}, - {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8eaf44ccbc4e35762683078b72bf293f476561d8b68ec8a64f98cf32811c323e"}, - {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1c43eb1ab7cbf411b8e387dc169acb31f0ca0d8c09ba63f9eac67829585b44f"}, - {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de7a5299827253023c55ea549444e058c0eb496931fa05d693b95140a947cb73"}, - {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4790f0e15f00058f7599dab2b206d3049d7ac464dc2e5eae0e93fa18aee9e7bf"}, - {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:44b324a6b8376a23e6ba25d368726ee3bc281e6ab306db80b5819999c737d820"}, - {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:0d277cfb304118079e7044aad0b76685d30ecb86f83a0711fc5fb257ffe832ca"}, - {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:54d9ddea424cd19d3ff6128601a4a4d23d54a421f9b4c0fff740505813739a91"}, - {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:4f1c9866ccf48a6df2b06823e6ae80573529f2af3a0992ec4fe75b1a510df8a6"}, - {file = "aiohttp-3.10.5-cp312-cp312-win32.whl", hash = "sha256:dc4826823121783dccc0871e3f405417ac116055bf184ac04c36f98b75aacd12"}, - {file = "aiohttp-3.10.5-cp312-cp312-win_amd64.whl", hash = "sha256:22c0a23a3b3138a6bf76fc553789cb1a703836da86b0f306b6f0dc1617398abc"}, - {file = "aiohttp-3.10.5-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7f6b639c36734eaa80a6c152a238242bedcee9b953f23bb887e9102976343092"}, - {file = "aiohttp-3.10.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f29930bc2921cef955ba39a3ff87d2c4398a0394ae217f41cb02d5c26c8b1b77"}, - {file = "aiohttp-3.10.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f489a2c9e6455d87eabf907ac0b7d230a9786be43fbe884ad184ddf9e9c1e385"}, - {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:123dd5b16b75b2962d0fff566effb7a065e33cd4538c1692fb31c3bda2bfb972"}, - {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b98e698dc34966e5976e10bbca6d26d6724e6bdea853c7c10162a3235aba6e16"}, - {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c3b9162bab7e42f21243effc822652dc5bb5e8ff42a4eb62fe7782bcbcdfacf6"}, - {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1923a5c44061bffd5eebeef58cecf68096e35003907d8201a4d0d6f6e387ccaa"}, - {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d55f011da0a843c3d3df2c2cf4e537b8070a419f891c930245f05d329c4b0689"}, - {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:afe16a84498441d05e9189a15900640a2d2b5e76cf4efe8cbb088ab4f112ee57"}, - {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f8112fb501b1e0567a1251a2fd0747baae60a4ab325a871e975b7bb67e59221f"}, - {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:1e72589da4c90337837fdfe2026ae1952c0f4a6e793adbbfbdd40efed7c63599"}, - {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:4d46c7b4173415d8e583045fbc4daa48b40e31b19ce595b8d92cf639396c15d5"}, - {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:33e6bc4bab477c772a541f76cd91e11ccb6d2efa2b8d7d7883591dfb523e5987"}, - {file = "aiohttp-3.10.5-cp313-cp313-win32.whl", hash = "sha256:c58c6837a2c2a7cf3133983e64173aec11f9c2cd8e87ec2fdc16ce727bcf1a04"}, - {file = "aiohttp-3.10.5-cp313-cp313-win_amd64.whl", hash = "sha256:38172a70005252b6893088c0f5e8a47d173df7cc2b2bd88650957eb84fcf5022"}, - {file = "aiohttp-3.10.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f6f18898ace4bcd2d41a122916475344a87f1dfdec626ecde9ee802a711bc569"}, - {file = "aiohttp-3.10.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5ede29d91a40ba22ac1b922ef510aab871652f6c88ef60b9dcdf773c6d32ad7a"}, - {file = "aiohttp-3.10.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:673f988370f5954df96cc31fd99c7312a3af0a97f09e407399f61583f30da9bc"}, - {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58718e181c56a3c02d25b09d4115eb02aafe1a732ce5714ab70326d9776457c3"}, - {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b38b1570242fbab8d86a84128fb5b5234a2f70c2e32f3070143a6d94bc854cf"}, - {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:074d1bff0163e107e97bd48cad9f928fa5a3eb4b9d33366137ffce08a63e37fe"}, - {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd31f176429cecbc1ba499d4aba31aaccfea488f418d60376b911269d3b883c5"}, - {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7384d0b87d4635ec38db9263e6a3f1eb609e2e06087f0aa7f63b76833737b471"}, - {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:8989f46f3d7ef79585e98fa991e6ded55d2f48ae56d2c9fa5e491a6e4effb589"}, - {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:c83f7a107abb89a227d6c454c613e7606c12a42b9a4ca9c5d7dad25d47c776ae"}, - {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:cde98f323d6bf161041e7627a5fd763f9fd829bcfcd089804a5fdce7bb6e1b7d"}, - {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:676f94c5480d8eefd97c0c7e3953315e4d8c2b71f3b49539beb2aa676c58272f"}, - {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:2d21ac12dc943c68135ff858c3a989f2194a709e6e10b4c8977d7fcd67dfd511"}, - {file = "aiohttp-3.10.5-cp38-cp38-win32.whl", hash = "sha256:17e997105bd1a260850272bfb50e2a328e029c941c2708170d9d978d5a30ad9a"}, - {file = "aiohttp-3.10.5-cp38-cp38-win_amd64.whl", hash = "sha256:1c19de68896747a2aa6257ae4cf6ef59d73917a36a35ee9d0a6f48cff0f94db8"}, - {file = "aiohttp-3.10.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7e2fe37ac654032db1f3499fe56e77190282534810e2a8e833141a021faaab0e"}, - {file = "aiohttp-3.10.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f5bf3ead3cb66ab990ee2561373b009db5bc0e857549b6c9ba84b20bc462e172"}, - {file = "aiohttp-3.10.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1b2c16a919d936ca87a3c5f0e43af12a89a3ce7ccbce59a2d6784caba945b68b"}, - {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad146dae5977c4dd435eb31373b3fe9b0b1bf26858c6fc452bf6af394067e10b"}, - {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c5c6fa16412b35999320f5c9690c0f554392dc222c04e559217e0f9ae244b92"}, - {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:95c4dc6f61d610bc0ee1edc6f29d993f10febfe5b76bb470b486d90bbece6b22"}, - {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da452c2c322e9ce0cfef392e469a26d63d42860f829026a63374fde6b5c5876f"}, - {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:898715cf566ec2869d5cb4d5fb4be408964704c46c96b4be267442d265390f32"}, - {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:391cc3a9c1527e424c6865e087897e766a917f15dddb360174a70467572ac6ce"}, - {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:380f926b51b92d02a34119d072f178d80bbda334d1a7e10fa22d467a66e494db"}, - {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ce91db90dbf37bb6fa0997f26574107e1b9d5ff939315247b7e615baa8ec313b"}, - {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9093a81e18c45227eebe4c16124ebf3e0d893830c6aca7cc310bfca8fe59d857"}, - {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:ee40b40aa753d844162dcc80d0fe256b87cba48ca0054f64e68000453caead11"}, - {file = "aiohttp-3.10.5-cp39-cp39-win32.whl", hash = "sha256:03f2645adbe17f274444953bdea69f8327e9d278d961d85657cb0d06864814c1"}, - {file = "aiohttp-3.10.5-cp39-cp39-win_amd64.whl", hash = "sha256:d17920f18e6ee090bdd3d0bfffd769d9f2cb4c8ffde3eb203777a3895c128862"}, - {file = "aiohttp-3.10.5.tar.gz", hash = "sha256:f071854b47d39591ce9a17981c46790acb30518e2f83dfca8db2dfa091178691"}, + {file = "aiohttp-3.11.11-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a60804bff28662cbcf340a4d61598891f12eea3a66af48ecfdc975ceec21e3c8"}, + {file = "aiohttp-3.11.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4b4fa1cb5f270fb3eab079536b764ad740bb749ce69a94d4ec30ceee1b5940d5"}, + {file = "aiohttp-3.11.11-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:731468f555656767cda219ab42e033355fe48c85fbe3ba83a349631541715ba2"}, + {file = "aiohttp-3.11.11-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb23d8bb86282b342481cad4370ea0853a39e4a32a0042bb52ca6bdde132df43"}, + {file = "aiohttp-3.11.11-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f047569d655f81cb70ea5be942ee5d4421b6219c3f05d131f64088c73bb0917f"}, + {file = "aiohttp-3.11.11-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dd7659baae9ccf94ae5fe8bfaa2c7bc2e94d24611528395ce88d009107e00c6d"}, + {file = "aiohttp-3.11.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af01e42ad87ae24932138f154105e88da13ce7d202a6de93fafdafb2883a00ef"}, + {file = "aiohttp-3.11.11-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5854be2f3e5a729800bac57a8d76af464e160f19676ab6aea74bde18ad19d438"}, + {file = "aiohttp-3.11.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6526e5fb4e14f4bbf30411216780c9967c20c5a55f2f51d3abd6de68320cc2f3"}, + {file = "aiohttp-3.11.11-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:85992ee30a31835fc482468637b3e5bd085fa8fe9392ba0bdcbdc1ef5e9e3c55"}, + {file = "aiohttp-3.11.11-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:88a12ad8ccf325a8a5ed80e6d7c3bdc247d66175afedbe104ee2aaca72960d8e"}, + {file = "aiohttp-3.11.11-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:0a6d3fbf2232e3a08c41eca81ae4f1dff3d8f1a30bae415ebe0af2d2458b8a33"}, + {file = "aiohttp-3.11.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:84a585799c58b795573c7fa9b84c455adf3e1d72f19a2bf498b54a95ae0d194c"}, + {file = "aiohttp-3.11.11-cp310-cp310-win32.whl", hash = "sha256:bfde76a8f430cf5c5584553adf9926534352251d379dcb266ad2b93c54a29745"}, + {file = "aiohttp-3.11.11-cp310-cp310-win_amd64.whl", hash = "sha256:0fd82b8e9c383af11d2b26f27a478640b6b83d669440c0a71481f7c865a51da9"}, + {file = "aiohttp-3.11.11-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ba74ec819177af1ef7f59063c6d35a214a8fde6f987f7661f4f0eecc468a8f76"}, + {file = "aiohttp-3.11.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4af57160800b7a815f3fe0eba9b46bf28aafc195555f1824555fa2cfab6c1538"}, + {file = "aiohttp-3.11.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ffa336210cf9cd8ed117011085817d00abe4c08f99968deef0013ea283547204"}, + {file = "aiohttp-3.11.11-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81b8fe282183e4a3c7a1b72f5ade1094ed1c6345a8f153506d114af5bf8accd9"}, + {file = "aiohttp-3.11.11-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3af41686ccec6a0f2bdc66686dc0f403c41ac2089f80e2214a0f82d001052c03"}, + {file = "aiohttp-3.11.11-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:70d1f9dde0e5dd9e292a6d4d00058737052b01f3532f69c0c65818dac26dc287"}, + {file = "aiohttp-3.11.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:249cc6912405917344192b9f9ea5cd5b139d49e0d2f5c7f70bdfaf6b4dbf3a2e"}, + {file = "aiohttp-3.11.11-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0eb98d90b6690827dcc84c246811feeb4e1eea683c0eac6caed7549be9c84665"}, + {file = "aiohttp-3.11.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ec82bf1fda6cecce7f7b915f9196601a1bd1a3079796b76d16ae4cce6d0ef89b"}, + {file = "aiohttp-3.11.11-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:9fd46ce0845cfe28f108888b3ab17abff84ff695e01e73657eec3f96d72eef34"}, + {file = "aiohttp-3.11.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:bd176afcf8f5d2aed50c3647d4925d0db0579d96f75a31e77cbaf67d8a87742d"}, + {file = "aiohttp-3.11.11-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:ec2aa89305006fba9ffb98970db6c8221541be7bee4c1d027421d6f6df7d1ce2"}, + {file = "aiohttp-3.11.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:92cde43018a2e17d48bb09c79e4d4cb0e236de5063ce897a5e40ac7cb4878773"}, + {file = "aiohttp-3.11.11-cp311-cp311-win32.whl", hash = "sha256:aba807f9569455cba566882c8938f1a549f205ee43c27b126e5450dc9f83cc62"}, + {file = "aiohttp-3.11.11-cp311-cp311-win_amd64.whl", hash = "sha256:ae545f31489548c87b0cced5755cfe5a5308d00407000e72c4fa30b19c3220ac"}, + {file = "aiohttp-3.11.11-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:e595c591a48bbc295ebf47cb91aebf9bd32f3ff76749ecf282ea7f9f6bb73886"}, + {file = "aiohttp-3.11.11-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3ea1b59dc06396b0b424740a10a0a63974c725b1c64736ff788a3689d36c02d2"}, + {file = "aiohttp-3.11.11-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8811f3f098a78ffa16e0ea36dffd577eb031aea797cbdba81be039a4169e242c"}, + {file = "aiohttp-3.11.11-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd7227b87a355ce1f4bf83bfae4399b1f5bb42e0259cb9405824bd03d2f4336a"}, + {file = "aiohttp-3.11.11-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d40f9da8cabbf295d3a9dae1295c69975b86d941bc20f0a087f0477fa0a66231"}, + {file = "aiohttp-3.11.11-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ffb3dc385f6bb1568aa974fe65da84723210e5d9707e360e9ecb51f59406cd2e"}, + {file = "aiohttp-3.11.11-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8f5f7515f3552d899c61202d99dcb17d6e3b0de777900405611cd747cecd1b8"}, + {file = "aiohttp-3.11.11-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3499c7ffbfd9c6a3d8d6a2b01c26639da7e43d47c7b4f788016226b1e711caa8"}, + {file = "aiohttp-3.11.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8e2bf8029dbf0810c7bfbc3e594b51c4cc9101fbffb583a3923aea184724203c"}, + {file = "aiohttp-3.11.11-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b6212a60e5c482ef90f2d788835387070a88d52cf6241d3916733c9176d39eab"}, + {file = "aiohttp-3.11.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:d119fafe7b634dbfa25a8c597718e69a930e4847f0b88e172744be24515140da"}, + {file = "aiohttp-3.11.11-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:6fba278063559acc730abf49845d0e9a9e1ba74f85f0ee6efd5803f08b285853"}, + {file = "aiohttp-3.11.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:92fc484e34b733704ad77210c7957679c5c3877bd1e6b6d74b185e9320cc716e"}, + {file = "aiohttp-3.11.11-cp312-cp312-win32.whl", hash = "sha256:9f5b3c1ed63c8fa937a920b6c1bec78b74ee09593b3f5b979ab2ae5ef60d7600"}, + {file = "aiohttp-3.11.11-cp312-cp312-win_amd64.whl", hash = "sha256:1e69966ea6ef0c14ee53ef7a3d68b564cc408121ea56c0caa2dc918c1b2f553d"}, + {file = "aiohttp-3.11.11-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:541d823548ab69d13d23730a06f97460f4238ad2e5ed966aaf850d7c369782d9"}, + {file = "aiohttp-3.11.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:929f3ed33743a49ab127c58c3e0a827de0664bfcda566108989a14068f820194"}, + {file = "aiohttp-3.11.11-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0882c2820fd0132240edbb4a51eb8ceb6eef8181db9ad5291ab3332e0d71df5f"}, + {file = "aiohttp-3.11.11-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b63de12e44935d5aca7ed7ed98a255a11e5cb47f83a9fded7a5e41c40277d104"}, + {file = "aiohttp-3.11.11-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa54f8ef31d23c506910c21163f22b124facb573bff73930735cf9fe38bf7dff"}, + {file = "aiohttp-3.11.11-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a344d5dc18074e3872777b62f5f7d584ae4344cd6006c17ba12103759d407af3"}, + {file = "aiohttp-3.11.11-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b7fb429ab1aafa1f48578eb315ca45bd46e9c37de11fe45c7f5f4138091e2f1"}, + {file = "aiohttp-3.11.11-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c341c7d868750e31961d6d8e60ff040fb9d3d3a46d77fd85e1ab8e76c3e9a5c4"}, + {file = "aiohttp-3.11.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ed9ee95614a71e87f1a70bc81603f6c6760128b140bc4030abe6abaa988f1c3d"}, + {file = "aiohttp-3.11.11-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:de8d38f1c2810fa2a4f1d995a2e9c70bb8737b18da04ac2afbf3971f65781d87"}, + {file = "aiohttp-3.11.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:a9b7371665d4f00deb8f32208c7c5e652059b0fda41cf6dbcac6114a041f1cc2"}, + {file = "aiohttp-3.11.11-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:620598717fce1b3bd14dd09947ea53e1ad510317c85dda2c9c65b622edc96b12"}, + {file = "aiohttp-3.11.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:bf8d9bfee991d8acc72d060d53860f356e07a50f0e0d09a8dfedea1c554dd0d5"}, + {file = "aiohttp-3.11.11-cp313-cp313-win32.whl", hash = "sha256:9d73ee3725b7a737ad86c2eac5c57a4a97793d9f442599bea5ec67ac9f4bdc3d"}, + {file = "aiohttp-3.11.11-cp313-cp313-win_amd64.whl", hash = "sha256:c7a06301c2fb096bdb0bd25fe2011531c1453b9f2c163c8031600ec73af1cc99"}, + {file = "aiohttp-3.11.11-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3e23419d832d969f659c208557de4a123e30a10d26e1e14b73431d3c13444c2e"}, + {file = "aiohttp-3.11.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:21fef42317cf02e05d3b09c028712e1d73a9606f02467fd803f7c1f39cc59add"}, + {file = "aiohttp-3.11.11-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1f21bb8d0235fc10c09ce1d11ffbd40fc50d3f08a89e4cf3a0c503dc2562247a"}, + {file = "aiohttp-3.11.11-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1642eceeaa5ab6c9b6dfeaaa626ae314d808188ab23ae196a34c9d97efb68350"}, + {file = "aiohttp-3.11.11-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2170816e34e10f2fd120f603e951630f8a112e1be3b60963a1f159f5699059a6"}, + {file = "aiohttp-3.11.11-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8be8508d110d93061197fd2d6a74f7401f73b6d12f8822bbcd6d74f2b55d71b1"}, + {file = "aiohttp-3.11.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4eed954b161e6b9b65f6be446ed448ed3921763cc432053ceb606f89d793927e"}, + {file = "aiohttp-3.11.11-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6c9af134da4bc9b3bd3e6a70072509f295d10ee60c697826225b60b9959acdd"}, + {file = "aiohttp-3.11.11-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:44167fc6a763d534a6908bdb2592269b4bf30a03239bcb1654781adf5e49caf1"}, + {file = "aiohttp-3.11.11-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:479b8c6ebd12aedfe64563b85920525d05d394b85f166b7873c8bde6da612f9c"}, + {file = "aiohttp-3.11.11-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:10b4ff0ad793d98605958089fabfa350e8e62bd5d40aa65cdc69d6785859f94e"}, + {file = "aiohttp-3.11.11-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:b540bd67cfb54e6f0865ceccd9979687210d7ed1a1cc8c01f8e67e2f1e883d28"}, + {file = "aiohttp-3.11.11-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1dac54e8ce2ed83b1f6b1a54005c87dfed139cf3f777fdc8afc76e7841101226"}, + {file = "aiohttp-3.11.11-cp39-cp39-win32.whl", hash = "sha256:568c1236b2fde93b7720f95a890741854c1200fba4a3471ff48b2934d2d93fd3"}, + {file = "aiohttp-3.11.11-cp39-cp39-win_amd64.whl", hash = "sha256:943a8b052e54dfd6439fd7989f67fc6a7f2138d0a2cf0a7de5f18aa4fe7eb3b1"}, + {file = "aiohttp-3.11.11.tar.gz", hash = "sha256:bb49c7f1e6ebf3821a42d81d494f538107610c3a705987f53068546b0e90303e"}, ] [package.dependencies] @@ -138,7 +123,8 @@ aiosignal = ">=1.1.2" attrs = ">=17.3.0" frozenlist = ">=1.1.1" multidict = ">=4.5,<7.0" -yarl = ">=1.0,<2.0" +propcache = ">=0.2.0" +yarl = ">=1.17.0,<2.0" [package.extras] speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] @@ -656,13 +642,13 @@ reference = "aliyun" [[package]] name = "anyio" -version = "4.7.0" +version = "4.8.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false python-versions = ">=3.9" files = [ - {file = "anyio-4.7.0-py3-none-any.whl", hash = "sha256:ea60c3723ab42ba6fff7e8ccb0488c898ec538ff4df1f1d5e642c3601d07e352"}, - {file = "anyio-4.7.0.tar.gz", hash = "sha256:2f834749c602966b7d456a7567cafcb309f96482b5081d14ac93ccd457f9dd48"}, + {file = "anyio-4.8.0-py3-none-any.whl", hash = "sha256:b5011f270ab5eb0abf13385f851315585cc37ef330dd88e27ec3d34d651fd47a"}, + {file = "anyio-4.8.0.tar.gz", hash = "sha256:1d9fe889df5212298c0c0723fa20479d1b94883a2df44bd3897aa91083316f7a"}, ] [package.dependencies] @@ -672,7 +658,7 @@ typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""} [package.extras] doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx_rtd_theme"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21)"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21)"] trio = ["trio (>=0.26.1)"] [package.source] @@ -1153,13 +1139,13 @@ reference = "aliyun" [[package]] name = "botocore" -version = "1.35.87" +version = "1.35.94" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">= 3.8" files = [ - {file = "botocore-1.35.87-py3-none-any.whl", hash = "sha256:81cf84f12030d9ab3829484b04765d5641697ec53c2ac2b3987a99eefe501692"}, - {file = "botocore-1.35.87.tar.gz", hash = "sha256:3062d073ce4170a994099270f469864169dc1a1b8b3d4a21c14ce0ae995e0f89"}, + {file = "botocore-1.35.94-py3-none-any.whl", hash = "sha256:d784d944865d8279c79d2301fc09ac28b5221d4e7328fb4e23c642c253b9932c"}, + {file = "botocore-1.35.94.tar.gz", hash = "sha256:2b3309b356541faa4d88bb957dcac1d8004aa44953c0b7d4521a6cc5d3d5d6ba"}, ] [package.dependencies] @@ -2480,13 +2466,13 @@ reference = "aliyun" [[package]] name = "dataclass-wizard" -version = "0.33.0" -description = "Lightning-fast JSON wizardry for Python dataclasses — effortless serialization with no external tools required!" +version = "0.34.0" +description = "Lightning-fast JSON wizardry for Python dataclasses — effortless serialization right out of the box!" optional = false python-versions = "*" files = [ - {file = "dataclass-wizard-0.33.0.tar.gz", hash = "sha256:10edbac626c390ed1f3be4f972303bc2c25c162fd453d55adfe22b05b263ee6d"}, - {file = "dataclass_wizard-0.33.0-py2.py3-none-any.whl", hash = "sha256:d257d65a707f6ea76140ad5442189ffe6fc68e1d678dc6b3fa25754e3106160d"}, + {file = "dataclass-wizard-0.34.0.tar.gz", hash = "sha256:f917db2220e395806a852f7c57e9011dd783b7fe3eee763bb56ae2d48968ab03"}, + {file = "dataclass_wizard-0.34.0-py2.py3-none-any.whl", hash = "sha256:9c184edd3526c3523fec2de5b6d6cdfcdc97ed7b2c5ba8bc574284b793704f01"}, ] [package.dependencies] @@ -2829,13 +2815,13 @@ reference = "aliyun" [[package]] name = "elastic-transport" -version = "8.15.1" +version = "8.17.0" description = "Transport classes and utilities shared among Python Elastic client libraries" optional = false python-versions = ">=3.8" files = [ - {file = "elastic_transport-8.15.1-py3-none-any.whl", hash = "sha256:b5e82ff1679d8c7705a03fd85c7f6ef85d6689721762d41228dd312e34f331fc"}, - {file = "elastic_transport-8.15.1.tar.gz", hash = "sha256:9cac4ab5cf9402668cf305ae0b7d93ddc0c7b61461d6d1027850db6da9cc5742"}, + {file = "elastic_transport-8.17.0-py3-none-any.whl", hash = "sha256:59f553300866750e67a38828fede000576562a0e66930c641adb75249e0c95af"}, + {file = "elastic_transport-8.17.0.tar.gz", hash = "sha256:e755f38f99fa6ec5456e236b8e58f0eb18873ac8fe710f74b91a16dd562de2a5"}, ] [package.dependencies] @@ -2894,32 +2880,6 @@ type = "legacy" url = "https://mirrors.aliyun.com/pypi/simple" reference = "aliyun" -[[package]] -name = "environs" -version = "9.5.0" -description = "simplified environment variable parsing" -optional = false -python-versions = ">=3.6" -files = [ - {file = "environs-9.5.0-py2.py3-none-any.whl", hash = "sha256:1e549569a3de49c05f856f40bce86979e7d5ffbbc4398e7f338574c220189124"}, - {file = "environs-9.5.0.tar.gz", hash = "sha256:a76307b36fbe856bdca7ee9161e6c466fd7fcffc297109a118c59b54e27e30c9"}, -] - -[package.dependencies] -marshmallow = ">=3.0.0" -python-dotenv = "*" - -[package.extras] -dev = ["dj-database-url", "dj-email-url", "django-cache-url", "flake8 (==4.0.1)", "flake8-bugbear (==21.9.2)", "mypy (==0.910)", "pre-commit (>=2.4,<3.0)", "pytest", "tox"] -django = ["dj-database-url", "dj-email-url", "django-cache-url"] -lint = ["flake8 (==4.0.1)", "flake8-bugbear (==21.9.2)", "mypy (==0.910)", "pre-commit (>=2.4,<3.0)"] -tests = ["dj-database-url", "dj-email-url", "django-cache-url", "pytest"] - -[package.source] -type = "legacy" -url = "https://mirrors.aliyun.com/pypi/simple" -reference = "aliyun" - [[package]] name = "esdk-obs-python" version = "3.24.6.1" @@ -4297,17 +4257,17 @@ reference = "aliyun" [[package]] name = "gotrue" -version = "2.11.0" +version = "2.11.1" description = "Python Client Library for Supabase Auth" optional = false python-versions = ">=3.9,<4.0" files = [ - {file = "gotrue-2.11.0-py3-none-any.whl", hash = "sha256:62177ffd567448b352121bc7e9244ff018d59bb746dad476b51658f856d59cf8"}, - {file = "gotrue-2.11.0.tar.gz", hash = "sha256:a0a452748ef741337820c97b934327c25f796e7cd33c0bf4341346bcc5a837f5"}, + {file = "gotrue-2.11.1-py3-none-any.whl", hash = "sha256:1b2d915bdc65fd0ad608532759ce9c72fa2e910145c1e6901f2188519e7bcd2d"}, + {file = "gotrue-2.11.1.tar.gz", hash = "sha256:5594ceee60bd873e5f4fdd028b08dece3906f6013b6ed08e7786b71c0092fed0"}, ] [package.dependencies] -httpx = {version = ">=0.26,<0.28", extras = ["http2"]} +httpx = {version = ">=0.26,<0.29", extras = ["http2"]} pydantic = ">=1.10,<3" [package.source] @@ -4408,13 +4368,13 @@ reference = "aliyun" [[package]] name = "grpc-google-iam-v1" -version = "0.13.1" +version = "0.14.0" description = "IAM API client library" optional = false python-versions = ">=3.7" files = [ - {file = "grpc-google-iam-v1-0.13.1.tar.gz", hash = "sha256:3ff4b2fd9d990965e410965253c0da6f66205d5a8291c4c31c6ebecca18a9001"}, - {file = "grpc_google_iam_v1-0.13.1-py2.py3-none-any.whl", hash = "sha256:c3e86151a981811f30d5e7330f271cee53e73bb87755e88cc3b6f0c7b5fe374e"}, + {file = "grpc_google_iam_v1-0.14.0-py2.py3-none-any.whl", hash = "sha256:fb4a084b30099ba3ab07d61d620a0d4429570b13ff53bd37bac75235f98b7da4"}, + {file = "grpc_google_iam_v1-0.14.0.tar.gz", hash = "sha256:c66e07aa642e39bb37950f9e7f491f70dad150ac9801263b42b2814307c2df99"}, ] [package.dependencies] @@ -4429,70 +4389,70 @@ reference = "aliyun" [[package]] name = "grpcio" -version = "1.68.1" +version = "1.67.1" description = "HTTP/2-based RPC framework" optional = false python-versions = ">=3.8" files = [ - {file = "grpcio-1.68.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:d35740e3f45f60f3c37b1e6f2f4702c23867b9ce21c6410254c9c682237da68d"}, - {file = "grpcio-1.68.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:d99abcd61760ebb34bdff37e5a3ba333c5cc09feda8c1ad42547bea0416ada78"}, - {file = "grpcio-1.68.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:f8261fa2a5f679abeb2a0a93ad056d765cdca1c47745eda3f2d87f874ff4b8c9"}, - {file = "grpcio-1.68.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0feb02205a27caca128627bd1df4ee7212db051019a9afa76f4bb6a1a80ca95e"}, - {file = "grpcio-1.68.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:919d7f18f63bcad3a0f81146188e90274fde800a94e35d42ffe9eadf6a9a6330"}, - {file = "grpcio-1.68.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:963cc8d7d79b12c56008aabd8b457f400952dbea8997dd185f155e2f228db079"}, - {file = "grpcio-1.68.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ccf2ebd2de2d6661e2520dae293298a3803a98ebfc099275f113ce1f6c2a80f1"}, - {file = "grpcio-1.68.1-cp310-cp310-win32.whl", hash = "sha256:2cc1fd04af8399971bcd4f43bd98c22d01029ea2e56e69c34daf2bf8470e47f5"}, - {file = "grpcio-1.68.1-cp310-cp310-win_amd64.whl", hash = "sha256:ee2e743e51cb964b4975de572aa8fb95b633f496f9fcb5e257893df3be854746"}, - {file = "grpcio-1.68.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:55857c71641064f01ff0541a1776bfe04a59db5558e82897d35a7793e525774c"}, - {file = "grpcio-1.68.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4b177f5547f1b995826ef529d2eef89cca2f830dd8b2c99ffd5fde4da734ba73"}, - {file = "grpcio-1.68.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:3522c77d7e6606d6665ec8d50e867f13f946a4e00c7df46768f1c85089eae515"}, - {file = "grpcio-1.68.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9d1fae6bbf0816415b81db1e82fb3bf56f7857273c84dcbe68cbe046e58e1ccd"}, - {file = "grpcio-1.68.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:298ee7f80e26f9483f0b6f94cc0a046caf54400a11b644713bb5b3d8eb387600"}, - {file = "grpcio-1.68.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cbb5780e2e740b6b4f2d208e90453591036ff80c02cc605fea1af8e6fc6b1bbe"}, - {file = "grpcio-1.68.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ddda1aa22495d8acd9dfbafff2866438d12faec4d024ebc2e656784d96328ad0"}, - {file = "grpcio-1.68.1-cp311-cp311-win32.whl", hash = "sha256:b33bd114fa5a83f03ec6b7b262ef9f5cac549d4126f1dc702078767b10c46ed9"}, - {file = "grpcio-1.68.1-cp311-cp311-win_amd64.whl", hash = "sha256:7f20ebec257af55694d8f993e162ddf0d36bd82d4e57f74b31c67b3c6d63d8b2"}, - {file = "grpcio-1.68.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:8829924fffb25386995a31998ccbbeaa7367223e647e0122043dfc485a87c666"}, - {file = "grpcio-1.68.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:3aed6544e4d523cd6b3119b0916cef3d15ef2da51e088211e4d1eb91a6c7f4f1"}, - {file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:4efac5481c696d5cb124ff1c119a78bddbfdd13fc499e3bc0ca81e95fc573684"}, - {file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ab2d912ca39c51f46baf2a0d92aa265aa96b2443266fc50d234fa88bf877d8e"}, - {file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95c87ce2a97434dffe7327a4071839ab8e8bffd0054cc74cbe971fba98aedd60"}, - {file = "grpcio-1.68.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:e4842e4872ae4ae0f5497bf60a0498fa778c192cc7a9e87877abd2814aca9475"}, - {file = "grpcio-1.68.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:255b1635b0ed81e9f91da4fcc8d43b7ea5520090b9a9ad9340d147066d1d3613"}, - {file = "grpcio-1.68.1-cp312-cp312-win32.whl", hash = "sha256:7dfc914cc31c906297b30463dde0b9be48e36939575eaf2a0a22a8096e69afe5"}, - {file = "grpcio-1.68.1-cp312-cp312-win_amd64.whl", hash = "sha256:a0c8ddabef9c8f41617f213e527254c41e8b96ea9d387c632af878d05db9229c"}, - {file = "grpcio-1.68.1-cp313-cp313-linux_armv7l.whl", hash = "sha256:a47faedc9ea2e7a3b6569795c040aae5895a19dde0c728a48d3c5d7995fda385"}, - {file = "grpcio-1.68.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:390eee4225a661c5cd133c09f5da1ee3c84498dc265fd292a6912b65c421c78c"}, - {file = "grpcio-1.68.1-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:66a24f3d45c33550703f0abb8b656515b0ab777970fa275693a2f6dc8e35f1c1"}, - {file = "grpcio-1.68.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c08079b4934b0bf0a8847f42c197b1d12cba6495a3d43febd7e99ecd1cdc8d54"}, - {file = "grpcio-1.68.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8720c25cd9ac25dd04ee02b69256d0ce35bf8a0f29e20577427355272230965a"}, - {file = "grpcio-1.68.1-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:04cfd68bf4f38f5bb959ee2361a7546916bd9a50f78617a346b3aeb2b42e2161"}, - {file = "grpcio-1.68.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:c28848761a6520c5c6071d2904a18d339a796ebe6b800adc8b3f474c5ce3c3ad"}, - {file = "grpcio-1.68.1-cp313-cp313-win32.whl", hash = "sha256:77d65165fc35cff6e954e7fd4229e05ec76102d4406d4576528d3a3635fc6172"}, - {file = "grpcio-1.68.1-cp313-cp313-win_amd64.whl", hash = "sha256:a8040f85dcb9830d8bbb033ae66d272614cec6faceee88d37a88a9bd1a7a704e"}, - {file = "grpcio-1.68.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:eeb38ff04ab6e5756a2aef6ad8d94e89bb4a51ef96e20f45c44ba190fa0bcaad"}, - {file = "grpcio-1.68.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8a3869a6661ec8f81d93f4597da50336718bde9eb13267a699ac7e0a1d6d0bea"}, - {file = "grpcio-1.68.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:2c4cec6177bf325eb6faa6bd834d2ff6aa8bb3b29012cceb4937b86f8b74323c"}, - {file = "grpcio-1.68.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:12941d533f3cd45d46f202e3667be8ebf6bcb3573629c7ec12c3e211d99cfccf"}, - {file = "grpcio-1.68.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80af6f1e69c5e68a2be529990684abdd31ed6622e988bf18850075c81bb1ad6e"}, - {file = "grpcio-1.68.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e8dbe3e00771bfe3d04feed8210fc6617006d06d9a2679b74605b9fed3e8362c"}, - {file = "grpcio-1.68.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:83bbf5807dc3ee94ce1de2dfe8a356e1d74101e4b9d7aa8c720cc4818a34aded"}, - {file = "grpcio-1.68.1-cp38-cp38-win32.whl", hash = "sha256:8cb620037a2fd9eeee97b4531880e439ebfcd6d7d78f2e7dcc3726428ab5ef63"}, - {file = "grpcio-1.68.1-cp38-cp38-win_amd64.whl", hash = "sha256:52fbf85aa71263380d330f4fce9f013c0798242e31ede05fcee7fbe40ccfc20d"}, - {file = "grpcio-1.68.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:cb400138e73969eb5e0535d1d06cae6a6f7a15f2cc74add320e2130b8179211a"}, - {file = "grpcio-1.68.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a1b988b40f2fd9de5c820f3a701a43339d8dcf2cb2f1ca137e2c02671cc83ac1"}, - {file = "grpcio-1.68.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:96f473cdacfdd506008a5d7579c9f6a7ff245a9ade92c3c0265eb76cc591914f"}, - {file = "grpcio-1.68.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:37ea3be171f3cf3e7b7e412a98b77685eba9d4fd67421f4a34686a63a65d99f9"}, - {file = "grpcio-1.68.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ceb56c4285754e33bb3c2fa777d055e96e6932351a3082ce3559be47f8024f0"}, - {file = "grpcio-1.68.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:dffd29a2961f3263a16d73945b57cd44a8fd0b235740cb14056f0612329b345e"}, - {file = "grpcio-1.68.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:025f790c056815b3bf53da850dd70ebb849fd755a4b1ac822cb65cd631e37d43"}, - {file = "grpcio-1.68.1-cp39-cp39-win32.whl", hash = "sha256:1098f03dedc3b9810810568060dea4ac0822b4062f537b0f53aa015269be0a76"}, - {file = "grpcio-1.68.1-cp39-cp39-win_amd64.whl", hash = "sha256:334ab917792904245a028f10e803fcd5b6f36a7b2173a820c0b5b076555825e1"}, - {file = "grpcio-1.68.1.tar.gz", hash = "sha256:44a8502dd5de653ae6a73e2de50a401d84184f0331d0ac3daeb044e66d5c5054"}, -] - -[package.extras] -protobuf = ["grpcio-tools (>=1.68.1)"] + {file = "grpcio-1.67.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:8b0341d66a57f8a3119b77ab32207072be60c9bf79760fa609c5609f2deb1f3f"}, + {file = "grpcio-1.67.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:f5a27dddefe0e2357d3e617b9079b4bfdc91341a91565111a21ed6ebbc51b22d"}, + {file = "grpcio-1.67.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:43112046864317498a33bdc4797ae6a268c36345a910de9b9c17159d8346602f"}, + {file = "grpcio-1.67.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c9b929f13677b10f63124c1a410994a401cdd85214ad83ab67cc077fc7e480f0"}, + {file = "grpcio-1.67.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7d1797a8a3845437d327145959a2c0c47c05947c9eef5ff1a4c80e499dcc6fa"}, + {file = "grpcio-1.67.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:0489063974d1452436139501bf6b180f63d4977223ee87488fe36858c5725292"}, + {file = "grpcio-1.67.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9fd042de4a82e3e7aca44008ee2fb5da01b3e5adb316348c21980f7f58adc311"}, + {file = "grpcio-1.67.1-cp310-cp310-win32.whl", hash = "sha256:638354e698fd0c6c76b04540a850bf1db27b4d2515a19fcd5cf645c48d3eb1ed"}, + {file = "grpcio-1.67.1-cp310-cp310-win_amd64.whl", hash = "sha256:608d87d1bdabf9e2868b12338cd38a79969eaf920c89d698ead08f48de9c0f9e"}, + {file = "grpcio-1.67.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:7818c0454027ae3384235a65210bbf5464bd715450e30a3d40385453a85a70cb"}, + {file = "grpcio-1.67.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ea33986b70f83844cd00814cee4451055cd8cab36f00ac64a31f5bb09b31919e"}, + {file = "grpcio-1.67.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:c7a01337407dd89005527623a4a72c5c8e2894d22bead0895306b23c6695698f"}, + {file = "grpcio-1.67.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:80b866f73224b0634f4312a4674c1be21b2b4afa73cb20953cbbb73a6b36c3cc"}, + {file = "grpcio-1.67.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9fff78ba10d4250bfc07a01bd6254a6d87dc67f9627adece85c0b2ed754fa96"}, + {file = "grpcio-1.67.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8a23cbcc5bb11ea7dc6163078be36c065db68d915c24f5faa4f872c573bb400f"}, + {file = "grpcio-1.67.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1a65b503d008f066e994f34f456e0647e5ceb34cfcec5ad180b1b44020ad4970"}, + {file = "grpcio-1.67.1-cp311-cp311-win32.whl", hash = "sha256:e29ca27bec8e163dca0c98084040edec3bc49afd10f18b412f483cc68c712744"}, + {file = "grpcio-1.67.1-cp311-cp311-win_amd64.whl", hash = "sha256:786a5b18544622bfb1e25cc08402bd44ea83edfb04b93798d85dca4d1a0b5be5"}, + {file = "grpcio-1.67.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:267d1745894200e4c604958da5f856da6293f063327cb049a51fe67348e4f953"}, + {file = "grpcio-1.67.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:85f69fdc1d28ce7cff8de3f9c67db2b0ca9ba4449644488c1e0303c146135ddb"}, + {file = "grpcio-1.67.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:f26b0b547eb8d00e195274cdfc63ce64c8fc2d3e2d00b12bf468ece41a0423a0"}, + {file = "grpcio-1.67.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4422581cdc628f77302270ff839a44f4c24fdc57887dc2a45b7e53d8fc2376af"}, + {file = "grpcio-1.67.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d7616d2ded471231c701489190379e0c311ee0a6c756f3c03e6a62b95a7146e"}, + {file = "grpcio-1.67.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8a00efecde9d6fcc3ab00c13f816313c040a28450e5e25739c24f432fc6d3c75"}, + {file = "grpcio-1.67.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:699e964923b70f3101393710793289e42845791ea07565654ada0969522d0a38"}, + {file = "grpcio-1.67.1-cp312-cp312-win32.whl", hash = "sha256:4e7b904484a634a0fff132958dabdb10d63e0927398273917da3ee103e8d1f78"}, + {file = "grpcio-1.67.1-cp312-cp312-win_amd64.whl", hash = "sha256:5721e66a594a6c4204458004852719b38f3d5522082be9061d6510b455c90afc"}, + {file = "grpcio-1.67.1-cp313-cp313-linux_armv7l.whl", hash = "sha256:aa0162e56fd10a5547fac8774c4899fc3e18c1aa4a4759d0ce2cd00d3696ea6b"}, + {file = "grpcio-1.67.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:beee96c8c0b1a75d556fe57b92b58b4347c77a65781ee2ac749d550f2a365dc1"}, + {file = "grpcio-1.67.1-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:a93deda571a1bf94ec1f6fcda2872dad3ae538700d94dc283c672a3b508ba3af"}, + {file = "grpcio-1.67.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0e6f255980afef598a9e64a24efce87b625e3e3c80a45162d111a461a9f92955"}, + {file = "grpcio-1.67.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e838cad2176ebd5d4a8bb03955138d6589ce9e2ce5d51c3ada34396dbd2dba8"}, + {file = "grpcio-1.67.1-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:a6703916c43b1d468d0756c8077b12017a9fcb6a1ef13faf49e67d20d7ebda62"}, + {file = "grpcio-1.67.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:917e8d8994eed1d86b907ba2a61b9f0aef27a2155bca6cbb322430fc7135b7bb"}, + {file = "grpcio-1.67.1-cp313-cp313-win32.whl", hash = "sha256:e279330bef1744040db8fc432becc8a727b84f456ab62b744d3fdb83f327e121"}, + {file = "grpcio-1.67.1-cp313-cp313-win_amd64.whl", hash = "sha256:fa0c739ad8b1996bd24823950e3cb5152ae91fca1c09cc791190bf1627ffefba"}, + {file = "grpcio-1.67.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:178f5db771c4f9a9facb2ab37a434c46cb9be1a75e820f187ee3d1e7805c4f65"}, + {file = "grpcio-1.67.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0f3e49c738396e93b7ba9016e153eb09e0778e776df6090c1b8c91877cc1c426"}, + {file = "grpcio-1.67.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:24e8a26dbfc5274d7474c27759b54486b8de23c709d76695237515bc8b5baeab"}, + {file = "grpcio-1.67.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b6c16489326d79ead41689c4b84bc40d522c9a7617219f4ad94bc7f448c5085"}, + {file = "grpcio-1.67.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:60e6a4dcf5af7bbc36fd9f81c9f372e8ae580870a9e4b6eafe948cd334b81cf3"}, + {file = "grpcio-1.67.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:95b5f2b857856ed78d72da93cd7d09b6db8ef30102e5e7fe0961fe4d9f7d48e8"}, + {file = "grpcio-1.67.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b49359977c6ec9f5d0573ea4e0071ad278ef905aa74e420acc73fd28ce39e9ce"}, + {file = "grpcio-1.67.1-cp38-cp38-win32.whl", hash = "sha256:f5b76ff64aaac53fede0cc93abf57894ab2a7362986ba22243d06218b93efe46"}, + {file = "grpcio-1.67.1-cp38-cp38-win_amd64.whl", hash = "sha256:804c6457c3cd3ec04fe6006c739579b8d35c86ae3298ffca8de57b493524b771"}, + {file = "grpcio-1.67.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:a25bdea92b13ff4d7790962190bf6bf5c4639876e01c0f3dda70fc2769616335"}, + {file = "grpcio-1.67.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cdc491ae35a13535fd9196acb5afe1af37c8237df2e54427be3eecda3653127e"}, + {file = "grpcio-1.67.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:85f862069b86a305497e74d0dc43c02de3d1d184fc2c180993aa8aa86fbd19b8"}, + {file = "grpcio-1.67.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ec74ef02010186185de82cc594058a3ccd8d86821842bbac9873fd4a2cf8be8d"}, + {file = "grpcio-1.67.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01f616a964e540638af5130469451cf580ba8c7329f45ca998ab66e0c7dcdb04"}, + {file = "grpcio-1.67.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:299b3d8c4f790c6bcca485f9963b4846dd92cf6f1b65d3697145d005c80f9fe8"}, + {file = "grpcio-1.67.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:60336bff760fbb47d7e86165408126f1dded184448e9a4c892189eb7c9d3f90f"}, + {file = "grpcio-1.67.1-cp39-cp39-win32.whl", hash = "sha256:5ed601c4c6008429e3d247ddb367fe8c7259c355757448d7c1ef7bd4a6739e8e"}, + {file = "grpcio-1.67.1-cp39-cp39-win_amd64.whl", hash = "sha256:5db70d32d6703b89912af16d6d45d78406374a8b8ef0d28140351dd0ec610e98"}, + {file = "grpcio-1.67.1.tar.gz", hash = "sha256:3dc2ed4cabea4dc14d5e708c2b426205956077cc5de419b4d4079315017e9732"}, +] + +[package.extras] +protobuf = ["grpcio-tools (>=1.67.1)"] [package.source] type = "legacy" @@ -5084,13 +5044,13 @@ reference = "aliyun" [[package]] name = "importlib-resources" -version = "6.4.5" +version = "6.5.2" description = "Read resources from Python packages" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "importlib_resources-6.4.5-py3-none-any.whl", hash = "sha256:ac29d5f956f01d5e4bb63102a5a19957f1b9175e45649977264a1416783bb717"}, - {file = "importlib_resources-6.4.5.tar.gz", hash = "sha256:980862a1d16c9e147a59603677fa2aa5fd82b87f223b6cb870695bcfce830065"}, + {file = "importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec"}, + {file = "importlib_resources-6.5.2.tar.gz", hash = "sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c"}, ] [package.extras] @@ -5708,6 +5668,144 @@ type = "legacy" url = "https://mirrors.aliyun.com/pypi/simple" reference = "aliyun" +[[package]] +name = "levenshtein" +version = "0.26.1" +description = "Python extension for computing string edit distances and similarities." +optional = false +python-versions = ">=3.9" +files = [ + {file = "levenshtein-0.26.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8dc4a4aecad538d944a1264c12769c99e3c0bf8e741fc5e454cc954913befb2e"}, + {file = "levenshtein-0.26.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ec108f368c12b25787c8b1a4537a1452bc53861c3ee4abc810cc74098278edcd"}, + {file = "levenshtein-0.26.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69229d651c97ed5b55b7ce92481ed00635cdbb80fbfb282a22636e6945dc52d5"}, + {file = "levenshtein-0.26.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79dcd157046d62482a7719b08ba9e3ce9ed3fc5b015af8ea989c734c702aedd4"}, + {file = "levenshtein-0.26.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f53f9173ae21b650b4ed8aef1d0ad0c37821f367c221a982f4d2922b3044e0d"}, + {file = "levenshtein-0.26.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3956f3c5c229257dbeabe0b6aacd2c083ebcc1e335842a6ff2217fe6cc03b6b"}, + {file = "levenshtein-0.26.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e1e83af732726987d2c4cd736f415dae8b966ba17b7a2239c8b7ffe70bfb5543"}, + {file = "levenshtein-0.26.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4f052c55046c2a9c9b5f742f39e02fa6e8db8039048b8c1c9e9fdd27c8a240a1"}, + {file = "levenshtein-0.26.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:9895b3a98f6709e293615fde0dcd1bb0982364278fa2072361a1a31b3e388b7a"}, + {file = "levenshtein-0.26.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:a3777de1d8bfca054465229beed23994f926311ce666f5a392c8859bb2722f16"}, + {file = "levenshtein-0.26.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:81c57e1135c38c5e6e3675b5e2077d8a8d3be32bf0a46c57276c092b1dffc697"}, + {file = "levenshtein-0.26.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:91d5e7d984891df3eff7ea9fec8cf06fdfacc03cd074fd1a410435706f73b079"}, + {file = "levenshtein-0.26.1-cp310-cp310-win32.whl", hash = "sha256:f48abff54054b4142ad03b323e80aa89b1d15cabc48ff49eb7a6ff7621829a56"}, + {file = "levenshtein-0.26.1-cp310-cp310-win_amd64.whl", hash = "sha256:79dd6ad799784ea7b23edd56e3bf94b3ca866c4c6dee845658ee75bb4aefdabf"}, + {file = "levenshtein-0.26.1-cp310-cp310-win_arm64.whl", hash = "sha256:3351ddb105ef010cc2ce474894c5d213c83dddb7abb96400beaa4926b0b745bd"}, + {file = "levenshtein-0.26.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:44c51f5d33b3cfb9db518b36f1288437a509edd82da94c4400f6a681758e0cb6"}, + {file = "levenshtein-0.26.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:56b93203e725f9df660e2afe3d26ba07d71871b6d6e05b8b767e688e23dfb076"}, + {file = "levenshtein-0.26.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:270d36c5da04a0d89990660aea8542227cbd8f5bc34e9fdfadd34916ff904520"}, + {file = "levenshtein-0.26.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:480674c05077eeb0b0f748546d4fcbb386d7c737f9fff0010400da3e8b552942"}, + {file = "levenshtein-0.26.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13946e37323728695ba7a22f3345c2e907d23f4600bc700bf9b4352fb0c72a48"}, + {file = "levenshtein-0.26.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ceb673f572d1d0dc9b1cd75792bb8bad2ae8eb78a7c6721e23a3867d318cb6f2"}, + {file = "levenshtein-0.26.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:42d6fa242e3b310ce6bfd5af0c83e65ef10b608b885b3bb69863c01fb2fcff98"}, + {file = "levenshtein-0.26.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b8b68295808893a81e0a1dbc2274c30dd90880f14d23078e8eb4325ee615fc68"}, + {file = "levenshtein-0.26.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b01061d377d1944eb67bc40bef5d4d2f762c6ab01598efd9297ce5d0047eb1b5"}, + {file = "levenshtein-0.26.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:9d12c8390f156745e533d01b30773b9753e41d8bbf8bf9dac4b97628cdf16314"}, + {file = "levenshtein-0.26.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:48825c9f967f922061329d1481b70e9fee937fc68322d6979bc623f69f75bc91"}, + {file = "levenshtein-0.26.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d8ec137170b95736842f99c0e7a9fd8f5641d0c1b63b08ce027198545d983e2b"}, + {file = "levenshtein-0.26.1-cp311-cp311-win32.whl", hash = "sha256:798f2b525a2e90562f1ba9da21010dde0d73730e277acaa5c52d2a6364fd3e2a"}, + {file = "levenshtein-0.26.1-cp311-cp311-win_amd64.whl", hash = "sha256:55b1024516c59df55f1cf1a8651659a568f2c5929d863d3da1ce8893753153bd"}, + {file = "levenshtein-0.26.1-cp311-cp311-win_arm64.whl", hash = "sha256:e52575cbc6b9764ea138a6f82d73d3b1bc685fe62e207ff46a963d4c773799f6"}, + {file = "levenshtein-0.26.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:cc741ca406d3704dc331a69c04b061fc952509a069b79cab8287413f434684bd"}, + {file = "levenshtein-0.26.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:821ace3b4e1c2e02b43cf5dc61aac2ea43bdb39837ac890919c225a2c3f2fea4"}, + {file = "levenshtein-0.26.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f92694c9396f55d4c91087efacf81297bef152893806fc54c289fc0254b45384"}, + {file = "levenshtein-0.26.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:51ba374de7a1797d04a14a4f0ad3602d2d71fef4206bb20a6baaa6b6a502da58"}, + {file = "levenshtein-0.26.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f7aa5c3327dda4ef952769bacec09c09ff5bf426e07fdc94478c37955681885b"}, + {file = "levenshtein-0.26.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33e2517e8d3c221de2d1183f400aed64211fcfc77077b291ed9f3bb64f141cdc"}, + {file = "levenshtein-0.26.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9092b622765c7649dd1d8af0f43354723dd6f4e570ac079ffd90b41033957438"}, + {file = "levenshtein-0.26.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:fc16796c85d7d8b259881d59cc8b5e22e940901928c2ff6924b2c967924e8a0b"}, + {file = "levenshtein-0.26.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e4370733967f5994ceeed8dc211089bedd45832ee688cecea17bfd35a9eb22b9"}, + {file = "levenshtein-0.26.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:3535ecfd88c9b283976b5bc61265855f59bba361881e92ed2b5367b6990c93fe"}, + {file = "levenshtein-0.26.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:90236e93d98bdfd708883a6767826fafd976dac8af8fc4a0fb423d4fa08e1bf0"}, + {file = "levenshtein-0.26.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:04b7cabb82edf566b1579b3ed60aac0eec116655af75a3c551fee8754ffce2ea"}, + {file = "levenshtein-0.26.1-cp312-cp312-win32.whl", hash = "sha256:ae382af8c76f6d2a040c0d9ca978baf461702ceb3f79a0a3f6da8d596a484c5b"}, + {file = "levenshtein-0.26.1-cp312-cp312-win_amd64.whl", hash = "sha256:fd091209798cfdce53746f5769987b4108fe941c54fb2e058c016ffc47872918"}, + {file = "levenshtein-0.26.1-cp312-cp312-win_arm64.whl", hash = "sha256:7e82f2ea44a81ad6b30d92a110e04cd3c8c7c6034b629aca30a3067fa174ae89"}, + {file = "levenshtein-0.26.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:790374a9f5d2cbdb30ee780403a62e59bef51453ac020668c1564d1e43438f0e"}, + {file = "levenshtein-0.26.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7b05c0415c386d00efda83d48db9db68edd02878d6dbc6df01194f12062be1bb"}, + {file = "levenshtein-0.26.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3114586032361722ddededf28401ce5baf1cf617f9f49fb86b8766a45a423ff"}, + {file = "levenshtein-0.26.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2532f8a13b68bf09f152d906f118a88da2063da22f44c90e904b142b0a53d534"}, + {file = "levenshtein-0.26.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:219c30be6aa734bf927188d1208b7d78d202a3eb017b1c5f01ab2034d2d4ccca"}, + {file = "levenshtein-0.26.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:397e245e77f87836308bd56305bba630010cd8298c34c4c44bd94990cdb3b7b1"}, + {file = "levenshtein-0.26.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aeff6ea3576f72e26901544c6c55c72a7b79b9983b6f913cba0e9edbf2f87a97"}, + {file = "levenshtein-0.26.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a19862e3539a697df722a08793994e334cd12791e8144851e8a1dee95a17ff63"}, + {file = "levenshtein-0.26.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:dc3b5a64f57c3c078d58b1e447f7d68cad7ae1b23abe689215d03fc434f8f176"}, + {file = "levenshtein-0.26.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:bb6c7347424a91317c5e1b68041677e4c8ed3e7823b5bbaedb95bffb3c3497ea"}, + {file = "levenshtein-0.26.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:b817376de4195a207cc0e4ca37754c0e1e1078c2a2d35a6ae502afde87212f9e"}, + {file = "levenshtein-0.26.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7b50c3620ff47c9887debbb4c154aaaac3e46be7fc2e5789ee8dbe128bce6a17"}, + {file = "levenshtein-0.26.1-cp313-cp313-win32.whl", hash = "sha256:9fb859da90262eb474c190b3ca1e61dee83add022c676520f5c05fdd60df902a"}, + {file = "levenshtein-0.26.1-cp313-cp313-win_amd64.whl", hash = "sha256:8adcc90e3a5bfb0a463581d85e599d950fe3c2938ac6247b29388b64997f6e2d"}, + {file = "levenshtein-0.26.1-cp313-cp313-win_arm64.whl", hash = "sha256:c2599407e029865dc66d210b8804c7768cbdbf60f061d993bb488d5242b0b73e"}, + {file = "levenshtein-0.26.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:dc54ced948fc3feafce8ad4ba4239d8ffc733a0d70e40c0363ac2a7ab2b7251e"}, + {file = "levenshtein-0.26.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e6516f69213ae393a220e904332f1a6bfc299ba22cf27a6520a1663a08eba0fb"}, + {file = "levenshtein-0.26.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4cfea4eada1746d0c75a864bc7e9e63d4a6e987c852d6cec8d9cb0c83afe25b"}, + {file = "levenshtein-0.26.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a323161dfeeac6800eb13cfe76a8194aec589cd948bcf1cdc03f66cc3ec26b72"}, + {file = "levenshtein-0.26.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2c23e749b68ebc9a20b9047317b5cd2053b5856315bc8636037a8adcbb98bed1"}, + {file = "levenshtein-0.26.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f80dd7432d4b6cf493d012d22148db7af769017deb31273e43406b1fb7f091c"}, + {file = "levenshtein-0.26.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0ae7cd6e4312c6ef34b2e273836d18f9fff518d84d823feff5ad7c49668256e0"}, + {file = "levenshtein-0.26.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:dcdad740e841d791b805421c2b20e859b4ed556396d3063b3aa64cd055be648c"}, + {file = "levenshtein-0.26.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:e07afb1613d6f5fd99abd4e53ad3b446b4efaa0f0d8e9dfb1d6d1b9f3f884d32"}, + {file = "levenshtein-0.26.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:f1add8f1d83099a98ae4ac472d896b7e36db48c39d3db25adf12b373823cdeff"}, + {file = "levenshtein-0.26.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:1010814b1d7a60833a951f2756dfc5c10b61d09976ce96a0edae8fecdfb0ea7c"}, + {file = "levenshtein-0.26.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:33fa329d1bb65ce85e83ceda281aea31cee9f2f6e167092cea54f922080bcc66"}, + {file = "levenshtein-0.26.1-cp39-cp39-win32.whl", hash = "sha256:488a945312f2f16460ab61df5b4beb1ea2254c521668fd142ce6298006296c98"}, + {file = "levenshtein-0.26.1-cp39-cp39-win_amd64.whl", hash = "sha256:9f942104adfddd4b336c3997050121328c39479f69de702d7d144abb69ea7ab9"}, + {file = "levenshtein-0.26.1-cp39-cp39-win_arm64.whl", hash = "sha256:c1d8f85b2672939f85086ed75effcf768f6077516a3e299c2ba1f91bc4644c22"}, + {file = "levenshtein-0.26.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:6cf8f1efaf90ca585640c5d418c30b7d66d9ac215cee114593957161f63acde0"}, + {file = "levenshtein-0.26.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d5b2953978b8c158dd5cd93af8216a5cfddbf9de66cf5481c2955f44bb20767a"}, + {file = "levenshtein-0.26.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b952b3732c4631c49917d4b15d78cb4a2aa006c1d5c12e2a23ba8e18a307a055"}, + {file = "levenshtein-0.26.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07227281e12071168e6ae59238918a56d2a0682e529f747b5431664f302c0b42"}, + {file = "levenshtein-0.26.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8191241cd8934feaf4d05d0cc0e5e72877cbb17c53bbf8c92af9f1aedaa247e9"}, + {file = "levenshtein-0.26.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:9e70d7ee157a9b698c73014f6e2b160830e7d2d64d2e342fefc3079af3c356fc"}, + {file = "levenshtein-0.26.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:0eb3059f826f6cb0a5bca4a85928070f01e8202e7ccafcba94453470f83e49d4"}, + {file = "levenshtein-0.26.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:6c389e44da12d6fb1d7ba0a709a32a96c9391e9be4160ccb9269f37e040599ee"}, + {file = "levenshtein-0.26.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e9de292f2c51a7d34a0ae23bec05391b8f61f35781cd3e4c6d0533e06250c55"}, + {file = "levenshtein-0.26.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d87215113259efdca8716e53b6d59ab6d6009e119d95d45eccc083148855f33"}, + {file = "levenshtein-0.26.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18f00a3eebf68a82fb651d8d0e810c10bfaa60c555d21dde3ff81350c74fb4c2"}, + {file = "levenshtein-0.26.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:b3554c1b59de63d05075577380340c185ff41b028e541c0888fddab3c259a2b4"}, + {file = "levenshtein-0.26.1.tar.gz", hash = "sha256:0d19ba22330d50609b2349021ec3cf7d905c6fe21195a2d0d876a146e7ed2575"}, +] + +[package.dependencies] +rapidfuzz = ">=3.9.0,<4.0.0" + +[package.source] +type = "legacy" +url = "https://mirrors.aliyun.com/pypi/simple" +reference = "aliyun" + +[[package]] +name = "litellm" +version = "1.51.3" +description = "Library to easily interface with LLM API providers" +optional = false +python-versions = ">=3.8, !=2.7.*, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*, !=3.7.*" +files = [ + {file = "litellm-1.51.3-py3-none-any.whl", hash = "sha256:440d3c7cc5ab8eeb12cee8f4d806bff05b7db834ebc11117d7fa070a1142ced5"}, + {file = "litellm-1.51.3.tar.gz", hash = "sha256:31eff9fcbf7b058bac0fd7432c4ea0487e8555f12446a1f30e5862e33716f44d"}, +] + +[package.dependencies] +aiohttp = "*" +click = "*" +importlib-metadata = ">=6.8.0" +jinja2 = ">=3.1.2,<4.0.0" +jsonschema = ">=4.22.0,<5.0.0" +openai = ">=1.52.0" +pydantic = ">=2.0.0,<3.0.0" +python-dotenv = ">=0.2.0" +requests = ">=2.31.0,<3.0.0" +tiktoken = ">=0.7.0" +tokenizers = "*" + +[package.extras] +extra-proxy = ["azure-identity (>=1.15.0,<2.0.0)", "azure-keyvault-secrets (>=4.8.0,<5.0.0)", "google-cloud-kms (>=2.21.3,<3.0.0)", "prisma (==0.11.0)", "resend (>=0.8.0,<0.9.0)"] +proxy = ["PyJWT (>=2.8.0,<3.0.0)", "apscheduler (>=3.10.4,<4.0.0)", "backoff", "cryptography (>=42.0.5,<43.0.0)", "fastapi (>=0.111.0,<0.112.0)", "fastapi-sso (>=0.10.0,<0.11.0)", "gunicorn (>=22.0.0,<23.0.0)", "orjson (>=3.9.7,<4.0.0)", "pynacl (>=1.5.0,<2.0.0)", "python-multipart (>=0.0.9,<0.0.10)", "pyyaml (>=6.0.1,<7.0.0)", "rq", "uvicorn (>=0.22.0,<0.23.0)"] + +[package.source] +type = "legacy" +url = "https://mirrors.aliyun.com/pypi/simple" +reference = "aliyun" + [[package]] name = "llvmlite" version = "0.43.0" @@ -6152,13 +6250,13 @@ reference = "aliyun" [[package]] name = "marshmallow" -version = "3.23.2" +version = "3.24.1" description = "A lightweight library for converting complex datatypes to and from native Python datatypes." optional = false python-versions = ">=3.9" files = [ - {file = "marshmallow-3.23.2-py3-none-any.whl", hash = "sha256:bcaf2d6fd74fb1459f8450e85d994997ad3e70036452cbfa4ab685acb19479b3"}, - {file = "marshmallow-3.23.2.tar.gz", hash = "sha256:c448ac6455ca4d794773f00bae22c2f351d62d739929f761dce5eacb5c468d7f"}, + {file = "marshmallow-3.24.1-py3-none-any.whl", hash = "sha256:ddb5c9987017d37be351c184e4e867e7bf55f7331f4da730dedad6b7af662cdd"}, + {file = "marshmallow-3.24.1.tar.gz", hash = "sha256:efdcb656ac8788f0e3d1d938f8dc0f237bf1a99aff8f6dfbffa594981641cea0"}, ] [package.dependencies] @@ -6166,7 +6264,7 @@ packaging = ">=17.0" [package.extras] dev = ["marshmallow[tests]", "pre-commit (>=3.5,<5.0)", "tox"] -docs = ["alabaster (==1.0.0)", "autodocsumm (==0.2.14)", "sphinx (==8.1.3)", "sphinx-issues (==5.0.0)", "sphinx-version-warning (==1.1.2)"] +docs = ["alabaster (==1.0.0)", "autodocsumm (==0.2.14)", "sphinx (==8.1.3)", "sphinx-issues (==5.0.0)"] tests = ["pytest", "simplejson"] [package.source] @@ -6245,15 +6343,15 @@ reference = "aliyun" [[package]] name = "milvus-lite" -version = "2.4.10" +version = "2.4.11" description = "A lightweight version of Milvus wrapped with Python." optional = false python-versions = ">=3.7" files = [ - {file = "milvus_lite-2.4.10-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:fc4246d3ed7d1910847afce0c9ba18212e93a6e9b8406048436940578dfad5cb"}, - {file = "milvus_lite-2.4.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:74a8e07c5e3b057df17fbb46913388e84df1dc403a200f4e423799a58184c800"}, - {file = "milvus_lite-2.4.10-py3-none-manylinux2014_aarch64.whl", hash = "sha256:240c7386b747bad696ecb5bd1f58d491e86b9d4b92dccee3315ed7256256eddc"}, - {file = "milvus_lite-2.4.10-py3-none-manylinux2014_x86_64.whl", hash = "sha256:211d2e334a043f9282bdd9755f76b9b2d93b23bffa7af240919ffce6a8dfe325"}, + {file = "milvus_lite-2.4.11-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9e563ae0dca1b41bfd76b90f06b2bcc474460fe4eba142c9bab18d2747ff843b"}, + {file = "milvus_lite-2.4.11-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d21472bd24eb327542817829ce7cb51878318e6173c4d62353c77421aecf98d6"}, + {file = "milvus_lite-2.4.11-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8e6ef27f7f84976f9fd0047b675ede746db2e0cc581c44a916ac9e71e0cef05d"}, + {file = "milvus_lite-2.4.11-py3-none-manylinux2014_x86_64.whl", hash = "sha256:551f56b49fcfbb330b658b4a3c56ed29ba9b692ec201edd1f2dade7f5e39957d"}, ] [package.dependencies] @@ -6266,13 +6364,13 @@ reference = "aliyun" [[package]] name = "mistune" -version = "3.0.2" +version = "3.1.0" description = "A sane and fast Markdown parser with useful plugins and renderers" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "mistune-3.0.2-py3-none-any.whl", hash = "sha256:71481854c30fdbc938963d3605b72501f5c10a9320ecd412c121c163a1c7d205"}, - {file = "mistune-3.0.2.tar.gz", hash = "sha256:fc7f93ded930c92394ef2cb6f04a8aabab4117a91449e72dcc8dfa646a508be8"}, + {file = "mistune-3.1.0-py3-none-any.whl", hash = "sha256:b05198cf6d671b3deba6c87ec6cf0d4eb7b72c524636eddb6dbf13823b52cee1"}, + {file = "mistune-3.1.0.tar.gz", hash = "sha256:dbcac2f78292b9dc066cd03b7a3a26b62d85f8159f2ea5fd28e55df79908d667"}, ] [package.source] @@ -6704,6 +6802,63 @@ type = "legacy" url = "https://mirrors.aliyun.com/pypi/simple" reference = "aliyun" +[[package]] +name = "mypy" +version = "1.13.0" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6607e0f1dd1fb7f0aca14d936d13fd19eba5e17e1cd2a14f808fa5f8f6d8f60a"}, + {file = "mypy-1.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a21be69bd26fa81b1f80a61ee7ab05b076c674d9b18fb56239d72e21d9f4c80"}, + {file = "mypy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b2353a44d2179846a096e25691d54d59904559f4232519d420d64da6828a3a7"}, + {file = "mypy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0730d1c6a2739d4511dc4253f8274cdd140c55c32dfb0a4cf8b7a43f40abfa6f"}, + {file = "mypy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:c5fc54dbb712ff5e5a0fca797e6e0aa25726c7e72c6a5850cfd2adbc1eb0a372"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:581665e6f3a8a9078f28d5502f4c334c0c8d802ef55ea0e7276a6e409bc0d82d"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ddb5b9bf82e05cc9a627e84707b528e5c7caaa1c55c69e175abb15a761cec2d"}, + {file = "mypy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:20c7ee0bc0d5a9595c46f38beb04201f2620065a93755704e141fcac9f59db2b"}, + {file = "mypy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3790ded76f0b34bc9c8ba4def8f919dd6a46db0f5a6610fb994fe8efdd447f73"}, + {file = "mypy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51f869f4b6b538229c1d1bcc1dd7d119817206e2bc54e8e374b3dfa202defcca"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5c7051a3461ae84dfb5dd15eff5094640c61c5f22257c8b766794e6dd85e72d5"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39bb21c69a5d6342f4ce526e4584bc5c197fd20a60d14a8624d8743fffb9472e"}, + {file = "mypy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:164f28cb9d6367439031f4c81e84d3ccaa1e19232d9d05d37cb0bd880d3f93c2"}, + {file = "mypy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a4c1bfcdbce96ff5d96fc9b08e3831acb30dc44ab02671eca5953eadad07d6d0"}, + {file = "mypy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0affb3a79a256b4183ba09811e3577c5163ed06685e4d4b46429a271ba174d2"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a7b44178c9760ce1a43f544e595d35ed61ac2c3de306599fa59b38a6048e1aa7"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d5092efb8516d08440e36626f0153b5006d4088c1d663d88bf79625af3d1d62"}, + {file = "mypy-1.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de2904956dac40ced10931ac967ae63c5089bd498542194b436eb097a9f77bc8"}, + {file = "mypy-1.13.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:7bfd8836970d33c2105562650656b6846149374dc8ed77d98424b40b09340ba7"}, + {file = "mypy-1.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:9f73dba9ec77acb86457a8fc04b5239822df0c14a082564737833d2963677dbc"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:100fac22ce82925f676a734af0db922ecfea991e1d7ec0ceb1e115ebe501301a"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7bcb0bb7f42a978bb323a7c88f1081d1b5dee77ca86f4100735a6f541299d8fb"}, + {file = "mypy-1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bde31fc887c213e223bbfc34328070996061b0833b0a4cfec53745ed61f3519b"}, + {file = "mypy-1.13.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:07de989f89786f62b937851295ed62e51774722e5444a27cecca993fc3f9cd74"}, + {file = "mypy-1.13.0-cp38-cp38-win_amd64.whl", hash = "sha256:4bde84334fbe19bad704b3f5b78c4abd35ff1026f8ba72b29de70dda0916beb6"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0246bcb1b5de7f08f2826451abd947bf656945209b140d16ed317f65a17dc7dc"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f5b7deae912cf8b77e990b9280f170381fdfbddf61b4ef80927edd813163732"}, + {file = "mypy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7029881ec6ffb8bc233a4fa364736789582c738217b133f1b55967115288a2bc"}, + {file = "mypy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3e38b980e5681f28f033f3be86b099a247b13c491f14bb8b1e1e134d23bb599d"}, + {file = "mypy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:a6789be98a2017c912ae6ccb77ea553bbaf13d27605d2ca20a76dfbced631b24"}, + {file = "mypy-1.13.0-py3-none-any.whl", hash = "sha256:9c250883f9fd81d212e0952c92dbfcc96fc237f4b7c92f56ac81fd48460b3e5a"}, + {file = "mypy-1.13.0.tar.gz", hash = "sha256:0291a61b6fbf3e6673e3405cfcc0e7650bebc7939659fdca2702958038bd835e"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +typing-extensions = ">=4.6.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +faster-cache = ["orjson"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + +[package.source] +type = "legacy" +url = "https://mirrors.aliyun.com/pypi/simple" +reference = "aliyun" + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -6720,6 +6875,22 @@ type = "legacy" url = "https://mirrors.aliyun.com/pypi/simple" reference = "aliyun" +[[package]] +name = "ndjson" +version = "0.3.1" +description = "JsonDecoder for ndjson" +optional = false +python-versions = "*" +files = [ + {file = "ndjson-0.3.1-py2.py3-none-any.whl", hash = "sha256:839c22275e6baa3040077b83c005ac24199b94973309a8a1809be962c753a410"}, + {file = "ndjson-0.3.1.tar.gz", hash = "sha256:bf9746cb6bb1cb53d172cda7f154c07c786d665ff28341e4e689b796b229e5d6"}, +] + +[package.source] +type = "legacy" +url = "https://mirrors.aliyun.com/pypi/simple" +reference = "aliyun" + [[package]] name = "nest-asyncio" version = "1.6.0" @@ -7181,13 +7352,13 @@ reference = "aliyun" [[package]] name = "opencensus-ext-azure" -version = "1.1.13" +version = "1.1.14" description = "OpenCensus Azure Monitor Exporter" optional = false python-versions = "*" files = [ - {file = "opencensus-ext-azure-1.1.13.tar.gz", hash = "sha256:aec30472177005379ba56a702a097d618c5f57558e1bb6676ec75f948130692a"}, - {file = "opencensus_ext_azure-1.1.13-py2.py3-none-any.whl", hash = "sha256:06001fac6f8588ba00726a3a7c6c7f2fc88bc8ad12a65afdca657923085393dd"}, + {file = "opencensus-ext-azure-1.1.14.tar.gz", hash = "sha256:c9c6ebad542aeb61813322e627d5889a563e7b8c4e024bf58469d06db73ab148"}, + {file = "opencensus_ext_azure-1.1.14-py2.py3-none-any.whl", hash = "sha256:a1f6870d6e4e312832e6ebd95df28ed499ac637c36cbd77665fe06e24ddeb2f1"}, ] [package.dependencies] @@ -7497,6 +7668,36 @@ type = "legacy" url = "https://mirrors.aliyun.com/pypi/simple" reference = "aliyun" +[[package]] +name = "opik" +version = "1.3.4" +description = "Comet tool for logging and evaluating LLM traces" +optional = false +python-versions = ">=3.8" +files = [ + {file = "opik-1.3.4-py3-none-any.whl", hash = "sha256:c5e10a9f1fb18188471cce2ae8b841e8b187d04ee3b1aed01c643102bae588fb"}, + {file = "opik-1.3.4.tar.gz", hash = "sha256:6013d3af4aea61f38b9e7121aa5d8cf4305a5ed3807b3f43d9ab91602b2a5785"}, +] + +[package.dependencies] +click = "*" +httpx = "<0.28.0" +levenshtein = "<1.0.0" +litellm = "*" +openai = "<2.0.0" +pydantic = ">=2.0.0,<3.0.0" +pydantic-settings = ">=2.0.0,<3.0.0" +pytest = "*" +rich = "*" +tenacity = "*" +tqdm = "*" +uuid6 = "*" + +[package.source] +type = "legacy" +url = "https://mirrors.aliyun.com/pypi/simple" +reference = "aliyun" + [[package]] name = "oracledb" version = "2.2.1" @@ -7547,86 +7748,86 @@ reference = "aliyun" [[package]] name = "orjson" -version = "3.10.12" +version = "3.10.13" description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" optional = false python-versions = ">=3.8" files = [ - {file = "orjson-3.10.12-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:ece01a7ec71d9940cc654c482907a6b65df27251255097629d0dea781f255c6d"}, - {file = "orjson-3.10.12-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c34ec9aebc04f11f4b978dd6caf697a2df2dd9b47d35aa4cc606cabcb9df69d7"}, - {file = "orjson-3.10.12-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fd6ec8658da3480939c79b9e9e27e0db31dffcd4ba69c334e98c9976ac29140e"}, - {file = "orjson-3.10.12-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f17e6baf4cf01534c9de8a16c0c611f3d94925d1701bf5f4aff17003677d8ced"}, - {file = "orjson-3.10.12-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6402ebb74a14ef96f94a868569f5dccf70d791de49feb73180eb3c6fda2ade56"}, - {file = "orjson-3.10.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0000758ae7c7853e0a4a6063f534c61656ebff644391e1f81698c1b2d2fc8cd2"}, - {file = "orjson-3.10.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:888442dcee99fd1e5bd37a4abb94930915ca6af4db50e23e746cdf4d1e63db13"}, - {file = "orjson-3.10.12-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:c1f7a3ce79246aa0e92f5458d86c54f257fb5dfdc14a192651ba7ec2c00f8a05"}, - {file = "orjson-3.10.12-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:802a3935f45605c66fb4a586488a38af63cb37aaad1c1d94c982c40dcc452e85"}, - {file = "orjson-3.10.12-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:1da1ef0113a2be19bb6c557fb0ec2d79c92ebd2fed4cfb1b26bab93f021fb885"}, - {file = "orjson-3.10.12-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7a3273e99f367f137d5b3fecb5e9f45bcdbfac2a8b2f32fbc72129bbd48789c2"}, - {file = "orjson-3.10.12-cp310-none-win32.whl", hash = "sha256:475661bf249fd7907d9b0a2a2421b4e684355a77ceef85b8352439a9163418c3"}, - {file = "orjson-3.10.12-cp310-none-win_amd64.whl", hash = "sha256:87251dc1fb2b9e5ab91ce65d8f4caf21910d99ba8fb24b49fd0c118b2362d509"}, - {file = "orjson-3.10.12-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:a734c62efa42e7df94926d70fe7d37621c783dea9f707a98cdea796964d4cf74"}, - {file = "orjson-3.10.12-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:750f8b27259d3409eda8350c2919a58b0cfcd2054ddc1bd317a643afc646ef23"}, - {file = "orjson-3.10.12-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bb52c22bfffe2857e7aa13b4622afd0dd9d16ea7cc65fd2bf318d3223b1b6252"}, - {file = "orjson-3.10.12-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:440d9a337ac8c199ff8251e100c62e9488924c92852362cd27af0e67308c16ef"}, - {file = "orjson-3.10.12-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a9e15c06491c69997dfa067369baab3bf094ecb74be9912bdc4339972323f252"}, - {file = "orjson-3.10.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:362d204ad4b0b8724cf370d0cd917bb2dc913c394030da748a3bb632445ce7c4"}, - {file = "orjson-3.10.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2b57cbb4031153db37b41622eac67329c7810e5f480fda4cfd30542186f006ae"}, - {file = "orjson-3.10.12-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:165c89b53ef03ce0d7c59ca5c82fa65fe13ddf52eeb22e859e58c237d4e33b9b"}, - {file = "orjson-3.10.12-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:5dee91b8dfd54557c1a1596eb90bcd47dbcd26b0baaed919e6861f076583e9da"}, - {file = "orjson-3.10.12-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:77a4e1cfb72de6f905bdff061172adfb3caf7a4578ebf481d8f0530879476c07"}, - {file = "orjson-3.10.12-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:038d42c7bc0606443459b8fe2d1f121db474c49067d8d14c6a075bbea8bf14dd"}, - {file = "orjson-3.10.12-cp311-none-win32.whl", hash = "sha256:03b553c02ab39bed249bedd4abe37b2118324d1674e639b33fab3d1dafdf4d79"}, - {file = "orjson-3.10.12-cp311-none-win_amd64.whl", hash = "sha256:8b8713b9e46a45b2af6b96f559bfb13b1e02006f4242c156cbadef27800a55a8"}, - {file = "orjson-3.10.12-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:53206d72eb656ca5ac7d3a7141e83c5bbd3ac30d5eccfe019409177a57634b0d"}, - {file = "orjson-3.10.12-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac8010afc2150d417ebda810e8df08dd3f544e0dd2acab5370cfa6bcc0662f8f"}, - {file = "orjson-3.10.12-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed459b46012ae950dd2e17150e838ab08215421487371fa79d0eced8d1461d70"}, - {file = "orjson-3.10.12-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8dcb9673f108a93c1b52bfc51b0af422c2d08d4fc710ce9c839faad25020bb69"}, - {file = "orjson-3.10.12-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:22a51ae77680c5c4652ebc63a83d5255ac7d65582891d9424b566fb3b5375ee9"}, - {file = "orjson-3.10.12-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:910fdf2ac0637b9a77d1aad65f803bac414f0b06f720073438a7bd8906298192"}, - {file = "orjson-3.10.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:24ce85f7100160936bc2116c09d1a8492639418633119a2224114f67f63a4559"}, - {file = "orjson-3.10.12-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8a76ba5fc8dd9c913640292df27bff80a685bed3a3c990d59aa6ce24c352f8fc"}, - {file = "orjson-3.10.12-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:ff70ef093895fd53f4055ca75f93f047e088d1430888ca1229393a7c0521100f"}, - {file = "orjson-3.10.12-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:f4244b7018b5753ecd10a6d324ec1f347da130c953a9c88432c7fbc8875d13be"}, - {file = "orjson-3.10.12-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:16135ccca03445f37921fa4b585cff9a58aa8d81ebcb27622e69bfadd220b32c"}, - {file = "orjson-3.10.12-cp312-none-win32.whl", hash = "sha256:2d879c81172d583e34153d524fcba5d4adafbab8349a7b9f16ae511c2cee8708"}, - {file = "orjson-3.10.12-cp312-none-win_amd64.whl", hash = "sha256:fc23f691fa0f5c140576b8c365bc942d577d861a9ee1142e4db468e4e17094fb"}, - {file = "orjson-3.10.12-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:47962841b2a8aa9a258b377f5188db31ba49af47d4003a32f55d6f8b19006543"}, - {file = "orjson-3.10.12-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6334730e2532e77b6054e87ca84f3072bee308a45a452ea0bffbbbc40a67e296"}, - {file = "orjson-3.10.12-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:accfe93f42713c899fdac2747e8d0d5c659592df2792888c6c5f829472e4f85e"}, - {file = "orjson-3.10.12-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a7974c490c014c48810d1dede6c754c3cc46598da758c25ca3b4001ac45b703f"}, - {file = "orjson-3.10.12-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:3f250ce7727b0b2682f834a3facff88e310f52f07a5dcfd852d99637d386e79e"}, - {file = "orjson-3.10.12-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f31422ff9486ae484f10ffc51b5ab2a60359e92d0716fcce1b3593d7bb8a9af6"}, - {file = "orjson-3.10.12-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5f29c5d282bb2d577c2a6bbde88d8fdcc4919c593f806aac50133f01b733846e"}, - {file = "orjson-3.10.12-cp313-none-win32.whl", hash = "sha256:f45653775f38f63dc0e6cd4f14323984c3149c05d6007b58cb154dd080ddc0dc"}, - {file = "orjson-3.10.12-cp313-none-win_amd64.whl", hash = "sha256:229994d0c376d5bdc91d92b3c9e6be2f1fbabd4cc1b59daae1443a46ee5e9825"}, - {file = "orjson-3.10.12-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:7d69af5b54617a5fac5c8e5ed0859eb798e2ce8913262eb522590239db6c6763"}, - {file = "orjson-3.10.12-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ed119ea7d2953365724a7059231a44830eb6bbb0cfead33fcbc562f5fd8f935"}, - {file = "orjson-3.10.12-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9c5fc1238ef197e7cad5c91415f524aaa51e004be5a9b35a1b8a84ade196f73f"}, - {file = "orjson-3.10.12-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43509843990439b05f848539d6f6198d4ac86ff01dd024b2f9a795c0daeeab60"}, - {file = "orjson-3.10.12-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f72e27a62041cfb37a3de512247ece9f240a561e6c8662276beaf4d53d406db4"}, - {file = "orjson-3.10.12-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a904f9572092bb6742ab7c16c623f0cdccbad9eeb2d14d4aa06284867bddd31"}, - {file = "orjson-3.10.12-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:855c0833999ed5dc62f64552db26f9be767434917d8348d77bacaab84f787d7b"}, - {file = "orjson-3.10.12-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:897830244e2320f6184699f598df7fb9db9f5087d6f3f03666ae89d607e4f8ed"}, - {file = "orjson-3.10.12-cp38-cp38-musllinux_1_2_armv7l.whl", hash = "sha256:0b32652eaa4a7539f6f04abc6243619c56f8530c53bf9b023e1269df5f7816dd"}, - {file = "orjson-3.10.12-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:36b4aa31e0f6a1aeeb6f8377769ca5d125db000f05c20e54163aef1d3fe8e833"}, - {file = "orjson-3.10.12-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:5535163054d6cbf2796f93e4f0dbc800f61914c0e3c4ed8499cf6ece22b4a3da"}, - {file = "orjson-3.10.12-cp38-none-win32.whl", hash = "sha256:90a5551f6f5a5fa07010bf3d0b4ca2de21adafbbc0af6cb700b63cd767266cb9"}, - {file = "orjson-3.10.12-cp38-none-win_amd64.whl", hash = "sha256:703a2fb35a06cdd45adf5d733cf613cbc0cb3ae57643472b16bc22d325b5fb6c"}, - {file = "orjson-3.10.12-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:f29de3ef71a42a5822765def1febfb36e0859d33abf5c2ad240acad5c6a1b78d"}, - {file = "orjson-3.10.12-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de365a42acc65d74953f05e4772c974dad6c51cfc13c3240899f534d611be967"}, - {file = "orjson-3.10.12-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:91a5a0158648a67ff0004cb0df5df7dcc55bfc9ca154d9c01597a23ad54c8d0c"}, - {file = "orjson-3.10.12-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c47ce6b8d90fe9646a25b6fb52284a14ff215c9595914af63a5933a49972ce36"}, - {file = "orjson-3.10.12-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0eee4c2c5bfb5c1b47a5db80d2ac7aaa7e938956ae88089f098aff2c0f35d5d8"}, - {file = "orjson-3.10.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:35d3081bbe8b86587eb5c98a73b97f13d8f9fea685cf91a579beddacc0d10566"}, - {file = "orjson-3.10.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:73c23a6e90383884068bc2dba83d5222c9fcc3b99a0ed2411d38150734236755"}, - {file = "orjson-3.10.12-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:5472be7dc3269b4b52acba1433dac239215366f89dc1d8d0e64029abac4e714e"}, - {file = "orjson-3.10.12-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:7319cda750fca96ae5973efb31b17d97a5c5225ae0bc79bf5bf84df9e1ec2ab6"}, - {file = "orjson-3.10.12-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:74d5ca5a255bf20b8def6a2b96b1e18ad37b4a122d59b154c458ee9494377f80"}, - {file = "orjson-3.10.12-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:ff31d22ecc5fb85ef62c7d4afe8301d10c558d00dd24274d4bbe464380d3cd69"}, - {file = "orjson-3.10.12-cp39-none-win32.whl", hash = "sha256:c22c3ea6fba91d84fcb4cda30e64aff548fcf0c44c876e681f47d61d24b12e6b"}, - {file = "orjson-3.10.12-cp39-none-win_amd64.whl", hash = "sha256:be604f60d45ace6b0b33dd990a66b4526f1a7a186ac411c942674625456ca548"}, - {file = "orjson-3.10.12.tar.gz", hash = "sha256:0a78bbda3aea0f9f079057ee1ee8a1ecf790d4f1af88dd67493c6b8ee52506ff"}, + {file = "orjson-3.10.13-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:1232c5e873a4d1638ef957c5564b4b0d6f2a6ab9e207a9b3de9de05a09d1d920"}, + {file = "orjson-3.10.13-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d26a0eca3035619fa366cbaf49af704c7cb1d4a0e6c79eced9f6a3f2437964b6"}, + {file = "orjson-3.10.13-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d4b6acd7c9c829895e50d385a357d4b8c3fafc19c5989da2bae11783b0fd4977"}, + {file = "orjson-3.10.13-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1884e53c6818686891cc6fc5a3a2540f2f35e8c76eac8dc3b40480fb59660b00"}, + {file = "orjson-3.10.13-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6a428afb5720f12892f64920acd2eeb4d996595bf168a26dd9190115dbf1130d"}, + {file = "orjson-3.10.13-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba5b13b8739ce5b630c65cb1c85aedbd257bcc2b9c256b06ab2605209af75a2e"}, + {file = "orjson-3.10.13-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cab83e67f6aabda1b45882254b2598b48b80ecc112968fc6483fa6dae609e9f0"}, + {file = "orjson-3.10.13-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:62c3cc00c7e776c71c6b7b9c48c5d2701d4c04e7d1d7cdee3572998ee6dc57cc"}, + {file = "orjson-3.10.13-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:dc03db4922e75bbc870b03fc49734cefbd50fe975e0878327d200022210b82d8"}, + {file = "orjson-3.10.13-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:22f1c9a30b43d14a041a6ea190d9eca8a6b80c4beb0e8b67602c82d30d6eec3e"}, + {file = "orjson-3.10.13-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b42f56821c29e697c68d7d421410d7c1d8f064ae288b525af6a50cf99a4b1200"}, + {file = "orjson-3.10.13-cp310-cp310-win32.whl", hash = "sha256:0dbf3b97e52e093d7c3e93eb5eb5b31dc7535b33c2ad56872c83f0160f943487"}, + {file = "orjson-3.10.13-cp310-cp310-win_amd64.whl", hash = "sha256:46c249b4e934453be4ff2e518cd1adcd90467da7391c7a79eaf2fbb79c51e8c7"}, + {file = "orjson-3.10.13-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:a36c0d48d2f084c800763473020a12976996f1109e2fcb66cfea442fdf88047f"}, + {file = "orjson-3.10.13-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0065896f85d9497990731dfd4a9991a45b0a524baec42ef0a63c34630ee26fd6"}, + {file = "orjson-3.10.13-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:92b4ec30d6025a9dcdfe0df77063cbce238c08d0404471ed7a79f309364a3d19"}, + {file = "orjson-3.10.13-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a94542d12271c30044dadad1125ee060e7a2048b6c7034e432e116077e1d13d2"}, + {file = "orjson-3.10.13-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3723e137772639af8adb68230f2aa4bcb27c48b3335b1b1e2d49328fed5e244c"}, + {file = "orjson-3.10.13-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f00c7fb18843bad2ac42dc1ce6dd214a083c53f1e324a0fd1c8137c6436269b"}, + {file = "orjson-3.10.13-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0e2759d3172300b2f892dee85500b22fca5ac49e0c42cfff101aaf9c12ac9617"}, + {file = "orjson-3.10.13-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ee948c6c01f6b337589c88f8e0bb11e78d32a15848b8b53d3f3b6fea48842c12"}, + {file = "orjson-3.10.13-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:aa6fe68f0981fba0d4bf9cdc666d297a7cdba0f1b380dcd075a9a3dd5649a69e"}, + {file = "orjson-3.10.13-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:dbcd7aad6bcff258f6896abfbc177d54d9b18149c4c561114f47ebfe74ae6bfd"}, + {file = "orjson-3.10.13-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:2149e2fcd084c3fd584881c7f9d7f9e5ad1e2e006609d8b80649655e0d52cd02"}, + {file = "orjson-3.10.13-cp311-cp311-win32.whl", hash = "sha256:89367767ed27b33c25c026696507c76e3d01958406f51d3a2239fe9e91959df2"}, + {file = "orjson-3.10.13-cp311-cp311-win_amd64.whl", hash = "sha256:dca1d20f1af0daff511f6e26a27354a424f0b5cf00e04280279316df0f604a6f"}, + {file = "orjson-3.10.13-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:a3614b00621c77f3f6487792238f9ed1dd8a42f2ec0e6540ee34c2d4e6db813a"}, + {file = "orjson-3.10.13-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c976bad3996aa027cd3aef78aa57873f3c959b6c38719de9724b71bdc7bd14b"}, + {file = "orjson-3.10.13-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f74d878d1efb97a930b8a9f9898890067707d683eb5c7e20730030ecb3fb930"}, + {file = "orjson-3.10.13-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:33ef84f7e9513fb13b3999c2a64b9ca9c8143f3da9722fbf9c9ce51ce0d8076e"}, + {file = "orjson-3.10.13-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dd2bcde107221bb9c2fa0c4aaba735a537225104173d7e19cf73f70b3126c993"}, + {file = "orjson-3.10.13-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:064b9dbb0217fd64a8d016a8929f2fae6f3312d55ab3036b00b1d17399ab2f3e"}, + {file = "orjson-3.10.13-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c0044b0b8c85a565e7c3ce0a72acc5d35cda60793edf871ed94711e712cb637d"}, + {file = "orjson-3.10.13-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7184f608ad563032e398f311910bc536e62b9fbdca2041be889afcbc39500de8"}, + {file = "orjson-3.10.13-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:d36f689e7e1b9b6fb39dbdebc16a6f07cbe994d3644fb1c22953020fc575935f"}, + {file = "orjson-3.10.13-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:54433e421618cd5873e51c0e9d0b9fb35f7bf76eb31c8eab20b3595bb713cd3d"}, + {file = "orjson-3.10.13-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e1ba0c5857dd743438acecc1cd0e1adf83f0a81fee558e32b2b36f89e40cee8b"}, + {file = "orjson-3.10.13-cp312-cp312-win32.whl", hash = "sha256:a42b9fe4b0114b51eb5cdf9887d8c94447bc59df6dbb9c5884434eab947888d8"}, + {file = "orjson-3.10.13-cp312-cp312-win_amd64.whl", hash = "sha256:3a7df63076435f39ec024bdfeb4c9767ebe7b49abc4949068d61cf4857fa6d6c"}, + {file = "orjson-3.10.13-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:2cdaf8b028a976ebab837a2c27b82810f7fc76ed9fb243755ba650cc83d07730"}, + {file = "orjson-3.10.13-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48a946796e390cbb803e069472de37f192b7a80f4ac82e16d6eb9909d9e39d56"}, + {file = "orjson-3.10.13-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a7d64f1db5ecbc21eb83097e5236d6ab7e86092c1cd4c216c02533332951afc"}, + {file = "orjson-3.10.13-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:711878da48f89df194edd2ba603ad42e7afed74abcd2bac164685e7ec15f96de"}, + {file = "orjson-3.10.13-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:cf16f06cb77ce8baf844bc222dbcb03838f61d0abda2c3341400c2b7604e436e"}, + {file = "orjson-3.10.13-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:8257c3fb8dd7b0b446b5e87bf85a28e4071ac50f8c04b6ce2d38cb4abd7dff57"}, + {file = "orjson-3.10.13-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d9c3a87abe6f849a4a7ac8a8a1dede6320a4303d5304006b90da7a3cd2b70d2c"}, + {file = "orjson-3.10.13-cp313-cp313-win32.whl", hash = "sha256:527afb6ddb0fa3fe02f5d9fba4920d9d95da58917826a9be93e0242da8abe94a"}, + {file = "orjson-3.10.13-cp313-cp313-win_amd64.whl", hash = "sha256:b5f7c298d4b935b222f52d6c7f2ba5eafb59d690d9a3840b7b5c5cda97f6ec5c"}, + {file = "orjson-3.10.13-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:e49333d1038bc03a25fdfe11c86360df9b890354bfe04215f1f54d030f33c342"}, + {file = "orjson-3.10.13-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:003721c72930dbb973f25c5d8e68d0f023d6ed138b14830cc94e57c6805a2eab"}, + {file = "orjson-3.10.13-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:63664bf12addb318dc8f032160e0f5dc17eb8471c93601e8f5e0d07f95003784"}, + {file = "orjson-3.10.13-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6066729cf9552d70de297b56556d14b4f49c8f638803ee3c90fd212fa43cc6af"}, + {file = "orjson-3.10.13-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8a1152e2761025c5d13b5e1908d4b1c57f3797ba662e485ae6f26e4e0c466388"}, + {file = "orjson-3.10.13-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:69b21d91c5c5ef8a201036d207b1adf3aa596b930b6ca3c71484dd11386cf6c3"}, + {file = "orjson-3.10.13-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b12a63f48bb53dba8453d36ca2661f2330126d54e26c1661e550b32864b28ce3"}, + {file = "orjson-3.10.13-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:a5a7624ab4d121c7e035708c8dd1f99c15ff155b69a1c0affc4d9d8b551281ba"}, + {file = "orjson-3.10.13-cp38-cp38-musllinux_1_2_armv7l.whl", hash = "sha256:0fee076134398d4e6cb827002468679ad402b22269510cf228301b787fdff5ae"}, + {file = "orjson-3.10.13-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:ae537fcf330b3947e82c6ae4271e092e6cf16b9bc2cef68b14ffd0df1fa8832a"}, + {file = "orjson-3.10.13-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:f81b26c03f5fb5f0d0ee48d83cea4d7bc5e67e420d209cc1a990f5d1c62f9be0"}, + {file = "orjson-3.10.13-cp38-cp38-win32.whl", hash = "sha256:0bc858086088b39dc622bc8219e73d3f246fb2bce70a6104abd04b3a080a66a8"}, + {file = "orjson-3.10.13-cp38-cp38-win_amd64.whl", hash = "sha256:3ca6f17467ebbd763f8862f1d89384a5051b461bb0e41074f583a0ebd7120e8e"}, + {file = "orjson-3.10.13-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4a11532cbfc2f5752c37e84863ef8435b68b0e6d459b329933294f65fa4bda1a"}, + {file = "orjson-3.10.13-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c96d2fb80467d1d0dfc4d037b4e1c0f84f1fe6229aa7fea3f070083acef7f3d7"}, + {file = "orjson-3.10.13-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dda4ba4d3e6f6c53b6b9c35266788053b61656a716a7fef5c884629c2a52e7aa"}, + {file = "orjson-3.10.13-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e4f998bbf300690be881772ee9c5281eb9c0044e295bcd4722504f5b5c6092ff"}, + {file = "orjson-3.10.13-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dce1cc42ed75b585c0c4dc5eb53a90a34ccb493c09a10750d1a1f9b9eff2bd12"}, + {file = "orjson-3.10.13-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03b0f29d485411e3c13d79604b740b14e4e5fb58811743f6f4f9693ee6480a8f"}, + {file = "orjson-3.10.13-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:233aae4474078d82f425134bb6a10fb2b3fc5a1a1b3420c6463ddd1b6a97eda8"}, + {file = "orjson-3.10.13-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:e384e330a67cf52b3597ee2646de63407da6f8fc9e9beec3eaaaef5514c7a1c9"}, + {file = "orjson-3.10.13-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:4222881d0aab76224d7b003a8e5fdae4082e32c86768e0e8652de8afd6c4e2c1"}, + {file = "orjson-3.10.13-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:e400436950ba42110a20c50c80dff4946c8e3ec09abc1c9cf5473467e83fd1c5"}, + {file = "orjson-3.10.13-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:f47c9e7d224b86ffb086059cdcf634f4b3f32480f9838864aa09022fe2617ce2"}, + {file = "orjson-3.10.13-cp39-cp39-win32.whl", hash = "sha256:a9ecea472f3eb653e1c0a3d68085f031f18fc501ea392b98dcca3e87c24f9ebe"}, + {file = "orjson-3.10.13-cp39-cp39-win_amd64.whl", hash = "sha256:5385935a73adce85cc7faac9d396683fd813566d3857fa95a0b521ef84a5b588"}, + {file = "orjson-3.10.13.tar.gz", hash = "sha256:eb9bfb14ab8f68d9d9492d4817ae497788a15fd7da72e14dfabc289c3bb088ec"}, ] [package.source] @@ -7788,6 +7989,26 @@ type = "legacy" url = "https://mirrors.aliyun.com/pypi/simple" reference = "aliyun" +[[package]] +name = "pandas-stubs" +version = "2.2.3.241126" +description = "Type annotations for pandas" +optional = false +python-versions = ">=3.10" +files = [ + {file = "pandas_stubs-2.2.3.241126-py3-none-any.whl", hash = "sha256:74aa79c167af374fe97068acc90776c0ebec5266a6e5c69fe11e9c2cf51f2267"}, + {file = "pandas_stubs-2.2.3.241126.tar.gz", hash = "sha256:cf819383c6d9ae7d4dabf34cd47e1e45525bb2f312e6ad2939c2c204cb708acd"}, +] + +[package.dependencies] +numpy = ">=1.23.5" +types-pytz = ">=2022.1.1" + +[package.source] +type = "legacy" +url = "https://mirrors.aliyun.com/pypi/simple" +reference = "aliyun" + [[package]] name = "pathos" version = "0.3.3" @@ -7872,93 +8093,89 @@ reference = "aliyun" [[package]] name = "pillow" -version = "11.0.0" +version = "11.1.0" description = "Python Imaging Library (Fork)" optional = false python-versions = ">=3.9" files = [ - {file = "pillow-11.0.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:6619654954dc4936fcff82db8eb6401d3159ec6be81e33c6000dfd76ae189947"}, - {file = "pillow-11.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b3c5ac4bed7519088103d9450a1107f76308ecf91d6dabc8a33a2fcfb18d0fba"}, - {file = "pillow-11.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a65149d8ada1055029fcb665452b2814fe7d7082fcb0c5bed6db851cb69b2086"}, - {file = "pillow-11.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88a58d8ac0cc0e7f3a014509f0455248a76629ca9b604eca7dc5927cc593c5e9"}, - {file = "pillow-11.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c26845094b1af3c91852745ae78e3ea47abf3dbcd1cf962f16b9a5fbe3ee8488"}, - {file = "pillow-11.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:1a61b54f87ab5786b8479f81c4b11f4d61702830354520837f8cc791ebba0f5f"}, - {file = "pillow-11.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:674629ff60030d144b7bca2b8330225a9b11c482ed408813924619c6f302fdbb"}, - {file = "pillow-11.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:598b4e238f13276e0008299bd2482003f48158e2b11826862b1eb2ad7c768b97"}, - {file = "pillow-11.0.0-cp310-cp310-win32.whl", hash = "sha256:9a0f748eaa434a41fccf8e1ee7a3eed68af1b690e75328fd7a60af123c193b50"}, - {file = "pillow-11.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:a5629742881bcbc1f42e840af185fd4d83a5edeb96475a575f4da50d6ede337c"}, - {file = "pillow-11.0.0-cp310-cp310-win_arm64.whl", hash = "sha256:ee217c198f2e41f184f3869f3e485557296d505b5195c513b2bfe0062dc537f1"}, - {file = "pillow-11.0.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:1c1d72714f429a521d8d2d018badc42414c3077eb187a59579f28e4270b4b0fc"}, - {file = "pillow-11.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:499c3a1b0d6fc8213519e193796eb1a86a1be4b1877d678b30f83fd979811d1a"}, - {file = "pillow-11.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c8b2351c85d855293a299038e1f89db92a2f35e8d2f783489c6f0b2b5f3fe8a3"}, - {file = "pillow-11.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f4dba50cfa56f910241eb7f883c20f1e7b1d8f7d91c750cd0b318bad443f4d5"}, - {file = "pillow-11.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:5ddbfd761ee00c12ee1be86c9c0683ecf5bb14c9772ddbd782085779a63dd55b"}, - {file = "pillow-11.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:45c566eb10b8967d71bf1ab8e4a525e5a93519e29ea071459ce517f6b903d7fa"}, - {file = "pillow-11.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b4fd7bd29610a83a8c9b564d457cf5bd92b4e11e79a4ee4716a63c959699b306"}, - {file = "pillow-11.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:cb929ca942d0ec4fac404cbf520ee6cac37bf35be479b970c4ffadf2b6a1cad9"}, - {file = "pillow-11.0.0-cp311-cp311-win32.whl", hash = "sha256:006bcdd307cc47ba43e924099a038cbf9591062e6c50e570819743f5607404f5"}, - {file = "pillow-11.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:52a2d8323a465f84faaba5236567d212c3668f2ab53e1c74c15583cf507a0291"}, - {file = "pillow-11.0.0-cp311-cp311-win_arm64.whl", hash = "sha256:16095692a253047fe3ec028e951fa4221a1f3ed3d80c397e83541a3037ff67c9"}, - {file = "pillow-11.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d2c0a187a92a1cb5ef2c8ed5412dd8d4334272617f532d4ad4de31e0495bd923"}, - {file = "pillow-11.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:084a07ef0821cfe4858fe86652fffac8e187b6ae677e9906e192aafcc1b69903"}, - {file = "pillow-11.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8069c5179902dcdce0be9bfc8235347fdbac249d23bd90514b7a47a72d9fecf4"}, - {file = "pillow-11.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f02541ef64077f22bf4924f225c0fd1248c168f86e4b7abdedd87d6ebaceab0f"}, - {file = "pillow-11.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:fcb4621042ac4b7865c179bb972ed0da0218a076dc1820ffc48b1d74c1e37fe9"}, - {file = "pillow-11.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:00177a63030d612148e659b55ba99527803288cea7c75fb05766ab7981a8c1b7"}, - {file = "pillow-11.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8853a3bf12afddfdf15f57c4b02d7ded92c7a75a5d7331d19f4f9572a89c17e6"}, - {file = "pillow-11.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3107c66e43bda25359d5ef446f59c497de2b5ed4c7fdba0894f8d6cf3822dafc"}, - {file = "pillow-11.0.0-cp312-cp312-win32.whl", hash = "sha256:86510e3f5eca0ab87429dd77fafc04693195eec7fd6a137c389c3eeb4cfb77c6"}, - {file = "pillow-11.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:8ec4a89295cd6cd4d1058a5e6aec6bf51e0eaaf9714774e1bfac7cfc9051db47"}, - {file = "pillow-11.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25"}, - {file = "pillow-11.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:bcd1fb5bb7b07f64c15618c89efcc2cfa3e95f0e3bcdbaf4642509de1942a699"}, - {file = "pillow-11.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0e038b0745997c7dcaae350d35859c9715c71e92ffb7e0f4a8e8a16732150f38"}, - {file = "pillow-11.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ae08bd8ffc41aebf578c2af2f9d8749d91f448b3bfd41d7d9ff573d74f2a6b2"}, - {file = "pillow-11.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d69bfd8ec3219ae71bcde1f942b728903cad25fafe3100ba2258b973bd2bc1b2"}, - {file = "pillow-11.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:61b887f9ddba63ddf62fd02a3ba7add935d053b6dd7d58998c630e6dbade8527"}, - {file = "pillow-11.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:c6a660307ca9d4867caa8d9ca2c2658ab685de83792d1876274991adec7b93fa"}, - {file = "pillow-11.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:73e3a0200cdda995c7e43dd47436c1548f87a30bb27fb871f352a22ab8dcf45f"}, - {file = "pillow-11.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fba162b8872d30fea8c52b258a542c5dfd7b235fb5cb352240c8d63b414013eb"}, - {file = "pillow-11.0.0-cp313-cp313-win32.whl", hash = "sha256:f1b82c27e89fffc6da125d5eb0ca6e68017faf5efc078128cfaa42cf5cb38798"}, - {file = "pillow-11.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:8ba470552b48e5835f1d23ecb936bb7f71d206f9dfeee64245f30c3270b994de"}, - {file = "pillow-11.0.0-cp313-cp313-win_arm64.whl", hash = "sha256:846e193e103b41e984ac921b335df59195356ce3f71dcfd155aa79c603873b84"}, - {file = "pillow-11.0.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:4ad70c4214f67d7466bea6a08061eba35c01b1b89eaa098040a35272a8efb22b"}, - {file = "pillow-11.0.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:6ec0d5af64f2e3d64a165f490d96368bb5dea8b8f9ad04487f9ab60dc4bb6003"}, - {file = "pillow-11.0.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c809a70e43c7977c4a42aefd62f0131823ebf7dd73556fa5d5950f5b354087e2"}, - {file = "pillow-11.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:4b60c9520f7207aaf2e1d94de026682fc227806c6e1f55bba7606d1c94dd623a"}, - {file = "pillow-11.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:1e2688958a840c822279fda0086fec1fdab2f95bf2b717b66871c4ad9859d7e8"}, - {file = "pillow-11.0.0-cp313-cp313t-win32.whl", hash = "sha256:607bbe123c74e272e381a8d1957083a9463401f7bd01287f50521ecb05a313f8"}, - {file = "pillow-11.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:5c39ed17edea3bc69c743a8dd3e9853b7509625c2462532e62baa0732163a904"}, - {file = "pillow-11.0.0-cp313-cp313t-win_arm64.whl", hash = "sha256:75acbbeb05b86bc53cbe7b7e6fe00fbcf82ad7c684b3ad82e3d711da9ba287d3"}, - {file = "pillow-11.0.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:2e46773dc9f35a1dd28bd6981332fd7f27bec001a918a72a79b4133cf5291dba"}, - {file = "pillow-11.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2679d2258b7f1192b378e2893a8a0a0ca472234d4c2c0e6bdd3380e8dfa21b6a"}, - {file = "pillow-11.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eda2616eb2313cbb3eebbe51f19362eb434b18e3bb599466a1ffa76a033fb916"}, - {file = "pillow-11.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20ec184af98a121fb2da42642dea8a29ec80fc3efbaefb86d8fdd2606619045d"}, - {file = "pillow-11.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:8594f42df584e5b4bb9281799698403f7af489fba84c34d53d1c4bfb71b7c4e7"}, - {file = "pillow-11.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:c12b5ae868897c7338519c03049a806af85b9b8c237b7d675b8c5e089e4a618e"}, - {file = "pillow-11.0.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:70fbbdacd1d271b77b7721fe3cdd2d537bbbd75d29e6300c672ec6bb38d9672f"}, - {file = "pillow-11.0.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5178952973e588b3f1360868847334e9e3bf49d19e169bbbdfaf8398002419ae"}, - {file = "pillow-11.0.0-cp39-cp39-win32.whl", hash = "sha256:8c676b587da5673d3c75bd67dd2a8cdfeb282ca38a30f37950511766b26858c4"}, - {file = "pillow-11.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:94f3e1780abb45062287b4614a5bc0874519c86a777d4a7ad34978e86428b8dd"}, - {file = "pillow-11.0.0-cp39-cp39-win_arm64.whl", hash = "sha256:290f2cc809f9da7d6d622550bbf4c1e57518212da51b6a30fe8e0a270a5b78bd"}, - {file = "pillow-11.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:1187739620f2b365de756ce086fdb3604573337cc28a0d3ac4a01ab6b2d2a6d2"}, - {file = "pillow-11.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:fbbcb7b57dc9c794843e3d1258c0fbf0f48656d46ffe9e09b63bbd6e8cd5d0a2"}, - {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d203af30149ae339ad1b4f710d9844ed8796e97fda23ffbc4cc472968a47d0b"}, - {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21a0d3b115009ebb8ac3d2ebec5c2982cc693da935f4ab7bb5c8ebe2f47d36f2"}, - {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:73853108f56df97baf2bb8b522f3578221e56f646ba345a372c78326710d3830"}, - {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e58876c91f97b0952eb766123bfef372792ab3f4e3e1f1a2267834c2ab131734"}, - {file = "pillow-11.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:224aaa38177597bb179f3ec87eeefcce8e4f85e608025e9cfac60de237ba6316"}, - {file = "pillow-11.0.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:5bd2d3bdb846d757055910f0a59792d33b555800813c3b39ada1829c372ccb06"}, - {file = "pillow-11.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:375b8dd15a1f5d2feafff536d47e22f69625c1aa92f12b339ec0b2ca40263273"}, - {file = "pillow-11.0.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:daffdf51ee5db69a82dd127eabecce20729e21f7a3680cf7cbb23f0829189790"}, - {file = "pillow-11.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7326a1787e3c7b0429659e0a944725e1b03eeaa10edd945a86dead1913383944"}, - {file = "pillow-11.0.0.tar.gz", hash = "sha256:72bacbaf24ac003fea9bff9837d1eedb6088758d41e100c1552930151f677739"}, + {file = "pillow-11.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:e1abe69aca89514737465752b4bcaf8016de61b3be1397a8fc260ba33321b3a8"}, + {file = "pillow-11.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c640e5a06869c75994624551f45e5506e4256562ead981cce820d5ab39ae2192"}, + {file = "pillow-11.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a07dba04c5e22824816b2615ad7a7484432d7f540e6fa86af60d2de57b0fcee2"}, + {file = "pillow-11.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e267b0ed063341f3e60acd25c05200df4193e15a4a5807075cd71225a2386e26"}, + {file = "pillow-11.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:bd165131fd51697e22421d0e467997ad31621b74bfc0b75956608cb2906dda07"}, + {file = "pillow-11.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:abc56501c3fd148d60659aae0af6ddc149660469082859fa7b066a298bde9482"}, + {file = "pillow-11.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:54ce1c9a16a9561b6d6d8cb30089ab1e5eb66918cb47d457bd996ef34182922e"}, + {file = "pillow-11.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:73ddde795ee9b06257dac5ad42fcb07f3b9b813f8c1f7f870f402f4dc54b5269"}, + {file = "pillow-11.1.0-cp310-cp310-win32.whl", hash = "sha256:3a5fe20a7b66e8135d7fd617b13272626a28278d0e578c98720d9ba4b2439d49"}, + {file = "pillow-11.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:b6123aa4a59d75f06e9dd3dac5bf8bc9aa383121bb3dd9a7a612e05eabc9961a"}, + {file = "pillow-11.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:a76da0a31da6fcae4210aa94fd779c65c75786bc9af06289cd1c184451ef7a65"}, + {file = "pillow-11.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:e06695e0326d05b06833b40b7ef477e475d0b1ba3a6d27da1bb48c23209bf457"}, + {file = "pillow-11.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96f82000e12f23e4f29346e42702b6ed9a2f2fea34a740dd5ffffcc8c539eb35"}, + {file = "pillow-11.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3cd561ded2cf2bbae44d4605837221b987c216cff94f49dfeed63488bb228d2"}, + {file = "pillow-11.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f189805c8be5ca5add39e6f899e6ce2ed824e65fb45f3c28cb2841911da19070"}, + {file = "pillow-11.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dd0052e9db3474df30433f83a71b9b23bd9e4ef1de13d92df21a52c0303b8ab6"}, + {file = "pillow-11.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:837060a8599b8f5d402e97197d4924f05a2e0d68756998345c829c33186217b1"}, + {file = "pillow-11.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:aa8dd43daa836b9a8128dbe7d923423e5ad86f50a7a14dc688194b7be5c0dea2"}, + {file = "pillow-11.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0a2f91f8a8b367e7a57c6e91cd25af510168091fb89ec5146003e424e1558a96"}, + {file = "pillow-11.1.0-cp311-cp311-win32.whl", hash = "sha256:c12fc111ef090845de2bb15009372175d76ac99969bdf31e2ce9b42e4b8cd88f"}, + {file = "pillow-11.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fbd43429d0d7ed6533b25fc993861b8fd512c42d04514a0dd6337fb3ccf22761"}, + {file = "pillow-11.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:f7955ecf5609dee9442cbface754f2c6e541d9e6eda87fad7f7a989b0bdb9d71"}, + {file = "pillow-11.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2062ffb1d36544d42fcaa277b069c88b01bb7298f4efa06731a7fd6cc290b81a"}, + {file = "pillow-11.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a85b653980faad27e88b141348707ceeef8a1186f75ecc600c395dcac19f385b"}, + {file = "pillow-11.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9409c080586d1f683df3f184f20e36fb647f2e0bc3988094d4fd8c9f4eb1b3b3"}, + {file = "pillow-11.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7fdadc077553621911f27ce206ffcbec7d3f8d7b50e0da39f10997e8e2bb7f6a"}, + {file = "pillow-11.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:93a18841d09bcdd774dcdc308e4537e1f867b3dec059c131fde0327899734aa1"}, + {file = "pillow-11.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:9aa9aeddeed452b2f616ff5507459e7bab436916ccb10961c4a382cd3e03f47f"}, + {file = "pillow-11.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3cdcdb0b896e981678eee140d882b70092dac83ac1cdf6b3a60e2216a73f2b91"}, + {file = "pillow-11.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:36ba10b9cb413e7c7dfa3e189aba252deee0602c86c309799da5a74009ac7a1c"}, + {file = "pillow-11.1.0-cp312-cp312-win32.whl", hash = "sha256:cfd5cd998c2e36a862d0e27b2df63237e67273f2fc78f47445b14e73a810e7e6"}, + {file = "pillow-11.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:a697cd8ba0383bba3d2d3ada02b34ed268cb548b369943cd349007730c92bddf"}, + {file = "pillow-11.1.0-cp312-cp312-win_arm64.whl", hash = "sha256:4dd43a78897793f60766563969442020e90eb7847463eca901e41ba186a7d4a5"}, + {file = "pillow-11.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ae98e14432d458fc3de11a77ccb3ae65ddce70f730e7c76140653048c71bfcbc"}, + {file = "pillow-11.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:cc1331b6d5a6e144aeb5e626f4375f5b7ae9934ba620c0ac6b3e43d5e683a0f0"}, + {file = "pillow-11.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:758e9d4ef15d3560214cddbc97b8ef3ef86ce04d62ddac17ad39ba87e89bd3b1"}, + {file = "pillow-11.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b523466b1a31d0dcef7c5be1f20b942919b62fd6e9a9be199d035509cbefc0ec"}, + {file = "pillow-11.1.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:9044b5e4f7083f209c4e35aa5dd54b1dd5b112b108648f5c902ad586d4f945c5"}, + {file = "pillow-11.1.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:3764d53e09cdedd91bee65c2527815d315c6b90d7b8b79759cc48d7bf5d4f114"}, + {file = "pillow-11.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:31eba6bbdd27dde97b0174ddf0297d7a9c3a507a8a1480e1e60ef914fe23d352"}, + {file = "pillow-11.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:b5d658fbd9f0d6eea113aea286b21d3cd4d3fd978157cbf2447a6035916506d3"}, + {file = "pillow-11.1.0-cp313-cp313-win32.whl", hash = "sha256:f86d3a7a9af5d826744fabf4afd15b9dfef44fe69a98541f666f66fbb8d3fef9"}, + {file = "pillow-11.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:593c5fd6be85da83656b93ffcccc2312d2d149d251e98588b14fbc288fd8909c"}, + {file = "pillow-11.1.0-cp313-cp313-win_arm64.whl", hash = "sha256:11633d58b6ee5733bde153a8dafd25e505ea3d32e261accd388827ee987baf65"}, + {file = "pillow-11.1.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:70ca5ef3b3b1c4a0812b5c63c57c23b63e53bc38e758b37a951e5bc466449861"}, + {file = "pillow-11.1.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:8000376f139d4d38d6851eb149b321a52bb8893a88dae8ee7d95840431977081"}, + {file = "pillow-11.1.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ee85f0696a17dd28fbcfceb59f9510aa71934b483d1f5601d1030c3c8304f3c"}, + {file = "pillow-11.1.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:dd0e081319328928531df7a0e63621caf67652c8464303fd102141b785ef9547"}, + {file = "pillow-11.1.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e63e4e5081de46517099dc30abe418122f54531a6ae2ebc8680bcd7096860eab"}, + {file = "pillow-11.1.0-cp313-cp313t-win32.whl", hash = "sha256:dda60aa465b861324e65a78c9f5cf0f4bc713e4309f83bc387be158b077963d9"}, + {file = "pillow-11.1.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ad5db5781c774ab9a9b2c4302bbf0c1014960a0a7be63278d13ae6fdf88126fe"}, + {file = "pillow-11.1.0-cp313-cp313t-win_arm64.whl", hash = "sha256:67cd427c68926108778a9005f2a04adbd5e67c442ed21d95389fe1d595458756"}, + {file = "pillow-11.1.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:bf902d7413c82a1bfa08b06a070876132a5ae6b2388e2712aab3a7cbc02205c6"}, + {file = "pillow-11.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c1eec9d950b6fe688edee07138993e54ee4ae634c51443cfb7c1e7613322718e"}, + {file = "pillow-11.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e275ee4cb11c262bd108ab2081f750db2a1c0b8c12c1897f27b160c8bd57bbc"}, + {file = "pillow-11.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4db853948ce4e718f2fc775b75c37ba2efb6aaea41a1a5fc57f0af59eee774b2"}, + {file = "pillow-11.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:ab8a209b8485d3db694fa97a896d96dd6533d63c22829043fd9de627060beade"}, + {file = "pillow-11.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:54251ef02a2309b5eec99d151ebf5c9904b77976c8abdcbce7891ed22df53884"}, + {file = "pillow-11.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:5bb94705aea800051a743aa4874bb1397d4695fb0583ba5e425ee0328757f196"}, + {file = "pillow-11.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:89dbdb3e6e9594d512780a5a1c42801879628b38e3efc7038094430844e271d8"}, + {file = "pillow-11.1.0-cp39-cp39-win32.whl", hash = "sha256:e5449ca63da169a2e6068dd0e2fcc8d91f9558aba89ff6d02121ca8ab11e79e5"}, + {file = "pillow-11.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:3362c6ca227e65c54bf71a5f88b3d4565ff1bcbc63ae72c34b07bbb1cc59a43f"}, + {file = "pillow-11.1.0-cp39-cp39-win_arm64.whl", hash = "sha256:b20be51b37a75cc54c2c55def3fa2c65bb94ba859dde241cd0a4fd302de5ae0a"}, + {file = "pillow-11.1.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:8c730dc3a83e5ac137fbc92dfcfe1511ce3b2b5d7578315b63dbbb76f7f51d90"}, + {file = "pillow-11.1.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:7d33d2fae0e8b170b6a6c57400e077412240f6f5bb2a342cf1ee512a787942bb"}, + {file = "pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a8d65b38173085f24bc07f8b6c505cbb7418009fa1a1fcb111b1f4961814a442"}, + {file = "pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:015c6e863faa4779251436db398ae75051469f7c903b043a48f078e437656f83"}, + {file = "pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d44ff19eea13ae4acdaaab0179fa68c0c6f2f45d66a4d8ec1eda7d6cecbcc15f"}, + {file = "pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d3d8da4a631471dfaf94c10c85f5277b1f8e42ac42bade1ac67da4b4a7359b73"}, + {file = "pillow-11.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:4637b88343166249fe8aa94e7c4a62a180c4b3898283bb5d3d2fd5fe10d8e4e0"}, + {file = "pillow-11.1.0.tar.gz", hash = "sha256:368da70808b36d73b4b390a8ffac11069f8a5c85f29eff1f1b01bcf3ef5b2a20"}, ] [package.extras] docs = ["furo", "olefile", "sphinx (>=8.1)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinxext-opengraph"] fpx = ["olefile"] mic = ["olefile"] -tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] +tests = ["check-manifest", "coverage (>=7.4.2)", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout", "trove-classifiers (>=2024.10.12)"] typing = ["typing-extensions"] xmp = ["defusedxml"] @@ -8091,13 +8308,13 @@ reference = "aliyun" [[package]] name = "posthog" -version = "3.7.4" +version = "3.7.5" description = "Integrate PostHog into any python application." optional = false python-versions = "*" files = [ - {file = "posthog-3.7.4-py2.py3-none-any.whl", hash = "sha256:21c18c6bf43b2de303ea4cd6e95804cc0f24c20cb2a96a8fd09da2ed50b62faa"}, - {file = "posthog-3.7.4.tar.gz", hash = "sha256:19384bd09d330f9787a7e2446aba14c8057ece56144970ea2791072d4e40cd36"}, + {file = "posthog-3.7.5-py2.py3-none-any.whl", hash = "sha256:022132c17069dde03c5c5904e2ae1b9bd68d5059cbc5a8dffc5c1537a1b71cb5"}, + {file = "posthog-3.7.5.tar.gz", hash = "sha256:8ba40ab623da35db72715fc87fe7dccb7fc272ced92581fe31db2d4dbe7ad761"}, ] [package.dependencies] @@ -8154,20 +8371,20 @@ reference = "aliyun" [[package]] name = "primp" -version = "0.9.2" +version = "0.10.0" description = "HTTP client that can impersonate web browsers, mimicking their headers and `TLS/JA3/JA4/HTTP2` fingerprints" optional = false python-versions = ">=3.8" files = [ - {file = "primp-0.9.2-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:a3179640e633be843ed5daba5c4e3086ad91f77c7bb40a9db06326f28d56b12b"}, - {file = "primp-0.9.2-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:94a5da8ba25f74152b43bc16a7591dfb5d7d30a5827dc0a0f96a956f7d3616be"}, - {file = "primp-0.9.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0668c0abb6d56fc8b0a918179b1d0f68e7267c1dc632e2b683c618317e13143f"}, - {file = "primp-0.9.2-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:a9c29a4b8eabfc28a1746d2fe93d33b9fcf2e81e642dd0e3eaecede60cc36b7d"}, - {file = "primp-0.9.2-cp38-abi3-manylinux_2_34_armv7l.whl", hash = "sha256:04d499308a101b06b40f5fda1bdc795db5731cd0dfbb1a8873f4acd07c085b1d"}, - {file = "primp-0.9.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:4cd5daf39034a0a8c96cdc0c4c306184c6f2b1b2a0b39dc3294d79ed28a6f7fe"}, - {file = "primp-0.9.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:8d14653434837eb431b3cf7ca006647d7a196906e48bba96bb600ba2ba70bcdc"}, - {file = "primp-0.9.2-cp38-abi3-win_amd64.whl", hash = "sha256:80d9f07564dc9b25b1a9676df770561418557c124fedecae84f6491a1974b61d"}, - {file = "primp-0.9.2.tar.gz", hash = "sha256:5b95666c25b9107eab3c05a89cb7b1748d5122e57c57b25bfc3249d525c45300"}, + {file = "primp-0.10.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:7a91a089bf2962b5b56c8d83d09535eb81cf55b53c09d83208b9e5a715cf2c17"}, + {file = "primp-0.10.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:0128453cce81552f7aa6ac2bf9b8741b7816cdb2d10536e62c77daaf6483b9af"}, + {file = "primp-0.10.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a959e9a83cff0ae7a85a02cc183e4db636f69ff41dddb7c4e32f997924923417"}, + {file = "primp-0.10.0-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:8e711cfa019fa9bdc0cba4d5d596f319c884a4329e505bd73e92eee0b024061a"}, + {file = "primp-0.10.0-cp38-abi3-manylinux_2_34_armv7l.whl", hash = "sha256:b859336d9a35669b68a29c5d8f050e0dca380452dabf6c9667bb8599f010d164"}, + {file = "primp-0.10.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:dc875cc9a733fe3e6344a37f2b5888e0a9605bb37807fc3009f3b03786408f34"}, + {file = "primp-0.10.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a27c5d997c37bf8237963c11e376eaa66e7eccee39164e3e259a1c3767c304d6"}, + {file = "primp-0.10.0-cp38-abi3-win_amd64.whl", hash = "sha256:7fe94c3164c2efffff08f7f54c018ac445112961b3ce4f4f499315ba0a9d1ef3"}, + {file = "primp-0.10.0.tar.gz", hash = "sha256:93142590a5a1958240ee5b74faaf2f55185ed499ccaabc622d71cb0cc8a47a0b"}, ] [package.extras] @@ -8197,6 +8414,102 @@ type = "legacy" url = "https://mirrors.aliyun.com/pypi/simple" reference = "aliyun" +[[package]] +name = "propcache" +version = "0.2.1" +description = "Accelerated property cache" +optional = false +python-versions = ">=3.9" +files = [ + {file = "propcache-0.2.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:6b3f39a85d671436ee3d12c017f8fdea38509e4f25b28eb25877293c98c243f6"}, + {file = "propcache-0.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:39d51fbe4285d5db5d92a929e3e21536ea3dd43732c5b177c7ef03f918dff9f2"}, + {file = "propcache-0.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6445804cf4ec763dc70de65a3b0d9954e868609e83850a47ca4f0cb64bd79fea"}, + {file = "propcache-0.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9479aa06a793c5aeba49ce5c5692ffb51fcd9a7016e017d555d5e2b0045d212"}, + {file = "propcache-0.2.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d9631c5e8b5b3a0fda99cb0d29c18133bca1e18aea9effe55adb3da1adef80d3"}, + {file = "propcache-0.2.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3156628250f46a0895f1f36e1d4fbe062a1af8718ec3ebeb746f1d23f0c5dc4d"}, + {file = "propcache-0.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b6fb63ae352e13748289f04f37868099e69dba4c2b3e271c46061e82c745634"}, + {file = "propcache-0.2.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:887d9b0a65404929641a9fabb6452b07fe4572b269d901d622d8a34a4e9043b2"}, + {file = "propcache-0.2.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a96dc1fa45bd8c407a0af03b2d5218392729e1822b0c32e62c5bf7eeb5fb3958"}, + {file = "propcache-0.2.1-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:a7e65eb5c003a303b94aa2c3852ef130230ec79e349632d030e9571b87c4698c"}, + {file = "propcache-0.2.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:999779addc413181912e984b942fbcc951be1f5b3663cd80b2687758f434c583"}, + {file = "propcache-0.2.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:19a0f89a7bb9d8048d9c4370c9c543c396e894c76be5525f5e1ad287f1750ddf"}, + {file = "propcache-0.2.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:1ac2f5fe02fa75f56e1ad473f1175e11f475606ec9bd0be2e78e4734ad575034"}, + {file = "propcache-0.2.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:574faa3b79e8ebac7cb1d7930f51184ba1ccf69adfdec53a12f319a06030a68b"}, + {file = "propcache-0.2.1-cp310-cp310-win32.whl", hash = "sha256:03ff9d3f665769b2a85e6157ac8b439644f2d7fd17615a82fa55739bc97863f4"}, + {file = "propcache-0.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:2d3af2e79991102678f53e0dbf4c35de99b6b8b58f29a27ca0325816364caaba"}, + {file = "propcache-0.2.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:1ffc3cca89bb438fb9c95c13fc874012f7b9466b89328c3c8b1aa93cdcfadd16"}, + {file = "propcache-0.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f174bbd484294ed9fdf09437f889f95807e5f229d5d93588d34e92106fbf6717"}, + {file = "propcache-0.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:70693319e0b8fd35dd863e3e29513875eb15c51945bf32519ef52927ca883bc3"}, + {file = "propcache-0.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b480c6a4e1138e1aa137c0079b9b6305ec6dcc1098a8ca5196283e8a49df95a9"}, + {file = "propcache-0.2.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d27b84d5880f6d8aa9ae3edb253c59d9f6642ffbb2c889b78b60361eed449787"}, + {file = "propcache-0.2.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:857112b22acd417c40fa4595db2fe28ab900c8c5fe4670c7989b1c0230955465"}, + {file = "propcache-0.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf6c4150f8c0e32d241436526f3c3f9cbd34429492abddbada2ffcff506c51af"}, + {file = "propcache-0.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:66d4cfda1d8ed687daa4bc0274fcfd5267873db9a5bc0418c2da19273040eeb7"}, + {file = "propcache-0.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c2f992c07c0fca81655066705beae35fc95a2fa7366467366db627d9f2ee097f"}, + {file = "propcache-0.2.1-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:4a571d97dbe66ef38e472703067021b1467025ec85707d57e78711c085984e54"}, + {file = "propcache-0.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:bb6178c241278d5fe853b3de743087be7f5f4c6f7d6d22a3b524d323eecec505"}, + {file = "propcache-0.2.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ad1af54a62ffe39cf34db1aa6ed1a1873bd548f6401db39d8e7cd060b9211f82"}, + {file = "propcache-0.2.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:e7048abd75fe40712005bcfc06bb44b9dfcd8e101dda2ecf2f5aa46115ad07ca"}, + {file = "propcache-0.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:160291c60081f23ee43d44b08a7e5fb76681221a8e10b3139618c5a9a291b84e"}, + {file = "propcache-0.2.1-cp311-cp311-win32.whl", hash = "sha256:819ce3b883b7576ca28da3861c7e1a88afd08cc8c96908e08a3f4dd64a228034"}, + {file = "propcache-0.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:edc9fc7051e3350643ad929df55c451899bb9ae6d24998a949d2e4c87fb596d3"}, + {file = "propcache-0.2.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:081a430aa8d5e8876c6909b67bd2d937bfd531b0382d3fdedb82612c618bc41a"}, + {file = "propcache-0.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d2ccec9ac47cf4e04897619c0e0c1a48c54a71bdf045117d3a26f80d38ab1fb0"}, + {file = "propcache-0.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:14d86fe14b7e04fa306e0c43cdbeebe6b2c2156a0c9ce56b815faacc193e320d"}, + {file = "propcache-0.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:049324ee97bb67285b49632132db351b41e77833678432be52bdd0289c0e05e4"}, + {file = "propcache-0.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1cd9a1d071158de1cc1c71a26014dcdfa7dd3d5f4f88c298c7f90ad6f27bb46d"}, + {file = "propcache-0.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98110aa363f1bb4c073e8dcfaefd3a5cea0f0834c2aab23dda657e4dab2f53b5"}, + {file = "propcache-0.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:647894f5ae99c4cf6bb82a1bb3a796f6e06af3caa3d32e26d2350d0e3e3faf24"}, + {file = "propcache-0.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bfd3223c15bebe26518d58ccf9a39b93948d3dcb3e57a20480dfdd315356baff"}, + {file = "propcache-0.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d71264a80f3fcf512eb4f18f59423fe82d6e346ee97b90625f283df56aee103f"}, + {file = "propcache-0.2.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:e73091191e4280403bde6c9a52a6999d69cdfde498f1fdf629105247599b57ec"}, + {file = "propcache-0.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3935bfa5fede35fb202c4b569bb9c042f337ca4ff7bd540a0aa5e37131659348"}, + {file = "propcache-0.2.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:f508b0491767bb1f2b87fdfacaba5f7eddc2f867740ec69ece6d1946d29029a6"}, + {file = "propcache-0.2.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:1672137af7c46662a1c2be1e8dc78cb6d224319aaa40271c9257d886be4363a6"}, + {file = "propcache-0.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b74c261802d3d2b85c9df2dfb2fa81b6f90deeef63c2db9f0e029a3cac50b518"}, + {file = "propcache-0.2.1-cp312-cp312-win32.whl", hash = "sha256:d09c333d36c1409d56a9d29b3a1b800a42c76a57a5a8907eacdbce3f18768246"}, + {file = "propcache-0.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:c214999039d4f2a5b2073ac506bba279945233da8c786e490d411dfc30f855c1"}, + {file = "propcache-0.2.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:aca405706e0b0a44cc6bfd41fbe89919a6a56999157f6de7e182a990c36e37bc"}, + {file = "propcache-0.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:12d1083f001ace206fe34b6bdc2cb94be66d57a850866f0b908972f90996b3e9"}, + {file = "propcache-0.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d93f3307ad32a27bda2e88ec81134b823c240aa3abb55821a8da553eed8d9439"}, + {file = "propcache-0.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba278acf14471d36316159c94a802933d10b6a1e117b8554fe0d0d9b75c9d536"}, + {file = "propcache-0.2.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4e6281aedfca15301c41f74d7005e6e3f4ca143584ba696ac69df4f02f40d629"}, + {file = "propcache-0.2.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5b750a8e5a1262434fb1517ddf64b5de58327f1adc3524a5e44c2ca43305eb0b"}, + {file = "propcache-0.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf72af5e0fb40e9babf594308911436c8efde3cb5e75b6f206c34ad18be5c052"}, + {file = "propcache-0.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b2d0a12018b04f4cb820781ec0dffb5f7c7c1d2a5cd22bff7fb055a2cb19ebce"}, + {file = "propcache-0.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e800776a79a5aabdb17dcc2346a7d66d0777e942e4cd251defeb084762ecd17d"}, + {file = "propcache-0.2.1-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:4160d9283bd382fa6c0c2b5e017acc95bc183570cd70968b9202ad6d8fc48dce"}, + {file = "propcache-0.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:30b43e74f1359353341a7adb783c8f1b1c676367b011709f466f42fda2045e95"}, + {file = "propcache-0.2.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:58791550b27d5488b1bb52bc96328456095d96206a250d28d874fafe11b3dfaf"}, + {file = "propcache-0.2.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:0f022d381747f0dfe27e99d928e31bc51a18b65bb9e481ae0af1380a6725dd1f"}, + {file = "propcache-0.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:297878dc9d0a334358f9b608b56d02e72899f3b8499fc6044133f0d319e2ec30"}, + {file = "propcache-0.2.1-cp313-cp313-win32.whl", hash = "sha256:ddfab44e4489bd79bda09d84c430677fc7f0a4939a73d2bba3073036f487a0a6"}, + {file = "propcache-0.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:556fc6c10989f19a179e4321e5d678db8eb2924131e64652a51fe83e4c3db0e1"}, + {file = "propcache-0.2.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:6a9a8c34fb7bb609419a211e59da8887eeca40d300b5ea8e56af98f6fbbb1541"}, + {file = "propcache-0.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ae1aa1cd222c6d205853b3013c69cd04515f9d6ab6de4b0603e2e1c33221303e"}, + {file = "propcache-0.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:accb6150ce61c9c4b7738d45550806aa2b71c7668c6942f17b0ac182b6142fd4"}, + {file = "propcache-0.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5eee736daafa7af6d0a2dc15cc75e05c64f37fc37bafef2e00d77c14171c2097"}, + {file = "propcache-0.2.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7a31fc1e1bd362874863fdeed71aed92d348f5336fd84f2197ba40c59f061bd"}, + {file = "propcache-0.2.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cba4cfa1052819d16699e1d55d18c92b6e094d4517c41dd231a8b9f87b6fa681"}, + {file = "propcache-0.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f089118d584e859c62b3da0892b88a83d611c2033ac410e929cb6754eec0ed16"}, + {file = "propcache-0.2.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:781e65134efaf88feb447e8c97a51772aa75e48b794352f94cb7ea717dedda0d"}, + {file = "propcache-0.2.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:31f5af773530fd3c658b32b6bdc2d0838543de70eb9a2156c03e410f7b0d3aae"}, + {file = "propcache-0.2.1-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:a7a078f5d37bee6690959c813977da5291b24286e7b962e62a94cec31aa5188b"}, + {file = "propcache-0.2.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:cea7daf9fc7ae6687cf1e2c049752f19f146fdc37c2cc376e7d0032cf4f25347"}, + {file = "propcache-0.2.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:8b3489ff1ed1e8315674d0775dc7d2195fb13ca17b3808721b54dbe9fd020faf"}, + {file = "propcache-0.2.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9403db39be1393618dd80c746cb22ccda168efce239c73af13c3763ef56ffc04"}, + {file = "propcache-0.2.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5d97151bc92d2b2578ff7ce779cdb9174337390a535953cbb9452fb65164c587"}, + {file = "propcache-0.2.1-cp39-cp39-win32.whl", hash = "sha256:9caac6b54914bdf41bcc91e7eb9147d331d29235a7c967c150ef5df6464fd1bb"}, + {file = "propcache-0.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:92fc4500fcb33899b05ba73276dfb684a20d31caa567b7cb5252d48f896a91b1"}, + {file = "propcache-0.2.1-py3-none-any.whl", hash = "sha256:52277518d6aae65536e9cea52d4e7fd2f7a66f4aa2d30ed3f2fcea620ace3c54"}, + {file = "propcache-0.2.1.tar.gz", hash = "sha256:3f77ce728b19cb537714499928fe800c3dda29e8d9428778fc7c186da4c09a64"}, +] + +[package.source] +type = "legacy" +url = "https://mirrors.aliyun.com/pypi/simple" +reference = "aliyun" + [[package]] name = "proto-plus" version = "1.25.0" @@ -8279,6 +8592,21 @@ type = "legacy" url = "https://mirrors.aliyun.com/pypi/simple" reference = "aliyun" +[[package]] +name = "psycogreen" +version = "1.0.2" +description = "psycopg2 integration with coroutine libraries" +optional = false +python-versions = "*" +files = [ + {file = "psycogreen-1.0.2.tar.gz", hash = "sha256:c429845a8a49cf2f76b71265008760bcd7c7c77d80b806db4dc81116dbcd130d"}, +] + +[package.source] +type = "legacy" +url = "https://mirrors.aliyun.com/pypi/simple" +reference = "aliyun" + [[package]] name = "psycopg2-binary" version = "2.9.10" @@ -8769,13 +9097,13 @@ reference = "aliyun" [[package]] name = "pygments" -version = "2.18.0" +version = "2.19.1" description = "Pygments is a syntax highlighting package written in Python." optional = false python-versions = ">=3.8" files = [ - {file = "pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a"}, - {file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"}, + {file = "pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c"}, + {file = "pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f"}, ] [package.extras] @@ -8813,21 +9141,21 @@ reference = "aliyun" [[package]] name = "pymilvus" -version = "2.4.9" +version = "2.5.3" description = "Python Sdk for Milvus" optional = false python-versions = ">=3.8" files = [ - {file = "pymilvus-2.4.9-py3-none-any.whl", hash = "sha256:45313607d2c164064bdc44e0f933cb6d6afa92e9efcc7f357c5240c57db58fbe"}, - {file = "pymilvus-2.4.9.tar.gz", hash = "sha256:0937663700007c23a84cfc0656160b301f6ff9247aaec4c96d599a6b43572136"}, + {file = "pymilvus-2.5.3-py3-none-any.whl", hash = "sha256:64ca63594284586937274800be27a402f3be2d078130bf81d94ab8d7798ac9c8"}, + {file = "pymilvus-2.5.3.tar.gz", hash = "sha256:68bc3797b7a14c494caf116cee888894ffd6eba7b96a3ac841be85d60694cc5d"}, ] [package.dependencies] -environs = "<=9.5.0" -grpcio = ">=1.49.1" -milvus-lite = {version = ">=2.4.0,<2.5.0", markers = "sys_platform != \"win32\""} +grpcio = ">=1.49.1,<=1.67.1" +milvus-lite = {version = ">=2.4.0", markers = "sys_platform != \"win32\""} pandas = ">=1.2.4" protobuf = ">=3.20.0" +python-dotenv = ">=1.0.1,<2.0.0" setuptools = ">69" ujson = ">=2.0.0" @@ -8945,13 +9273,13 @@ reference = "aliyun" [[package]] name = "pyparsing" -version = "3.2.0" +version = "3.2.1" description = "pyparsing module - Classes and methods to define and execute parsing grammars" optional = false python-versions = ">=3.9" files = [ - {file = "pyparsing-3.2.0-py3-none-any.whl", hash = "sha256:93d9577b88da0bbea8cc8334ee8b918ed014968fd2ec383e868fb8afb1ccef84"}, - {file = "pyparsing-3.2.0.tar.gz", hash = "sha256:cbf74e27246d595d9a74b186b810f6fbb86726dbf3b9532efb343f6d7294fe9c"}, + {file = "pyparsing-3.2.1-py3-none-any.whl", hash = "sha256:506ff4f4386c4cec0590ec19e6302d3aedb992fdc02c761e90416f158dacf8e1"}, + {file = "pyparsing-3.2.1.tar.gz", hash = "sha256:61980854fd66de3a90028d679a954d5f2623e83144b5afe5ee86f43d762e5f0a"}, ] [package.extras] @@ -9331,13 +9659,13 @@ reference = "aliyun" [[package]] name = "python-dotenv" -version = "1.0.0" +version = "1.0.1" description = "Read key-value pairs from a .env file and set them as environment variables" optional = false python-versions = ">=3.8" files = [ - {file = "python-dotenv-1.0.0.tar.gz", hash = "sha256:a8df96034aae6d2d50a4ebe8216326c61c3eb64836776504fcca410e5937a3ba"}, - {file = "python_dotenv-1.0.0-py3-none-any.whl", hash = "sha256:f5971a9226b701070a4bf2c38c89e5a3f0d64de8debda981d1db98583009122a"}, + {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, + {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, ] [package.extras] @@ -9772,20 +10100,20 @@ reference = "aliyun" [[package]] name = "realtime" -version = "2.0.2" +version = "2.1.0" description = "" optional = false python-versions = ">=3.9,<4.0" files = [ - {file = "realtime-2.0.2-py3-none-any.whl", hash = "sha256:2634c915bc38807f2013f21e8bcc4d2f79870dfd81460ddb9393883d0489928a"}, - {file = "realtime-2.0.2.tar.gz", hash = "sha256:519da9325b3b8102139d51785013d592f6b2403d81fa21d838a0b0234723ed7d"}, + {file = "realtime-2.1.0-py3-none-any.whl", hash = "sha256:e2d4f28bb2a08c1cf80e40fbf31e6116544ad29d67dd4093093e511ad738708c"}, + {file = "realtime-2.1.0.tar.gz", hash = "sha256:ca3ae6be47667a3cf3a307fec982ec1bf60313c38a8e29f016ab0380b76d7adb"}, ] [package.dependencies] -aiohttp = ">=3.10.2,<4.0.0" +aiohttp = ">=3.11.11,<4.0.0" python-dateutil = ">=2.8.1,<3.0.0" typing-extensions = ">=4.12.2,<5.0.0" -websockets = ">=11,<13" +websockets = ">=11,<14" [package.source] type = "legacy" @@ -10256,29 +10584,29 @@ reference = "aliyun" [[package]] name = "ruff" -version = "0.8.4" +version = "0.8.6" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.8.4-py3-none-linux_armv6l.whl", hash = "sha256:58072f0c06080276804c6a4e21a9045a706584a958e644353603d36ca1eb8a60"}, - {file = "ruff-0.8.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ffb60904651c00a1e0b8df594591770018a0f04587f7deeb3838344fe3adabac"}, - {file = "ruff-0.8.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6ddf5d654ac0d44389f6bf05cee4caeefc3132a64b58ea46738111d687352296"}, - {file = "ruff-0.8.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e248b1f0fa2749edd3350a2a342b67b43a2627434c059a063418e3d375cfe643"}, - {file = "ruff-0.8.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bf197b98ed86e417412ee3b6c893f44c8864f816451441483253d5ff22c0e81e"}, - {file = "ruff-0.8.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c41319b85faa3aadd4d30cb1cffdd9ac6b89704ff79f7664b853785b48eccdf3"}, - {file = "ruff-0.8.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:9f8402b7c4f96463f135e936d9ab77b65711fcd5d72e5d67597b543bbb43cf3f"}, - {file = "ruff-0.8.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e4e56b3baa9c23d324ead112a4fdf20db9a3f8f29eeabff1355114dd96014604"}, - {file = "ruff-0.8.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:736272574e97157f7edbbb43b1d046125fce9e7d8d583d5d65d0c9bf2c15addf"}, - {file = "ruff-0.8.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5fe710ab6061592521f902fca7ebcb9fabd27bc7c57c764298b1c1f15fff720"}, - {file = "ruff-0.8.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:13e9ec6d6b55f6da412d59953d65d66e760d583dd3c1c72bf1f26435b5bfdbae"}, - {file = "ruff-0.8.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:97d9aefef725348ad77d6db98b726cfdb075a40b936c7984088804dfd38268a7"}, - {file = "ruff-0.8.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:ab78e33325a6f5374e04c2ab924a3367d69a0da36f8c9cb6b894a62017506111"}, - {file = "ruff-0.8.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:8ef06f66f4a05c3ddbc9121a8b0cecccd92c5bf3dd43b5472ffe40b8ca10f0f8"}, - {file = "ruff-0.8.4-py3-none-win32.whl", hash = "sha256:552fb6d861320958ca5e15f28b20a3d071aa83b93caee33a87b471f99a6c0835"}, - {file = "ruff-0.8.4-py3-none-win_amd64.whl", hash = "sha256:f21a1143776f8656d7f364bd264a9d60f01b7f52243fbe90e7670c0dfe0cf65d"}, - {file = "ruff-0.8.4-py3-none-win_arm64.whl", hash = "sha256:9183dd615d8df50defa8b1d9a074053891ba39025cf5ae88e8bcb52edcc4bf08"}, - {file = "ruff-0.8.4.tar.gz", hash = "sha256:0d5f89f254836799af1615798caa5f80b7f935d7a670fad66c5007928e57ace8"}, + {file = "ruff-0.8.6-py3-none-linux_armv6l.whl", hash = "sha256:defed167955d42c68b407e8f2e6f56ba52520e790aba4ca707a9c88619e580e3"}, + {file = "ruff-0.8.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:54799ca3d67ae5e0b7a7ac234baa657a9c1784b48ec954a094da7c206e0365b1"}, + {file = "ruff-0.8.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:e88b8f6d901477c41559ba540beeb5a671e14cd29ebd5683903572f4b40a9807"}, + {file = "ruff-0.8.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0509e8da430228236a18a677fcdb0c1f102dd26d5520f71f79b094963322ed25"}, + {file = "ruff-0.8.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:91a7ddb221779871cf226100e677b5ea38c2d54e9e2c8ed847450ebbdf99b32d"}, + {file = "ruff-0.8.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:248b1fb3f739d01d528cc50b35ee9c4812aa58cc5935998e776bf8ed5b251e75"}, + {file = "ruff-0.8.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:bc3c083c50390cf69e7e1b5a5a7303898966be973664ec0c4a4acea82c1d4315"}, + {file = "ruff-0.8.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:52d587092ab8df308635762386f45f4638badb0866355b2b86760f6d3c076188"}, + {file = "ruff-0.8.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:61323159cf21bc3897674e5adb27cd9e7700bab6b84de40d7be28c3d46dc67cf"}, + {file = "ruff-0.8.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ae4478b1471fc0c44ed52a6fb787e641a2ac58b1c1f91763bafbc2faddc5117"}, + {file = "ruff-0.8.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0c000a471d519b3e6cfc9c6680025d923b4ca140ce3e4612d1a2ef58e11f11fe"}, + {file = "ruff-0.8.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:9257aa841e9e8d9b727423086f0fa9a86b6b420fbf4bf9e1465d1250ce8e4d8d"}, + {file = "ruff-0.8.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:45a56f61b24682f6f6709636949ae8cc82ae229d8d773b4c76c09ec83964a95a"}, + {file = "ruff-0.8.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:496dd38a53aa173481a7d8866bcd6451bd934d06976a2505028a50583e001b76"}, + {file = "ruff-0.8.6-py3-none-win32.whl", hash = "sha256:e169ea1b9eae61c99b257dc83b9ee6c76f89042752cb2d83486a7d6e48e8f764"}, + {file = "ruff-0.8.6-py3-none-win_amd64.whl", hash = "sha256:f1d70bef3d16fdc897ee290d7d20da3cbe4e26349f62e8a0274e7a3f4ce7a905"}, + {file = "ruff-0.8.6-py3-none-win_arm64.whl", hash = "sha256:7d7fc2377a04b6e04ffe588caad613d0c460eb2ecba4c0ccbbfe2bc973cbc162"}, + {file = "ruff-0.8.6.tar.gz", hash = "sha256:dcad24b81b62650b0eb8814f576fc65cfee8674772a6e24c9b747911801eeaa5"}, ] [package.source] @@ -10590,53 +10918,60 @@ reference = "aliyun" [[package]] name = "scipy" -version = "1.14.1" +version = "1.15.0" description = "Fundamental algorithms for scientific computing in Python" optional = false python-versions = ">=3.10" files = [ - {file = "scipy-1.14.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389"}, - {file = "scipy-1.14.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3"}, - {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0"}, - {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3"}, - {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d"}, - {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69"}, - {file = "scipy-1.14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad"}, - {file = "scipy-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5"}, - {file = "scipy-1.14.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675"}, - {file = "scipy-1.14.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2"}, - {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617"}, - {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8"}, - {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37"}, - {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2"}, - {file = "scipy-1.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2"}, - {file = "scipy-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94"}, - {file = "scipy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d"}, - {file = "scipy-1.14.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07"}, - {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5"}, - {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc"}, - {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310"}, - {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066"}, - {file = "scipy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1"}, - {file = "scipy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f"}, - {file = "scipy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79"}, - {file = "scipy-1.14.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e"}, - {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73"}, - {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e"}, - {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d"}, - {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e"}, - {file = "scipy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06"}, - {file = "scipy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84"}, - {file = "scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417"}, -] - -[package.dependencies] -numpy = ">=1.23.5,<2.3" + {file = "scipy-1.15.0-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:aeac60d3562a7bf2f35549bdfdb6b1751c50590f55ce7322b4b2fc821dc27fca"}, + {file = "scipy-1.15.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:5abbdc6ede5c5fed7910cf406a948e2c0869231c0db091593a6b2fa78be77e5d"}, + {file = "scipy-1.15.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:eb1533c59f0ec6c55871206f15a5c72d1fae7ad3c0a8ca33ca88f7c309bbbf8c"}, + {file = "scipy-1.15.0-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:de112c2dae53107cfeaf65101419662ac0a54e9a088c17958b51c95dac5de56d"}, + {file = "scipy-1.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2240e1fd0782e62e1aacdc7234212ee271d810f67e9cd3b8d521003a82603ef8"}, + {file = "scipy-1.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d35aef233b098e4de88b1eac29f0df378278e7e250a915766786b773309137c4"}, + {file = "scipy-1.15.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:1b29e4fc02e155a5fd1165f1e6a73edfdd110470736b0f48bcbe48083f0eee37"}, + {file = "scipy-1.15.0-cp310-cp310-win_amd64.whl", hash = "sha256:0e5b34f8894f9904cc578008d1a9467829c1817e9f9cb45e6d6eeb61d2ab7731"}, + {file = "scipy-1.15.0-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:46e91b5b16909ff79224b56e19cbad65ca500b3afda69225820aa3afbf9ec020"}, + {file = "scipy-1.15.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:82bff2eb01ccf7cea8b6ee5274c2dbeadfdac97919da308ee6d8e5bcbe846443"}, + {file = "scipy-1.15.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:9c8254fe21dd2c6c8f7757035ec0c31daecf3bb3cffd93bc1ca661b731d28136"}, + {file = "scipy-1.15.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:c9624eeae79b18cab1a31944b5ef87aa14b125d6ab69b71db22f0dbd962caf1e"}, + {file = "scipy-1.15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d13bbc0658c11f3d19df4138336e4bce2c4fbd78c2755be4bf7b8e235481557f"}, + {file = "scipy-1.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdca4c7bb8dc41307e5f39e9e5d19c707d8e20a29845e7533b3bb20a9d4ccba0"}, + {file = "scipy-1.15.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6f376d7c767731477bac25a85d0118efdc94a572c6b60decb1ee48bf2391a73b"}, + {file = "scipy-1.15.0-cp311-cp311-win_amd64.whl", hash = "sha256:61513b989ee8d5218fbeb178b2d51534ecaddba050db949ae99eeb3d12f6825d"}, + {file = "scipy-1.15.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5beb0a2200372b7416ec73fdae94fe81a6e85e44eb49c35a11ac356d2b8eccc6"}, + {file = "scipy-1.15.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:fde0f3104dfa1dfbc1f230f65506532d0558d43188789eaf68f97e106249a913"}, + {file = "scipy-1.15.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:35c68f7044b4e7ad73a3e68e513dda946989e523df9b062bd3cf401a1a882192"}, + {file = "scipy-1.15.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:52475011be29dfcbecc3dfe3060e471ac5155d72e9233e8d5616b84e2b542054"}, + {file = "scipy-1.15.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5972e3f96f7dda4fd3bb85906a17338e65eaddfe47f750e240f22b331c08858e"}, + {file = "scipy-1.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe00169cf875bed0b3c40e4da45b57037dc21d7c7bf0c85ed75f210c281488f1"}, + {file = "scipy-1.15.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:161f80a98047c219c257bf5ce1777c574bde36b9d962a46b20d0d7e531f86863"}, + {file = "scipy-1.15.0-cp312-cp312-win_amd64.whl", hash = "sha256:327163ad73e54541a675240708244644294cb0a65cca420c9c79baeb9648e479"}, + {file = "scipy-1.15.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0fcb16eb04d84670722ce8d93b05257df471704c913cb0ff9dc5a1c31d1e9422"}, + {file = "scipy-1.15.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:767e8cf6562931f8312f4faa7ddea412cb783d8df49e62c44d00d89f41f9bbe8"}, + {file = "scipy-1.15.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:37ce9394cdcd7c5f437583fc6ef91bd290014993900643fdfc7af9b052d1613b"}, + {file = "scipy-1.15.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:6d26f17c64abd6c6c2dfb39920f61518cc9e213d034b45b2380e32ba78fde4c0"}, + {file = "scipy-1.15.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1e2448acd79c6374583581a1ded32ac71a00c2b9c62dfa87a40e1dd2520be111"}, + {file = "scipy-1.15.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36be480e512d38db67f377add5b759fb117edd987f4791cdf58e59b26962bee4"}, + {file = "scipy-1.15.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ccb6248a9987193fe74363a2d73b93bc2c546e0728bd786050b7aef6e17db03c"}, + {file = "scipy-1.15.0-cp313-cp313-win_amd64.whl", hash = "sha256:952d2e9eaa787f0a9e95b6e85da3654791b57a156c3e6609e65cc5176ccfe6f2"}, + {file = "scipy-1.15.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:b1432102254b6dc7766d081fa92df87832ac25ff0b3d3a940f37276e63eb74ff"}, + {file = "scipy-1.15.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:4e08c6a36f46abaedf765dd2dfcd3698fa4bd7e311a9abb2d80e33d9b2d72c34"}, + {file = "scipy-1.15.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:ec915cd26d76f6fc7ae8522f74f5b2accf39546f341c771bb2297f3871934a52"}, + {file = "scipy-1.15.0-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:351899dd2a801edd3691622172bc8ea01064b1cada794f8641b89a7dc5418db6"}, + {file = "scipy-1.15.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9baff912ea4f78a543d183ed6f5b3bea9784509b948227daaf6f10727a0e2e5"}, + {file = "scipy-1.15.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:cd9d9198a7fd9a77f0eb5105ea9734df26f41faeb2a88a0e62e5245506f7b6df"}, + {file = "scipy-1.15.0-cp313-cp313t-win_amd64.whl", hash = "sha256:129f899ed275c0515d553b8d31696924e2ca87d1972421e46c376b9eb87de3d2"}, + {file = "scipy-1.15.0.tar.gz", hash = "sha256:300742e2cc94e36a2880ebe464a1c8b4352a7b0f3e36ec3d2ac006cdbe0219ac"}, +] + +[package.dependencies] +numpy = ">=1.23.5,<2.5" [package.extras] dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"] -doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<=7.3.7)", "sphinx-design (>=0.4.0)"] -test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +doc = ["intersphinx_registry", "jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.16.5)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<8.0.0)", "sphinx-copybutton", "sphinx-design (>=0.4.0)"] +test = ["Cython", "array-api-strict (>=2.0,<2.1.1)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] [package.source] type = "legacy" @@ -10700,23 +11035,23 @@ reference = "aliyun" [[package]] name = "setuptools" -version = "75.6.0" +version = "75.7.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.9" files = [ - {file = "setuptools-75.6.0-py3-none-any.whl", hash = "sha256:ce74b49e8f7110f9bf04883b730f4765b774ef3ef28f722cce7c273d253aaf7d"}, - {file = "setuptools-75.6.0.tar.gz", hash = "sha256:8199222558df7c86216af4f84c30e9b34a61d8ba19366cc914424cdbd28252f6"}, + {file = "setuptools-75.7.0-py3-none-any.whl", hash = "sha256:84fb203f278ebcf5cd08f97d3fb96d3fbed4b629d500b29ad60d11e00769b183"}, + {file = "setuptools-75.7.0.tar.gz", hash = "sha256:886ff7b16cd342f1d1defc16fc98c9ce3fde69e087a4e1983d7ab634e5f41f4f"}, ] [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.7.0)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.8.0)"] core = ["importlib_metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] enabler = ["pytest-enabler (>=2.2)"] -test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] -type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (>=1.12,<1.14)", "pytest-mypy"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14.*)", "pytest-mypy"] [package.source] type = "legacy" @@ -11237,13 +11572,13 @@ reference = "aliyun" [[package]] name = "tencentcloud-sdk-python-common" -version = "3.0.1289" +version = "3.0.1298" description = "Tencent Cloud Common SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-common-3.0.1289.tar.gz", hash = "sha256:3761b577ee295b02f28672468f1984c6c08e39bd86c3757ed8f270b180d985c6"}, - {file = "tencentcloud_sdk_python_common-3.0.1289-py2.py3-none-any.whl", hash = "sha256:3a6cde93c2bace9a783250f0fb17313d857c4d861e706eed9218dee67366b2f9"}, + {file = "tencentcloud-sdk-python-common-3.0.1298.tar.gz", hash = "sha256:0f0f182410c1ceda5764ff8bcbef27aa6139caf1c5f5985d94ec731a41c8a59f"}, + {file = "tencentcloud_sdk_python_common-3.0.1298-py2.py3-none-any.whl", hash = "sha256:c80929a0ff57ebee4ceec749dc82d5f2d1105b888e55175a7e9c722afc3a5d7a"}, ] [package.dependencies] @@ -11256,17 +11591,17 @@ reference = "aliyun" [[package]] name = "tencentcloud-sdk-python-hunyuan" -version = "3.0.1289" +version = "3.0.1298" description = "Tencent Cloud Hunyuan SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-hunyuan-3.0.1289.tar.gz", hash = "sha256:40d5a611bb4a858dce59593246dec3b0acf03ac76fb8bd0f7d009a33bf55f638"}, - {file = "tencentcloud_sdk_python_hunyuan-3.0.1289-py2.py3-none-any.whl", hash = "sha256:fdb1ff887e4a8429bb2df98dea1ad4af06ca2ff723eb47982ba8e70731b72356"}, + {file = "tencentcloud-sdk-python-hunyuan-3.0.1298.tar.gz", hash = "sha256:c3d86a577de02046d25682a3804955453555fa641082bb8765238460bded3f03"}, + {file = "tencentcloud_sdk_python_hunyuan-3.0.1298-py2.py3-none-any.whl", hash = "sha256:f01e33318b6a4152ac88c500fda77f2cda1864eeca000cdd29c41e4f92f8de65"}, ] [package.dependencies] -tencentcloud-sdk-python-common = "3.0.1289" +tencentcloud-sdk-python-common = "3.0.1298" [package.source] type = "legacy" @@ -11736,6 +12071,22 @@ type = "legacy" url = "https://mirrors.aliyun.com/pypi/simple" reference = "aliyun" +[[package]] +name = "types-pytz" +version = "2024.2.0.20241221" +description = "Typing stubs for pytz" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types_pytz-2024.2.0.20241221-py3-none-any.whl", hash = "sha256:8fc03195329c43637ed4f593663df721fef919b60a969066e22606edf0b53ad5"}, + {file = "types_pytz-2024.2.0.20241221.tar.gz", hash = "sha256:06d7cde9613e9f7504766a0554a270c369434b50e00975b3a4a0f6eed0f2c1a9"}, +] + +[package.source] +type = "legacy" +url = "https://mirrors.aliyun.com/pypi/simple" +reference = "aliyun" + [[package]] name = "types-requests" version = "2.32.0.20241016" @@ -11901,13 +12252,13 @@ reference = "aliyun" [[package]] name = "unstructured" -version = "0.16.11" +version = "0.16.12" description = "A library that prepares raw documents for downstream ML tasks." optional = false python-versions = ">=3.9.0,<3.13" files = [ - {file = "unstructured-0.16.11-py3-none-any.whl", hash = "sha256:a92d5bc2c2b7bb23369641fb7a7f0daba1775639199306ce4cd83ca564a03763"}, - {file = "unstructured-0.16.11.tar.gz", hash = "sha256:33ebf68aae11ce33c8a96335296557b5abd8ba96eaba3e5a1554c0b9eee40bb5"}, + {file = "unstructured-0.16.12-py3-none-any.whl", hash = "sha256:bcac29ac1b38fba4228c5a1a7721d1aa7c48220f7c1dd43b563645c56e978c49"}, + {file = "unstructured-0.16.12.tar.gz", hash = "sha256:c3133731c6edb9c2f474e62cb2b560cd0a8d578c4532ec14d8c0941e401770b0"}, ] [package.dependencies] @@ -11921,6 +12272,7 @@ html5lib = "*" langdetect = "*" lxml = "*" markdown = {version = "*", optional = true, markers = "extra == \"md\""} +ndjson = "*" nltk = "*" numpy = "<2" psutil = "*" @@ -12049,6 +12401,22 @@ type = "legacy" url = "https://mirrors.aliyun.com/pypi/simple" reference = "aliyun" +[[package]] +name = "uuid6" +version = "2024.7.10" +description = "New time-based UUID formats which are suited for use as a database key" +optional = false +python-versions = ">=3.8" +files = [ + {file = "uuid6-2024.7.10-py3-none-any.whl", hash = "sha256:93432c00ba403751f722829ad21759ff9db051dea140bf81493271e8e4dd18b7"}, + {file = "uuid6-2024.7.10.tar.gz", hash = "sha256:2d29d7f63f593caaeea0e0d0dd0ad8129c9c663b29e19bdf882e864bedf18fb0"}, +] + +[package.source] +type = "legacy" +url = "https://mirrors.aliyun.com/pypi/simple" +reference = "aliyun" + [[package]] name = "uvicorn" version = "0.34.0" @@ -12453,83 +12821,97 @@ reference = "aliyun" [[package]] name = "websockets" -version = "12.0" +version = "13.1" description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" optional = false python-versions = ">=3.8" files = [ - {file = "websockets-12.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d554236b2a2006e0ce16315c16eaa0d628dab009c33b63ea03f41c6107958374"}, - {file = "websockets-12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2d225bb6886591b1746b17c0573e29804619c8f755b5598d875bb4235ea639be"}, - {file = "websockets-12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:eb809e816916a3b210bed3c82fb88eaf16e8afcf9c115ebb2bacede1797d2547"}, - {file = "websockets-12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c588f6abc13f78a67044c6b1273a99e1cf31038ad51815b3b016ce699f0d75c2"}, - {file = "websockets-12.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5aa9348186d79a5f232115ed3fa9020eab66d6c3437d72f9d2c8ac0c6858c558"}, - {file = "websockets-12.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6350b14a40c95ddd53e775dbdbbbc59b124a5c8ecd6fbb09c2e52029f7a9f480"}, - {file = "websockets-12.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:70ec754cc2a769bcd218ed8d7209055667b30860ffecb8633a834dde27d6307c"}, - {file = "websockets-12.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6e96f5ed1b83a8ddb07909b45bd94833b0710f738115751cdaa9da1fb0cb66e8"}, - {file = "websockets-12.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4d87be612cbef86f994178d5186add3d94e9f31cc3cb499a0482b866ec477603"}, - {file = "websockets-12.0-cp310-cp310-win32.whl", hash = "sha256:befe90632d66caaf72e8b2ed4d7f02b348913813c8b0a32fae1cc5fe3730902f"}, - {file = "websockets-12.0-cp310-cp310-win_amd64.whl", hash = "sha256:363f57ca8bc8576195d0540c648aa58ac18cf85b76ad5202b9f976918f4219cf"}, - {file = "websockets-12.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5d873c7de42dea355d73f170be0f23788cf3fa9f7bed718fd2830eefedce01b4"}, - {file = "websockets-12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3f61726cae9f65b872502ff3c1496abc93ffbe31b278455c418492016e2afc8f"}, - {file = "websockets-12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ed2fcf7a07334c77fc8a230755c2209223a7cc44fc27597729b8ef5425aa61a3"}, - {file = "websockets-12.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e332c210b14b57904869ca9f9bf4ca32f5427a03eeb625da9b616c85a3a506c"}, - {file = "websockets-12.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5693ef74233122f8ebab026817b1b37fe25c411ecfca084b29bc7d6efc548f45"}, - {file = "websockets-12.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e9e7db18b4539a29cc5ad8c8b252738a30e2b13f033c2d6e9d0549b45841c04"}, - {file = "websockets-12.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6e2df67b8014767d0f785baa98393725739287684b9f8d8a1001eb2839031447"}, - {file = "websockets-12.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:bea88d71630c5900690fcb03161ab18f8f244805c59e2e0dc4ffadae0a7ee0ca"}, - {file = "websockets-12.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dff6cdf35e31d1315790149fee351f9e52978130cef6c87c4b6c9b3baf78bc53"}, - {file = "websockets-12.0-cp311-cp311-win32.whl", hash = "sha256:3e3aa8c468af01d70332a382350ee95f6986db479ce7af14d5e81ec52aa2b402"}, - {file = "websockets-12.0-cp311-cp311-win_amd64.whl", hash = "sha256:25eb766c8ad27da0f79420b2af4b85d29914ba0edf69f547cc4f06ca6f1d403b"}, - {file = "websockets-12.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0e6e2711d5a8e6e482cacb927a49a3d432345dfe7dea8ace7b5790df5932e4df"}, - {file = "websockets-12.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:dbcf72a37f0b3316e993e13ecf32f10c0e1259c28ffd0a85cee26e8549595fbc"}, - {file = "websockets-12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:12743ab88ab2af1d17dd4acb4645677cb7063ef4db93abffbf164218a5d54c6b"}, - {file = "websockets-12.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b645f491f3c48d3f8a00d1fce07445fab7347fec54a3e65f0725d730d5b99cb"}, - {file = "websockets-12.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9893d1aa45a7f8b3bc4510f6ccf8db8c3b62120917af15e3de247f0780294b92"}, - {file = "websockets-12.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f38a7b376117ef7aff996e737583172bdf535932c9ca021746573bce40165ed"}, - {file = "websockets-12.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:f764ba54e33daf20e167915edc443b6f88956f37fb606449b4a5b10ba42235a5"}, - {file = "websockets-12.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1e4b3f8ea6a9cfa8be8484c9221ec0257508e3a1ec43c36acdefb2a9c3b00aa2"}, - {file = "websockets-12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9fdf06fd06c32205a07e47328ab49c40fc1407cdec801d698a7c41167ea45113"}, - {file = "websockets-12.0-cp312-cp312-win32.whl", hash = "sha256:baa386875b70cbd81798fa9f71be689c1bf484f65fd6fb08d051a0ee4e79924d"}, - {file = "websockets-12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ae0a5da8f35a5be197f328d4727dbcfafa53d1824fac3d96cdd3a642fe09394f"}, - {file = "websockets-12.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5f6ffe2c6598f7f7207eef9a1228b6f5c818f9f4d53ee920aacd35cec8110438"}, - {file = "websockets-12.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9edf3fc590cc2ec20dc9d7a45108b5bbaf21c0d89f9fd3fd1685e223771dc0b2"}, - {file = "websockets-12.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8572132c7be52632201a35f5e08348137f658e5ffd21f51f94572ca6c05ea81d"}, - {file = "websockets-12.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:604428d1b87edbf02b233e2c207d7d528460fa978f9e391bd8aaf9c8311de137"}, - {file = "websockets-12.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a9d160fd080c6285e202327aba140fc9a0d910b09e423afff4ae5cbbf1c7205"}, - {file = "websockets-12.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87b4aafed34653e465eb77b7c93ef058516cb5acf3eb21e42f33928616172def"}, - {file = "websockets-12.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b2ee7288b85959797970114deae81ab41b731f19ebcd3bd499ae9ca0e3f1d2c8"}, - {file = "websockets-12.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:7fa3d25e81bfe6a89718e9791128398a50dec6d57faf23770787ff441d851967"}, - {file = "websockets-12.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a571f035a47212288e3b3519944f6bf4ac7bc7553243e41eac50dd48552b6df7"}, - {file = "websockets-12.0-cp38-cp38-win32.whl", hash = "sha256:3c6cc1360c10c17463aadd29dd3af332d4a1adaa8796f6b0e9f9df1fdb0bad62"}, - {file = "websockets-12.0-cp38-cp38-win_amd64.whl", hash = "sha256:1bf386089178ea69d720f8db6199a0504a406209a0fc23e603b27b300fdd6892"}, - {file = "websockets-12.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ab3d732ad50a4fbd04a4490ef08acd0517b6ae6b77eb967251f4c263011a990d"}, - {file = "websockets-12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a1d9697f3337a89691e3bd8dc56dea45a6f6d975f92e7d5f773bc715c15dde28"}, - {file = "websockets-12.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1df2fbd2c8a98d38a66f5238484405b8d1d16f929bb7a33ed73e4801222a6f53"}, - {file = "websockets-12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23509452b3bc38e3a057382c2e941d5ac2e01e251acce7adc74011d7d8de434c"}, - {file = "websockets-12.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e5fc14ec6ea568200ea4ef46545073da81900a2b67b3e666f04adf53ad452ec"}, - {file = "websockets-12.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46e71dbbd12850224243f5d2aeec90f0aaa0f2dde5aeeb8fc8df21e04d99eff9"}, - {file = "websockets-12.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b81f90dcc6c85a9b7f29873beb56c94c85d6f0dac2ea8b60d995bd18bf3e2aae"}, - {file = "websockets-12.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:a02413bc474feda2849c59ed2dfb2cddb4cd3d2f03a2fedec51d6e959d9b608b"}, - {file = "websockets-12.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:bbe6013f9f791944ed31ca08b077e26249309639313fff132bfbf3ba105673b9"}, - {file = "websockets-12.0-cp39-cp39-win32.whl", hash = "sha256:cbe83a6bbdf207ff0541de01e11904827540aa069293696dd528a6640bd6a5f6"}, - {file = "websockets-12.0-cp39-cp39-win_amd64.whl", hash = "sha256:fc4e7fa5414512b481a2483775a8e8be7803a35b30ca805afa4998a84f9fd9e8"}, - {file = "websockets-12.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:248d8e2446e13c1d4326e0a6a4e9629cb13a11195051a73acf414812700badbd"}, - {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f44069528d45a933997a6fef143030d8ca8042f0dfaad753e2906398290e2870"}, - {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c4e37d36f0d19f0a4413d3e18c0d03d0c268ada2061868c1e6f5ab1a6d575077"}, - {file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d829f975fc2e527a3ef2f9c8f25e553eb7bc779c6665e8e1d52aa22800bb38b"}, - {file = "websockets-12.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2c71bd45a777433dd9113847af751aae36e448bc6b8c361a566cb043eda6ec30"}, - {file = "websockets-12.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0bee75f400895aef54157b36ed6d3b308fcab62e5260703add87f44cee9c82a6"}, - {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:423fc1ed29f7512fceb727e2d2aecb952c46aa34895e9ed96071821309951123"}, - {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:27a5e9964ef509016759f2ef3f2c1e13f403725a5e6a1775555994966a66e931"}, - {file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3181df4583c4d3994d31fb235dc681d2aaad744fbdbf94c4802485ececdecf2"}, - {file = "websockets-12.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:b067cb952ce8bf40115f6c19f478dc71c5e719b7fbaa511359795dfd9d1a6468"}, - {file = "websockets-12.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:00700340c6c7ab788f176d118775202aadea7602c5cc6be6ae127761c16d6b0b"}, - {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e469d01137942849cff40517c97a30a93ae79917752b34029f0ec72df6b46399"}, - {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffefa1374cd508d633646d51a8e9277763a9b78ae71324183693959cf94635a7"}, - {file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba0cab91b3956dfa9f512147860783a1829a8d905ee218a9837c18f683239611"}, - {file = "websockets-12.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2cb388a5bfb56df4d9a406783b7f9dbefb888c09b71629351cc6b036e9259370"}, - {file = "websockets-12.0-py3-none-any.whl", hash = "sha256:dc284bbc8d7c78a6c69e0c7325ab46ee5e40bb4d50e494d8131a07ef47500e9e"}, - {file = "websockets-12.0.tar.gz", hash = "sha256:81df9cbcbb6c260de1e007e58c011bfebe2dafc8435107b0537f393dd38c8b1b"}, + {file = "websockets-13.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f48c749857f8fb598fb890a75f540e3221d0976ed0bf879cf3c7eef34151acee"}, + {file = "websockets-13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c7e72ce6bda6fb9409cc1e8164dd41d7c91466fb599eb047cfda72fe758a34a7"}, + {file = "websockets-13.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f779498eeec470295a2b1a5d97aa1bc9814ecd25e1eb637bd9d1c73a327387f6"}, + {file = "websockets-13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4676df3fe46956fbb0437d8800cd5f2b6d41143b6e7e842e60554398432cf29b"}, + {file = "websockets-13.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7affedeb43a70351bb811dadf49493c9cfd1ed94c9c70095fd177e9cc1541fa"}, + {file = "websockets-13.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1971e62d2caa443e57588e1d82d15f663b29ff9dfe7446d9964a4b6f12c1e700"}, + {file = "websockets-13.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5f2e75431f8dc4a47f31565a6e1355fb4f2ecaa99d6b89737527ea917066e26c"}, + {file = "websockets-13.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:58cf7e75dbf7e566088b07e36ea2e3e2bd5676e22216e4cad108d4df4a7402a0"}, + {file = "websockets-13.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c90d6dec6be2c7d03378a574de87af9b1efea77d0c52a8301dd831ece938452f"}, + {file = "websockets-13.1-cp310-cp310-win32.whl", hash = "sha256:730f42125ccb14602f455155084f978bd9e8e57e89b569b4d7f0f0c17a448ffe"}, + {file = "websockets-13.1-cp310-cp310-win_amd64.whl", hash = "sha256:5993260f483d05a9737073be197371940c01b257cc45ae3f1d5d7adb371b266a"}, + {file = "websockets-13.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:61fc0dfcda609cda0fc9fe7977694c0c59cf9d749fbb17f4e9483929e3c48a19"}, + {file = "websockets-13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ceec59f59d092c5007e815def4ebb80c2de330e9588e101cf8bd94c143ec78a5"}, + {file = "websockets-13.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c1dca61c6db1166c48b95198c0b7d9c990b30c756fc2923cc66f68d17dc558fd"}, + {file = "websockets-13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:308e20f22c2c77f3f39caca508e765f8725020b84aa963474e18c59accbf4c02"}, + {file = "websockets-13.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62d516c325e6540e8a57b94abefc3459d7dab8ce52ac75c96cad5549e187e3a7"}, + {file = "websockets-13.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87c6e35319b46b99e168eb98472d6c7d8634ee37750d7693656dc766395df096"}, + {file = "websockets-13.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5f9fee94ebafbc3117c30be1844ed01a3b177bb6e39088bc6b2fa1dc15572084"}, + {file = "websockets-13.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:7c1e90228c2f5cdde263253fa5db63e6653f1c00e7ec64108065a0b9713fa1b3"}, + {file = "websockets-13.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6548f29b0e401eea2b967b2fdc1c7c7b5ebb3eeb470ed23a54cd45ef078a0db9"}, + {file = "websockets-13.1-cp311-cp311-win32.whl", hash = "sha256:c11d4d16e133f6df8916cc5b7e3e96ee4c44c936717d684a94f48f82edb7c92f"}, + {file = "websockets-13.1-cp311-cp311-win_amd64.whl", hash = "sha256:d04f13a1d75cb2b8382bdc16ae6fa58c97337253826dfe136195b7f89f661557"}, + {file = "websockets-13.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:9d75baf00138f80b48f1eac72ad1535aac0b6461265a0bcad391fc5aba875cfc"}, + {file = "websockets-13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9b6f347deb3dcfbfde1c20baa21c2ac0751afaa73e64e5b693bb2b848efeaa49"}, + {file = "websockets-13.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:de58647e3f9c42f13f90ac7e5f58900c80a39019848c5547bc691693098ae1bd"}, + {file = "websockets-13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1b54689e38d1279a51d11e3467dd2f3a50f5f2e879012ce8f2d6943f00e83f0"}, + {file = "websockets-13.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cf1781ef73c073e6b0f90af841aaf98501f975d306bbf6221683dd594ccc52b6"}, + {file = "websockets-13.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d23b88b9388ed85c6faf0e74d8dec4f4d3baf3ecf20a65a47b836d56260d4b9"}, + {file = "websockets-13.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3c78383585f47ccb0fcf186dcb8a43f5438bd7d8f47d69e0b56f71bf431a0a68"}, + {file = "websockets-13.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:d6d300f8ec35c24025ceb9b9019ae9040c1ab2f01cddc2bcc0b518af31c75c14"}, + {file = "websockets-13.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a9dcaf8b0cc72a392760bb8755922c03e17a5a54e08cca58e8b74f6902b433cf"}, + {file = "websockets-13.1-cp312-cp312-win32.whl", hash = "sha256:2f85cf4f2a1ba8f602298a853cec8526c2ca42a9a4b947ec236eaedb8f2dc80c"}, + {file = "websockets-13.1-cp312-cp312-win_amd64.whl", hash = "sha256:38377f8b0cdeee97c552d20cf1865695fcd56aba155ad1b4ca8779a5b6ef4ac3"}, + {file = "websockets-13.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a9ab1e71d3d2e54a0aa646ab6d4eebfaa5f416fe78dfe4da2839525dc5d765c6"}, + {file = "websockets-13.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b9d7439d7fab4dce00570bb906875734df13d9faa4b48e261c440a5fec6d9708"}, + {file = "websockets-13.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:327b74e915cf13c5931334c61e1a41040e365d380f812513a255aa804b183418"}, + {file = "websockets-13.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:325b1ccdbf5e5725fdcb1b0e9ad4d2545056479d0eee392c291c1bf76206435a"}, + {file = "websockets-13.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:346bee67a65f189e0e33f520f253d5147ab76ae42493804319b5716e46dddf0f"}, + {file = "websockets-13.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91a0fa841646320ec0d3accdff5b757b06e2e5c86ba32af2e0815c96c7a603c5"}, + {file = "websockets-13.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:18503d2c5f3943e93819238bf20df71982d193f73dcecd26c94514f417f6b135"}, + {file = "websockets-13.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:a9cd1af7e18e5221d2878378fbc287a14cd527fdd5939ed56a18df8a31136bb2"}, + {file = "websockets-13.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:70c5be9f416aa72aab7a2a76c90ae0a4fe2755c1816c153c1a2bcc3333ce4ce6"}, + {file = "websockets-13.1-cp313-cp313-win32.whl", hash = "sha256:624459daabeb310d3815b276c1adef475b3e6804abaf2d9d2c061c319f7f187d"}, + {file = "websockets-13.1-cp313-cp313-win_amd64.whl", hash = "sha256:c518e84bb59c2baae725accd355c8dc517b4a3ed8db88b4bc93c78dae2974bf2"}, + {file = "websockets-13.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c7934fd0e920e70468e676fe7f1b7261c1efa0d6c037c6722278ca0228ad9d0d"}, + {file = "websockets-13.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:149e622dc48c10ccc3d2760e5f36753db9cacf3ad7bc7bbbfd7d9c819e286f23"}, + {file = "websockets-13.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a569eb1b05d72f9bce2ebd28a1ce2054311b66677fcd46cf36204ad23acead8c"}, + {file = "websockets-13.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:95df24ca1e1bd93bbca51d94dd049a984609687cb2fb08a7f2c56ac84e9816ea"}, + {file = "websockets-13.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8dbb1bf0c0a4ae8b40bdc9be7f644e2f3fb4e8a9aca7145bfa510d4a374eeb7"}, + {file = "websockets-13.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:035233b7531fb92a76beefcbf479504db8c72eb3bff41da55aecce3a0f729e54"}, + {file = "websockets-13.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:e4450fc83a3df53dec45922b576e91e94f5578d06436871dce3a6be38e40f5db"}, + {file = "websockets-13.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:463e1c6ec853202dd3657f156123d6b4dad0c546ea2e2e38be2b3f7c5b8e7295"}, + {file = "websockets-13.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6d6855bbe70119872c05107e38fbc7f96b1d8cb047d95c2c50869a46c65a8e96"}, + {file = "websockets-13.1-cp38-cp38-win32.whl", hash = "sha256:204e5107f43095012b00f1451374693267adbb832d29966a01ecc4ce1db26faf"}, + {file = "websockets-13.1-cp38-cp38-win_amd64.whl", hash = "sha256:485307243237328c022bc908b90e4457d0daa8b5cf4b3723fd3c4a8012fce4c6"}, + {file = "websockets-13.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9b37c184f8b976f0c0a231a5f3d6efe10807d41ccbe4488df8c74174805eea7d"}, + {file = "websockets-13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:163e7277e1a0bd9fb3c8842a71661ad19c6aa7bb3d6678dc7f89b17fbcc4aeb7"}, + {file = "websockets-13.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4b889dbd1342820cc210ba44307cf75ae5f2f96226c0038094455a96e64fb07a"}, + {file = "websockets-13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:586a356928692c1fed0eca68b4d1c2cbbd1ca2acf2ac7e7ebd3b9052582deefa"}, + {file = "websockets-13.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7bd6abf1e070a6b72bfeb71049d6ad286852e285f146682bf30d0296f5fbadfa"}, + {file = "websockets-13.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2aad13a200e5934f5a6767492fb07151e1de1d6079c003ab31e1823733ae79"}, + {file = "websockets-13.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:df01aea34b6e9e33572c35cd16bae5a47785e7d5c8cb2b54b2acdb9678315a17"}, + {file = "websockets-13.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:e54affdeb21026329fb0744ad187cf812f7d3c2aa702a5edb562b325191fcab6"}, + {file = "websockets-13.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:9ef8aa8bdbac47f4968a5d66462a2a0935d044bf35c0e5a8af152d58516dbeb5"}, + {file = "websockets-13.1-cp39-cp39-win32.whl", hash = "sha256:deeb929efe52bed518f6eb2ddc00cc496366a14c726005726ad62c2dd9017a3c"}, + {file = "websockets-13.1-cp39-cp39-win_amd64.whl", hash = "sha256:7c65ffa900e7cc958cd088b9a9157a8141c991f8c53d11087e6fb7277a03f81d"}, + {file = "websockets-13.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:5dd6da9bec02735931fccec99d97c29f47cc61f644264eb995ad6c0c27667238"}, + {file = "websockets-13.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:2510c09d8e8df777177ee3d40cd35450dc169a81e747455cc4197e63f7e7bfe5"}, + {file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1c3cf67185543730888b20682fb186fc8d0fa6f07ccc3ef4390831ab4b388d9"}, + {file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bcc03c8b72267e97b49149e4863d57c2d77f13fae12066622dc78fe322490fe6"}, + {file = "websockets-13.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:004280a140f220c812e65f36944a9ca92d766b6cc4560be652a0a3883a79ed8a"}, + {file = "websockets-13.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e2620453c075abeb0daa949a292e19f56de518988e079c36478bacf9546ced23"}, + {file = "websockets-13.1-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9156c45750b37337f7b0b00e6248991a047be4aa44554c9886fe6bdd605aab3b"}, + {file = "websockets-13.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:80c421e07973a89fbdd93e6f2003c17d20b69010458d3a8e37fb47874bd67d51"}, + {file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82d0ba76371769d6a4e56f7e83bb8e81846d17a6190971e38b5de108bde9b0d7"}, + {file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9875a0143f07d74dc5e1ded1c4581f0d9f7ab86c78994e2ed9e95050073c94d"}, + {file = "websockets-13.1-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a11e38ad8922c7961447f35c7b17bffa15de4d17c70abd07bfbe12d6faa3e027"}, + {file = "websockets-13.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:4059f790b6ae8768471cddb65d3c4fe4792b0ab48e154c9f0a04cefaabcd5978"}, + {file = "websockets-13.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:25c35bf84bf7c7369d247f0b8cfa157f989862c49104c5cf85cb5436a641d93e"}, + {file = "websockets-13.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:83f91d8a9bb404b8c2c41a707ac7f7f75b9442a0a876df295de27251a856ad09"}, + {file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a43cfdcddd07f4ca2b1afb459824dd3c6d53a51410636a2c7fc97b9a8cf4842"}, + {file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:48a2ef1381632a2f0cb4efeff34efa97901c9fbc118e01951ad7cfc10601a9bb"}, + {file = "websockets-13.1-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:459bf774c754c35dbb487360b12c5727adab887f1622b8aed5755880a21c4a20"}, + {file = "websockets-13.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:95858ca14a9f6fa8413d29e0a585b31b278388aa775b8a81fa24830123874678"}, + {file = "websockets-13.1-py3-none-any.whl", hash = "sha256:a9a396a6ad26130cdae92ae10c36af09d9bfe6cafe69670fd3b6da9b07b4044f"}, + {file = "websockets-13.1.tar.gz", hash = "sha256:a3b3366087c1bc0a2795111edcadddb8b3b59509d5db5d7ea3fdd69f954a8878"}, ] [package.source] @@ -12774,108 +13156,99 @@ reference = "aliyun" [[package]] name = "yarl" -version = "1.9.11" +version = "1.18.3" description = "Yet another URL library" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "yarl-1.9.11-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:79e08c691deae6fcac2fdde2e0515ac561dd3630d7c8adf7b1e786e22f1e193b"}, - {file = "yarl-1.9.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:752f4b5cf93268dc73c2ae994cc6d684b0dad5118bc87fbd965fd5d6dca20f45"}, - {file = "yarl-1.9.11-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:441049d3a449fb8756b0535be72c6a1a532938a33e1cf03523076700a5f87a01"}, - {file = "yarl-1.9.11-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3dfe17b4aed832c627319da22a33f27f282bd32633d6b145c726d519c89fbaf"}, - {file = "yarl-1.9.11-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:67abcb7df27952864440c9c85f1c549a4ad94afe44e2655f77d74b0d25895454"}, - {file = "yarl-1.9.11-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6de3fa29e76fd1518a80e6af4902c44f3b1b4d7fed28eb06913bba4727443de3"}, - {file = "yarl-1.9.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fee45b3bd4d8d5786472e056aa1359cc4dc9da68aded95a10cd7929a0ec661fe"}, - {file = "yarl-1.9.11-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c59b23886234abeba62087fd97d10fb6b905d9e36e2f3465d1886ce5c0ca30df"}, - {file = "yarl-1.9.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d93c612b2024ac25a3dc01341fd98fdd19c8c5e2011f3dcd084b3743cba8d756"}, - {file = "yarl-1.9.11-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:4d368e3b9ecd50fa22017a20c49e356471af6ae91c4d788c6e9297e25ddf5a62"}, - {file = "yarl-1.9.11-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:5b593acd45cdd4cf6664d342ceacedf25cd95263b83b964fddd6c78930ea5211"}, - {file = "yarl-1.9.11-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:224f8186c220ff00079e64bf193909829144d4e5174bb58665ef0da8bf6955c4"}, - {file = "yarl-1.9.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:91c478741d7563a12162f7a2db96c0d23d93b0521563f1f1f0ece46ea1702d33"}, - {file = "yarl-1.9.11-cp310-cp310-win32.whl", hash = "sha256:1cdb8f5bb0534986776a43df84031da7ff04ac0cf87cb22ae8a6368231949c40"}, - {file = "yarl-1.9.11-cp310-cp310-win_amd64.whl", hash = "sha256:498439af143b43a2b2314451ffd0295410aa0dcbdac5ee18fc8633da4670b605"}, - {file = "yarl-1.9.11-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9e290de5db4fd4859b4ed57cddfe793fcb218504e65781854a8ac283ab8d5518"}, - {file = "yarl-1.9.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e5f50a2e26cc2b89186f04c97e0ec0ba107ae41f1262ad16832d46849864f914"}, - {file = "yarl-1.9.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b4a0e724a28d7447e4d549c8f40779f90e20147e94bf949d490402eee09845c6"}, - {file = "yarl-1.9.11-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85333d38a4fa5997fa2ff6fd169be66626d814b34fa35ec669e8c914ca50a097"}, - {file = "yarl-1.9.11-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6ff184002ee72e4b247240e35d5dce4c2d9a0e81fdbef715dde79ab4718aa541"}, - {file = "yarl-1.9.11-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:675004040f847c0284827f44a1fa92d8baf425632cc93e7e0aa38408774b07c1"}, - {file = "yarl-1.9.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b30703a7ade2b53f02e09a30685b70cd54f65ed314a8d9af08670c9a5391af1b"}, - {file = "yarl-1.9.11-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7230007ab67d43cf19200ec15bc6b654e6b85c402f545a6fc565d254d34ff754"}, - {file = "yarl-1.9.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8c2cf0c7ad745e1c6530fe6521dfb19ca43338239dfcc7da165d0ef2332c0882"}, - {file = "yarl-1.9.11-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4567cc08f479ad80fb07ed0c9e1bcb363a4f6e3483a490a39d57d1419bf1c4c7"}, - {file = "yarl-1.9.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:95adc179a02949c4560ef40f8f650a008380766eb253d74232eb9c024747c111"}, - {file = "yarl-1.9.11-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:755ae9cff06c429632d750aa8206f08df2e3d422ca67be79567aadbe74ae64cc"}, - {file = "yarl-1.9.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:94f71d54c5faf715e92c8434b4a0b968c4d1043469954d228fc031d51086f143"}, - {file = "yarl-1.9.11-cp311-cp311-win32.whl", hash = "sha256:4ae079573efeaa54e5978ce86b77f4175cd32f42afcaf9bfb8a0677e91f84e4e"}, - {file = "yarl-1.9.11-cp311-cp311-win_amd64.whl", hash = "sha256:9fae7ec5c9a4fe22abb995804e6ce87067dfaf7e940272b79328ce37c8f22097"}, - {file = "yarl-1.9.11-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:614fa50fd0db41b79f426939a413d216cdc7bab8d8c8a25844798d286a999c5a"}, - {file = "yarl-1.9.11-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ff64f575d71eacb5a4d6f0696bfe991993d979423ea2241f23ab19ff63f0f9d1"}, - {file = "yarl-1.9.11-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5c23f6dc3d7126b4c64b80aa186ac2bb65ab104a8372c4454e462fb074197bc6"}, - {file = "yarl-1.9.11-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b8f847cc092c2b85d22e527f91ea83a6cf51533e727e2461557a47a859f96734"}, - {file = "yarl-1.9.11-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:63a5dc2866791236779d99d7a422611d22bb3a3d50935bafa4e017ea13e51469"}, - {file = "yarl-1.9.11-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c335342d482e66254ae94b1231b1532790afb754f89e2e0c646f7f19d09740aa"}, - {file = "yarl-1.9.11-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4a8c3dedd081cca134a21179aebe58b6e426e8d1e0202da9d1cafa56e01af3c"}, - {file = "yarl-1.9.11-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:504d19320c92532cabc3495fb7ed6bb599f3c2bfb45fed432049bf4693dbd6d0"}, - {file = "yarl-1.9.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0b2a8e5eb18181060197e3d5db7e78f818432725c0759bc1e5a9d603d9246389"}, - {file = "yarl-1.9.11-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:f568d70b7187f4002b6b500c0996c37674a25ce44b20716faebe5fdb8bd356e7"}, - {file = "yarl-1.9.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:735b285ea46ca7e86ad261a462a071d0968aade44e1a3ea2b7d4f3d63b5aab12"}, - {file = "yarl-1.9.11-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:2d1c81c3b92bef0c1c180048e43a5a85754a61b4f69d6f84df8e4bd615bef25d"}, - {file = "yarl-1.9.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8d6e1c1562b53bd26efd38e886fc13863b8d904d559426777990171020c478a9"}, - {file = "yarl-1.9.11-cp312-cp312-win32.whl", hash = "sha256:aeba4aaa59cb709edb824fa88a27cbbff4e0095aaf77212b652989276c493c00"}, - {file = "yarl-1.9.11-cp312-cp312-win_amd64.whl", hash = "sha256:569309a3efb8369ff5d32edb2a0520ebaf810c3059f11d34477418c90aa878fd"}, - {file = "yarl-1.9.11-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:4915818ac850c3b0413e953af34398775b7a337babe1e4d15f68c8f5c4872553"}, - {file = "yarl-1.9.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ef9610b2f5a73707d4d8bac040f0115ca848e510e3b1f45ca53e97f609b54130"}, - {file = "yarl-1.9.11-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:47c0a3dc8076a8dd159de10628dea04215bc7ddaa46c5775bf96066a0a18f82b"}, - {file = "yarl-1.9.11-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:545f2fbfa0c723b446e9298b5beba0999ff82ce2c126110759e8dac29b5deaf4"}, - {file = "yarl-1.9.11-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9137975a4ccc163ad5d7a75aad966e6e4e95dedee08d7995eab896a639a0bce2"}, - {file = "yarl-1.9.11-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0b0c70c451d2a86f8408abced5b7498423e2487543acf6fcf618b03f6e669b0a"}, - {file = "yarl-1.9.11-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce2bd986b1e44528677c237b74d59f215c8bfcdf2d69442aa10f62fd6ab2951c"}, - {file = "yarl-1.9.11-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d7b717f77846a9631046899c6cc730ea469c0e2fb252ccff1cc119950dbc296"}, - {file = "yarl-1.9.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:3a26a24bbd19241283d601173cea1e5b93dec361a223394e18a1e8e5b0ef20bd"}, - {file = "yarl-1.9.11-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:c189bf01af155ac9882e128d9f3b3ad68a1f2c2f51404afad7201305df4e12b1"}, - {file = "yarl-1.9.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:0cbcc2c54084b2bda4109415631db017cf2960f74f9e8fd1698e1400e4f8aae2"}, - {file = "yarl-1.9.11-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:30f201bc65941a4aa59c1236783efe89049ec5549dafc8cd2b63cc179d3767b0"}, - {file = "yarl-1.9.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:922ba3b74f0958a0b5b9c14ff1ef12714a381760c08018f2b9827632783a590c"}, - {file = "yarl-1.9.11-cp313-cp313-win32.whl", hash = "sha256:17107b4b8c43e66befdcbe543fff2f9c93f7a3a9f8e3a9c9ac42bffeba0e8828"}, - {file = "yarl-1.9.11-cp313-cp313-win_amd64.whl", hash = "sha256:0324506afab4f2e176a93cb08b8abcb8b009e1f324e6cbced999a8f5dd9ddb76"}, - {file = "yarl-1.9.11-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:4e4f820fde9437bb47297194f43d29086433e6467fa28fe9876366ad357bd7bb"}, - {file = "yarl-1.9.11-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:dfa9b9d5c9c0dbe69670f5695264452f5e40947590ec3a38cfddc9640ae8ff89"}, - {file = "yarl-1.9.11-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e700eb26635ce665c018c8cfea058baff9b843ed0cc77aa61849d807bb82a64c"}, - {file = "yarl-1.9.11-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c305c1bdf10869b5e51facf50bd5b15892884aeae81962ae4ba061fc11217103"}, - {file = "yarl-1.9.11-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5b7b307140231ea4f7aad5b69355aba2a67f2d7bc34271cffa3c9c324d35b27"}, - {file = "yarl-1.9.11-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a744bdeda6c86cf3025c94eb0e01ccabe949cf385cd75b6576a3ac9669404b68"}, - {file = "yarl-1.9.11-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e8ed183c7a8f75e40068333fc185566472a8f6c77a750cf7541e11810576ea5"}, - {file = "yarl-1.9.11-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c1db9a4384694b5d20bdd9cb53f033b0831ac816416ab176c8d0997835015d22"}, - {file = "yarl-1.9.11-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:70194da6e99713250aa3f335a7fa246b36adf53672a2bcd0ddaa375d04e53dc0"}, - {file = "yarl-1.9.11-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:ddad5cfcda729e22422bb1c85520bdf2770ce6d975600573ac9017fe882f4b7e"}, - {file = "yarl-1.9.11-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:ca35996e0a4bed28fa0640d9512d37952f6b50dea583bcc167d4f0b1e112ac7f"}, - {file = "yarl-1.9.11-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:61ec0e80970b21a8f3c4b97fa6c6d181c6c6a135dbc7b4a601a78add3feeb209"}, - {file = "yarl-1.9.11-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:9636e4519f6c7558fdccf8f91e6e3b98df2340dc505c4cc3286986d33f2096c2"}, - {file = "yarl-1.9.11-cp38-cp38-win32.whl", hash = "sha256:58081cea14b8feda57c7ce447520e9d0a96c4d010cce54373d789c13242d7083"}, - {file = "yarl-1.9.11-cp38-cp38-win_amd64.whl", hash = "sha256:7d2dee7d6485807c0f64dd5eab9262b7c0b34f760e502243dd83ec09d647d5e1"}, - {file = "yarl-1.9.11-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d65ad67f981e93ea11f87815f67d086c4f33da4800cf2106d650dd8a0b79dda4"}, - {file = "yarl-1.9.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:752c0d33b4aacdb147871d0754b88f53922c6dc2aff033096516b3d5f0c02a0f"}, - {file = "yarl-1.9.11-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:54cc24be98d7f4ff355ca2e725a577e19909788c0db6beead67a0dda70bd3f82"}, - {file = "yarl-1.9.11-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c82126817492bb2ebc946e74af1ffa10aacaca81bee360858477f96124be39a"}, - {file = "yarl-1.9.11-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8503989860d7ac10c85cb5b607fec003a45049cf7a5b4b72451e87893c6bb990"}, - {file = "yarl-1.9.11-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:475e09a67f8b09720192a170ad9021b7abf7827ffd4f3a83826317a705be06b7"}, - {file = "yarl-1.9.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afcac5bda602b74ff701e1f683feccd8cce0d5a21dbc68db81bf9bd8fd93ba56"}, - {file = "yarl-1.9.11-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aaeffcb84faceb2923a94a8a9aaa972745d3c728ab54dd011530cc30a3d5d0c1"}, - {file = "yarl-1.9.11-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:51a6f770ac86477cd5c553f88a77a06fe1f6f3b643b053fcc7902ab55d6cbe14"}, - {file = "yarl-1.9.11-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:3fcd056cb7dff3aea5b1ee1b425b0fbaa2fbf6a1c6003e88caf524f01de5f395"}, - {file = "yarl-1.9.11-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:21e56c30e39a1833e4e3fd0112dde98c2abcbc4c39b077e6105c76bb63d2aa04"}, - {file = "yarl-1.9.11-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:0a205ec6349879f5e75dddfb63e069a24f726df5330b92ce76c4752a436aac01"}, - {file = "yarl-1.9.11-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:a5706821e1cf3c70dfea223e4e0958ea354f4e2af9420a1bd45c6b547297fb97"}, - {file = "yarl-1.9.11-cp39-cp39-win32.whl", hash = "sha256:cc295969f8c2172b5d013c0871dccfec7a0e1186cf961e7ea575d47b4d5cbd32"}, - {file = "yarl-1.9.11-cp39-cp39-win_amd64.whl", hash = "sha256:55a67dd29367ce7c08a0541bb602ec0a2c10d46c86b94830a1a665f7fd093dfa"}, - {file = "yarl-1.9.11-py3-none-any.whl", hash = "sha256:c6f6c87665a9e18a635f0545ea541d9640617832af2317d4f5ad389686b4ed3d"}, - {file = "yarl-1.9.11.tar.gz", hash = "sha256:c7548a90cb72b67652e2cd6ae80e2683ee08fde663104528ac7df12d8ef271d2"}, + {file = "yarl-1.18.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7df647e8edd71f000a5208fe6ff8c382a1de8edfbccdbbfe649d263de07d8c34"}, + {file = "yarl-1.18.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c69697d3adff5aa4f874b19c0e4ed65180ceed6318ec856ebc423aa5850d84f7"}, + {file = "yarl-1.18.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:602d98f2c2d929f8e697ed274fbadc09902c4025c5a9963bf4e9edfc3ab6f7ed"}, + {file = "yarl-1.18.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c654d5207c78e0bd6d749f6dae1dcbbfde3403ad3a4b11f3c5544d9906969dde"}, + {file = "yarl-1.18.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5094d9206c64181d0f6e76ebd8fb2f8fe274950a63890ee9e0ebfd58bf9d787b"}, + {file = "yarl-1.18.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:35098b24e0327fc4ebdc8ffe336cee0a87a700c24ffed13161af80124b7dc8e5"}, + {file = "yarl-1.18.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3236da9272872443f81fedc389bace88408f64f89f75d1bdb2256069a8730ccc"}, + {file = "yarl-1.18.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2c08cc9b16f4f4bc522771d96734c7901e7ebef70c6c5c35dd0f10845270bcd"}, + {file = "yarl-1.18.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:80316a8bd5109320d38eef8833ccf5f89608c9107d02d2a7f985f98ed6876990"}, + {file = "yarl-1.18.3-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:c1e1cc06da1491e6734f0ea1e6294ce00792193c463350626571c287c9a704db"}, + {file = "yarl-1.18.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fea09ca13323376a2fdfb353a5fa2e59f90cd18d7ca4eaa1fd31f0a8b4f91e62"}, + {file = "yarl-1.18.3-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:e3b9fd71836999aad54084906f8663dffcd2a7fb5cdafd6c37713b2e72be1760"}, + {file = "yarl-1.18.3-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:757e81cae69244257d125ff31663249b3013b5dc0a8520d73694aed497fb195b"}, + {file = "yarl-1.18.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b1771de9944d875f1b98a745bc547e684b863abf8f8287da8466cf470ef52690"}, + {file = "yarl-1.18.3-cp310-cp310-win32.whl", hash = "sha256:8874027a53e3aea659a6d62751800cf6e63314c160fd607489ba5c2edd753cf6"}, + {file = "yarl-1.18.3-cp310-cp310-win_amd64.whl", hash = "sha256:93b2e109287f93db79210f86deb6b9bbb81ac32fc97236b16f7433db7fc437d8"}, + {file = "yarl-1.18.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8503ad47387b8ebd39cbbbdf0bf113e17330ffd339ba1144074da24c545f0069"}, + {file = "yarl-1.18.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:02ddb6756f8f4517a2d5e99d8b2f272488e18dd0bfbc802f31c16c6c20f22193"}, + {file = "yarl-1.18.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:67a283dd2882ac98cc6318384f565bffc751ab564605959df4752d42483ad889"}, + {file = "yarl-1.18.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d980e0325b6eddc81331d3f4551e2a333999fb176fd153e075c6d1c2530aa8a8"}, + {file = "yarl-1.18.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b643562c12680b01e17239be267bc306bbc6aac1f34f6444d1bded0c5ce438ca"}, + {file = "yarl-1.18.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c017a3b6df3a1bd45b9fa49a0f54005e53fbcad16633870104b66fa1a30a29d8"}, + {file = "yarl-1.18.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75674776d96d7b851b6498f17824ba17849d790a44d282929c42dbb77d4f17ae"}, + {file = "yarl-1.18.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ccaa3a4b521b780a7e771cc336a2dba389a0861592bbce09a476190bb0c8b4b3"}, + {file = "yarl-1.18.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2d06d3005e668744e11ed80812e61efd77d70bb7f03e33c1598c301eea20efbb"}, + {file = "yarl-1.18.3-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:9d41beda9dc97ca9ab0b9888cb71f7539124bc05df02c0cff6e5acc5a19dcc6e"}, + {file = "yarl-1.18.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:ba23302c0c61a9999784e73809427c9dbedd79f66a13d84ad1b1943802eaaf59"}, + {file = "yarl-1.18.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:6748dbf9bfa5ba1afcc7556b71cda0d7ce5f24768043a02a58846e4a443d808d"}, + {file = "yarl-1.18.3-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:0b0cad37311123211dc91eadcb322ef4d4a66008d3e1bdc404808992260e1a0e"}, + {file = "yarl-1.18.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0fb2171a4486bb075316ee754c6d8382ea6eb8b399d4ec62fde2b591f879778a"}, + {file = "yarl-1.18.3-cp311-cp311-win32.whl", hash = "sha256:61b1a825a13bef4a5f10b1885245377d3cd0bf87cba068e1d9a88c2ae36880e1"}, + {file = "yarl-1.18.3-cp311-cp311-win_amd64.whl", hash = "sha256:b9d60031cf568c627d028239693fd718025719c02c9f55df0a53e587aab951b5"}, + {file = "yarl-1.18.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1dd4bdd05407ced96fed3d7f25dbbf88d2ffb045a0db60dbc247f5b3c5c25d50"}, + {file = "yarl-1.18.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7c33dd1931a95e5d9a772d0ac5e44cac8957eaf58e3c8da8c1414de7dd27c576"}, + {file = "yarl-1.18.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:25b411eddcfd56a2f0cd6a384e9f4f7aa3efee14b188de13048c25b5e91f1640"}, + {file = "yarl-1.18.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:436c4fc0a4d66b2badc6c5fc5ef4e47bb10e4fd9bf0c79524ac719a01f3607c2"}, + {file = "yarl-1.18.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e35ef8683211db69ffe129a25d5634319a677570ab6b2eba4afa860f54eeaf75"}, + {file = "yarl-1.18.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:84b2deecba4a3f1a398df819151eb72d29bfeb3b69abb145a00ddc8d30094512"}, + {file = "yarl-1.18.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00e5a1fea0fd4f5bfa7440a47eff01d9822a65b4488f7cff83155a0f31a2ecba"}, + {file = "yarl-1.18.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0e883008013c0e4aef84dcfe2a0b172c4d23c2669412cf5b3371003941f72bb"}, + {file = "yarl-1.18.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5a3f356548e34a70b0172d8890006c37be92995f62d95a07b4a42e90fba54272"}, + {file = "yarl-1.18.3-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:ccd17349166b1bee6e529b4add61727d3f55edb7babbe4069b5764c9587a8cc6"}, + {file = "yarl-1.18.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b958ddd075ddba5b09bb0be8a6d9906d2ce933aee81100db289badbeb966f54e"}, + {file = "yarl-1.18.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c7d79f7d9aabd6011004e33b22bc13056a3e3fb54794d138af57f5ee9d9032cb"}, + {file = "yarl-1.18.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:4891ed92157e5430874dad17b15eb1fda57627710756c27422200c52d8a4e393"}, + {file = "yarl-1.18.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ce1af883b94304f493698b00d0f006d56aea98aeb49d75ec7d98cd4a777e9285"}, + {file = "yarl-1.18.3-cp312-cp312-win32.whl", hash = "sha256:f91c4803173928a25e1a55b943c81f55b8872f0018be83e3ad4938adffb77dd2"}, + {file = "yarl-1.18.3-cp312-cp312-win_amd64.whl", hash = "sha256:7e2ee16578af3b52ac2f334c3b1f92262f47e02cc6193c598502bd46f5cd1477"}, + {file = "yarl-1.18.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:90adb47ad432332d4f0bc28f83a5963f426ce9a1a8809f5e584e704b82685dcb"}, + {file = "yarl-1.18.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:913829534200eb0f789d45349e55203a091f45c37a2674678744ae52fae23efa"}, + {file = "yarl-1.18.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ef9f7768395923c3039055c14334ba4d926f3baf7b776c923c93d80195624782"}, + {file = "yarl-1.18.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88a19f62ff30117e706ebc9090b8ecc79aeb77d0b1f5ec10d2d27a12bc9f66d0"}, + {file = "yarl-1.18.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e17c9361d46a4d5addf777c6dd5eab0715a7684c2f11b88c67ac37edfba6c482"}, + {file = "yarl-1.18.3-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1a74a13a4c857a84a845505fd2d68e54826a2cd01935a96efb1e9d86c728e186"}, + {file = "yarl-1.18.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41f7ce59d6ee7741af71d82020346af364949314ed3d87553763a2df1829cc58"}, + {file = "yarl-1.18.3-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f52a265001d830bc425f82ca9eabda94a64a4d753b07d623a9f2863fde532b53"}, + {file = "yarl-1.18.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:82123d0c954dc58db301f5021a01854a85bf1f3bb7d12ae0c01afc414a882ca2"}, + {file = "yarl-1.18.3-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:2ec9bbba33b2d00999af4631a3397d1fd78290c48e2a3e52d8dd72db3a067ac8"}, + {file = "yarl-1.18.3-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:fbd6748e8ab9b41171bb95c6142faf068f5ef1511935a0aa07025438dd9a9bc1"}, + {file = "yarl-1.18.3-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:877d209b6aebeb5b16c42cbb377f5f94d9e556626b1bfff66d7b0d115be88d0a"}, + {file = "yarl-1.18.3-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:b464c4ab4bfcb41e3bfd3f1c26600d038376c2de3297760dfe064d2cb7ea8e10"}, + {file = "yarl-1.18.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8d39d351e7faf01483cc7ff7c0213c412e38e5a340238826be7e0e4da450fdc8"}, + {file = "yarl-1.18.3-cp313-cp313-win32.whl", hash = "sha256:61ee62ead9b68b9123ec24bc866cbef297dd266175d53296e2db5e7f797f902d"}, + {file = "yarl-1.18.3-cp313-cp313-win_amd64.whl", hash = "sha256:578e281c393af575879990861823ef19d66e2b1d0098414855dd367e234f5b3c"}, + {file = "yarl-1.18.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:61e5e68cb65ac8f547f6b5ef933f510134a6bf31bb178be428994b0cb46c2a04"}, + {file = "yarl-1.18.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:fe57328fbc1bfd0bd0514470ac692630f3901c0ee39052ae47acd1d90a436719"}, + {file = "yarl-1.18.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a440a2a624683108a1b454705ecd7afc1c3438a08e890a1513d468671d90a04e"}, + {file = "yarl-1.18.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09c7907c8548bcd6ab860e5f513e727c53b4a714f459b084f6580b49fa1b9cee"}, + {file = "yarl-1.18.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b4f6450109834af88cb4cc5ecddfc5380ebb9c228695afc11915a0bf82116789"}, + {file = "yarl-1.18.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a9ca04806f3be0ac6d558fffc2fdf8fcef767e0489d2684a21912cc4ed0cd1b8"}, + {file = "yarl-1.18.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:77a6e85b90a7641d2e07184df5557132a337f136250caafc9ccaa4a2a998ca2c"}, + {file = "yarl-1.18.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6333c5a377c8e2f5fae35e7b8f145c617b02c939d04110c76f29ee3676b5f9a5"}, + {file = "yarl-1.18.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0b3c92fa08759dbf12b3a59579a4096ba9af8dd344d9a813fc7f5070d86bbab1"}, + {file = "yarl-1.18.3-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:4ac515b860c36becb81bb84b667466885096b5fc85596948548b667da3bf9f24"}, + {file = "yarl-1.18.3-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:045b8482ce9483ada4f3f23b3774f4e1bf4f23a2d5c912ed5170f68efb053318"}, + {file = "yarl-1.18.3-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:a4bb030cf46a434ec0225bddbebd4b89e6471814ca851abb8696170adb163985"}, + {file = "yarl-1.18.3-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:54d6921f07555713b9300bee9c50fb46e57e2e639027089b1d795ecd9f7fa910"}, + {file = "yarl-1.18.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1d407181cfa6e70077df3377938c08012d18893f9f20e92f7d2f314a437c30b1"}, + {file = "yarl-1.18.3-cp39-cp39-win32.whl", hash = "sha256:ac36703a585e0929b032fbaab0707b75dc12703766d0b53486eabd5139ebadd5"}, + {file = "yarl-1.18.3-cp39-cp39-win_amd64.whl", hash = "sha256:ba87babd629f8af77f557b61e49e7c7cac36f22f871156b91e10a6e9d4f829e9"}, + {file = "yarl-1.18.3-py3-none-any.whl", hash = "sha256:b57f4f58099328dfb26c6a771d09fb20dbbae81d20cfb66141251ea063bd101b"}, + {file = "yarl-1.18.3.tar.gz", hash = "sha256:ac1801c45cbf77b6c99242eeff4fffb5e4e73a800b5c4ad4fc0be5def634d2e1"}, ] [package.dependencies] idna = ">=2.0" multidict = ">=4.0" +propcache = ">=0.2.0" [package.source] type = "legacy" @@ -12937,13 +13310,13 @@ reference = "aliyun" [[package]] name = "zhipuai" -version = "2.1.5.20241204" +version = "2.1.5.20250106" description = "A SDK library for accessing big model apis from ZhipuAI" optional = false python-versions = ">=3.8, !=2.7.*, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*, !=3.7.*" files = [ - {file = "zhipuai-2.1.5.20241204-py3-none-any.whl", hash = "sha256:063c7527d6741ced82eedb19d53fd24ce61cf43ab835ee3c0262843f59503a7c"}, - {file = "zhipuai-2.1.5.20241204.tar.gz", hash = "sha256:888b42a83c8f1daf07375b84e560219eedab96b9f9e59542f0329928291db635"}, + {file = "zhipuai-2.1.5.20250106-py3-none-any.whl", hash = "sha256:ca76095f32db501e36038fc1ac4b287b88ed90c4cdd28902d3b1a9365fff879b"}, + {file = "zhipuai-2.1.5.20250106.tar.gz", hash = "sha256:45d391be336a210b360f126443f07882fa6d8184a148c46a8c7d0b7607d6d1f8"}, ] [package.dependencies] @@ -13184,4 +13557,4 @@ reference = "aliyun" [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "48c5ffa65c830817f094a5e7f34b0ffe59cc289a3d01231f8eae26b9b50725e3" +content-hash = "f6df71353c050214fd6efd3ec3005b376fa433d229f4107235cfc03fba0050e3" diff --git a/api/pyproject.toml b/api/pyproject.toml index db5b55c67..79822d57e 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -59,7 +59,10 @@ numpy = "~1.26.4" oci = "~2.135.1" openai = "~1.52.0" openpyxl = "~3.1.5" +opik = "~1.3.4" pandas = { version = "~2.2.2", extras = ["performance", "excel"] } +pandas-stubs = "~2.2.3.241009" +psycogreen = "~1.0.2" psycopg2-binary = "~2.9.6" pycryptodome = "3.19.1" pydantic = "~2.9.2" @@ -69,7 +72,7 @@ pyjwt = "~2.8.0" pypdfium2 = "~4.30.0" python = ">=3.11,<3.13" python-docx = "~1.1.0" -python-dotenv = "1.0.0" +python-dotenv = "1.0.1" pyyaml = "~6.0.1" readabilipy = "0.2.0" redis = { version = "~5.0.3", extras = ["hiredis"] } @@ -80,16 +83,17 @@ scikit-learn = "~1.5.1" sentry-sdk = { version = "~1.44.1", extras = ["flask"] } sqlalchemy = "~2.0.29" starlette = "0.41.0" -tencentcloud-sdk-python-hunyuan = "~3.0.1158" +tencentcloud-sdk-python-hunyuan = "~3.0.1294" tiktoken = "~0.8.0" tokenizers = "~0.15.0" transformers = "~4.35.0" +types-pytz = "~2024.2.0.20241003" unstructured = { version = "~0.16.1", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] } validators = "0.21.0" volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"} websocket-client = "~1.7.0" xinference-client = "0.15.2" -yarl = "~1.9.4" +yarl = "~1.18.3" youtube-transcript-api = "~0.6.2" zhipuai = "~2.1.5" ##### start extend ###### @@ -158,7 +162,7 @@ opensearch-py = "2.4.0" oracledb = "~2.2.1" pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] } pgvector = "0.2.5" -pymilvus = "~2.4.4" +pymilvus = "~2.5.0" pymochow = "1.3.1" pyobvector = "~0.1.6" qdrant-client = "1.7.3" @@ -177,6 +181,7 @@ optional = true [tool.poetry.group.dev.dependencies] coverage = "~7.2.4" faker = "~32.1.0" +mypy = "~1.13.0" pytest = "~8.3.2" pytest-benchmark = "~4.0.0" pytest-env = "~1.1.3" diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 97e5c77e9..5e4d3ec32 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -28,12 +28,12 @@ def clean_messages(): plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta( days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING ) - page = 1 while True: try: # Main query with join and filter + # FIXME:for mypy no paginate method error messages = ( - db.session.query(Message) + db.session.query(Message) # type: ignore .filter(Message.created_at < plan_sandbox_clean_message_day) .order_by(Message.created_at.desc()) .limit(100) @@ -78,4 +78,4 @@ def clean_messages(): db.session.query(Message).filter(Message.id == message.id).delete() db.session.commit() end_at = time.perf_counter() - click.echo(click.style("Cleaned unused dataset from db success latency: {}".format(end_at - start_at), fg="green")) + click.echo(click.style("Cleaned messages from db success latency: {}".format(end_at - start_at), fg="green")) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index e12be649e..4e7e443c2 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -10,7 +10,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.dataset import Dataset, DatasetQuery, Document +from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Document from services.feature_service import FeatureService @@ -52,8 +52,7 @@ def clean_unused_datasets_task(): # Main query with join and filter datasets = ( - db.session.query(Dataset) - .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .filter( Dataset.created_at < plan_sandbox_clean_day, @@ -76,6 +75,23 @@ def clean_unused_datasets_task(): ) if not dataset_query or len(dataset_query) == 0: try: + # add auto disable log + documents = ( + db.session.query(Document) + .filter( + Document.dataset_id == dataset.id, + Document.enabled == True, + Document.archived == False, + ) + .all() + ) + for document in documents: + dataset_auto_disable_log = DatasetAutoDisableLog( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + document_id=document.id, + ) + db.session.add(dataset_auto_disable_log) # remove index index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() index_processor.clean(dataset, None) @@ -120,8 +136,7 @@ def clean_unused_datasets_task(): # Main query with join and filter datasets = ( - db.session.query(Dataset) - .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .filter( Dataset.created_at < plan_pro_clean_day, diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index a20b50030..1c985461c 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -36,14 +36,15 @@ def create_tidb_serverless_task(): def create_clusters(batch_size): try: + # TODO: maybe we can set the default value for the following parameters in the config file new_clusters = TidbService.batch_create_tidb_serverless_cluster( - batch_size, - dify_config.TIDB_PROJECT_ID, - dify_config.TIDB_API_URL, - dify_config.TIDB_IAM_API_URL, - dify_config.TIDB_PUBLIC_KEY, - dify_config.TIDB_PRIVATE_KEY, - dify_config.TIDB_REGION, + batch_size=batch_size, + project_id=dify_config.TIDB_PROJECT_ID or "", + api_url=dify_config.TIDB_API_URL or "", + iam_url=dify_config.TIDB_IAM_API_URL or "", + public_key=dify_config.TIDB_PUBLIC_KEY or "", + private_key=dify_config.TIDB_PRIVATE_KEY or "", + region=dify_config.TIDB_REGION or "", ) for new_cluster in new_clusters: tidb_auth_binding = TidbAuthBinding( diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py new file mode 100644 index 000000000..fe6839288 --- /dev/null +++ b/api/schedule/mail_clean_document_notify_task.py @@ -0,0 +1,90 @@ +import logging +import time +from collections import defaultdict + +import click +from flask import render_template # type: ignore + +import app +from configs import dify_config +from extensions.ext_database import db +from extensions.ext_mail import mail +from models.account import Account, Tenant, TenantAccountJoin +from models.dataset import Dataset, DatasetAutoDisableLog +from services.feature_service import FeatureService + + +@app.celery.task(queue="dataset") +def send_document_clean_notify_task(): + """ + Async Send document clean notify mail + + Usage: send_document_clean_notify_task.delay() + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start send document clean notify mail", fg="green")) + start_at = time.perf_counter() + + # send document clean notify mail + try: + dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() + # group by tenant_id + dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) + for dataset_auto_disable_log in dataset_auto_disable_logs: + if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map: + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = [] + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) + url = f"{dify_config.CONSOLE_WEB_URL}/datasets" + for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): + features = FeatureService.get_features(tenant_id) + plan = features.billing.subscription.plan + if plan != "sandbox": + knowledge_details = [] + # check tenant + tenant = Tenant.query.filter(Tenant.id == tenant_id).first() + if not tenant: + continue + # check current owner + current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() + if not current_owner_join: + continue + account = Account.query.filter(Account.id == current_owner_join.account_id).first() + if not account: + continue + + dataset_auto_dataset_map = {} # type: ignore + for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: + if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = [] + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( + dataset_auto_disable_log.document_id + ) + + for dataset_id, document_ids in dataset_auto_dataset_map.items(): + dataset = Dataset.query.filter(Dataset.id == dataset_id).first() + if dataset: + document_count = len(document_ids) + knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") + if knowledge_details: + html_content = render_template( + "clean_document_job_mail_template-US.html", + userName=account.email, + knowledge_details=knowledge_details, + url=url, + ) + mail.send( + to=account.email, subject="Dify Knowledge base auto disable notification", html=html_content + ) + + # update notified to True + for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: + dataset_auto_disable_log.notified = True + db.session.commit() + end_at = time.perf_counter() + logging.info( + click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green") + ) + except Exception: + logging.exception("Send document clean notify mail failed") diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index b2d8746f9..11a39e60e 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -36,13 +36,14 @@ def update_clusters(tidb_serverless_list: list[TidbAuthBinding]): # batch 20 for i in range(0, len(tidb_serverless_list), 20): items = tidb_serverless_list[i : i + 20] + # TODO: maybe we can set the default value for the following parameters in the config file TidbService.batch_update_tidb_serverless_cluster_status( - items, - dify_config.TIDB_PROJECT_ID, - dify_config.TIDB_API_URL, - dify_config.TIDB_IAM_API_URL, - dify_config.TIDB_PUBLIC_KEY, - dify_config.TIDB_PRIVATE_KEY, + tidb_serverless_list=items, + project_id=dify_config.TIDB_PROJECT_ID or "", + api_url=dify_config.TIDB_API_URL or "", + iam_url=dify_config.TIDB_IAM_API_URL or "", + public_key=dify_config.TIDB_PUBLIC_KEY or "", + private_key=dify_config.TIDB_PRIVATE_KEY or "", ) except Exception as e: click.echo(click.style(f"Error: {e}", fg="red")) diff --git a/api/services/account_service.py b/api/services/account_service.py index 22b54a3ab..dd1cc5f94 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -6,7 +6,7 @@ import uuid from datetime import UTC, datetime, timedelta from hashlib import sha256 -from typing import Any, Optional +from typing import Any, Optional, cast from pydantic import BaseModel from sqlalchemy import func @@ -32,6 +32,7 @@ TenantStatus, ) from models.model import DifySetup +from services.billing_service import BillingService from services.errors.account import ( AccountAlreadyInTenantError, AccountLoginError, @@ -50,6 +51,8 @@ ) from services.errors.workspace import WorkSpaceNotAllowedCreateError from services.feature_service import FeatureService +from tasks.delete_account_task import delete_account_task +from tasks.mail_account_deletion_task import send_account_deletion_verification_code from tasks.mail_email_code_login import send_email_code_login_mail_task from tasks.mail_invite_member_task import send_invite_member_mail_task from tasks.mail_reset_password_task import send_reset_password_mail_task @@ -62,7 +65,7 @@ class TokenPair(BaseModel): REFRESH_TOKEN_PREFIX = "refresh_token:" ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:" -REFRESH_TOKEN_EXPIRY = timedelta(days=30) +REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS) class AccountService: @@ -70,6 +73,9 @@ class AccountService: email_code_login_rate_limiter = RateLimiter( prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1 ) + email_code_account_deletion_rate_limiter = RateLimiter( + prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1 + ) LOGIN_MAX_ERROR_LIMITS = 5 @staticmethod @@ -119,7 +125,7 @@ def load_user(user_id: str) -> None | Account: account.last_active_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() - return account + return cast(Account, account) @staticmethod def get_account_jwt_token(account: Account) -> str: @@ -132,7 +138,7 @@ def get_account_jwt_token(account: Account) -> str: "sub": "Console API Passport", } - token = PassportService().issue(payload) + token: str = PassportService().issue(payload) return token @staticmethod @@ -164,7 +170,7 @@ def authenticate(email: str, password: str, invite_token: Optional[str] = None) db.session.commit() - return account + return cast(Account, account) @staticmethod def update_account_password(account, password, new_password): @@ -201,6 +207,15 @@ def create_account( from controllers.console.error import AccountNotFound raise AccountNotFound() + + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email): + raise AccountRegisterError( + description=( + "This email account has been deleted within the past " + "30 days and is temporarily unavailable for new account registration" + ) + ) + account = Account() account.email = email account.name = name @@ -240,6 +255,42 @@ def create_account_and_tenant( return account + @staticmethod + def generate_account_deletion_verification_code(account: Account) -> tuple[str, str]: + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + token = TokenManager.generate_token( + account=account, token_type="account_deletion", additional_data={"code": code} + ) + return token, code + + @classmethod + def send_account_deletion_verification_email(cls, account: Account, code: str): + email = account.email + if cls.email_code_account_deletion_rate_limiter.is_rate_limited(email): + from controllers.console.auth.error import EmailCodeAccountDeletionRateLimitExceededError + + raise EmailCodeAccountDeletionRateLimitExceededError() + + send_account_deletion_verification_code.delay(to=email, code=code) + + cls.email_code_account_deletion_rate_limiter.increment_rate_limit(email) + + @staticmethod + def verify_account_deletion_code(token: str, code: str) -> bool: + token_data = TokenManager.get_token_data(token, "account_deletion") + if token_data is None: + return False + + if token_data["code"] != code: + return False + + return True + + @staticmethod + def delete_account(account: Account) -> None: + """Delete account. This method only adds a task to the queue for deletion.""" + delete_account_task.delay(account.id) + @staticmethod def link_account_integrate(provider: str, open_id: str, account: Account) -> None: """Link account integrate""" @@ -347,6 +398,8 @@ def send_reset_password_email( language: Optional[str] = "en-US", ): account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") if cls.reset_password_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import PasswordResetRateLimitExceededError @@ -377,6 +430,9 @@ def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: def send_email_code_login_email( cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" ): + email = account.email if account else email + if email is None: + raise ValueError("Email must be provided.") if cls.email_code_login_rate_limiter.is_rate_limited(email): from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError @@ -404,6 +460,14 @@ def revoke_email_code_login_token(cls, token: str): @classmethod def get_user_through_email(cls, email: str): + if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email): + raise AccountRegisterError( + description=( + "This email account has been deleted within the past " + "30 days and is temporarily unavailable for new account registration" + ) + ) + account = db.session.query(Account).filter(Account.email == email).first() if not account: return None @@ -669,7 +733,7 @@ def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoi @staticmethod def get_tenant_count() -> int: """Get tenant count""" - return db.session.query(func.count(Tenant.id)).scalar() + return cast(int, db.session.query(func.count(Tenant.id)).scalar()) @staticmethod def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None: @@ -733,10 +797,10 @@ def dissolve_tenant(tenant: Tenant, operator: Account) -> None: db.session.commit() @staticmethod - def get_custom_config(tenant_id: str) -> None: - tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).one_or_404() + def get_custom_config(tenant_id: str) -> dict: + tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404() - return tenant.custom_config_dict + return cast(dict, tenant.custom_config_dict) class RegisterService: @@ -793,6 +857,7 @@ def register( language: Optional[str] = None, status: Optional[AccountStatus] = None, is_setup: Optional[bool] = False, + create_workspace_required: Optional[bool] = True, ) -> Account: db.session.begin_nested() """Register account""" @@ -807,10 +872,10 @@ def register( account.status = AccountStatus.ACTIVE.value if not status else status.value account.initialized_at = datetime.now(UTC).replace(tzinfo=None) - if open_id is not None or provider is not None: + if open_id is not None and provider is not None: AccountService.link_account_integrate(provider, open_id, account) - if FeatureService.get_system_features().is_allow_create_workspace: + if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required: tenant = TenantService.create_tenant(f"{account.name}'s Workspace") TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant @@ -819,6 +884,10 @@ def register( db.session.commit() except WorkSpaceNotAllowedCreateError: db.session.rollback() + except AccountRegisterError as are: + db.session.rollback() + logging.exception("Register failed") + raise are except Exception as e: db.session.rollback() logging.exception("Register failed") @@ -828,10 +897,11 @@ def register( @classmethod def invite_new_member( - cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account = None + cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Optional[Account] = None ) -> str: """Invite new member""" account = Account.query.filter_by(email=email).first() + assert inviter is not None, "Inviter must be provided." if not account: TenantService.check_member_permission(tenant, inviter, None, "add") @@ -894,7 +964,9 @@ def revoke_token(cls, workspace_id: str, email: str, token: str): redis_client.delete(cls._get_invitation_token_key(token)) @classmethod - def get_invitation_if_token_valid(cls, workspace_id: str, email: str, token: str) -> Optional[dict[str, Any]]: + def get_invitation_if_token_valid( + cls, workspace_id: Optional[str], email: str, token: str + ) -> Optional[dict[str, Any]]: invitation_data = cls._get_invitation_by_token(token, workspace_id, email) if not invitation_data: return None @@ -953,7 +1025,7 @@ def _get_invitation_by_token( if not data: return None - invitation = json.loads(data) + invitation: dict = json.loads(data) return invitation diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index d2cd7bea6..6dc1affa1 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -48,6 +48,8 @@ def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> return cls.get_chat_prompt( copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt ) + # default return empty dict + return {} @classmethod def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: @@ -91,3 +93,5 @@ def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) - return cls.get_chat_prompt( copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt ) + # default return empty dict + return {} diff --git a/api/services/agent_service.py b/api/services/agent_service.py index c8819535f..b02f762ad 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -1,5 +1,7 @@ +from typing import Optional + import pytz -from flask_login import current_user +from flask_login import current_user # type: ignore from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager from core.tools.tool_manager import ToolManager @@ -14,7 +16,7 @@ def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) - """ Service to get agent logs """ - conversation: Conversation = ( + conversation: Optional[Conversation] = ( db.session.query(Conversation) .filter( Conversation.id == conversation_id, @@ -26,7 +28,7 @@ def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) - if not conversation: raise ValueError(f"Conversation not found: {conversation_id}") - message: Message = ( + message: Optional[Message] = ( db.session.query(Message) .filter( Message.id == message_id, @@ -72,7 +74,10 @@ def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) - } agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict()) - agent_tools = agent_config.tools + if not agent_config: + return result + + agent_tools = agent_config.tools or [] def find_agent_tool(tool_name: str): for agent_tool in agent_tools: diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index f45c21cb1..45ec1e9b5 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,8 +1,9 @@ import datetime import uuid +from typing import cast import pandas as pd -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy import or_ from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound @@ -71,7 +72,7 @@ def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> Messa app_id, annotation_setting.collection_binding_id, ) - return annotation + return cast(MessageAnnotation, annotation) @classmethod def enable_app_annotation(cls, args: dict, app_id: str) -> dict: @@ -124,8 +125,7 @@ def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keywo raise NotFound("App not found") if keyword: annotations = ( - db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) + MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) .filter( or_( MessageAnnotation.question.ilike("%{}%".format(keyword)), @@ -137,8 +137,7 @@ def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keywo ) else: annotations = ( - db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) + MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) ) @@ -287,7 +286,7 @@ def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict: df = pd.read_csv(file) result = [] for index, row in df.iterrows(): - content = {"question": row[0], "answer": row[1]} + content = {"question": row.iloc[0], "answer": row.iloc[1]} result.append(content) if len(result) == 0: raise ValueError("The CSV file is empty.") @@ -327,8 +326,7 @@ def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, lim raise NotFound("Annotation not found") annotation_hit_histories = ( - db.session.query(AppAnnotationHitHistory) - .filter( + AppAnnotationHitHistory.query.filter( AppAnnotationHitHistory.app_id == app_id, AppAnnotationHitHistory.annotation_id == annotation_id, ) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 7c1a17598..15119247f 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -2,9 +2,10 @@ import uuid from enum import StrEnum from typing import Optional +from urllib.parse import urlparse from uuid import uuid4 -import yaml +import yaml # type: ignore from packaging import version from pydantic import BaseModel from sqlalchemy import select @@ -103,7 +104,7 @@ def import_app( raise ValueError(f"Invalid import_mode: {import_mode}") # Get YAML content - content = "" + content: str = "" if mode == ImportMode.YAML_URL: if not yaml_url: return Import( @@ -113,13 +114,17 @@ def import_app( ) try: max_size = 10 * 1024 * 1024 # 10MB - # tricky way to handle url from github to github raw url - if yaml_url.startswith("https://github.com") and yaml_url.endswith((".yml", ".yaml")): + parsed_url = urlparse(yaml_url) + if ( + parsed_url.scheme == "https" + and parsed_url.netloc == "github.com" + and parsed_url.path.endswith((".yml", ".yaml")) + ): yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com") yaml_url = yaml_url.replace("/blob/", "/") response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) response.raise_for_status() - content = response.content + content = response.content.decode() if len(content) > max_size: return Import( @@ -134,15 +139,6 @@ def import_app( status=ImportStatus.FAILED, error="Empty content from url", ) - - try: - content = content.decode("utf-8") - except UnicodeDecodeError as e: - return Import( - id=import_id, - status=ImportStatus.FAILED, - error=f"Error decoding content: {e}", - ) except Exception as e: return Import( id=import_id, @@ -176,6 +172,9 @@ def import_app( data["kind"] = "app" imported_version = data.get("version", "0.1.0") + # check if imported_version is a float-like string + if not isinstance(imported_version, str): + raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}") status = _check_version_compatibility(imported_version) # Extract app data @@ -362,6 +361,9 @@ def _create_or_update_app( app.icon_background = icon_background or app_data.get("icon_background", app.icon_background) app.updated_by = account.id else: + if account.current_tenant_id is None: + raise ValueError("Current tenant is not set") + # Create new app app = App() app.id = str(uuid4()) @@ -462,7 +464,7 @@ def export_dsl(cls, app_model: App, include_secret: bool = False) -> str: else: cls._append_model_config_export_data(export_data, app_model) - return yaml.dump(export_data, allow_unicode=True) + return yaml.dump(export_data, allow_unicode=True) # type: ignore @classmethod def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None: diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 9def7d15e..51aef7cca 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -118,7 +118,7 @@ def generate( @staticmethod def _get_max_active_requests(app_model: App) -> int: max_active_requests = app_model.max_active_requests - if app_model.max_active_requests is None: + if max_active_requests is None: max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS) return max_active_requests @@ -150,7 +150,7 @@ def generate_more_like_this( message_id: str, invoke_from: InvokeFrom, streaming: bool = True, - ) -> Union[dict, Generator]: + ) -> Union[Mapping, Generator]: """ Generate more like this :param app_model: app model diff --git a/api/services/app_service.py b/api/services/app_service.py index adcc40bc9..675b623fa 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,9 +1,9 @@ import json import logging from datetime import UTC, datetime -from typing import cast +from typing import Optional, cast -from flask_login import current_user +from flask_login import current_user # type: ignore from flask_sqlalchemy.pagination import Pagination from sqlalchemy.sql import text # Extend: App Center - Recommended list sorted by usage frequency @@ -33,9 +33,10 @@ class AppService: - def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None: + def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None: """ Get app list with pagination + :param user_id: user id :param tenant_id: tenant id :param args: request args :return: @@ -60,6 +61,8 @@ def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None: elif args["mode"] == "channel": filters.append(App.mode == AppMode.CHANNEL.value) + if args.get("is_created_by_me", False): + filters.append(App.created_by == user_id) if args.get("name"): name = args["name"][:30] filters.append(App.name.ilike(f"%{name}%")) @@ -107,7 +110,7 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: # get default model instance try: model_instance = model_manager.get_default_model_instance( - tenant_id=account.current_tenant_id, model_type=ModelType.LLM + tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM ) except (ProviderTokenNotInitError, LLMBadRequestError): model_instance = None @@ -124,6 +127,8 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: else: llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + if model_schema is None: + raise ValueError(f"model schema not found for model {model_instance.model}") default_model_dict = { "provider": model_instance.provider, @@ -133,7 +138,7 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: } else: provider, model = model_manager.get_default_provider_model_name( - tenant_id=account.current_tenant_id, model_type=ModelType.LLM + tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM ) default_model_config["model"]["provider"] = provider default_model_config["model"]["name"] = model @@ -352,7 +357,7 @@ def get_app_meta(self, app_model: App) -> dict: """ app_mode = AppMode.value_of(app_model.mode) - meta = {"tool_icons": {}} + meta: dict = {"tool_icons": {}} if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow @@ -374,7 +379,7 @@ def get_app_meta(self, app_model: App) -> dict: } ) else: - app_model_config: AppModelConfig = app_model.app_model_config + app_model_config: Optional[AppModelConfig] = app_model.app_model_config if not app_model_config: return meta @@ -390,16 +395,18 @@ def get_app_meta(self, app_model: App) -> dict: keys = list(tool.keys()) if len(keys) >= 4: # current tool standard - provider_type = tool.get("provider_type") - provider_id = tool.get("provider_id") - tool_name = tool.get("tool_name") + provider_type = tool.get("provider_type", "") + provider_id = tool.get("provider_id", "") + tool_name = tool.get("tool_name", "") if provider_type == "builtin": meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon" elif provider_type == "api": try: - provider: ApiToolProvider = ( + provider: Optional[ApiToolProvider] = ( db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first() ) + if provider is None: + raise ValueError(f"provider not found for tool {tool_name}") meta["tool_icons"][tool_name] = json.loads(provider.icon) except: meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"} diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 7a0cd5725..294dfe4c8 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -1,5 +1,6 @@ import io import logging +import uuid from typing import Optional from werkzeug.datastructures import FileStorage @@ -81,7 +82,7 @@ def transcript_tts( from app import app from extensions.ext_database import db - def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): + def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None): with app.app_context(): if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow @@ -94,6 +95,8 @@ def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): voice = features_dict["text_to_speech"].get("voice") if voice is None else voice else: + if app_model.app_model_config is None: + raise ValueError("AppModelConfig not found") text_to_speech_dict = app_model.app_model_config.text_to_speech_dict if not text_to_speech_dict.get("enabled"): @@ -110,6 +113,8 @@ def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): voices = model_instance.get_tts_voices() if voices: voice = voices[0].get("value") + if not voice: + raise ValueError("Sorry, no voice available.") else: raise ValueError("Sorry, no voice available.") @@ -120,7 +125,13 @@ def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): raise e if message_id: + try: + uuid.UUID(message_id) + except ValueError: + return None message = db.session.query(Message).filter(Message.id == message_id).first() + if message is None: + return None if message.answer == "" and message.status == "normal": return None @@ -130,6 +141,8 @@ def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): return Response(stream_with_context(response), content_type="audio/mpeg") return response else: + if text is None: + raise ValueError("Text is required") response = invoke_tts(text, app_model, voice) if isinstance(response, Generator): return Response(stream_with_context(response), content_type="audio/mpeg") diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py index afc491398..50e4edff1 100644 --- a/api/services/auth/firecrawl/firecrawl.py +++ b/api/services/auth/firecrawl/firecrawl.py @@ -11,8 +11,8 @@ def __init__(self, credentials: dict): auth_type = credentials.get("auth_type") if auth_type != "bearer": raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer") - self.api_key = credentials.get("config").get("api_key", None) - self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev") + self.api_key = credentials.get("config", {}).get("api_key", None) + self.base_url = credentials.get("config", {}).get("base_url", "https://api.firecrawl.dev") if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py index de898a1f9..6100e9afc 100644 --- a/api/services/auth/jina.py +++ b/api/services/auth/jina.py @@ -11,7 +11,7 @@ def __init__(self, credentials: dict): auth_type = credentials.get("auth_type") if auth_type != "bearer": raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") - self.api_key = credentials.get("config").get("api_key", None) + self.api_key = credentials.get("config", {}).get("api_key", None) if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py index de898a1f9..6100e9afc 100644 --- a/api/services/auth/jina/jina.py +++ b/api/services/auth/jina/jina.py @@ -11,7 +11,7 @@ def __init__(self, credentials: dict): auth_type = credentials.get("auth_type") if auth_type != "bearer": raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") - self.api_key = credentials.get("config").get("api_key", None) + self.api_key = credentials.get("config", {}).get("api_key", None) if not self.api_key: raise ValueError("No API key provided") diff --git a/api/services/billing_service.py b/api/services/billing_service.py index edc516821..0d50a2aa8 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -1,7 +1,8 @@ import os +from typing import Literal, Optional import httpx -from tenacity import retry, retry_if_not_exception_type, stop_before_delay, wait_fixed +from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed from extensions.ext_database import db from models.account import TenantAccountJoin, TenantAccountRole @@ -16,7 +17,6 @@ def get_info(cls, tenant_id: str): params = {"tenant_id": tenant_id} billing_info = cls._send_request("GET", "/subscription/info", params=params) - return billing_info @classmethod @@ -43,26 +43,51 @@ def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""): @retry( wait=wait_fixed(2), stop=stop_before_delay(10), - retry=retry_if_not_exception_type(httpx.RequestError), + retry=retry_if_exception_type(httpx.RequestError), reraise=True, ) - def _send_request(cls, method, endpoint, json=None, params=None): + def _send_request(cls, method: Literal["GET", "POST", "DELETE"], endpoint: str, json=None, params=None): headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" response = httpx.request(method, url, json=json, params=params, headers=headers) - + if method == "GET" and response.status_code != httpx.codes.OK: + raise ValueError("Unable to retrieve billing information. Please try again later or contact support.") return response.json() @staticmethod def is_tenant_owner_or_admin(current_user): tenant_id = current_user.current_tenant_id - join = ( + join: Optional[TenantAccountJoin] = ( db.session.query(TenantAccountJoin) .filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) .first() ) + if not join: + raise ValueError("Tenant account join not found") + if not TenantAccountRole.is_privileged_role(join.role): raise ValueError("Only team owner or team admin can perform this action") + + @classmethod + def delete_account(cls, account_id: str): + """Delete account.""" + params = {"account_id": account_id} + return cls._send_request("DELETE", "/account/", params=params) + + @classmethod + def is_email_in_freeze(cls, email: str) -> bool: + params = {"email": email} + try: + response = cls._send_request("GET", "/account/in-freeze", params=params) + return bool(response.get("data", False)) + except Exception: + return False + + @classmethod + def update_account_deletion_feedback(cls, email: str, feedback: str): + """Update account deletion feedback.""" + json = {"email": email, "feedback": feedback} + return cls._send_request("POST", "/account/delete-feedback", json=json) diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 456dc3ebe..6485cbf37 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -72,8 +72,7 @@ def pagination_by_last_id( sort_direction=sort_direction, reference_conversation=current_page_last_conversation, ) - count_stmt = stmt.where(rest_filter_condition) - count_stmt = select(func.count()).select_from(count_stmt.subquery()) + count_stmt = select(func.count()).select_from(stmt.where(rest_filter_condition).subquery()) rest_count = session.scalar(count_stmt) or 0 if rest_count > 0: has_more = True diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 4e99c73ad..dac0a6a77 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -6,7 +6,7 @@ import uuid from typing import Any, Optional -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy import func from werkzeug.exceptions import NotFound @@ -14,6 +14,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexType from core.rag.retrieval.retrieval_methods import RetrievalMethod from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted @@ -23,7 +24,9 @@ from models.account import Account, TenantAccountRole from models.dataset import ( AppDatasetJoin, + ChildChunk, Dataset, + DatasetAutoDisableLog, DatasetCollectionBinding, DatasetPermission, DatasetPermissionEnum, @@ -35,8 +38,15 @@ ) from models.model import UploadFile from models.source import DataSourceOauthBinding -from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateEntity -from services.errors.account import NoPermissionError +from services.entities.knowledge_entities.knowledge_entities import ( + ChildChunkUpdateArgs, + KnowledgeConfig, + RerankingModel, + RetrievalModel, + SegmentUpdateArgs, +) +from services.errors.account import InvalidActionError, NoPermissionError +from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError from services.errors.dataset import DatasetNameDuplicateError from services.errors.document import DocumentIndexingError from services.errors.file import FileNotExistsError @@ -44,13 +54,16 @@ from services.feature_service import FeatureModel, FeatureService from services.tag_service import TagService from services.vector_service import VectorService +from tasks.batch_clean_document_task import batch_clean_document_task from tasks.clean_notion_document_task import clean_notion_document_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.delete_segment_from_index_task import delete_segment_from_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task +from tasks.disable_segments_from_index_task import disable_segments_from_index_task from tasks.document_indexing_task import document_indexing_task from tasks.document_indexing_update_task import document_indexing_update_task from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task +from tasks.enable_segments_to_index_task import enable_segments_to_index_task from tasks.recover_document_indexing_task import recover_document_indexing_task from tasks.retry_document_indexing_task import retry_document_indexing_task from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task @@ -58,7 +71,7 @@ class DatasetService: @staticmethod - def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None): + def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) if user: @@ -73,25 +86,30 @@ def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids else: return [], 0 else: - # show all datasets that the user has permission to access - if permitted_dataset_ids: - query = query.filter( - db.or_( - Dataset.permission == DatasetPermissionEnum.ALL_TEAM, - db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id), - db.and_( - Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM, - Dataset.id.in_(permitted_dataset_ids), - ), + if user.current_role != TenantAccountRole.OWNER or not include_all: + # show all datasets that the user has permission to access + if permitted_dataset_ids: + query = query.filter( + db.or_( + Dataset.permission == DatasetPermissionEnum.ALL_TEAM, + db.and_( + Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id + ), + db.and_( + Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM, + Dataset.id.in_(permitted_dataset_ids), + ), + ) ) - ) - else: - query = query.filter( - db.or_( - Dataset.permission == DatasetPermissionEnum.ALL_TEAM, - db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id), + else: + query = query.filter( + db.or_( + Dataset.permission == DatasetPermissionEnum.ALL_TEAM, + db.and_( + Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id + ), + ) ) - ) else: # if no user, only show datasets that are shared with all team members query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) @@ -186,8 +204,9 @@ def create_empty_dataset( return dataset @staticmethod - def get_dataset(dataset_id) -> Dataset: - return Dataset.query.filter_by(id=dataset_id).first() + def get_dataset(dataset_id) -> Optional[Dataset]: + dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first() + return dataset @staticmethod def check_dataset_model_setting(dataset): @@ -228,6 +247,8 @@ def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, @staticmethod def update_dataset(dataset_id, data, user): dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise ValueError("Dataset not found") DatasetService.check_dataset_permission(dataset, user) if dataset.provider == "external": @@ -361,26 +382,38 @@ def check_dataset_permission(dataset, user): if dataset.tenant_id != user.current_tenant_id: logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") raise NoPermissionError("You do not have permission to access this dataset.") - if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id: - logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") - raise NoPermissionError("You do not have permission to access this dataset.") - if dataset.permission == "partial_members": - user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first() - if not user_permission and dataset.tenant_id != user.current_tenant_id and dataset.created_by != user.id: + if user.current_role != TenantAccountRole.OWNER: + if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id: logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") raise NoPermissionError("You do not have permission to access this dataset.") + if dataset.permission == "partial_members": + user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first() + if ( + not user_permission + and dataset.tenant_id != user.current_tenant_id + and dataset.created_by != user.id + ): + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod - def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None): - if dataset.permission == DatasetPermissionEnum.ONLY_ME: - if dataset.created_by != user.id: - raise NoPermissionError("You do not have permission to access this dataset.") + def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None): + if not dataset: + raise ValueError("Dataset not found") - elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: - if not any( - dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all() - ): - raise NoPermissionError("You do not have permission to access this dataset.") + if not user: + raise ValueError("User not found") + + if user.current_role != TenantAccountRole.OWNER: + if dataset.permission == DatasetPermissionEnum.ONLY_ME: + if dataset.created_by != user.id: + raise NoPermissionError("You do not have permission to access this dataset.") + + elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: + if not any( + dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all() + ): + raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod def get_dataset_queries(dataset_id: str, page: int, per_page: int): @@ -399,9 +432,33 @@ def get_related_apps(dataset_id: str): .all() ) + @staticmethod + def get_dataset_auto_disable_logs(dataset_id: str) -> dict: + features = FeatureService.get_features(current_user.current_tenant_id) + if not features.billing.enabled or features.billing.subscription.plan == "sandbox": + return { + "document_ids": [], + "count": 0, + } + # get recent 30 days auto disable logs + start_date = datetime.datetime.now() - datetime.timedelta(days=30) + dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter( + DatasetAutoDisableLog.dataset_id == dataset_id, + DatasetAutoDisableLog.created_at >= start_date, + ).all() + if dataset_auto_disable_logs: + return { + "document_ids": [log.document_id for log in dataset_auto_disable_logs], + "count": len(dataset_auto_disable_logs), + } + return { + "document_ids": [], + "count": 0, + } + class DocumentService: - DEFAULT_RULES = { + DEFAULT_RULES: dict[str, Any] = { "mode": "custom", "rules": { "pre_processing_rules": [ @@ -415,7 +472,7 @@ class DocumentService: }, } - DOCUMENT_METADATA_SCHEMA = { + DOCUMENT_METADATA_SCHEMA: dict[str, Any] = { "book": { "title": str, "language": str, @@ -509,12 +566,14 @@ class DocumentService: } @staticmethod - def get_document(dataset_id: str, document_id: str) -> Optional[Document]: - document = ( - db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() - ) - - return document + def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]: + if document_id: + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) + return document + else: + return None @staticmethod def get_document_by_id(document_id: str) -> Optional[Document]: @@ -579,6 +638,20 @@ def delete_document(document): db.session.delete(document) db.session.commit() + @staticmethod + def delete_documents(dataset: Dataset, document_ids: list[str]): + documents = db.session.query(Document).filter(Document.id.in_(document_ids)).all() + file_ids = [ + document.data_source_info_dict["upload_file_id"] + for document in documents + if document.data_source_type == "upload_file" + ] + batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) + + for document in documents: + db.session.delete(document) + db.session.commit() + @staticmethod def rename_document(dataset_id: str, document_id: str, name: str) -> Document: dataset = DatasetService.get_dataset(dataset_id) @@ -680,7 +753,7 @@ def get_documents_position(dataset_id): @staticmethod def save_document_with_dataset_id( dataset: Dataset, - document_data: dict, + knowledge_config: KnowledgeConfig, account: Account | Any, dataset_process_rule: Optional[DatasetProcessRule] = None, created_from: str = "web", @@ -689,45 +762,49 @@ def save_document_with_dataset_id( features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - if "original_document_id" not in document_data or not document_data["original_document_id"]: + if not knowledge_config.original_document_id: count = 0 - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] - count = len(upload_file_list) - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] - for notion_info in notion_info_list: - count = count + len(notion_info["pages"]) - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]["info_list"]["website_info_list"] - count = len(website_info["urls"]) - batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) - if count > batch_upload_limit: - raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") - - DocumentService.check_documents_upload_quota(count, features) + if knowledge_config.data_source: + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore + count = len(upload_file_list) + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list + for notion_info in notion_info_list: # type: ignore + count = count + len(notion_info.pages) + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list + count = len(website_info.urls) # type: ignore + batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) + if count > batch_upload_limit: + raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") + + DocumentService.check_documents_upload_quota(count, features) # if dataset is empty, update dataset data_source_type if not dataset.data_source_type: - dataset.data_source_type = document_data["data_source"]["type"] + dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore if not dataset.indexing_technique: - if ( - "indexing_technique" not in document_data - or document_data["indexing_technique"] not in Dataset.INDEXING_TECHNIQUE_LIST - ): - raise ValueError("Indexing technique is required") + if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: + raise ValueError("Indexing technique is invalid") - dataset.indexing_technique = document_data["indexing_technique"] - if document_data["indexing_technique"] == "high_quality": + dataset.indexing_technique = knowledge_config.indexing_technique + if knowledge_config.indexing_technique == "high_quality": model_manager = ModelManager() - embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING - ) - dataset.embedding_model = embedding_model.model - dataset.embedding_model_provider = embedding_model.provider + if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: + dataset_embedding_model = knowledge_config.embedding_model + dataset_embedding_model_provider = knowledge_config.embedding_model_provider + else: + embedding_model = model_manager.get_default_model_instance( + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING + ) + dataset_embedding_model = embedding_model.model + dataset_embedding_model_provider = embedding_model.provider + dataset.embedding_model = dataset_embedding_model + dataset.embedding_model_provider = dataset_embedding_model_provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model + dataset_embedding_model_provider, dataset_embedding_model ) dataset.collection_binding_id = dataset_collection_binding.id if not dataset.retrieval_model: @@ -739,41 +816,51 @@ def save_document_with_dataset_id( "score_threshold_enabled": False, } - dataset.retrieval_model = document_data.get("retrieval_model") or default_retrieval_model + dataset.retrieval_model = ( + knowledge_config.retrieval_model.model_dump() + if knowledge_config.retrieval_model + else default_retrieval_model + ) # type: ignore documents = [] - if document_data.get("original_document_id"): - document = DocumentService.update_document_with_dataset_id(dataset, document_data, account) + if knowledge_config.original_document_id: + document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) documents.append(document) batch = document.batch else: batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) # save process rule if not dataset_process_rule: - process_rule = document_data["process_rule"] - if process_rule["mode"] == "custom": - dataset_process_rule = DatasetProcessRule( - dataset_id=dataset.id, - mode=process_rule["mode"], - rules=json.dumps(process_rule["rules"]), - created_by=account.id, - ) - elif process_rule["mode"] == "automatic": - dataset_process_rule = DatasetProcessRule( - dataset_id=dataset.id, - mode=process_rule["mode"], - rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), - created_by=account.id, - ) - db.session.add(dataset_process_rule) - db.session.commit() + process_rule = knowledge_config.process_rule + if process_rule: + if process_rule.mode in ("custom", "hierarchical"): + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=process_rule.rules.model_dump_json() if process_rule.rules else None, + created_by=account.id, + ) + elif process_rule.mode == "automatic": + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + created_by=account.id, + ) + else: + logging.warn( + f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" + ) + return + db.session.add(dataset_process_rule) + db.session.commit() lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) with redis_client.lock(lock_name, timeout=600): position = DocumentService.get_documents_position(dataset.id) document_ids = [] duplicate_document_ids = [] - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore for file_id in upload_file_list: file = ( db.session.query(UploadFile) @@ -790,7 +877,7 @@ def save_document_with_dataset_id( "upload_file_id": file_id, } # check duplicate - if document_data.get("duplicate", False): + if knowledge_config.duplicate: document = Document.query.filter_by( dataset_id=dataset.id, tenant_id=current_user.current_tenant_id, @@ -799,11 +886,11 @@ def save_document_with_dataset_id( name=file_name, ).first() if document: - document.dataset_process_rule_id = dataset_process_rule.id - document.updated_at = datetime.datetime.utcnow() + document.dataset_process_rule_id = dataset_process_rule.id # type: ignore + document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) document.created_from = created_from - document.doc_form = document_data["doc_form"] - document.doc_language = document_data["doc_language"] + document.doc_form = knowledge_config.doc_form + document.doc_language = knowledge_config.doc_language document.data_source_info = json.dumps(data_source_info) document.batch = batch document.indexing_status = "waiting" @@ -813,10 +900,10 @@ def save_document_with_dataset_id( continue document = DocumentService.build_document( dataset, - dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, data_source_info, created_from, position, @@ -829,8 +916,10 @@ def save_document_with_dataset_id( document_ids.append(document.id) documents.append(document) position += 1 - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list + if not notion_info_list: + raise ValueError("No notion info list found.") exist_page_ids = [] exist_document = {} documents = Document.query.filter_by( @@ -845,7 +934,7 @@ def save_document_with_dataset_id( exist_page_ids.append(data_source_info["notion_page_id"]) exist_document[data_source_info["notion_page_id"]] = document.id for notion_info in notion_info_list: - workspace_id = notion_info["workspace_id"] + workspace_id = notion_info.workspace_id data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, @@ -856,25 +945,25 @@ def save_document_with_dataset_id( ).first() if not data_source_binding: raise ValueError("Data source binding not found.") - for page in notion_info["pages"]: - if page["page_id"] not in exist_page_ids: + for page in notion_info.pages: + if page.page_id not in exist_page_ids: data_source_info = { "notion_workspace_id": workspace_id, - "notion_page_id": page["page_id"], - "notion_page_icon": page["page_icon"], - "type": page["type"], + "notion_page_id": page.page_id, + "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, + "type": page.type, } document = DocumentService.build_document( dataset, - dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, data_source_info, created_from, position, account, - page["page_name"], + page.page_name, batch, ) db.session.add(document) @@ -883,19 +972,21 @@ def save_document_with_dataset_id( documents.append(document) position += 1 else: - exist_document.pop(page["page_id"]) + exist_document.pop(page.page_id) # delete not selected documents if len(exist_document) > 0: clean_notion_document_task.delay(list(exist_document.values()), dataset.id) - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]["info_list"]["website_info_list"] - urls = website_info["urls"] + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list + if not website_info: + raise ValueError("No website info list found.") + urls = website_info.urls for url in urls: data_source_info = { "url": url, - "provider": website_info["provider"], - "job_id": website_info["job_id"], - "only_main_content": website_info.get("only_main_content", False), + "provider": website_info.provider, + "job_id": website_info.job_id, + "only_main_content": website_info.only_main_content, "mode": "crawl", } if len(url) > 255: @@ -904,10 +995,10 @@ def save_document_with_dataset_id( document_name = url document = DocumentService.build_document( dataset, - dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], + dataset_process_rule.id, # type: ignore + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, data_source_info, created_from, position, @@ -981,43 +1072,46 @@ def get_tenant_documents_count(): @staticmethod def update_document_with_dataset_id( dataset: Dataset, - document_data: dict, + document_data: KnowledgeConfig, account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, created_from: str = "web", ): DatasetService.check_dataset_model_setting(dataset) - document = DocumentService.get_document(dataset.id, document_data["original_document_id"]) + document = DocumentService.get_document(dataset.id, document_data.original_document_id) if document is None: raise NotFound("Document not found") if document.display_status != "available": raise ValueError("Document is not available") # save process rule - if document_data.get("process_rule"): - process_rule = document_data["process_rule"] - if process_rule["mode"] == "custom": + if document_data.process_rule: + process_rule = document_data.process_rule + if process_rule.mode in {"custom", "hierarchical"}: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode=process_rule["mode"], - rules=json.dumps(process_rule["rules"]), + mode=process_rule.mode, + rules=process_rule.rules.model_dump_json() if process_rule.rules else None, created_by=account.id, ) - elif process_rule["mode"] == "automatic": + elif process_rule.mode == "automatic": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode=process_rule["mode"], + mode=process_rule.mode, rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) - db.session.add(dataset_process_rule) - db.session.commit() - document.dataset_process_rule_id = dataset_process_rule.id + if dataset_process_rule is not None: + db.session.add(dataset_process_rule) + db.session.commit() + document.dataset_process_rule_id = dataset_process_rule.id # update document data source - if document_data.get("data_source"): + if document_data.data_source: file_name = "" data_source_info = {} - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] + if document_data.data_source.info_list.data_source_type == "upload_file": + if not document_data.data_source.info_list.file_info_list: + raise ValueError("No file info list found.") + upload_file_list = document_data.data_source.info_list.file_info_list.file_ids for file_id in upload_file_list: file = ( db.session.query(UploadFile) @@ -1033,10 +1127,12 @@ def update_document_with_dataset_id( data_source_info = { "upload_file_id": file_id, } - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] + elif document_data.data_source.info_list.data_source_type == "notion_import": + if not document_data.data_source.info_list.notion_info_list: + raise ValueError("No notion info list found.") + notion_info_list = document_data.data_source.info_list.notion_info_list for notion_info in notion_info_list: - workspace_id = notion_info["workspace_id"] + workspace_id = notion_info.workspace_id data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, @@ -1047,31 +1143,32 @@ def update_document_with_dataset_id( ).first() if not data_source_binding: raise ValueError("Data source binding not found.") - for page in notion_info["pages"]: + for page in notion_info.pages: data_source_info = { "notion_workspace_id": workspace_id, - "notion_page_id": page["page_id"], - "notion_page_icon": page["page_icon"], - "type": page["type"], + "notion_page_id": page.page_id, + "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore + "type": page.type, } - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]["info_list"]["website_info_list"] - urls = website_info["urls"] - for url in urls: - data_source_info = { - "url": url, - "provider": website_info["provider"], - "job_id": website_info["job_id"], - "only_main_content": website_info.get("only_main_content", False), - "mode": "crawl", - } - document.data_source_type = document_data["data_source"]["type"] + elif document_data.data_source.info_list.data_source_type == "website_crawl": + website_info = document_data.data_source.info_list.website_info_list + if website_info: + urls = website_info.urls + for url in urls: + data_source_info = { + "url": url, + "provider": website_info.provider, + "job_id": website_info.job_id, + "only_main_content": website_info.only_main_content, # type: ignore + "mode": "crawl", + } + document.data_source_type = document_data.data_source.info_list.data_source_type document.data_source_info = json.dumps(data_source_info) document.name = file_name # update document name - if document_data.get("name"): - document.name = document_data["name"] + if document_data.name: + document.name = document_data.name # update document to be waiting document.indexing_status = "waiting" document.completed_at = None @@ -1081,7 +1178,7 @@ def update_document_with_dataset_id( document.splitting_completed_at = None document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.created_from = created_from - document.doc_form = document_data["doc_form"] + document.doc_form = document_data.doc_form db.session.add(document) db.session.commit() # update document segment @@ -1093,21 +1190,27 @@ def update_document_with_dataset_id( return document @staticmethod - def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account): + def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account): features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: count = 0 - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = ( + knowledge_config.data_source.info_list.file_info_list.file_ids + if knowledge_config.data_source.info_list.file_info_list + else [] + ) count = len(upload_file_list) - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] - for notion_info in notion_info_list: - count = count + len(notion_info["pages"]) - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]["info_list"]["website_info_list"] - count = len(website_info["urls"]) + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list + if notion_info_list: + for notion_info in notion_info_list: + count = count + len(notion_info.pages) + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list + if website_info: + count = len(website_info.urls) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -1116,39 +1219,39 @@ def save_document_without_dataset_id(tenant_id: str, document_data: dict, accoun dataset_collection_binding_id = None retrieval_model = None - if document_data["indexing_technique"] == "high_quality": + if knowledge_config.indexing_technique == "high_quality": dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - document_data["embedding_model_provider"], document_data["embedding_model"] + knowledge_config.embedding_model_provider, # type: ignore + knowledge_config.embedding_model, # type: ignore ) dataset_collection_binding_id = dataset_collection_binding.id - if document_data.get("retrieval_model"): - retrieval_model = document_data["retrieval_model"] + if knowledge_config.retrieval_model: + retrieval_model = knowledge_config.retrieval_model else: - default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, - "reranking_enable": False, - "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, - "score_threshold_enabled": False, - } - retrieval_model = default_retrieval_model + retrieval_model = RetrievalModel( + search_method=RetrievalMethod.SEMANTIC_SEARCH.value, + reranking_enable=False, + reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""), + top_k=2, + score_threshold_enabled=False, + ) # save dataset dataset = Dataset( tenant_id=tenant_id, name="", - data_source_type=document_data["data_source"]["type"], - indexing_technique=document_data.get("indexing_technique", "high_quality"), + data_source_type=knowledge_config.data_source.info_list.data_source_type, + indexing_technique=knowledge_config.indexing_technique, created_by=account.id, - embedding_model=document_data.get("embedding_model"), - embedding_model_provider=document_data.get("embedding_model_provider"), + embedding_model=knowledge_config.embedding_model, + embedding_model_provider=knowledge_config.embedding_model_provider, collection_binding_id=dataset_collection_binding_id, - retrieval_model=retrieval_model, + retrieval_model=retrieval_model.model_dump() if retrieval_model else None, ) - db.session.add(dataset) + db.session.add(dataset) # type: ignore db.session.flush() - documents, batch = DocumentService.save_document_with_dataset_id(dataset, document_data, account) + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) cut_length = 18 cut_name = documents[0].name[:cut_length] @@ -1159,133 +1262,86 @@ def save_document_without_dataset_id(tenant_id: str, document_data: dict, accoun return dataset, documents, batch @classmethod - def document_create_args_validate(cls, args: dict): - if "original_document_id" not in args or not args["original_document_id"]: - DocumentService.data_source_args_validate(args) - DocumentService.process_rule_args_validate(args) + def document_create_args_validate(cls, knowledge_config: KnowledgeConfig): + if not knowledge_config.data_source and not knowledge_config.process_rule: + raise ValueError("Data source or Process rule is required") else: - if ("data_source" not in args or not args["data_source"]) and ( - "process_rule" not in args or not args["process_rule"] - ): - raise ValueError("Data source or Process rule is required") - else: - if args.get("data_source"): - DocumentService.data_source_args_validate(args) - if args.get("process_rule"): - DocumentService.process_rule_args_validate(args) + if knowledge_config.data_source: + DocumentService.data_source_args_validate(knowledge_config) + if knowledge_config.process_rule: + DocumentService.process_rule_args_validate(knowledge_config) @classmethod - def data_source_args_validate(cls, args: dict): - if "data_source" not in args or not args["data_source"]: + def data_source_args_validate(cls, knowledge_config: KnowledgeConfig): + if not knowledge_config.data_source: raise ValueError("Data source is required") - if not isinstance(args["data_source"], dict): - raise ValueError("Data source is invalid") - - if "type" not in args["data_source"] or not args["data_source"]["type"]: - raise ValueError("Data source type is required") - - if args["data_source"]["type"] not in Document.DATA_SOURCES: + if knowledge_config.data_source.info_list.data_source_type not in Document.DATA_SOURCES: raise ValueError("Data source type is invalid") - if "info_list" not in args["data_source"] or not args["data_source"]["info_list"]: + if not knowledge_config.data_source.info_list: raise ValueError("Data source info is required") - if args["data_source"]["type"] == "upload_file": - if ( - "file_info_list" not in args["data_source"]["info_list"] - or not args["data_source"]["info_list"]["file_info_list"] - ): + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + if not knowledge_config.data_source.info_list.file_info_list: raise ValueError("File source info is required") - if args["data_source"]["type"] == "notion_import": - if ( - "notion_info_list" not in args["data_source"]["info_list"] - or not args["data_source"]["info_list"]["notion_info_list"] - ): + if knowledge_config.data_source.info_list.data_source_type == "notion_import": + if not knowledge_config.data_source.info_list.notion_info_list: raise ValueError("Notion source info is required") - if args["data_source"]["type"] == "website_crawl": - if ( - "website_info_list" not in args["data_source"]["info_list"] - or not args["data_source"]["info_list"]["website_info_list"] - ): + if knowledge_config.data_source.info_list.data_source_type == "website_crawl": + if not knowledge_config.data_source.info_list.website_info_list: raise ValueError("Website source info is required") @classmethod - def process_rule_args_validate(cls, args: dict): - if "process_rule" not in args or not args["process_rule"]: + def process_rule_args_validate(cls, knowledge_config: KnowledgeConfig): + if not knowledge_config.process_rule: raise ValueError("Process rule is required") - if not isinstance(args["process_rule"], dict): - raise ValueError("Process rule is invalid") - - if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]: + if not knowledge_config.process_rule.mode: raise ValueError("Process rule mode is required") - if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: + if knowledge_config.process_rule.mode not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if args["process_rule"]["mode"] == "automatic": - args["process_rule"]["rules"] = {} + if knowledge_config.process_rule.mode == "automatic": + knowledge_config.process_rule.rules = None else: - if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: + if not knowledge_config.process_rule.rules: raise ValueError("Process rule rules is required") - if not isinstance(args["process_rule"]["rules"], dict): - raise ValueError("Process rule rules is invalid") - - if ( - "pre_processing_rules" not in args["process_rule"]["rules"] - or args["process_rule"]["rules"]["pre_processing_rules"] is None - ): + if knowledge_config.process_rule.rules.pre_processing_rules is None: raise ValueError("Process rule pre_processing_rules is required") - if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list): - raise ValueError("Process rule pre_processing_rules is invalid") - unique_pre_processing_rule_dicts = {} - for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]: - if "id" not in pre_processing_rule or not pre_processing_rule["id"]: + for pre_processing_rule in knowledge_config.process_rule.rules.pre_processing_rules: + if not pre_processing_rule.id: raise ValueError("Process rule pre_processing_rules id is required") - if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES: - raise ValueError("Process rule pre_processing_rules id is invalid") - - if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None: - raise ValueError("Process rule pre_processing_rules enabled is required") - - if not isinstance(pre_processing_rule["enabled"], bool): + if not isinstance(pre_processing_rule.enabled, bool): raise ValueError("Process rule pre_processing_rules enabled is invalid") - unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule + unique_pre_processing_rule_dicts[pre_processing_rule.id] = pre_processing_rule - args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values()) + knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values()) - if ( - "segmentation" not in args["process_rule"]["rules"] - or args["process_rule"]["rules"]["segmentation"] is None - ): + if not knowledge_config.process_rule.rules.segmentation: raise ValueError("Process rule segmentation is required") - if not isinstance(args["process_rule"]["rules"]["segmentation"], dict): - raise ValueError("Process rule segmentation is invalid") - - if ( - "separator" not in args["process_rule"]["rules"]["segmentation"] - or not args["process_rule"]["rules"]["segmentation"]["separator"] - ): + if not knowledge_config.process_rule.rules.segmentation.separator: raise ValueError("Process rule segmentation separator is required") - if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str): + if not isinstance(knowledge_config.process_rule.rules.segmentation.separator, str): raise ValueError("Process rule segmentation separator is invalid") - if ( - "max_tokens" not in args["process_rule"]["rules"]["segmentation"] - or not args["process_rule"]["rules"]["segmentation"]["max_tokens"] + if not ( + knowledge_config.process_rule.mode == "hierarchical" + and knowledge_config.process_rule.rules.parent_mode == "full-doc" ): - raise ValueError("Process rule segmentation max_tokens is required") + if not knowledge_config.process_rule.rules.segmentation.max_tokens: + raise ValueError("Process rule segmentation max_tokens is required") - if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): - raise ValueError("Process rule segmentation max_tokens is invalid") + if not isinstance(knowledge_config.process_rule.rules.segmentation.max_tokens, int): + raise ValueError("Process rule segmentation max_tokens is invalid") @classmethod def estimate_args_validate(cls, args: dict): @@ -1432,7 +1488,7 @@ def create_segment(cls, args: dict, document: Document, dataset: Dataset): # save vector index try: - VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset) + VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset, document.doc_form) except Exception as e: logging.exception("create segment index failed") segment_document.enabled = False @@ -1510,7 +1566,7 @@ def multi_create_segment(cls, segments: list, document: Document, dataset: Datas db.session.add(document) try: # save vector index - VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset) + VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset, document.doc_form) except Exception as e: logging.exception("create segment index failed") for segment_document in segment_data_list: @@ -1522,14 +1578,13 @@ def multi_create_segment(cls, segments: list, document: Document, dataset: Datas return segment_data_list @classmethod - def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset): - segment_update_entity = SegmentUpdateEntity(**args) + def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): indexing_cache_key = "segment_{}_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is indexing, please try again later") - if segment_update_entity.enabled is not None: - action = segment_update_entity.enabled + if args.enabled is not None: + action = args.enabled if segment.enabled != action: if not action: segment.enabled = action @@ -1542,22 +1597,22 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document disable_segment_from_index_task.delay(segment.id) return segment if not segment.enabled: - if segment_update_entity.enabled is not None: - if not segment_update_entity.enabled: + if args.enabled is not None: + if not args.enabled: raise ValueError("Can't update disabled segment") else: raise ValueError("Can't update disabled segment") try: word_count_change = segment.word_count - content = segment_update_entity.content + content = args.content or segment.content if segment.content == content: segment.word_count = len(content) if document.doc_form == "qa_model": - segment.answer = segment_update_entity.answer - segment.word_count += len(segment_update_entity.answer) + segment.answer = args.answer + segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change - if segment_update_entity.keywords: - segment.keywords = segment_update_entity.keywords + if args.keywords: + segment.keywords = args.keywords segment.enabled = True segment.disabled_at = None segment.disabled_by = None @@ -1568,8 +1623,45 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) # update segment index task - if segment_update_entity.enabled: - VectorService.create_segments_vector([segment_update_entity.keywords], [segment], dataset) + if args.enabled: + VectorService.create_segments_vector( + [args.keywords] if args.keywords else None, + [segment], + dataset, + document.doc_form, + ) + if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + # regenerate child chunks + # get embedding model instance + if dataset.indexing_technique == "high_quality": + # check embedding model setting + model_manager = ModelManager() + + if dataset.embedding_model_provider: + embedding_model_instance = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + else: + embedding_model_instance = model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + else: + raise ValueError("The knowledge base index technique is not high quality!") + # get the process rule + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .first() + ) + if not processing_rule: + raise ValueError("No processing rule found.") + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) else: segment_hash = helper.generate_text_hash(content) tokens = 0 @@ -1600,8 +1692,8 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document segment.disabled_at = None segment.disabled_by = None if document.doc_form == "qa_model": - segment.answer = segment_update_entity.answer - segment.word_count += len(segment_update_entity.answer) + segment.answer = args.answer + segment.word_count += len(args.answer) if args.answer else 0 word_count_change = segment.word_count - word_count_change # update document word count if word_count_change != 0: @@ -1609,8 +1701,40 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document db.session.add(document) db.session.add(segment) db.session.commit() - # update segment vector index - VectorService.update_segment_vector(segment_update_entity.keywords, segment, dataset) + if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + # get embedding model instance + if dataset.indexing_technique == "high_quality": + # check embedding model setting + model_manager = ModelManager() + + if dataset.embedding_model_provider: + embedding_model_instance = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + else: + embedding_model_instance = model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + else: + raise ValueError("The knowledge base index technique is not high quality!") + # get the process rule + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .first() + ) + if not processing_rule: + raise ValueError("No processing rule found.") + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) + elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): + # update segment vector index + VectorService.update_segment_vector(args.keywords, segment, dataset) except Exception as e: logging.exception("update segment index failed") @@ -1619,8 +1743,8 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document segment.status = "error" segment.error = str(e) db.session.commit() - segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() - return segment + new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() + return new_segment @classmethod def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset): @@ -1633,13 +1757,265 @@ def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: D if segment.enabled: # send delete segment index task redis_client.setex(indexing_cache_key, 600, 1) - delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id) + delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id) db.session.delete(segment) # update document word count document.word_count -= segment.word_count db.session.add(document) db.session.commit() + @classmethod + def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): + index_node_ids = ( + DocumentSegment.query.with_entities(DocumentSegment.index_node_id) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.tenant_id == current_user.current_tenant_id, + ) + .all() + ) + index_node_ids = [index_node_id[0] for index_node_id in index_node_ids] + + delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id) + db.session.query(DocumentSegment).filter(DocumentSegment.id.in_(segment_ids)).delete() + db.session.commit() + + @classmethod + def update_segments_status(cls, segment_ids: list, action: str, dataset: Dataset, document: Document): + if action == "enable": + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == False, + ) + .all() + ) + if not segments: + return + real_deal_segmment_ids = [] + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + db.session.add(segment) + real_deal_segmment_ids.append(segment.id) + db.session.commit() + + enable_segments_to_index_task.delay(real_deal_segmment_ids, dataset.id, document.id) + elif action == "disable": + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == True, + ) + .all() + ) + if not segments: + return + real_deal_segmment_ids = [] + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = False + segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.disabled_by = current_user.id + db.session.add(segment) + real_deal_segmment_ids.append(segment.id) + db.session.commit() + + disable_segments_from_index_task.delay(real_deal_segmment_ids, dataset.id, document.id) + else: + raise InvalidActionError() + + @classmethod + def create_child_chunk( + cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset + ) -> ChildChunk: + lock_name = "add_child_lock_{}".format(segment.id) + with redis_client.lock(lock_name, timeout=20): + index_node_id = str(uuid.uuid4()) + index_node_hash = helper.generate_text_hash(content) + child_chunk_count = ( + db.session.query(ChildChunk) + .filter( + ChildChunk.tenant_id == current_user.current_tenant_id, + ChildChunk.dataset_id == dataset.id, + ChildChunk.document_id == document.id, + ChildChunk.segment_id == segment.id, + ) + .count() + ) + max_position = ( + db.session.query(func.max(ChildChunk.position)) + .filter( + ChildChunk.tenant_id == current_user.current_tenant_id, + ChildChunk.dataset_id == dataset.id, + ChildChunk.document_id == document.id, + ChildChunk.segment_id == segment.id, + ) + .scalar() + ) + child_chunk = ChildChunk( + tenant_id=current_user.current_tenant_id, + dataset_id=dataset.id, + document_id=document.id, + segment_id=segment.id, + position=max_position + 1, + index_node_id=index_node_id, + index_node_hash=index_node_hash, + content=content, + word_count=len(content), + type="customized", + created_by=current_user.id, + ) + db.session.add(child_chunk) + # save vector index + try: + VectorService.create_child_chunk_vector(child_chunk, dataset) + except Exception as e: + logging.exception("create child chunk index failed") + db.session.rollback() + raise ChildChunkIndexingError(str(e)) + db.session.commit() + + return child_chunk + + @classmethod + def update_child_chunks( + cls, + child_chunks_update_args: list[ChildChunkUpdateArgs], + segment: DocumentSegment, + document: Document, + dataset: Dataset, + ) -> list[ChildChunk]: + child_chunks = ( + db.session.query(ChildChunk) + .filter( + ChildChunk.dataset_id == dataset.id, + ChildChunk.document_id == document.id, + ChildChunk.segment_id == segment.id, + ) + .all() + ) + child_chunks_map = {chunk.id: chunk for chunk in child_chunks} + + new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], [] + + for child_chunk_update_args in child_chunks_update_args: + if child_chunk_update_args.id: + child_chunk = child_chunks_map.pop(child_chunk_update_args.id, None) + if child_chunk: + if child_chunk.content != child_chunk_update_args.content: + child_chunk.content = child_chunk_update_args.content + child_chunk.word_count = len(child_chunk.content) + child_chunk.updated_by = current_user.id + child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + child_chunk.type = "customized" + update_child_chunks.append(child_chunk) + else: + new_child_chunks_args.append(child_chunk_update_args) + if child_chunks_map: + delete_child_chunks = list(child_chunks_map.values()) + try: + if update_child_chunks: + db.session.bulk_save_objects(update_child_chunks) + + if delete_child_chunks: + for child_chunk in delete_child_chunks: + db.session.delete(child_chunk) + if new_child_chunks_args: + child_chunk_count = len(child_chunks) + for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1): + index_node_id = str(uuid.uuid4()) + index_node_hash = helper.generate_text_hash(args.content) + child_chunk = ChildChunk( + tenant_id=current_user.current_tenant_id, + dataset_id=dataset.id, + document_id=document.id, + segment_id=segment.id, + position=position, + index_node_id=index_node_id, + index_node_hash=index_node_hash, + content=args.content, + word_count=len(args.content), + type="customized", + created_by=current_user.id, + ) + + db.session.add(child_chunk) + db.session.flush() + new_child_chunks.append(child_chunk) + VectorService.update_child_chunk_vector(new_child_chunks, update_child_chunks, delete_child_chunks, dataset) + db.session.commit() + except Exception as e: + logging.exception("update child chunk index failed") + db.session.rollback() + raise ChildChunkIndexingError(str(e)) + return sorted(new_child_chunks + update_child_chunks, key=lambda x: x.position) + + @classmethod + def update_child_chunk( + cls, + content: str, + child_chunk: ChildChunk, + segment: DocumentSegment, + document: Document, + dataset: Dataset, + ) -> ChildChunk: + try: + child_chunk.content = content + child_chunk.word_count = len(content) + child_chunk.updated_by = current_user.id + child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + child_chunk.type = "customized" + db.session.add(child_chunk) + VectorService.update_child_chunk_vector([], [child_chunk], [], dataset) + db.session.commit() + except Exception as e: + logging.exception("update child chunk index failed") + db.session.rollback() + raise ChildChunkIndexingError(str(e)) + return child_chunk + + @classmethod + def delete_child_chunk(cls, child_chunk: ChildChunk, dataset: Dataset): + db.session.delete(child_chunk) + try: + VectorService.delete_child_chunk_vector(child_chunk, dataset) + except Exception as e: + logging.exception("delete child chunk index failed") + db.session.rollback() + raise ChildChunkDeleteIndexError(str(e)) + db.session.commit() + + @classmethod + def get_child_chunks( + cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None + ): + query = ChildChunk.query.filter_by( + tenant_id=current_user.current_tenant_id, + dataset_id=dataset_id, + document_id=document_id, + segment_id=segment_id, + ).order_by(ChildChunk.position.asc()) + if keyword: + query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) + return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + class DatasetCollectionBindingService: @classmethod @@ -1680,6 +2056,8 @@ def get_dataset_collection_binding_by_id_and_type( .order_by(DatasetCollectionBinding.created_at) .first() ) + if not dataset_collection_binding: + raise ValueError("Dataset collection binding not found") return dataset_collection_binding diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index 92098f06c..3c3f97044 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -8,8 +8,8 @@ class EnterpriseRequest: secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") proxies = { - "http": None, - "https": None, + "http": "", + "https": "", } @classmethod diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 449b79f33..76d9c2881 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -1,4 +1,5 @@ -from typing import Optional +from enum import Enum +from typing import Literal, Optional from pydantic import BaseModel @@ -8,3 +9,112 @@ class SegmentUpdateEntity(BaseModel): answer: Optional[str] = None keywords: Optional[list[str]] = None enabled: Optional[bool] = None + + +class ParentMode(str, Enum): + FULL_DOC = "full-doc" + PARAGRAPH = "paragraph" + + +class NotionIcon(BaseModel): + type: str + url: Optional[str] = None + emoji: Optional[str] = None + + +class NotionPage(BaseModel): + page_id: str + page_name: str + page_icon: Optional[NotionIcon] = None + type: str + + +class NotionInfo(BaseModel): + workspace_id: str + pages: list[NotionPage] + + +class WebsiteInfo(BaseModel): + provider: str + job_id: str + urls: list[str] + only_main_content: bool = True + + +class FileInfo(BaseModel): + file_ids: list[str] + + +class InfoList(BaseModel): + data_source_type: Literal["upload_file", "notion_import", "website_crawl"] + notion_info_list: Optional[list[NotionInfo]] = None + file_info_list: Optional[FileInfo] = None + website_info_list: Optional[WebsiteInfo] = None + + +class DataSource(BaseModel): + info_list: InfoList + + +class PreProcessingRule(BaseModel): + id: str + enabled: bool + + +class Segmentation(BaseModel): + separator: str = "\n" + max_tokens: int + chunk_overlap: int = 0 + + +class Rule(BaseModel): + pre_processing_rules: Optional[list[PreProcessingRule]] = None + segmentation: Optional[Segmentation] = None + parent_mode: Optional[Literal["full-doc", "paragraph"]] = None + subchunk_segmentation: Optional[Segmentation] = None + + +class ProcessRule(BaseModel): + mode: Literal["automatic", "custom", "hierarchical"] + rules: Optional[Rule] = None + + +class RerankingModel(BaseModel): + reranking_provider_name: Optional[str] = None + reranking_model_name: Optional[str] = None + + +class RetrievalModel(BaseModel): + search_method: Literal["hybrid_search", "semantic_search", "full_text_search"] + reranking_enable: bool + reranking_model: Optional[RerankingModel] = None + top_k: int + score_threshold_enabled: bool + score_threshold: Optional[float] = None + + +class KnowledgeConfig(BaseModel): + original_document_id: Optional[str] = None + duplicate: bool = True + indexing_technique: Literal["high_quality", "economy"] + data_source: DataSource + process_rule: Optional[ProcessRule] = None + retrieval_model: Optional[RetrievalModel] = None + doc_form: str = "text_model" + doc_language: str = "English" + embedding_model: Optional[str] = None + embedding_model_provider: Optional[str] = None + name: Optional[str] = None + + +class SegmentUpdateArgs(BaseModel): + content: Optional[str] = None + answer: Optional[str] = None + keywords: Optional[list[str]] = None + regenerate_child_chunks: bool = False + enabled: Optional[bool] = None + + +class ChildChunkUpdateArgs(BaseModel): + id: Optional[str] = None + content: str diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index c519f0b0e..f1417c6cb 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -4,7 +4,10 @@ from pydantic import BaseModel, ConfigDict from configs import dify_config -from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity +from core.entities.model_entities import ( + ModelWithProviderEntity, + ProviderModelWithStatusEntity, +) from core.entities.provider_entities import QuotaConfiguration from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ModelType @@ -148,7 +151,8 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity): Model with provider entity. """ - provider: SimpleProviderEntityResponse + # FIXME type error ignore here + provider: SimpleProviderEntityResponse # type: ignore def __init__(self, model: ModelWithProviderEntity) -> None: super().__init__(**model.model_dump()) diff --git a/api/services/errors/base.py b/api/services/errors/base.py index 4d39f956b..35ea28468 100644 --- a/api/services/errors/base.py +++ b/api/services/errors/base.py @@ -1,6 +1,6 @@ from typing import Optional -class BaseServiceError(Exception): +class BaseServiceError(ValueError): def __init__(self, description: Optional[str] = None): self.description = description diff --git a/api/services/errors/chunk.py b/api/services/errors/chunk.py new file mode 100644 index 000000000..75bf4d5d5 --- /dev/null +++ b/api/services/errors/chunk.py @@ -0,0 +1,9 @@ +from services.errors.base import BaseServiceError + + +class ChildChunkIndexingError(BaseServiceError): + description = "{message}" + + +class ChildChunkDeleteIndexError(BaseServiceError): + description = "{message}" diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 7be20301a..898624066 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,7 +1,7 @@ import json from copy import deepcopy from datetime import UTC, datetime -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast import httpx import validators @@ -45,7 +45,10 @@ def validate_api_list(cls, api_settings: dict): @staticmethod def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis: - ExternalDatasetService.check_endpoint_and_api_key(args.get("settings")) + settings = args.get("settings") + if settings is None: + raise ValueError("settings is required") + ExternalDatasetService.check_endpoint_and_api_key(settings) external_knowledge_api = ExternalKnowledgeApis( tenant_id=tenant_id, created_by=user_id, @@ -86,11 +89,16 @@ def check_endpoint_and_api_key(settings: dict): @staticmethod def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis: - return ExternalKnowledgeApis.query.filter_by(id=external_knowledge_api_id).first() + external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( + id=external_knowledge_api_id + ).first() + if external_knowledge_api is None: + raise ValueError("api template not found") + return external_knowledge_api @staticmethod def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis: - external_knowledge_api = ExternalKnowledgeApis.query.filter_by( + external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( id=external_knowledge_api_id, tenant_id=tenant_id ).first() if external_knowledge_api is None: @@ -127,7 +135,7 @@ def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bo @staticmethod def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: - external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( + external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by( dataset_id=dataset_id, tenant_id=tenant_id ).first() if not external_knowledge_binding: @@ -163,8 +171,9 @@ def process_external_api( "follow_redirects": True, } - response = getattr(ssrf_proxy, settings.request_method)(data=json.dumps(settings.params), files=files, **kwargs) - + response: httpx.Response = getattr(ssrf_proxy, settings.request_method)( + data=json.dumps(settings.params), files=files, **kwargs + ) return response @staticmethod @@ -265,15 +274,15 @@ def fetch_external_knowledge_retrieval( "knowledge_id": external_knowledge_binding.external_knowledge_id, } - external_knowledge_api_setting = { - "url": f"{settings.get('endpoint')}/retrieval", - "request_method": "post", - "headers": headers, - "params": request_params, - } response = ExternalDatasetService.process_external_api( - ExternalKnowledgeApiSetting(**external_knowledge_api_setting), None + ExternalKnowledgeApiSetting( + url=f"{settings.get('endpoint')}/retrieval", + request_method="post", + headers=headers, + params=request_params, + ), + None, ) if response.status_code == 200: - return response.json().get("records", []) + return cast(list[Any], response.json().get("records", [])) return [] diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 67a592126..ffdef0708 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -75,7 +75,7 @@ def get_features(cls, tenant_id: str) -> FeatureModel: cls._fulfill_params_from_env(features) - if dify_config.BILLING_ENABLED: + if dify_config.BILLING_ENABLED and tenant_id: cls._fulfill_params_from_billing_api(features, tenant_id) return features diff --git a/api/services/file_service.py b/api/services/file_service.py index b12b95ca1..d417e8173 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -3,7 +3,7 @@ import uuid from typing import Any, Literal, Union -from flask_login import current_user +from flask_login import current_user # type: ignore from werkzeug.exceptions import NotFound from configs import dify_config @@ -61,14 +61,14 @@ def upload_file( # end_user current_tenant_id = user.tenant_id - file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension + file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension # save file to storage storage.save(file_key, content) # save file to db upload_file = UploadFile( - tenant_id=current_tenant_id, + tenant_id=current_tenant_id or "", storage_type=dify_config.STORAGE_TYPE, key=file_key, name=filename, diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 7957b4dc8..e9176fc1c 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,12 +1,13 @@ import logging import time +from typing import Any from core.rag.datasource.retrieval_service import RetrievalService from core.rag.models.document import Document from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from models.account import Account -from models.dataset import Dataset, DatasetQuery, DocumentSegment +from models.dataset import Dataset, DatasetQuery default_retrieval_model = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, @@ -24,7 +25,7 @@ def retrieve( dataset: Dataset, query: str, account: Account, - retrieval_model: dict, + retrieval_model: Any, # FIXME drop this any external_retrieval_model: dict, limit: int = 10, ) -> dict: @@ -68,7 +69,7 @@ def retrieve( db.session.add(dataset_query) db.session.commit() - return cls.compact_retrieve_response(dataset, query, all_documents) + return cls.compact_retrieve_response(query, all_documents) # type: ignore @classmethod def external_retrieve( @@ -102,45 +103,21 @@ def external_retrieve( db.session.add(dataset_query) db.session.commit() - return cls.compact_external_retrieve_response(dataset, query, all_documents) + return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) @classmethod - def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): - records = [] - - for document in documents: - index_node_id = document.metadata["doc_id"] - - segment = ( - db.session.query(DocumentSegment) - .filter( - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.index_node_id == index_node_id, - ) - .first() - ) - - if not segment: - continue - - record = { - "segment": segment, - "score": document.metadata.get("score", None), - } - - records.append(record) + def compact_retrieve_response(cls, query: str, documents: list[Document]): + records = RetrievalService.format_retrieval_documents(documents) return { "query": { "content": query, }, - "records": records, + "records": [record.model_dump() for record in records], } @classmethod - def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list): + def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list) -> dict[Any, Any]: records = [] if dataset.provider == "external": for document in documents: @@ -152,11 +129,10 @@ def compact_external_retrieve_response(cls, dataset: Dataset, query: str, docume } records.append(record) return { - "query": { - "content": query, - }, + "query": {"content": query}, "records": records, } + return {"query": {"content": query}, "records": []} @classmethod def hit_testing_args_check(cls, args): diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py index 02fe1d19b..8df1a6ba1 100644 --- a/api/services/knowledge_service.py +++ b/api/services/knowledge_service.py @@ -1,4 +1,4 @@ -import boto3 +import boto3 # type: ignore from configs import dify_config diff --git a/api/services/message_service.py b/api/services/message_service.py index be2922f4c..c17122ef6 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -152,12 +152,13 @@ def pagination_by_last_id( @classmethod def create_feedback( cls, + *, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], rating: Optional[str], content: Optional[str], - ) -> MessageFeedback: + ): if not user: raise ValueError("user cannot be None") @@ -264,6 +265,8 @@ def get_suggested_questions_after_answer( ) app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs) + if not app_model_config: + raise ValueError("did not find app model config") suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict if suggested_questions_after_answer.get("enabled", False) is False: @@ -285,7 +288,7 @@ def get_suggested_questions_after_answer( ) with measure_time() as timer: - questions = LLMGenerator.generate_suggested_questions_after_answer( + questions: list[Message] = LLMGenerator.generate_suggested_questions_after_answer( tenant_id=app_model.tenant_id, histories=histories ) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index b20bda875..bacd3a8ec 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -2,7 +2,7 @@ import json import logging from json import JSONDecodeError -from typing import Optional +from typing import Optional, Union from constants import HIDDEN_VALUE from core.entities.provider_configuration import ProviderConfiguration @@ -88,11 +88,11 @@ def get_load_balancing_configs( raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) # Get provider model setting provider_model_setting = provider_configuration.get_provider_model_setting( - model_type=model_type, + model_type=model_type_enum, model=model, ) @@ -106,7 +106,7 @@ def get_load_balancing_configs( .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) .order_by(LoadBalancingModelConfig.created_at) @@ -124,7 +124,7 @@ def get_load_balancing_configs( if not inherit_config_exists: # Initialize the inherit configuration - inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type) + inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type_enum) # prepend the inherit configuration load_balancing_configs.insert(0, inherit_config) @@ -148,7 +148,7 @@ def get_load_balancing_configs( tenant_id=tenant_id, provider=provider, model=model, - model_type=model_type, + model_type=model_type_enum, config_id=load_balancing_config.id, ) @@ -214,7 +214,7 @@ def get_load_balancing_config( raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) # Get load balancing configurations load_balancing_model_config = ( @@ -222,7 +222,7 @@ def get_load_balancing_config( .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.id == config_id, ) @@ -300,7 +300,7 @@ def update_load_balancing_configs( raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) if not isinstance(configs, list): raise ValueError("Invalid load balancing configs") @@ -310,7 +310,7 @@ def update_load_balancing_configs( .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, ) .all() @@ -359,7 +359,7 @@ def update_load_balancing_configs( credentials = self._custom_credentials_validate( tenant_id=tenant_id, provider_configuration=provider_configuration, - model_type=model_type, + model_type=model_type_enum, model=model, credentials=credentials, load_balancing_model_config=load_balancing_config, @@ -395,7 +395,7 @@ def update_load_balancing_configs( credentials = self._custom_credentials_validate( tenant_id=tenant_id, provider_configuration=provider_configuration, - model_type=model_type, + model_type=model_type_enum, model=model, credentials=credentials, validate=False, @@ -405,7 +405,7 @@ def update_load_balancing_configs( load_balancing_model_config = LoadBalancingModelConfig( tenant_id=tenant_id, provider_name=provider_configuration.provider.provider, - model_type=model_type.to_origin_model_type(), + model_type=model_type_enum.to_origin_model_type(), model_name=model, name=name, encrypted_config=json.dumps(credentials), @@ -450,7 +450,7 @@ def validate_load_balancing_credentials( raise ValueError(f"Provider {provider} does not exist.") # Convert model type to ModelType - model_type = ModelType.value_of(model_type) + model_type_enum = ModelType.value_of(model_type) load_balancing_model_config = None if config_id: @@ -460,7 +460,7 @@ def validate_load_balancing_credentials( .filter( LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.provider_name == provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.id == config_id, ) @@ -474,7 +474,7 @@ def validate_load_balancing_credentials( self._custom_credentials_validate( tenant_id=tenant_id, provider_configuration=provider_configuration, - model_type=model_type, + model_type=model_type_enum, model=model, credentials=credentials, load_balancing_model_config=load_balancing_model_config, @@ -547,19 +547,14 @@ def _custom_credentials_validate( def _get_credential_schema( self, provider_configuration: ProviderConfiguration - ) -> ModelCredentialSchema | ProviderCredentialSchema: - """ - Get form schemas. - :param provider_configuration: provider configuration - :return: - """ - # Get credential form schemas from model credential schema or provider credential schema + ) -> Union[ModelCredentialSchema, ProviderCredentialSchema]: + """Get form schemas.""" if provider_configuration.provider.model_credential_schema: - credential_schema = provider_configuration.provider.model_credential_schema + return provider_configuration.provider.model_credential_schema + elif provider_configuration.provider.provider_credential_schema: + return provider_configuration.provider.provider_credential_schema else: - credential_schema = provider_configuration.provider.provider_credential_schema - - return credential_schema + raise ValueError("No credential schema found") def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None: """ diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 384a072b3..b10c5ad2d 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -7,7 +7,7 @@ import requests from flask import current_app -from core.entities.model_entities import ModelStatus, ProviderModelWithStatusEntity +from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity from core.model_runtime.entities.model_entities import ModelType, ParameterRule from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -100,23 +100,15 @@ def get_models_by_provider(self, tenant_id: str, provider: str) -> list[ModelWit ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider) ] - def get_provider_credentials(self, tenant_id: str, provider: str) -> dict: + def get_provider_credentials(self, tenant_id: str, provider: str): """ get provider credentials. - - :param tenant_id: - :param provider: - :return: """ - # Get all provider configurations of the current workspace provider_configurations = self.provider_manager.get_configurations(tenant_id) - - # Get provider configuration provider_configuration = provider_configurations.get(provider) if not provider_configuration: raise ValueError(f"Provider {provider} does not exist.") - # Get provider custom credentials from workspace return provider_configuration.get_custom_credentials(obfuscated=True) def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None: @@ -176,7 +168,7 @@ def remove_provider_credentials(self, tenant_id: str, provider: str) -> None: # Remove custom provider credentials. provider_configuration.delete_custom_credentials() - def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> dict: + def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str): """ get model credentials. @@ -287,7 +279,7 @@ def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[Prov models = provider_configurations.get_models(model_type=ModelType.value_of(model_type)) # Group models by provider - provider_models = {} + provider_models: dict[str, list[ModelWithProviderEntity]] = {} for model in models: if model.provider.provider not in provider_models: provider_models[model.provider.provider] = [] @@ -362,7 +354,7 @@ def get_model_parameter_rules(self, tenant_id: str, provider: str, model: str) - return [] # Call get_parameter_rules method of model instance to get model parameter rules - return model_type_instance.get_parameter_rules(model=model, credentials=credentials) + return list(model_type_instance.get_parameter_rules(model=model, credentials=credentials)) def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]: """ @@ -422,6 +414,7 @@ def get_model_provider_icon( """ provider_instance = model_provider_factory.get_provider_instance(provider) provider_schema = provider_instance.get_provider_schema() + file_name: str | None = None if icon_type.lower() == "icon_small": if not provider_schema.icon_small: @@ -439,6 +432,8 @@ def get_model_provider_icon( file_name = provider_schema.icon_large.zh_Hans else: file_name = provider_schema.icon_large.en_US + if not file_name: + return None, None root_path = current_app.root_path provider_instance_path = os.path.dirname( @@ -524,7 +519,7 @@ def disable_model(self, tenant_id: str, provider: str, model: str, model_type: s def free_quota_submit(self, tenant_id: str, provider: str): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") - api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") + api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "") api_url = api_base_url + "/api/v1/providers/apply" headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} @@ -545,7 +540,7 @@ def free_quota_submit(self, tenant_id: str, provider: str): def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") - api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") + api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "") api_url = api_base_url + "/api/v1/providers/qualification-verify" headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index dfb21e767..082afeed8 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -1,3 +1,5 @@ +from typing import Optional + from core.moderation.factory import ModerationFactory, ModerationOutputsResult from extensions.ext_database import db from models.model import App, AppModelConfig @@ -5,7 +7,7 @@ class ModerationService: def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult: - app_model_config: AppModelConfig = None + app_model_config: Optional[AppModelConfig] = None app_model_config = ( db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 1160a1f27..78340d2bc 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -1,3 +1,5 @@ +from typing import Optional + from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map from extensions.ext_database import db from models.model import App, TraceAppConfig @@ -12,7 +14,7 @@ def get_tracing_app_config(cls, app_id: str, tracing_provider: str): :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = ( + trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -22,7 +24,10 @@ def get_tracing_app_config(cls, app_id: str, tracing_provider: str): return None # decrypt_token and obfuscated_token - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + tenant = db.session.query(App).filter(App.id == app_id).first() + if not tenant: + return None + tenant_id = tenant.tenant_id decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config( tenant_id, tracing_provider, trace_config_data.tracing_config ) @@ -54,6 +59,15 @@ def get_tracing_app_config(cls, app_id: str, tracing_provider: str): except Exception: new_decrypt_tracing_config.update({"project_url": "https://smith.langchain.com/"}) + if tracing_provider == "opik" and ( + "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url") + ): + try: + project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider) + new_decrypt_tracing_config.update({"project_url": project_url}) + except Exception: + new_decrypt_tracing_config.update({"project_url": "https://www.comet.com/opik/"}) + trace_config_data.tracing_config = new_decrypt_tracing_config return trace_config_data.to_dict() @@ -73,8 +87,9 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["other_keys"], ) - default_config_instance = config_class(**tracing_config) - for key in other_keys: + # FIXME: ignore type error + default_config_instance = config_class(**tracing_config) # type: ignore + for key in other_keys: # type: ignore if key in tracing_config and tracing_config[key] == "": tracing_config[key] = getattr(default_config_instance, key, None) @@ -86,13 +101,13 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c if tracing_provider == "langfuse": project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider) project_url = "{host}/project/{key}".format(host=tracing_config.get("host"), key=project_key) - elif tracing_provider == "langsmith": + elif tracing_provider in ("langsmith", "opik"): project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider) else: project_url = None # check if trace config already exists - trace_config_data: TraceAppConfig = ( + trace_config_data: Optional[TraceAppConfig] = ( db.session.query(TraceAppConfig) .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .first() @@ -102,7 +117,10 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c return None # get tenant id - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + tenant = db.session.query(App).filter(App.id == app_id).first() + if not tenant: + return None + tenant_id = tenant.tenant_id tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config) if project_url: tracing_config["project_url"] = project_url @@ -139,7 +157,10 @@ def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c return None # get tenant id - tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id + tenant = db.session.query(App).filter(App.id == app_id).first() + if not tenant: + return None + tenant_id = tenant.tenant_id tracing_config = OpsTraceManager.encrypt_tracing_config( tenant_id, tracing_provider, tracing_config, current_trace_config.tracing_config ) diff --git a/api/services/recommend_app/buildin/buildin_retrieval.py b/api/services/recommend_app/buildin/buildin_retrieval.py index 4704d533a..523aebeed 100644 --- a/api/services/recommend_app/buildin/buildin_retrieval.py +++ b/api/services/recommend_app/buildin/buildin_retrieval.py @@ -41,7 +41,7 @@ def _get_builtin_data(cls) -> dict: Path(path.join(root_path, "constants", "recommended_apps.json")).read_text(encoding="utf-8") ) - return cls.builtin_data + return cls.builtin_data or {} @classmethod def fetch_recommended_apps_from_builtin(cls, language: str) -> dict: @@ -50,8 +50,8 @@ def fetch_recommended_apps_from_builtin(cls, language: str) -> dict: :param language: language :return: """ - builtin_data = cls._get_builtin_data() - return builtin_data.get("recommended_apps", {}).get(language) + builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + return builtin_data.get("recommended_apps", {}).get(language, {}) @classmethod def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]: @@ -60,5 +60,5 @@ def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict :param app_id: App ID :return: """ - builtin_data = cls._get_builtin_data() + builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() return builtin_data.get("app_details", {}).get(app_id) diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py index b0607a213..80e1aefc0 100644 --- a/api/services/recommend_app/remote/remote_retrieval.py +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -47,8 +47,8 @@ def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optiona response = requests.get(url, timeout=(3, 10)) if response.status_code != 200: return None - - return response.json() + data: dict = response.json() + return data @classmethod def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: @@ -63,7 +63,7 @@ def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: if response.status_code != 200: raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") - result = response.json() + result: dict = response.json() if "categories" in result: result["categories"] = sorted(result["categories"]) diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 4660316fc..54c584551 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -33,5 +33,5 @@ def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() - result = retrieval_instance.get_recommend_app_detail(app_id) + result: dict = retrieval_instance.get_recommend_app_detail(app_id) return result diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 9fe3cecce..4cb870011 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -13,6 +13,8 @@ class SavedMessageService: def pagination_by_last_id( cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int ) -> InfiniteScrollPagination: + if not user: + raise ValueError("User is required") saved_messages = ( db.session.query(SavedMessage) .filter( @@ -31,6 +33,8 @@ def pagination_by_last_id( @classmethod def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + if not user: + return saved_message = ( db.session.query(SavedMessage) .filter( @@ -59,6 +63,8 @@ def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_i @classmethod def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): + if not user: + return saved_message = ( db.session.query(SavedMessage) .filter( diff --git a/api/services/tag_service.py b/api/services/tag_service.py index a374bdcf0..960060163 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -1,7 +1,7 @@ import uuid from typing import Optional -from flask_login import current_user +from flask_login import current_user # type: ignore from sqlalchemy import func from werkzeug.exceptions import NotFound @@ -21,7 +21,7 @@ def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = Non if keyword: query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) query = query.group_by(Tag.id) - results = query.order_by(Tag.created_at.desc()).all() + results: list = query.order_by(Tag.created_at.desc()).all() return results @staticmethod diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 78a80f70a..988f9df92 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -1,6 +1,7 @@ import json import logging -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional, cast from httpx import get @@ -28,12 +29,12 @@ class ApiToolManageService: @staticmethod - def parser_api_schema(schema: str) -> list[ApiToolBundle]: + def parser_api_schema(schema: str) -> Mapping[str, Any]: """ parse api schema to tool bundle """ try: - warnings = {} + warnings: dict[str, str] = {} try: tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings) except Exception as e: @@ -68,13 +69,16 @@ def parser_api_schema(schema: str) -> list[ApiToolBundle]: ), ] - return jsonable_encoder( - { - "schema_type": schema_type, - "parameters_schema": tool_bundles, - "credentials_schema": credentials_schema, - "warning": warnings, - } + return cast( + Mapping, + jsonable_encoder( + { + "schema_type": schema_type, + "parameters_schema": tool_bundles, + "credentials_schema": credentials_schema, + "warning": warnings, + } + ), ) except Exception as e: raise ValueError(f"invalid schema: {str(e)}") @@ -129,7 +133,7 @@ def create_api_tool_provider( raise ValueError(f"provider {provider_name} already exists") # parse openapi to tool bundle - extra_info = {} + extra_info: dict[str, str] = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) @@ -262,9 +266,8 @@ def update_api_tool_provider( if provider is None: raise ValueError(f"api provider {provider_name} does not exists") - # parse openapi to tool bundle - extra_info = {} + extra_info: dict[str, str] = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) @@ -416,13 +419,13 @@ def test_api_tool_preview( provider_controller.validate_credentials_format(credentials) # get tool tool = provider_controller.get_tool(tool_name) - tool = tool.fork_tool_runtime( + runtime_tool = tool.fork_tool_runtime( runtime={ "credentials": credentials, "tenant_id": tenant_id, } ) - result = tool.validate_credentials(credentials, parameters) + result = runtime_tool.validate_credentials(credentials, parameters) except Exception as e: return {"error": str(e)} @@ -454,7 +457,7 @@ def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) - for tool in tools: + for tool in tools or []: user_provider.tools.append( ToolTransformService.tool_to_user_tool( tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index fada881fd..21adbb007 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -50,8 +50,8 @@ def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str credentials = builtin_provider.credentials credentials = tool_provider_configurations.decrypt_tool_credentials(credentials) - result = [] - for tool in tools: + result: list[UserTool] = [] + for tool in tools or []: result.append( ToolTransformService.tool_to_user_tool( tool=tool, @@ -217,6 +217,8 @@ def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: name_func=lambda x: x.identity.name, ): continue + if provider_controller.identity is None: + continue # convert provider controller to user provider user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( @@ -229,7 +231,7 @@ def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: ToolTransformService.repack_provider(user_builtin_provider) tools = provider_controller.get_tools() - for tool in tools: + for tool in tools or []: user_builtin_provider.tools.append( ToolTransformService.tool_to_user_tool( tenant_id=tenant_id, diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index a4aa870dc..6e3a45be0 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import Optional, Union +from typing import Optional, Union, cast from configs import dify_config from core.tools.entities.api_entities import UserTool, UserToolProvider @@ -35,7 +35,7 @@ def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str return url_prefix + "builtin/" + provider_name + "/icon" elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}: try: - return json.loads(icon) + return cast(dict, json.loads(icon)) except: return {"background": "#252525", "content": "\ud83d\ude01"} @@ -53,8 +53,11 @@ def repack_provider(provider: Union[dict, UserToolProvider]): provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"] ) elif isinstance(provider, UserToolProvider): - provider.icon = ToolTransformService.get_tool_provider_icon_url( - provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon + provider.icon = cast( + str, + ToolTransformService.get_tool_provider_icon_url( + provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon + ), ) @staticmethod @@ -66,6 +69,9 @@ def builtin_provider_to_user_provider( """ convert provider controller to user provider """ + if provider_controller.identity is None: + raise ValueError("provider identity is None") + result = UserToolProvider( id=provider_controller.identity.name, author=provider_controller.identity.author, @@ -93,7 +99,8 @@ def builtin_provider_to_user_provider( # get credentials schema schema = provider_controller.get_credentials_schema() for name, value in schema.items(): - result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type) + assert result.masked_credentials is not None, "masked credentials is None" + result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(str(value.type)) # check if the provider need credentials if not provider_controller.need_credentials: @@ -149,6 +156,9 @@ def workflow_provider_to_user_provider( """ convert provider controller to user provider """ + if provider_controller.identity is None: + raise ValueError("provider identity is None") + return UserToolProvider( id=provider_controller.provider_id, author=provider_controller.identity.author, @@ -180,6 +190,8 @@ def api_provider_to_user_provider( convert provider controller to user provider """ username = "Anonymous" + if db_provider.user is None: + raise ValueError(f"user is None for api provider {db_provider.id}") try: username = db_provider.user.name except Exception as e: @@ -256,19 +268,22 @@ def tool_to_user_tool( if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: current_parameters.append(runtime_parameter) + if tool.identity is None: + raise ValueError("tool identity is None") + return UserTool( author=tool.identity.author, name=tool.identity.name, label=tool.identity.label, - description=tool.description.human, + description=tool.description.human if tool.description else "", # type: ignore parameters=current_parameters, labels=labels, ) if isinstance(tool, ApiToolBundle): return UserTool( author=tool.author, - name=tool.operation_id, - label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id), + name=tool.operation_id or "", + label=I18nObject(en_US=tool.operation_id or "", zh_Hans=tool.operation_id or ""), description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""), parameters=tool.parameters, labels=labels, diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 318107beb..69430de43 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -6,8 +6,10 @@ from sqlalchemy import or_ from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools.entities.api_entities import UserToolProvider +from core.tools.entities.api_entities import UserTool, UserToolProvider +from core.tools.provider.tool_provider import ToolProviderController from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController +from core.tools.tool.tool import Tool from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from extensions.ext_database import db @@ -32,7 +34,7 @@ def create_workflow_tool( label: str, icon: dict, description: str, - parameters: Mapping[str, Any], + parameters: list[Mapping[str, Any]], privacy_policy: str = "", labels: Optional[list[str]] = None, ) -> dict: @@ -97,7 +99,7 @@ def update_workflow_tool( label: str, icon: dict, description: str, - parameters: list[dict], + parameters: list[Mapping[str, Any]], privacy_policy: str = "", labels: Optional[list[str]] = None, ) -> dict: @@ -131,7 +133,7 @@ def update_workflow_tool( if existing_workflow_tool_provider is not None: raise ValueError(f"Tool with name {name} already exists") - workflow_tool_provider: WorkflowToolProvider = ( + workflow_tool_provider: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() @@ -140,14 +142,14 @@ def update_workflow_tool( if workflow_tool_provider is None: raise ValueError(f"Tool {workflow_tool_id} not found") - app: App = ( + app: Optional[App] = ( db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first() ) if app is None: raise ValueError(f"App {workflow_tool_provider.app_id} not found") - workflow: Workflow = app.workflow + workflow: Optional[Workflow] = app.workflow if workflow is None: raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") @@ -193,7 +195,7 @@ def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserTo # skip deleted tools pass - labels = ToolLabelManager.get_tools_labels(tools) + labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)]) result = [] @@ -202,10 +204,11 @@ def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserTo provider_controller=tool, labels=labels.get(tool.provider_id, []) ) ToolTransformService.repack_provider(user_tool_provider) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + continue user_tool_provider.tools = [ - ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, []) - ) + ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=labels.get(tool.provider_id, [])) ] result.append(user_tool_provider) @@ -236,7 +239,7 @@ def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_too :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider = ( + db_tool: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() @@ -245,13 +248,19 @@ def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_too if db_tool is None: raise ValueError(f"Tool {workflow_tool_id} not found") - workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() + workflow_app: Optional[App] = ( + db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() + ) if workflow_app is None: raise ValueError(f"App {db_tool.app_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + raise ValueError(f"Tool {workflow_tool_id} not found") + return { "name": db_tool.name, "label": db_tool.label, @@ -261,9 +270,9 @@ def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_too "description": db_tool.description, "parameters": jsonable_encoder(db_tool.parameter_configurations), "tool": ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) + to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool) ), - "synced": workflow_app.workflow.version == db_tool.version, + "synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False, "privacy_policy": db_tool.privacy_policy, } @@ -276,7 +285,7 @@ def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_ :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider = ( + db_tool: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) .first() @@ -285,12 +294,17 @@ def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_ if db_tool is None: raise ValueError(f"Tool {workflow_app_id} not found") - workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() + workflow_app: Optional[App] = ( + db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() + ) if workflow_app is None: raise ValueError(f"App {db_tool.app_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + raise ValueError(f"Tool {workflow_app_id} not found") return { "name": db_tool.name, @@ -301,14 +315,14 @@ def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_ "description": db_tool.description, "parameters": jsonable_encoder(db_tool.parameter_configurations), "tool": ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) + to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool) ), - "synced": workflow_app.workflow.version == db_tool.version, + "synced": workflow_app.workflow.version == db_tool.version if workflow_app.workflow else False, "privacy_policy": db_tool.privacy_policy, } @classmethod - def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]: + def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]: """ List workflow tool provider tools. :param user_id: the user id @@ -316,7 +330,7 @@ def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_ :param workflow_app_id: the workflow app id :return: the list of tools """ - db_tool: WorkflowToolProvider = ( + db_tool: Optional[WorkflowToolProvider] = ( db.session.query(WorkflowToolProvider) .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) .first() @@ -326,9 +340,8 @@ def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_ raise ValueError(f"Tool {workflow_tool_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) + to_user_tool: Optional[list[Tool]] = tool.get_tools(user_id, tenant_id) + if to_user_tool is None or len(to_user_tool) == 0: + raise ValueError(f"Tool {workflow_tool_id} not found") - return [ - ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) - ) - ] + return [ToolTransformService.tool_to_user_tool(to_user_tool[0], labels=ToolLabelManager.get_tool_labels(tool))] diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 3c6735133..92422bf29 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,40 +1,70 @@ from typing import Optional +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import Document -from models.dataset import Dataset, DocumentSegment +from extensions.ext_database import db +from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment +from models.dataset import Document as DatasetDocument +from services.entities.knowledge_entities.knowledge_entities import ParentMode class VectorService: @classmethod def create_segments_vector( - cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset + cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str ): documents = [] - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - documents.append(document) - if dataset.indexing_technique == "high_quality": - # save vector index - vector = Vector(dataset=dataset) - vector.add_texts(documents, duplicate_check=True) - # save keyword index - keyword = Keyword(dataset) + for segment in segments: + if doc_form == IndexType.PARENT_CHILD_INDEX: + document = DatasetDocument.query.filter_by(id=segment.document_id).first() + # get the process rule + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .first() + ) + if not processing_rule: + raise ValueError("No processing rule found.") + # get embedding model instance + if dataset.indexing_technique == "high_quality": + # check embedding model setting + model_manager = ModelManager() - if keywords_list and len(keywords_list) > 0: - keyword.add_texts(documents, keywords_list=keywords_list) - else: - keyword.add_texts(documents) + if dataset.embedding_model_provider: + embedding_model_instance = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + else: + embedding_model_instance = model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + else: + raise ValueError("The knowledge base index technique is not high quality!") + cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False) + else: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + documents.append(document) + if len(documents) > 0: + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) @classmethod def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): @@ -65,3 +95,123 @@ def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentS keyword.add_texts([document], keywords_list=[keywords]) else: keyword.add_texts([document]) + + @classmethod + def generate_child_chunks( + cls, + segment: DocumentSegment, + dataset_document: DatasetDocument, + dataset: Dataset, + embedding_model_instance: ModelInstance, + processing_rule: DatasetProcessRule, + regenerate: bool = False, + ): + index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() + if regenerate: + # delete child chunks + index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True) + + # generate child chunks + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + # use full doc mode to generate segment's child chunk + processing_rule_dict = processing_rule.to_dict() + processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value + documents = index_processor.transform( + [document], + embedding_model_instance=embedding_model_instance, + process_rule=processing_rule_dict, + tenant_id=dataset.tenant_id, + doc_language=dataset_document.doc_language, + ) + # save child chunks + if documents and documents[0].children: + index_processor.load(dataset, documents) + + for position, child_chunk in enumerate(documents[0].children, start=1): + child_segment = ChildChunk( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + document_id=dataset_document.id, + segment_id=segment.id, + position=position, + index_node_id=child_chunk.metadata["doc_id"], + index_node_hash=child_chunk.metadata["doc_hash"], + content=child_chunk.page_content, + word_count=len(child_chunk.page_content), + type="automatic", + created_by=dataset_document.created_by, + ) + db.session.add(child_segment) + db.session.commit() + + @classmethod + def create_child_chunk_vector(cls, child_segment: ChildChunk, dataset: Dataset): + child_document = Document( + page_content=child_segment.content, + metadata={ + "doc_id": child_segment.index_node_id, + "doc_hash": child_segment.index_node_hash, + "document_id": child_segment.document_id, + "dataset_id": child_segment.dataset_id, + }, + ) + if dataset.indexing_technique == "high_quality": + # save vector index + vector = Vector(dataset=dataset) + vector.add_texts([child_document], duplicate_check=True) + + @classmethod + def update_child_chunk_vector( + cls, + new_child_chunks: list[ChildChunk], + update_child_chunks: list[ChildChunk], + delete_child_chunks: list[ChildChunk], + dataset: Dataset, + ): + documents = [] + delete_node_ids = [] + for new_child_chunk in new_child_chunks: + new_child_document = Document( + page_content=new_child_chunk.content, + metadata={ + "doc_id": new_child_chunk.index_node_id, + "doc_hash": new_child_chunk.index_node_hash, + "document_id": new_child_chunk.document_id, + "dataset_id": new_child_chunk.dataset_id, + }, + ) + documents.append(new_child_document) + for update_child_chunk in update_child_chunks: + child_document = Document( + page_content=update_child_chunk.content, + metadata={ + "doc_id": update_child_chunk.index_node_id, + "doc_hash": update_child_chunk.index_node_hash, + "document_id": update_child_chunk.document_id, + "dataset_id": update_child_chunk.dataset_id, + }, + ) + documents.append(child_document) + delete_node_ids.append(update_child_chunk.index_node_id) + for delete_child_chunk in delete_child_chunks: + delete_node_ids.append(delete_child_chunk.index_node_id) + if dataset.indexing_technique == "high_quality": + # update vector index + vector = Vector(dataset=dataset) + if delete_node_ids: + vector.delete_by_ids(delete_node_ids) + if documents: + vector.add_texts(documents, duplicate_check=True) + + @classmethod + def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset): + vector = Vector(dataset=dataset) + vector.delete_by_ids([child_chunk.index_node_id]) diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index 508fe2097..f698ed308 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -26,6 +26,8 @@ def pagination_by_last_id( pinned: Optional[bool] = None, sort_by="-updated_at", ) -> InfiniteScrollPagination: + if not user: + raise ValueError("User is required") include_ids = None exclude_ids = None if pinned is not None and user: @@ -59,6 +61,8 @@ def pagination_by_last_id( @classmethod def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + if not user: + return pinned_conversation = ( db.session.query(PinnedConversation) .filter( @@ -89,6 +93,8 @@ def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, @classmethod def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): + if not user: + return pinned_conversation = ( db.session.query(PinnedConversation) .filter( diff --git a/api/services/website_service.py b/api/services/website_service.py index 230f5d781..1ad7d0399 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -1,8 +1,9 @@ import datetime import json +from typing import Any import requests -from flask_login import current_user +from flask_login import current_user # type: ignore from core.helper import encrypter from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp @@ -23,9 +24,9 @@ def document_create_args_validate(cls, args: dict): @classmethod def crawl_url(cls, args: dict) -> dict: - provider = args.get("provider") + provider = args.get("provider", "") url = args.get("url") - options = args.get("options") + options = args.get("options", "") credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) if provider == "firecrawl": # decrypt api_key @@ -164,16 +165,18 @@ def get_crawl_status(cls, job_id: str, provider: str) -> dict: return crawl_status_data @classmethod - def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None: + def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[Any, Any] | None: credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) # decrypt api_key api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + # FIXME data is redefine too many times here, use Any to ease the type checking, fix it later + data: Any if provider == "firecrawl": file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): - data = storage.load_once(file_key) - if data: - data = json.loads(data.decode("utf-8")) + d = storage.load_once(file_key) + if d: + data = json.loads(d.decode("utf-8")) else: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) result = firecrawl_app.check_crawl_status(job_id) @@ -183,22 +186,17 @@ def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str if data: for item in data: if item.get("source_url") == url: - return item + return dict(item) return None elif provider == "jinareader": - file_key = "website_files/" + job_id + ".txt" - if storage.exists(file_key): - data = storage.load_once(file_key) - if data: - data = json.loads(data.decode("utf-8")) - elif not job_id: + if not job_id: response = requests.get( f"https://r.jina.ai/{url}", headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, ) if response.json().get("code") != 200: raise ValueError("Failed to crawl") - return response.json().get("data") + return dict(response.json().get("data", {})) else: api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) response = requests.post( @@ -218,12 +216,13 @@ def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str data = response.json().get("data", {}) for item in data.get("processed", {}).values(): if item.get("data", {}).get("url") == url: - return item.get("data", {}) + return dict(item.get("data", {})) + return None else: raise ValueError("Invalid provider") @classmethod - def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None: + def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict: credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) if provider == "firecrawl": # decrypt api_key diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 90b5cc483..2b0d57bdf 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,5 +1,5 @@ import json -from typing import Optional +from typing import Any, Optional from core.app.app_config.entities import ( DatasetEntity, @@ -101,7 +101,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config) # init workflow graph - graph = {"nodes": [], "edges": []} + graph: dict[str, Any] = {"nodes": [], "edges": []} # Convert list: # - variables -> start @@ -118,7 +118,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: graph["nodes"].append(start_node) # convert to http request node - external_data_variable_node_mapping = {} + external_data_variable_node_mapping: dict[str, str] = {} if app_config.external_data_variables: http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node( app_model=app_model, @@ -199,15 +199,16 @@ def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: return workflow def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: - app_mode = AppMode.value_of(app_model.mode) - if app_mode == AppMode.AGENT_CHAT or app_model.is_agent: + app_mode_enum = AppMode.value_of(app_model.mode) + app_config: EasyUIBasedAppConfig + if app_mode_enum == AppMode.AGENT_CHAT or app_model.is_agent: app_model.mode = AppMode.AGENT_CHAT.value app_config = AgentChatAppConfigManager.get_app_config( app_model=app_model, app_model_config=app_model_config ) - elif app_mode == AppMode.CHAT: + elif app_mode_enum == AppMode.CHAT: app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config) - elif app_mode == AppMode.COMPLETION: + elif app_mode_enum == AppMode.COMPLETION: app_config = CompletionAppConfigManager.get_app_config( app_model=app_model, app_model_config=app_model_config ) @@ -302,7 +303,7 @@ def _convert_to_http_request_node( nodes.append(http_request_node) # append code node for response body parsing - code_node = { + code_node: dict[str, Any] = { "id": f"code_{index}", "position": None, "data": { @@ -401,6 +402,7 @@ def _convert_to_llm_node( ) role_prefix = None + prompts: Any = None # Chat Model if model_config.mode == LLMMode.CHAT.value: diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index f89487415..7eab0ac1d 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -27,7 +27,7 @@ def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Paginati query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id) if keyword: - keyword_like_val = f"%{args['keyword'][:30]}%" + keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u") keyword_conditions = [ WorkflowRun.inputs.ilike(keyword_like_val), WorkflowRun.outputs.ilike(keyword_like_val), diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index d8ee32390..4343596a2 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,3 +1,5 @@ +from typing import Optional + from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom @@ -92,7 +94,7 @@ def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScro return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) - def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun: + def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]: """ Get workflow run detail diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 84768d5af..9f7a9c770 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -2,7 +2,10 @@ import time from collections.abc import Sequence from datetime import UTC, datetime -from typing import Optional, cast +from typing import Any, Optional, cast +from uuid import uuid4 + +from sqlalchemy import desc from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager @@ -75,6 +78,28 @@ def get_published_workflow(self, app_model: App) -> Optional[Workflow]: return workflow + def get_all_published_workflow(self, app_model: App, page: int, limit: int) -> tuple[list[Workflow], bool]: + """ + Get published workflow with pagination + """ + if not app_model.workflow_id: + return [], False + + workflows = ( + db.session.query(Workflow) + .filter(Workflow.app_id == app_model.id) + .order_by(desc(Workflow.version)) + .offset((page - 1) * limit) + .limit(limit + 1) + .all() + ) + + has_more = len(workflows) > limit + if has_more: + workflows = workflows[:-1] + + return workflows, has_more + def sync_draft_workflow( self, *, @@ -242,7 +267,7 @@ def run_draft_workflow_node( raise ValueError("Node run failed with no run result") # single step debug mode error handling return if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: - node_error_args = { + node_error_args: dict[str, Any] = { "status": WorkflowNodeExecutionStatus.EXCEPTION, "error": node_run_result.error, "inputs": node_run_result.inputs, @@ -277,6 +302,7 @@ def run_draft_workflow_node( error = e.error workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.id = str(uuid4()) workflow_node_execution.tenant_id = app_model.tenant_id workflow_node_execution.app_id = app_model.id workflow_node_execution.workflow_id = draft_workflow.id @@ -338,7 +364,7 @@ def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> A raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow - new_app = workflow_converter.convert_to_workflow( + new_app: App = workflow_converter.convert_to_workflow( app_model=app_model, account=account, name=args.get("name", "Default Name"), diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index a49628cf6..f20a4640b 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,6 +1,6 @@ import logging -from flask_login import current_user +from flask_login import current_user # type: ignore from configs import dify_config from extensions.ext_database import db @@ -32,6 +32,7 @@ def get_tenant_info(cls, tenant: Tenant): .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) .first() ) + assert tenant_account_join is not None, "TenantAccountJoin not found" tenant_info["role"] = tenant_account_join.role # ----------------------- 二开部分Start 添加用户权限 - ---------------------- diff --git a/api/tasks/__init__.py b/api/tasks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 09be66121..bd7fcdade 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -3,15 +3,16 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from models.dataset import DatasetAutoDisableLog, DocumentSegment from models.dataset import Document as DatasetDocument -from models.dataset import DocumentSegment @shared_task(queue="dataset") @@ -37,7 +38,11 @@ def add_document_to_index_task(dataset_document_id: str): try: segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .filter( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.enabled == False, + DocumentSegment.status == "completed", + ) .order_by(DocumentSegment.position.asc()) .all() ) @@ -53,7 +58,22 @@ def add_document_to_index_task(dataset_document_id: str): "dataset_id": segment.dataset_id, }, ) - + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents documents.append(document) dataset = dataset_document.dataset @@ -65,6 +85,22 @@ def add_document_to_index_task(dataset_document_id: str): index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor.load(dataset, documents) + # delete auto disable log + db.session.query(DatasetAutoDisableLog).filter( + DatasetAutoDisableLog.document_id == dataset_document.id + ).delete() + + # update segment to enable + db.session.query(DocumentSegment).filter(DocumentSegment.document_id == dataset_document.id).update( + { + DocumentSegment.enabled: True, + DocumentSegment.disabled_at: None, + DocumentSegment.disabled_by: None, + DocumentSegment.updated_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + } + ) + db.session.commit() + end_at = time.perf_counter() logging.info( click.style( diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index 25c55bcfa..aab21a441 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index fa7e5ac91..06162b02d 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index f0f6b32b0..a6a598ce4 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.datasource.vdb.vector_factory import Vector from models.dataset import Dataset diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index a2f491351..26bf1c7c9 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 0bdcd0ecc..b42af0c7f 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.datasource.vdb.vector_factory import Vector diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index b685d84d0..8c675feaa 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py new file mode 100644 index 000000000..3bae82a5e --- /dev/null +++ b/api/tasks/batch_clean_document_task.py @@ -0,0 +1,76 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.tools.utils.web_reader_tool import get_image_upload_file_ids +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.dataset import Dataset, DocumentSegment +from models.model import UploadFile + + +@shared_task(queue="dataset") +def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]): + """ + Clean document when document deleted. + :param document_ids: document ids + :param dataset_id: dataset id + :param doc_form: doc_form + :param file_ids: file ids + + Usage: clean_document_task.delay(document_id, dataset_id) + """ + logging.info(click.style("Start batch clean documents when documents deleted", fg="green")) + start_at = time.perf_counter() + + try: + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + + if not dataset: + raise Exception("Document has no dataset") + + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all() + # check segment is exist + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + for segment in segments: + image_upload_file_ids = get_image_upload_file_ids(segment.content) + for upload_file_id in image_upload_file_ids: + image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + try: + if image_file and image_file.key: + storage.delete(image_file.key) + except Exception: + logging.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: {}".format(upload_file_id) + ) + db.session.delete(image_file) + db.session.delete(segment) + + db.session.commit() + if file_ids: + files = db.session.query(UploadFile).filter(UploadFile.id.in_(file_ids)).all() + for file in files: + try: + storage.delete(file.key) + except Exception: + logging.exception("Delete file failed when document deleted, file_id: {}".format(file.id)) + db.session.delete(file) + db.session.commit() + + end_at = time.perf_counter() + logging.info( + click.style( + "Cleaned documents when documents deleted latency: {}".format(end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("Cleaned documents when documents deleted failed") diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index dcb7009e4..dbef6b708 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -4,16 +4,16 @@ import uuid import click -from celery import shared_task +from celery import shared_task # type: ignore from sqlalchemy import func -from core.indexing_runner import IndexingRunner from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper from models.dataset import Dataset, Document, DocumentSegment +from services.vector_service import VectorService @shared_task(queue="dataset") @@ -58,12 +58,13 @@ def batch_create_segment_to_index_task( model=dataset.embedding_model, ) word_count_change = 0 + segments_to_insert: list[str] = [] # Explicitly type hint the list as List[str] for segment in content: - content = segment["content"] + content_str = segment["content"] doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) + segment_hash = helper.generate_text_hash(content_str) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) if embedding_model else 0 + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content_str]) if embedding_model else 0 max_position = ( db.session.query(func.max(DocumentSegment.position)) .filter(DocumentSegment.document_id == dataset_document.id) @@ -76,8 +77,8 @@ def batch_create_segment_to_index_task( index_node_id=doc_id, index_node_hash=segment_hash, position=max_position + 1 if max_position else 1, - content=content, - word_count=len(content), + content=content_str, + word_count=len(content_str), tokens=tokens, created_by=user_id, indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), @@ -90,12 +91,12 @@ def batch_create_segment_to_index_task( word_count_change += segment_document.word_count db.session.add(segment_document) document_segments.append(segment_document) + segments_to_insert.append(str(segment)) # Cast to string if needed # update document word count dataset_document.word_count += word_count_change db.session.add(dataset_document) # add index to db - indexing_runner = IndexingRunner() - indexing_runner.batch_add_segments(document_segments, dataset) + VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) db.session.commit() redis_client.setex(indexing_cache_key, 600, "completed") end_at = time.perf_counter() diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index a555fb287..dfc7a896f 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids @@ -62,7 +62,7 @@ def clean_dataset_task( if doc_form is None: raise ValueError("Index type must be specified.") index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, None) + index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True) for document in documents: db.session.delete(document) @@ -71,6 +71,8 @@ def clean_dataset_task( image_upload_file_ids = get_image_upload_file_ids(segment.content) for upload_file_id in image_upload_file_ids: image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + if image_file is None: + continue try: storage.delete(image_file.key) except Exception: diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 4d328643b..7a536f742 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -3,7 +3,7 @@ from typing import Optional import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.tools.utils.web_reader_tool import get_image_upload_file_ids @@ -38,12 +38,14 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i if segments: index_node_ids = [segment.index_node_id for segment in segments] index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) for upload_file_id in image_upload_file_ids: image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + if image_file is None: + continue try: storage.delete(image_file.key) except Exception: diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index 75d9e0313..5a6eb00a6 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -37,7 +37,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() index_node_ids = [segment.index_node_id for segment in segments] - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 315b01f15..dfa053a43 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -4,7 +4,7 @@ from typing import Optional import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index cfc54920e..a9b5ab91a 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -2,10 +2,11 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore +from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -27,7 +28,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): if not dataset: raise Exception("Dataset not found") - index_type = dataset.doc_form + index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "remove": index_processor.clean(dataset, None, with_keywords=False) @@ -105,7 +106,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): db.session.commit() # clean index - index_processor.clean(dataset, None, with_keywords=False) + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) for dataset_document in dataset_documents: # update from vector index @@ -128,7 +129,22 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): "dataset_id": segment.dataset_id, }, ) - + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) @@ -141,6 +157,9 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): {"indexing_status": "error", "error": str(e)}, synchronize_session=False ) db.session.commit() + else: + # clean collection + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) end_at = time.perf_counter() logging.info( diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py new file mode 100644 index 000000000..52c884ca2 --- /dev/null +++ b/api/tasks/delete_account_task.py @@ -0,0 +1,26 @@ +import logging + +from celery import shared_task # type: ignore + +from extensions.ext_database import db +from models.account import Account +from services.billing_service import BillingService +from tasks.mail_account_deletion_task import send_deletion_success_task + +logger = logging.getLogger(__name__) + + +@shared_task(queue="dataset") +def delete_account_task(account_id): + account = db.session.query(Account).filter(Account.id == account_id).first() + try: + BillingService.delete_account(account_id) + except Exception as e: + logger.exception(f"Failed to delete account {account_id} from billing service.") + raise + + if not account: + logger.error(f"Account {account_id} not found.") + return + # send success email + send_deletion_success_task.delay(account.email) diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index c3e0ea5d9..3b04143dd 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -2,52 +2,42 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db -from extensions.ext_redis import redis_client from models.dataset import Dataset, Document @shared_task(queue="dataset") -def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_id: str, document_id: str): +def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str): """ Async Remove segment from index - :param segment_id: - :param index_node_id: + :param index_node_ids: :param dataset_id: :param document_id: - Usage: delete_segment_from_index_task.delay(segment_id) + Usage: delete_segment_from_index_task.delay(segment_ids) """ - logging.info(click.style("Start delete segment from index: {}".format(segment_id), fg="green")) + logging.info(click.style("Start delete segment from index", fg="green")) start_at = time.perf_counter() - indexing_cache_key = "segment_{}_delete_indexing".format(segment_id) try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - logging.info(click.style("Segment {} has no dataset, pass.".format(segment_id), fg="cyan")) return dataset_document = db.session.query(Document).filter(Document.id == document_id).first() if not dataset_document: - logging.info(click.style("Segment {} has no document, pass.".format(segment_id), fg="cyan")) return if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logging.info(click.style("Segment {} document status is invalid, pass.".format(segment_id), fg="cyan")) return index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.clean(dataset, [index_node_id]) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) end_at = time.perf_counter() - logging.info( - click.style("Segment deleted from index: {} latency: {}".format(segment_id, end_at - start_at), fg="green") - ) + logging.info(click.style("Segment deleted from index latency: {}".format(end_at - start_at), fg="green")) except Exception: logging.exception("delete segment from index failed") - finally: - redis_client.delete(indexing_cache_key) diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 15e1e5007..f30a1cc7a 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py new file mode 100644 index 000000000..67112666e --- /dev/null +++ b/api/tasks/disable_segments_from_index_task.py @@ -0,0 +1,76 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument + + +@shared_task(queue="dataset") +def disable_segments_from_index_task(segment_ids: list, dataset_id: str, document_id: str): + """ + Async disable segments from index + :param segment_ids: + + Usage: disable_segments_from_index_task.delay(segment_ids, dataset_id, document_id) + """ + start_at = time.perf_counter() + + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) + return + + dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + + if not dataset_document: + logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) + return + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan")) + return + # sync index processor + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ) + .all() + ) + + if not segments: + return + + try: + index_node_ids = [segment.index_node_id for segment in segments] + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) + + end_at = time.perf_counter() + logging.info(click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green")) + except Exception: + # update segment error msg + db.session.query(DocumentSegment).filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ).update( + { + "disabled_at": None, + "disabled_by": None, + "enabled": True, + } + ) + db.session.commit() + finally: + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 183169139..d686698b9 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner @@ -82,7 +82,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 734dd2478..21b571b6c 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 1a52a6636..d8f14830c 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner @@ -47,7 +47,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index f4c3dbd2e..8e1d2b6b5 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner @@ -26,6 +26,8 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): start_at = time.perf_counter() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if dataset is None: + raise ValueError("Dataset not found") # check document limit features = FeatureService.get_features(dataset.tenant_id) @@ -49,7 +51,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() return @@ -71,14 +73,14 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: db.session.delete(segment) db.session.commit() document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 12639db93..76522f472 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -3,11 +3,12 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DocumentSegment @@ -61,6 +62,22 @@ def enable_segment_to_index_task(segment_id: str): return index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents # save vector index index_processor.load(dataset, [document]) diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py new file mode 100644 index 000000000..0864e05e2 --- /dev/null +++ b/api/tasks/enable_segments_to_index_task.py @@ -0,0 +1,108 @@ +import datetime +import logging +import time + +import click +from celery import shared_task # type: ignore + +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import ChildDocument, Document +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument + + +@shared_task(queue="dataset") +def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_id: str): + """ + Async enable segments to index + :param segment_ids: + + Usage: enable_segments_to_index_task.delay(segment_ids) + """ + start_at = time.perf_counter() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) + return + + dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + + if not dataset_document: + logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) + return + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan")) + return + # sync index processor + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ) + .all() + ) + if not segments: + return + + try: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": document_id, + "dataset_id": dataset_id, + }, + ) + + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": document_id, + "dataset_id": dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + documents.append(document) + # save vector index + index_processor.load(dataset, documents) + + end_at = time.perf_counter() + logging.info(click.style("Segments enabled to index latency: {}".format(end_at - start_at), fg="green")) + except Exception as e: + logging.exception("enable segments to index failed") + # update segment error msg + db.session.query(DocumentSegment).filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ).update( + { + "error": str(e), + "status": "error", + "disabled_at": datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + "enabled": False, + } + ) + db.session.commit() + finally: + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/mail_account_deletion_task.py b/api/tasks/mail_account_deletion_task.py new file mode 100644 index 000000000..49a3a6d28 --- /dev/null +++ b/api/tasks/mail_account_deletion_task.py @@ -0,0 +1,70 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore +from flask import render_template + +from extensions.ext_mail import mail + + +@shared_task(queue="mail") +def send_deletion_success_task(to): + """Send email to user regarding account deletion. + + Args: + log (AccountDeletionLog): Account deletion log object + """ + if not mail.is_inited(): + return + + logging.info(click.style(f"Start send account deletion success email to {to}", fg="green")) + start_at = time.perf_counter() + + try: + html_content = render_template( + "delete_account_success_template_en-US.html", + to=to, + email=to, + ) + mail.send(to=to, subject="Your Dify.AI Account Has Been Successfully Deleted", html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send account deletion success email to {}: latency: {}".format(to, end_at - start_at), fg="green" + ) + ) + except Exception: + logging.exception("Send account deletion success email to {} failed".format(to)) + + +@shared_task(queue="mail") +def send_account_deletion_verification_code(to, code): + """Send email to user regarding account deletion verification code. + + Args: + to (str): Recipient email address + code (str): Verification code + """ + if not mail.is_inited(): + return + + logging.info(click.style(f"Start send account deletion verification code email to {to}", fg="green")) + start_at = time.perf_counter() + + try: + html_content = render_template("delete_account_code_email_template_en-US.html", to=to, code=code) + mail.send(to=to, subject="Dify.AI Account Deletion and Verification", html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send account deletion verification code email to {} succeeded: latency: {}".format( + to, end_at - start_at + ), + fg="green", + ) + ) + except Exception: + logging.exception("Send account deletion verification code email to {} failed".format(to)) diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py index d78fc2b89..5dc935548 100644 --- a/api/tasks/mail_email_code_login.py +++ b/api/tasks/mail_email_code_login.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from flask import render_template from extensions.ext_mail import mail diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index c7dfb9bf6..3094527fd 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from flask import render_template from configs import dify_config diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index 8596ca07c..d5be94431 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from flask import render_template from extensions.ext_mail import mail diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 34c62dc92..bb3b9e17e 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -1,7 +1,7 @@ import json import logging -from celery import shared_task +from celery import shared_task # type: ignore from flask import current_app from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 934eb7430..b603d689b 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -2,7 +2,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 66f78636e..c3910e2be 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -3,7 +3,7 @@ from collections.abc import Callable import click -from celery import shared_task +from celery import shared_task # type: ignore from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index 1909eaf34..d0c4382f5 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -1,8 +1,9 @@ +import datetime import logging import time import click -from celery import shared_task +from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -43,9 +44,19 @@ def remove_document_from_index_task(document_id: str): index_node_ids = [segment.index_node_id for segment in segments] if index_node_ids: try: - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) except Exception: logging.exception(f"clean dataset {dataset.id} from index failed") + # update segment to disable + db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).update( + { + DocumentSegment.enabled: False, + DocumentSegment.disabled_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + DocumentSegment.disabled_by: document.disabled_by, + DocumentSegment.updated_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + } + ) + db.session.commit() end_at = time.perf_counter() logging.info( diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 73471fd6e..74fd542f6 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -22,10 +22,13 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): Usage: retry_document_indexing_task.delay(dataset_id, document_id) """ - documents = [] + documents: list[Document] = [] start_at = time.perf_counter() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("Dataset not found") + for document_id in document_ids: retry_indexing_cache_key = "document_{}_is_retried".format(document_id) # check document limit @@ -45,7 +48,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() redis_client.delete(retry_indexing_cache_key) @@ -55,33 +58,35 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): document = ( db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() ) + if not document: + logging.info(click.style("Document not found: {}".format(document_id), fg="yellow")) + return try: - if document: - # clean old data - index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + # clean old data + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - index_processor.clean(dataset, index_node_ids) + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + # delete from vector index + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - for segment in segments: - db.session.delete(segment) - db.session.commit() + for segment in segments: + db.session.delete(segment) + db.session.commit() - document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() - db.session.add(document) - db.session.commit() + document.indexing_status = "parsing" + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.add(document) + db.session.commit() - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - redis_client.delete(retry_indexing_cache_key) + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + redis_client.delete(retry_indexing_cache_key) except Exception as ex: document.indexing_status = "error" document.error = str(ex) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() logging.info(click.style(str(ex), fg="yellow")) diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 1d2a338c8..8da050d0d 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -3,7 +3,7 @@ import time import click -from celery import shared_task +from celery import shared_task # type: ignore from core.indexing_runner import IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -25,6 +25,8 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): start_at = time.perf_counter() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if dataset is None: + raise ValueError("Dataset not found") sync_indexing_cache_key = "document_{}_is_sync".format(document_id) # check document limit @@ -44,7 +46,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() redis_client.delete(sync_indexing_cache_key) @@ -52,33 +54,35 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): logging.info(click.style("Start sync website document: {}".format(document_id), fg="green")) document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + if not document: + logging.info(click.style("Document not found: {}".format(document_id), fg="yellow")) + return try: - if document: - # clean old data - index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() + # clean old data + index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() - segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - # delete from vector index - index_processor.clean(dataset, index_node_ids) + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + # delete from vector index + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - for segment in segments: - db.session.delete(segment) - db.session.commit() + for segment in segments: + db.session.delete(segment) + db.session.commit() - document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() - db.session.add(document) - db.session.commit() + document.indexing_status = "parsing" + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + db.session.add(document) + db.session.commit() - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - redis_client.delete(sync_indexing_cache_key) + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + redis_client.delete(sync_indexing_cache_key) except Exception as ex: document.indexing_status = "error" document.error = str(ex) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() logging.info(click.style(str(ex), fg="yellow")) diff --git a/api/templates/clean_document_job_mail_template-US.html b/api/templates/clean_document_job_mail_template-US.html new file mode 100644 index 000000000..88e78f41c --- /dev/null +++ b/api/templates/clean_document_job_mail_template-US.html @@ -0,0 +1,100 @@ + + + + + + Documents Disabled Notification + + + + + + \ No newline at end of file diff --git a/api/templates/delete_account_code_email_template_en-US.html b/api/templates/delete_account_code_email_template_en-US.html new file mode 100644 index 000000000..770738533 --- /dev/null +++ b/api/templates/delete_account_code_email_template_en-US.html @@ -0,0 +1,125 @@ + + + + + + + + +
+
+ + Dify Logo +
+

Dify.AI Account Deletion and Verification

+

We received a request to delete your Dify account. To ensure the security of your account and + confirm this action, please use the verification code below:

+
+ {{code}} +
+
+

To complete the account deletion process:

+

1. Return to the account deletion page on our website

+

2. Enter the verification code above

+

3. Click "Confirm Deletion"

+
+

Please note:

+
    +
  • This code is valid for 5 minutes
  • +
  • As the Owner of any Workspaces, your workspaces will be scheduled in a queue for permanent deletion.
  • +
  • All your user data will be queued for permanent deletion.
  • +
+
+ + + \ No newline at end of file diff --git a/api/templates/delete_account_success_template_en-US.html b/api/templates/delete_account_success_template_en-US.html new file mode 100644 index 000000000..c5df75cab --- /dev/null +++ b/api/templates/delete_account_success_template_en-US.html @@ -0,0 +1,105 @@ + + + + + + + + +
+
+ + Dify Logo +
+

Your Dify.AI Account Has Been Successfully Deleted

+

We're writing to confirm that your Dify.AI account has been successfully deleted as per your request. Your + account is no longer accessible, and you can't log in using your previous credentials. If you decide to use + Dify.AI services in the future, you'll need to create a new account after 30 days. We appreciate the time you + spent with Dify.AI and are sorry to see you go. If you have any questions or concerns about the deletion process, + please don't hesitate to reach out to our support team.

+

Thank you for being a part of the Dify.AI community.

+

Best regards,

+

Dify.AI Team

+
+ + + \ No newline at end of file diff --git a/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py b/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py index 64f2884c4..57fba3176 100644 --- a/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py +++ b/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py @@ -1,6 +1,6 @@ from typing import Any -import toml +import toml # type: ignore def load_api_poetry_configs() -> dict[str, Any]: @@ -38,7 +38,7 @@ def test_group_dependencies_version_operator(): ) -def test_duplicated_dependency_crossing_groups(): +def test_duplicated_dependency_crossing_groups() -> None: all_dependency_names: list[str] = [] for dependencies in load_all_dependency_groups().values(): dependency_names = list(dependencies.keys()) diff --git a/api/tests/integration_tests/controllers/test_controllers.py b/api/tests/integration_tests/controllers/test_controllers.py index 637169469..276ad3a7e 100644 --- a/api/tests/integration_tests/controllers/test_controllers.py +++ b/api/tests/integration_tests/controllers/test_controllers.py @@ -1,10 +1,9 @@ from unittest.mock import patch -from app_fixture import app, mock_user +from app_fixture import mock_user # type: ignore def test_post_requires_login(app): - with app.test_client() as client: - with patch("flask_login.utils._get_user", mock_user): - response = client.get("/console/api/data-source/integrates") - assert response.status_code == 200 + with app.test_client() as client, patch("flask_login.utils._get_user", mock_user): + response = client.get("/console/api/data-source/integrates") + assert response.status_code == 200 diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py index 5ea86baa8..3a26b99e3 100644 --- a/api/tests/integration_tests/model_runtime/__mock/google.py +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -1,7 +1,6 @@ -from collections.abc import Generator from unittest.mock import MagicMock -import google.generativeai.types.generation_types as generation_config_types +import google.generativeai.types.generation_types as generation_config_types # type: ignore import pytest from _pytest.monkeypatch import MonkeyPatch from google.ai import generativelanguage as glm @@ -45,7 +44,7 @@ def generate_content_sync() -> GenerateContentResponse: return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]) @staticmethod - def generate_content_stream() -> Generator[GenerateContentResponse, None, None]: + def generate_content_stream() -> MockGoogleResponseClass: return MockGoogleResponseClass() def generate_content( diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface.py b/api/tests/integration_tests/model_runtime/__mock/huggingface.py index 97038ef59..4de525144 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface.py @@ -2,7 +2,7 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from huggingface_hub import InferenceClient +from huggingface_hub import InferenceClient # type: ignore from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py index 9ee76c935..77c7e7f5e 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py @@ -3,15 +3,15 @@ from typing import Any, Literal, Optional, Union from _pytest.monkeypatch import MonkeyPatch -from huggingface_hub import InferenceClient -from huggingface_hub.inference._text_generation import ( +from huggingface_hub import InferenceClient # type: ignore +from huggingface_hub.inference._text_generation import ( # type: ignore Details, StreamDetails, TextGenerationResponse, TextGenerationStreamResponse, Token, ) -from huggingface_hub.utils import BadRequestError +from huggingface_hub.utils import BadRequestError # type: ignore class MockHuggingfaceChatClass: diff --git a/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py b/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py index 6a25398cb..4e00660a2 100644 --- a/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py +++ b/api/tests/integration_tests/model_runtime/__mock/nomic_embeddings.py @@ -6,7 +6,7 @@ # import monkeypatch from _pytest.monkeypatch import MonkeyPatch -from nomic import embed +from nomic import embed # type: ignore def create_embedding(texts: list[str], model: str, **kwargs: Any) -> dict: diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index 794f4b058..e2abaa52b 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -6,14 +6,14 @@ from _pytest.monkeypatch import MonkeyPatch from requests import Response from requests.sessions import Session -from xinference_client.client.restful.restful_client import ( +from xinference_client.client.restful.restful_client import ( # type: ignore Client, RESTfulChatModelHandle, RESTfulEmbeddingModelHandle, RESTfulGenerateModelHandle, RESTfulRerankModelHandle, ) -from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage +from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage # type: ignore class MockXinferenceClass: diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_speech2text.py b/api/tests/integration_tests/model_runtime/gpustack/test_speech2text.py new file mode 100644 index 000000000..c215e9b73 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gpustack/test_speech2text.py @@ -0,0 +1,55 @@ +import os +from pathlib import Path + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gpustack.speech2text.speech2text import GPUStackSpeech2TextModel + + +def test_validate_credentials(): + model = GPUStackSpeech2TextModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="faster-whisper-medium", + credentials={ + "endpoint_url": "invalid_url", + "api_key": "invalid_api_key", + }, + ) + + model.validate_credentials( + model="faster-whisper-medium", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + ) + + +def test_invoke_model(): + model = GPUStackSpeech2TextModel() + + # Get the directory of the current file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Get assets directory + assets_dir = os.path.join(os.path.dirname(current_dir), "assets") + + # Construct the path to the audio file + audio_file_path = os.path.join(assets_dir, "audio.mp3") + + file = Path(audio_file_path).read_bytes() + + result = model.invoke( + model="faster-whisper-medium", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + file=file, + ) + + assert isinstance(result, str) + assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10" diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_tts.py b/api/tests/integration_tests/model_runtime/gpustack/test_tts.py new file mode 100644 index 000000000..8997ad074 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gpustack/test_tts.py @@ -0,0 +1,24 @@ +import os + +from core.model_runtime.model_providers.gpustack.tts.tts import GPUStackText2SpeechModel + + +def test_invoke_model(): + model = GPUStackText2SpeechModel() + + result = model.invoke( + model="cosyvoice-300m-sft", + tenant_id="test", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + content_text="Hello world", + voice="Chinese Female", + ) + + content = b"" + for chunk in result: + content += chunk + + assert content != b"" diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py b/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py index 2dcfb92c6..d37fcf897 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py @@ -1,6 +1,6 @@ import os -import dashscope +import dashscope # type: ignore import pytest from core.model_runtime.entities.rerank_entities import RerankResult diff --git a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py index 83f4d70ce..2860739f0 100644 --- a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py +++ b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py @@ -1,5 +1,5 @@ from flask import Flask, request -from flask_restful import Api, Resource +from flask_restful import Api, Resource # type: ignore app = Flask(__name__) api = Api(app) diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py index 0ea61369c..4af35a8be 100644 --- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -4,11 +4,11 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from pymochow import MochowClient -from pymochow.model.database import Database -from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState -from pymochow.model.schema import HNSWParams, VectorIndex -from pymochow.model.table import Table +from pymochow import MochowClient # type: ignore +from pymochow.model.database import Database # type: ignore +from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore +from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore +from pymochow.model.table import Table # type: ignore from requests.adapters import HTTPAdapter diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index 61d6ed165..68a1e290a 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -4,12 +4,12 @@ import pytest from _pytest.monkeypatch import MonkeyPatch from requests.adapters import HTTPAdapter -from tcvectordb import VectorDBClient -from tcvectordb.model.database import Collection, Database -from tcvectordb.model.document import Document, Filter -from tcvectordb.model.enum import ReadConsistency -from tcvectordb.model.index import Index -from xinference_client.types import Embedding +from tcvectordb import VectorDBClient # type: ignore +from tcvectordb.model.database import Collection, Database # type: ignore +from tcvectordb.model.document import Document, Filter # type: ignore +from tcvectordb.model.enum import ReadConsistency # type: ignore +from tcvectordb.model.index import Index # type: ignore +from xinference_client.types import Embedding # type: ignore class MockTcvectordbClass: diff --git a/api/tests/integration_tests/vdb/__mock/vikingdb.py b/api/tests/integration_tests/vdb/__mock/vikingdb.py index 0f40337fe..3ad72e555 100644 --- a/api/tests/integration_tests/vdb/__mock/vikingdb.py +++ b/api/tests/integration_tests/vdb/__mock/vikingdb.py @@ -4,7 +4,7 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from volcengine.viking_db import ( +from volcengine.viking_db import ( # type: ignore Collection, Data, DistanceType, diff --git a/api/tests/integration_tests/vdb/baidu/test_baidu.py b/api/tests/integration_tests/vdb/baidu/test_baidu.py index 5dc2ce4f8..25989958d 100644 --- a/api/tests/integration_tests/vdb/baidu/test_baidu.py +++ b/api/tests/integration_tests/vdb/baidu/test_baidu.py @@ -1,5 +1,3 @@ -from unittest.mock import MagicMock - from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector from tests.integration_tests.vdb.__mock.baiduvectordb import setup_baiduvectordb_mock from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis diff --git a/api/tests/integration_tests/vdb/milvus/test_milvus.py b/api/tests/integration_tests/vdb/milvus/test_milvus.py index c99739a86..0e13f9369 100644 --- a/api/tests/integration_tests/vdb/milvus/test_milvus.py +++ b/api/tests/integration_tests/vdb/milvus/test_milvus.py @@ -19,9 +19,9 @@ def __init__(self): ) def search_by_full_text(self): - # milvus dos not support full text searching yet in < 2.3.x + # milvus support BM25 full text search after version 2.5.0-beta hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) - assert len(hits_by_full_text) == 0 + assert len(hits_by_full_text) >= 0 def get_ids_by_metadata_field(self): ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) diff --git a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py index 4c83c66bf..df0bb3f81 100644 --- a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py +++ b/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py @@ -1,5 +1,3 @@ -from unittest.mock import MagicMock, patch - import pytest from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig diff --git a/api/tests/unit_tests/core/app/segments/test_variables.py b/api/tests/unit_tests/core/app/segments/test_variables.py index 0c264c15a..426557c71 100644 --- a/api/tests/unit_tests/core/app/segments/test_variables.py +++ b/api/tests/unit_tests/core/app/segments/test_variables.py @@ -2,6 +2,8 @@ from pydantic import ValidationError from core.variables import ( + ArrayFileVariable, + ArrayVariable, FloatVariable, IntegerVariable, ObjectVariable, @@ -81,3 +83,8 @@ def test_variable_to_object(): assert var.to_object() == 3.14 var = SecretVariable(name="secret", value="secret_value") assert var.to_object() == "secret_value" + + +def test_array_file_variable_is_array_variable(): + var = ArrayFileVariable(name="files", value=[]) + assert isinstance(var, ArrayVariable) diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index ee0f7672f..f6d22690d 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -4,7 +4,7 @@ from configs import dify_config from core.app.app_config.entities import ModelConfigEntity -from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig +from core.file import File, FileTransferMethod, FileType from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py index f6555cfdd..c3a381865 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Generator -from datetime import UTC, datetime, timezone +from datetime import UTC, datetime from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 76db42ef1..7e979bcaa 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -21,8 +21,7 @@ from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment, StringSegment -from core.workflow.entities.variable_entities import VariableSelector +from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.nodes.answer import AnswerStreamGenerateRoute diff --git a/api/tests/unit_tests/core/workflow/nodes/test_retry.py b/api/tests/unit_tests/core/workflow/nodes/test_retry.py index c232875ce..4ac79d7ac 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_retry.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_retry.py @@ -1,7 +1,6 @@ from core.workflow.graph_engine.entities.event import ( GraphRunFailedEvent, GraphRunPartialSucceededEvent, - GraphRunSucceededEvent, NodeRunRetryEvent, ) from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py index 16c137001..1501722b8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py @@ -1,5 +1,3 @@ -import pytest - from core.variables import SegmentType from core.workflow.nodes.variable_assigner.v2.enums import Operation from core.workflow.nodes.variable_assigner.v2.helpers import is_input_value_valid diff --git a/api/tests/unit_tests/oss/__mock/aliyun_oss.py b/api/tests/unit_tests/oss/__mock/aliyun_oss.py index 27e1c0ad8..4f6d8a2f5 100644 --- a/api/tests/unit_tests/oss/__mock/aliyun_oss.py +++ b/api/tests/unit_tests/oss/__mock/aliyun_oss.py @@ -4,8 +4,8 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from oss2 import Bucket -from oss2.models import GetObjectResult, PutObjectResult +from oss2 import Bucket # type: ignore +from oss2.models import GetObjectResult, PutObjectResult # type: ignore from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/__mock/tencent_cos.py b/api/tests/unit_tests/oss/__mock/tencent_cos.py index 5189b68e8..c77c5b08f 100644 --- a/api/tests/unit_tests/oss/__mock/tencent_cos.py +++ b/api/tests/unit_tests/oss/__mock/tencent_cos.py @@ -3,8 +3,8 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from qcloud_cos import CosS3Client -from qcloud_cos.streambody import StreamBody +from qcloud_cos import CosS3Client # type: ignore +from qcloud_cos.streambody import StreamBody # type: ignore from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/__mock/volcengine_tos.py b/api/tests/unit_tests/oss/__mock/volcengine_tos.py index 649d93a20..88df59f91 100644 --- a/api/tests/unit_tests/oss/__mock/volcengine_tos.py +++ b/api/tests/unit_tests/oss/__mock/volcengine_tos.py @@ -4,8 +4,8 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from tos import TosClientV2 -from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput +from tos import TosClientV2 # type: ignore +from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore from tests.unit_tests.oss.__mock.base import ( get_example_bucket, diff --git a/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py b/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py index 65d31352b..f87a38569 100644 --- a/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py +++ b/api/tests/unit_tests/oss/aliyun_oss/aliyun_oss/test_aliyun_oss.py @@ -1,7 +1,7 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest -from oss2 import Auth +from oss2 import Auth # type: ignore from extensions.storage.aliyun_oss_storage import AliyunOssStorage from tests.unit_tests.oss.__mock.aliyun_oss import setup_aliyun_oss_mock diff --git a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py index 303f0493b..d28975180 100644 --- a/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py +++ b/api/tests/unit_tests/oss/tencent_cos/test_tencent_cos.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from qcloud_cos import CosConfig +from qcloud_cos import CosConfig # type: ignore from extensions.storage.tencent_cos_storage import TencentCosStorage from tests.unit_tests.oss.__mock.base import ( diff --git a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py index 5afbc9e8b..04988e85d 100644 --- a/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py +++ b/api/tests/unit_tests/oss/volcengine_tos/test_volcengine_tos.py @@ -1,5 +1,5 @@ import pytest -from tos import TosClientV2 +from tos import TosClientV2 # type: ignore from extensions.storage.volcengine_tos_storage import VolcengineTosStorage from tests.unit_tests.oss.__mock.base import ( diff --git a/api/tests/unit_tests/utils/test_text_processing.py b/api/tests/unit_tests/utils/test_text_processing.py index f9d00d0b3..8bfc97ae6 100644 --- a/api/tests/unit_tests/utils/test_text_processing.py +++ b/api/tests/unit_tests/utils/test_text_processing.py @@ -1,5 +1,3 @@ -from textwrap import dedent - import pytest from core.tools.utils.text_processing_utils import remove_leading_symbols diff --git a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py index 95b93651d..8d6454872 100644 --- a/api/tests/unit_tests/utils/yaml/test_yaml_utils.py +++ b/api/tests/unit_tests/utils/yaml/test_yaml_utils.py @@ -1,7 +1,7 @@ from textwrap import dedent import pytest -from yaml import YAMLError +from yaml import YAMLError # type: ignore from core.tools.utils.yaml_utils import load_yaml_file diff --git a/docker-legacy/docker-compose.yaml b/docker-legacy/docker-compose.yaml index 1cff58be7..6e4c8a748 100644 --- a/docker-legacy/docker-compose.yaml +++ b/docker-legacy/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.14.2 + image: langgenius/dify-api:0.15.1 restart: always environment: # Startup mode, 'api' starts the API server. @@ -227,7 +227,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.14.2 + image: langgenius/dify-api:0.15.1 restart: always environment: CONSOLE_WEB_URL: '' @@ -397,7 +397,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.14.2 + image: langgenius/dify-web:0.15.1 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is diff --git a/docker/.env.example b/docker/.env.example index 43e67a8db..b21bdc708 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -105,6 +105,9 @@ FILES_ACCESS_TIMEOUT=300 # Access token expiration time in minutes ACCESS_TOKEN_EXPIRE_MINUTES=60 +# Refresh token expiration time in days +REFRESH_TOKEN_EXPIRE_DAYS=30 + # The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer. APP_MAX_ACTIVE_REQUESTS=0 APP_MAX_EXECUTION_TIME=1200 @@ -123,10 +126,13 @@ DIFY_PORT=5001 # The number of API server workers, i.e., the number of workers. # Formula: number of cpu cores x 2 + 1 for sync, 1 for Gevent # Reference: https://docs.gunicorn.org/en/stable/design.html#how-many-workers -SERVER_WORKER_AMOUNT= +SERVER_WORKER_AMOUNT=1 # Defaults to gevent. If using windows, it can be switched to sync or solo. -SERVER_WORKER_CLASS= +SERVER_WORKER_CLASS=gevent + +# Default number of worker connections, the default is 10. +SERVER_WORKER_CONNECTIONS=10 # Similar to SERVER_WORKER_CLASS. # If using windows, it can be switched to sync or solo. @@ -315,7 +321,7 @@ AZURE_BLOB_ACCOUNT_URL=https://.blob.core.windows.net # Google Storage Configuration # GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name -GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string +GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64= # The Alibaba Cloud OSS configurations, # @@ -377,7 +383,7 @@ SUPABASE_URL=your-server-url # ------------------------------ # The type of vector store to use. -# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`. +# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`. VECTOR_STORE=weaviate # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. @@ -397,6 +403,7 @@ MILVUS_URI=http://127.0.0.1:19530 MILVUS_TOKEN= MILVUS_USER=root MILVUS_PASSWORD=Milvus +MILVUS_ENABLE_HYBRID_SEARCH=False # MyScale configuration, only available when VECTOR_STORE is `myscale` # For multi-language support, please set MYSCALE_FTS_PARAMS with referring to: @@ -923,3 +930,5 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false # Maximum number of submitted thread count in a ThreadPool for parallel node execution MAX_SUBMIT_COUNT=100 +# The maximum number of top-k value for RAG. +TOP_K_MAX_VALUE=10 diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index d4e0ba49d..e2daead92 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:0.14.2 + image: langgenius/dify-api:0.15.1 restart: always environment: # Use the shared environment variables. @@ -25,7 +25,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.14.2 + image: langgenius/dify-api:0.15.1 restart: always environment: # Use the shared environment variables. @@ -47,7 +47,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.14.2 + image: langgenius/dify-web:0.15.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -56,6 +56,7 @@ services: NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0} TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} CSP_WHITELIST: ${CSP_WHITELIST:-} + TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-} # The postgres database. db: @@ -75,7 +76,7 @@ services: volumes: - ./volumes/db/data:/var/lib/postgresql/data healthcheck: - test: ['CMD', 'pg_isready'] + test: [ 'CMD', 'pg_isready' ] interval: 1s timeout: 3s retries: 30 @@ -92,7 +93,7 @@ services: # Set the redis password when startup redis server. command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456} healthcheck: - test: ['CMD', 'redis-cli', 'ping'] + test: [ 'CMD', 'redis-cli', 'ping' ] # The DifySandbox sandbox: @@ -112,7 +113,7 @@ services: volumes: - ./volumes/sandbox/dependencies:/dependencies healthcheck: - test: ['CMD', 'curl', '-f', 'http://localhost:8194/health'] + test: [ 'CMD', 'curl', '-f', 'http://localhost:8194/health' ] networks: - ssrf_proxy_network @@ -125,12 +126,7 @@ services: volumes: - ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template - ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh - entrypoint: - [ - 'sh', - '-c', - "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", - ] + entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ] environment: # pls clearly modify the squid env vars to fit your network environment. HTTP_PORT: ${SSRF_HTTP_PORT:-3128} @@ -159,8 +155,8 @@ services: - CERTBOT_EMAIL=${CERTBOT_EMAIL} - CERTBOT_DOMAIN=${CERTBOT_DOMAIN} - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} - entrypoint: ['/docker-entrypoint.sh'] - command: ['tail', '-f', '/dev/null'] + entrypoint: [ '/docker-entrypoint.sh' ] + command: [ 'tail', '-f', '/dev/null' ] # The nginx reverse proxy. # used for reverse proxying the API service and Web service. @@ -177,12 +173,7 @@ services: - ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container) - ./volumes/certbot/conf:/etc/letsencrypt - ./volumes/certbot/www:/var/www/html - entrypoint: - [ - 'sh', - '-c', - "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", - ] + entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ] environment: NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_} NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false} @@ -274,7 +265,7 @@ services: working_dir: /opt/couchbase stdin_open: true tty: true - entrypoint: [""] + entrypoint: [ "" ] command: sh -c "/opt/couchbase/init/init-cbserver.sh" volumes: - ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data @@ -303,7 +294,7 @@ services: volumes: - ./volumes/pgvector/data:/var/lib/postgresql/data healthcheck: - test: ['CMD', 'pg_isready'] + test: [ 'CMD', 'pg_isready' ] interval: 1s timeout: 3s retries: 30 @@ -325,7 +316,7 @@ services: volumes: - ./volumes/pgvecto_rs/data:/var/lib/postgresql/data healthcheck: - test: ['CMD', 'pg_isready'] + test: [ 'CMD', 'pg_isready' ] interval: 1s timeout: 3s retries: 30 @@ -390,7 +381,7 @@ services: - ./volumes/milvus/etcd:/etcd command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd healthcheck: - test: ['CMD', 'etcdctl', 'endpoint', 'health'] + test: [ 'CMD', 'etcdctl', 'endpoint', 'health' ] interval: 30s timeout: 20s retries: 3 @@ -409,7 +400,7 @@ services: - ./volumes/milvus/minio:/minio_data command: minio server /minio_data --console-address ":9001" healthcheck: - test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live'] + test: [ 'CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live' ] interval: 30s timeout: 20s retries: 3 @@ -418,10 +409,10 @@ services: milvus-standalone: container_name: milvus-standalone - image: milvusdb/milvus:v2.3.1 + image: milvusdb/milvus:v2.5.0-beta profiles: - milvus - command: ['milvus', 'run', 'standalone'] + command: [ 'milvus', 'run', 'standalone' ] environment: ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} @@ -429,7 +420,7 @@ services: volumes: - ./volumes/milvus/milvus:/var/lib/milvus healthcheck: - test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz'] + test: [ 'CMD', 'curl', '-f', 'http://localhost:9091/healthz' ] interval: 30s start_period: 90s timeout: 20s @@ -502,22 +493,30 @@ services: container_name: elasticsearch profiles: - elasticsearch + - elasticsearch-ja restart: always volumes: + - ./elasticsearch/docker-entrypoint.sh:/docker-entrypoint-mount.sh - dify_es01_data:/usr/share/elasticsearch/data environment: ELASTIC_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} + VECTOR_STORE: ${VECTOR_STORE:-} cluster.name: dify-es-cluster node.name: dify-es0 discovery.type: single-node - xpack.license.self_generated.type: trial + xpack.license.self_generated.type: basic xpack.security.enabled: 'true' xpack.security.enrollment.enabled: 'false' xpack.security.http.ssl.enabled: 'false' ports: - ${ELASTICSEARCH_PORT:-9200}:9200 + deploy: + resources: + limits: + memory: 2g + entrypoint: [ 'sh', '-c', "sh /docker-entrypoint-mount.sh" ] healthcheck: - test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty'] + test: [ 'CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty' ] interval: 30s timeout: 10s retries: 50 @@ -545,7 +544,7 @@ services: ports: - ${KIBANA_PORT:-5601}:5601 healthcheck: - test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1'] + test: [ 'CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1' ] interval: 30s timeout: 10s retries: 3 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 7122f4a6d..f60fcdbcf 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -15,24 +15,26 @@ x-shared-env: &shared-api-worker-env LOG_FILE: ${LOG_FILE:-/app/logs/server.log} LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20} LOG_FILE_BACKUP_COUNT: ${LOG_FILE_BACKUP_COUNT:-5} - LOG_DATEFORMAT: ${LOG_DATEFORMAT:-"%Y-%m-%d %H:%M:%S"} + LOG_DATEFORMAT: ${LOG_DATEFORMAT:-%Y-%m-%d %H:%M:%S} LOG_TZ: ${LOG_TZ:-UTC} DEBUG: ${DEBUG:-false} FLASK_DEBUG: ${FLASK_DEBUG:-false} SECRET_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U} INIT_PASSWORD: ${INIT_PASSWORD:-} DEPLOY_ENV: ${DEPLOY_ENV:-PRODUCTION} - CHECK_UPDATE_URL: ${CHECK_UPDATE_URL:-"https://updates.dify.ai"} - OPENAI_API_BASE: ${OPENAI_API_BASE:-"https://api.openai.com/v1"} + CHECK_UPDATE_URL: ${CHECK_UPDATE_URL:-https://updates.dify.ai} + OPENAI_API_BASE: ${OPENAI_API_BASE:-https://api.openai.com/v1} MIGRATION_ENABLED: ${MIGRATION_ENABLED:-true} FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300} ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} + REFRESH_TOKEN_EXPIRE_DAYS: ${REFRESH_TOKEN_EXPIRE_DAYS:-30} APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0} APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200} DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0} DIFY_PORT: ${DIFY_PORT:-5001} - SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-} - SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-} + SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-1} + SERVER_WORKER_CLASS: ${SERVER_WORKER_CLASS:-gevent} + SERVER_WORKER_CONNECTIONS: ${SERVER_WORKER_CONNECTIONS:-10} CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-} GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360} CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-} @@ -69,7 +71,7 @@ x-shared-env: &shared-api-worker-env REDIS_USE_CLUSTERS: ${REDIS_USE_CLUSTERS:-false} REDIS_CLUSTERS: ${REDIS_CLUSTERS:-} REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-} - CELERY_BROKER_URL: ${CELERY_BROKER_URL:-"redis://:difyai123456@redis:6379/1"} + CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} BROKER_USE_SSL: ${BROKER_USE_SSL:-false} CELERY_USE_SENTINEL: ${CELERY_USE_SENTINEL:-false} CELERY_SENTINEL_MASTER_NAME: ${CELERY_SENTINEL_MASTER_NAME:-} @@ -88,13 +90,13 @@ x-shared-env: &shared-api-worker-env AZURE_BLOB_ACCOUNT_NAME: ${AZURE_BLOB_ACCOUNT_NAME:-difyai} AZURE_BLOB_ACCOUNT_KEY: ${AZURE_BLOB_ACCOUNT_KEY:-difyai} AZURE_BLOB_CONTAINER_NAME: ${AZURE_BLOB_CONTAINER_NAME:-difyai-container} - AZURE_BLOB_ACCOUNT_URL: ${AZURE_BLOB_ACCOUNT_URL:-"https://.blob.core.windows.net"} + AZURE_BLOB_ACCOUNT_URL: ${AZURE_BLOB_ACCOUNT_URL:-https://.blob.core.windows.net} GOOGLE_STORAGE_BUCKET_NAME: ${GOOGLE_STORAGE_BUCKET_NAME:-your-bucket-name} - GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: ${GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64:-your-google-service-account-json-base64-string} + GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: ${GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64:-} ALIYUN_OSS_BUCKET_NAME: ${ALIYUN_OSS_BUCKET_NAME:-your-bucket-name} ALIYUN_OSS_ACCESS_KEY: ${ALIYUN_OSS_ACCESS_KEY:-your-access-key} ALIYUN_OSS_SECRET_KEY: ${ALIYUN_OSS_SECRET_KEY:-your-secret-key} - ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-"https://oss-ap-southeast-1-internal.aliyuncs.com"} + ALIYUN_OSS_ENDPOINT: ${ALIYUN_OSS_ENDPOINT:-https://oss-ap-southeast-1-internal.aliyuncs.com} ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-ap-southeast-1} ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4} ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-your-path} @@ -103,7 +105,7 @@ x-shared-env: &shared-api-worker-env TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id} TENCENT_COS_REGION: ${TENCENT_COS_REGION:-your-region} TENCENT_COS_SCHEME: ${TENCENT_COS_SCHEME:-your-scheme} - OCI_ENDPOINT: ${OCI_ENDPOINT:-"https://objectstorage.us-ashburn-1.oraclecloud.com"} + OCI_ENDPOINT: ${OCI_ENDPOINT:-https://objectstorage.us-ashburn-1.oraclecloud.com} OCI_BUCKET_NAME: ${OCI_BUCKET_NAME:-your-bucket-name} OCI_ACCESS_KEY: ${OCI_ACCESS_KEY:-your-access-key} OCI_SECRET_KEY: ${OCI_SECRET_KEY:-your-secret-key} @@ -125,24 +127,25 @@ x-shared-env: &shared-api-worker-env SUPABASE_API_KEY: ${SUPABASE_API_KEY:-your-access-key} SUPABASE_URL: ${SUPABASE_URL:-your-server-url} VECTOR_STORE: ${VECTOR_STORE:-weaviate} - WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-"http://weaviate:8080"} + WEAVIATE_ENDPOINT: ${WEAVIATE_ENDPOINT:-http://weaviate:8080} WEAVIATE_API_KEY: ${WEAVIATE_API_KEY:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih} - QDRANT_URL: ${QDRANT_URL:-"http://qdrant:6333"} + QDRANT_URL: ${QDRANT_URL:-http://qdrant:6333} QDRANT_API_KEY: ${QDRANT_API_KEY:-difyai123456} QDRANT_CLIENT_TIMEOUT: ${QDRANT_CLIENT_TIMEOUT:-20} QDRANT_GRPC_ENABLED: ${QDRANT_GRPC_ENABLED:-false} QDRANT_GRPC_PORT: ${QDRANT_GRPC_PORT:-6334} - MILVUS_URI: ${MILVUS_URI:-"http://127.0.0.1:19530"} + MILVUS_URI: ${MILVUS_URI:-http://127.0.0.1:19530} MILVUS_TOKEN: ${MILVUS_TOKEN:-} MILVUS_USER: ${MILVUS_USER:-root} MILVUS_PASSWORD: ${MILVUS_PASSWORD:-Milvus} + MILVUS_ENABLE_HYBRID_SEARCH: ${MILVUS_ENABLE_HYBRID_SEARCH:-False} MYSCALE_HOST: ${MYSCALE_HOST:-myscale} MYSCALE_PORT: ${MYSCALE_PORT:-8123} MYSCALE_USER: ${MYSCALE_USER:-default} MYSCALE_PASSWORD: ${MYSCALE_PASSWORD:-} MYSCALE_DATABASE: ${MYSCALE_DATABASE:-dify} MYSCALE_FTS_PARAMS: ${MYSCALE_FTS_PARAMS:-} - COUCHBASE_CONNECTION_STRING: ${COUCHBASE_CONNECTION_STRING:-"couchbase://couchbase-server"} + COUCHBASE_CONNECTION_STRING: ${COUCHBASE_CONNECTION_STRING:-couchbase://couchbase-server} COUCHBASE_USER: ${COUCHBASE_USER:-Administrator} COUCHBASE_PASSWORD: ${COUCHBASE_PASSWORD:-password} COUCHBASE_BUCKET_NAME: ${COUCHBASE_BUCKET_NAME:-Embeddings} @@ -176,15 +179,15 @@ x-shared-env: &shared-api-worker-env TIDB_VECTOR_USER: ${TIDB_VECTOR_USER:-} TIDB_VECTOR_PASSWORD: ${TIDB_VECTOR_PASSWORD:-} TIDB_VECTOR_DATABASE: ${TIDB_VECTOR_DATABASE:-dify} - TIDB_ON_QDRANT_URL: ${TIDB_ON_QDRANT_URL:-"http://127.0.0.1"} + TIDB_ON_QDRANT_URL: ${TIDB_ON_QDRANT_URL:-http://127.0.0.1} TIDB_ON_QDRANT_API_KEY: ${TIDB_ON_QDRANT_API_KEY:-dify} TIDB_ON_QDRANT_CLIENT_TIMEOUT: ${TIDB_ON_QDRANT_CLIENT_TIMEOUT:-20} TIDB_ON_QDRANT_GRPC_ENABLED: ${TIDB_ON_QDRANT_GRPC_ENABLED:-false} TIDB_ON_QDRANT_GRPC_PORT: ${TIDB_ON_QDRANT_GRPC_PORT:-6334} TIDB_PUBLIC_KEY: ${TIDB_PUBLIC_KEY:-dify} TIDB_PRIVATE_KEY: ${TIDB_PRIVATE_KEY:-dify} - TIDB_API_URL: ${TIDB_API_URL:-"http://127.0.0.1"} - TIDB_IAM_API_URL: ${TIDB_IAM_API_URL:-"http://127.0.0.1"} + TIDB_API_URL: ${TIDB_API_URL:-http://127.0.0.1} + TIDB_IAM_API_URL: ${TIDB_IAM_API_URL:-http://127.0.0.1} TIDB_REGION: ${TIDB_REGION:-regions/aws-us-east-1} TIDB_PROJECT_ID: ${TIDB_PROJECT_ID:-dify} TIDB_SPEND_LIMIT: ${TIDB_SPEND_LIMIT:-100} @@ -209,7 +212,7 @@ x-shared-env: &shared-api-worker-env OPENSEARCH_USER: ${OPENSEARCH_USER:-admin} OPENSEARCH_PASSWORD: ${OPENSEARCH_PASSWORD:-admin} OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true} - TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-"http://127.0.0.1"} + TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-http://127.0.0.1} TENCENT_VECTOR_DB_API_KEY: ${TENCENT_VECTOR_DB_API_KEY:-dify} TENCENT_VECTOR_DB_TIMEOUT: ${TENCENT_VECTOR_DB_TIMEOUT:-30} TENCENT_VECTOR_DB_USERNAME: ${TENCENT_VECTOR_DB_USERNAME:-dify} @@ -221,7 +224,7 @@ x-shared-env: &shared-api-worker-env ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic} ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} KIBANA_PORT: ${KIBANA_PORT:-5601} - BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-"http://127.0.0.1:5287"} + BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-http://127.0.0.1:5287} BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: ${BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS:-30000} BAIDU_VECTOR_DB_ACCOUNT: ${BAIDU_VECTOR_DB_ACCOUNT:-root} BAIDU_VECTOR_DB_API_KEY: ${BAIDU_VECTOR_DB_API_KEY:-dify} @@ -235,7 +238,7 @@ x-shared-env: &shared-api-worker-env VIKINGDB_SCHEMA: ${VIKINGDB_SCHEMA:-http} VIKINGDB_CONNECTION_TIMEOUT: ${VIKINGDB_CONNECTION_TIMEOUT:-30} VIKINGDB_SOCKET_TIMEOUT: ${VIKINGDB_SOCKET_TIMEOUT:-30} - LINDORM_URL: ${LINDORM_URL:-"http://lindorm:30070"} + LINDORM_URL: ${LINDORM_URL:-http://lindorm:30070} LINDORM_USERNAME: ${LINDORM_USERNAME:-lindorm} LINDORM_PASSWORD: ${LINDORM_PASSWORD:-lindorm} OCEANBASE_VECTOR_HOST: ${OCEANBASE_VECTOR_HOST:-oceanbase} @@ -245,7 +248,7 @@ x-shared-env: &shared-api-worker-env OCEANBASE_VECTOR_DATABASE: ${OCEANBASE_VECTOR_DATABASE:-test} OCEANBASE_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OCEANBASE_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G} - UPSTASH_VECTOR_URL: ${UPSTASH_VECTOR_URL:-"https://xxx-vector.upstash.io"} + UPSTASH_VECTOR_URL: ${UPSTASH_VECTOR_URL:-https://xxx-vector.upstash.io} UPSTASH_VECTOR_TOKEN: ${UPSTASH_VECTOR_TOKEN:-dify} UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} @@ -270,7 +273,7 @@ x-shared-env: &shared-api-worker-env NOTION_INTERNAL_SECRET: ${NOTION_INTERNAL_SECRET:-} MAIL_TYPE: ${MAIL_TYPE:-resend} MAIL_DEFAULT_SEND_FROM: ${MAIL_DEFAULT_SEND_FROM:-} - RESEND_API_URL: ${RESEND_API_URL:-"https://api.resend.com"} + RESEND_API_URL: ${RESEND_API_URL:-https://api.resend.com} RESEND_API_KEY: ${RESEND_API_KEY:-your-resend-api-key} SMTP_SERVER: ${SMTP_SERVER:-} SMTP_PORT: ${SMTP_PORT:-465} @@ -281,7 +284,7 @@ x-shared-env: &shared-api-worker-env INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000} INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72} RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5} - CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-"http://sandbox:8194"} + CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-http://sandbox:8194} CODE_EXECUTION_API_KEY: ${CODE_EXECUTION_API_KEY:-dify-sandbox} CODE_MAX_NUMBER: ${CODE_MAX_NUMBER:-9223372036854775807} CODE_MIN_NUMBER: ${CODE_MIN_NUMBER:--9223372036854775808} @@ -303,8 +306,8 @@ x-shared-env: &shared-api-worker-env WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} - SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-"http://ssrf_proxy:3128"} - SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-"http://ssrf_proxy:3128"} + SSRF_PROXY_HTTP_URL: ${SSRF_PROXY_HTTP_URL:-http://ssrf_proxy:3128} + SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-http://ssrf_proxy:3128} TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} PGUSER: ${PGUSER:-${DB_USERNAME}} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}} @@ -314,8 +317,8 @@ x-shared-env: &shared-api-worker-env SANDBOX_GIN_MODE: ${SANDBOX_GIN_MODE:-release} SANDBOX_WORKER_TIMEOUT: ${SANDBOX_WORKER_TIMEOUT:-15} SANDBOX_ENABLE_NETWORK: ${SANDBOX_ENABLE_NETWORK:-true} - SANDBOX_HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-"http://ssrf_proxy:3128"} - SANDBOX_HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-"http://ssrf_proxy:3128"} + SANDBOX_HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} + SANDBOX_HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} SANDBOX_PORT: ${SANDBOX_PORT:-8194} WEAVIATE_PERSISTENCE_DATA_PATH: ${WEAVIATE_PERSISTENCE_DATA_PATH:-/var/lib/weaviate} WEAVIATE_QUERY_DEFAULTS_LIMIT: ${WEAVIATE_QUERY_DEFAULTS_LIMIT:-25} @@ -338,8 +341,8 @@ x-shared-env: &shared-api-worker-env ETCD_SNAPSHOT_COUNT: ${ETCD_SNAPSHOT_COUNT:-50000} MINIO_ACCESS_KEY: ${MINIO_ACCESS_KEY:-minioadmin} MINIO_SECRET_KEY: ${MINIO_SECRET_KEY:-minioadmin} - ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-"etcd:2379"} - MINIO_ADDRESS: ${MINIO_ADDRESS:-"minio:9000"} + ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} + MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} MILVUS_AUTHORIZATION_ENABLED: ${MILVUS_AUTHORIZATION_ENABLED:-true} PGVECTOR_PGUSER: ${PGVECTOR_PGUSER:-postgres} PGVECTOR_POSTGRES_PASSWORD: ${PGVECTOR_POSTGRES_PASSWORD:-difyai123456} @@ -360,7 +363,7 @@ x-shared-env: &shared-api-worker-env NGINX_SSL_PORT: ${NGINX_SSL_PORT:-443} NGINX_SSL_CERT_FILENAME: ${NGINX_SSL_CERT_FILENAME:-dify.crt} NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} - NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-"TLSv1.1 TLSv1.2 TLSv1.3"} + NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3} NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-15M} NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} @@ -374,7 +377,6 @@ x-shared-env: &shared-api-worker-env SSRF_COREDUMP_DIR: ${SSRF_COREDUMP_DIR:-/var/spool/squid} SSRF_REVERSE_PROXY_PORT: ${SSRF_REVERSE_PROXY_PORT:-8194} SSRF_SANDBOX_HOST: ${SSRF_SANDBOX_HOST:-sandbox} - COMPOSE_PROFILES: ${COMPOSE_PROFILES:-"${VECTOR_STORE:-weaviate}"} EXPOSE_NGINX_PORT: ${EXPOSE_NGINX_PORT:-80} EXPOSE_NGINX_SSL_PORT: ${EXPOSE_NGINX_SSL_PORT:-443} POSITION_TOOL_PINS: ${POSITION_TOOL_PINS:-} @@ -386,11 +388,12 @@ x-shared-env: &shared-api-worker-env CSP_WHITELIST: ${CSP_WHITELIST:-} CREATE_TIDB_SERVICE_JOB_ENABLED: ${CREATE_TIDB_SERVICE_JOB_ENABLED:-false} MAX_SUBMIT_COUNT: ${MAX_SUBMIT_COUNT:-100} + TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-10} services: # API service api: - image: langgenius/dify-api:0.14.2 + image: langgenius/dify-api:0.15.1 restart: always environment: # Use the shared environment variables. @@ -413,7 +416,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.14.2 + image: langgenius/dify-api:0.15.1 restart: always environment: # Use the shared environment variables. @@ -435,7 +438,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.14.2 + image: langgenius/dify-web:0.15.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -444,6 +447,7 @@ services: NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0} TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} CSP_WHITELIST: ${CSP_WHITELIST:-} + TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-} # The postgres database. db: @@ -463,7 +467,7 @@ services: volumes: - ./volumes/db/data:/var/lib/postgresql/data healthcheck: - test: ['CMD', 'pg_isready'] + test: [ 'CMD', 'pg_isready' ] interval: 1s timeout: 3s retries: 30 @@ -480,7 +484,7 @@ services: # Set the redis password when startup redis server. command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456} healthcheck: - test: ['CMD', 'redis-cli', 'ping'] + test: [ 'CMD', 'redis-cli', 'ping' ] # The DifySandbox sandbox: @@ -500,7 +504,7 @@ services: volumes: - ./volumes/sandbox/dependencies:/dependencies healthcheck: - test: ['CMD', 'curl', '-f', 'http://localhost:8194/health'] + test: [ 'CMD', 'curl', '-f', 'http://localhost:8194/health' ] networks: - ssrf_proxy_network @@ -513,12 +517,7 @@ services: volumes: - ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template - ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh - entrypoint: - [ - 'sh', - '-c', - "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", - ] + entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ] environment: # pls clearly modify the squid env vars to fit your network environment. HTTP_PORT: ${SSRF_HTTP_PORT:-3128} @@ -547,8 +546,8 @@ services: - CERTBOT_EMAIL=${CERTBOT_EMAIL} - CERTBOT_DOMAIN=${CERTBOT_DOMAIN} - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} - entrypoint: ['/docker-entrypoint.sh'] - command: ['tail', '-f', '/dev/null'] + entrypoint: [ '/docker-entrypoint.sh' ] + command: [ 'tail', '-f', '/dev/null' ] # The nginx reverse proxy. # used for reverse proxying the API service and Web service. @@ -565,12 +564,7 @@ services: - ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container) - ./volumes/certbot/conf:/etc/letsencrypt - ./volumes/certbot/www:/var/www/html - entrypoint: - [ - 'sh', - '-c', - "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", - ] + entrypoint: [ 'sh', '-c', "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ] environment: NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_} NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false} @@ -662,7 +656,7 @@ services: working_dir: /opt/couchbase stdin_open: true tty: true - entrypoint: [""] + entrypoint: [ "" ] command: sh -c "/opt/couchbase/init/init-cbserver.sh" volumes: - ./volumes/couchbase/data:/opt/couchbase/var/lib/couchbase/data @@ -691,7 +685,7 @@ services: volumes: - ./volumes/pgvector/data:/var/lib/postgresql/data healthcheck: - test: ['CMD', 'pg_isready'] + test: [ 'CMD', 'pg_isready' ] interval: 1s timeout: 3s retries: 30 @@ -713,7 +707,7 @@ services: volumes: - ./volumes/pgvecto_rs/data:/var/lib/postgresql/data healthcheck: - test: ['CMD', 'pg_isready'] + test: [ 'CMD', 'pg_isready' ] interval: 1s timeout: 3s retries: 30 @@ -778,7 +772,7 @@ services: - ./volumes/milvus/etcd:/etcd command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd healthcheck: - test: ['CMD', 'etcdctl', 'endpoint', 'health'] + test: [ 'CMD', 'etcdctl', 'endpoint', 'health' ] interval: 30s timeout: 20s retries: 3 @@ -797,7 +791,7 @@ services: - ./volumes/milvus/minio:/minio_data command: minio server /minio_data --console-address ":9001" healthcheck: - test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live'] + test: [ 'CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live' ] interval: 30s timeout: 20s retries: 3 @@ -806,10 +800,10 @@ services: milvus-standalone: container_name: milvus-standalone - image: milvusdb/milvus:v2.3.1 + image: milvusdb/milvus:v2.5.0-beta profiles: - milvus - command: ['milvus', 'run', 'standalone'] + command: [ 'milvus', 'run', 'standalone' ] environment: ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} @@ -817,7 +811,7 @@ services: volumes: - ./volumes/milvus/milvus:/var/lib/milvus healthcheck: - test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz'] + test: [ 'CMD', 'curl', '-f', 'http://localhost:9091/healthz' ] interval: 30s start_period: 90s timeout: 20s @@ -890,22 +884,30 @@ services: container_name: elasticsearch profiles: - elasticsearch + - elasticsearch-ja restart: always volumes: + - ./elasticsearch/docker-entrypoint.sh:/docker-entrypoint-mount.sh - dify_es01_data:/usr/share/elasticsearch/data environment: ELASTIC_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} + VECTOR_STORE: ${VECTOR_STORE:-} cluster.name: dify-es-cluster node.name: dify-es0 discovery.type: single-node - xpack.license.self_generated.type: trial + xpack.license.self_generated.type: basic xpack.security.enabled: 'true' xpack.security.enrollment.enabled: 'false' xpack.security.http.ssl.enabled: 'false' ports: - ${ELASTICSEARCH_PORT:-9200}:9200 + deploy: + resources: + limits: + memory: 2g + entrypoint: [ 'sh', '-c', "sh /docker-entrypoint-mount.sh" ] healthcheck: - test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty'] + test: [ 'CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty' ] interval: 30s timeout: 10s retries: 50 @@ -933,7 +935,7 @@ services: ports: - ${KIBANA_PORT:-5601}:5601 healthcheck: - test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1'] + test: [ 'CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1' ] interval: 30s timeout: 10s retries: 3 diff --git a/docker/elasticsearch/docker-entrypoint.sh b/docker/elasticsearch/docker-entrypoint.sh new file mode 100755 index 000000000..6669aec5a --- /dev/null +++ b/docker/elasticsearch/docker-entrypoint.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +set -e + +if [ "${VECTOR_STORE}" = "elasticsearch-ja" ]; then + # Check if the ICU tokenizer plugin is installed + if ! /usr/share/elasticsearch/bin/elasticsearch-plugin list | grep -q analysis-icu; then + printf '%s\n' "Installing the ICU tokenizer plugin" + if ! /usr/share/elasticsearch/bin/elasticsearch-plugin install analysis-icu; then + printf '%s\n' "Failed to install the ICU tokenizer plugin" + exit 1 + fi + fi + # Check if the Japanese language analyzer plugin is installed + if ! /usr/share/elasticsearch/bin/elasticsearch-plugin list | grep -q analysis-kuromoji; then + printf '%s\n' "Installing the Japanese language analyzer plugin" + if ! /usr/share/elasticsearch/bin/elasticsearch-plugin install analysis-kuromoji; then + printf '%s\n' "Failed to install the Japanese language analyzer plugin" + exit 1 + fi + fi +fi + +# Run the original entrypoint script +exec /bin/tini -- /usr/local/bin/docker-entrypoint.sh diff --git a/docker/generate_docker_compose b/docker/generate_docker_compose index 54b6d5521..b5c0acefb 100755 --- a/docker/generate_docker_compose +++ b/docker/generate_docker_compose @@ -37,13 +37,15 @@ def generate_shared_env_block(env_vars, anchor_name="shared-api-worker-env"): """ lines = [f"x-shared-env: &{anchor_name}"] for key, default in env_vars.items(): + if key == "COMPOSE_PROFILES": + continue # If default value is empty, use ${KEY:-} if default == "": lines.append(f" {key}: ${{{key}:-}}") else: # If default value contains special characters, wrap it in quotes if re.search(r"[:\s]", default): - default = f'"{default}"' + default = f"{default}" lines.append(f" {key}: ${{{key}:-{default}}}") return "\n".join(lines) diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py index e66448830..ee1b5c57e 100644 --- a/sdks/python-client/dify_client/client.py +++ b/sdks/python-client/dify_client/client.py @@ -160,7 +160,10 @@ def get_result(self, workflow_run_id): class KnowledgeBaseClient(DifyClient): def __init__( - self, api_key, base_url: str = "https://api.dify.ai/v1", dataset_id: str = None + self, + api_key, + base_url: str = "https://api.dify.ai/v1", + dataset_id: str | None = None, ): """ Construct a KnowledgeBaseClient object. @@ -187,7 +190,9 @@ def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs): "GET", f"/datasets?page={page}&limit={page_size}", **kwargs ) - def create_document_by_text(self, name, text, extra_params: dict = None, **kwargs): + def create_document_by_text( + self, name, text, extra_params: dict | None = None, **kwargs + ): """ Create a document by text. @@ -225,7 +230,7 @@ def create_document_by_text(self, name, text, extra_params: dict = None, **kwarg return self._send_request("POST", url, json=data, **kwargs) def update_document_by_text( - self, document_id, name, text, extra_params: dict = None, **kwargs + self, document_id, name, text, extra_params: dict | None = None, **kwargs ): """ Update a document by text. @@ -262,7 +267,7 @@ def update_document_by_text( return self._send_request("POST", url, json=data, **kwargs) def create_document_by_file( - self, file_path, original_document_id=None, extra_params: dict = None + self, file_path, original_document_id=None, extra_params: dict | None = None ): """ Create a document by file. @@ -304,7 +309,7 @@ def create_document_by_file( ) def update_document_by_file( - self, document_id, file_path, extra_params: dict = None + self, document_id, file_path, extra_params: dict | None = None ): """ Update a document by file. @@ -372,7 +377,11 @@ def delete_document(self, document_id): return self._send_request("DELETE", url) def list_documents( - self, page: int = None, page_size: int = None, keyword: str = None, **kwargs + self, + page: int | None = None, + page_size: int | None = None, + keyword: str | None = None, + **kwargs, ): """ Get a list of documents in this dataset. @@ -402,7 +411,11 @@ def add_segments(self, document_id, segments, **kwargs): return self._send_request("POST", url, json=data, **kwargs) def query_segments( - self, document_id, keyword: str = None, status: str = None, **kwargs + self, + document_id, + keyword: str | None = None, + status: str | None = None, + **kwargs, ): """ Query segments in this document. diff --git a/web/.env.example b/web/.env.example index 5a7b5c090..87eb982cf 100644 --- a/web/.env.example +++ b/web/.env.example @@ -26,6 +26,9 @@ NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000 # CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP NEXT_PUBLIC_CSP_WHITELIST= +# The maximum number of top-k value for RAG. +NEXT_PUBLIC_TOP_K_MAX_VALUE=10 + # Default Domain Extend(二开新增配置) NEXT_PUBLIC_DEFAULT_DOMAIN= diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx index 8e3d8f9ec..17f46c258 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx @@ -5,7 +5,7 @@ import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' import TracingIcon from './tracing-icon' import ProviderPanel from './provider-panel' -import type { LangFuseConfig, LangSmithConfig } from './type' +import type { LangFuseConfig, LangSmithConfig, OpikConfig } from './type' import { TracingProvider } from './type' import ProviderConfigModal from './provider-config-modal' import Indicator from '@/app/components/header/indicator' @@ -23,7 +23,8 @@ export type PopupProps = { onChooseProvider: (provider: TracingProvider) => void langSmithConfig: LangSmithConfig | null langFuseConfig: LangFuseConfig | null - onConfigUpdated: (provider: TracingProvider, payload: LangSmithConfig | LangFuseConfig) => void + opikConfig: OpikConfig | null + onConfigUpdated: (provider: TracingProvider, payload: LangSmithConfig | LangFuseConfig | OpikConfig) => void onConfigRemoved: (provider: TracingProvider) => void } @@ -36,6 +37,7 @@ const ConfigPopup: FC = ({ onChooseProvider, langSmithConfig, langFuseConfig, + opikConfig, onConfigUpdated, onConfigRemoved, }) => { @@ -59,7 +61,7 @@ const ConfigPopup: FC = ({ } }, [onChooseProvider]) - const handleConfigUpdated = useCallback((payload: LangSmithConfig | LangFuseConfig) => { + const handleConfigUpdated = useCallback((payload: LangSmithConfig | LangFuseConfig | OpikConfig) => { onConfigUpdated(currentProvider!, payload) hideConfigModal() }, [currentProvider, hideConfigModal, onConfigUpdated]) @@ -69,8 +71,8 @@ const ConfigPopup: FC = ({ hideConfigModal() }, [currentProvider, hideConfigModal, onConfigRemoved]) - const providerAllConfigured = langSmithConfig && langFuseConfig - const providerAllNotConfigured = !langSmithConfig && !langFuseConfig + const providerAllConfigured = langSmithConfig && langFuseConfig && opikConfig + const providerAllNotConfigured = !langSmithConfig && !langFuseConfig && !opikConfig const switchContent = ( = ({ onConfig={handleOnConfig(TracingProvider.langSmith)} isChosen={chosenProvider === TracingProvider.langSmith} onChoose={handleOnChoose(TracingProvider.langSmith)} + key="langSmith-provider-panel" /> ) @@ -102,9 +105,61 @@ const ConfigPopup: FC = ({ onConfig={handleOnConfig(TracingProvider.langfuse)} isChosen={chosenProvider === TracingProvider.langfuse} onChoose={handleOnChoose(TracingProvider.langfuse)} + key="langfuse-provider-panel" /> ) + const opikPanel = ( + + ) + + const configuredProviderPanel = () => { + const configuredPanels: ProviderPanel[] = [] + + if (langSmithConfig) + configuredPanels.push(langSmithPanel) + + if (langFuseConfig) + configuredPanels.push(langfusePanel) + + if (opikConfig) + configuredPanels.push(opikPanel) + + return configuredPanels + } + + const moreProviderPanel = () => { + const notConfiguredPanels: ProviderPanel[] = [] + + if (!langSmithConfig) + notConfiguredPanels.push(langSmithPanel) + + if (!langFuseConfig) + notConfiguredPanels.push(langfusePanel) + + if (!opikConfig) + notConfiguredPanels.push(opikPanel) + + return notConfiguredPanels + } + + const configuredProviderConfig = () => { + if (currentProvider === TracingProvider.langSmith) + return langSmithConfig + if (currentProvider === TracingProvider.langfuse) + return langFuseConfig + return opikConfig + } + return (
@@ -146,18 +201,19 @@ const ConfigPopup: FC = ({
{langSmithPanel} {langfusePanel} + {opikPanel}
) : ( <>
{t(`${I18N_PREFIX}.configProviderTitle.configured`)}
-
- {langSmithConfig ? langSmithPanel : langfusePanel} +
+ {configuredProviderPanel()}
{t(`${I18N_PREFIX}.configProviderTitle.moreProvider`)}
-
- {!langSmithConfig ? langSmithPanel : langfusePanel} +
+ {moreProviderPanel()}
)} @@ -167,7 +223,7 @@ const ConfigPopup: FC = ({ { }) } const inUseTracingProvider: TracingProvider | null = tracingStatus?.tracing_provider || null - const InUseProviderIcon = inUseTracingProvider === TracingProvider.langSmith ? LangsmithIcon : LangfuseIcon + + const InUseProviderIcon + = inUseTracingProvider === TracingProvider.langSmith + ? LangsmithIcon + : inUseTracingProvider === TracingProvider.langfuse + ? LangfuseIcon + : inUseTracingProvider === TracingProvider.opik + ? OpikIcon + : null const [langSmithConfig, setLangSmithConfig] = useState(null) const [langFuseConfig, setLangFuseConfig] = useState(null) - const hasConfiguredTracing = !!(langSmithConfig || langFuseConfig) + const [opikConfig, setOpikConfig] = useState(null) + const hasConfiguredTracing = !!(langSmithConfig || langFuseConfig || opikConfig) const fetchTracingConfig = async () => { const { tracing_config: langSmithConfig, has_not_configured: langSmithHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.langSmith }) @@ -83,6 +92,9 @@ const Panel: FC = () => { const { tracing_config: langFuseConfig, has_not_configured: langFuseHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.langfuse }) if (!langFuseHasNotConfig) setLangFuseConfig(langFuseConfig as LangFuseConfig) + const { tracing_config: opikConfig, has_not_configured: OpikHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.opik }) + if (!OpikHasNotConfig) + setOpikConfig(opikConfig as OpikConfig) } const handleTracingConfigUpdated = async (provider: TracingProvider) => { @@ -90,15 +102,19 @@ const Panel: FC = () => { const { tracing_config } = await doFetchTracingConfig({ appId, provider }) if (provider === TracingProvider.langSmith) setLangSmithConfig(tracing_config as LangSmithConfig) - else + else if (provider === TracingProvider.langSmith) setLangFuseConfig(tracing_config as LangFuseConfig) + else if (provider === TracingProvider.opik) + setOpikConfig(tracing_config as OpikConfig) } const handleTracingConfigRemoved = (provider: TracingProvider) => { if (provider === TracingProvider.langSmith) setLangSmithConfig(null) - else + else if (provider === TracingProvider.langSmith) setLangFuseConfig(null) + else if (provider === TracingProvider.opik) + setOpikConfig(null) if (provider === inUseTracingProvider) { handleTracingStatusChange({ enabled: false, @@ -167,6 +183,7 @@ const Panel: FC = () => { onChooseProvider={handleChooseProvider} langSmithConfig={langSmithConfig} langFuseConfig={langFuseConfig} + opikConfig={opikConfig} onConfigUpdated={handleTracingConfigUpdated} onConfigRemoved={handleTracingConfigRemoved} controlShowPopup={controlShowPopup} diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx index e7ecd2f4c..b813e9b13 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-config-modal.tsx @@ -4,7 +4,7 @@ import React, { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' import Field from './field' -import type { LangFuseConfig, LangSmithConfig } from './type' +import type { LangFuseConfig, LangSmithConfig, OpikConfig } from './type' import { TracingProvider } from './type' import { docURL } from './config' import { @@ -21,10 +21,10 @@ import Toast from '@/app/components/base/toast' type Props = { appId: string type: TracingProvider - payload?: LangSmithConfig | LangFuseConfig | null + payload?: LangSmithConfig | LangFuseConfig | OpikConfig | null onRemoved: () => void onCancel: () => void - onSaved: (payload: LangSmithConfig | LangFuseConfig) => void + onSaved: (payload: LangSmithConfig | LangFuseConfig | OpikConfig) => void onChosen: (provider: TracingProvider) => void } @@ -42,6 +42,13 @@ const langFuseConfigTemplate = { host: '', } +const opikConfigTemplate = { + api_key: '', + project: '', + url: '', + workspace: '', +} + const ProviderConfigModal: FC = ({ appId, type, @@ -55,14 +62,17 @@ const ProviderConfigModal: FC = ({ const isEdit = !!payload const isAdd = !isEdit const [isSaving, setIsSaving] = useState(false) - const [config, setConfig] = useState((() => { + const [config, setConfig] = useState((() => { if (isEdit) return payload if (type === TracingProvider.langSmith) return langSmithConfigTemplate - return langFuseConfigTemplate + else if (type === TracingProvider.langfuse) + return langFuseConfigTemplate + + return opikConfigTemplate })()) const [isShowRemoveConfirm, { setTrue: showRemoveConfirm, @@ -111,6 +121,10 @@ const ProviderConfigModal: FC = ({ errorMessage = t('common.errorMsg.fieldRequired', { field: 'Host' }) } + if (type === TracingProvider.opik) { + const postData = config as OpikConfig + } + return errorMessage }, [config, t, type]) const handleSave = useCallback(async () => { @@ -215,6 +229,38 @@ const ProviderConfigModal: FC = ({ /> )} + {type === TracingProvider.opik && ( + <> + + + + + + )}
diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx index 6e5046ecf..34e5bbeb0 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx @@ -4,7 +4,7 @@ import React, { useCallback } from 'react' import { useTranslation } from 'react-i18next' import { TracingProvider } from './type' import cn from '@/utils/classnames' -import { LangfuseIconBig, LangsmithIconBig } from '@/app/components/base/icons/src/public/tracing' +import { LangfuseIconBig, LangsmithIconBig, OpikIconBig } from '@/app/components/base/icons/src/public/tracing' import { Settings04 } from '@/app/components/base/icons/src/vender/line/general' import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general' @@ -24,6 +24,7 @@ const getIcon = (type: TracingProvider) => { return ({ [TracingProvider.langSmith]: LangsmithIconBig, [TracingProvider.langfuse]: LangfuseIconBig, + [TracingProvider.opik]: OpikIconBig, })[type] } diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/type.ts b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/type.ts index e07cf37c9..982d01ffb 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/type.ts +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/type.ts @@ -1,6 +1,7 @@ export enum TracingProvider { langSmith = 'langsmith', langfuse = 'langfuse', + opik = 'opik', } export type LangSmithConfig = { @@ -14,3 +15,10 @@ export type LangFuseConfig = { secret_key: string host: string } + +export type OpikConfig = { + api_key: string + project: string + workspace: string + url: string +} diff --git a/web/app/(commonLayout)/apps/Apps.tsx b/web/app/(commonLayout)/apps/Apps.tsx index 49ec5a1cd..0ddc1df6a 100644 --- a/web/app/(commonLayout)/apps/Apps.tsx +++ b/web/app/(commonLayout)/apps/Apps.tsx @@ -25,16 +25,18 @@ import Input from '@/app/components/base/input' import { useStore as useTagStore } from '@/app/components/base/tag-management/store' import TagManagementModal from '@/app/components/base/tag-management' import TagFilter from '@/app/components/base/tag-management/filter' +import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label' const getKey = ( pageIndex: number, previousPageData: AppListResponse, activeTab: string, + isCreatedByMe: boolean, tags: string[], keywords: string, ) => { if (!pageIndex || previousPageData.has_more) { - const params: any = { url: 'apps', params: { page: pageIndex + 1, limit: 30, name: keywords } } + const params: any = { url: 'apps', params: { page: pageIndex + 1, limit: 30, name: keywords, is_created_by_me: isCreatedByMe } } if (activeTab !== 'all') params.params.mode = activeTab @@ -58,6 +60,7 @@ const Apps = () => { defaultTab: 'all', }) const { query: { tagIDs = [], keywords = '' }, setQuery } = useAppsQueryState() + const [isCreatedByMe, setIsCreatedByMe] = useState(false) const [tagFilterValue, setTagFilterValue] = useState(tagIDs) const [searchKeywords, setSearchKeywords] = useState(keywords) const setKeywords = useCallback((keywords: string) => { @@ -68,7 +71,7 @@ const Apps = () => { }, [setQuery]) const { data, isLoading, setSize, mutate } = useSWRInfinite( - (pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, tagIDs, searchKeywords), + (pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, isCreatedByMe, tagIDs, searchKeywords), fetchAppList, { revalidateFirstPage: true }, ) @@ -133,6 +136,12 @@ const Apps = () => { options={options} />
+ setIsCreatedByMe(!isCreatedByMe)} + /> void } -// eslint-disable-next-line react/display-name const CreateAppCard = forwardRef(({ className, onSuccess }, ref) => { const { t } = useTranslation() const { onPlanInfoChanged } = useProviderContext() @@ -44,24 +43,22 @@ const CreateAppCard = forwardRef(({ classNam >
{t('app.createApp')}
-
setShowNewAppModal(true)}> +
-
setShowNewAppTemplateDialog(true)}> + +
-
-
setShowCreateFromDSLModal(true)} - > -
+ +
+
+ setShowNewAppModal(false)} @@ -108,4 +105,6 @@ const CreateAppCard = forwardRef(({ classNam ) }) +CreateAppCard.displayName = 'CreateAppCard' export default CreateAppCard +export { CreateAppCard } diff --git a/web/app/(commonLayout)/apps/hooks/useAppsQueryState.ts b/web/app/(commonLayout)/apps/hooks/useAppsQueryState.ts index fae5357bf..7f1f4ba65 100644 --- a/web/app/(commonLayout)/apps/hooks/useAppsQueryState.ts +++ b/web/app/(commonLayout)/apps/hooks/useAppsQueryState.ts @@ -37,7 +37,7 @@ function useAppsQueryState() { const syncSearchParams = useCallback((params: URLSearchParams) => { const search = params.toString() const query = search ? `?${search}` : '' - router.push(`${pathname}${query}`) + router.push(`${pathname}${query}`, { scroll: false }) }, [router, pathname]) // Update the URL search string whenever the query changes. diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx index b416659a6..a6fb116fa 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx @@ -7,85 +7,36 @@ import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' import { Cog8ToothIcon, - // CommandLineIcon, - Squares2X2Icon, - // eslint-disable-next-line sort-imports - PuzzlePieceIcon, DocumentTextIcon, PaperClipIcon, - QuestionMarkCircleIcon, } from '@heroicons/react/24/outline' import { Cog8ToothIcon as Cog8ToothSolidIcon, // CommandLineIcon as CommandLineSolidIcon, DocumentTextIcon as DocumentTextSolidIcon, } from '@heroicons/react/24/solid' -import Link from 'next/link' +import { RiApps2AddLine, RiInformation2Line } from '@remixicon/react' import s from './style.module.css' import classNames from '@/utils/classnames' import { fetchDatasetDetail, fetchDatasetRelatedApps } from '@/service/datasets' -import type { RelatedApp, RelatedAppResponse } from '@/models/datasets' +import type { RelatedAppResponse } from '@/models/datasets' import AppSideBar from '@/app/components/app-sidebar' -import Divider from '@/app/components/base/divider' -import AppIcon from '@/app/components/base/app-icon' import Loading from '@/app/components/base/loading' -import FloatPopoverContainer from '@/app/components/base/float-popover-container' import DatasetDetailContext from '@/context/dataset-detail' import { DataSourceType } from '@/models/datasets' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import { LanguagesSupported } from '@/i18n/language' import { useStore } from '@/app/components/app/store' -import { AiText, ChatBot, CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication' -import { Route } from '@/app/components/base/icons/src/vender/solid/mapsAndTravel' import { getLocaleOnClient } from '@/i18n' import { useAppContext } from '@/context/app-context' +import Tooltip from '@/app/components/base/tooltip' +import LinkedAppsPanel from '@/app/components/base/linked-apps-panel' export type IAppDetailLayoutProps = { children: React.ReactNode params: { datasetId: string } } -type ILikedItemProps = { - type?: 'plugin' | 'app' - appStatus?: boolean - detail: RelatedApp - isMobile: boolean -} - -const LikedItem = ({ - type = 'app', - detail, - isMobile, -}: ILikedItemProps) => { - return ( - -
- - {type === 'app' && ( - - {detail.mode === 'advanced-chat' && ( - - )} - {detail.mode === 'agent-chat' && ( - - )} - {detail.mode === 'chat' && ( - - )} - {detail.mode === 'completion' && ( - - )} - {detail.mode === 'workflow' && ( - - )} - - )} -
- {!isMobile &&
{detail?.name || '--'}
} - - ) -} - const TargetIcon = ({ className }: SVGProps) => { return @@ -117,65 +68,80 @@ const BookOpenIcon = ({ className }: SVGProps) => { type IExtraInfoProps = { isMobile: boolean relatedApps?: RelatedAppResponse + expand: boolean } -const ExtraInfo = ({ isMobile, relatedApps }: IExtraInfoProps) => { +const ExtraInfo = ({ isMobile, relatedApps, expand }: IExtraInfoProps) => { const locale = getLocaleOnClient() const [isShowTips, { toggle: toggleTips, set: setShowTips }] = useBoolean(!isMobile) const { t } = useTranslation() + const hasRelatedApps = relatedApps?.data && relatedApps?.data?.length > 0 + const relatedAppsTotal = relatedApps?.data?.length || 0 + useEffect(() => { setShowTips(!isMobile) }, [isMobile, setShowTips]) - return
- - {(relatedApps?.data && relatedApps?.data?.length > 0) && ( + return
+ {hasRelatedApps && ( <> - {!isMobile &&
{relatedApps?.total || '--'} {t('common.datasetMenus.relatedApp')}
} + {!isMobile && ( + + } + > +
+ {relatedAppsTotal || '--'} {t('common.datasetMenus.relatedApp')} + +
+
+ )} + {isMobile &&
- {relatedApps?.total || '--'} + {relatedAppsTotal || '--'}
} - {relatedApps?.data?.map((item, index) => ())} )} - {!relatedApps?.data?.length && ( - - + {!hasRelatedApps && !expand && ( + +
+ +
+
{t('common.datasetMenus.emptyTip')}
+ + + {t('common.datasetMenus.viewDoc')} +
} > -
-
-
- -
-
- -
-
-
{t('common.datasetMenus.emptyTip')}
- - - {t('common.datasetMenus.viewDoc')} - +
+ {t('common.datasetMenus.noRelatedApp')} +
- + )}
} @@ -235,7 +201,7 @@ const DatasetDetailLayout: FC = (props) => { }, [isMobile, setAppSiderbarExpand]) if (!datasetRes && !error) - return + return return (
@@ -246,7 +212,7 @@ const DatasetDetailLayout: FC = (props) => { desc={datasetRes?.description || '--'} isExternal={datasetRes?.provider === 'external'} navigation={navigation} - extraInfo={!isCurrentWorkspaceDatasetOperator ? mode => : undefined} + extraInfo={!isCurrentWorkspaceDatasetOperator ? mode => : undefined} iconType={datasetRes?.data_source_type === DataSourceType.NOTION ? 'notion' : 'dataset'} />} = (props) => { dataset: datasetRes, mutateDatasetRes: () => mutateDatasetRes(), }}> -
{children}
+
{children}
) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx index df314ddaf..3a65f1d30 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx @@ -7,10 +7,10 @@ const Settings = async () => { const { t } = await translate(locale, 'dataset-settings') return ( -
+
-
{t('title')}
-
{t('desc')}
+
{t('title')}
+
{t('desc')}
diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/style.module.css b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/style.module.css index 0ee64b4fc..516b12480 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/style.module.css +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/style.module.css @@ -1,12 +1,3 @@ -.itemWrapper { - @apply flex items-center w-full h-10 rounded-lg hover:bg-gray-50 cursor-pointer; -} -.appInfo { - @apply truncate text-gray-700 text-sm font-normal; -} -.iconWrapper { - @apply relative w-6 h-6 rounded-lg; -} .statusPoint { @apply flex justify-center items-center absolute -right-0.5 -bottom-0.5 w-2.5 h-2.5 bg-white rounded; } diff --git a/web/app/(commonLayout)/datasets/Container.tsx b/web/app/(commonLayout)/datasets/Container.tsx index a30521d99..f484d30a3 100644 --- a/web/app/(commonLayout)/datasets/Container.tsx +++ b/web/app/(commonLayout)/datasets/Container.tsx @@ -4,7 +4,8 @@ import { useEffect, useMemo, useRef, useState } from 'react' import { useRouter } from 'next/navigation' import { useTranslation } from 'react-i18next' -import { useDebounceFn } from 'ahooks' +import { useBoolean, useDebounceFn } from 'ahooks' +import { useQuery } from '@tanstack/react-query' // Components import ExternalAPIPanel from '../../components/datasets/external-api/external-api-panel' @@ -16,8 +17,9 @@ import TabSliderNew from '@/app/components/base/tab-slider-new' import TagManagementModal from '@/app/components/base/tag-management' import TagFilter from '@/app/components/base/tag-management/filter' import Button from '@/app/components/base/button' +import Input from '@/app/components/base/input' import { ApiConnectionMod } from '@/app/components/base/icons/src/vender/solid/development' -import SearchInput from '@/app/components/base/search-input' +import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label' // Services import { fetchDatasetApiBaseUrl } from '@/service/datasets' @@ -27,15 +29,14 @@ import { useTabSearchParams } from '@/hooks/use-tab-searchparams' import { useStore as useTagStore } from '@/app/components/base/tag-management/store' import { useAppContext } from '@/context/app-context' import { useExternalApiPanel } from '@/context/external-api-panel-context' -// eslint-disable-next-line import/order -import { useQuery } from '@tanstack/react-query' const Container = () => { const { t } = useTranslation() const router = useRouter() - const { currentWorkspace } = useAppContext() + const { currentWorkspace, isCurrentWorkspaceOwner } = useAppContext() const showTagManagementModal = useTagStore(s => s.showTagManagementModal) const { showExternalApiPanel, setShowExternalApiPanel } = useExternalApiPanel() + const [includeAll, { toggle: toggleIncludeAll }] = useBoolean(false) const options = useMemo(() => { return [ @@ -81,17 +82,32 @@ const Container = () => { }, [currentWorkspace, router]) return ( -
-
+
+
setActiveTab(newActiveTab)} options={options} /> {activeTab === 'dataset' && ( -
+
+ {isCurrentWorkspaceOwner && } - + handleKeywordsChange(e.target.value)} + onClear={() => handleKeywordsChange('')} + />
{activeTab === 'dataset' && ( <> - + {showTagManagementModal && ( diff --git a/web/app/(commonLayout)/datasets/DatasetCard.tsx b/web/app/(commonLayout)/datasets/DatasetCard.tsx index e8ccddbcb..ad83a41df 100644 --- a/web/app/(commonLayout)/datasets/DatasetCard.tsx +++ b/web/app/(commonLayout)/datasets/DatasetCard.tsx @@ -111,7 +111,7 @@ const DatasetCard = ({ return ( <>
{ e.preventDefault() @@ -129,8 +129,8 @@ const DatasetCard = ({
-
-
{dataset.name}
+
+
{dataset.name}
{!dataset.embedding_available && ( )}
-
+
{ if (!pageIndex || previousPageData.has_more) { - const params: any = { + const params: FetchDatasetsParams = { url: 'datasets', params: { page: pageIndex + 1, limit: 30, + include_all: includeAll, }, } if (tags.length) @@ -37,16 +39,18 @@ type Props = { containerRef: React.RefObject tags: string[] keywords: string + includeAll: boolean } const Datasets = ({ containerRef, tags, keywords, + includeAll, }: Props) => { const { isCurrentWorkspaceEditor } = useAppContext() const { data, isLoading, setSize, mutate } = useSWRInfinite( - (pageIndex: number, previousPageData: DataSetListResponse) => getKey(pageIndex, previousPageData, tags, keywords), + (pageIndex: number, previousPageData: DataSetListResponse) => getKey(pageIndex, previousPageData, tags, keywords, includeAll), fetchDatasets, { revalidateFirstPage: false, revalidateAll: true }, ) diff --git a/web/app/(commonLayout)/datasets/Doc.tsx b/web/app/(commonLayout)/datasets/Doc.tsx index 553dca500..f7a7e8b06 100644 --- a/web/app/(commonLayout)/datasets/Doc.tsx +++ b/web/app/(commonLayout)/datasets/Doc.tsx @@ -1,7 +1,9 @@ 'use client' -import { type FC, useEffect } from 'react' +import { useEffect, useState } from 'react' import { useContext } from 'use-context-selector' +import { useTranslation } from 'react-i18next' +import { RiListUnordered } from '@remixicon/react' import TemplateEn from './template/template.en.mdx' import TemplateZh from './template/template.zh.mdx' import I18n from '@/context/i18n' @@ -10,25 +12,106 @@ import { LanguagesSupported } from '@/i18n/language' type DocProps = { apiBaseUrl: string } -const Doc: FC = ({ - apiBaseUrl, -}) => { + +const Doc = ({ apiBaseUrl }: DocProps) => { const { locale } = useContext(I18n) + const { t } = useTranslation() + const [toc, setToc] = useState>([]) + const [isTocExpanded, setIsTocExpanded] = useState(false) + // Set initial TOC expanded state based on screen width useEffect(() => { - const hash = location.hash - if (hash) - document.querySelector(hash)?.scrollIntoView() + const mediaQuery = window.matchMedia('(min-width: 1280px)') + setIsTocExpanded(mediaQuery.matches) }, []) + // Extract TOC from article content + useEffect(() => { + const extractTOC = () => { + const article = document.querySelector('article') + if (article) { + const headings = article.querySelectorAll('h2') + const tocItems = Array.from(headings).map((heading) => { + const anchor = heading.querySelector('a') + if (anchor) { + return { + href: anchor.getAttribute('href') || '', + text: anchor.textContent || '', + } + } + return null + }).filter((item): item is { href: string; text: string } => item !== null) + setToc(tocItems) + } + } + + setTimeout(extractTOC, 0) + }, [locale]) + + // Handle TOC item click + const handleTocClick = (e: React.MouseEvent, item: { href: string; text: string }) => { + e.preventDefault() + const targetId = item.href.replace('#', '') + const element = document.getElementById(targetId) + if (element) { + const scrollContainer = document.querySelector('.scroll-container') + if (scrollContainer) { + const headerOffset = -40 + const elementTop = element.offsetTop - headerOffset + scrollContainer.scrollTo({ + top: elementTop, + behavior: 'smooth', + }) + } + } + } + return ( -
- { - locale !== LanguagesSupported[1] +
+
+ {isTocExpanded + ? ( + + ) + : ( + + )} +
+
+ {locale !== LanguagesSupported[1] ? : - } -
+ } +
+
) } diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index d3dcfc4b2..3fa22a162 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -1,5 +1,5 @@ import { CodeGroup } from '@/app/components/develop/code.tsx' -import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from '@/app/components/develop/md.tsx' +import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstruction, Paragraph } from '@/app/components/develop/md.tsx' # Knowledge API @@ -52,6 +52,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - high_quality High quality: embedding using embedding model, built as vector database index - economy Economy: Build using inverted index of keyword table index + + Format of indexed content + - text_model Text documents are directly embedded; `economy` mode defaults to using this form + - hierarchical_model Parent-child mode + - qa_model Q&A Mode: Generates Q&A pairs for segmented documents and then embeds the questions + + + In Q&A mode, specify the language of the document, for example: English, Chinese + Processing rules - mode (string) Cleaning, segmentation mode, automatic / custom @@ -65,6 +74,32 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - segmentation (object) Segmentation rules - separator Custom segment identifier, currently only allows one delimiter to be set. Default is \n - max_tokens Maximum length (token) defaults to 1000 + - parent_mode Retrieval mode of parent chunks: full-doc full text retrieval / paragraph paragraph retrieval + - subchunk_segmentation (object) Child chunk rules + - separator Segmentation identifier. Currently, only one delimiter is allowed. The default is *** + - max_tokens The maximum length (tokens) must be validated to be shorter than the length of the parent chunk + - chunk_overlap Define the overlap between adjacent chunks (optional) + + When no parameters are set for the knowledge base, the first upload requires the following parameters to be provided; if not provided, the default parameters will be used. + + Retrieval model + - search_method (string) Search method + - hybrid_search Hybrid search + - semantic_search Semantic search + - full_text_search Full-text search + - reranking_enable (bool) Whether to enable reranking + - reranking_mode (object) Rerank model configuration + - reranking_provider_name (string) Rerank model provider + - reranking_model_name (string) Rerank model name + - top_k (int) Number of results to return + - score_threshold_enabled (bool) Whether to enable score threshold + - score_threshold (float) Score threshold + + + Embedding model name + + + Embedding model provider @@ -155,6 +190,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - high_quality High quality: embedding using embedding model, built as vector database index - economy Economy: Build using inverted index of keyword table index + - doc_form Format of indexed content + - text_model Text documents are directly embedded; `economy` mode defaults to using this form + - hierarchical_model Parent-child mode + - qa_model Q&A Mode: Generates Q&A pairs for segmented documents and then embeds the questions + + - doc_language In Q&A mode, specify the language of the document, for example: English, Chinese + - process_rule Processing rules - mode (string) Cleaning, segmentation mode, automatic / custom - rules (object) Custom rules (in automatic mode, this field is empty) @@ -167,10 +209,36 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - segmentation (object) Segmentation rules - separator Custom segment identifier, currently only allows one delimiter to be set. Default is \n - max_tokens Maximum length (token) defaults to 1000 + - parent_mode Retrieval mode of parent chunks: full-doc full text retrieval / paragraph paragraph retrieval + - subchunk_segmentation (object) Child chunk rules + - separator Segmentation identifier. Currently, only one delimiter is allowed. The default is *** + - max_tokens The maximum length (tokens) must be validated to be shorter than the length of the parent chunk + - chunk_overlap Define the overlap between adjacent chunks (optional) Files that need to be uploaded. + When no parameters are set for the knowledge base, the first upload requires the following parameters to be provided; if not provided, the default parameters will be used. + + Retrieval model + - search_method (string) Search method + - hybrid_search Hybrid search + - semantic_search Semantic search + - full_text_search Full-text search + - reranking_enable (bool) Whether to enable reranking + - reranking_mode (object) Rerank model configuration + - reranking_provider_name (string) Rerank model provider + - reranking_model_name (string) Rerank model name + - top_k (int) Number of results to return + - score_threshold_enabled (bool) Whether to enable score threshold + - score_threshold (float) Score threshold + + + Embedding model name + + + Embedding model provider + @@ -449,6 +517,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - segmentation (object) Segmentation rules - separator Custom segment identifier, currently only allows one delimiter to be set. Default is \n - max_tokens Maximum length (token) defaults to 1000 + - parent_mode Retrieval mode of parent chunks: full-doc full text retrieval / paragraph paragraph retrieval + - subchunk_segmentation (object) Child chunk rules + - separator Segmentation identifier. Currently, only one delimiter is allowed. The default is *** + - max_tokens The maximum length (tokens) must be validated to be shorter than the length of the parent chunk + - chunk_overlap Define the overlap between adjacent chunks (optional) @@ -546,6 +619,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - segmentation (object) Segmentation rules - separator Custom segment identifier, currently only allows one delimiter to be set. Default is \n - max_tokens Maximum length (token) defaults to 1000 + - parent_mode Retrieval mode of parent chunks: full-doc full text retrieval / paragraph paragraph retrieval + - subchunk_segmentation (object) Child chunk rules + - separator Segmentation identifier. Currently, only one delimiter is allowed. The default is *** + - max_tokens The maximum length (tokens) must be validated to be shorter than the length of the parent chunk + - chunk_overlap Define the overlap between adjacent chunks (optional) @@ -984,7 +1062,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from @@ -1009,6 +1087,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - answer (text) Answer content, passed if the knowledge is in Q&A mode (optional) - keywords (list) Keyword (optional) - enabled (bool) False / true (optional) + - regenerate_child_chunks (bool) Whether to regenerate child chunks (optional) @@ -1069,6 +1148,57 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
+ + + + ### Path + + + Knowledge ID + + + Document ID + + + + + + ```bash {{ title: 'cURL' }} + curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/upload-file' \ + --header 'Authorization: Bearer {api_key}' \ + --header 'Content-Type: application/json' + ``` + + + ```json {{ title: 'Response' }} + { + "id": "file_id", + "name": "file_name", + "size": 1024, + "extension": "txt", + "url": "preview_url", + "download_url": "download_url", + "mime_type": "text/plain", + "created_by": "user_id", + "created_at": 1728734540, + } + ``` + + + + +
+ reranking_mode (object) Rerank model configuration, required if reranking is enabled - reranking_provider_name (string) Rerank model provider - reranking_model_name (string) Rerank model name - - weights (double) Semantic search weight setting in hybrid search mode + - weights (float) Semantic search weight setting in hybrid search mode - top_k (integer) Number of results to return (optional) - score_threshold_enabled (bool) Whether to enable score threshold - - score_threshold (double) Score threshold + - score_threshold (float) Score threshold Unused field diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index db15ede9f..334591743 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -1,5 +1,5 @@ import { CodeGroup } from '@/app/components/develop/code.tsx' -import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from '@/app/components/develop/md.tsx' +import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstruction, Paragraph } from '@/app/components/develop/md.tsx' # 知识库 API @@ -52,6 +52,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - high_quality 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引 - economy 经济:使用 keyword table index 的倒排索引进行构建 + + 索引内容的形式 + - text_model text 文档直接 embedding,经济模式默认为该模式 + - hierarchical_model parent-child 模式 + - qa_model Q&A 模式:为分片文档生成 Q&A 对,然后对问题进行 embedding + + + 在 Q&A 模式下,指定文档的语言,例如:EnglishChinese + 处理规则 - mode (string) 清洗、分段模式 ,automatic 自动 / custom 自定义 @@ -63,8 +72,34 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - remove_urls_emails 删除 URL、电子邮件地址 - enabled (bool) 是否选中该规则,不传入文档 ID 时代表默认值 - segmentation (object) 分段规则 - - separator 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n + - separator 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - max_tokens 最大长度(token)默认为 1000 + - parent_mode 父分段的召回模式 full-doc 全文召回 / paragraph 段落召回 + - subchunk_segmentation (object) 子分段规则 + - separator 分段标识符,目前仅允许设置一个分隔符。默认为 *** + - max_tokens 最大长度 (token) 需要校验小于父级的长度 + - chunk_overlap 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填) + + 当知识库未设置任何参数的时候,首次上传需要提供以下参数,未提供则使用默认选项: + + 检索模式 + - search_method (string) 检索方法 + - hybrid_search 混合检索 + - semantic_search 语义检索 + - full_text_search 全文检索 + - reranking_enable (bool) 是否开启rerank + - reranking_model (object) Rerank 模型配置 + - reranking_provider_name (string) Rerank 模型的提供商 + - reranking_model_name (string) Rerank 模型的名称 + - top_k (int) 召回条数 + - score_threshold_enabled (bool)是否开启召回分数限制 + - score_threshold (float) 召回分数限制 + + + Embedding 模型名称 + + + Embedding 模型供应商 @@ -155,6 +190,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - high_quality 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引 - economy 经济:使用 keyword table index 的倒排索引进行构建 + - doc_form 索引内容的形式 + - text_model text 文档直接 embedding,经济模式默认为该模式 + - hierarchical_model parent-child 模式 + - qa_model Q&A 模式:为分片文档生成 Q&A 对,然后对问题进行 embedding + + - doc_language 在 Q&A 模式下,指定文档的语言,例如:EnglishChinese + - process_rule 处理规则 - mode (string) 清洗、分段模式 ,automatic 自动 / custom 自定义 - rules (object) 自定义规则(自动模式下,该字段为空) @@ -167,10 +209,36 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - segmentation (object) 分段规则 - separator 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - max_tokens 最大长度(token)默认为 1000 + - parent_mode 父分段的召回模式 full-doc 全文召回 / paragraph 段落召回 + - subchunk_segmentation (object) 子分段规则 + - separator 分段标识符,目前仅允许设置一个分隔符。默认为 *** + - max_tokens 最大长度 (token) 需要校验小于父级的长度 + - chunk_overlap 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填) 需要上传的文件。 + 当知识库未设置任何参数的时候,首次上传需要提供以下参数,未提供则使用默认选项: + + 检索模式 + - search_method (string) 检索方法 + - hybrid_search 混合检索 + - semantic_search 语义检索 + - full_text_search 全文检索 + - reranking_enable (bool) 是否开启rerank + - reranking_model (object) Rerank 模型配置 + - reranking_provider_name (string) Rerank 模型的提供商 + - reranking_model_name (string) Rerank 模型的名称 + - top_k (int) 召回条数 + - score_threshold_enabled (bool)是否开启召回分数限制 + - score_threshold (float) 召回分数限制 + + + Embedding 模型名称 + + + Embedding 模型供应商 + @@ -411,7 +479,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from @@ -449,6 +517,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - segmentation (object) 分段规则 - separator 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - max_tokens 最大长度(token)默认为 1000 + - parent_mode 父分段的召回模式 full-doc 全文召回 / paragraph 段落召回 + - subchunk_segmentation (object) 子分段规则 + - separator 分段标识符,目前仅允许设置一个分隔符。默认为 *** + - max_tokens 最大长度 (token) 需要校验小于父级的长度 + - chunk_overlap 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填) @@ -508,7 +581,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from @@ -546,6 +619,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - segmentation (object) 分段规则 - separator 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - max_tokens 最大长度(token)默认为 1000 + - parent_mode 父分段的召回模式 full-doc 全文召回 / paragraph 段落召回 + - subchunk_segmentation (object) 子分段规则 + - separator 分段标识符,目前仅允许设置一个分隔符。默认为 *** + - max_tokens 最大长度 (token) 需要校验小于父级的长度 + - chunk_overlap 分段重叠指的是在对数据进行分段时,段与段之间存在一定的重叠部分(选填) @@ -1009,6 +1087,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - answer (text) 答案内容,非必填,如果知识库的模式为 Q&A 模式则传值 - keywords (list) 关键字,非必填 - enabled (bool) false/true,非必填 + - regenerate_child_chunks (bool) 是否重新生成子分段,非必填 @@ -1070,6 +1149,57 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
+ + + + ### Path + + + 知识库 ID + + + 文档 ID + + + + + + ```bash {{ title: 'cURL' }} + curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/upload-file' \ + --header 'Authorization: Bearer {api_key}' \ + --header 'Content-Type: application/json' + ``` + + + ```json {{ title: 'Response' }} + { + "id": "file_id", + "name": "file_name", + "size": 1024, + "extension": "txt", + "url": "preview_url", + "download_url": "download_url", + "mime_type": "text/plain", + "created_by": "user_id", + "created_at": 1728734540, + } + ``` + + + + +
+ full_text_search 全文检索 - hybrid_search 混合检索 - reranking_enable (bool) 是否启用 Reranking,非必填,如果检索模式为 semantic_search 模式或者 hybrid_search 则传值 - - reranking_mode (object) Rerank模型配置,非必填,如果启用了 reranking 则传值 + - reranking_mode (object) Rerank 模型配置,非必填,如果启用了 reranking 则传值 - reranking_provider_name (string) Rerank 模型提供商 - reranking_model_name (string) Rerank 模型名称 - - weights (double) 混合检索模式下语意检索的权重设置 + - weights (float) 混合检索模式下语意检索的权重设置 - top_k (integer) 返回结果数量,非必填 - score_threshold_enabled (bool) 是否开启 score 阈值 - - score_threshold (double) Score 阈值 + - score_threshold (float) Score 阈值 未启用字段 diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx index f0f7e0321..af36d4d96 100644 --- a/web/app/(commonLayout)/layout.tsx +++ b/web/app/(commonLayout)/layout.tsx @@ -8,27 +8,24 @@ import Header from '@/app/components/header' import { EventEmitterContextProvider } from '@/context/event-emitter' import { ProviderContextProvider } from '@/context/provider-context' import { ModalContextProvider } from '@/context/modal-context' -import { TanstackQueryIniter } from '@/context/query-client' const Layout = ({ children }: { children: ReactNode }) => { return ( <> - - - - - - -
- - {children} - - - - - + + + + + +
+ + {children} + + + + ) diff --git a/web/app/account/account-page/index.tsx b/web/app/account/account-page/index.tsx index c7af05793..443501956 100644 --- a/web/app/account/account-page/index.tsx +++ b/web/app/account/account-page/index.tsx @@ -3,11 +3,11 @@ import { useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' +import DeleteAccount from '../delete-account' import s from './index.module.css' import Collapse from '@/app/components/header/account-setting/collapse' import type { IItem } from '@/app/components/header/account-setting/collapse' import Modal from '@/app/components/base/modal' -import Confirm from '@/app/components/base/confirm' import Button from '@/app/components/base/button' import { updateUserProfile } from '@/service/common' import { useAppContext } from '@/context/app-context' @@ -296,37 +296,9 @@ export default function AccountPage() { } { showDeleteAccountModal && ( - setShowDeleteAccountModal(false)} onConfirm={() => setShowDeleteAccountModal(false)} - showCancel={false} - type='warning' - title={t('common.account.delete')} - content={ - <> -
- {t('common.account.deleteTip')} -
- -
{`${t('common.account.delete')}: ${userProfile.email}`}
- - } - confirmText={t('common.operation.ok') as string} /> ) } diff --git a/web/app/account/delete-account/components/check-email.tsx b/web/app/account/delete-account/components/check-email.tsx new file mode 100644 index 000000000..84ea8a4c2 --- /dev/null +++ b/web/app/account/delete-account/components/check-email.tsx @@ -0,0 +1,48 @@ +'use client' +import { useTranslation } from 'react-i18next' +import { useCallback, useState } from 'react' +import Link from 'next/link' +import { useSendDeleteAccountEmail } from '../state' +import { useAppContext } from '@/context/app-context' +import Input from '@/app/components/base/input' +import Button from '@/app/components/base/button' + +type DeleteAccountProps = { + onCancel: () => void + onConfirm: () => void +} + +export default function CheckEmail(props: DeleteAccountProps) { + const { t } = useTranslation() + const { userProfile } = useAppContext() + const [userInputEmail, setUserInputEmail] = useState('') + + const { isPending: isSendingEmail, mutateAsync: getDeleteEmailVerifyCode } = useSendDeleteAccountEmail() + + const handleConfirm = useCallback(async () => { + try { + const ret = await getDeleteEmailVerifyCode() + if (ret.result === 'success') + props.onConfirm() + } + catch (error) { console.error(error) } + }, [getDeleteEmailVerifyCode, props]) + + return <> +
+ {t('common.account.deleteTip')} +
+
+ {t('common.account.deletePrivacyLinkTip')} + {t('common.account.deletePrivacyLink')} +
+ + { + setUserInputEmail(e.target.value) + }} /> +
+ + +
+ +} diff --git a/web/app/account/delete-account/components/feed-back.tsx b/web/app/account/delete-account/components/feed-back.tsx new file mode 100644 index 000000000..1d01c69d9 --- /dev/null +++ b/web/app/account/delete-account/components/feed-back.tsx @@ -0,0 +1,68 @@ +'use client' +import { useTranslation } from 'react-i18next' +import { useCallback, useState } from 'react' +import { useRouter } from 'next/navigation' +import { useDeleteAccountFeedback } from '../state' +import { useAppContext } from '@/context/app-context' +import Button from '@/app/components/base/button' +import CustomDialog from '@/app/components/base/dialog' +import Textarea from '@/app/components/base/textarea' +import Toast from '@/app/components/base/toast' +import { logout } from '@/service/common' + +type DeleteAccountProps = { + onCancel: () => void + onConfirm: () => void +} + +export default function FeedBack(props: DeleteAccountProps) { + const { t } = useTranslation() + const { userProfile } = useAppContext() + const router = useRouter() + const [userFeedback, setUserFeedback] = useState('') + const { isPending, mutateAsync: sendFeedback } = useDeleteAccountFeedback() + + const handleSuccess = useCallback(async () => { + try { + await logout({ + url: '/logout', + params: {}, + }) + localStorage.removeItem('refresh_token') + localStorage.removeItem('console_token') + router.push('/signin') + Toast.notify({ type: 'info', message: t('common.account.deleteSuccessTip') }) + } + catch (error) { console.error(error) } + }, [router, t]) + + const handleSubmit = useCallback(async () => { + try { + await sendFeedback({ feedback: userFeedback, email: userProfile.email }) + props.onConfirm() + await handleSuccess() + } + catch (error) { console.error(error) } + }, [handleSuccess, userFeedback, sendFeedback, userProfile, props]) + + const handleSkip = useCallback(() => { + props.onCancel() + handleSuccess() + }, [handleSuccess, props]) + return + +