Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify cattrs._compat.is_typeddict #384

Merged
merged 9 commits into from
Jun 14, 2023
36 changes: 5 additions & 31 deletions src/cattrs/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@
from attr import fields as attrs_fields
from attr import resolve_types

__all__ = ["ExtensionsTypedDict", "is_typeddict", "TypedDict"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like __all__ was deleted from this file in #382, but I'm not entirely sure why. I've added it back here, as otherwise ruff complains (reasonably) that is_typeddict is imported but unused. The alternative would be to do from typing_extensions import is_typeddict as is_typeddict, but I find that syntax really ugly personally :p

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops, my bad. You might wanna add ExceptionGroup too which uses the x as x syntax currently.


try:
from typing_extensions import TypedDict as ExtensionsTypedDict
except ImportError:
ExtensionsTypedDict = None

try:
from typing_extensions import _TypedDictMeta as ExtensionsTypedDictMeta
from typing_extensions import is_typeddict
except ImportError:
ExtensionsTypedDictMeta = None
assert sys.version_info >= (3, 10)
from typing import is_typeddict

if sys.version_info >= (3, 8):
from typing import Final, Protocol, get_args, get_origin
Expand Down Expand Up @@ -157,7 +160,6 @@ def get_final_base(type) -> Optional[type]:
_AnnotatedAlias,
_GenericAlias,
_SpecialGenericAlias,
_TypedDictMeta,
_UnionGenericAlias,
)

Expand Down Expand Up @@ -234,20 +236,6 @@ def get_newtype_base(typ: Any) -> Optional[type]:
return supertype
return None

def is_typeddict(cls) -> bool:
return (
cls.__class__ is _TypedDictMeta
or (is_generic(cls) and (cls.__origin__.__class__ is _TypedDictMeta))
or (
ExtensionsTypedDictMeta is not None
and cls.__class__ is ExtensionsTypedDictMeta
or (
is_generic(cls)
and (cls.__origin__.__class__ is ExtensionsTypedDictMeta)
)
)
)

def get_notrequired_base(type) -> "Union[Any, Literal[NOTHING]]":
if get_origin(type) in (NotRequired, Required):
return get_args(type)[0]
Expand Down Expand Up @@ -462,20 +450,6 @@ def copy_with(type, args):
"""Replace a generic type's arguments."""
return type.copy_with(args)

def is_typeddict(cls) -> bool:
return (
cls.__class__ is _TypedDictMeta
or (is_generic(cls) and (cls.__origin__.__class__ is _TypedDictMeta))
or (
ExtensionsTypedDictMeta is not None
and cls.__class__ is ExtensionsTypedDictMeta
or (
is_generic(cls)
and (cls.__origin__.__class__ is ExtensionsTypedDictMeta)
)
)
)

def get_notrequired_base(type) -> "Union[Any, Literal[NOTHING]]":
if get_origin(type) in (NotRequired, Required):
return get_args(type)[0]
Expand Down