Skip to content

Commit

Permalink
refactor(model_runtime/zhipuai_sdk): syntax style format
Browse files Browse the repository at this point in the history
  • Loading branch information
ox01024 committed Sep 13, 2024
1 parent e16ab1b commit 1a0f708
Show file tree
Hide file tree
Showing 13 changed files with 383 additions and 460 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
"ModelT",
"Query",
"FileTypes",

"PYDANTIC_V2",
"ConfigDict",
"GenericModel",
Expand All @@ -81,9 +80,7 @@
"get_model_config",
"get_model_fields",
"field_get_default",

"is_file_content",

"ZhipuAIError",
"APIStatusError",
"APIRequestFailedError",
Expand All @@ -94,23 +91,18 @@
"APIResponseError",
"APIResponseValidationError",
"APITimeoutError",

"make_request_options",
"HttpClient",
"ZHIPUAI_DEFAULT_TIMEOUT",
"ZHIPUAI_DEFAULT_MAX_RETRIES",
"ZHIPUAI_DEFAULT_LIMITS",

"is_list",
"is_mapping",
"parse_date",
"parse_datetime",
"is_given",
"maybe_transform",

"deepcopy_minimal",
"extract_files",

"StreamResponse",

]
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,19 @@
# v1 re-exports
if TYPE_CHECKING:

def parse_date(value: date | StrBytesIntFloat) -> date:
...
def parse_date(value: date | StrBytesIntFloat) -> date: ...

def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime:
...
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: ...

def get_args(t: type[Any]) -> tuple[Any, ...]:
...
def get_args(t: type[Any]) -> tuple[Any, ...]: ...

def is_union(tp: type[Any] | None) -> bool:
...
def is_union(tp: type[Any] | None) -> bool: ...

def get_origin(t: type[Any]) -> type[Any] | None:
...
def get_origin(t: type[Any]) -> type[Any] | None: ...

def is_literal_type(type_: type[Any]) -> bool:
...
def is_literal_type(type_: type[Any]) -> bool: ...

def is_typeddict(type_: type[Any]) -> bool:
...
def is_typeddict(type_: type[Any]) -> bool: ...

else:
if PYDANTIC_V2:
Expand Down Expand Up @@ -178,22 +171,19 @@ def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
# generic models
if TYPE_CHECKING:

class GenericModel(pydantic.BaseModel):
...
class GenericModel(pydantic.BaseModel): ...

else:
if PYDANTIC_V2:
# there no longer needs to be a distinction in v2 but
# we still have to create our own subclass to avoid
# inconsistent MRO ordering errors
class GenericModel(pydantic.BaseModel):
...
class GenericModel(pydantic.BaseModel): ...

else:
import pydantic.generics

class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel):
...
class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ...


# cached properties
Expand All @@ -212,26 +202,21 @@ class TypedCachedProperty(Generic[_T]):
func: Callable[[Any], _T]
attrname: str | None

def __init__(self, func: Callable[[Any], _T]) -> None:
...
def __init__(self, func: Callable[[Any], _T]) -> None: ...

@overload
def __get__(self, instance: None, owner: type[Any] | None = None) -> Self:
...
def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ...

@overload
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T:
...
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ...

def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
raise NotImplementedError()

def __set_name__(self, owner: type[Any], name: str) -> None:
...
def __set_name__(self, owner: type[Any], name: str) -> None: ...

# __set__ is not defined at runtime, but @cached_property is designed to be settable
def __set__(self, instance: object, value: _T) -> None:
...
def __set__(self, instance: object, value: _T) -> None: ...
else:
try:
from functools import cached_property as cached_property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
extra: Any = pydantic.Extra.allow # type: ignore

def to_dict(
self,
*,
mode: Literal["json", "python"] = "python",
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
self,
*,
mode: Literal["json", "python"] = "python",
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
) -> dict[str, object]:
"""Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
Expand Down Expand Up @@ -117,14 +117,14 @@ def to_dict(
)

def to_json(
self,
*,
indent: int | None = 2,
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
self,
*,
indent: int | None = 2,
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
) -> str:
"""Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation).
Expand Down Expand Up @@ -161,9 +161,9 @@ def __str__(self) -> str:
@classmethod
@override
def construct(
cls: type[ModelT],
_fields_set: set[str] | None = None,
**values: object,
cls: type[ModelT],
_fields_set: set[str] | None = None,
**values: object,
) -> ModelT:
m = cls.__new__(cls)
fields_values: dict[str, object] = {}
Expand Down Expand Up @@ -229,19 +229,19 @@ def construct(

@override
def model_dump(
self,
*,
mode: Literal["json", "python"] | str = "python",
include: IncEx = None,
exclude: IncEx = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
context: dict[str, Any] | None = None,
serialize_as_any: bool = False,
self,
*,
mode: Literal["json", "python"] | str = "python",
include: IncEx = None,
exclude: IncEx = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
context: dict[str, Any] | None = None,
serialize_as_any: bool = False,
) -> dict[str, Any]:
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump
Expand Down Expand Up @@ -284,19 +284,19 @@ def model_dump(

@override
def model_dump_json(
self,
*,
indent: int | None = None,
include: IncEx = None,
exclude: IncEx = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
context: dict[str, Any] | None = None,
serialize_as_any: bool = False,
self,
*,
indent: int | None = None,
include: IncEx = None,
exclude: IncEx = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
context: dict[str, Any] | None = None,
serialize_as_any: bool = False,
) -> str:
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json
Expand Down Expand Up @@ -364,9 +364,9 @@ def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericMo


def build(
base_model_cls: Callable[P, _BaseModelT],
*args: P.args,
**kwargs: P.kwargs,
base_model_cls: Callable[P, _BaseModelT],
*args: P.args,
**kwargs: P.kwargs,
) -> _BaseModelT:
"""Construct a BaseModel class without validation.
Expand Down Expand Up @@ -534,11 +534,11 @@ class Foo(BaseModel):
"""

def __init__(
self,
*,
mapping: dict[str, type],
discriminator_field: str,
discriminator_alias: str | None,
self,
*,
mapping: dict[str, type],
discriminator_field: str,
discriminator_alias: str | None,
) -> None:
self.mapping = mapping
self.field_name = discriminator_field
Expand Down Expand Up @@ -580,8 +580,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
if isinstance(entry, str):
mapping[entry] = variant
else:
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(
discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
if not field_info:
continue

Expand Down Expand Up @@ -640,6 +639,7 @@ def validate_type(*, type_: type[_T], value: object) -> _T:
class GenericModel(BaseGenericModel, BaseModel):
pass


if PYDANTIC_V2:
from pydantic import TypeAdapter

Expand All @@ -650,14 +650,14 @@ def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:

class TypeAdapter(Generic[_T]):
"""Used as a placeholder to easily convert runtime types to a Pydantic format
to provide validation.
For example:
```py
validated = RootModel[int](__root__="5").__root__
# validated: 5
```
"""
to provide validation.
For example:
```py
validated = RootModel[int](__root__="5").__root__
# validated: 5
```
"""

def __init__(self, type_: type[_T]):
self.type_ = type_
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,18 @@ def __bool__(self) -> Literal[False]:
class ModelBuilderProtocol(Protocol):
@classmethod
def build(
cls: type[_T],
*,
response: Response,
data: object,
) -> _T:
...
cls: type[_T],
*,
response: Response,
data: object,
) -> _T: ...


Headers = Mapping[str, Union[str, Omit]]


class HeadersLikeProtocol(Protocol):
def get(self, __key: str) -> str | None:
...
def get(self, __key: str) -> str | None: ...


HeadersLike = Union[Headers, HeadersLikeProtocol]
Expand Down Expand Up @@ -147,11 +145,11 @@ class HttpxSendArgs(TypedDict, total=False):
FileTypes = Union[
# file (or bytes)
FileContent,
# (filename, file (or bytes))
# (filename, file (or bytes))
tuple[Optional[str], FileContent],
# (filename, file (or bytes), content_type)
# (filename, file (or bytes), content_type)
tuple[Optional[str], FileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
# (filename, file (or bytes), content_type, headers)
tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
]
RequestFiles = Union[Mapping[str, FileTypes], Sequence[tuple[str, FileTypes]]]
Expand All @@ -161,11 +159,11 @@ class HttpxSendArgs(TypedDict, total=False):
HttpxFileTypes = Union[
# file (or bytes)
HttpxFileContent,
# (filename, file (or bytes))
# (filename, file (or bytes))
tuple[Optional[str], HttpxFileContent],
# (filename, file (or bytes), content_type)
# (filename, file (or bytes), content_type)
tuple[Optional[str], HttpxFileContent, Optional[str]],
# (filename, file (or bytes), content_type, headers)
# (filename, file (or bytes), content_type, headers)
tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]],
]

Expand Down
Loading

0 comments on commit 1a0f708

Please sign in to comment.