Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify cattrs._compat.is_typeddict #384

Merged
merged 9 commits into from
Jun 14, 2023
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- Optimize and improve unstructuring of `Optional` (unions of one type and `None`).
([#380](https://github.com/python-attrs/cattrs/issues/380) [#381](https://github.com/python-attrs/cattrs/pull/381))
- Fix `format_exception` and `transform_error` type annotations.
- Improve the implementation of `cattrs._compat.is_typeddict`. The implementation is now simpler, and relies on fewer private implementation details from `typing` and typing_extensions. ([#384](https://github.com/python-attrs/cattrs/pull/384))

## 23.1.2 (2023-06-02)

Expand Down
54 changes: 16 additions & 38 deletions src/cattrs/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,13 @@
from attr import fields as attrs_fields
from attr import resolve_types

__all__ = ["ExceptionGroup", "ExtensionsTypedDict", "TypedDict", "is_typeddict"]

try:
from typing_extensions import TypedDict as ExtensionsTypedDict
except ImportError:
ExtensionsTypedDict = None

try:
from typing_extensions import _TypedDictMeta as ExtensionsTypedDictMeta
except ImportError:
ExtensionsTypedDictMeta = None

if sys.version_info >= (3, 8):
from typing import Final, Protocol, get_args, get_origin

Expand All @@ -44,9 +41,20 @@ def get_origin(cl):
from typing_extensions import Final, Protocol

if sys.version_info >= (3, 11):
ExceptionGroup = ExceptionGroup
from builtins import ExceptionGroup
else:
from exceptiongroup import ExceptionGroup as ExceptionGroup # noqa: PLC0414
from exceptiongroup import ExceptionGroup

try:
from typing_extensions import is_typeddict as _is_typeddict
except ImportError:
assert sys.version_info >= (3, 10)
from typing import is_typeddict as _is_typeddict


def is_typeddict(cls):
"""Thin wrapper around typing(_extensions).is_typeddict"""
return _is_typeddict(getattr(cls, "__origin__", cls))


def has(cls):
Expand Down Expand Up @@ -157,7 +165,6 @@ def get_final_base(type) -> Optional[type]:
_AnnotatedAlias,
_GenericAlias,
_SpecialGenericAlias,
_TypedDictMeta,
_UnionGenericAlias,
)

Expand Down Expand Up @@ -234,20 +241,6 @@ def get_newtype_base(typ: Any) -> Optional[type]:
return supertype
return None

def is_typeddict(cls) -> bool:
return (
cls.__class__ is _TypedDictMeta
or (is_generic(cls) and (cls.__origin__.__class__ is _TypedDictMeta))
or (
ExtensionsTypedDictMeta is not None
and cls.__class__ is ExtensionsTypedDictMeta
or (
is_generic(cls)
and (cls.__origin__.__class__ is ExtensionsTypedDictMeta)
)
)
)

def get_notrequired_base(type) -> "Union[Any, Literal[NOTHING]]":
if get_origin(type) in (NotRequired, Required):
return get_args(type)[0]
Expand Down Expand Up @@ -364,9 +357,8 @@ def copy_with(type, args):
from typing_extensions import get_origin as te_get_origin

if sys.version_info >= (3, 8):
from typing import TypedDict, _TypedDictMeta
from typing import TypedDict
else:
_TypedDictMeta = None
TypedDict = ExtensionsTypedDict

def is_annotated(type) -> bool:
Expand Down Expand Up @@ -462,20 +454,6 @@ def copy_with(type, args):
"""Replace a generic type's arguments."""
return type.copy_with(args)

def is_typeddict(cls) -> bool:
return (
cls.__class__ is _TypedDictMeta
or (is_generic(cls) and (cls.__origin__.__class__ is _TypedDictMeta))
or (
ExtensionsTypedDictMeta is not None
and cls.__class__ is ExtensionsTypedDictMeta
or (
is_generic(cls)
and (cls.__origin__.__class__ is ExtensionsTypedDictMeta)
)
)
)

def get_notrequired_base(type) -> "Union[Any, Literal[NOTHING]]":
if get_origin(type) in (NotRequired, Required):
return get_args(type)[0]
Expand Down