Skip to content

Commit

Permalink
Add improved union support from python#15050
Browse files Browse the repository at this point in the history
  • Loading branch information
ikonst committed Apr 21, 2023
1 parent 9b491f5 commit 283fe3d
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 43 deletions.
128 changes: 100 additions & 28 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from mypy import errorcodes, message_registry
from mypy.expandtype import expand_type, expand_type_by_instance
from mypy.meet import meet_types
from mypy.messages import format_type_bare
from mypy.nodes import (
ARG_NAMED,
Expand Down Expand Up @@ -57,10 +58,13 @@
Instance,
LiteralType,
NoneType,
ProperType,
TupleType,
Type,
TypeOfAny,
TypeVarType,
UninhabitedType,
UnionType,
get_proper_type,
)
from mypy.typevars import fill_typevars
Expand Down Expand Up @@ -372,7 +376,6 @@ def _add_internal_replace_method(self, attributes: list[DataclassAttribute]) ->
arg_names=arg_names,
ret_type=NoneType(),
fallback=self._api.named_type("builtins.function"),
name=f"replace of {self._cls.info.name}",
)

self._cls.info.names[_INTERNAL_REPLACE_SYM_NAME] = SymbolTableNode(
Expand Down Expand Up @@ -923,6 +926,91 @@ def _has_direct_dataclass_transform_metaclass(info: TypeInfo) -> bool:
)


def _fail_not_dataclass(ctx: FunctionSigContext, t: Type, parent_t: Type) -> None:
t_name = format_type_bare(t, ctx.api.options)
if parent_t is t:
msg = (
f'Argument 1 to "replace" has a variable type "{t_name}" not bound to a dataclass'
if isinstance(t, TypeVarType)
else f'Argument 1 to "replace" has incompatible type "{t_name}"; expected a dataclass'
)
else:
pt_name = format_type_bare(parent_t, ctx.api.options)
msg = (
f'Argument 1 to "replace" has type "{pt_name}" whose item "{t_name}" is not bound to a dataclass'
if isinstance(t, TypeVarType)
else f'Argument 1 to "replace" has incompatible type "{pt_name}" whose item "{t_name}" is not a dataclass'
)

ctx.api.fail(msg, ctx.context)


def _get_expanded_dataclasses_fields(
ctx: FunctionSigContext, typ: ProperType, display_typ: ProperType, parent_typ: ProperType
) -> list[CallableType] | None:
"""
For a given type, determine what dataclasses it can be: for each class, return the field types.
For generic classes, the field types are expanded.
If the type contains Any or a non-dataclass, returns None; in the latter case, also reports an error.
"""
if isinstance(typ, AnyType):
return None
elif isinstance(typ, UnionType):
ret: list[CallableType] | None = []
for item in typ.relevant_items():
item = get_proper_type(item)
item_types = _get_expanded_dataclasses_fields(ctx, item, item, parent_typ)
if ret is not None and item_types is not None:
ret += item_types
else:
ret = None # but keep iterating to emit all errors
return ret
elif isinstance(typ, TypeVarType):
return _get_expanded_dataclasses_fields(
ctx, get_proper_type(typ.upper_bound), display_typ, parent_typ
)
elif isinstance(typ, Instance):
replace_sym = typ.type.get_method(_INTERNAL_REPLACE_SYM_NAME)
if replace_sym is None:
_fail_not_dataclass(ctx, display_typ, parent_typ)
return None
replace_sig = get_proper_type(replace_sym.type)
assert isinstance(replace_sig, CallableType)
return [expand_type_by_instance(replace_sig, typ)]
else:
_fail_not_dataclass(ctx, display_typ, parent_typ)
return None


def _meet_replace_sigs(sigs: list[CallableType]) -> CallableType:
"""
Produces the lowest bound of the 'replace' signatures of multiple dataclasses.
"""
args = {
name: (typ, kind)
for name, typ, kind in zip(sigs[0].arg_names, sigs[0].arg_types, sigs[0].arg_kinds)
}

for sig in sigs[1:]:
sig_args = {
name: (typ, kind)
for name, typ, kind in zip(sig.arg_names, sig.arg_types, sig.arg_kinds)
}
for name in (*args.keys(), *sig_args.keys()):
sig_typ, sig_kind = args.get(name, (UninhabitedType(), ARG_NAMED_OPT))
sig2_typ, sig2_kind = sig_args.get(name, (UninhabitedType(), ARG_NAMED_OPT))
args[name] = (
meet_types(sig_typ, sig2_typ),
ARG_NAMED_OPT if sig_kind == sig2_kind == ARG_NAMED_OPT else ARG_NAMED,
)

return sigs[0].copy_modified(
arg_names=list(args.keys()),
arg_types=[typ for typ, _ in args.values()],
arg_kinds=[kind for _, kind in args.values()],
)


def replace_function_sig_callback(ctx: FunctionSigContext) -> CallableType:
"""
Returns a signature for the 'dataclasses.replace' function that's dependent on the type
Expand All @@ -946,34 +1034,18 @@ def replace_function_sig_callback(ctx: FunctionSigContext) -> CallableType:
# </hack>

obj_type = get_proper_type(obj_type)
obj_type_str = format_type_bare(obj_type)
if isinstance(obj_type, AnyType):
return ctx.default_signature # replace(Any, ...) -> Any
inst_type_str = format_type_bare(obj_type, ctx.api.options)

dataclass_type = get_proper_type(
obj_type.upper_bound if isinstance(obj_type, TypeVarType) else obj_type
)
replace_func = None
if isinstance(dataclass_type, Instance):
replace_func = dataclass_type.type.get_method(_INTERNAL_REPLACE_SYM_NAME)
if replace_func is None:
ctx.api.fail(
f'Argument 1 to "replace" has variable type "{obj_type_str}" not bound to a dataclass'
if isinstance(obj_type, TypeVarType)
else f'Argument 1 to "replace" has incompatible type "{obj_type_str}"; expected a dataclass',
ctx.context,
)
replace_sigs = _get_expanded_dataclasses_fields(ctx, obj_type, obj_type, obj_type)
if replace_sigs is None:
return ctx.default_signature
assert isinstance(dataclass_type, Instance)

signature = get_proper_type(replace_func.type)
assert isinstance(signature, CallableType)
signature = expand_type_by_instance(signature, dataclass_type)
# re-add the instance type
return signature.copy_modified(
arg_types=[obj_type, *signature.arg_types],
arg_kinds=[ARG_POS, *signature.arg_kinds],
arg_names=[None, *signature.arg_names],
replace_sig = _meet_replace_sigs(replace_sigs)

return replace_sig.copy_modified(
arg_names=[None, *replace_sig.arg_names],
arg_kinds=[ARG_POS, *replace_sig.arg_kinds],
arg_types=[obj_type, *replace_sig.arg_types],
ret_type=obj_type,
name=f"{ctx.default_signature.name} of {obj_type_str}",
fallback=ctx.default_signature.fallback,
name=f"{ctx.default_signature.name} of {inst_type_str}",
)
89 changes: 74 additions & 15 deletions test-data/unit/check-dataclasses.test
Original file line number Diff line number Diff line change
Expand Up @@ -2059,39 +2059,101 @@ a2 = replace(a, x='42', q=42) # E: Argument "x" to "replace" of "A" has incompa
a2 = replace(a, q='42') # E: Argument "q" to "replace" of "A" has incompatible type "str"; expected "int"
reveal_type(a2) # N: Revealed type is "__main__.A"

[case testReplaceUnion]
# flags: --strict-optional
from typing import Generic, Union, TypeVar
from dataclasses import dataclass, replace, InitVar

T = TypeVar('T')

@dataclass
class A(Generic[T]):
x: T # exercises meet(T=int, int) = int
y: bool # exercises meet(bool, int) = bool
z: str # exercises meet(str, bytes) = <nothing>
w: dict # exercises meet(dict, <nothing>) = <nothing>
a: InitVar[int] # exercises (non-optional, optional) = non-optional

@dataclass
class B:
x: int
y: bool
z: bytes
a: int


a_or_b: Union[A[int], B]
_ = replace(a_or_b, x=42, y=True, a=42)
_ = replace(a_or_b, x=42, y=True) # E: Missing named argument "a" for "replace" of "Union[A[int], B]"
_ = replace(a_or_b, x=42, y=True, z='42', a=42) # E: Argument "z" to "replace" of "Union[A[int], B]" has incompatible type "str"; expected <nothing>
_ = replace(a_or_b, x=42, y=True, w={}, a=42) # E: Argument "w" to "replace" of "Union[A[int], B]" has incompatible type "Dict[<nothing>, <nothing>]"; expected <nothing>

[builtins fixtures/dataclasses.pyi]

[case testReplaceTypeVar]
[case testReplaceUnionOfTypeVar]
# flags: --strict-optional
from typing import Generic, Union, TypeVar
from dataclasses import dataclass, replace
from typing import TypeVar

@dataclass
class A:
x: int
y: int
z: str
w: dict

class B:
pass

TA = TypeVar('TA', bound=A)
TB = TypeVar('TB', bound=B)

def f(b_or_t: Union[TA, TB, int]) -> None:
a2 = replace(b_or_t) # E: Argument 1 to "replace" has type "Union[TA, TB, int]" whose item "TB" is not bound to a dataclass # E: Argument 1 to "replace" has incompatible type "Union[TA, TB, int]" whose item "int" is not a dataclass

[case testReplaceTypeVarBoundNotDataclass]
from dataclasses import dataclass, replace
from typing import Union, TypeVar

TInt = TypeVar('TInt', bound=int)
TAny = TypeVar('TAny')
TNone = TypeVar('TNone', bound=None)
TUnion = TypeVar('TUnion', bound=Union[str, int])

def f1(t: TInt) -> None:
_ = replace(t, x=42) # E: Argument 1 to "replace" has a variable type "TInt" not bound to a dataclass

def f(t: TA) -> TA:
_ = replace(t, x='spam') # E: Argument "x" to "replace" of "TA" has incompatible type "str"; expected "int"
return replace(t, x=42)
def f2(t: TAny) -> TAny:
return replace(t, x='spam') # E: Argument 1 to "replace" has a variable type "TAny" not bound to a dataclass

def f3(t: TNone) -> TNone:
return replace(t, x='spam') # E: Argument 1 to "replace" has a variable type "TNone" not bound to a dataclass

def g(t: TInt) -> None:
_ = replace(t, x=42) # E: Argument 1 to "replace" has variable type "TInt" not bound to a dataclass
def f4(t: TUnion) -> TUnion:
return replace(t, x='spam') # E: Argument 1 to "replace" has incompatible type "TUnion" whose item "str" is not a dataclass # E: Argument 1 to "replace" has incompatible type "TUnion" whose item "int" is not a dataclass

[case testReplaceTypeVarBound]
from dataclasses import dataclass, replace
from typing import TypeVar

def h(t: TAny) -> TAny:
return replace(t, x='spam') # E: Argument 1 to "replace" has variable type "TAny" not bound to a dataclass
@dataclass
class A:
x: int

@dataclass
class B(A):
pass

def q(t: TNone) -> TNone:
return replace(t, x='spam') # E: Argument 1 to "replace" has variable type "TNone" not bound to a dataclass
TA = TypeVar('TA', bound=A)

[builtins fixtures/dataclasses.pyi]
def f(t: TA) -> TA:
t2 = replace(t, x=42)
reveal_type(t2) # N: Revealed type is "TA`-1"
_ = replace(t, x='42') # E: Argument "x" to "replace" of "TA" has incompatible type "str"; expected "int"
return t2

f(A(x=42))
f(B(x=42))

[case testReplaceAny]
from dataclasses import replace
Expand All @@ -2101,8 +2163,6 @@ a: Any
a2 = replace(a)
reveal_type(a2) # N: Revealed type is "Any"

[builtins fixtures/dataclasses.pyi]

[case testReplaceNotDataclass]
from dataclasses import replace

Expand All @@ -2125,7 +2185,6 @@ T = TypeVar('T')
class A(Generic[T]):
x: T


a = A(x=42)
reveal_type(a) # N: Revealed type is "__main__.A[builtins.int]"
a2 = replace(a, x=42)
Expand Down

0 comments on commit 283fe3d

Please sign in to comment.