diff --git a/src/attr/__init__.pyi b/src/attr/__init__.pyi index 51d1af3a0..303ea7402 100644 --- a/src/attr/__init__.pyi +++ b/src/attr/__init__.pyi @@ -33,6 +33,11 @@ if sys.version_info >= (3, 10): else: from typing_extensions import TypeGuard +if sys.version_info >= (3, 11): + from typing import dataclass_transform +else: + from typing_extensions import dataclass_transform + __version__: str __version_info__: VersionInfo __title__: str @@ -93,7 +98,6 @@ if sys.version_info >= (3, 8): factory: Callable[[], _T], takes_self: Literal[False], ) -> _T: ... - else: @overload def Factory(factory: Callable[[], _T]) -> _T: ... @@ -103,23 +107,6 @@ else: takes_self: bool = ..., ) -> _T: ... -# Static type inference support via __dataclass_transform__ implemented as per: -# https://github.com/microsoft/pyright/blob/1.1.135/specs/dataclass_transforms.md -# This annotation must be applied to all overloads of "define" and "attrs" -# -# NOTE: This is a typing construct and does not exist at runtime. Extensions -# wrapping attrs decorators should declare a separate __dataclass_transform__ -# signature in the extension module using the specification linked above to -# provide pyright support. -def __dataclass_transform__( - *, - eq_default: bool = True, - order_default: bool = False, - kw_only_default: bool = False, - frozen_default: bool = False, - field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()), -) -> Callable[[_T], _T]: ... - class Attribute(Generic[_T]): name: str default: Optional[_T] @@ -322,7 +309,7 @@ def field( type: Optional[type] = ..., ) -> Any: ... @overload -@__dataclass_transform__(order_default=True, field_descriptors=(attrib, field)) +@dataclass_transform(order_default=True, field_descriptors=(attrib, field)) def attrs( maybe_cls: _C, these: Optional[Dict[str, Any]] = ..., @@ -350,7 +337,7 @@ def attrs( unsafe_hash: Optional[bool] = ..., ) -> _C: ... @overload -@__dataclass_transform__(order_default=True, field_descriptors=(attrib, field)) +@dataclass_transform(order_default=True, field_descriptors=(attrib, field)) def attrs( maybe_cls: None = ..., these: Optional[Dict[str, Any]] = ..., @@ -378,7 +365,7 @@ def attrs( unsafe_hash: Optional[bool] = ..., ) -> Callable[[_C], _C]: ... @overload -@__dataclass_transform__(field_descriptors=(attrib, field)) +@dataclass_transform(field_descriptors=(attrib, field)) def define( maybe_cls: _C, *, @@ -404,7 +391,7 @@ def define( match_args: bool = ..., ) -> _C: ... @overload -@__dataclass_transform__(field_descriptors=(attrib, field)) +@dataclass_transform(field_descriptors=(attrib, field)) def define( maybe_cls: None = ..., *, @@ -433,7 +420,7 @@ def define( mutable = define @overload -@__dataclass_transform__( +@dataclass_transform( frozen_default=True, field_descriptors=(attrib, field) ) def frozen( @@ -461,7 +448,7 @@ def frozen( match_args: bool = ..., ) -> _C: ... @overload -@__dataclass_transform__( +@dataclass_transform( frozen_default=True, field_descriptors=(attrib, field) ) def frozen( diff --git a/tests/test_pyright.py b/tests/test_pyright.py index fc2d31350..1463a0f89 100644 --- a/tests/test_pyright.py +++ b/tests/test_pyright.py @@ -41,8 +41,8 @@ def parse_pyright_output(test_file: Path) -> set[PyrightDiagnostic]: def test_pyright_baseline(): """ - The __dataclass_transform__ decorator allows pyright to determine attrs - decorated class types. + The typing.dataclass_transform() decorator allows pyright to determine + attrs decorated class types. """ test_file = Path(__file__).parent / "dataclass_transform_example.py"