diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py index bf9b093cb3f02b..fc71d64714bd96 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py @@ -1,7 +1,8 @@ from .__version__ import __version__ from ._client import ZhipuAI -from .core._errors import ( +from .core import ( APIAuthenticationError, + APIConnectionError, APIInternalError, APIReachLimitError, APIRequestFailedError, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py index 659f38d7ff32d2..51f8c49ecb827d 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py @@ -1 +1 @@ -__version__ = "v2.0.1" +__version__ = "v2.1.0" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py index df9e506095fab9..705d371e628f08 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py @@ -9,15 +9,13 @@ from typing_extensions import override from . import api_resource -from .core import _jwt_token -from .core._base_type import NOT_GIVEN, NotGiven -from .core._errors import ZhipuAIError -from .core._http_client import ZHIPUAI_DEFAULT_MAX_RETRIES, HttpClient +from .core import NOT_GIVEN, ZHIPUAI_DEFAULT_MAX_RETRIES, HttpClient, NotGiven, ZhipuAIError, _jwt_token class ZhipuAI(HttpClient): - chat: api_resource.chat + chat: api_resource.chat.Chat api_key: str + _disable_token_cache: bool = True def __init__( self, @@ -28,10 +26,15 @@ def __init__( max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, http_client: httpx.Client | None = None, custom_headers: Mapping[str, str] | None = None, + disable_token_cache: bool = True, + _strict_response_validation: bool = False, ) -> None: if api_key is None: - raise ZhipuAIError("No api_key provided, please provide it through parameters or environment variables") + api_key = os.environ.get("ZHIPUAI_API_KEY") + if api_key is None: + raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供") self.api_key = api_key + self._disable_token_cache = disable_token_cache if base_url is None: base_url = os.environ.get("ZHIPUAI_BASE_URL") @@ -42,21 +45,31 @@ def __init__( super().__init__( version=__version__, base_url=base_url, + max_retries=max_retries, timeout=timeout, custom_httpx_client=http_client, custom_headers=custom_headers, + _strict_response_validation=_strict_response_validation, ) self.chat = api_resource.chat.Chat(self) self.images = api_resource.images.Images(self) self.embeddings = api_resource.embeddings.Embeddings(self) self.files = api_resource.files.Files(self) self.fine_tuning = api_resource.fine_tuning.FineTuning(self) + self.batches = api_resource.Batches(self) + self.knowledge = api_resource.Knowledge(self) + self.tools = api_resource.Tools(self) + self.videos = api_resource.Videos(self) + self.assistant = api_resource.Assistant(self) @property @override - def _auth_headers(self) -> dict[str, str]: + def auth_headers(self) -> dict[str, str]: api_key = self.api_key - return {"Authorization": f"{_jwt_token.generate_token(api_key)}"} + if self._disable_token_cache: + return {"Authorization": f"Bearer {api_key}"} + else: + return {"Authorization": f"Bearer {_jwt_token.generate_token(api_key)}"} def __del__(self) -> None: if not hasattr(self, "_has_custom_http_client") or not hasattr(self, "close") or not hasattr(self, "_client"): diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py index 0a90e21e48bcca..4fe0719dde3e0b 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py @@ -1,5 +1,34 @@ -from .chat import chat +from .assistant import ( + Assistant, +) +from .batches import Batches +from .chat import ( + AsyncCompletions, + Chat, + Completions, +) from .embeddings import Embeddings -from .files import Files -from .fine_tuning import fine_tuning +from .files import Files, FilesWithRawResponse +from .fine_tuning import FineTuning from .images import Images +from .knowledge import Knowledge +from .tools import Tools +from .videos import ( + Videos, +) + +__all__ = [ + "Videos", + "AsyncCompletions", + "Chat", + "Completions", + "Images", + "Embeddings", + "Files", + "FilesWithRawResponse", + "FineTuning", + "Batches", + "Knowledge", + "Tools", + "Assistant", +] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/assistant/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/assistant/__init__.py new file mode 100644 index 00000000000000..ce619aa7f09222 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/assistant/__init__.py @@ -0,0 +1,3 @@ +from .assistant import Assistant + +__all__ = ["Assistant"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/assistant/assistant.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/assistant/assistant.py new file mode 100644 index 00000000000000..f772340a82c4be --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/assistant/assistant.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import httpx + +from ...core import ( + NOT_GIVEN, + BaseAPI, + Body, + Headers, + NotGiven, + StreamResponse, + deepcopy_minimal, + make_request_options, + maybe_transform, +) +from ...types.assistant import AssistantCompletion +from ...types.assistant.assistant_conversation_resp import ConversationUsageListResp +from ...types.assistant.assistant_support_resp import AssistantSupportResp + +if TYPE_CHECKING: + from ..._client import ZhipuAI + +from ...types.assistant import assistant_conversation_params, assistant_create_params + +__all__ = ["Assistant"] + + +class Assistant(BaseAPI): + def __init__(self, client: ZhipuAI) -> None: + super().__init__(client) + + def conversation( + self, + assistant_id: str, + model: str, + messages: list[assistant_create_params.ConversationMessage], + *, + stream: bool = True, + conversation_id: Optional[str] = None, + attachments: Optional[list[assistant_create_params.AssistantAttachments]] = None, + metadata: dict | None = None, + request_id: str = None, + user_id: str = None, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> StreamResponse[AssistantCompletion]: + body = deepcopy_minimal( + { + "assistant_id": assistant_id, + "model": model, + "messages": messages, + "stream": stream, + "conversation_id": conversation_id, + "attachments": attachments, + "metadata": metadata, + "request_id": request_id, + "user_id": user_id, + } + ) + return self._post( + "/assistant", + body=maybe_transform(body, assistant_create_params.AssistantParameters), + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=AssistantCompletion, + stream=stream or True, + stream_cls=StreamResponse[AssistantCompletion], + ) + + def query_support( + self, + *, + assistant_id_list: list[str] = None, + request_id: str = None, + user_id: str = None, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> AssistantSupportResp: + body = deepcopy_minimal( + { + "assistant_id_list": assistant_id_list, + "request_id": request_id, + "user_id": user_id, + } + ) + return self._post( + "/assistant/list", + body=body, + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=AssistantSupportResp, + ) + + def query_conversation_usage( + self, + assistant_id: str, + page: int = 1, + page_size: int = 10, + *, + request_id: str = None, + user_id: str = None, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ConversationUsageListResp: + body = deepcopy_minimal( + { + "assistant_id": assistant_id, + "page": page, + "page_size": page_size, + "request_id": request_id, + "user_id": user_id, + } + ) + return self._post( + "/assistant/conversation/list", + body=maybe_transform(body, assistant_conversation_params.ConversationParameters), + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=ConversationUsageListResp, + ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/batches.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/batches.py new file mode 100644 index 00000000000000..ae2f2be85eb9b4 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/batches.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, Optional + +import httpx + +from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options, maybe_transform +from ..core.pagination import SyncCursorPage +from ..types import batch_create_params, batch_list_params +from ..types.batch import Batch + +if TYPE_CHECKING: + from .._client import ZhipuAI + + +class Batches(BaseAPI): + def __init__(self, client: ZhipuAI) -> None: + super().__init__(client) + + def create( + self, + *, + completion_window: str | None = None, + endpoint: Literal["/v1/chat/completions", "/v1/embeddings"], + input_file_id: str, + metadata: Optional[dict[str, str]] | NotGiven = NOT_GIVEN, + auto_delete_input_file: bool = True, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Batch: + return self._post( + "/batches", + body=maybe_transform( + { + "completion_window": completion_window, + "endpoint": endpoint, + "input_file_id": input_file_id, + "metadata": metadata, + "auto_delete_input_file": auto_delete_input_file, + }, + batch_create_params.BatchCreateParams, + ), + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=Batch, + ) + + def retrieve( + self, + batch_id: str, + *, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Batch: + """ + Retrieves a batch. + + Args: + extra_headers: Send extra headers + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not batch_id: + raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}") + return self._get( + f"/batches/{batch_id}", + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=Batch, + ) + + def list( + self, + *, + after: str | NotGiven = NOT_GIVEN, + limit: int | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> SyncCursorPage[Batch]: + """List your organization's batches. + + Args: + after: A cursor for use in pagination. + + `after` is an object ID that defines your place + in the list. For instance, if you make a list request and receive 100 objects, + ending with obj_foo, your subsequent call can include after=obj_foo in order to + fetch the next page of the list. + + limit: A limit on the number of objects to be returned. Limit can range between 1 and + 100, and the default is 20. + + extra_headers: Send extra headers + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return self._get_api_list( + "/batches", + page=SyncCursorPage[Batch], + options=make_request_options( + extra_headers=extra_headers, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + { + "after": after, + "limit": limit, + }, + batch_list_params.BatchListParams, + ), + ), + model=Batch, + ) + + def cancel( + self, + batch_id: str, + *, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Batch: + """ + Cancels an in-progress batch. + + Args: + batch_id: The ID of the batch to cancel. + extra_headers: Send extra headers + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + + """ + if not batch_id: + raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}") + return self._post( + f"/batches/{batch_id}/cancel", + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=Batch, + ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py index e69de29bb2d1d6..5cd8dc6f339a60 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py @@ -0,0 +1,5 @@ +from .async_completions import AsyncCompletions +from .chat import Chat +from .completions import Completions + +__all__ = ["AsyncCompletions", "Chat", "Completions"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py index 1f8011973951b3..d8ecc310644d17 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py @@ -1,13 +1,25 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING, Literal, Optional, Union import httpx -from ...core._base_api import BaseAPI -from ...core._base_type import NOT_GIVEN, Headers, NotGiven -from ...core._http_client import make_user_request_input +from ...core import ( + NOT_GIVEN, + BaseAPI, + Body, + Headers, + NotGiven, + drop_prefix_image_data, + make_request_options, + maybe_transform, +) from ...types.chat.async_chat_completion import AsyncCompletion, AsyncTaskStatus +from ...types.chat.code_geex import code_geex_params +from ...types.sensitive_word_check import SensitiveWordCheckRequest + +logger = logging.getLogger(__name__) if TYPE_CHECKING: from ..._client import ZhipuAI @@ -22,6 +34,7 @@ def create( *, model: str, request_id: Optional[str] | NotGiven = NOT_GIVEN, + user_id: Optional[str] | NotGiven = NOT_GIVEN, do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, temperature: Optional[float] | NotGiven = NOT_GIVEN, top_p: Optional[float] | NotGiven = NOT_GIVEN, @@ -29,50 +42,74 @@ def create( seed: int | NotGiven = NOT_GIVEN, messages: Union[str, list[str], list[int], list[list[int]], None], stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN, tools: Optional[object] | NotGiven = NOT_GIVEN, tool_choice: str | NotGiven = NOT_GIVEN, + meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN, + extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN, extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, + extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> AsyncTaskStatus: _cast_type = AsyncTaskStatus + logger.debug(f"temperature:{temperature}, top_p:{top_p}") + if temperature is not None and temperature != NOT_GIVEN: + if temperature <= 0: + do_sample = False + temperature = 0.01 + # logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperture不生效)") # noqa: E501 + if temperature >= 1: + temperature = 0.99 + # logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间") + if top_p is not None and top_p != NOT_GIVEN: + if top_p >= 1: + top_p = 0.99 + # logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1") + if top_p <= 0: + top_p = 0.01 + # logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1") + + logger.debug(f"temperature:{temperature}, top_p:{top_p}") + if isinstance(messages, list): + for item in messages: + if item.get("content"): + item["content"] = drop_prefix_image_data(item["content"]) - if disable_strict_validation: - _cast_type = object + body = { + "model": model, + "request_id": request_id, + "user_id": user_id, + "temperature": temperature, + "top_p": top_p, + "do_sample": do_sample, + "max_tokens": max_tokens, + "seed": seed, + "messages": messages, + "stop": stop, + "sensitive_word_check": sensitive_word_check, + "tools": tools, + "tool_choice": tool_choice, + "meta": meta, + "extra": maybe_transform(extra, code_geex_params.CodeGeexExtra), + } return self._post( "/async/chat/completions", - body={ - "model": model, - "request_id": request_id, - "temperature": temperature, - "top_p": top_p, - "do_sample": do_sample, - "max_tokens": max_tokens, - "seed": seed, - "messages": messages, - "stop": stop, - "sensitive_word_check": sensitive_word_check, - "tools": tools, - "tool_choice": tool_choice, - }, - options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), + body=body, + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), cast_type=_cast_type, - enable_stream=False, + stream=False, ) def retrieve_completion_result( self, id: str, extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, + extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> Union[AsyncCompletion, AsyncTaskStatus]: _cast_type = Union[AsyncCompletion, AsyncTaskStatus] - if disable_strict_validation: - _cast_type = object return self._get( path=f"/async-result/{id}", cast_type=_cast_type, - options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py index 92362fc50a7252..b3cc46566c7bf3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py @@ -1,17 +1,18 @@ from typing import TYPE_CHECKING -from ...core._base_api import BaseAPI +from ...core import BaseAPI, cached_property from .async_completions import AsyncCompletions from .completions import Completions if TYPE_CHECKING: - from ..._client import ZhipuAI + pass class Chat(BaseAPI): - completions: Completions + @cached_property + def completions(self) -> Completions: + return Completions(self._client) - def __init__(self, client: "ZhipuAI") -> None: - super().__init__(client) - self.completions = Completions(client) - self.asyncCompletions = AsyncCompletions(client) + @cached_property + def asyncCompletions(self) -> AsyncCompletions: # noqa: N802 + return AsyncCompletions(self._client) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py index ec29f33864203c..1c23473a03ae32 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py @@ -1,15 +1,28 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING, Literal, Optional, Union import httpx -from ...core._base_api import BaseAPI -from ...core._base_type import NOT_GIVEN, Headers, NotGiven -from ...core._http_client import make_user_request_input -from ...core._sse_client import StreamResponse +from ...core import ( + NOT_GIVEN, + BaseAPI, + Body, + Headers, + NotGiven, + StreamResponse, + deepcopy_minimal, + drop_prefix_image_data, + make_request_options, + maybe_transform, +) from ...types.chat.chat_completion import Completion from ...types.chat.chat_completion_chunk import ChatCompletionChunk +from ...types.chat.code_geex import code_geex_params +from ...types.sensitive_word_check import SensitiveWordCheckRequest + +logger = logging.getLogger(__name__) if TYPE_CHECKING: from ..._client import ZhipuAI @@ -24,6 +37,7 @@ def create( *, model: str, request_id: Optional[str] | NotGiven = NOT_GIVEN, + user_id: Optional[str] | NotGiven = NOT_GIVEN, do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, temperature: Optional[float] | NotGiven = NOT_GIVEN, @@ -32,23 +46,43 @@ def create( seed: int | NotGiven = NOT_GIVEN, messages: Union[str, list[str], list[int], object, None], stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN, tools: Optional[object] | NotGiven = NOT_GIVEN, tool_choice: str | NotGiven = NOT_GIVEN, + meta: Optional[dict[str, str]] | NotGiven = NOT_GIVEN, + extra: Optional[code_geex_params.CodeGeexExtra] | NotGiven = NOT_GIVEN, extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, + extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> Completion | StreamResponse[ChatCompletionChunk]: - _cast_type = Completion - _stream_cls = StreamResponse[ChatCompletionChunk] - if disable_strict_validation: - _cast_type = object - _stream_cls = StreamResponse[object] - return self._post( - "/chat/completions", - body={ + logger.debug(f"temperature:{temperature}, top_p:{top_p}") + if temperature is not None and temperature != NOT_GIVEN: + if temperature <= 0: + do_sample = False + temperature = 0.01 + # logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间,do_sample重写为:false(参数top_p temperture不生效)") # noqa: E501 + if temperature >= 1: + temperature = 0.99 + # logger.warning("temperature:取值范围是:(0.0, 1.0) 开区间") + if top_p is not None and top_p != NOT_GIVEN: + if top_p >= 1: + top_p = 0.99 + # logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1") + if top_p <= 0: + top_p = 0.01 + # logger.warning("top_p:取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1") + + logger.debug(f"temperature:{temperature}, top_p:{top_p}") + if isinstance(messages, list): + for item in messages: + if item.get("content"): + item["content"] = drop_prefix_image_data(item["content"]) + + body = deepcopy_minimal( + { "model": model, "request_id": request_id, + "user_id": user_id, "temperature": temperature, "top_p": top_p, "do_sample": do_sample, @@ -60,11 +94,15 @@ def create( "stream": stream, "tools": tools, "tool_choice": tool_choice, - }, - options=make_user_request_input( - extra_headers=extra_headers, - ), - cast_type=_cast_type, - enable_stream=stream or False, - stream_cls=_stream_cls, + "meta": meta, + "extra": maybe_transform(extra, code_geex_params.CodeGeexExtra), + } + ) + return self._post( + "/chat/completions", + body=body, + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=Completion, + stream=stream or False, + stream_cls=StreamResponse[ChatCompletionChunk], ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py index 2308a204514e17..4b4baef9421ba6 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py @@ -4,9 +4,7 @@ import httpx -from ..core._base_api import BaseAPI -from ..core._base_type import NOT_GIVEN, Headers, NotGiven -from ..core._http_client import make_user_request_input +from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options from ..types.embeddings import EmbeddingsResponded if TYPE_CHECKING: @@ -22,10 +20,13 @@ def create( *, input: Union[str, list[str], list[int], list[list[int]]], model: Union[str], + dimensions: Union[int] | NotGiven = NOT_GIVEN, encoding_format: str | NotGiven = NOT_GIVEN, user: str | NotGiven = NOT_GIVEN, + request_id: Optional[str] | NotGiven = NOT_GIVEN, sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, extra_headers: Headers | None = None, + extra_body: Body | None = None, disable_strict_validation: Optional[bool] | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> EmbeddingsResponded: @@ -37,11 +38,13 @@ def create( body={ "input": input, "model": model, + "dimensions": dimensions, "encoding_format": encoding_format, "user": user, + "request_id": request_id, "sensitive_word_check": sensitive_word_check, }, - options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), cast_type=_cast_type, - enable_stream=False, + stream=False, ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py index f2ac74bffa8439..ba9de75b7ef092 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py @@ -1,19 +1,30 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from collections.abc import Mapping +from typing import TYPE_CHECKING, Literal, cast import httpx -from ..core._base_api import BaseAPI -from ..core._base_type import NOT_GIVEN, FileTypes, Headers, NotGiven -from ..core._files import is_file_content -from ..core._http_client import make_user_request_input -from ..types.file_object import FileObject, ListOfFileObject +from ..core import ( + NOT_GIVEN, + BaseAPI, + Body, + FileTypes, + Headers, + NotGiven, + _legacy_binary_response, + _legacy_response, + deepcopy_minimal, + extract_files, + make_request_options, + maybe_transform, +) +from ..types.files import FileDeleted, FileObject, ListOfFileObject, UploadDetail, file_create_params if TYPE_CHECKING: from .._client import ZhipuAI -__all__ = ["Files"] +__all__ = ["Files", "FilesWithRawResponse"] class Files(BaseAPI): @@ -23,30 +34,69 @@ def __init__(self, client: ZhipuAI) -> None: def create( self, *, - file: FileTypes, - purpose: str, + file: FileTypes = None, + upload_detail: list[UploadDetail] = None, + purpose: Literal["fine-tune", "retrieval", "batch"], + knowledge_id: str = None, + sentence_size: int = None, extra_headers: Headers | None = None, + extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FileObject: - if not is_file_content(file): - prefix = f"Expected file input `{file!r}`" - raise RuntimeError( - f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(file)} instead." - ) from None - files = [("file", file)] - - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} - + if not file and not upload_detail: + raise ValueError("At least one of `file` and `upload_detail` must be provided.") + body = deepcopy_minimal( + { + "file": file, + "upload_detail": upload_detail, + "purpose": purpose, + "knowledge_id": knowledge_id, + "sentence_size": sentence_size, + } + ) + files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) + if files: + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return self._post( "/files", - body={ - "purpose": purpose, - }, + body=maybe_transform(body, file_create_params.FileCreateParams), files=files, - options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), cast_type=FileObject, ) + # def retrieve( + # self, + # file_id: str, + # *, + # extra_headers: Headers | None = None, + # extra_body: Body | None = None, + # timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + # ) -> FileObject: + # """ + # Returns information about a specific file. + # + # Args: + # file_id: The ID of the file to retrieve information about + # extra_headers: Send extra headers + # + # extra_body: Add additional JSON properties to the request + # + # timeout: Override the client-level default timeout for this request, in seconds + # """ + # if not file_id: + # raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") + # return self._get( + # f"/files/{file_id}", + # options=make_request_options( + # extra_headers=extra_headers, extra_body=extra_body, timeout=timeout + # ), + # cast_type=FileObject, + # ) + def list( self, *, @@ -55,13 +105,15 @@ def list( after: str | NotGiven = NOT_GIVEN, order: str | NotGiven = NOT_GIVEN, extra_headers: Headers | None = None, + extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ListOfFileObject: return self._get( "/files", cast_type=ListOfFileObject, - options=make_user_request_input( + options=make_request_options( extra_headers=extra_headers, + extra_body=extra_body, timeout=timeout, query={ "purpose": purpose, @@ -71,3 +123,72 @@ def list( }, ), ) + + def delete( + self, + file_id: str, + *, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> FileDeleted: + """ + Delete a file. + + Args: + file_id: The ID of the file to delete + extra_headers: Send extra headers + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not file_id: + raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") + return self._delete( + f"/files/{file_id}", + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=FileDeleted, + ) + + def content( + self, + file_id: str, + *, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> _legacy_response.HttpxBinaryResponseContent: + """ + Returns the contents of the specified file. + + Args: + extra_headers: Send extra headers + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not file_id: + raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") + extra_headers = {"Accept": "application/binary", **(extra_headers or {})} + return self._get( + f"/files/{file_id}/content", + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=_legacy_binary_response.HttpxBinaryResponseContent, + ) + + +class FilesWithRawResponse: + def __init__(self, files: Files) -> None: + self._files = files + + self.create = _legacy_response.to_raw_response_wrapper( + files.create, + ) + self.list = _legacy_response.to_raw_response_wrapper( + files.list, + ) + self.content = _legacy_response.to_raw_response_wrapper( + files.content, + ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py index e69de29bb2d1d6..7c309b83416803 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py @@ -0,0 +1,5 @@ +from .fine_tuning import FineTuning +from .jobs import Jobs +from .models import FineTunedModels + +__all__ = ["Jobs", "FineTunedModels", "FineTuning"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py index dc30bd33edfbbc..8670f7de00df84 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py @@ -1,15 +1,18 @@ from typing import TYPE_CHECKING -from ...core._base_api import BaseAPI +from ...core import BaseAPI, cached_property from .jobs import Jobs +from .models import FineTunedModels if TYPE_CHECKING: - from ..._client import ZhipuAI + pass class FineTuning(BaseAPI): - jobs: Jobs + @cached_property + def jobs(self) -> Jobs: + return Jobs(self._client) - def __init__(self, client: "ZhipuAI") -> None: - super().__init__(client) - self.jobs = Jobs(client) + @cached_property + def models(self) -> FineTunedModels: + return FineTunedModels(self._client) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs/__init__.py new file mode 100644 index 00000000000000..40777a153f272a --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs/__init__.py @@ -0,0 +1,3 @@ +from .jobs import Jobs + +__all__ = ["Jobs"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs/jobs.py similarity index 53% rename from api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py rename to api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs/jobs.py index 3d2e9208a11f17..8b038cadc06407 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs/jobs.py @@ -4,13 +4,23 @@ import httpx -from ...core._base_api import BaseAPI -from ...core._base_type import NOT_GIVEN, Headers, NotGiven -from ...core._http_client import make_user_request_input -from ...types.fine_tuning import FineTuningJob, FineTuningJobEvent, ListOfFineTuningJob, job_create_params +from ....core import ( + NOT_GIVEN, + BaseAPI, + Body, + Headers, + NotGiven, + make_request_options, +) +from ....types.fine_tuning import ( + FineTuningJob, + FineTuningJobEvent, + ListOfFineTuningJob, + job_create_params, +) if TYPE_CHECKING: - from ..._client import ZhipuAI + from ...._client import ZhipuAI __all__ = ["Jobs"] @@ -29,6 +39,7 @@ def create( request_id: Optional[str] | NotGiven = NOT_GIVEN, validation_file: Optional[str] | NotGiven = NOT_GIVEN, extra_headers: Headers | None = None, + extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FineTuningJob: return self._post( @@ -41,7 +52,7 @@ def create( "validation_file": validation_file, "request_id": request_id, }, - options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), cast_type=FineTuningJob, ) @@ -50,11 +61,12 @@ def retrieve( fine_tuning_job_id: str, *, extra_headers: Headers | None = None, + extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FineTuningJob: return self._get( f"/fine_tuning/jobs/{fine_tuning_job_id}", - options=make_user_request_input(extra_headers=extra_headers, timeout=timeout), + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), cast_type=FineTuningJob, ) @@ -64,13 +76,15 @@ def list( after: str | NotGiven = NOT_GIVEN, limit: int | NotGiven = NOT_GIVEN, extra_headers: Headers | None = None, + extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ListOfFineTuningJob: return self._get( "/fine_tuning/jobs", cast_type=ListOfFineTuningJob, - options=make_user_request_input( + options=make_request_options( extra_headers=extra_headers, + extra_body=extra_body, timeout=timeout, query={ "after": after, @@ -79,6 +93,24 @@ def list( ), ) + def cancel( + self, + fine_tuning_job_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # noqa: E501 + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> FineTuningJob: + if not fine_tuning_job_id: + raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}") + return self._post( + f"/fine_tuning/jobs/{fine_tuning_job_id}/cancel", + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=FineTuningJob, + ) + def list_events( self, fine_tuning_job_id: str, @@ -86,13 +118,15 @@ def list_events( after: str | NotGiven = NOT_GIVEN, limit: int | NotGiven = NOT_GIVEN, extra_headers: Headers | None = None, + extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> FineTuningJobEvent: return self._get( f"/fine_tuning/jobs/{fine_tuning_job_id}/events", cast_type=FineTuningJobEvent, - options=make_user_request_input( + options=make_request_options( extra_headers=extra_headers, + extra_body=extra_body, timeout=timeout, query={ "after": after, @@ -100,3 +134,19 @@ def list_events( }, ), ) + + def delete( + self, + fine_tuning_job_id: str, + *, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> FineTuningJob: + if not fine_tuning_job_id: + raise ValueError(f"Expected a non-empty value for `fine_tuning_job_id` but received {fine_tuning_job_id!r}") + return self._delete( + f"/fine_tuning/jobs/{fine_tuning_job_id}", + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=FineTuningJob, + ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/models/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/models/__init__.py new file mode 100644 index 00000000000000..d832635bafbc6f --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/models/__init__.py @@ -0,0 +1,3 @@ +from .fine_tuned_models import FineTunedModels + +__all__ = ["FineTunedModels"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/models/fine_tuned_models.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/models/fine_tuned_models.py new file mode 100644 index 00000000000000..29c023e3b1cd5a --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/models/fine_tuned_models.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import httpx + +from ....core import ( + NOT_GIVEN, + BaseAPI, + Body, + Headers, + NotGiven, + make_request_options, +) +from ....types.fine_tuning.models import FineTunedModelsStatus + +if TYPE_CHECKING: + from ...._client import ZhipuAI + +__all__ = ["FineTunedModels"] + + +class FineTunedModels(BaseAPI): + def __init__(self, client: ZhipuAI) -> None: + super().__init__(client) + + def delete( + self, + fine_tuned_model: str, + *, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> FineTunedModelsStatus: + if not fine_tuned_model: + raise ValueError(f"Expected a non-empty value for `fine_tuned_model` but received {fine_tuned_model!r}") + return self._delete( + f"fine_tuning/fine_tuned_models/{fine_tuned_model}", + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=FineTunedModelsStatus, + ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py index 2692b093af8b43..8ad411913fa115 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py @@ -4,10 +4,9 @@ import httpx -from ..core._base_api import BaseAPI -from ..core._base_type import NOT_GIVEN, Body, Headers, NotGiven -from ..core._http_client import make_user_request_input +from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options from ..types.image import ImagesResponded +from ..types.sensitive_word_check import SensitiveWordCheckRequest if TYPE_CHECKING: from .._client import ZhipuAI @@ -27,8 +26,10 @@ def generations( response_format: Optional[str] | NotGiven = NOT_GIVEN, size: Optional[str] | NotGiven = NOT_GIVEN, style: Optional[str] | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN, user: str | NotGiven = NOT_GIVEN, request_id: Optional[str] | NotGiven = NOT_GIVEN, + user_id: Optional[str] | NotGiven = NOT_GIVEN, extra_headers: Headers | None = None, extra_body: Body | None = None, disable_strict_validation: Optional[bool] | None = None, @@ -45,12 +46,14 @@ def generations( "n": n, "quality": quality, "response_format": response_format, + "sensitive_word_check": sensitive_word_check, "size": size, "style": style, "user": user, + "user_id": user_id, "request_id": request_id, }, - options=make_user_request_input(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), cast_type=_cast_type, - enable_stream=False, + stream=False, ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/__init__.py new file mode 100644 index 00000000000000..5a67d743c35b9b --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/__init__.py @@ -0,0 +1,3 @@ +from .knowledge import Knowledge + +__all__ = ["Knowledge"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/document/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/document/__init__.py new file mode 100644 index 00000000000000..fd289e2232b955 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/document/__init__.py @@ -0,0 +1,3 @@ +from .document import Document + +__all__ = ["Document"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/document/document.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/document/document.py new file mode 100644 index 00000000000000..2c4066d8930342 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/document/document.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING, Literal, Optional, cast + +import httpx + +from ....core import ( + NOT_GIVEN, + BaseAPI, + Body, + FileTypes, + Headers, + NotGiven, + deepcopy_minimal, + extract_files, + make_request_options, + maybe_transform, +) +from ....types.files import UploadDetail, file_create_params +from ....types.knowledge.document import DocumentData, DocumentObject, document_edit_params, document_list_params +from ....types.knowledge.document.document_list_resp import DocumentPage + +if TYPE_CHECKING: + from ...._client import ZhipuAI + +__all__ = ["Document"] + + +class Document(BaseAPI): + def __init__(self, client: ZhipuAI) -> None: + super().__init__(client) + + def create( + self, + *, + file: FileTypes = None, + custom_separator: Optional[list[str]] = None, + upload_detail: list[UploadDetail] = None, + purpose: Literal["retrieval"], + knowledge_id: str = None, + sentence_size: int = None, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> DocumentObject: + if not file and not upload_detail: + raise ValueError("At least one of `file` and `upload_detail` must be provided.") + body = deepcopy_minimal( + { + "file": file, + "upload_detail": upload_detail, + "purpose": purpose, + "custom_separator": custom_separator, + "knowledge_id": knowledge_id, + "sentence_size": sentence_size, + } + ) + files = extract_files(cast(Mapping[str, object], body), paths=[["file"]]) + if files: + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + return self._post( + "/files", + body=maybe_transform(body, file_create_params.FileCreateParams), + files=files, + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=DocumentObject, + ) + + def edit( + self, + document_id: str, + knowledge_type: str, + *, + custom_separator: Optional[list[str]] = None, + sentence_size: Optional[int] = None, + callback_url: Optional[str] = None, + callback_header: Optional[dict[str, str]] = None, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> httpx.Response: + """ + + Args: + document_id: 知识id + knowledge_type: 知识类型: + 1:文章知识: 支持pdf,url,docx + 2.问答知识-文档: 支持pdf,url,docx + 3.问答知识-表格: 支持xlsx + 4.商品库-表格: 支持xlsx + 5.自定义: 支持pdf,url,docx + extra_headers: Send extra headers + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + :param knowledge_type: + :param document_id: + :param timeout: + :param extra_body: + :param callback_header: + :param sentence_size: + :param extra_headers: + :param callback_url: + :param custom_separator: + """ + if not document_id: + raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}") + + body = deepcopy_minimal( + { + "id": document_id, + "knowledge_type": knowledge_type, + "custom_separator": custom_separator, + "sentence_size": sentence_size, + "callback_url": callback_url, + "callback_header": callback_header, + } + ) + + return self._put( + f"/document/{document_id}", + body=maybe_transform(body, document_edit_params.DocumentEditParams), + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=httpx.Response, + ) + + def list( + self, + knowledge_id: str, + *, + purpose: str | NotGiven = NOT_GIVEN, + page: str | NotGiven = NOT_GIVEN, + limit: str | NotGiven = NOT_GIVEN, + order: Literal["desc", "asc"] | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> DocumentPage: + return self._get( + "/files", + options=make_request_options( + extra_headers=extra_headers, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + { + "knowledge_id": knowledge_id, + "purpose": purpose, + "page": page, + "limit": limit, + "order": order, + }, + document_list_params.DocumentListParams, + ), + ), + cast_type=DocumentPage, + ) + + def delete( + self, + document_id: str, + *, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> httpx.Response: + """ + Delete a file. + + Args: + + document_id: 知识id + extra_headers: Send extra headers + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not document_id: + raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}") + + return self._delete( + f"/document/{document_id}", + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=httpx.Response, + ) + + def retrieve( + self, + document_id: str, + *, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> DocumentData: + """ + + Args: + extra_headers: Send extra headers + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not document_id: + raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}") + + return self._get( + f"/document/{document_id}", + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=DocumentData, + ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/knowledge.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/knowledge.py new file mode 100644 index 00000000000000..fea4c73ac997c3 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/knowledge/knowledge.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, Optional + +import httpx + +from ...core import ( + NOT_GIVEN, + BaseAPI, + Body, + Headers, + NotGiven, + cached_property, + deepcopy_minimal, + make_request_options, + maybe_transform, +) +from ...types.knowledge import KnowledgeInfo, KnowledgeUsed, knowledge_create_params, knowledge_list_params +from ...types.knowledge.knowledge_list_resp import KnowledgePage +from .document import Document + +if TYPE_CHECKING: + from ..._client import ZhipuAI + +__all__ = ["Knowledge"] + + +class Knowledge(BaseAPI): + def __init__(self, client: ZhipuAI) -> None: + super().__init__(client) + + @cached_property + def document(self) -> Document: + return Document(self._client) + + def create( + self, + embedding_id: int, + name: str, + *, + customer_identifier: Optional[str] = None, + description: Optional[str] = None, + background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None, + icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None, + bucket_id: Optional[str] = None, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> KnowledgeInfo: + body = deepcopy_minimal( + { + "embedding_id": embedding_id, + "name": name, + "customer_identifier": customer_identifier, + "description": description, + "background": background, + "icon": icon, + "bucket_id": bucket_id, + } + ) + return self._post( + "/knowledge", + body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams), + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=KnowledgeInfo, + ) + + def modify( + self, + knowledge_id: str, + embedding_id: int, + *, + name: str, + description: Optional[str] = None, + background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None, + icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> httpx.Response: + body = deepcopy_minimal( + { + "id": knowledge_id, + "embedding_id": embedding_id, + "name": name, + "description": description, + "background": background, + "icon": icon, + } + ) + return self._put( + f"/knowledge/{knowledge_id}", + body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams), + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=httpx.Response, + ) + + def query( + self, + *, + page: int | NotGiven = 1, + size: int | NotGiven = 10, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> KnowledgePage: + return self._get( + "/knowledge", + options=make_request_options( + extra_headers=extra_headers, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + { + "page": page, + "size": size, + }, + knowledge_list_params.KnowledgeListParams, + ), + ), + cast_type=KnowledgePage, + ) + + def delete( + self, + knowledge_id: str, + *, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> httpx.Response: + """ + Delete a file. + + Args: + knowledge_id: 知识库ID + extra_headers: Send extra headers + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not knowledge_id: + raise ValueError("Expected a non-empty value for `knowledge_id`") + + return self._delete( + f"/knowledge/{knowledge_id}", + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=httpx.Response, + ) + + def used( + self, + *, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> KnowledgeUsed: + """ + Returns the contents of the specified file. + + Args: + extra_headers: Send extra headers + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return self._get( + "/knowledge/capacity", + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=KnowledgeUsed, + ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/tools/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/tools/__init__.py new file mode 100644 index 00000000000000..43e4e37da1779f --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/tools/__init__.py @@ -0,0 +1,3 @@ +from .tools import Tools + +__all__ = ["Tools"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/tools/tools.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/tools/tools.py new file mode 100644 index 00000000000000..3c3a630aff47d7 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/tools/tools.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Literal, Optional, Union + +import httpx + +from ...core import ( + NOT_GIVEN, + BaseAPI, + Body, + Headers, + NotGiven, + StreamResponse, + deepcopy_minimal, + make_request_options, + maybe_transform, +) +from ...types.tools import WebSearch, WebSearchChunk, tools_web_search_params + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from ..._client import ZhipuAI + +__all__ = ["Tools"] + + +class Tools(BaseAPI): + def __init__(self, client: ZhipuAI) -> None: + super().__init__(client) + + def web_search( + self, + *, + model: str, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + messages: Union[str, list[str], list[int], object, None], + scope: Optional[str] | NotGiven = NOT_GIVEN, + location: Optional[str] | NotGiven = NOT_GIVEN, + recent_days: Optional[int] | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> WebSearch | StreamResponse[WebSearchChunk]: + body = deepcopy_minimal( + { + "model": model, + "request_id": request_id, + "messages": messages, + "stream": stream, + "scope": scope, + "location": location, + "recent_days": recent_days, + } + ) + return self._post( + "/tools", + body=maybe_transform(body, tools_web_search_params.WebSearchParams), + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=WebSearch, + stream=stream or False, + stream_cls=StreamResponse[WebSearchChunk], + ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/videos/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/videos/__init__.py new file mode 100644 index 00000000000000..6b0f99ed09efe3 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/videos/__init__.py @@ -0,0 +1,7 @@ +from .videos import ( + Videos, +) + +__all__ = [ + "Videos", +] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/videos/videos.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/videos/videos.py new file mode 100644 index 00000000000000..f1f1c08036a660 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/videos/videos.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import httpx + +from ...core import ( + NOT_GIVEN, + BaseAPI, + Body, + Headers, + NotGiven, + deepcopy_minimal, + make_request_options, + maybe_transform, +) +from ...types.sensitive_word_check import SensitiveWordCheckRequest +from ...types.video import VideoObject, video_create_params + +if TYPE_CHECKING: + from ..._client import ZhipuAI + +__all__ = ["Videos"] + + +class Videos(BaseAPI): + def __init__(self, client: ZhipuAI) -> None: + super().__init__(client) + + def generations( + self, + model: str, + *, + prompt: str = None, + image_url: str = None, + sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN, + request_id: str = None, + user_id: str = None, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> VideoObject: + if not model and not model: + raise ValueError("At least one of `model` and `prompt` must be provided.") + body = deepcopy_minimal( + { + "model": model, + "prompt": prompt, + "image_url": image_url, + "sensitive_word_check": sensitive_word_check, + "request_id": request_id, + "user_id": user_id, + } + ) + return self._post( + "/videos/generations", + body=maybe_transform(body, video_create_params.VideoCreateParams), + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=VideoObject, + ) + + def retrieve_videos_result( + self, + id: str, + *, + extra_headers: Headers | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> VideoObject: + if not id: + raise ValueError("At least one of `id` must be provided.") + + return self._get( + f"/async-result/{id}", + options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout), + cast_type=VideoObject, + ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/__init__.py index e69de29bb2d1d6..3d6466d279861a 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/__init__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/__init__.py @@ -0,0 +1,108 @@ +from ._base_api import BaseAPI +from ._base_compat import ( + PYDANTIC_V2, + ConfigDict, + GenericModel, + cached_property, + field_get_default, + get_args, + get_model_config, + get_model_fields, + get_origin, + is_literal_type, + is_union, + parse_obj, +) +from ._base_models import BaseModel, construct_type +from ._base_type import ( + NOT_GIVEN, + Body, + FileTypes, + Headers, + IncEx, + ModelT, + NotGiven, + Query, +) +from ._constants import ( + ZHIPUAI_DEFAULT_LIMITS, + ZHIPUAI_DEFAULT_MAX_RETRIES, + ZHIPUAI_DEFAULT_TIMEOUT, +) +from ._errors import ( + APIAuthenticationError, + APIConnectionError, + APIInternalError, + APIReachLimitError, + APIRequestFailedError, + APIResponseError, + APIResponseValidationError, + APIServerFlowExceedError, + APIStatusError, + APITimeoutError, + ZhipuAIError, +) +from ._files import is_file_content +from ._http_client import HttpClient, make_request_options +from ._sse_client import StreamResponse +from ._utils import ( + deepcopy_minimal, + drop_prefix_image_data, + extract_files, + is_given, + is_list, + is_mapping, + maybe_transform, + parse_date, + parse_datetime, +) + +__all__ = [ + "BaseModel", + "construct_type", + "BaseAPI", + "NOT_GIVEN", + "Headers", + "NotGiven", + "Body", + "IncEx", + "ModelT", + "Query", + "FileTypes", + "PYDANTIC_V2", + "ConfigDict", + "GenericModel", + "get_args", + "is_union", + "parse_obj", + "get_origin", + "is_literal_type", + "get_model_config", + "get_model_fields", + "field_get_default", + "is_file_content", + "ZhipuAIError", + "APIStatusError", + "APIRequestFailedError", + "APIAuthenticationError", + "APIReachLimitError", + "APIInternalError", + "APIServerFlowExceedError", + "APIResponseError", + "APIResponseValidationError", + "APITimeoutError", + "make_request_options", + "HttpClient", + "ZHIPUAI_DEFAULT_TIMEOUT", + "ZHIPUAI_DEFAULT_MAX_RETRIES", + "ZHIPUAI_DEFAULT_LIMITS", + "is_list", + "is_mapping", + "parse_date", + "parse_datetime", + "is_given", + "maybe_transform", + "deepcopy_minimal", + "extract_files", + "StreamResponse", +] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py index 10b46ff8e381a3..3592ea6bacd170 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py @@ -16,3 +16,4 @@ def __init__(self, client: ZhipuAI) -> None: self._post = client.post self._put = client.put self._patch = client.patch + self._get_api_list = client.get_api_list diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_compat.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_compat.py new file mode 100644 index 00000000000000..92a5d683be6732 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_compat.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +from collections.abc import Callable +from datetime import date, datetime +from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, overload + +import pydantic +from pydantic.fields import FieldInfo +from typing_extensions import Self + +from ._base_type import StrBytesIntFloat + +_T = TypeVar("_T") +_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel) + +# --------------- Pydantic v2 compatibility --------------- + +# Pyright incorrectly reports some of our functions as overriding a method when they don't +# pyright: reportIncompatibleMethodOverride=false + +PYDANTIC_V2 = pydantic.VERSION.startswith("2.") + +# v1 re-exports +if TYPE_CHECKING: + + def parse_date(value: date | StrBytesIntFloat) -> date: ... + + def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: ... + + def get_args(t: type[Any]) -> tuple[Any, ...]: ... + + def is_union(tp: type[Any] | None) -> bool: ... + + def get_origin(t: type[Any]) -> type[Any] | None: ... + + def is_literal_type(type_: type[Any]) -> bool: ... + + def is_typeddict(type_: type[Any]) -> bool: ... + +else: + if PYDANTIC_V2: + from pydantic.v1.typing import ( # noqa: I001 + get_args as get_args, # noqa: PLC0414 + is_union as is_union, # noqa: PLC0414 + get_origin as get_origin, # noqa: PLC0414 + is_typeddict as is_typeddict, # noqa: PLC0414 + is_literal_type as is_literal_type, # noqa: PLC0414 + ) + from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # noqa: PLC0414 + else: + from pydantic.typing import ( # noqa: I001 + get_args as get_args, # noqa: PLC0414 + is_union as is_union, # noqa: PLC0414 + get_origin as get_origin, # noqa: PLC0414 + is_typeddict as is_typeddict, # noqa: PLC0414 + is_literal_type as is_literal_type, # noqa: PLC0414 + ) + from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # noqa: PLC0414 + + +# refactored config +if TYPE_CHECKING: + from pydantic import ConfigDict +else: + if PYDANTIC_V2: + from pydantic import ConfigDict + else: + # TODO: provide an error message here? + ConfigDict = None + + +# renamed methods / properties +def parse_obj(model: type[_ModelT], value: object) -> _ModelT: + if PYDANTIC_V2: + return model.model_validate(value) + else: + # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + return cast(_ModelT, model.parse_obj(value)) + + +def field_is_required(field: FieldInfo) -> bool: + if PYDANTIC_V2: + return field.is_required() + return field.required # type: ignore + + +def field_get_default(field: FieldInfo) -> Any: + value = field.get_default() + if PYDANTIC_V2: + from pydantic_core import PydanticUndefined + + if value == PydanticUndefined: + return None + return value + return value + + +def field_outer_type(field: FieldInfo) -> Any: + if PYDANTIC_V2: + return field.annotation + return field.outer_type_ # type: ignore + + +def get_model_config(model: type[pydantic.BaseModel]) -> Any: + if PYDANTIC_V2: + return model.model_config + return model.__config__ # type: ignore + + +def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]: + if PYDANTIC_V2: + return model.model_fields + return model.__fields__ # type: ignore + + +def model_copy(model: _ModelT) -> _ModelT: + if PYDANTIC_V2: + return model.model_copy() + return model.copy() # type: ignore + + +def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: + if PYDANTIC_V2: + return model.model_dump_json(indent=indent) + return model.json(indent=indent) # type: ignore + + +def model_dump( + model: pydantic.BaseModel, + *, + exclude_unset: bool = False, + exclude_defaults: bool = False, +) -> dict[str, Any]: + if PYDANTIC_V2: + return model.model_dump( + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + ) + return cast( + "dict[str, Any]", + model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + ), + ) + + +def model_parse(model: type[_ModelT], data: Any) -> _ModelT: + if PYDANTIC_V2: + return model.model_validate(data) + return model.parse_obj(data) # pyright: ignore[reportDeprecated] + + +# generic models +if TYPE_CHECKING: + + class GenericModel(pydantic.BaseModel): ... + +else: + if PYDANTIC_V2: + # there no longer needs to be a distinction in v2 but + # we still have to create our own subclass to avoid + # inconsistent MRO ordering errors + class GenericModel(pydantic.BaseModel): ... + + else: + import pydantic.generics + + class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ... + + +# cached properties +if TYPE_CHECKING: + cached_property = property + + # we define a separate type (copied from typeshed) + # that represents that `cached_property` is `set`able + # at runtime, which differs from `@property`. + # + # this is a separate type as editors likely special case + # `@property` and we don't want to cause issues just to have + # more helpful internal types. + + class typed_cached_property(Generic[_T]): # noqa: N801 + func: Callable[[Any], _T] + attrname: str | None + + def __init__(self, func: Callable[[Any], _T]) -> None: ... + + @overload + def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ... + + @overload + def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ... + + def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self: + raise NotImplementedError() + + def __set_name__(self, owner: type[Any], name: str) -> None: ... + + # __set__ is not defined at runtime, but @cached_property is designed to be settable + def __set__(self, instance: object, value: _T) -> None: ... +else: + try: + from functools import cached_property + except ImportError: + from cached_property import cached_property + + typed_cached_property = cached_property diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_models.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_models.py new file mode 100644 index 00000000000000..5e9a7e0a987e28 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_models.py @@ -0,0 +1,671 @@ +from __future__ import annotations + +import inspect +import os +from collections.abc import Callable +from datetime import date, datetime +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeGuard, TypeVar, cast + +import pydantic +import pydantic.generics +from pydantic.fields import FieldInfo +from typing_extensions import ( + ParamSpec, + Protocol, + override, + runtime_checkable, +) + +from ._base_compat import ( + PYDANTIC_V2, + ConfigDict, + field_get_default, + get_args, + get_model_config, + get_model_fields, + get_origin, + is_literal_type, + is_union, + parse_obj, +) +from ._base_compat import ( + GenericModel as BaseGenericModel, +) +from ._base_type import ( + IncEx, + ModelT, +) +from ._utils import ( + PropertyInfo, + coerce_boolean, + extract_type_arg, + is_annotated_type, + is_list, + is_mapping, + parse_date, + parse_datetime, + strip_annotated_type, +) + +if TYPE_CHECKING: + from pydantic_core.core_schema import LiteralSchema, ModelField, ModelFieldsSchema + +__all__ = ["BaseModel", "GenericModel"] +_BaseModelT = TypeVar("_BaseModelT", bound="BaseModel") + +_T = TypeVar("_T") +P = ParamSpec("P") + + +@runtime_checkable +class _ConfigProtocol(Protocol): + allow_population_by_field_name: bool + + +class BaseModel(pydantic.BaseModel): + if PYDANTIC_V2: + model_config: ClassVar[ConfigDict] = ConfigDict( + extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")) + ) + else: + + @property + @override + def model_fields_set(self) -> set[str]: + # a forwards-compat shim for pydantic v2 + return self.__fields_set__ # type: ignore + + class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] + extra: Any = pydantic.Extra.allow # type: ignore + + def to_dict( + self, + *, + mode: Literal["json", "python"] = "python", + use_api_names: bool = True, + exclude_unset: bool = True, + exclude_defaults: bool = False, + exclude_none: bool = False, + warnings: bool = True, + ) -> dict[str, object]: + """Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude. + + By default, fields that were not set by the API will not be included, + and keys will match the API response, *not* the property names from the model. + + For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property, + the output will use the `"fooBar"` key (unless `use_api_names=False` is passed). + + Args: + mode: + If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`. + If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)` + + use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`. + exclude_unset: Whether to exclude fields that have not been explicitly set. + exclude_defaults: Whether to exclude fields that are set to their default value from the output. + exclude_none: Whether to exclude fields that have a value of `None` from the output. + warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2. + """ # noqa: E501 + return self.model_dump( + mode=mode, + by_alias=use_api_names, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + warnings=warnings, + ) + + def to_json( + self, + *, + indent: int | None = 2, + use_api_names: bool = True, + exclude_unset: bool = True, + exclude_defaults: bool = False, + exclude_none: bool = False, + warnings: bool = True, + ) -> str: + """Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation). + + By default, fields that were not set by the API will not be included, + and keys will match the API response, *not* the property names from the model. + + For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property, + the output will use the `"fooBar"` key (unless `use_api_names=False` is passed). + + Args: + indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2` + use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`. + exclude_unset: Whether to exclude fields that have not been explicitly set. + exclude_defaults: Whether to exclude fields that have the default value. + exclude_none: Whether to exclude fields that have a value of `None`. + warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2. + """ # noqa: E501 + return self.model_dump_json( + indent=indent, + by_alias=use_api_names, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + warnings=warnings, + ) + + @override + def __str__(self) -> str: + # mypy complains about an invalid self arg + return f'{self.__repr_name__()}({self.__repr_str__(", ")})' # type: ignore[misc] + + # Override the 'construct' method in a way that supports recursive parsing without validation. + # Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836. + @classmethod + @override + def construct( + cls: type[ModelT], + _fields_set: set[str] | None = None, + **values: object, + ) -> ModelT: + m = cls.__new__(cls) + fields_values: dict[str, object] = {} + + config = get_model_config(cls) + populate_by_name = ( + config.allow_population_by_field_name + if isinstance(config, _ConfigProtocol) + else config.get("populate_by_name") + ) + + if _fields_set is None: + _fields_set = set() + + model_fields = get_model_fields(cls) + for name, field in model_fields.items(): + key = field.alias + if key is None or (key not in values and populate_by_name): + key = name + + if key in values: + fields_values[name] = _construct_field(value=values[key], field=field, key=key) + _fields_set.add(name) + else: + fields_values[name] = field_get_default(field) + + _extra = {} + for key, value in values.items(): + if key not in model_fields: + if PYDANTIC_V2: + _extra[key] = value + else: + _fields_set.add(key) + fields_values[key] = value + + object.__setattr__(m, "__dict__", fields_values) # noqa: PLC2801 + + if PYDANTIC_V2: + # these properties are copied from Pydantic's `model_construct()` method + object.__setattr__(m, "__pydantic_private__", None) # noqa: PLC2801 + object.__setattr__(m, "__pydantic_extra__", _extra) # noqa: PLC2801 + object.__setattr__(m, "__pydantic_fields_set__", _fields_set) # noqa: PLC2801 + else: + # init_private_attributes() does not exist in v2 + m._init_private_attributes() # type: ignore + + # copied from Pydantic v1's `construct()` method + object.__setattr__(m, "__fields_set__", _fields_set) # noqa: PLC2801 + + return m + + if not TYPE_CHECKING: + # type checkers incorrectly complain about this assignment + # because the type signatures are technically different + # although not in practice + model_construct = construct + + if not PYDANTIC_V2: + # we define aliases for some of the new pydantic v2 methods so + # that we can just document these methods without having to specify + # a specific pydantic version as some users may not know which + # pydantic version they are currently using + + @override + def model_dump( + self, + *, + mode: Literal["json", "python"] | str = "python", + include: IncEx = None, + exclude: IncEx = None, + by_alias: bool = False, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + round_trip: bool = False, + warnings: bool | Literal["none", "warn", "error"] = True, + context: dict[str, Any] | None = None, + serialize_as_any: bool = False, + ) -> dict[str, Any]: + """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump + + Generate a dictionary representation of the model, optionally specifying which fields to include or exclude. + + Args: + mode: The mode in which `to_python` should run. + If mode is 'json', the dictionary will only contain JSON serializable types. + If mode is 'python', the dictionary may contain any Python objects. + include: A list of fields to include in the output. + exclude: A list of fields to exclude from the output. + by_alias: Whether to use the field's alias in the dictionary key if defined. + exclude_unset: Whether to exclude fields that are unset or None from the output. + exclude_defaults: Whether to exclude fields that are set to their default value from the output. + exclude_none: Whether to exclude fields that have a value of `None` from the output. + round_trip: Whether to enable serialization and deserialization round-trip support. + warnings: Whether to log warnings when invalid fields are encountered. + + Returns: + A dictionary representation of the model. + """ + if mode != "python": + raise ValueError("mode is only supported in Pydantic v2") + if round_trip != False: + raise ValueError("round_trip is only supported in Pydantic v2") + if warnings != True: + raise ValueError("warnings is only supported in Pydantic v2") + if context is not None: + raise ValueError("context is only supported in Pydantic v2") + if serialize_as_any != False: + raise ValueError("serialize_as_any is only supported in Pydantic v2") + return super().dict( # pyright: ignore[reportDeprecated] + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + @override + def model_dump_json( + self, + *, + indent: int | None = None, + include: IncEx = None, + exclude: IncEx = None, + by_alias: bool = False, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + round_trip: bool = False, + warnings: bool | Literal["none", "warn", "error"] = True, + context: dict[str, Any] | None = None, + serialize_as_any: bool = False, + ) -> str: + """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json + + Generates a JSON representation of the model using Pydantic's `to_json` method. + + Args: + indent: Indentation to use in the JSON output. If None is passed, the output will be compact. + include: Field(s) to include in the JSON output. Can take either a string or set of strings. + exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings. + by_alias: Whether to serialize using field aliases. + exclude_unset: Whether to exclude fields that have not been explicitly set. + exclude_defaults: Whether to exclude fields that have the default value. + exclude_none: Whether to exclude fields that have a value of `None`. + round_trip: Whether to use serialization/deserialization between JSON and class instance. + warnings: Whether to show any warnings that occurred during serialization. + + Returns: + A JSON string representation of the model. + """ + if round_trip != False: + raise ValueError("round_trip is only supported in Pydantic v2") + if warnings != True: + raise ValueError("warnings is only supported in Pydantic v2") + if context is not None: + raise ValueError("context is only supported in Pydantic v2") + if serialize_as_any != False: + raise ValueError("serialize_as_any is only supported in Pydantic v2") + return super().json( # type: ignore[reportDeprecated] + indent=indent, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + +def _construct_field(value: object, field: FieldInfo, key: str) -> object: + if value is None: + return field_get_default(field) + + if PYDANTIC_V2: + type_ = field.annotation + else: + type_ = cast(type, field.outer_type_) # type: ignore + + if type_ is None: + raise RuntimeError(f"Unexpected field type is None for {key}") + + return construct_type(value=value, type_=type_) + + +def is_basemodel(type_: type) -> bool: + """Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`""" + if is_union(type_): + return any(is_basemodel(variant) for variant in get_args(type_)) + + return is_basemodel_type(type_) + + +def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]: + origin = get_origin(type_) or type_ + return issubclass(origin, BaseModel) or issubclass(origin, GenericModel) + + +def build( + base_model_cls: Callable[P, _BaseModelT], + *args: P.args, + **kwargs: P.kwargs, +) -> _BaseModelT: + """Construct a BaseModel class without validation. + + This is useful for cases where you need to instantiate a `BaseModel` + from an API response as this provides type-safe params which isn't supported + by helpers like `construct_type()`. + + ```py + build(MyModel, my_field_a="foo", my_field_b=123) + ``` + """ + if args: + raise TypeError( + "Received positional arguments which are not supported; Keyword arguments must be used instead", + ) + + return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs)) + + +def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T: + """Loose coercion to the expected type with construction of nested values. + + Note: the returned value from this function is not guaranteed to match the + given type. + """ + return cast(_T, construct_type(value=value, type_=type_)) + + +def construct_type(*, value: object, type_: type) -> object: + """Loose coercion to the expected type with construction of nested values. + + If the given value does not match the expected type then it is returned as-is. + """ + # we allow `object` as the input type because otherwise, passing things like + # `Literal['value']` will be reported as a type error by type checkers + type_ = cast("type[object]", type_) + + # unwrap `Annotated[T, ...]` -> `T` + if is_annotated_type(type_): + meta: tuple[Any, ...] = get_args(type_)[1:] + type_ = extract_type_arg(type_, 0) + else: + meta = () + # we need to use the origin class for any types that are subscripted generics + # e.g. Dict[str, object] + origin = get_origin(type_) or type_ + args = get_args(type_) + + if is_union(origin): + try: + return validate_type(type_=cast("type[object]", type_), value=value) + except Exception: + pass + + # if the type is a discriminated union then we want to construct the right variant + # in the union, even if the data doesn't match exactly, otherwise we'd break code + # that relies on the constructed class types, e.g. + # + # class FooType: + # kind: Literal['foo'] + # value: str + # + # class BarType: + # kind: Literal['bar'] + # value: int + # + # without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then + # we'd end up constructing `FooType` when it should be `BarType`. + discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta) + if discriminator and is_mapping(value): + variant_value = value.get(discriminator.field_alias_from or discriminator.field_name) + if variant_value and isinstance(variant_value, str): + variant_type = discriminator.mapping.get(variant_value) + if variant_type: + return construct_type(type_=variant_type, value=value) + + # if the data is not valid, use the first variant that doesn't fail while deserializing + for variant in args: + try: + return construct_type(value=value, type_=variant) + except Exception: + continue + + raise RuntimeError(f"Could not convert data into a valid instance of {type_}") + if origin == dict: + if not is_mapping(value): + return value + + _, items_type = get_args(type_) # Dict[_, items_type] + return {key: construct_type(value=item, type_=items_type) for key, item in value.items()} + + if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)): + if is_list(value): + return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value] + + if is_mapping(value): + if issubclass(type_, BaseModel): + return type_.construct(**value) # type: ignore[arg-type] + + return cast(Any, type_).construct(**value) + + if origin == list: + if not is_list(value): + return value + + inner_type = args[0] # List[inner_type] + return [construct_type(value=entry, type_=inner_type) for entry in value] + + if origin == float: + if isinstance(value, int): + coerced = float(value) + if coerced != value: + return value + return coerced + + return value + + if type_ == datetime: + try: + return parse_datetime(value) # type: ignore + except Exception: + return value + + if type_ == date: + try: + return parse_date(value) # type: ignore + except Exception: + return value + + return value + + +@runtime_checkable +class CachedDiscriminatorType(Protocol): + __discriminator__: DiscriminatorDetails + + +class DiscriminatorDetails: + field_name: str + """The name of the discriminator field in the variant class, e.g. + + ```py + class Foo(BaseModel): + type: Literal['foo'] + ``` + + Will result in field_name='type' + """ + + field_alias_from: str | None + """The name of the discriminator field in the API response, e.g. + + ```py + class Foo(BaseModel): + type: Literal['foo'] = Field(alias='type_from_api') + ``` + + Will result in field_alias_from='type_from_api' + """ + + mapping: dict[str, type] + """Mapping of discriminator value to variant type, e.g. + + {'foo': FooVariant, 'bar': BarVariant} + """ + + def __init__( + self, + *, + mapping: dict[str, type], + discriminator_field: str, + discriminator_alias: str | None, + ) -> None: + self.mapping = mapping + self.field_name = discriminator_field + self.field_alias_from = discriminator_alias + + +def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: + if isinstance(union, CachedDiscriminatorType): + return union.__discriminator__ + + discriminator_field_name: str | None = None + + for annotation in meta_annotations: + if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None: + discriminator_field_name = annotation.discriminator + break + + if not discriminator_field_name: + return None + + mapping: dict[str, type] = {} + discriminator_alias: str | None = None + + for variant in get_args(union): + variant = strip_annotated_type(variant) + if is_basemodel_type(variant): + if PYDANTIC_V2: + field = _extract_field_schema_pv2(variant, discriminator_field_name) + if not field: + continue + + # Note: if one variant defines an alias then they all should + discriminator_alias = field.get("serialization_alias") + + field_schema = field["schema"] + + if field_schema["type"] == "literal": + for entry in cast("LiteralSchema", field_schema)["expected"]: + if isinstance(entry, str): + mapping[entry] = variant + else: + field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + if not field_info: + continue + + # Note: if one variant defines an alias then they all should + discriminator_alias = field_info.alias + + if field_info.annotation and is_literal_type(field_info.annotation): + for entry in get_args(field_info.annotation): + if isinstance(entry, str): + mapping[entry] = variant + + if not mapping: + return None + + details = DiscriminatorDetails( + mapping=mapping, + discriminator_field=discriminator_field_name, + discriminator_alias=discriminator_alias, + ) + cast(CachedDiscriminatorType, union).__discriminator__ = details + return details + + +def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None: + schema = model.__pydantic_core_schema__ + if schema["type"] != "model": + return None + + fields_schema = schema["schema"] + if fields_schema["type"] != "model-fields": + return None + + fields_schema = cast("ModelFieldsSchema", fields_schema) + + field = fields_schema["fields"].get(field_name) + if not field: + return None + + return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast] + + +def validate_type(*, type_: type[_T], value: object) -> _T: + """Strict validation that the given value matches the expected type""" + if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel): + return cast(_T, parse_obj(type_, value)) + + return cast(_T, _validate_non_model_type(type_=type_, value=value)) + + +# our use of subclasssing here causes weirdness for type checkers, +# so we just pretend that we don't subclass +if TYPE_CHECKING: + GenericModel = BaseModel +else: + + class GenericModel(BaseGenericModel, BaseModel): + pass + + +if PYDANTIC_V2: + from pydantic import TypeAdapter + + def _validate_non_model_type(*, type_: type[_T], value: object) -> _T: + return TypeAdapter(type_).validate_python(value) + +elif not TYPE_CHECKING: + + class TypeAdapter(Generic[_T]): + """Used as a placeholder to easily convert runtime types to a Pydantic format + to provide validation. + + For example: + ```py + validated = RootModel[int](__root__="5").__root__ + # validated: 5 + ``` + """ + + def __init__(self, type_: type[_T]): + self.type_ = type_ + + def validate_python(self, value: Any) -> _T: + if not isinstance(value, self.type_): + raise ValueError(f"Invalid type: {value} is not of type {self.type_}") + return value + + def _validate_non_model_type(*, type_: type[_T], value: object) -> _T: + return TypeAdapter(type_).validate_python(value) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py index 7a91f9b79627c4..ea1d3f09dc42ea 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py @@ -1,11 +1,21 @@ from __future__ import annotations -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from os import PathLike -from typing import IO, TYPE_CHECKING, Any, Literal, TypeVar, Union +from typing import ( + IO, + TYPE_CHECKING, + Any, + Literal, + Optional, + TypeAlias, + TypeVar, + Union, +) import pydantic -from typing_extensions import override +from httpx import Response +from typing_extensions import Protocol, TypedDict, override, runtime_checkable Query = Mapping[str, object] Body = object @@ -22,7 +32,7 @@ # Sentinel class used until PEP 0661 is accepted -class NotGiven(pydantic.BaseModel): +class NotGiven: """ A sentinel singleton class used to distinguish omitted keyword arguments from those passed in with the value None (which may have different behavior). @@ -50,7 +60,7 @@ def __repr__(self) -> str: NOT_GIVEN = NotGiven() -class Omit(pydantic.BaseModel): +class Omit: """In certain situations you need to be able to represent a case where a default value has to be explicitly removed and `None` is not an appropriate substitute, for example: @@ -71,37 +81,90 @@ def __bool__(self) -> Literal[False]: return False +@runtime_checkable +class ModelBuilderProtocol(Protocol): + @classmethod + def build( + cls: type[_T], + *, + response: Response, + data: object, + ) -> _T: ... + + Headers = Mapping[str, Union[str, Omit]] + +class HeadersLikeProtocol(Protocol): + def get(self, __key: str) -> str | None: ... + + +HeadersLike = Union[Headers, HeadersLikeProtocol] + ResponseT = TypeVar( "ResponseT", - bound="Union[str, None, BaseModel, list[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol," - " BinaryResponseContent]", + bound="Union[str, None, BaseModel, list[Any], dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", # noqa: E501 ) +StrBytesIntFloat = Union[str, bytes, int, float] + +# Note: copied from Pydantic +# https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49 +IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None" + +PostParser = Callable[[Any], Any] + + +@runtime_checkable +class InheritsGeneric(Protocol): + """Represents a type that has inherited from `Generic` + + The `__orig_bases__` property can be used to determine the resolved + type variable for a given base class. + """ + + __orig_bases__: tuple[_GenericAlias] + + +class _GenericAlias(Protocol): + __origin__: type[object] + + +class HttpxSendArgs(TypedDict, total=False): + auth: httpx.Auth + + # for user input files if TYPE_CHECKING: + Base64FileInput = Union[IO[bytes], PathLike[str]] FileContent = Union[IO[bytes], bytes, PathLike[str]] else: + Base64FileInput = Union[IO[bytes], PathLike] FileContent = Union[IO[bytes], bytes, PathLike] FileTypes = Union[ - FileContent, # file content - tuple[str, FileContent], # (filename, file) - tuple[str, FileContent, str], # (filename, file , content_type) - tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) + # file (or bytes) + FileContent, + # (filename, file (or bytes)) + tuple[Optional[str], FileContent], + # (filename, file (or bytes), content_type) + tuple[Optional[str], FileContent, Optional[str]], + # (filename, file (or bytes), content_type, headers) + tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]], ] - RequestFiles = Union[Mapping[str, FileTypes], Sequence[tuple[str, FileTypes]]] -# for httpx client supported files - +# duplicate of the above but without our custom file support HttpxFileContent = Union[bytes, IO[bytes]] HttpxFileTypes = Union[ - FileContent, # file content - tuple[str, HttpxFileContent], # (filename, file) - tuple[str, HttpxFileContent, str], # (filename, file , content_type) - tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) + # file (or bytes) + HttpxFileContent, + # (filename, file (or bytes)) + tuple[Optional[str], HttpxFileContent], + # (filename, file (or bytes), content_type) + tuple[Optional[str], HttpxFileContent, Optional[str]], + # (filename, file (or bytes), content_type, headers) + tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]], ] HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[tuple[str, HttpxFileTypes]]] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_constants.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_constants.py new file mode 100644 index 00000000000000..8e43bdebecb61f --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_constants.py @@ -0,0 +1,12 @@ +import httpx + +RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response" +# 通过 `Timeout` 控制接口`connect` 和 `read` 超时时间,默认为`timeout=300.0, connect=8.0` +ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0) +# 通过 `retry` 参数控制重试次数,默认为3次 +ZHIPUAI_DEFAULT_MAX_RETRIES = 3 +# 通过 `Limits` 控制最大连接数和保持连接数,默认为`max_connections=50, max_keepalive_connections=10` +ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10) + +INITIAL_RETRY_DELAY = 0.5 +MAX_RETRY_DELAY = 8.0 diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py index 1027c1bc5b1e55..e2c9d24c6c0d24 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py @@ -13,6 +13,7 @@ "APIResponseError", "APIResponseValidationError", "APITimeoutError", + "APIConnectionError", ] @@ -24,7 +25,7 @@ def __init__( super().__init__(message) -class APIStatusError(Exception): +class APIStatusError(ZhipuAIError): response: httpx.Response status_code: int @@ -49,7 +50,7 @@ class APIInternalError(APIStatusError): ... class APIServerFlowExceedError(APIStatusError): ... -class APIResponseError(Exception): +class APIResponseError(ZhipuAIError): message: str request: httpx.Request json_data: object @@ -75,9 +76,11 @@ def __init__(self, response: httpx.Response, json_data: object | None, *, messag self.status_code = response.status_code -class APITimeoutError(Exception): - request: httpx.Request +class APIConnectionError(APIResponseError): + def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None: + super().__init__(message, request, json_data=None) - def __init__(self, request: httpx.Request): - self.request = request - super().__init__("Request Timeout") + +class APITimeoutError(APIConnectionError): + def __init__(self, request: httpx.Request) -> None: + super().__init__(message="Request timed out.", request=request) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py index 0796bfe11cc658..f9d2e14d9ecb93 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py @@ -2,40 +2,74 @@ import io import os -from collections.abc import Mapping, Sequence -from pathlib import Path +import pathlib +from typing import TypeGuard, overload -from ._base_type import FileTypes, HttpxFileTypes, HttpxRequestFiles, RequestFiles +from ._base_type import ( + Base64FileInput, + FileContent, + FileTypes, + HttpxFileContent, + HttpxFileTypes, + HttpxRequestFiles, + RequestFiles, +) +from ._utils import is_mapping_t, is_sequence_t, is_tuple_t -def is_file_content(obj: object) -> bool: +def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]: + return isinstance(obj, io.IOBase | os.PathLike) + + +def is_file_content(obj: object) -> TypeGuard[FileContent]: return isinstance(obj, bytes | tuple | io.IOBase | os.PathLike) -def _transform_file(file: FileTypes) -> HttpxFileTypes: - if is_file_content(file): - if isinstance(file, os.PathLike): - path = Path(file) - return path.name, path.read_bytes() - else: - return file - if isinstance(file, tuple): - if isinstance(file[1], os.PathLike): - return (file[0], Path(file[1]).read_bytes(), *file[2:]) - else: - return (file[0], file[1], *file[2:]) - else: - raise TypeError(f"Unexpected input file with type {type(file)},Expected FileContent type or tuple type") +def assert_is_file_content(obj: object, *, key: str | None = None) -> None: + if not is_file_content(obj): + prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`" + raise RuntimeError( + f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads" + ) from None -def make_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: +@overload +def to_httpx_files(files: None) -> None: ... + + +@overload +def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ... + + +def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: if files is None: return None - if isinstance(files, Mapping): + if is_mapping_t(files): files = {key: _transform_file(file) for key, file in files.items()} - elif isinstance(files, Sequence): + elif is_sequence_t(files): files = [(key, _transform_file(file)) for key, file in files] else: - raise TypeError(f"Unexpected input file with type {type(files)}, excepted Mapping or Sequence") + raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence") + return files + + +def _transform_file(file: FileTypes) -> HttpxFileTypes: + if is_file_content(file): + if isinstance(file, os.PathLike): + path = pathlib.Path(file) + return (path.name, path.read_bytes()) + + return file + + if is_tuple_t(file): + return (file[0], _read_file_content(file[1]), *file[2:]) + + raise TypeError("Expected file types input to be a FileContent type or to be a tuple") + + +def _read_file_content(file: FileContent) -> HttpxFileContent: + if isinstance(file, os.PathLike): + return pathlib.Path(file).read_bytes() + return file diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py index 5f7f6d04f20d98..d0f933d8141389 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py @@ -1,23 +1,70 @@ from __future__ import annotations import inspect -from collections.abc import Mapping -from typing import Any, Union, cast +import logging +import time +import warnings +from collections.abc import Iterator, Mapping +from itertools import starmap +from random import random +from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, Union, cast, overload import httpx import pydantic from httpx import URL, Timeout -from tenacity import retry -from tenacity.stop import stop_after_attempt - -from . import _errors -from ._base_type import NOT_GIVEN, AnyMapping, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT -from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError -from ._files import make_httpx_files -from ._request_opt import ClientRequestParam, UserRequestInput -from ._response import HttpResponse + +from . import _errors, get_origin +from ._base_compat import model_copy +from ._base_models import GenericModel, construct_type, validate_type +from ._base_type import ( + NOT_GIVEN, + AnyMapping, + Body, + Data, + Headers, + HttpxSendArgs, + ModelBuilderProtocol, + NotGiven, + Omit, + PostParser, + Query, + RequestFiles, + ResponseT, +) +from ._constants import ( + INITIAL_RETRY_DELAY, + MAX_RETRY_DELAY, + RAW_RESPONSE_HEADER, + ZHIPUAI_DEFAULT_LIMITS, + ZHIPUAI_DEFAULT_MAX_RETRIES, + ZHIPUAI_DEFAULT_TIMEOUT, +) +from ._errors import APIConnectionError, APIResponseValidationError, APIStatusError, APITimeoutError +from ._files import to_httpx_files +from ._legacy_response import LegacyAPIResponse +from ._request_opt import FinalRequestOptions, UserRequestInput +from ._response import APIResponse, BaseAPIResponse, extract_response_type from ._sse_client import StreamResponse -from ._utils import flatten +from ._utils import flatten, is_given, is_mapping + +log: logging.Logger = logging.getLogger(__name__) + +# TODO: make base page type vars covariant +SyncPageT = TypeVar("SyncPageT", bound="BaseSyncPage[Any]") +# AsyncPageT = TypeVar("AsyncPageT", bound="BaseAsyncPage[Any]") + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) + +if TYPE_CHECKING: + from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT +else: + try: + from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT + except ImportError: + # taken from https://github.com/encode/httpx/blob/3ba5fe0d7ac70222590e759c31442b1cab263791/httpx/_config.py#L366 + HTTPX_DEFAULT_TIMEOUT = Timeout(5.0) + headers = { "Accept": "application/json", @@ -25,50 +72,180 @@ } -def _merge_map(map1: Mapping, map2: Mapping) -> Mapping: - merged = {**map1, **map2} - return {key: val for key, val in merged.items() if val is not None} +class PageInfo: + """Stores the necessary information to build the request to retrieve the next page. + Either `url` or `params` must be set. + """ -from itertools import starmap + url: URL | NotGiven + params: Query | NotGiven + + @overload + def __init__( + self, + *, + url: URL, + ) -> None: ... + + @overload + def __init__( + self, + *, + params: Query, + ) -> None: ... + + def __init__( + self, + *, + url: URL | NotGiven = NOT_GIVEN, + params: Query | NotGiven = NOT_GIVEN, + ) -> None: + self.url = url + self.params = params + + +class BasePage(GenericModel, Generic[_T]): + """ + Defines the core interface for pagination. + + Type Args: + ModelT: The pydantic model that represents an item in the response. + + Methods: + has_next_page(): Check if there is another page available + next_page_info(): Get the necessary information to make a request for the next page + """ + + _options: FinalRequestOptions = pydantic.PrivateAttr() + _model: type[_T] = pydantic.PrivateAttr() -from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT + def has_next_page(self) -> bool: + items = self._get_page_items() + if not items: + return False + return self.next_page_info() is not None -ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0) -ZHIPUAI_DEFAULT_MAX_RETRIES = 3 -ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=5, max_keepalive_connections=5) + def next_page_info(self) -> Optional[PageInfo]: ... + + def _get_page_items(self) -> Iterable[_T]: # type: ignore[empty-body] + ... + + def _params_from_url(self, url: URL) -> httpx.QueryParams: + # TODO: do we have to preprocess params here? + return httpx.QueryParams(cast(Any, self._options.params)).merge(url.params) + + def _info_to_options(self, info: PageInfo) -> FinalRequestOptions: + options = model_copy(self._options) + options._strip_raw_response_header() + + if not isinstance(info.params, NotGiven): + options.params = {**options.params, **info.params} + return options + + if not isinstance(info.url, NotGiven): + params = self._params_from_url(info.url) + url = info.url.copy_with(params=params) + options.params = dict(url.params) + options.url = str(url) + return options + + raise ValueError("Unexpected PageInfo state") + + +class BaseSyncPage(BasePage[_T], Generic[_T]): + _client: HttpClient = pydantic.PrivateAttr() + + def _set_private_attributes( + self, + client: HttpClient, + model: type[_T], + options: FinalRequestOptions, + ) -> None: + self._model = model + self._client = client + self._options = options + + # Pydantic uses a custom `__iter__` method to support casting BaseModels + # to dictionaries. e.g. dict(model). + # As we want to support `for item in page`, this is inherently incompatible + # with the default pydantic behaviour. It is not possible to support both + # use cases at once. Fortunately, this is not a big deal as all other pydantic + # methods should continue to work as expected as there is an alternative method + # to cast a model to a dictionary, model.dict(), which is used internally + # by pydantic. + def __iter__(self) -> Iterator[_T]: # type: ignore + for page in self.iter_pages(): + yield from page._get_page_items() + + def iter_pages(self: SyncPageT) -> Iterator[SyncPageT]: + page = self + while True: + yield page + if page.has_next_page(): + page = page.get_next_page() + else: + return + + def get_next_page(self: SyncPageT) -> SyncPageT: + info = self.next_page_info() + if not info: + raise RuntimeError( + "No next page expected; please check `.has_next_page()` before calling `.get_next_page()`." + ) + + options = self._info_to_options(info) + return self._client._request_api_list(self._model, page=self.__class__, options=options) class HttpClient: _client: httpx.Client _version: str _base_url: URL - + max_retries: int timeout: Union[float, Timeout, None] _limits: httpx.Limits _has_custom_http_client: bool _default_stream_cls: type[StreamResponse[Any]] | None = None + _strict_response_validation: bool + def __init__( self, *, version: str, base_url: URL, + _strict_response_validation: bool, + max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, timeout: Union[float, Timeout, None], + limits: httpx.Limits | None = None, custom_httpx_client: httpx.Client | None = None, custom_headers: Mapping[str, str] | None = None, ) -> None: - if timeout is None or isinstance(timeout, NotGiven): + if limits is not None: + warnings.warn( + "The `connection_pool_limits` argument is deprecated. The `http_client` argument should be passed instead", # noqa: E501 + category=DeprecationWarning, + stacklevel=3, + ) + if custom_httpx_client is not None: + raise ValueError("The `http_client` argument is mutually exclusive with `connection_pool_limits`") + else: + limits = ZHIPUAI_DEFAULT_LIMITS + + if not is_given(timeout): if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT: timeout = custom_httpx_client.timeout else: timeout = ZHIPUAI_DEFAULT_TIMEOUT - self.timeout = cast(Timeout, timeout) + self.max_retries = max_retries + self.timeout = timeout + self._limits = limits self._has_custom_http_client = bool(custom_httpx_client) self._client = custom_httpx_client or httpx.Client( base_url=base_url, timeout=self.timeout, - limits=ZHIPUAI_DEFAULT_LIMITS, + limits=limits, ) self._version = version url = URL(url=base_url) @@ -76,6 +253,7 @@ def __init__( url = url.copy_with(raw_path=url.raw_path + b"/") self._base_url = url self._custom_headers = custom_headers or {} + self._strict_response_validation = _strict_response_validation def _prepare_url(self, url: str) -> URL: sub_url = URL(url) @@ -93,55 +271,101 @@ def _default_headers(self): "ZhipuAI-SDK-Ver": self._version, "source_type": "zhipu-sdk-python", "x-request-sdk": "zhipu-sdk-python", - **self._auth_headers, + **self.auth_headers, **self._custom_headers, } @property - def _auth_headers(self): + def custom_auth(self) -> httpx.Auth | None: + return None + + @property + def auth_headers(self): return {} - def _prepare_headers(self, request_param: ClientRequestParam) -> httpx.Headers: - custom_headers = request_param.headers or {} - headers_dict = _merge_map(self._default_headers, custom_headers) + def _prepare_headers(self, options: FinalRequestOptions) -> httpx.Headers: + custom_headers = options.headers or {} + headers_dict = _merge_mappings(self._default_headers, custom_headers) httpx_headers = httpx.Headers(headers_dict) return httpx_headers - def _prepare_request(self, request_param: ClientRequestParam) -> httpx.Request: + def _remaining_retries( + self, + remaining_retries: Optional[int], + options: FinalRequestOptions, + ) -> int: + return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries) + + def _calculate_retry_timeout( + self, + remaining_retries: int, + options: FinalRequestOptions, + response_headers: Optional[httpx.Headers] = None, + ) -> float: + max_retries = options.get_max_retries(self.max_retries) + + # If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says. + # retry_after = self._parse_retry_after_header(response_headers) + # if retry_after is not None and 0 < retry_after <= 60: + # return retry_after + + nb_retries = max_retries - remaining_retries + + # Apply exponential backoff, but not more than the max. + sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2.0, nb_retries), MAX_RETRY_DELAY) + + # Apply some jitter, plus-or-minus half a second. + jitter = 1 - 0.25 * random() + timeout = sleep_seconds * jitter + return max(timeout, 0) + + def _build_request(self, options: FinalRequestOptions) -> httpx.Request: kwargs: dict[str, Any] = {} - json_data = request_param.json_data - headers = self._prepare_headers(request_param) - url = self._prepare_url(request_param.url) - json_data = request_param.json_data + headers = self._prepare_headers(options) + url = self._prepare_url(options.url) + json_data = options.json_data + if options.extra_json is not None: + if json_data is None: + json_data = cast(Body, options.extra_json) + elif is_mapping(json_data): + json_data = _merge_mappings(json_data, options.extra_json) + else: + raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`") + + content_type = headers.get("Content-Type") + # multipart/form-data; boundary=---abc-- if headers.get("Content-Type") == "multipart/form-data": - headers.pop("Content-Type") + if "boundary" not in content_type: + # only remove the header if the boundary hasn't been explicitly set + # as the caller doesn't want httpx to come up with their own boundary + headers.pop("Content-Type") if json_data: kwargs["data"] = self._make_multipartform(json_data) return self._client.build_request( headers=headers, - timeout=self.timeout if isinstance(request_param.timeout, NotGiven) else request_param.timeout, - method=request_param.method, + timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout, + method=options.method, url=url, json=json_data, - files=request_param.files, - params=request_param.params, + files=options.files, + params=options.params, **kwargs, ) - def _object_to_formdata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]: + def _object_to_formfata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]: items = [] if isinstance(value, Mapping): for k, v in value.items(): - items.extend(self._object_to_formdata(f"{key}[{k}]", v)) + items.extend(self._object_to_formfata(f"{key}[{k}]", v)) return items if isinstance(value, list | tuple): for v in value: - items.extend(self._object_to_formdata(key + "[]", v)) + items.extend(self._object_to_formfata(key + "[]", v)) return items def _primitive_value_to_str(val) -> str: @@ -161,7 +385,7 @@ def _primitive_value_to_str(val) -> str: return [(key, str_data)] def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]: - items = flatten(list(starmap(self._object_to_formdata, data.items()))) + items = flatten(list(starmap(self._object_to_formfata, data.items()))) serialized: dict[str, object] = {} for key, value in items: @@ -170,20 +394,6 @@ def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object serialized[key] = value return serialized - def _parse_response( - self, - *, - cast_type: type[ResponseT], - response: httpx.Response, - enable_stream: bool, - request_param: ClientRequestParam, - stream_cls: type[StreamResponse[Any]] | None = None, - ) -> HttpResponse: - http_response = HttpResponse( - raw_response=response, cast_type=cast_type, client=self, enable_stream=enable_stream, stream_cls=stream_cls - ) - return http_response.parse() - def _process_response_data( self, *, @@ -194,14 +404,58 @@ def _process_response_data( if data is None: return cast(ResponseT, None) + if cast_type is object: + return cast(ResponseT, data) + try: - if inspect.isclass(cast_type) and issubclass(cast_type, pydantic.BaseModel): - return cast(ResponseT, cast_type.validate(data)) + if inspect.isclass(cast_type) and issubclass(cast_type, ModelBuilderProtocol): + return cast(ResponseT, cast_type.build(response=response, data=data)) + + if self._strict_response_validation: + return cast(ResponseT, validate_type(type_=cast_type, value=data)) - return cast(ResponseT, pydantic.TypeAdapter(cast_type).validate_python(data)) + return cast(ResponseT, construct_type(type_=cast_type, value=data)) except pydantic.ValidationError as err: raise APIResponseValidationError(response=response, json_data=data) from err + def _should_stream_response_body(self, request: httpx.Request) -> bool: + return request.headers.get(RAW_RESPONSE_HEADER) == "stream" # type: ignore[no-any-return] + + def _should_retry(self, response: httpx.Response) -> bool: + # Note: this is not a standard header + should_retry_header = response.headers.get("x-should-retry") + + # If the server explicitly says whether or not to retry, obey. + if should_retry_header == "true": + log.debug("Retrying as header `x-should-retry` is set to `true`") + return True + if should_retry_header == "false": + log.debug("Not retrying as header `x-should-retry` is set to `false`") + return False + + # Retry on request timeouts. + if response.status_code == 408: + log.debug("Retrying due to status code %i", response.status_code) + return True + + # Retry on lock timeouts. + if response.status_code == 409: + log.debug("Retrying due to status code %i", response.status_code) + return True + + # Retry on rate limits. + if response.status_code == 429: + log.debug("Retrying due to status code %i", response.status_code) + return True + + # Retry internal errors. + if response.status_code >= 500: + log.debug("Retrying due to status code %i", response.status_code) + return True + + log.debug("Not retrying") + return False + def is_closed(self) -> bool: return self._client.is_closed @@ -214,117 +468,385 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() - @retry(stop=stop_after_attempt(ZHIPUAI_DEFAULT_MAX_RETRIES)) def request( + self, + cast_type: type[ResponseT], + options: FinalRequestOptions, + remaining_retries: Optional[int] = None, + *, + stream: bool = False, + stream_cls: type[StreamResponse] | None = None, + ) -> ResponseT | StreamResponse: + return self._request( + cast_type=cast_type, + options=options, + stream=stream, + stream_cls=stream_cls, + remaining_retries=remaining_retries, + ) + + def _request( self, *, cast_type: type[ResponseT], - params: ClientRequestParam, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, + options: FinalRequestOptions, + remaining_retries: int | None, + stream: bool, + stream_cls: type[StreamResponse] | None, ) -> ResponseT | StreamResponse: - request = self._prepare_request(params) + retries = self._remaining_retries(remaining_retries, options) + request = self._build_request(options) + kwargs: HttpxSendArgs = {} + if self.custom_auth is not None: + kwargs["auth"] = self.custom_auth try: response = self._client.send( request, - stream=enable_stream, + stream=stream or self._should_stream_response_body(request=request), + **kwargs, ) - response.raise_for_status() except httpx.TimeoutException as err: + log.debug("Encountered httpx.TimeoutException", exc_info=True) + + if retries > 0: + return self._retry_request( + options, + cast_type, + retries, + stream=stream, + stream_cls=stream_cls, + response_headers=None, + ) + + log.debug("Raising timeout error") raise APITimeoutError(request=request) from err - except httpx.HTTPStatusError as err: - err.response.read() - # raise err - raise self._make_status_error(err.response) from None - except Exception as err: - raise err + log.debug("Encountered Exception", exc_info=True) + + if retries > 0: + return self._retry_request( + options, + cast_type, + retries, + stream=stream, + stream_cls=stream_cls, + response_headers=None, + ) + + log.debug("Raising connection error") + raise APIConnectionError(request=request) from err + + log.debug( + 'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase + ) - return self._parse_response( + try: + response.raise_for_status() + except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code + log.debug("Encountered httpx.HTTPStatusError", exc_info=True) + + if retries > 0 and self._should_retry(err.response): + err.response.close() + return self._retry_request( + options, + cast_type, + retries, + err.response.headers, + stream=stream, + stream_cls=stream_cls, + ) + + # If the response is streamed then we need to explicitly read the response + # to completion before attempting to access the response text. + if not err.response.is_closed: + err.response.read() + + log.debug("Re-raising status error") + raise self._make_status_error(err.response) from None + + # return self._parse_response( + # cast_type=cast_type, + # options=options, + # response=response, + # stream=stream, + # stream_cls=stream_cls, + # ) + return self._process_response( cast_type=cast_type, - request_param=params, + options=options, response=response, - enable_stream=enable_stream, + stream=stream, + stream_cls=stream_cls, + ) + + def _retry_request( + self, + options: FinalRequestOptions, + cast_type: type[ResponseT], + remaining_retries: int, + response_headers: httpx.Headers | None, + *, + stream: bool, + stream_cls: type[StreamResponse] | None, + ) -> ResponseT | StreamResponse: + remaining = remaining_retries - 1 + if remaining == 1: + log.debug("1 retry left") + else: + log.debug("%i retries left", remaining) + + timeout = self._calculate_retry_timeout(remaining, options, response_headers) + log.info("Retrying request to %s in %f seconds", options.url, timeout) + + # In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a + # different thread if necessary. + time.sleep(timeout) + + return self._request( + options=options, + cast_type=cast_type, + remaining_retries=remaining, + stream=stream, + stream_cls=stream_cls, + ) + + def _process_response( + self, + *, + cast_type: type[ResponseT], + options: FinalRequestOptions, + response: httpx.Response, + stream: bool, + stream_cls: type[StreamResponse] | None, + ) -> ResponseT: + # _legacy_response with raw_response_header to paser method + if response.request.headers.get(RAW_RESPONSE_HEADER) == "true": + return cast( + ResponseT, + LegacyAPIResponse( + raw=response, + client=self, + cast_type=cast_type, + stream=stream, + stream_cls=stream_cls, + options=options, + ), + ) + + origin = get_origin(cast_type) or cast_type + + if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse): + if not issubclass(origin, APIResponse): + raise TypeError(f"API Response types must subclass {APIResponse}; Received {origin}") + + response_cls = cast("type[BaseAPIResponse[Any]]", cast_type) + return cast( + ResponseT, + response_cls( + raw=response, + client=self, + cast_type=extract_response_type(response_cls), + stream=stream, + stream_cls=stream_cls, + options=options, + ), + ) + + if cast_type == httpx.Response: + return cast(ResponseT, response) + + api_response = APIResponse( + raw=response, + client=self, + cast_type=cast("type[ResponseT]", cast_type), # pyright: ignore[reportUnnecessaryCast] + stream=stream, stream_cls=stream_cls, + options=options, ) + if bool(response.request.headers.get(RAW_RESPONSE_HEADER)): + return cast(ResponseT, api_response) + + return api_response.parse() + + def _request_api_list( + self, + model: type[object], + page: type[SyncPageT], + options: FinalRequestOptions, + ) -> SyncPageT: + def _parser(resp: SyncPageT) -> SyncPageT: + resp._set_private_attributes( + client=self, + model=model, + options=options, + ) + return resp + + options.post_parser = _parser + + return self.request(page, options, stream=False) + @overload def get( self, path: str, *, cast_type: type[ResponseT], options: UserRequestInput = {}, - enable_stream: bool = False, - ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct(method="get", url=path, **options) - return self.request(cast_type=cast_type, params=opts, enable_stream=enable_stream) + stream: Literal[False] = False, + ) -> ResponseT: ... + @overload + def get( + self, + path: str, + *, + cast_type: type[ResponseT], + options: UserRequestInput = {}, + stream: Literal[True], + stream_cls: type[StreamResponse], + ) -> StreamResponse: ... + + @overload + def get( + self, + path: str, + *, + cast_type: type[ResponseT], + options: UserRequestInput = {}, + stream: bool, + stream_cls: type[StreamResponse] | None = None, + ) -> ResponseT | StreamResponse: ... + + def get( + self, + path: str, + *, + cast_type: type[ResponseT], + options: UserRequestInput = {}, + stream: bool = False, + stream_cls: type[StreamResponse] | None = None, + ) -> ResponseT: + opts = FinalRequestOptions.construct(method="get", url=path, **options) + return cast(ResponseT, self.request(cast_type, opts, stream=stream, stream_cls=stream_cls)) + + @overload def post( self, path: str, *, + cast_type: type[ResponseT], body: Body | None = None, + options: UserRequestInput = {}, + files: RequestFiles | None = None, + stream: Literal[False] = False, + ) -> ResponseT: ... + + @overload + def post( + self, + path: str, + *, cast_type: type[ResponseT], + body: Body | None = None, options: UserRequestInput = {}, files: RequestFiles | None = None, - enable_stream: bool = False, + stream: Literal[True], + stream_cls: type[StreamResponse], + ) -> StreamResponse: ... + + @overload + def post( + self, + path: str, + *, + cast_type: type[ResponseT], + body: Body | None = None, + options: UserRequestInput = {}, + files: RequestFiles | None = None, + stream: bool, + stream_cls: type[StreamResponse] | None = None, + ) -> ResponseT | StreamResponse: ... + + def post( + self, + path: str, + *, + cast_type: type[ResponseT], + body: Body | None = None, + options: UserRequestInput = {}, + files: RequestFiles | None = None, + stream: bool = False, stream_cls: type[StreamResponse[Any]] | None = None, ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct( - method="post", json_data=body, files=make_httpx_files(files), url=path, **options + opts = FinalRequestOptions.construct( + method="post", url=path, json_data=body, files=to_httpx_files(files), **options ) - return self.request(cast_type=cast_type, params=opts, enable_stream=enable_stream, stream_cls=stream_cls) + return cast(ResponseT, self.request(cast_type, opts, stream=stream, stream_cls=stream_cls)) def patch( self, path: str, *, - body: Body | None = None, cast_type: type[ResponseT], + body: Body | None = None, options: UserRequestInput = {}, ) -> ResponseT: - opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options) + opts = FinalRequestOptions.construct(method="patch", url=path, json_data=body, **options) return self.request( cast_type=cast_type, - params=opts, + options=opts, ) def put( self, path: str, *, - body: Body | None = None, cast_type: type[ResponseT], + body: Body | None = None, options: UserRequestInput = {}, files: RequestFiles | None = None, ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct( - method="put", url=path, json_data=body, files=make_httpx_files(files), **options + opts = FinalRequestOptions.construct( + method="put", url=path, json_data=body, files=to_httpx_files(files), **options ) return self.request( cast_type=cast_type, - params=opts, + options=opts, ) def delete( self, path: str, *, - body: Body | None = None, cast_type: type[ResponseT], + body: Body | None = None, options: UserRequestInput = {}, ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options) + opts = FinalRequestOptions.construct(method="delete", url=path, json_data=body, **options) return self.request( cast_type=cast_type, - params=opts, + options=opts, ) + def get_api_list( + self, + path: str, + *, + model: type[object], + page: type[SyncPageT], + body: Body | None = None, + options: UserRequestInput = {}, + method: str = "get", + ) -> SyncPageT: + opts = FinalRequestOptions.construct(method=method, url=path, json_data=body, **options) + return self._request_api_list(model, page, opts) + def _make_status_error(self, response) -> APIStatusError: response_text = response.text.strip() status_code = response.status_code @@ -343,24 +865,46 @@ def _make_status_error(self, response) -> APIStatusError: return APIStatusError(message=error_msg, response=response) -def make_user_request_input( - max_retries: int | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, - extra_headers: Headers = None, - extra_body: Body | None = None, +def make_request_options( + *, query: Query | None = None, + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + post_parser: PostParser | NotGiven = NOT_GIVEN, ) -> UserRequestInput: + """Create a dict of type RequestOptions without keys of NotGiven values.""" options: UserRequestInput = {} - if extra_headers is not None: options["headers"] = extra_headers - if max_retries is not None: - options["max_retries"] = max_retries - if not isinstance(timeout, NotGiven): - options["timeout"] = timeout - if query is not None: - options["params"] = query + if extra_body is not None: options["extra_json"] = cast(AnyMapping, extra_body) + if query is not None: + options["params"] = query + + if extra_query is not None: + options["params"] = {**options.get("params", {}), **extra_query} + + if not isinstance(timeout, NotGiven): + options["timeout"] = timeout + + if is_given(post_parser): + # internal + options["post_parser"] = post_parser # type: ignore + return options + + +def _merge_mappings( + obj1: Mapping[_T_co, Union[_T, Omit]], + obj2: Mapping[_T_co, Union[_T, Omit]], +) -> dict[_T_co, _T]: + """Merge two mappings of the same type, removing any values that are instances of `Omit`. + + In cases with duplicate keys the second mapping takes precedence. + """ + merged = {**obj1, **obj2} + return {key: value for key, value in merged.items() if not isinstance(value, Omit)} diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py index b0a91d04a99447..21f158a5f45251 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py @@ -3,9 +3,11 @@ import cachetools.func import jwt -API_TOKEN_TTL_SECONDS = 3 * 60 +# 缓存时间 3分钟 +CACHE_TTL_SECONDS = 3 * 60 -CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30 +# token 有效期比缓存时间 多30秒 +API_TOKEN_TTL_SECONDS = CACHE_TTL_SECONDS + 30 @cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_legacy_binary_response.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_legacy_binary_response.py new file mode 100644 index 00000000000000..51623bd860951f --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_legacy_binary_response.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +import os +from collections.abc import AsyncIterator, Iterator +from typing import Any + +import httpx + + +class HttpxResponseContent: + @property + def content(self) -> bytes: + raise NotImplementedError("This method is not implemented for this class.") + + @property + def text(self) -> str: + raise NotImplementedError("This method is not implemented for this class.") + + @property + def encoding(self) -> str | None: + raise NotImplementedError("This method is not implemented for this class.") + + @property + def charset_encoding(self) -> str | None: + raise NotImplementedError("This method is not implemented for this class.") + + def json(self, **kwargs: Any) -> Any: + raise NotImplementedError("This method is not implemented for this class.") + + def read(self) -> bytes: + raise NotImplementedError("This method is not implemented for this class.") + + def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]: + raise NotImplementedError("This method is not implemented for this class.") + + def iter_text(self, chunk_size: int | None = None) -> Iterator[str]: + raise NotImplementedError("This method is not implemented for this class.") + + def iter_lines(self) -> Iterator[str]: + raise NotImplementedError("This method is not implemented for this class.") + + def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]: + raise NotImplementedError("This method is not implemented for this class.") + + def write_to_file( + self, + file: str | os.PathLike[str], + ) -> None: + raise NotImplementedError("This method is not implemented for this class.") + + def stream_to_file( + self, + file: str | os.PathLike[str], + *, + chunk_size: int | None = None, + ) -> None: + raise NotImplementedError("This method is not implemented for this class.") + + def close(self) -> None: + raise NotImplementedError("This method is not implemented for this class.") + + async def aread(self) -> bytes: + raise NotImplementedError("This method is not implemented for this class.") + + async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]: + raise NotImplementedError("This method is not implemented for this class.") + + async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]: + raise NotImplementedError("This method is not implemented for this class.") + + async def aiter_lines(self) -> AsyncIterator[str]: + raise NotImplementedError("This method is not implemented for this class.") + + async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]: + raise NotImplementedError("This method is not implemented for this class.") + + async def astream_to_file( + self, + file: str | os.PathLike[str], + *, + chunk_size: int | None = None, + ) -> None: + raise NotImplementedError("This method is not implemented for this class.") + + async def aclose(self) -> None: + raise NotImplementedError("This method is not implemented for this class.") + + +class HttpxBinaryResponseContent(HttpxResponseContent): + response: httpx.Response + + def __init__(self, response: httpx.Response) -> None: + self.response = response + + @property + def content(self) -> bytes: + return self.response.content + + @property + def encoding(self) -> str | None: + return self.response.encoding + + @property + def charset_encoding(self) -> str | None: + return self.response.charset_encoding + + def read(self) -> bytes: + return self.response.read() + + def text(self) -> str: + raise NotImplementedError("Not implemented for binary response content") + + def json(self, **kwargs: Any) -> Any: + raise NotImplementedError("Not implemented for binary response content") + + def iter_text(self, chunk_size: int | None = None) -> Iterator[str]: + raise NotImplementedError("Not implemented for binary response content") + + def iter_lines(self) -> Iterator[str]: + raise NotImplementedError("Not implemented for binary response content") + + async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]: + raise NotImplementedError("Not implemented for binary response content") + + async def aiter_lines(self) -> AsyncIterator[str]: + raise NotImplementedError("Not implemented for binary response content") + + def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]: + return self.response.iter_bytes(chunk_size) + + def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]: + return self.response.iter_raw(chunk_size) + + def write_to_file( + self, + file: str | os.PathLike[str], + ) -> None: + """Write the output to the given file. + + Accepts a filename or any path-like object, e.g. pathlib.Path + + Note: if you want to stream the data to the file instead of writing + all at once then you should use `.with_streaming_response` when making + the API request, e.g. `client.with_streaming_response.foo().stream_to_file('my_filename.txt')` + """ + with open(file, mode="wb") as f: + for data in self.response.iter_bytes(): + f.write(data) + + def stream_to_file( + self, + file: str | os.PathLike[str], + *, + chunk_size: int | None = None, + ) -> None: + with open(file, mode="wb") as f: + for data in self.response.iter_bytes(chunk_size): + f.write(data) + + def close(self) -> None: + return self.response.close() + + async def aread(self) -> bytes: + return await self.response.aread() + + async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]: + return self.response.aiter_bytes(chunk_size) + + async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]: + return self.response.aiter_raw(chunk_size) + + async def astream_to_file( + self, + file: str | os.PathLike[str], + *, + chunk_size: int | None = None, + ) -> None: + path = anyio.Path(file) + async with await path.open(mode="wb") as f: + async for data in self.response.aiter_bytes(chunk_size): + await f.write(data) + + async def aclose(self) -> None: + return await self.response.aclose() + + +class HttpxTextBinaryResponseContent(HttpxBinaryResponseContent): + response: httpx.Response + + @property + def text(self) -> str: + return self.response.text + + def json(self, **kwargs: Any) -> Any: + return self.response.json(**kwargs) + + def iter_text(self, chunk_size: int | None = None) -> Iterator[str]: + return self.response.iter_text(chunk_size) + + def iter_lines(self) -> Iterator[str]: + return self.response.iter_lines() + + async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]: + return self.response.aiter_text(chunk_size) + + async def aiter_lines(self) -> AsyncIterator[str]: + return self.response.aiter_lines() diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_legacy_response.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_legacy_response.py new file mode 100644 index 00000000000000..47183b9eee9c0d --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_legacy_response.py @@ -0,0 +1,341 @@ +from __future__ import annotations + +import datetime +import functools +import inspect +import logging +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, get_origin, overload + +import httpx +import pydantic +from typing_extensions import ParamSpec, override + +from ._base_models import BaseModel, is_basemodel +from ._base_type import NoneType +from ._constants import RAW_RESPONSE_HEADER +from ._errors import APIResponseValidationError +from ._legacy_binary_response import HttpxResponseContent, HttpxTextBinaryResponseContent +from ._sse_client import StreamResponse, extract_stream_chunk_type, is_stream_class_type +from ._utils import extract_type_arg, is_annotated_type, is_given + +if TYPE_CHECKING: + from ._http_client import HttpClient + from ._request_opt import FinalRequestOptions + +P = ParamSpec("P") +R = TypeVar("R") +_T = TypeVar("_T") + +log: logging.Logger = logging.getLogger(__name__) + + +class LegacyAPIResponse(Generic[R]): + """This is a legacy class as it will be replaced by `APIResponse` + and `AsyncAPIResponse` in the `_response.py` file in the next major + release. + + For the sync client this will mostly be the same with the exception + of `content` & `text` will be methods instead of properties. In the + async client, all methods will be async. + + A migration script will be provided & the migration in general should + be smooth. + """ + + _cast_type: type[R] + _client: HttpClient + _parsed_by_type: dict[type[Any], Any] + _stream: bool + _stream_cls: type[StreamResponse[Any]] | None + _options: FinalRequestOptions + + http_response: httpx.Response + + def __init__( + self, + *, + raw: httpx.Response, + cast_type: type[R], + client: HttpClient, + stream: bool, + stream_cls: type[StreamResponse[Any]] | None, + options: FinalRequestOptions, + ) -> None: + self._cast_type = cast_type + self._client = client + self._parsed_by_type = {} + self._stream = stream + self._stream_cls = stream_cls + self._options = options + self.http_response = raw + + @property + def request_id(self) -> str | None: + return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return] + + @overload + def parse(self, *, to: type[_T]) -> _T: ... + + @overload + def parse(self) -> R: ... + + def parse(self, *, to: type[_T] | None = None) -> R | _T: + """Returns the rich python representation of this response's data. + + NOTE: For the async client: this will become a coroutine in the next major version. + + For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. + + You can customise the type that the response is parsed into through + the `to` argument, e.g. + + ```py + from zhipuai import BaseModel + + + class MyModel(BaseModel): + foo: str + + + obj = response.parse(to=MyModel) + print(obj.foo) + ``` + + We support parsing: + - `BaseModel` + - `dict` + - `list` + - `Union` + - `str` + - `int` + - `float` + - `httpx.Response` + """ + cache_key = to if to is not None else self._cast_type + cached = self._parsed_by_type.get(cache_key) + if cached is not None: + return cached # type: ignore[no-any-return] + + parsed = self._parse(to=to) + if is_given(self._options.post_parser): + parsed = self._options.post_parser(parsed) + + self._parsed_by_type[cache_key] = parsed + return parsed + + @property + def headers(self) -> httpx.Headers: + return self.http_response.headers + + @property + def http_request(self) -> httpx.Request: + return self.http_response.request + + @property + def status_code(self) -> int: + return self.http_response.status_code + + @property + def url(self) -> httpx.URL: + return self.http_response.url + + @property + def method(self) -> str: + return self.http_request.method + + @property + def content(self) -> bytes: + """Return the binary response content. + + NOTE: this will be removed in favour of `.read()` in the + next major version. + """ + return self.http_response.content + + @property + def text(self) -> str: + """Return the decoded response content. + + NOTE: this will be turned into a method in the next major version. + """ + return self.http_response.text + + @property + def http_version(self) -> str: + return self.http_response.http_version + + @property + def is_closed(self) -> bool: + return self.http_response.is_closed + + @property + def elapsed(self) -> datetime.timedelta: + """The time taken for the complete request/response cycle to complete.""" + return self.http_response.elapsed + + def _parse(self, *, to: type[_T] | None = None) -> R | _T: + # unwrap `Annotated[T, ...]` -> `T` + if to and is_annotated_type(to): + to = extract_type_arg(to, 0) + + if self._stream: + if to: + if not is_stream_class_type(to): + raise TypeError(f"Expected custom parse type to be a subclass of {StreamResponse}") + + return cast( + _T, + to( + cast_type=extract_stream_chunk_type( + to, + failure_message="Expected custom stream type to be passed with a type argument, e.g. StreamResponse[ChunkType]", # noqa: E501 + ), + response=self.http_response, + client=cast(Any, self._client), + ), + ) + + if self._stream_cls: + return cast( + R, + self._stream_cls( + cast_type=extract_stream_chunk_type(self._stream_cls), + response=self.http_response, + client=cast(Any, self._client), + ), + ) + + stream_cls = cast("type[StreamResponse[Any]] | None", self._client._default_stream_cls) + if stream_cls is None: + raise MissingStreamClassError() + + return cast( + R, + stream_cls( + cast_type=self._cast_type, + response=self.http_response, + client=cast(Any, self._client), + ), + ) + + cast_type = to if to is not None else self._cast_type + + # unwrap `Annotated[T, ...]` -> `T` + if is_annotated_type(cast_type): + cast_type = extract_type_arg(cast_type, 0) + + if cast_type is NoneType: + return cast(R, None) + + response = self.http_response + if cast_type == str: + return cast(R, response.text) + + if cast_type == int: + return cast(R, int(response.text)) + + if cast_type == float: + return cast(R, float(response.text)) + + origin = get_origin(cast_type) or cast_type + + if inspect.isclass(origin) and issubclass(origin, HttpxResponseContent): + # in the response, e.g. mime file + *_, filename = response.headers.get("content-disposition", "").split("filename=") + # 判断文件类型是jsonl类型的使用HttpxTextBinaryResponseContent + if filename and filename.endswith(".jsonl") or filename and filename.endswith(".xlsx"): + return cast(R, HttpxTextBinaryResponseContent(response)) + else: + return cast(R, cast_type(response)) # type: ignore + + if origin == LegacyAPIResponse: + raise RuntimeError("Unexpected state - cast_type is `APIResponse`") + + if inspect.isclass(origin) and issubclass(origin, httpx.Response): + # Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response + # and pass that class to our request functions. We cannot change the variance to be either + # covariant or contravariant as that makes our usage of ResponseT illegal. We could construct + # the response class ourselves but that is something that should be supported directly in httpx + # as it would be easy to incorrectly construct the Response object due to the multitude of arguments. + if cast_type != httpx.Response: + raise ValueError("Subclasses of httpx.Response cannot be passed to `cast_type`") + return cast(R, response) + + if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel): + raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`") + + if ( + cast_type is not object + and origin is not list + and origin is not dict + and origin is not Union + and not issubclass(origin, BaseModel) + ): + raise RuntimeError( + f"Unsupported type, expected {cast_type} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}." # noqa: E501 + ) + + # split is required to handle cases where additional information is included + # in the response, e.g. application/json; charset=utf-8 + content_type, *_ = response.headers.get("content-type", "*").split(";") + if content_type != "application/json": + if is_basemodel(cast_type): + try: + data = response.json() + except Exception as exc: + log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc) + else: + return self._client._process_response_data( + data=data, + cast_type=cast_type, # type: ignore + response=response, + ) + + if self._client._strict_response_validation: + raise APIResponseValidationError( + response=response, + message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.", # noqa: E501 + json_data=response.text, + ) + + # If the API responds with content that isn't JSON then we just return + # the (decoded) text without performing any parsing so that you can still + # handle the response however you need to. + return response.text # type: ignore + + data = response.json() + + return self._client._process_response_data( + data=data, + cast_type=cast_type, # type: ignore + response=response, + ) + + @override + def __repr__(self) -> str: + return f"" + + +class MissingStreamClassError(TypeError): + def __init__(self) -> None: + super().__init__( + "The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference", # noqa: E501 + ) + + +def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIResponse[R]]: + """Higher order function that takes one of our bound API methods and wraps it + to support returning the raw `APIResponse` object directly. + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]: + extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "true" + + kwargs["extra_headers"] = extra_headers + + return cast(LegacyAPIResponse[R], func(*args, **kwargs)) + + return wrapped diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py index ac459151fc3a42..c3b894b3a3d88f 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py @@ -1,48 +1,97 @@ from __future__ import annotations -from typing import Any, ClassVar, Union +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, ClassVar, Union, cast +import pydantic.generics from httpx import Timeout -from pydantic import ConfigDict -from typing_extensions import TypedDict, Unpack +from typing_extensions import Required, TypedDict, Unpack, final -from ._base_type import Body, Headers, HttpxRequestFiles, NotGiven, Query -from ._utils import remove_notgiven_indict +from ._base_compat import PYDANTIC_V2, ConfigDict +from ._base_type import AnyMapping, Body, Headers, HttpxRequestFiles, NotGiven, Query +from ._constants import RAW_RESPONSE_HEADER +from ._utils import is_given, strip_not_given class UserRequestInput(TypedDict, total=False): + headers: Headers max_retries: int timeout: float | Timeout | None + params: Query + extra_json: AnyMapping + + +class FinalRequestOptionsInput(TypedDict, total=False): + method: Required[str] + url: Required[str] + params: Query headers: Headers - params: Query | None + max_retries: int + timeout: float | Timeout | None + files: HttpxRequestFiles | None + json_data: Body + extra_json: AnyMapping -class ClientRequestParam: +@final +class FinalRequestOptions(pydantic.BaseModel): method: str url: str - max_retries: Union[int, NotGiven] = NotGiven() - timeout: Union[float, NotGiven] = NotGiven() + params: Query = {} headers: Union[Headers, NotGiven] = NotGiven() - json_data: Union[Body, None] = None + max_retries: Union[int, NotGiven] = NotGiven() + timeout: Union[float, Timeout, None, NotGiven] = NotGiven() files: Union[HttpxRequestFiles, None] = None - params: Query = {} - model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) + idempotency_key: Union[str, None] = None + post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven() + + # It should be noted that we cannot use `json` here as that would override + # a BaseModel method in an incompatible fashion. + json_data: Union[Body, None] = None + extra_json: Union[AnyMapping, None] = None + + if PYDANTIC_V2: + model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) + else: - def get_max_retries(self, max_retries) -> int: + class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] + arbitrary_types_allowed: bool = True + + def get_max_retries(self, max_retries: int) -> int: if isinstance(self.max_retries, NotGiven): return max_retries return self.max_retries + def _strip_raw_response_header(self) -> None: + if not is_given(self.headers): + return + + if self.headers.get(RAW_RESPONSE_HEADER): + self.headers = {**self.headers} + self.headers.pop(RAW_RESPONSE_HEADER) + + # override the `construct` method so that we can run custom transformations. + # this is necessary as we don't want to do any actual runtime type checking + # (which means we can't use validators) but we do want to ensure that `NotGiven` + # values are not present + # + # type ignore required because we're adding explicit types to `**values` @classmethod def construct( # type: ignore cls, _fields_set: set[str] | None = None, **values: Unpack[UserRequestInput], - ) -> ClientRequestParam: - kwargs: dict[str, Any] = {key: remove_notgiven_indict(value) for key, value in values.items()} - client = cls() - client.__dict__.update(kwargs) - - return client + ) -> FinalRequestOptions: + kwargs: dict[str, Any] = { + # we unconditionally call `strip_not_given` on any value + # as it will just ignore any non-mapping types + key: strip_not_given(value) + for key, value in values.items() + } + if PYDANTIC_V2: + return super().model_construct(_fields_set, **kwargs) + return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated] - model_construct = construct + if not TYPE_CHECKING: + # type checkers incorrectly complain about this assignment + model_construct = construct diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py index 56e60a793407cd..45443da662d57e 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py @@ -1,87 +1,193 @@ from __future__ import annotations import datetime -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, get_args, get_origin +import inspect +import logging +from collections.abc import Iterator +from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, get_origin, overload import httpx import pydantic -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, override +from ._base_models import BaseModel, is_basemodel from ._base_type import NoneType -from ._sse_client import StreamResponse +from ._errors import APIResponseValidationError, ZhipuAIError +from ._sse_client import StreamResponse, extract_stream_chunk_type, is_stream_class_type +from ._utils import extract_type_arg, extract_type_var_from_base, is_annotated_type, is_given if TYPE_CHECKING: from ._http_client import HttpClient + from ._request_opt import FinalRequestOptions P = ParamSpec("P") R = TypeVar("R") +_T = TypeVar("_T") +_APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]") +log: logging.Logger = logging.getLogger(__name__) -class HttpResponse(Generic[R]): +class BaseAPIResponse(Generic[R]): _cast_type: type[R] _client: HttpClient - _parsed: R | None - _enable_stream: bool + _parsed_by_type: dict[type[Any], Any] + _is_sse_stream: bool _stream_cls: type[StreamResponse[Any]] + _options: FinalRequestOptions http_response: httpx.Response def __init__( self, *, - raw_response: httpx.Response, + raw: httpx.Response, cast_type: type[R], client: HttpClient, - enable_stream: bool = False, + stream: bool, stream_cls: type[StreamResponse[Any]] | None = None, + options: FinalRequestOptions, ) -> None: self._cast_type = cast_type self._client = client - self._parsed = None + self._parsed_by_type = {} + self._is_sse_stream = stream self._stream_cls = stream_cls - self._enable_stream = enable_stream - self.http_response = raw_response + self._options = options + self.http_response = raw - def parse(self) -> R: - self._parsed = self._parse() - return self._parsed + def _parse(self, *, to: type[_T] | None = None) -> R | _T: + # unwrap `Annotated[T, ...]` -> `T` + if to and is_annotated_type(to): + to = extract_type_arg(to, 0) - def _parse(self) -> R: - if self._enable_stream: - self._parsed = cast( + if self._is_sse_stream: + if to: + if not is_stream_class_type(to): + raise TypeError(f"Expected custom parse type to be a subclass of {StreamResponse}") + + return cast( + _T, + to( + cast_type=extract_stream_chunk_type( + to, + failure_message="Expected custom stream type to be passed with a type argument, e.g. StreamResponse[ChunkType]", # noqa: E501 + ), + response=self.http_response, + client=cast(Any, self._client), + ), + ) + + if self._stream_cls: + return cast( + R, + self._stream_cls( + cast_type=extract_stream_chunk_type(self._stream_cls), + response=self.http_response, + client=cast(Any, self._client), + ), + ) + + stream_cls = cast("type[Stream[Any]] | None", self._client._default_stream_cls) + if stream_cls is None: + raise MissingStreamClassError() + + return cast( R, - self._stream_cls( - cast_type=cast(type, get_args(self._stream_cls)[0]), + stream_cls( + cast_type=self._cast_type, response=self.http_response, - client=self._client, + client=cast(Any, self._client), ), ) - return self._parsed - cast_type = self._cast_type + + cast_type = to if to is not None else self._cast_type + + # unwrap `Annotated[T, ...]` -> `T` + if is_annotated_type(cast_type): + cast_type = extract_type_arg(cast_type, 0) + if cast_type is NoneType: return cast(R, None) - http_response = self.http_response + + response = self.http_response if cast_type == str: - return cast(R, http_response.text) + return cast(R, response.text) + + if cast_type == bytes: + return cast(R, response.content) + + if cast_type == int: + return cast(R, int(response.text)) + + if cast_type == float: + return cast(R, float(response.text)) - content_type, *_ = http_response.headers.get("content-type", "application/json").split(";") origin = get_origin(cast_type) or cast_type + + # handle the legacy binary response case + if inspect.isclass(cast_type) and cast_type.__name__ == "HttpxBinaryResponseContent": + return cast(R, cast_type(response)) # type: ignore + + if origin == APIResponse: + raise RuntimeError("Unexpected state - cast_type is `APIResponse`") + + if inspect.isclass(origin) and issubclass(origin, httpx.Response): + # Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response + # and pass that class to our request functions. We cannot change the variance to be either + # covariant or contravariant as that makes our usage of ResponseT illegal. We could construct + # the response class ourselves but that is something that should be supported directly in httpx + # as it would be easy to incorrectly construct the Response object due to the multitude of arguments. + if cast_type != httpx.Response: + raise ValueError("Subclasses of httpx.Response cannot be passed to `cast_type`") + return cast(R, response) + + if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel): + raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`") + + if ( + cast_type is not object + and origin is not list + and origin is not dict + and origin is not Union + and not issubclass(origin, BaseModel) + ): + raise RuntimeError( + f"Unsupported type, expected {cast_type} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}." # noqa: E501 + ) + + # split is required to handle cases where additional information is included + # in the response, e.g. application/json; charset=utf-8 + content_type, *_ = response.headers.get("content-type", "*").split(";") if content_type != "application/json": - if issubclass(origin, pydantic.BaseModel): - data = http_response.json() - return self._client._process_response_data( - data=data, - cast_type=cast_type, # type: ignore - response=http_response, + if is_basemodel(cast_type): + try: + data = response.json() + except Exception as exc: + log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc) + else: + return self._client._process_response_data( + data=data, + cast_type=cast_type, # type: ignore + response=response, + ) + + if self._client._strict_response_validation: + raise APIResponseValidationError( + response=response, + message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.", # noqa: E501 + json_data=response.text, ) - return http_response.text + # If the API responds with content that isn't JSON then we just return + # the (decoded) text without performing any parsing so that you can still + # handle the response however you need to. + return response.text # type: ignore - data = http_response.json() + data = response.json() return self._client._process_response_data( data=data, cast_type=cast_type, # type: ignore - response=http_response, + response=response, ) @property @@ -90,6 +196,7 @@ def headers(self) -> httpx.Headers: @property def http_request(self) -> httpx.Request: + """Returns the httpx Request instance associated with the current response.""" return self.http_response.request @property @@ -98,24 +205,194 @@ def status_code(self) -> int: @property def url(self) -> httpx.URL: + """Returns the URL for which the request was made.""" return self.http_response.url @property def method(self) -> str: return self.http_request.method - @property - def content(self) -> bytes: - return self.http_response.content - - @property - def text(self) -> str: - return self.http_response.text - @property def http_version(self) -> str: return self.http_response.http_version @property def elapsed(self) -> datetime.timedelta: + """The time taken for the complete request/response cycle to complete.""" return self.http_response.elapsed + + @property + def is_closed(self) -> bool: + """Whether or not the response body has been closed. + + If this is False then there is response data that has not been read yet. + You must either fully consume the response body or call `.close()` + before discarding the response to prevent resource leaks. + """ + return self.http_response.is_closed + + @override + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_type}>" # noqa: E501 + + +class APIResponse(BaseAPIResponse[R]): + @property + def request_id(self) -> str | None: + return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return] + + @overload + def parse(self, *, to: type[_T]) -> _T: ... + + @overload + def parse(self) -> R: ... + + def parse(self, *, to: type[_T] | None = None) -> R | _T: + """Returns the rich python representation of this response's data. + + For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. + + You can customise the type that the response is parsed into through + the `to` argument, e.g. + + ```py + from openai import BaseModel + + + class MyModel(BaseModel): + foo: str + + + obj = response.parse(to=MyModel) + print(obj.foo) + ``` + + We support parsing: + - `BaseModel` + - `dict` + - `list` + - `Union` + - `str` + - `int` + - `float` + - `httpx.Response` + """ + cache_key = to if to is not None else self._cast_type + cached = self._parsed_by_type.get(cache_key) + if cached is not None: + return cached # type: ignore[no-any-return] + + if not self._is_sse_stream: + self.read() + + parsed = self._parse(to=to) + if is_given(self._options.post_parser): + parsed = self._options.post_parser(parsed) + + self._parsed_by_type[cache_key] = parsed + return parsed + + def read(self) -> bytes: + """Read and return the binary response content.""" + try: + return self.http_response.read() + except httpx.StreamConsumed as exc: + # The default error raised by httpx isn't very + # helpful in our case so we re-raise it with + # a different error message. + raise StreamAlreadyConsumed() from exc + + def text(self) -> str: + """Read and decode the response content into a string.""" + self.read() + return self.http_response.text + + def json(self) -> object: + """Read and decode the JSON response content.""" + self.read() + return self.http_response.json() + + def close(self) -> None: + """Close the response and release the connection. + + Automatically called if the response body is read to completion. + """ + self.http_response.close() + + def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]: + """ + A byte-iterator over the decoded response content. + + This automatically handles gzip, deflate and brotli encoded responses. + """ + yield from self.http_response.iter_bytes(chunk_size) + + def iter_text(self, chunk_size: int | None = None) -> Iterator[str]: + """A str-iterator over the decoded response content + that handles both gzip, deflate, etc but also detects the content's + string encoding. + """ + yield from self.http_response.iter_text(chunk_size) + + def iter_lines(self) -> Iterator[str]: + """Like `iter_text()` but will only yield chunks for each line""" + yield from self.http_response.iter_lines() + + +class MissingStreamClassError(TypeError): + def __init__(self) -> None: + super().__init__( + "The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference", # noqa: E501 + ) + + +class StreamAlreadyConsumed(ZhipuAIError): # noqa: N818 + """ + Attempted to read or stream content, but the content has already + been streamed. + + This can happen if you use a method like `.iter_lines()` and then attempt + to read th entire response body afterwards, e.g. + + ```py + response = await client.post(...) + async for line in response.iter_lines(): + ... # do something with `line` + + content = await response.read() + # ^ error + ``` + + If you want this behaviour you'll need to either manually accumulate the response + content or call `await response.read()` before iterating over the stream. + """ + + def __init__(self) -> None: + message = ( + "Attempted to read or stream some content, but the content has " + "already been streamed. " + "This could be due to attempting to stream the response " + "content more than once." + "\n\n" + "You can fix this by manually accumulating the response content while streaming " + "or by calling `.read()` before starting to stream." + ) + super().__init__(message) + + +def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type: + """Given a type like `APIResponse[T]`, returns the generic type variable `T`. + + This also handles the case where a concrete subclass is given, e.g. + ```py + class MyResponse(APIResponse[bytes]): + ... + + extract_response_type(MyResponse) -> bytes + ``` + """ + return extract_type_var_from_base( + typ, + generic_bases=cast("tuple[type, ...]", (BaseAPIResponse, APIResponse)), + index=0, + ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py index ec2745d05912de..cbc449d24421d0 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py @@ -1,13 +1,16 @@ from __future__ import annotations +import inspect import json from collections.abc import Iterator, Mapping -from typing import TYPE_CHECKING, Generic +from typing import TYPE_CHECKING, Generic, TypeGuard, cast import httpx +from . import get_origin from ._base_type import ResponseT from ._errors import APIResponseError +from ._utils import extract_type_var_from_base, is_mapping _FIELD_SEPARATOR = ":" @@ -53,8 +56,41 @@ def __stream__(self) -> Iterator[ResponseT]: request=self.response.request, json_data=data["error"], ) + if sse.event is None: + data = sse.json_data() + if is_mapping(data) and data.get("error"): + message = None + error = data.get("error") + if is_mapping(error): + message = error.get("message") + if not message or not isinstance(message, str): + message = "An error occurred during streaming" + raise APIResponseError( + message=message, + request=self.response.request, + json_data=data["error"], + ) yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response) + + else: + data = sse.json_data() + + if sse.event == "error" and is_mapping(data) and data.get("error"): + message = None + error = data.get("error") + if is_mapping(error): + message = error.get("message") + if not message or not isinstance(message, str): + message = "An error occurred during streaming" + + raise APIResponseError( + message=message, + request=self.response.request, + json_data=data["error"], + ) + yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response) + for sse in iterator: pass @@ -138,3 +174,33 @@ def decode_line(self, line: str): except (TypeError, ValueError): pass return + + +def is_stream_class_type(typ: type) -> TypeGuard[type[StreamResponse[object]]]: + """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`""" + origin = get_origin(typ) or typ + return inspect.isclass(origin) and issubclass(origin, StreamResponse) + + +def extract_stream_chunk_type( + stream_cls: type, + *, + failure_message: str | None = None, +) -> type: + """Given a type like `StreamResponse[T]`, returns the generic type variable `T`. + + This also handles the case where a concrete subclass is given, e.g. + ```py + class MyStream(StreamResponse[bytes]): + ... + + extract_stream_chunk_type(MyStream) -> bytes + ``` + """ + + return extract_type_var_from_base( + stream_cls, + index=0, + generic_bases=cast("tuple[type, ...]", (StreamResponse,)), + failure_message=failure_message, + ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py deleted file mode 100644 index 6b610567daa099..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterable, Mapping -from typing import TypeVar - -from ._base_type import NotGiven - - -def remove_notgiven_indict(obj): - if obj is None or (not isinstance(obj, Mapping)): - return obj - return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)} - - -_T = TypeVar("_T") - - -def flatten(t: Iterable[Iterable[_T]]) -> list[_T]: - return [item for sublist in t for item in sublist] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/__init__.py new file mode 100644 index 00000000000000..a66b095816b8b0 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/__init__.py @@ -0,0 +1,52 @@ +from ._utils import ( # noqa: I001 + remove_notgiven_indict as remove_notgiven_indict, # noqa: PLC0414 + flatten as flatten, # noqa: PLC0414 + is_dict as is_dict, # noqa: PLC0414 + is_list as is_list, # noqa: PLC0414 + is_given as is_given, # noqa: PLC0414 + is_tuple as is_tuple, # noqa: PLC0414 + is_mapping as is_mapping, # noqa: PLC0414 + is_tuple_t as is_tuple_t, # noqa: PLC0414 + parse_date as parse_date, # noqa: PLC0414 + is_iterable as is_iterable, # noqa: PLC0414 + is_sequence as is_sequence, # noqa: PLC0414 + coerce_float as coerce_float, # noqa: PLC0414 + is_mapping_t as is_mapping_t, # noqa: PLC0414 + removeprefix as removeprefix, # noqa: PLC0414 + removesuffix as removesuffix, # noqa: PLC0414 + extract_files as extract_files, # noqa: PLC0414 + is_sequence_t as is_sequence_t, # noqa: PLC0414 + required_args as required_args, # noqa: PLC0414 + coerce_boolean as coerce_boolean, # noqa: PLC0414 + coerce_integer as coerce_integer, # noqa: PLC0414 + file_from_path as file_from_path, # noqa: PLC0414 + parse_datetime as parse_datetime, # noqa: PLC0414 + strip_not_given as strip_not_given, # noqa: PLC0414 + deepcopy_minimal as deepcopy_minimal, # noqa: PLC0414 + get_async_library as get_async_library, # noqa: PLC0414 + maybe_coerce_float as maybe_coerce_float, # noqa: PLC0414 + get_required_header as get_required_header, # noqa: PLC0414 + maybe_coerce_boolean as maybe_coerce_boolean, # noqa: PLC0414 + maybe_coerce_integer as maybe_coerce_integer, # noqa: PLC0414 + drop_prefix_image_data as drop_prefix_image_data, # noqa: PLC0414 +) + + +from ._typing import ( + is_list_type as is_list_type, # noqa: PLC0414 + is_union_type as is_union_type, # noqa: PLC0414 + extract_type_arg as extract_type_arg, # noqa: PLC0414 + is_iterable_type as is_iterable_type, # noqa: PLC0414 + is_required_type as is_required_type, # noqa: PLC0414 + is_annotated_type as is_annotated_type, # noqa: PLC0414 + strip_annotated_type as strip_annotated_type, # noqa: PLC0414 + extract_type_var_from_base as extract_type_var_from_base, # noqa: PLC0414 +) + +from ._transform import ( + PropertyInfo as PropertyInfo, # noqa: PLC0414 + transform as transform, # noqa: PLC0414 + async_transform as async_transform, # noqa: PLC0414 + maybe_transform as maybe_transform, # noqa: PLC0414 + async_maybe_transform as async_maybe_transform, # noqa: PLC0414 +) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_transform.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_transform.py new file mode 100644 index 00000000000000..e8ef1f79358a96 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_transform.py @@ -0,0 +1,383 @@ +from __future__ import annotations + +import base64 +import io +import pathlib +from collections.abc import Mapping +from datetime import date, datetime +from typing import Any, Literal, TypeVar, cast, get_args, get_type_hints + +import anyio +import pydantic +from typing_extensions import override + +from .._base_compat import is_typeddict, model_dump +from .._files import is_base64_file_input +from ._typing import ( + extract_type_arg, + is_annotated_type, + is_iterable_type, + is_list_type, + is_required_type, + is_union_type, + strip_annotated_type, +) +from ._utils import ( + is_iterable, + is_list, + is_mapping, +) + +_T = TypeVar("_T") + + +# TODO: support for drilling globals() and locals() +# TODO: ensure works correctly with forward references in all cases + + +PropertyFormat = Literal["iso8601", "base64", "custom"] + + +class PropertyInfo: + """Metadata class to be used in Annotated types to provide information about a given type. + + For example: + + class MyParams(TypedDict): + account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')] + + This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API. + """ # noqa: E501 + + alias: str | None + format: PropertyFormat | None + format_template: str | None + discriminator: str | None + + def __init__( + self, + *, + alias: str | None = None, + format: PropertyFormat | None = None, + format_template: str | None = None, + discriminator: str | None = None, + ) -> None: + self.alias = alias + self.format = format + self.format_template = format_template + self.discriminator = discriminator + + @override + def __repr__(self) -> str: + return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')" # noqa: E501 + + +def maybe_transform( + data: object, + expected_type: object, +) -> Any | None: + """Wrapper over `transform()` that allows `None` to be passed. + + See `transform()` for more details. + """ + if data is None: + return None + return transform(data, expected_type) + + +# Wrapper over _transform_recursive providing fake types +def transform( + data: _T, + expected_type: object, +) -> _T: + """Transform dictionaries based off of type information from the given type, for example: + + ```py + class Params(TypedDict, total=False): + card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]] + + + transformed = transform({"card_id": ""}, Params) + # {'cardID': ''} + ``` + + Any keys / data that does not have type information given will be included as is. + + It should be noted that the transformations that this function does are not represented in the type system. + """ + transformed = _transform_recursive(data, annotation=cast(type, expected_type)) + return cast(_T, transformed) + + +def _get_annotated_type(type_: type) -> type | None: + """If the given type is an `Annotated` type then it is returned, if not `None` is returned. + + This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]` + """ + if is_required_type(type_): + # Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]` + type_ = get_args(type_)[0] + + if is_annotated_type(type_): + return type_ + + return None + + +def _maybe_transform_key(key: str, type_: type) -> str: + """Transform the given `data` based on the annotations provided in `type_`. + + Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata. + """ + annotated_type = _get_annotated_type(type_) + if annotated_type is None: + # no `Annotated` definition for this type, no transformation needed + return key + + # ignore the first argument as it is the actual type + annotations = get_args(annotated_type)[1:] + for annotation in annotations: + if isinstance(annotation, PropertyInfo) and annotation.alias is not None: + return annotation.alias + + return key + + +def _transform_recursive( + data: object, + *, + annotation: type, + inner_type: type | None = None, +) -> object: + """Transform the given data against the expected type. + + Args: + annotation: The direct type annotation given to the particular piece of data. + This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc + + inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type + is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in + the list can be transformed using the metadata from the container type. + + Defaults to the same value as the `annotation` argument. + """ + if inner_type is None: + inner_type = annotation + + stripped_type = strip_annotated_type(inner_type) + if is_typeddict(stripped_type) and is_mapping(data): + return _transform_typeddict(data, stripped_type) + + if ( + # List[T] + (is_list_type(stripped_type) and is_list(data)) + # Iterable[T] + or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + ): + inner_type = extract_type_arg(stripped_type, 0) + return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] + + if is_union_type(stripped_type): + # For union types we run the transformation against all subtypes to ensure that everything is transformed. + # + # TODO: there may be edge cases where the same normalized field name will transform to two different names + # in different subtypes. + for subtype in get_args(stripped_type): + data = _transform_recursive(data, annotation=annotation, inner_type=subtype) + return data + + if isinstance(data, pydantic.BaseModel): + return model_dump(data, exclude_unset=True) + + annotated_type = _get_annotated_type(annotation) + if annotated_type is None: + return data + + # ignore the first argument as it is the actual type + annotations = get_args(annotated_type)[1:] + for annotation in annotations: + if isinstance(annotation, PropertyInfo) and annotation.format is not None: + return _format_data(data, annotation.format, annotation.format_template) + + return data + + +def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object: + if isinstance(data, date | datetime): + if format_ == "iso8601": + return data.isoformat() + + if format_ == "custom" and format_template is not None: + return data.strftime(format_template) + + if format_ == "base64" and is_base64_file_input(data): + binary: str | bytes | None = None + + if isinstance(data, pathlib.Path): + binary = data.read_bytes() + elif isinstance(data, io.IOBase): + binary = data.read() + + if isinstance(binary, str): # type: ignore[unreachable] + binary = binary.encode() + + if not isinstance(binary, bytes): + raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") + + return base64.b64encode(binary).decode("ascii") + + return data + + +def _transform_typeddict( + data: Mapping[str, object], + expected_type: type, +) -> Mapping[str, object]: + result: dict[str, object] = {} + annotations = get_type_hints(expected_type, include_extras=True) + for key, value in data.items(): + type_ = annotations.get(key) + if type_ is None: + # we do not have a type annotation for this field, leave it as is + result[key] = value + else: + result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_) + return result + + +async def async_maybe_transform( + data: object, + expected_type: object, +) -> Any | None: + """Wrapper over `async_transform()` that allows `None` to be passed. + + See `async_transform()` for more details. + """ + if data is None: + return None + return await async_transform(data, expected_type) + + +async def async_transform( + data: _T, + expected_type: object, +) -> _T: + """Transform dictionaries based off of type information from the given type, for example: + + ```py + class Params(TypedDict, total=False): + card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]] + + + transformed = transform({"card_id": ""}, Params) + # {'cardID': ''} + ``` + + Any keys / data that does not have type information given will be included as is. + + It should be noted that the transformations that this function does are not represented in the type system. + """ + transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type)) + return cast(_T, transformed) + + +async def _async_transform_recursive( + data: object, + *, + annotation: type, + inner_type: type | None = None, +) -> object: + """Transform the given data against the expected type. + + Args: + annotation: The direct type annotation given to the particular piece of data. + This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc + + inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type + is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in + the list can be transformed using the metadata from the container type. + + Defaults to the same value as the `annotation` argument. + """ + if inner_type is None: + inner_type = annotation + + stripped_type = strip_annotated_type(inner_type) + if is_typeddict(stripped_type) and is_mapping(data): + return await _async_transform_typeddict(data, stripped_type) + + if ( + # List[T] + (is_list_type(stripped_type) and is_list(data)) + # Iterable[T] + or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + ): + inner_type = extract_type_arg(stripped_type, 0) + return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] + + if is_union_type(stripped_type): + # For union types we run the transformation against all subtypes to ensure that everything is transformed. + # + # TODO: there may be edge cases where the same normalized field name will transform to two different names + # in different subtypes. + for subtype in get_args(stripped_type): + data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype) + return data + + if isinstance(data, pydantic.BaseModel): + return model_dump(data, exclude_unset=True) + + annotated_type = _get_annotated_type(annotation) + if annotated_type is None: + return data + + # ignore the first argument as it is the actual type + annotations = get_args(annotated_type)[1:] + for annotation in annotations: + if isinstance(annotation, PropertyInfo) and annotation.format is not None: + return await _async_format_data(data, annotation.format, annotation.format_template) + + return data + + +async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object: + if isinstance(data, date | datetime): + if format_ == "iso8601": + return data.isoformat() + + if format_ == "custom" and format_template is not None: + return data.strftime(format_template) + + if format_ == "base64" and is_base64_file_input(data): + binary: str | bytes | None = None + + if isinstance(data, pathlib.Path): + binary = await anyio.Path(data).read_bytes() + elif isinstance(data, io.IOBase): + binary = data.read() + + if isinstance(binary, str): # type: ignore[unreachable] + binary = binary.encode() + + if not isinstance(binary, bytes): + raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") + + return base64.b64encode(binary).decode("ascii") + + return data + + +async def _async_transform_typeddict( + data: Mapping[str, object], + expected_type: type, +) -> Mapping[str, object]: + result: dict[str, object] = {} + annotations = get_type_hints(expected_type, include_extras=True) + for key, value in data.items(): + type_ = annotations.get(key) + if type_ is None: + # we do not have a type annotation for this field, leave it as is + result[key] = value + else: + result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_) + return result diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_typing.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_typing.py new file mode 100644 index 00000000000000..c7c54dcc37458d --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_typing.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from collections import abc as _c_abc +from collections.abc import Iterable +from typing import Annotated, Any, TypeVar, cast, get_args, get_origin + +from typing_extensions import Required + +from .._base_compat import is_union as _is_union +from .._base_type import InheritsGeneric + + +def is_annotated_type(typ: type) -> bool: + return get_origin(typ) == Annotated + + +def is_list_type(typ: type) -> bool: + return (get_origin(typ) or typ) == list + + +def is_iterable_type(typ: type) -> bool: + """If the given type is `typing.Iterable[T]`""" + origin = get_origin(typ) or typ + return origin in {Iterable, _c_abc.Iterable} + + +def is_union_type(typ: type) -> bool: + return _is_union(get_origin(typ)) + + +def is_required_type(typ: type) -> bool: + return get_origin(typ) == Required + + +def is_typevar(typ: type) -> bool: + # type ignore is required because type checkers + # think this expression will always return False + return type(typ) == TypeVar # type: ignore + + +# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]] +def strip_annotated_type(typ: type) -> type: + if is_required_type(typ) or is_annotated_type(typ): + return strip_annotated_type(cast(type, get_args(typ)[0])) + + return typ + + +def extract_type_arg(typ: type, index: int) -> type: + args = get_args(typ) + try: + return cast(type, args[index]) + except IndexError as err: + raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err + + +def extract_type_var_from_base( + typ: type, + *, + generic_bases: tuple[type, ...], + index: int, + failure_message: str | None = None, +) -> type: + """Given a type like `Foo[T]`, returns the generic type variable `T`. + + This also handles the case where a concrete subclass is given, e.g. + ```py + class MyResponse(Foo[bytes]): + ... + + extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes + ``` + + And where a generic subclass is given: + ```py + _T = TypeVar('_T') + class MyResponse(Foo[_T]): + ... + + extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes + ``` + """ + cls = cast(object, get_origin(typ) or typ) + if cls in generic_bases: + # we're given the class directly + return extract_type_arg(typ, index) + + # if a subclass is given + # --- + # this is needed as __orig_bases__ is not present in the typeshed stubs + # because it is intended to be for internal use only, however there does + # not seem to be a way to resolve generic TypeVars for inherited subclasses + # without using it. + if isinstance(cls, InheritsGeneric): + target_base_class: Any | None = None + for base in cls.__orig_bases__: + if base.__origin__ in generic_bases: + target_base_class = base + break + + if target_base_class is None: + raise RuntimeError( + "Could not find the generic base class;\n" + "This should never happen;\n" + f"Does {cls} inherit from one of {generic_bases} ?" + ) + + extracted = extract_type_arg(target_base_class, index) + if is_typevar(extracted): + # If the extracted type argument is itself a type variable + # then that means the subclass itself is generic, so we have + # to resolve the type argument from the class itself, not + # the base class. + # + # Note: if there is more than 1 type argument, the subclass could + # change the ordering of the type arguments, this is not currently + # supported. + return extract_type_arg(typ, index) + + return extracted + + raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}") diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_utils.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_utils.py new file mode 100644 index 00000000000000..ce5e7786aa2937 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils/_utils.py @@ -0,0 +1,409 @@ +from __future__ import annotations + +import functools +import inspect +import os +import re +from collections.abc import Callable, Iterable, Mapping, Sequence +from pathlib import Path +from typing import ( + Any, + TypeGuard, + TypeVar, + Union, + cast, + overload, +) + +import sniffio + +from .._base_compat import parse_date as parse_date # noqa: PLC0414 +from .._base_compat import parse_datetime as parse_datetime # noqa: PLC0414 +from .._base_type import FileTypes, Headers, HeadersLike, NotGiven, NotGivenOr + + +def remove_notgiven_indict(obj): + if obj is None or (not isinstance(obj, Mapping)): + return obj + return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)} + + +_T = TypeVar("_T") +_TupleT = TypeVar("_TupleT", bound=tuple[object, ...]) +_MappingT = TypeVar("_MappingT", bound=Mapping[str, object]) +_SequenceT = TypeVar("_SequenceT", bound=Sequence[object]) +CallableT = TypeVar("CallableT", bound=Callable[..., Any]) + + +def flatten(t: Iterable[Iterable[_T]]) -> list[_T]: + return [item for sublist in t for item in sublist] + + +def extract_files( + # TODO: this needs to take Dict but variance issues..... + # create protocol type ? + query: Mapping[str, object], + *, + paths: Sequence[Sequence[str]], +) -> list[tuple[str, FileTypes]]: + """Recursively extract files from the given dictionary based on specified paths. + + A path may look like this ['foo', 'files', '', 'data']. + + Note: this mutates the given dictionary. + """ + files: list[tuple[str, FileTypes]] = [] + for path in paths: + files.extend(_extract_items(query, path, index=0, flattened_key=None)) + return files + + +def _extract_items( + obj: object, + path: Sequence[str], + *, + index: int, + flattened_key: str | None, +) -> list[tuple[str, FileTypes]]: + try: + key = path[index] + except IndexError: + if isinstance(obj, NotGiven): + # no value was provided - we can safely ignore + return [] + + # cyclical import + from .._files import assert_is_file_content + + # We have exhausted the path, return the entry we found. + assert_is_file_content(obj, key=flattened_key) + assert flattened_key is not None + return [(flattened_key, cast(FileTypes, obj))] + + index += 1 + if is_dict(obj): + try: + # We are at the last entry in the path so we must remove the field + if (len(path)) == index: + item = obj.pop(key) + else: + item = obj[key] + except KeyError: + # Key was not present in the dictionary, this is not indicative of an error + # as the given path may not point to a required field. We also do not want + # to enforce required fields as the API may differ from the spec in some cases. + return [] + if flattened_key is None: + flattened_key = key + else: + flattened_key += f"[{key}]" + return _extract_items( + item, + path, + index=index, + flattened_key=flattened_key, + ) + elif is_list(obj): + if key != "": + return [] + + return flatten( + [ + _extract_items( + item, + path, + index=index, + flattened_key=flattened_key + "[]" if flattened_key is not None else "[]", + ) + for item in obj + ] + ) + + # Something unexpected was passed, just ignore it. + return [] + + +def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]: + return not isinstance(obj, NotGiven) + + +# Type safe methods for narrowing types with TypeVars. +# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown], +# however this cause Pyright to rightfully report errors. As we know we don't +# care about the contained types we can safely use `object` in it's place. +# +# There are two separate functions defined, `is_*` and `is_*_t` for different use cases. +# `is_*` is for when you're dealing with an unknown input +# `is_*_t` is for when you're narrowing a known union type to a specific subset + + +def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]: + return isinstance(obj, tuple) + + +def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]: + return isinstance(obj, tuple) + + +def is_sequence(obj: object) -> TypeGuard[Sequence[object]]: + return isinstance(obj, Sequence) + + +def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]: + return isinstance(obj, Sequence) + + +def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]: + return isinstance(obj, Mapping) + + +def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]: + return isinstance(obj, Mapping) + + +def is_dict(obj: object) -> TypeGuard[dict[object, object]]: + return isinstance(obj, dict) + + +def is_list(obj: object) -> TypeGuard[list[object]]: + return isinstance(obj, list) + + +def is_iterable(obj: object) -> TypeGuard[Iterable[object]]: + return isinstance(obj, Iterable) + + +def deepcopy_minimal(item: _T) -> _T: + """Minimal reimplementation of copy.deepcopy() that will only copy certain object types: + + - mappings, e.g. `dict` + - list + + This is done for performance reasons. + """ + if is_mapping(item): + return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()}) + if is_list(item): + return cast(_T, [deepcopy_minimal(entry) for entry in item]) + return item + + +# copied from https://github.com/Rapptz/RoboDanny +def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str: + size = len(seq) + if size == 0: + return "" + + if size == 1: + return seq[0] + + if size == 2: + return f"{seq[0]} {final} {seq[1]}" + + return delim.join(seq[:-1]) + f" {final} {seq[-1]}" + + +def quote(string: str) -> str: + """Add single quotation marks around the given string. Does *not* do any escaping.""" + return f"'{string}'" + + +def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]: + """Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function. + + Useful for enforcing runtime validation of overloaded functions. + + Example usage: + ```py + @overload + def foo(*, a: str) -> str: + ... + + + @overload + def foo(*, b: bool) -> str: + ... + + + # This enforces the same constraints that a static type checker would + # i.e. that either a or b must be passed to the function + @required_args(["a"], ["b"]) + def foo(*, a: str | None = None, b: bool | None = None) -> str: + ... + ``` + """ + + def inner(func: CallableT) -> CallableT: + params = inspect.signature(func).parameters + positional = [ + name + for name, param in params.items() + if param.kind + in { + param.POSITIONAL_ONLY, + param.POSITIONAL_OR_KEYWORD, + } + ] + + @functools.wraps(func) + def wrapper(*args: object, **kwargs: object) -> object: + given_params: set[str] = set() + for i, _ in enumerate(args): + try: + given_params.add(positional[i]) + except IndexError: + raise TypeError( + f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given" + ) from None + + given_params.update(kwargs.keys()) + + for variant in variants: + matches = all(param in given_params for param in variant) + if matches: + break + else: # no break + if len(variants) > 1: + variations = human_join( + ["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants] + ) + msg = f"Missing required arguments; Expected either {variations} arguments to be given" + else: + # TODO: this error message is not deterministic + missing = list(set(variants[0]) - given_params) + if len(missing) > 1: + msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}" + else: + msg = f"Missing required argument: {quote(missing[0])}" + raise TypeError(msg) + return func(*args, **kwargs) + + return wrapper # type: ignore + + return inner + + +_K = TypeVar("_K") +_V = TypeVar("_V") + + +@overload +def strip_not_given(obj: None) -> None: ... + + +@overload +def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ... + + +@overload +def strip_not_given(obj: object) -> object: ... + + +def strip_not_given(obj: object | None) -> object: + """Remove all top-level keys where their values are instances of `NotGiven`""" + if obj is None: + return None + + if not is_mapping(obj): + return obj + + return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)} + + +def coerce_integer(val: str) -> int: + return int(val, base=10) + + +def coerce_float(val: str) -> float: + return float(val) + + +def coerce_boolean(val: str) -> bool: + return val in {"true", "1", "on"} + + +def maybe_coerce_integer(val: str | None) -> int | None: + if val is None: + return None + return coerce_integer(val) + + +def maybe_coerce_float(val: str | None) -> float | None: + if val is None: + return None + return coerce_float(val) + + +def maybe_coerce_boolean(val: str | None) -> bool | None: + if val is None: + return None + return coerce_boolean(val) + + +def removeprefix(string: str, prefix: str) -> str: + """Remove a prefix from a string. + + Backport of `str.removeprefix` for Python < 3.9 + """ + if string.startswith(prefix): + return string[len(prefix) :] + return string + + +def removesuffix(string: str, suffix: str) -> str: + """Remove a suffix from a string. + + Backport of `str.removesuffix` for Python < 3.9 + """ + if string.endswith(suffix): + return string[: -len(suffix)] + return string + + +def file_from_path(path: str) -> FileTypes: + contents = Path(path).read_bytes() + file_name = os.path.basename(path) + return (file_name, contents) + + +def get_required_header(headers: HeadersLike, header: str) -> str: + lower_header = header.lower() + if isinstance(headers, Mapping): + headers = cast(Headers, headers) + for k, v in headers.items(): + if k.lower() == lower_header and isinstance(v, str): + return v + + """ to deal with the case where the header looks like Stainless-Event-Id """ + intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize()) + + for normalized_header in [header, lower_header, header.upper(), intercaps_header]: + value = headers.get(normalized_header) + if value: + return value + + raise ValueError(f"Could not find {header} header") + + +def get_async_library() -> str: + try: + return sniffio.current_async_library() + except Exception: + return "false" + + +def drop_prefix_image_data(content: Union[str, list[dict]]) -> Union[str, list[dict]]: + """ + 删除 ;base64, 前缀 + :param image_data: + :return: + """ + if isinstance(content, list): + for data in content: + if data.get("type") == "image_url": + image_data = data.get("image_url").get("url") + if image_data.startswith("data:image/"): + image_data = image_data.split("base64,")[-1] + data["image_url"]["url"] = image_data + + return content diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/logs.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/logs.py new file mode 100644 index 00000000000000..e5fce94c00e9e0 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/logs.py @@ -0,0 +1,78 @@ +import logging +import os +import time + +logger = logging.getLogger(__name__) + + +class LoggerNameFilter(logging.Filter): + def filter(self, record): + # return record.name.startswith("loom_core") or record.name in "ERROR" or ( + # record.name.startswith("uvicorn.error") + # and record.getMessage().startswith("Uvicorn running on") + # ) + return True + + +def get_log_file(log_path: str, sub_dir: str): + """ + sub_dir should contain a timestamp. + """ + log_dir = os.path.join(log_path, sub_dir) + # Here should be creating a new directory each time, so `exist_ok=False` + os.makedirs(log_dir, exist_ok=False) + return os.path.join(log_dir, "zhipuai.log") + + +def get_config_dict(log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int) -> dict: + # for windows, the path should be a raw string. + log_file_path = log_file_path.encode("unicode-escape").decode() if os.name == "nt" else log_file_path + log_level = log_level.upper() + config_dict = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "formatter": {"format": ("%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s")}, + }, + "filters": { + "logger_name_filter": { + "()": __name__ + ".LoggerNameFilter", + }, + }, + "handlers": { + "stream_handler": { + "class": "logging.StreamHandler", + "formatter": "formatter", + "level": log_level, + # "stream": "ext://sys.stdout", + # "filters": ["logger_name_filter"], + }, + "file_handler": { + "class": "logging.handlers.RotatingFileHandler", + "formatter": "formatter", + "level": log_level, + "filename": log_file_path, + "mode": "a", + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, + "encoding": "utf8", + }, + }, + "loggers": { + "loom_core": { + "handlers": ["stream_handler", "file_handler"], + "level": log_level, + "propagate": False, + } + }, + "root": { + "level": log_level, + "handlers": ["stream_handler", "file_handler"], + }, + } + return config_dict + + +def get_timestamp_ms(): + t = time.time() + return int(round(t * 1000)) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/pagination.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/pagination.py new file mode 100644 index 00000000000000..7f0b1b91d98556 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/pagination.py @@ -0,0 +1,62 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Any, Generic, Optional, TypeVar, cast + +from typing_extensions import Protocol, override, runtime_checkable + +from ._http_client import BasePage, BaseSyncPage, PageInfo + +__all__ = ["SyncPage", "SyncCursorPage"] + +_T = TypeVar("_T") + + +@runtime_checkable +class CursorPageItem(Protocol): + id: Optional[str] + + +class SyncPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]): + """Note: no pagination actually occurs yet, this is for forwards-compatibility.""" + + data: list[_T] + object: str + + @override + def _get_page_items(self) -> list[_T]: + data = self.data + if not data: + return [] + return data + + @override + def next_page_info(self) -> None: + """ + This page represents a response that isn't actually paginated at the API level + so there will never be a next page. + """ + return None + + +class SyncCursorPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]): + data: list[_T] + + @override + def _get_page_items(self) -> list[_T]: + data = self.data + if not data: + return [] + return data + + @override + def next_page_info(self) -> Optional[PageInfo]: + data = self.data + if not data: + return None + + item = cast(Any, data[-1]) + if not isinstance(item, CursorPageItem) or item.id is None: + # TODO emit warning log + return None + + return PageInfo(params={"after": item.id}) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/__init__.py new file mode 100644 index 00000000000000..9f941fb91c8776 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/__init__.py @@ -0,0 +1,5 @@ +from .assistant_completion import AssistantCompletion + +__all__ = [ + "AssistantCompletion", +] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_completion.py new file mode 100644 index 00000000000000..cbfb6edaeb1f19 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_completion.py @@ -0,0 +1,40 @@ +from typing import Any, Optional + +from ...core import BaseModel +from .message import MessageContent + +__all__ = ["AssistantCompletion", "CompletionUsage"] + + +class ErrorInfo(BaseModel): + code: str # 错误码 + message: str # 错误信息 + + +class AssistantChoice(BaseModel): + index: int # 结果下标 + delta: MessageContent # 当前会话输出消息体 + finish_reason: str + """ + # 推理结束原因 stop代表推理自然结束或触发停止词。 sensitive 代表模型推理内容被安全审核接口拦截。请注意,针对此类内容,请用户自行判断并决定是否撤回已公开的内容。 + # network_error 代表模型推理服务异常。 + """ # noqa: E501 + metadata: dict # 元信息,拓展字段 + + +class CompletionUsage(BaseModel): + prompt_tokens: int # 输入的 tokens 数量 + completion_tokens: int # 输出的 tokens 数量 + total_tokens: int # 总 tokens 数量 + + +class AssistantCompletion(BaseModel): + id: str # 请求 ID + conversation_id: str # 会话 ID + assistant_id: str # 智能体 ID + created: int # 请求创建时间,Unix 时间戳 + status: str # 返回状态,包括:`completed` 表示生成结束`in_progress`表示生成中 `failed` 表示生成异常 + last_error: Optional[ErrorInfo] # 异常信息 + choices: list[AssistantChoice] # 增量返回的信息 + metadata: Optional[dict[str, Any]] # 元信息,拓展字段 + usage: Optional[CompletionUsage] # tokens 数量统计 diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_conversation_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_conversation_params.py new file mode 100644 index 00000000000000..03f14f4238f37f --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_conversation_params.py @@ -0,0 +1,7 @@ +from typing import TypedDict + + +class ConversationParameters(TypedDict, total=False): + assistant_id: str # 智能体 ID + page: int # 当前分页 + page_size: int # 分页数量 diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_conversation_resp.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_conversation_resp.py new file mode 100644 index 00000000000000..d1833d220a2e3b --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_conversation_resp.py @@ -0,0 +1,29 @@ +from ...core import BaseModel + +__all__ = ["ConversationUsageListResp"] + + +class Usage(BaseModel): + prompt_tokens: int # 用户输入的 tokens 数量 + completion_tokens: int # 模型输入的 tokens 数量 + total_tokens: int # 总 tokens 数量 + + +class ConversationUsage(BaseModel): + id: str # 会话 id + assistant_id: str # 智能体Assistant id + create_time: int # 创建时间 + update_time: int # 更新时间 + usage: Usage # 会话中 tokens 数量统计 + + +class ConversationUsageList(BaseModel): + assistant_id: str # 智能体id + has_more: bool # 是否还有更多页 + conversation_list: list[ConversationUsage] # 返回的 + + +class ConversationUsageListResp(BaseModel): + code: int + msg: str + data: ConversationUsageList diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_create_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_create_params.py new file mode 100644 index 00000000000000..2def1025cd2b33 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_create_params.py @@ -0,0 +1,32 @@ +from typing import Optional, TypedDict, Union + + +class AssistantAttachments: + file_id: str + + +class MessageTextContent: + type: str # 目前支持 type = text + text: str + + +MessageContent = Union[MessageTextContent] + + +class ConversationMessage(TypedDict): + """会话消息体""" + + role: str # 用户的输入角色,例如 'user' + content: list[MessageContent] # 会话消息体的内容 + + +class AssistantParameters(TypedDict, total=False): + """智能体参数类""" + + assistant_id: str # 智能体 ID + conversation_id: Optional[str] # 会话 ID,不传则创建新会话 + model: str # 模型名称,默认为 'GLM-4-Assistant' + stream: bool # 是否支持流式 SSE,需要传入 True + messages: list[ConversationMessage] # 会话消息体 + attachments: Optional[list[AssistantAttachments]] # 会话指定的文件,非必填 + metadata: Optional[dict] # 元信息,拓展字段,非必填 diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_support_resp.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_support_resp.py new file mode 100644 index 00000000000000..0709cdbcad25e1 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/assistant_support_resp.py @@ -0,0 +1,21 @@ +from ...core import BaseModel + +__all__ = ["AssistantSupportResp"] + + +class AssistantSupport(BaseModel): + assistant_id: str # 智能体的 Assistant id,用于智能体会话 + created_at: int # 创建时间 + updated_at: int # 更新时间 + name: str # 智能体名称 + avatar: str # 智能体头像 + description: str # 智能体描述 + status: str # 智能体状态,目前只有 publish + tools: list[str] # 智能体支持的工具名 + starter_prompts: list[str] # 智能体启动推荐的 prompt + + +class AssistantSupportResp(BaseModel): + code: int + msg: str + data: list[AssistantSupport] # 智能体列表 diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/__init__.py new file mode 100644 index 00000000000000..562e0151e53b48 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/__init__.py @@ -0,0 +1,3 @@ +from .message_content import MessageContent + +__all__ = ["MessageContent"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/message_content.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/message_content.py new file mode 100644 index 00000000000000..6a1a438a6fe03d --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/message_content.py @@ -0,0 +1,13 @@ +from typing import Annotated, TypeAlias, Union + +from ....core._utils import PropertyInfo +from .text_content_block import TextContentBlock +from .tools_delta_block import ToolsDeltaBlock + +__all__ = ["MessageContent"] + + +MessageContent: TypeAlias = Annotated[ + Union[ToolsDeltaBlock, TextContentBlock], + PropertyInfo(discriminator="type"), +] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/text_content_block.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/text_content_block.py new file mode 100644 index 00000000000000..865fb1139e2f75 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/text_content_block.py @@ -0,0 +1,14 @@ +from typing import Literal + +from ....core import BaseModel + +__all__ = ["TextContentBlock"] + + +class TextContentBlock(BaseModel): + content: str + + role: str = "assistant" + + type: Literal["content"] = "content" + """Always `content`.""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/code_interpreter_delta_block.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/code_interpreter_delta_block.py new file mode 100644 index 00000000000000..9d569b282ef9f7 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/code_interpreter_delta_block.py @@ -0,0 +1,27 @@ +from typing import Literal + +__all__ = ["CodeInterpreterToolBlock"] + +from .....core import BaseModel + + +class CodeInterpreterToolOutput(BaseModel): + """代码工具输出结果""" + + type: str # 代码执行日志,目前只有 logs + logs: str # 代码执行的日志结果 + error_msg: str # 错误信息 + + +class CodeInterpreter(BaseModel): + """代码解释器""" + + input: str # 生成的代码片段,输入给代码沙盒 + outputs: list[CodeInterpreterToolOutput] # 代码执行后的输出结果 + + +class CodeInterpreterToolBlock(BaseModel): + """代码工具块""" + + code_interpreter: CodeInterpreter # 代码解释器对象 + type: Literal["code_interpreter"] # 调用工具的类型,始终为 `code_interpreter` diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/drawing_tool_delta_block.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/drawing_tool_delta_block.py new file mode 100644 index 00000000000000..0b6895556b6164 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/drawing_tool_delta_block.py @@ -0,0 +1,21 @@ +from typing import Literal + +from .....core import BaseModel + +__all__ = ["DrawingToolBlock"] + + +class DrawingToolOutput(BaseModel): + image: str + + +class DrawingTool(BaseModel): + input: str + outputs: list[DrawingToolOutput] + + +class DrawingToolBlock(BaseModel): + drawing_tool: DrawingTool + + type: Literal["drawing_tool"] + """Always `drawing_tool`.""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/function_delta_block.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/function_delta_block.py new file mode 100644 index 00000000000000..c439bc4b3fbbb8 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/function_delta_block.py @@ -0,0 +1,22 @@ +from typing import Literal, Union + +__all__ = ["FunctionToolBlock"] + +from .....core import BaseModel + + +class FunctionToolOutput(BaseModel): + content: str + + +class FunctionTool(BaseModel): + name: str + arguments: Union[str, dict] + outputs: list[FunctionToolOutput] + + +class FunctionToolBlock(BaseModel): + function: FunctionTool + + type: Literal["function"] + """Always `drawing_tool`.""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/retrieval_delta_black.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/retrieval_delta_black.py new file mode 100644 index 00000000000000..4789e9378a8a39 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/retrieval_delta_black.py @@ -0,0 +1,41 @@ +from typing import Literal + +from .....core import BaseModel + + +class RetrievalToolOutput(BaseModel): + """ + This class represents the output of a retrieval tool. + + Attributes: + - text (str): The text snippet retrieved from the knowledge base. + - document (str): The name of the document from which the text snippet was retrieved, returned only in intelligent configuration. + """ # noqa: E501 + + text: str + document: str + + +class RetrievalTool(BaseModel): + """ + This class represents the outputs of a retrieval tool. + + Attributes: + - outputs (List[RetrievalToolOutput]): A list of text snippets and their respective document names retrieved from the knowledge base. + """ # noqa: E501 + + outputs: list[RetrievalToolOutput] + + +class RetrievalToolBlock(BaseModel): + """ + This class represents a block for invoking the retrieval tool. + + Attributes: + - retrieval (RetrievalTool): An instance of the RetrievalTool class containing the retrieval outputs. + - type (Literal["retrieval"]): The type of tool being used, always set to "retrieval". + """ + + retrieval: RetrievalTool + type: Literal["retrieval"] + """Always `retrieval`.""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/tools_type.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/tools_type.py new file mode 100644 index 00000000000000..98544053d4c83a --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/tools_type.py @@ -0,0 +1,16 @@ +from typing import Annotated, TypeAlias, Union + +from .....core._utils import PropertyInfo +from .code_interpreter_delta_block import CodeInterpreterToolBlock +from .drawing_tool_delta_block import DrawingToolBlock +from .function_delta_block import FunctionToolBlock +from .retrieval_delta_black import RetrievalToolBlock +from .web_browser_delta_block import WebBrowserToolBlock + +__all__ = ["ToolsType"] + + +ToolsType: TypeAlias = Annotated[ + Union[DrawingToolBlock, CodeInterpreterToolBlock, WebBrowserToolBlock, RetrievalToolBlock, FunctionToolBlock], + PropertyInfo(discriminator="type"), +] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/web_browser_delta_block.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/web_browser_delta_block.py new file mode 100644 index 00000000000000..966e6fe0c84fef --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools/web_browser_delta_block.py @@ -0,0 +1,48 @@ +from typing import Literal + +from .....core import BaseModel + +__all__ = ["WebBrowserToolBlock"] + + +class WebBrowserOutput(BaseModel): + """ + This class represents the output of a web browser search result. + + Attributes: + - title (str): The title of the search result. + - link (str): The URL link to the search result's webpage. + - content (str): The textual content extracted from the search result. + - error_msg (str): Any error message encountered during the search or retrieval process. + """ + + title: str + link: str + content: str + error_msg: str + + +class WebBrowser(BaseModel): + """ + This class represents the input and outputs of a web browser search. + + Attributes: + - input (str): The input query for the web browser search. + - outputs (List[WebBrowserOutput]): A list of search results returned by the web browser. + """ + + input: str + outputs: list[WebBrowserOutput] + + +class WebBrowserToolBlock(BaseModel): + """ + This class represents a block for invoking the web browser tool. + + Attributes: + - web_browser (WebBrowser): An instance of the WebBrowser class containing the search input and outputs. + - type (Literal["web_browser"]): The type of tool being used, always set to "web_browser". + """ + + web_browser: WebBrowser + type: Literal["web_browser"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools_delta_block.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools_delta_block.py new file mode 100644 index 00000000000000..781a1ab819c286 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/assistant/message/tools_delta_block.py @@ -0,0 +1,16 @@ +from typing import Literal + +from ....core import BaseModel +from .tools.tools_type import ToolsType + +__all__ = ["ToolsDeltaBlock"] + + +class ToolsDeltaBlock(BaseModel): + tool_calls: list[ToolsType] + """The index of the content part in the message.""" + + role: str = "tool" + + type: Literal["tool_calls"] = "tool_calls" + """Always `tool_calls`.""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch.py new file mode 100644 index 00000000000000..560562915c9d32 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch.py @@ -0,0 +1,82 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +import builtins +from typing import Literal, Optional + +from ..core import BaseModel +from .batch_error import BatchError +from .batch_request_counts import BatchRequestCounts + +__all__ = ["Batch", "Errors"] + + +class Errors(BaseModel): + data: Optional[list[BatchError]] = None + + object: Optional[str] = None + """这个类型,一直是`list`。""" + + +class Batch(BaseModel): + id: str + + completion_window: str + """用于执行请求的地址信息。""" + + created_at: int + """这是 Unix timestamp (in seconds) 表示的创建时间。""" + + endpoint: str + """这是ZhipuAI endpoint的地址。""" + + input_file_id: str + """标记为batch的输入文件的ID。""" + + object: Literal["batch"] + """这个类型,一直是`batch`.""" + + status: Literal[ + "validating", "failed", "in_progress", "finalizing", "completed", "expired", "cancelling", "cancelled" + ] + """batch 的状态。""" + + cancelled_at: Optional[int] = None + """Unix timestamp (in seconds) 表示的取消时间。""" + + cancelling_at: Optional[int] = None + """Unix timestamp (in seconds) 表示发起取消的请求时间 """ + + completed_at: Optional[int] = None + """Unix timestamp (in seconds) 表示的完成时间。""" + + error_file_id: Optional[str] = None + """这个文件id包含了执行请求失败的请求的输出。""" + + errors: Optional[Errors] = None + + expired_at: Optional[int] = None + """Unix timestamp (in seconds) 表示的将在过期时间。""" + + expires_at: Optional[int] = None + """Unix timestamp (in seconds) 触发过期""" + + failed_at: Optional[int] = None + """Unix timestamp (in seconds) 表示的失败时间。""" + + finalizing_at: Optional[int] = None + """Unix timestamp (in seconds) 表示的最终时间。""" + + in_progress_at: Optional[int] = None + """Unix timestamp (in seconds) 表示的开始处理时间。""" + + metadata: Optional[builtins.object] = None + """ + key:value形式的元数据,以便将信息存储 + 结构化格式。键的长度是64个字符,值最长512个字符 + """ + + output_file_id: Optional[str] = None + """完成请求的输出文件的ID。""" + + request_counts: Optional[BatchRequestCounts] = None + """批次中不同状态的请求计数""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_create_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_create_params.py new file mode 100644 index 00000000000000..3dae65ea46fcbe --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_create_params.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import Literal, Optional + +from typing_extensions import Required, TypedDict + +__all__ = ["BatchCreateParams"] + + +class BatchCreateParams(TypedDict, total=False): + completion_window: Required[str] + """The time frame within which the batch should be processed. + + Currently only `24h` is supported. + """ + + endpoint: Required[Literal["/v1/chat/completions", "/v1/embeddings"]] + """The endpoint to be used for all requests in the batch. + + Currently `/v1/chat/completions` and `/v1/embeddings` are supported. + """ + + input_file_id: Required[str] + """The ID of an uploaded file that contains requests for the new batch. + + See [upload file](https://platform.openai.com/docs/api-reference/files/create) + for how to upload a file. + + Your input file must be formatted as a + [JSONL file](https://platform.openai.com/docs/api-reference/batch/requestInput), + and must be uploaded with the purpose `batch`. + """ + + metadata: Optional[dict[str, str]] + """Optional custom metadata for the batch.""" + + auto_delete_input_file: Optional[bool] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_error.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_error.py new file mode 100644 index 00000000000000..f934db19781e41 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_error.py @@ -0,0 +1,21 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Optional + +from ..core import BaseModel + +__all__ = ["BatchError"] + + +class BatchError(BaseModel): + code: Optional[str] = None + """定义的业务错误码""" + + line: Optional[int] = None + """文件中的行号""" + + message: Optional[str] = None + """关于对话文件中的错误的描述""" + + param: Optional[str] = None + """参数名称,如果有的话""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_list_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_list_params.py new file mode 100644 index 00000000000000..1a681671320eca --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_list_params.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from typing_extensions import TypedDict + +__all__ = ["BatchListParams"] + + +class BatchListParams(TypedDict, total=False): + after: str + """分页的游标,用于获取下一页的数据。 + + `after` 是一个指向当前页面的游标,用于获取下一页的数据。如果没有提供 `after`,则返回第一页的数据。 + list. + """ + + limit: int + """这个参数用于限制返回的结果数量。 + + Limit 用于限制返回的结果数量。默认值为 10 + """ diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_request_counts.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_request_counts.py new file mode 100644 index 00000000000000..ca3ccae625052b --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/batch_request_counts.py @@ -0,0 +1,14 @@ +from ..core import BaseModel + +__all__ = ["BatchRequestCounts"] + + +class BatchRequestCounts(BaseModel): + completed: int + """这个数字表示已经完成的请求。""" + + failed: int + """这个数字表示失败的请求。""" + + total: int + """这个数字表示总的请求。""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py index a0645b09168821..c1eed070f32d9f 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py @@ -1,10 +1,9 @@ from typing import Optional -from pydantic import BaseModel - +from ...core import BaseModel from .chat_completion import CompletionChoice, CompletionUsage -__all__ = ["AsyncTaskStatus"] +__all__ = ["AsyncTaskStatus", "AsyncCompletion"] class AsyncTaskStatus(BaseModel): diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py index 4b3a929a2b816d..1945a826cda2d0 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py @@ -1,6 +1,6 @@ from typing import Optional -from pydantic import BaseModel +from ...core import BaseModel __all__ = ["Completion", "CompletionUsage"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py index c2506997419815..27fad0008a1dd4 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py @@ -1,8 +1,9 @@ -from typing import Optional +from typing import Any, Optional -from pydantic import BaseModel +from ...core import BaseModel __all__ = [ + "CompletionUsage", "ChatCompletionChunk", "Choice", "ChoiceDelta", @@ -53,3 +54,4 @@ class ChatCompletionChunk(BaseModel): created: Optional[int] = None model: Optional[str] = None usage: Optional[CompletionUsage] = None + extra_json: dict[str, Any] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/code_geex/code_geex_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/code_geex/code_geex_params.py new file mode 100644 index 00000000000000..666b38855cd637 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/code_geex/code_geex_params.py @@ -0,0 +1,146 @@ +from typing import Literal, Optional + +from typing_extensions import Required, TypedDict + +__all__ = [ + "CodeGeexTarget", + "CodeGeexContext", + "CodeGeexExtra", +] + + +class CodeGeexTarget(TypedDict, total=False): + """补全的内容参数""" + + path: Optional[str] + """文件路径""" + language: Required[ + Literal[ + "c", + "c++", + "cpp", + "c#", + "csharp", + "c-sharp", + "css", + "cuda", + "dart", + "lua", + "objectivec", + "objective-c", + "objective-c++", + "python", + "perl", + "prolog", + "swift", + "lisp", + "java", + "scala", + "tex", + "jsx", + "tsx", + "vue", + "markdown", + "html", + "php", + "js", + "javascript", + "typescript", + "go", + "shell", + "rust", + "sql", + "kotlin", + "vb", + "ruby", + "pascal", + "r", + "fortran", + "lean", + "matlab", + "delphi", + "scheme", + "basic", + "assembly", + "groovy", + "abap", + "gdscript", + "haskell", + "julia", + "elixir", + "excel", + "clojure", + "actionscript", + "solidity", + "powershell", + "erlang", + "cobol", + "alloy", + "awk", + "thrift", + "sparql", + "augeas", + "cmake", + "f-sharp", + "stan", + "isabelle", + "dockerfile", + "rmarkdown", + "literate-agda", + "tcl", + "glsl", + "antlr", + "verilog", + "racket", + "standard-ml", + "elm", + "yaml", + "smalltalk", + "ocaml", + "idris", + "visual-basic", + "protocol-buffer", + "bluespec", + "applescript", + "makefile", + "tcsh", + "maple", + "systemverilog", + "literate-coffeescript", + "vhdl", + "restructuredtext", + "sas", + "literate-haskell", + "java-server-pages", + "coffeescript", + "emacs-lisp", + "mathematica", + "xslt", + "zig", + "common-lisp", + "stata", + "agda", + "ada", + ] + ] + """代码语言类型,如python""" + code_prefix: Required[str] + """补全位置的前文""" + code_suffix: Required[str] + """补全位置的后文""" + + +class CodeGeexContext(TypedDict, total=False): + """附加代码""" + + path: Required[str] + """附加代码文件的路径""" + code: Required[str] + """附加的代码内容""" + + +class CodeGeexExtra(TypedDict, total=False): + target: Required[CodeGeexTarget] + """补全的内容参数""" + contexts: Optional[list[CodeGeexContext]] + """附加代码""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py index e01f2c815fb382..8425b5c86688dd 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py @@ -2,8 +2,7 @@ from typing import Optional -from pydantic import BaseModel - +from ..core import BaseModel from .chat.chat_completion import CompletionUsage __all__ = ["Embedding", "EmbeddingsResponded"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/__init__.py new file mode 100644 index 00000000000000..bbaf59e4d7d17a --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/__init__.py @@ -0,0 +1,5 @@ +from .file_deleted import FileDeleted +from .file_object import FileObject, ListOfFileObject +from .upload_detail import UploadDetail + +__all__ = ["FileObject", "ListOfFileObject", "UploadDetail", "FileDeleted"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_create_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_create_params.py new file mode 100644 index 00000000000000..4ef93b1c05acae --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_create_params.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from typing import Literal, Optional + +from typing_extensions import Required, TypedDict + +__all__ = ["FileCreateParams"] + +from ...core import FileTypes +from . import UploadDetail + + +class FileCreateParams(TypedDict, total=False): + file: FileTypes + """file和 upload_detail二选一必填""" + + upload_detail: list[UploadDetail] + """file和 upload_detail二选一必填""" + + purpose: Required[Literal["fine-tune", "retrieval", "batch"]] + """ + 上传文件的用途,支持 "fine-tune和 "retrieval" + retrieval支持上传Doc、Docx、PDF、Xlsx、URL类型文件,且单个文件的大小不超过 5MB。 + fine-tune支持上传.jsonl文件且当前单个文件的大小最大可为 100 MB ,文件中语料格式需满足微调指南中所描述的格式。 + """ + custom_separator: Optional[list[str]] + """ + 当 purpose 为 retrieval 且文件类型为 pdf, url, docx 时上传,切片规则默认为 `\n`。 + """ + knowledge_id: str + """ + 当文件上传目的为 retrieval 时,需要指定知识库ID进行上传。 + """ + + sentence_size: int + """ + 当文件上传目的为 retrieval 时,需要指定知识库ID进行上传。 + """ diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_deleted.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_deleted.py new file mode 100644 index 00000000000000..a384b1a69a5735 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_deleted.py @@ -0,0 +1,13 @@ +from typing import Literal + +from ...core import BaseModel + +__all__ = ["FileDeleted"] + + +class FileDeleted(BaseModel): + id: str + + deleted: bool + + object: Literal["file"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_object.py similarity index 86% rename from api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py rename to api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_object.py index 75f76fe969faf7..8f9d0fbb8e6ce3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/file_object.py @@ -1,8 +1,8 @@ from typing import Optional -from pydantic import BaseModel +from ...core import BaseModel -__all__ = ["FileObject"] +__all__ = ["FileObject", "ListOfFileObject"] class FileObject(BaseModel): diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/upload_detail.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/upload_detail.py new file mode 100644 index 00000000000000..8f1ca5ce5756aa --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/files/upload_detail.py @@ -0,0 +1,13 @@ +from typing import Optional + +from ...core import BaseModel + + +class UploadDetail(BaseModel): + url: str + knowledge_type: int + file_name: Optional[str] = None + sentence_size: Optional[int] = None + custom_separator: Optional[list[str]] = None + callback_url: Optional[str] = None + callback_header: Optional[dict[str, str]] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py index 1d3930286b89d3..75c7553dbe35c6 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py @@ -1,6 +1,6 @@ from typing import Optional, Union -from pydantic import BaseModel +from ...core import BaseModel __all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py index e26b448534246f..f996cff11430b0 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py @@ -1,6 +1,6 @@ from typing import Optional, Union -from pydantic import BaseModel +from ...core import BaseModel __all__ = ["FineTuningJobEvent", "Metric", "JobEvent"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/models/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/models/__init__.py new file mode 100644 index 00000000000000..57d0d2511dbc14 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/models/__init__.py @@ -0,0 +1 @@ +from .fine_tuned_models import FineTunedModelsStatus diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/models/fine_tuned_models.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/models/fine_tuned_models.py new file mode 100644 index 00000000000000..b286a5b5774d3d --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/models/fine_tuned_models.py @@ -0,0 +1,13 @@ +from typing import ClassVar + +from ....core import PYDANTIC_V2, BaseModel, ConfigDict + +__all__ = ["FineTunedModelsStatus"] + + +class FineTunedModelsStatus(BaseModel): + if PYDANTIC_V2: + model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow", protected_namespaces=()) + request_id: str # 请求id + model_name: str # 模型名称 + delete_status: str # 删除状态 deleting(删除中), deleted (已删除) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py index b352ce0954ad55..3bcad0acabd215 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py @@ -2,7 +2,7 @@ from typing import Optional -from pydantic import BaseModel +from ..core import BaseModel __all__ = ["GeneratedImage", "ImagesResponded"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/__init__.py new file mode 100644 index 00000000000000..8c81d703e214a3 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/__init__.py @@ -0,0 +1,8 @@ +from .knowledge import KnowledgeInfo +from .knowledge_used import KnowledgeStatistics, KnowledgeUsed + +__all__ = [ + "KnowledgeInfo", + "KnowledgeStatistics", + "KnowledgeUsed", +] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/__init__.py new file mode 100644 index 00000000000000..32e23e6dab3076 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/__init__.py @@ -0,0 +1,8 @@ +from .document import DocumentData, DocumentFailedInfo, DocumentObject, DocumentSuccessinfo + +__all__ = [ + "DocumentData", + "DocumentObject", + "DocumentSuccessinfo", + "DocumentFailedInfo", +] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document.py new file mode 100644 index 00000000000000..b9a1646391ece8 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document.py @@ -0,0 +1,51 @@ +from typing import Optional + +from ....core import BaseModel + +__all__ = ["DocumentData", "DocumentObject", "DocumentSuccessinfo", "DocumentFailedInfo"] + + +class DocumentSuccessinfo(BaseModel): + documentId: Optional[str] = None + """文件id""" + filename: Optional[str] = None + """文件名称""" + + +class DocumentFailedInfo(BaseModel): + failReason: Optional[str] = None + """上传失败的原因,包括:文件格式不支持、文件大小超出限制、知识库容量已满、容量上限为 50 万字。""" + filename: Optional[str] = None + """文件名称""" + documentId: Optional[str] = None + """知识库id""" + + +class DocumentObject(BaseModel): + """文档信息""" + + successInfos: Optional[list[DocumentSuccessinfo]] = None + """上传成功的文件信息""" + failedInfos: Optional[list[DocumentFailedInfo]] = None + """上传失败的文件信息""" + + +class DocumentDataFailInfo(BaseModel): + """失败原因""" + + embedding_code: Optional[int] = ( + None # 失败码 10001:知识不可用,知识库空间已达上限 10002:知识不可用,知识库空间已达上限(字数超出限制) + ) + embedding_msg: Optional[str] = None # 失败原因 + + +class DocumentData(BaseModel): + id: str = None # 知识唯一id + custom_separator: list[str] = None # 切片规则 + sentence_size: str = None # 切片大小 + length: int = None # 文件大小(字节) + word_num: int = None # 文件字数 + name: str = None # 文件名 + url: str = None # 文件下载链接 + embedding_stat: int = None # 0:向量化中 1:向量化完成 2:向量化失败 + failInfo: Optional[DocumentDataFailInfo] = None # 失败原因 向量化失败embedding_stat=2的时候 会有此值 diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_edit_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_edit_params.py new file mode 100644 index 00000000000000..509cb3a451af5f --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_edit_params.py @@ -0,0 +1,29 @@ +from typing import Optional, TypedDict + +__all__ = ["DocumentEditParams"] + + +class DocumentEditParams(TypedDict): + """ + 知识参数类型定义 + + Attributes: + id (str): 知识ID + knowledge_type (int): 知识类型: + 1:文章知识: 支持pdf,url,docx + 2.问答知识-文档: 支持pdf,url,docx + 3.问答知识-表格: 支持xlsx + 4.商品库-表格: 支持xlsx + 5.自定义: 支持pdf,url,docx + custom_separator (Optional[List[str]]): 当前知识类型为自定义(knowledge_type=5)时的切片规则,默认\n + sentence_size (Optional[int]): 当前知识类型为自定义(knowledge_type=5)时的切片字数,取值范围: 20-2000,默认300 + callback_url (Optional[str]): 回调地址 + callback_header (Optional[dict]): 回调时携带的header + """ + + id: str + knowledge_type: int + custom_separator: Optional[list[str]] + sentence_size: Optional[int] + callback_url: Optional[str] + callback_header: Optional[dict[str, str]] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_list_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_list_params.py new file mode 100644 index 00000000000000..910c8c045e1b97 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_list_params.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import Optional + +from typing_extensions import TypedDict + + +class DocumentListParams(TypedDict, total=False): + """ + 文件查询参数类型定义 + + Attributes: + purpose (Optional[str]): 文件用途 + knowledge_id (Optional[str]): 当文件用途为 retrieval 时,需要提供查询的知识库ID + page (Optional[int]): 页,默认1 + limit (Optional[int]): 查询文件列表数,默认10 + after (Optional[str]): 查询指定fileID之后的文件列表(当文件用途为 fine-tune 时需要) + order (Optional[str]): 排序规则,可选值['desc', 'asc'],默认desc(当文件用途为 fine-tune 时需要) + """ + + purpose: Optional[str] + knowledge_id: Optional[str] + page: Optional[int] + limit: Optional[int] + after: Optional[str] + order: Optional[str] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_list_resp.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_list_resp.py new file mode 100644 index 00000000000000..acae4fad9ff36b --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/document/document_list_resp.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from ....core import BaseModel +from . import DocumentData + +__all__ = ["DocumentPage"] + + +class DocumentPage(BaseModel): + list: list[DocumentData] + object: str diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge.py new file mode 100644 index 00000000000000..bc6f159eb211e5 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge.py @@ -0,0 +1,21 @@ +from typing import Optional + +from ...core import BaseModel + +__all__ = ["KnowledgeInfo"] + + +class KnowledgeInfo(BaseModel): + id: Optional[str] = None + """知识库唯一 id""" + embedding_id: Optional[str] = ( + None # 知识库绑定的向量化模型 见模型列表 [内部服务开放接口文档](https://lslfd0slxc.feishu.cn/docx/YauWdbBiMopV0FxB7KncPWCEn8f#H15NduiQZo3ugmxnWQFcfAHpnQ4) + ) + name: Optional[str] = None # 知识库名称 100字限制 + customer_identifier: Optional[str] = None # 用户标识 长度32位以内 + description: Optional[str] = None # 知识库描述 500字限制 + background: Optional[str] = None # 背景颜色(给枚举)'blue', 'red', 'orange', 'purple', 'sky' + icon: Optional[str] = ( + None # 知识库图标(给枚举) question: 问号、book: 书籍、seal: 印章、wrench: 扳手、tag: 标签、horn: 喇叭、house: 房子 # noqa: E501 + ) + bucket_id: Optional[str] = None # 桶id 限制32位 diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_create_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_create_params.py new file mode 100644 index 00000000000000..c3da201727c34a --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_create_params.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import Literal, Optional + +from typing_extensions import TypedDict + +__all__ = ["KnowledgeBaseParams"] + + +class KnowledgeBaseParams(TypedDict): + """ + 知识库参数类型定义 + + Attributes: + embedding_id (int): 知识库绑定的向量化模型ID + name (str): 知识库名称,限制100字 + customer_identifier (Optional[str]): 用户标识,长度32位以内 + description (Optional[str]): 知识库描述,限制500字 + background (Optional[Literal['blue', 'red', 'orange', 'purple', 'sky']]): 背景颜色 + icon (Optional[Literal['question', 'book', 'seal', 'wrench', 'tag', 'horn', 'house']]): 知识库图标 + bucket_id (Optional[str]): 桶ID,限制32位 + """ + + embedding_id: int + name: str + customer_identifier: Optional[str] + description: Optional[str] + background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None + icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None + bucket_id: Optional[str] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_list_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_list_params.py new file mode 100644 index 00000000000000..a221b28e4603be --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_list_params.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from typing_extensions import TypedDict + +__all__ = ["KnowledgeListParams"] + + +class KnowledgeListParams(TypedDict, total=False): + page: int = 1 + """ 页码,默认 1,第一页 + """ + + size: int = 10 + """每页数量 默认10 + """ diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_list_resp.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_list_resp.py new file mode 100644 index 00000000000000..e462eddc550d61 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_list_resp.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from ...core import BaseModel +from . import KnowledgeInfo + +__all__ = ["KnowledgePage"] + + +class KnowledgePage(BaseModel): + list: list[KnowledgeInfo] + object: str diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_used.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_used.py new file mode 100644 index 00000000000000..cfda7097026c59 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/knowledge/knowledge_used.py @@ -0,0 +1,21 @@ +from typing import Optional + +from ...core import BaseModel + +__all__ = ["KnowledgeStatistics", "KnowledgeUsed"] + + +class KnowledgeStatistics(BaseModel): + """ + 使用量统计 + """ + + word_num: Optional[int] = None + length: Optional[int] = None + + +class KnowledgeUsed(BaseModel): + used: Optional[KnowledgeStatistics] = None + """已使用量""" + total: Optional[KnowledgeStatistics] = None + """知识库总量""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/sensitive_word_check/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/sensitive_word_check/__init__.py new file mode 100644 index 00000000000000..c9bd60419ce606 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/sensitive_word_check/__init__.py @@ -0,0 +1,3 @@ +from .sensitive_word_check import SensitiveWordCheckRequest + +__all__ = ["SensitiveWordCheckRequest"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/sensitive_word_check/sensitive_word_check.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/sensitive_word_check/sensitive_word_check.py new file mode 100644 index 00000000000000..0c37d99e653292 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/sensitive_word_check/sensitive_word_check.py @@ -0,0 +1,14 @@ +from typing import Optional + +from typing_extensions import TypedDict + + +class SensitiveWordCheckRequest(TypedDict, total=False): + type: Optional[str] + """敏感词类型,当前仅支持ALL""" + status: Optional[str] + """敏感词启用禁用状态 + 启用:ENABLE + 禁用:DISABLE + 备注:默认开启敏感词校验,如果要关闭敏感词校验,需联系商务获取对应权限,否则敏感词禁用不生效。 + """ diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/__init__.py new file mode 100644 index 00000000000000..62f77344eee56b --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/__init__.py @@ -0,0 +1,9 @@ +from .web_search import ( + SearchIntent, + SearchRecommend, + SearchResult, + WebSearch, +) +from .web_search_chunk import WebSearchChunk + +__all__ = ["WebSearch", "SearchIntent", "SearchResult", "SearchRecommend", "WebSearchChunk"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/tools_web_search_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/tools_web_search_params.py new file mode 100644 index 00000000000000..b3a3b26f07ee58 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/tools_web_search_params.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import Optional, Union + +from typing_extensions import TypedDict + +__all__ = ["WebSearchParams"] + + +class WebSearchParams(TypedDict): + """ + 工具名:web-search-pro参数类型定义 + + Attributes: + :param model: str, 模型名称 + :param request_id: Optional[str], 请求ID + :param stream: Optional[bool], 是否流式 + :param messages: Union[str, List[str], List[int], object, None], + 包含历史对话上下文的内容,按照 {"role": "user", "content": "你好"} 的json 数组形式进行传参 + 当前版本仅支持 User Message 单轮对话,工具会理解User Message并进行搜索, + 请尽可能传入不带指令格式的用户原始提问,以提高搜索准确率。 + :param scope: Optional[str], 指定搜索范围,全网、学术等,默认全网 + :param location: Optional[str], 指定搜索用户地区 location 提高相关性 + :param recent_days: Optional[int],支持指定返回 N 天(1-30)更新的搜索结果 + + + """ + + model: str + request_id: Optional[str] + stream: Optional[bool] + messages: Union[str, list[str], list[int], object, None] + scope: Optional[str] = None + location: Optional[str] = None + recent_days: Optional[int] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/web_search.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/web_search.py new file mode 100644 index 00000000000000..ac9fa3821e979b --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/web_search.py @@ -0,0 +1,71 @@ +from typing import Optional + +from ...core import BaseModel + +__all__ = [ + "WebSearch", + "SearchIntent", + "SearchResult", + "SearchRecommend", +] + + +class SearchIntent(BaseModel): + index: int + # 搜索轮次,默认为 0 + query: str + # 搜索优化 query + intent: str + # 判断的意图类型 + keywords: str + # 搜索关键词 + + +class SearchResult(BaseModel): + index: int + # 搜索轮次,默认为 0 + title: str + # 标题 + link: str + # 链接 + content: str + # 内容 + icon: str + # 图标 + media: str + # 来源媒体 + refer: str + # 角标序号 [ref_1] + + +class SearchRecommend(BaseModel): + index: int + # 搜索轮次,默认为 0 + query: str + # 推荐query + + +class WebSearchMessageToolCall(BaseModel): + id: str + search_intent: Optional[SearchIntent] + search_result: Optional[SearchResult] + search_recommend: Optional[SearchRecommend] + type: str + + +class WebSearchMessage(BaseModel): + role: str + tool_calls: Optional[list[WebSearchMessageToolCall]] = None + + +class WebSearchChoice(BaseModel): + index: int + finish_reason: str + message: WebSearchMessage + + +class WebSearch(BaseModel): + created: Optional[int] = None + choices: list[WebSearchChoice] + request_id: Optional[str] = None + id: Optional[str] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/web_search_chunk.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/web_search_chunk.py new file mode 100644 index 00000000000000..7fb0e02bb58719 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/tools/web_search_chunk.py @@ -0,0 +1,33 @@ +from typing import Optional + +from ...core import BaseModel +from .web_search import SearchIntent, SearchRecommend, SearchResult + +__all__ = ["WebSearchChunk"] + + +class ChoiceDeltaToolCall(BaseModel): + index: int + id: Optional[str] = None + + search_intent: Optional[SearchIntent] = None + search_result: Optional[SearchResult] = None + search_recommend: Optional[SearchRecommend] = None + type: Optional[str] = None + + +class ChoiceDelta(BaseModel): + role: Optional[str] = None + tool_calls: Optional[list[ChoiceDeltaToolCall]] = None + + +class Choice(BaseModel): + delta: ChoiceDelta + finish_reason: Optional[str] = None + index: int + + +class WebSearchChunk(BaseModel): + id: Optional[str] = None + choices: list[Choice] + created: Optional[int] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/__init__.py new file mode 100644 index 00000000000000..b14072b1a771af --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/__init__.py @@ -0,0 +1,3 @@ +from .video_object import VideoObject, VideoResult + +__all__ = ["VideoObject", "VideoResult"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/video_create_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/video_create_params.py new file mode 100644 index 00000000000000..f5489d708e7227 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/video_create_params.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import Optional + +from typing_extensions import TypedDict + +__all__ = ["VideoCreateParams"] + +from ..sensitive_word_check import SensitiveWordCheckRequest + + +class VideoCreateParams(TypedDict, total=False): + model: str + """模型编码""" + prompt: str + """所需视频的文本描述""" + image_url: str + """所需视频的文本描述""" + sensitive_word_check: Optional[SensitiveWordCheckRequest] + """支持 URL 或者 Base64、传入 image 奖进行图生视频 + * 图片格式: + * 图片大小:""" + request_id: str + """由用户端传参,需保证唯一性;用于区分每次请求的唯一标识,用户端不传时平台会默认生成。""" + + user_id: str + """用户端。""" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/video_object.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/video_object.py new file mode 100644 index 00000000000000..85c3844d8a791c --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/video/video_object.py @@ -0,0 +1,30 @@ +from typing import Optional + +from ...core import BaseModel + +__all__ = ["VideoObject", "VideoResult"] + + +class VideoResult(BaseModel): + url: str + """视频url""" + cover_image_url: str + """预览图""" + + +class VideoObject(BaseModel): + id: Optional[str] = None + """智谱 AI 开放平台生成的任务订单号,调用请求结果接口时请使用此订单号""" + + model: str + """模型名称""" + + video_result: list[VideoResult] + """视频生成结果""" + + task_status: str + """处理状态,PROCESSING(处理中),SUCCESS(成功),FAIL(失败) + 注:处理中状态需通过查询获取结果""" + + request_id: str + """用户在客户端请求时提交的任务编号或者平台生成的任务编号""" diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.py b/api/core/tools/provider/builtin/cogview/tools/cogview3.py index 9039708588df16..085084ca383552 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.py @@ -21,15 +21,22 @@ def _invoke( ) size_mapping = { "square": "1024x1024", - "vertical": "1024x1792", - "horizontal": "1792x1024", + "vertical_768": "768x1344", + "vertical_864": "864x1152", + "horizontal_1344": "1344x768", + "horizontal_1152": "1152x864", + "widescreen_1440": "1440x720", + "tallscreen_720": "720x1440", } # prompt prompt = tool_parameters.get("prompt", "") if not prompt: return self.create_text_message("Please input prompt") - # get size - size = size_mapping[tool_parameters.get("size", "square")] + # get size key + size_key = tool_parameters.get("size", "square") + # cogview-3-plus get size + if size_key != "cogview_3": + size = size_mapping[size_key] # get n n = tool_parameters.get("n", 1) # get quality @@ -43,16 +50,29 @@ def _invoke( # set extra body seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) extra_body = {"seed": seed_id} - response = client.images.generations( - prompt=prompt, - model="cogview-3", - size=size, - n=n, - extra_body=extra_body, - style=style, - quality=quality, - response_format="b64_json", - ) + # cogview-3-plus + if size_key != "cogview_3": + response = client.images.generations( + prompt=prompt, + model="cogview-3-plus", + size=size, + n=n, + extra_body=extra_body, + style=style, + quality=quality, + response_format="b64_json", + ) + # cogview-3 + else: + response = client.images.generations( + prompt=prompt, + model="cogview-3", + n=n, + extra_body=extra_body, + style=style, + quality=quality, + response_format="b64_json", + ) result = [] for image in response.data: result.append(self.create_image_message(image=image.url)) diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml b/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml index 1de3f599b6ac02..9ab5c2729bf7a9 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml @@ -42,21 +42,46 @@ parameters: pt_BR: Image size form: form options: + - value: cogview_3 + label: + en_US: Square_cogview_3(1024x1024) + zh_Hans: 方_cogview_3(1024x1024) + pt_BR: Square_cogview_3(1024x1024) - value: square label: - en_US: Squre(1024x1024) + en_US: Square(1024x1024) zh_Hans: 方(1024x1024) - pt_BR: Squre(1024x1024) - - value: vertical + pt_BR: Square(1024x1024) + - value: vertical_768 + label: + en_US: Vertical(768x1344) + zh_Hans: 竖屏(768x1344) + pt_BR: Vertical(768x1344) + - value: vertical_864 + label: + en_US: Vertical(864x1152) + zh_Hans: 竖屏(864x1152) + pt_BR: Vertical(864x1152) + - value: horizontal_1344 + label: + en_US: Horizontal(1344x768) + zh_Hans: 横屏(1344x768) + pt_BR: Horizontal(1344x768) + - value: horizontal_1152 + label: + en_US: Horizontal(1152x864) + zh_Hans: 横屏(1152x864) + pt_BR: Horizontal(1152x864) + - value: widescreen_1440 label: - en_US: Vertical(1024x1792) - zh_Hans: 竖屏(1024x1792) - pt_BR: Vertical(1024x1792) - - value: horizontal + en_US: Widescreen(1440x720) + zh_Hans: 宽屏(1440x720) + pt_BR: Widescreen(1440x720) + - value: tallscreen_720 label: - en_US: Horizontal(1792x1024) - zh_Hans: 横屏(1792x1024) - pt_BR: Horizontal(1792x1024) + en_US: Tallscreen(720x1440) + zh_Hans: 高屏(720x1440) + pt_BR: Tallscreen(720x1440) default: square - name: n type: number