Skip to content

Commit

Permalink
Switched from __dataclass_transform__() to typing.dataclass_transform()
Browse files Browse the repository at this point in the history
Fixes #1157
  • Loading branch information
superbobry committed Jun 28, 2023
1 parent 261d26e commit 401be1c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 26 deletions.
35 changes: 11 additions & 24 deletions src/attr/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ if sys.version_info >= (3, 10):
else:
from typing_extensions import TypeGuard

if sys.version_info >= (3, 11):
from typing import dataclass_transform
else:
from typing_extensions import dataclass_transform

__version__: str
__version_info__: VersionInfo
__title__: str
Expand Down Expand Up @@ -93,7 +98,6 @@ if sys.version_info >= (3, 8):
factory: Callable[[], _T],
takes_self: Literal[False],
) -> _T: ...

else:
@overload
def Factory(factory: Callable[[], _T]) -> _T: ...
Expand All @@ -103,23 +107,6 @@ else:
takes_self: bool = ...,
) -> _T: ...

# Static type inference support via __dataclass_transform__ implemented as per:
# https://github.com/microsoft/pyright/blob/1.1.135/specs/dataclass_transforms.md
# This annotation must be applied to all overloads of "define" and "attrs"
#
# NOTE: This is a typing construct and does not exist at runtime. Extensions
# wrapping attrs decorators should declare a separate __dataclass_transform__
# signature in the extension module using the specification linked above to
# provide pyright support.
def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
frozen_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]: ...

class Attribute(Generic[_T]):
name: str
default: Optional[_T]
Expand Down Expand Up @@ -322,7 +309,7 @@ def field(
type: Optional[type] = ...,
) -> Any: ...
@overload
@__dataclass_transform__(order_default=True, field_descriptors=(attrib, field))
@dataclass_transform(order_default=True, field_descriptors=(attrib, field))
def attrs(
maybe_cls: _C,
these: Optional[Dict[str, Any]] = ...,
Expand Down Expand Up @@ -350,7 +337,7 @@ def attrs(
unsafe_hash: Optional[bool] = ...,
) -> _C: ...
@overload
@__dataclass_transform__(order_default=True, field_descriptors=(attrib, field))
@dataclass_transform(order_default=True, field_descriptors=(attrib, field))
def attrs(
maybe_cls: None = ...,
these: Optional[Dict[str, Any]] = ...,
Expand Down Expand Up @@ -378,7 +365,7 @@ def attrs(
unsafe_hash: Optional[bool] = ...,
) -> Callable[[_C], _C]: ...
@overload
@__dataclass_transform__(field_descriptors=(attrib, field))
@dataclass_transform(field_descriptors=(attrib, field))
def define(
maybe_cls: _C,
*,
Expand All @@ -404,7 +391,7 @@ def define(
match_args: bool = ...,
) -> _C: ...
@overload
@__dataclass_transform__(field_descriptors=(attrib, field))
@dataclass_transform(field_descriptors=(attrib, field))
def define(
maybe_cls: None = ...,
*,
Expand Down Expand Up @@ -433,7 +420,7 @@ def define(
mutable = define

@overload
@__dataclass_transform__(
@dataclass_transform(
frozen_default=True, field_descriptors=(attrib, field)
)
def frozen(
Expand Down Expand Up @@ -461,7 +448,7 @@ def frozen(
match_args: bool = ...,
) -> _C: ...
@overload
@__dataclass_transform__(
@dataclass_transform(
frozen_default=True, field_descriptors=(attrib, field)
)
def frozen(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def parse_pyright_output(test_file: Path) -> set[PyrightDiagnostic]:

def test_pyright_baseline():
"""
The __dataclass_transform__ decorator allows pyright to determine attrs
decorated class types.
The typing.dataclass_transform() decorator allows pyright to determine
attrs decorated class types.
"""

test_file = Path(__file__).parent / "dataclass_transform_example.py"
Expand Down

0 comments on commit 401be1c

Please sign in to comment.