Skip to content

Commit

Permalink
squashed
Browse files Browse the repository at this point in the history
  • Loading branch information
ikonst committed Apr 13, 2023
1 parent 5005428 commit b32252f
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 25 deletions.
110 changes: 87 additions & 23 deletions mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@

from __future__ import annotations

from typing import Iterable, List, cast
from collections import defaultdict
from functools import reduce
from typing import Iterable, List, Mapping, cast
from typing_extensions import Final, Literal

import mypy.plugin # To avoid circular imports.
from mypy.applytype import apply_generic_arguments
from mypy.checker import TypeChecker
from mypy.errorcodes import LITERAL_REQ
from mypy.expandtype import expand_type
from mypy.expandtype import expand_type, expand_type_by_instance
from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type
from mypy.meet import meet_types
from mypy.messages import format_type_bare
from mypy.nodes import (
ARG_NAMED,
Expand Down Expand Up @@ -67,6 +70,7 @@
Type,
TypeOfAny,
TypeVarType,
UninhabitedType,
UnionType,
get_proper_type,
)
Expand Down Expand Up @@ -943,12 +947,81 @@ def _get_attrs_init_type(typ: Instance) -> CallableType | None:
return init_method.type


def _get_attrs_cls_and_init(typ: ProperType) -> tuple[Instance | None, CallableType | None]:
def _format_not_attrs_class_failure(t: Type, parent_t: Type) -> str:
t_name = format_type_bare(t)
if parent_t is t:
return (
f'Argument 1 to "evolve" has a variable type "{t_name}" not bound to an attrs class'
if isinstance(t, TypeVarType)
else f'Argument 1 to "evolve" has incompatible type "{t_name}"; expected an attrs class'
)
else:
pt_name = format_type_bare(parent_t)
return (
f'Argument 1 to "evolve" has type "{pt_name}" whose item "{t_name}" is not bound to an attrs class'
if isinstance(t, TypeVarType)
else f'Argument 1 to "evolve" has incompatible type "{pt_name}" whose item "{t_name}" is not an attrs class'
)


def _get_expanded_attr_types(
ctx: mypy.plugin.FunctionSigContext,
typ: ProperType,
display_typ: ProperType,
parent_typ: ProperType,
) -> list[Mapping[str, Type]] | None:
"""
For a given type, determine what attrs classes it can be, and returns the field types for each class.
For generic classes, the field types are expanded.
If the type contains Any or a non-attrs type, returns None; in the latter case, also reports an error.
"""
if isinstance(typ, AnyType):
return None
if isinstance(typ, UnionType):
types = []
had_errors = False
for item in typ.relevant_items():
item = get_proper_type(item)
item_types = _get_expanded_attr_types(ctx, item, item, parent_typ)
if isinstance(item_types, list):
types += item_types
else:
had_errors = True
if had_errors:
return None
return types
if isinstance(typ, TypeVarType):
typ = get_proper_type(typ.upper_bound)
return _get_expanded_attr_types(
ctx, get_proper_type(typ.upper_bound), display_typ, parent_typ
)
if not isinstance(typ, Instance):
return None, None
return typ, _get_attrs_init_type(typ)
ctx.api.fail(_format_not_attrs_class_failure(display_typ, parent_typ), ctx.context)
return None
init_func = _get_attrs_init_type(typ)
if init_func is None:
ctx.api.fail(_format_not_attrs_class_failure(display_typ, parent_typ), ctx.context)
return None
init_func = expand_type_by_instance(init_func, typ)
field_names = cast(List[str], init_func.arg_names[1:])
field_types = init_func.arg_types[1:]
return [dict(zip(field_names, field_types))]


def _meet_fields(types: list[Mapping[str, Type]]) -> Mapping[str, Type]:
"""
"Meets" the fields of a list of attrs classes, i.e. for each field, its new type will be the lower bound.
"""
field_to_types = defaultdict(list)
for fields in types:
for name, typ in fields.items():
field_to_types[name].append(typ)

return {
name: get_proper_type(reduce(meet_types, f_types))
if len(f_types) == len(types)
else UninhabitedType()
for name, f_types in field_to_types.items()
}


def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType:
Expand All @@ -972,27 +1045,18 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl
# </hack>

inst_type = get_proper_type(inst_type)
if isinstance(inst_type, AnyType):
return ctx.default_signature # evolve(Any, ....) -> Any
inst_type_str = format_type_bare(inst_type)

attrs_type, attrs_init_type = _get_attrs_cls_and_init(inst_type)
if attrs_type is None or attrs_init_type is None:
ctx.api.fail(
f'Argument 1 to "evolve" has a variable type "{inst_type_str}" not bound to an attrs class'
if isinstance(inst_type, TypeVarType)
else f'Argument 1 to "evolve" has incompatible type "{inst_type_str}"; expected an attrs class',
ctx.context,
)
attr_types = _get_expanded_attr_types(ctx, inst_type, inst_type, inst_type)
if attr_types is None:
return ctx.default_signature
fields = _meet_fields(attr_types)

# AttrClass.__init__ has the following signature (or similar, if having kw-only & defaults):
# def __init__(self, attr1: Type1, attr2: Type2) -> None:
# We want to generate a signature for evolve that looks like this:
# def evolve(inst: AttrClass, *, attr1: Type1 = ..., attr2: Type2 = ...) -> AttrClass:
return attrs_init_type.copy_modified(
arg_names=["inst"] + attrs_init_type.arg_names[1:],
arg_kinds=[ARG_POS] + [ARG_NAMED_OPT for _ in attrs_init_type.arg_kinds[1:]],
return CallableType(
arg_names=["inst", *fields.keys()],
arg_kinds=[ARG_POS] + [ARG_NAMED_OPT] * len(fields),
arg_types=[inst_type, *fields.values()],
ret_type=inst_type,
fallback=ctx.default_signature.fallback,
name=f"{ctx.default_signature.name} of {inst_type_str}",
)
81 changes: 80 additions & 1 deletion test-data/unit/check-attr.test
Original file line number Diff line number Diff line change
Expand Up @@ -1970,6 +1970,81 @@ reveal_type(ret) # N: Revealed type is "Any"

[typing fixtures/typing-medium.pyi]

[case testEvolveGeneric]
import attrs
from typing import Generic, TypeVar

T = TypeVar('T')

@attrs.define
class A(Generic[T]):
x: T


a = A(x=42)
reveal_type(a) # N: Revealed type is "__main__.A[builtins.int]"
a2 = attrs.evolve(a, x=42)
reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]"
a2 = attrs.evolve(a, x='42') # E: Argument "x" to "evolve" of "A[int]" has incompatible type "str"; expected "int"
reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]"

[builtins fixtures/attr.pyi]

[case testEvolveUnion]
# flags: --python-version 3.10
from typing import Generic, TypeVar
import attrs

T = TypeVar('T')


@attrs.define
class A(Generic[T]):
x: T # exercises meet(T=int, int) = int
y: bool # exercises meet(bool, int) = bool
z: str # exercises meet(str, bytes) = <nothing>
w: dict # exercises meet(dict, <nothing>) = <nothing>


@attrs.define
class B:
x: int
y: bool
z: bytes


a_or_b: A[int] | B
a2 = attrs.evolve(a_or_b, x=42, y=True)
a2 = attrs.evolve(a_or_b, x=42, y=True, z='42') # E: Argument "z" to "evolve" of "Union[A[int], B]" has incompatible type "str"; expected <nothing>
a2 = attrs.evolve(a_or_b, x=42, y=True, w={}) # E: Argument "w" to "evolve" of "Union[A[int], B]" has incompatible type "Dict[<nothing>, <nothing>]"; expected <nothing>

[builtins fixtures/attr.pyi]

[case testEvolveUnionOfTypeVar]
# flags: --python-version 3.10
import attrs
from typing import TypeVar

@attrs.define
class A:
x: int
y: int
z: str
w: dict


class B:
pass

TA = TypeVar('TA', bound=A)
TB = TypeVar('TB', bound=B)

def f(b_or_t: TA | TB | int) -> None:
a2 = attrs.evolve(b_or_t) # E: Argument 1 to "evolve" has type "Union[TA, TB, int]" whose item "TB" is not bound to an attrs class # E: Argument 1 to "evolve" has incompatible type "Union[TA, TB, int]" whose item "int" is not an attrs class


[builtins fixtures/attr.pyi]

[case testEvolveTypeVarBound]
import attrs
from typing import TypeVar
Expand Down Expand Up @@ -1997,11 +2072,12 @@ f(B(x=42))

[case testEvolveTypeVarBoundNonAttrs]
import attrs
from typing import TypeVar
from typing import Union, TypeVar

TInt = TypeVar('TInt', bound=int)
TAny = TypeVar('TAny')
TNone = TypeVar('TNone', bound=None)
TUnion = TypeVar('TUnion', bound=Union[str, int])

def f(t: TInt) -> None:
_ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has a variable type "TInt" not bound to an attrs class
Expand All @@ -2012,6 +2088,9 @@ def g(t: TAny) -> None:
def h(t: TNone) -> None:
_ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has a variable type "TNone" not bound to an attrs class

def x(t: TUnion) -> None:
_ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has incompatible type "TUnion" whose item "str" is not an attrs class # E: Argument 1 to "evolve" has incompatible type "TUnion" whose item "int" is not an attrs class

[builtins fixtures/attr.pyi]

[case testEvolveTypeVarConstrained]
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/fixtures/attr.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ class object:
class type: pass
class bytes: pass
class function: pass
class bool: pass
class float: pass
class int:
@overload
def __init__(self, x: Union[str, bytes, int] = ...) -> None: ...
@overload
def __init__(self, x: Union[str, bytes], base: int) -> None: ...
class bool(int): pass
class complex:
@overload
def __init__(self, real: float = ..., im: float = ...) -> None: ...
Expand Down

0 comments on commit b32252f

Please sign in to comment.