Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some final touches for variadic types support #16334

Merged
merged 14 commits into from
Oct 28, 2023
7 changes: 7 additions & 0 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Callable, Sequence

import mypy.subtypes
from mypy.erasetype import erase_typevars
from mypy.expandtype import expand_type
from mypy.nodes import Context
from mypy.types import (
Expand Down Expand Up @@ -62,6 +63,11 @@ def get_target_type(
report_incompatible_typevar_value(callable, type, tvar.name, context)
else:
upper_bound = tvar.upper_bound
if tvar.name == "Self":
# Internally constructed Self-types contain class type variables in upper bound,
# so we need to erase them to avoid false positives. This is safe because we do
# not support type variables in upper bounds of user defined types.
upper_bound = erase_typevars(upper_bound)
if not mypy.subtypes.is_subtype(type, upper_bound):
if skip_unsatisfied:
return None
Expand Down Expand Up @@ -121,6 +127,7 @@ def apply_generic_arguments(
# Apply arguments to argument types.
var_arg = callable.var_arg()
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
# Same as for ParamSpec, callable with variadic types needs to be expanded as a whole.
callable = expand_type(callable, id_to_type)
assert isinstance(callable, CallableType)
return callable.copy_modified(variables=[tv for tv in tvars if tv.id not in id_to_type])
Expand Down
39 changes: 20 additions & 19 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1852,7 +1852,6 @@ def expand_typevars(
if defn.info:
# Class type variables
tvars += defn.info.defn.type_vars or []
# TODO(PEP612): audit for paramspec
for tvar in tvars:
if isinstance(tvar, TypeVarType) and tvar.values:
subst.append([(tvar.id, value) for value in tvar.values])
Expand Down Expand Up @@ -2538,6 +2537,9 @@ def check_protocol_variance(self, defn: ClassDef) -> None:
object_type = Instance(info.mro[-1], [])
tvars = info.defn.type_vars
for i, tvar in enumerate(tvars):
if not isinstance(tvar, TypeVarType):
# Variance of TypeVarTuple and ParamSpec is underspecified by PEPs.
continue
up_args: list[Type] = [
object_type if i == j else AnyType(TypeOfAny.special_form)
for j, _ in enumerate(tvars)
Expand All @@ -2554,7 +2556,7 @@ def check_protocol_variance(self, defn: ClassDef) -> None:
expected = CONTRAVARIANT
else:
expected = INVARIANT
if isinstance(tvar, TypeVarType) and expected != tvar.variance:
if expected != tvar.variance:
self.msg.bad_proto_variance(tvar.variance, tvar.name, expected, defn)

def check_multiple_inheritance(self, typ: TypeInfo) -> None:
Expand Down Expand Up @@ -6695,19 +6697,6 @@ def check_possible_missing_await(
return
self.msg.possible_missing_await(context, code)

def contains_none(self, t: Type) -> bool:
t = get_proper_type(t)
return (
isinstance(t, NoneType)
or (isinstance(t, UnionType) and any(self.contains_none(ut) for ut in t.items))
or (isinstance(t, TupleType) and any(self.contains_none(tt) for tt in t.items))
or (
isinstance(t, Instance)
and bool(t.args)
and any(self.contains_none(it) for it in t.args)
)
)

def named_type(self, name: str) -> Instance:
"""Return an instance type with given name and implicit Any type args.

Expand Down Expand Up @@ -7471,10 +7460,22 @@ def builtin_item_type(tp: Type) -> Type | None:
return None
if not isinstance(get_proper_type(tp.args[0]), AnyType):
return tp.args[0]
elif isinstance(tp, TupleType) and all(
not isinstance(it, AnyType) for it in get_proper_types(tp.items)
):
return make_simplified_union(tp.items) # this type is not externally visible
elif isinstance(tp, TupleType):
normalized_items = []
for it in tp.items:
# This use case is probably rare, but not handling unpacks here can cause crashes.
if isinstance(it, UnpackType):
unpacked = get_proper_type(it.type)
if isinstance(unpacked, TypeVarTupleType):
unpacked = get_proper_type(unpacked.upper_bound)
assert (
isinstance(unpacked, Instance) and unpacked.type.fullname == "builtins.tuple"
)
normalized_items.append(unpacked.args[0])
else:
normalized_items.append(it)
if all(not isinstance(it, AnyType) for it in get_proper_types(normalized_items)):
return make_simplified_union(normalized_items) # this type is not externally visible
elif isinstance(tp, TypedDictType):
# TypedDict always has non-optional string keys. Find the key type from the Mapping
# base class.
Expand Down
7 changes: 4 additions & 3 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
result = self.alias_type_in_runtime_context(
node, ctx=e, alias_definition=e.is_alias_rvalue or lvalue
)
elif isinstance(node, (TypeVarExpr, ParamSpecExpr)):
elif isinstance(node, (TypeVarExpr, ParamSpecExpr, TypeVarTupleExpr)):
result = self.object_type()
else:
if isinstance(node, PlaceholderNode):
Expand Down Expand Up @@ -3312,6 +3312,7 @@ def infer_literal_expr_type(self, value: LiteralValue, fallback_name: str) -> Ty

def concat_tuples(self, left: TupleType, right: TupleType) -> TupleType:
"""Concatenate two fixed length tuples."""
assert not (find_unpack_in_list(left.items) and find_unpack_in_list(right.items))
return TupleType(
items=left.items + right.items, fallback=self.named_type("builtins.tuple")
)
Expand Down Expand Up @@ -6503,8 +6504,8 @@ def merge_typevars_in_callables_by_name(
for tv in target.variables:
name = tv.fullname
if name not in unique_typevars:
# TODO(PEP612): fix for ParamSpecType
if isinstance(tv, ParamSpecType):
# TODO: support ParamSpecType and TypeVarTuple.
if isinstance(tv, (ParamSpecType, TypeVarTupleType)):
continue
assert isinstance(tv, TypeVarType)
unique_typevars[name] = tv
Expand Down
111 changes: 93 additions & 18 deletions mypy/checkpattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,13 @@
Type,
TypedDictType,
TypeOfAny,
TypeVarTupleType,
UninhabitedType,
UnionType,
UnpackType,
find_unpack_in_list,
get_proper_type,
split_with_prefix_and_suffix,
)
from mypy.typevars import fill_typevars
from mypy.visitor import PatternVisitor
Expand Down Expand Up @@ -239,13 +243,29 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
#
# get inner types of original type
#
unpack_index = None
if isinstance(current_type, TupleType):
inner_types = current_type.items
size_diff = len(inner_types) - required_patterns
if size_diff < 0:
return self.early_non_match()
elif size_diff > 0 and star_position is None:
return self.early_non_match()
unpack_index = find_unpack_in_list(inner_types)
if unpack_index is None:
size_diff = len(inner_types) - required_patterns
if size_diff < 0:
return self.early_non_match()
elif size_diff > 0 and star_position is None:
return self.early_non_match()
else:
normalized_inner_types = []
for it in inner_types:
# Unfortunately, it is not possible to "split" the TypeVarTuple
# into individual items, so we just use its upper bound for the whole
# analysis instead.
if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType):
it = UnpackType(it.type.upper_bound)
normalized_inner_types.append(it)
inner_types = normalized_inner_types
current_type = current_type.copy_modified(items=normalized_inner_types)
if len(inner_types) - 1 > required_patterns and star_position is None:
return self.early_non_match()
else:
inner_type = self.get_sequence_type(current_type, o)
if inner_type is None:
Expand All @@ -270,18 +290,18 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
self.update_type_map(captures, type_map)

new_inner_types = self.expand_starred_pattern_types(
contracted_new_inner_types, star_position, len(inner_types)
contracted_new_inner_types, star_position, len(inner_types), unpack_index is not None
)
rest_inner_types = self.expand_starred_pattern_types(
contracted_rest_inner_types, star_position, len(inner_types)
contracted_rest_inner_types, star_position, len(inner_types), unpack_index is not None
)

#
# Calculate new type
#
new_type: Type
rest_type: Type = current_type
if isinstance(current_type, TupleType):
if isinstance(current_type, TupleType) and unpack_index is None:
narrowed_inner_types = []
inner_rest_types = []
for inner_type, new_inner_type in zip(inner_types, new_inner_types):
Expand All @@ -301,6 +321,14 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
if all(is_uninhabited(typ) for typ in inner_rest_types):
# All subpatterns always match, so we can apply negative narrowing
rest_type = TupleType(rest_inner_types, current_type.partial_fallback)
elif isinstance(current_type, TupleType):
# For variadic tuples it is too tricky to match individual items like for fixed
# tuples, so we instead try to narrow the entire type.
# TODO: use more precise narrowing when possible (e.g. for identical shapes).
new_tuple_type = TupleType(new_inner_types, current_type.partial_fallback)
new_type, rest_type = self.chk.conditional_types_with_intersection(
new_tuple_type, [get_type_range(current_type)], o, default=new_tuple_type
)
else:
new_inner_type = UninhabitedType()
for typ in new_inner_types:
Expand Down Expand Up @@ -345,17 +373,45 @@ def contract_starred_pattern_types(

If star_pos in None the types are returned unchanged.
"""
if star_pos is None:
return types
new_types = types[:star_pos]
star_length = len(types) - num_patterns
new_types.append(make_simplified_union(types[star_pos : star_pos + star_length]))
new_types += types[star_pos + star_length :]

return new_types
unpack_index = find_unpack_in_list(types)
if unpack_index is not None:
# Variadic tuples require "re-shaping" to match the requested pattern.
unpack = types[unpack_index]
assert isinstance(unpack, UnpackType)
unpacked = get_proper_type(unpack.type)
# This should be guaranteed by the normalization in the caller.
assert isinstance(unpacked, Instance) and unpacked.type.fullname == "builtins.tuple"
if star_pos is None:
missing = num_patterns - len(types) + 1
new_types = types[:unpack_index]
new_types += [unpacked.args[0]] * missing
new_types += types[unpack_index + 1 :]
return new_types
prefix, middle, suffix = split_with_prefix_and_suffix(
tuple([UnpackType(unpacked) if isinstance(t, UnpackType) else t for t in types]),
star_pos,
num_patterns - star_pos,
)
new_middle = []
for m in middle:
# The existing code expects the star item type, rather than the type of
# the whole tuple "slice".
if isinstance(m, UnpackType):
new_middle.append(unpacked.args[0])
else:
new_middle.append(m)
return list(prefix) + [make_simplified_union(new_middle)] + list(suffix)
else:
if star_pos is None:
return types
new_types = types[:star_pos]
star_length = len(types) - num_patterns
new_types.append(make_simplified_union(types[star_pos : star_pos + star_length]))
new_types += types[star_pos + star_length :]
return new_types

def expand_starred_pattern_types(
self, types: list[Type], star_pos: int | None, num_types: int
self, types: list[Type], star_pos: int | None, num_types: int, original_unpack: bool
) -> list[Type]:
"""Undoes the contraction done by contract_starred_pattern_types.

Expand All @@ -364,6 +420,17 @@ def expand_starred_pattern_types(
"""
if star_pos is None:
return types
if original_unpack:
# In the case where original tuple type has an unpack item, it is not practical
# to coerce pattern type back to the original shape (and may not even be possible),
# so we only restore the type of the star item.
res = []
for i, t in enumerate(types):
if i != star_pos:
res.append(t)
else:
res.append(UnpackType(self.chk.named_generic_type("builtins.tuple", [t])))
return res
new_types = types[:star_pos]
star_length = num_types - len(types) + 1
new_types += [types[star_pos]] * star_length
Expand Down Expand Up @@ -459,7 +526,15 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
return self.early_non_match()
if isinstance(type_info, TypeInfo):
any_type = AnyType(TypeOfAny.implementation_artifact)
typ: Type = Instance(type_info, [any_type] * len(type_info.defn.type_vars))
args: list[Type] = []
for tv in type_info.defn.type_vars:
if isinstance(tv, TypeVarTupleType):
args.append(
UnpackType(self.chk.named_generic_type("builtins.tuple", [any_type]))
)
else:
args.append(any_type)
typ: Type = Instance(type_info, args)
elif isinstance(type_info, TypeAlias):
typ = type_info.target
elif (
Expand Down
19 changes: 9 additions & 10 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Instance,
LiteralType,
NoneType,
NormalizedCallableType,
Overloaded,
Parameters,
ParamSpecType,
Expand Down Expand Up @@ -1388,7 +1389,7 @@ def find_matching_overload_items(
return res


def get_tuple_fallback_from_unpack(unpack: UnpackType) -> TypeInfo | None:
def get_tuple_fallback_from_unpack(unpack: UnpackType) -> TypeInfo:
"""Get builtins.tuple type from available types to construct homogeneous tuples."""
tp = get_proper_type(unpack.type)
if isinstance(tp, Instance) and tp.type.fullname == "builtins.tuple":
Expand All @@ -1399,10 +1400,10 @@ def get_tuple_fallback_from_unpack(unpack: UnpackType) -> TypeInfo | None:
for base in tp.partial_fallback.type.mro:
if base.fullname == "builtins.tuple":
return base
return None
assert False, "Invalid unpack type"


def repack_callable_args(callable: CallableType, tuple_type: TypeInfo | None) -> list[Type]:
def repack_callable_args(callable: CallableType, tuple_type: TypeInfo) -> list[Type]:
"""Present callable with star unpack in a normalized form.

Since positional arguments cannot follow star argument, they are packed in a suffix,
Expand All @@ -1417,12 +1418,8 @@ def repack_callable_args(callable: CallableType, tuple_type: TypeInfo | None) ->
star_type = callable.arg_types[star_index]
suffix_types = []
if not isinstance(star_type, UnpackType):
if tuple_type is not None:
# Re-normalize *args: X -> *args: *tuple[X, ...]
star_type = UnpackType(Instance(tuple_type, [star_type]))
else:
# This is unfortunate, something like tuple[Any, ...] would be better.
star_type = UnpackType(AnyType(TypeOfAny.from_error))
# Re-normalize *args: X -> *args: *tuple[X, ...]
star_type = UnpackType(Instance(tuple_type, [star_type]))
else:
tp = get_proper_type(star_type.type)
if isinstance(tp, TupleType):
Expand Down Expand Up @@ -1544,7 +1541,9 @@ def infer_directed_arg_constraints(left: Type, right: Type, direction: int) -> l


def infer_callable_arguments_constraints(
template: CallableType | Parameters, actual: CallableType | Parameters, direction: int
template: NormalizedCallableType | Parameters,
actual: NormalizedCallableType | Parameters,
direction: int,
) -> list[Constraint]:
"""Infer constraints between argument types of two callables.

Expand Down
4 changes: 3 additions & 1 deletion mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def visit_parameters(self, t: Parameters) -> ProperType:
raise RuntimeError("Parameters should have been bound to a class")

def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
return AnyType(TypeOfAny.special_form)
# Likely, we can never get here because of aggressive erasure of types that
# can contain this, but better still return a valid replacement.
return t.tuple_fallback.copy_modified(args=[AnyType(TypeOfAny.special_form)])

def visit_unpack_type(self, t: UnpackType) -> ProperType:
return AnyType(TypeOfAny.special_form)
Expand Down
Loading