Skip to content

Commit

Permalink
Polymorphic inference: support for parameter specifications and lambd…
Browse files Browse the repository at this point in the history
…as (#15837)

This is a third follow-up for #15287
(likely there will be just one more PR, for `TypeVarTuple`s, and few
less important items I mentioned in the original PR I will leave for
more distant future).

After all this PR turned out to be larger than I wanted. The problem is
that `Concatenate` support for `ParamSpec` was quite broken, and this
caused many of my tests fail. So I decided to include some major cleanup
in this PR (I tried splitting it into a separate PR but it turned out to
be tricky). After all, if one ignores added tests, it is almost net zero
line count.

The main problems that I encountered are:
* First, valid substitutions for a `ParamSpecType` were: another
`ParamSpecType`, `Parameters`, and `CallableType` (and also `AnyType`
and `UninhabitedType` but those seem to be handled trivially). Having
`CallableType` in this list caused various missed cases, bogus
`get_proper_type()`s, and was generally counter-intuitive.
* Second (and probably bigger) issue is that it is possible to represent
`Concatenate` in two different forms: as a prefix for `ParamSpecType`
(used mostly for instances), and as separate argument types (used mostly
for callables). The problem is that some parts of the code were
implicitly relying on it being in one or the other form, while some
other code uncontrollably switched between the two.

I propose to fix this by introducing some simplifications and rules
(some of which I enforce by asserts):
* Only valid non-trivial substitutions (and consequently upper/lower
bound in constraints) for `ParamSpecType` are `ParamSpecType` and
`Parameters`.
* When `ParamSpecType` appears in a callable it must have an empty
`prefix`.
* `Parameters` cannot contain other `Parameters` (and ideally also
`ParamSpecType`s) among argument types.
* For inference we bring `Concatenate` to common representation (because
both callables and instances may appear in the same expression). Using
the `ParamSpecType` representation with `prefix` looks significantly
simpler (especially in solver).

Apart from this actual implementation of polymorphic inference is
simple/straightforward, I just handle the additional `ParamSpecType`
cases (in addition to `TypeVarType`) for inference, for solver, and for
application. I also enabled polymorphic inference for lambda
expressions, since they are handled by similar code paths.

Some minor comments:
* I fixed couple minor bugs uncovered by this PR (see e.g. test case for
accidental `TypeVar` id clash).
* I switch few tests to `--new-type-inference` because there error
messages are slightly different, and so it is easier for me to test
global flip to `True` locally.
* I may tweak some of the "ground rules" if `mypy_primer` output will be
particularly bad.

---------

Co-authored-by: Ivan Levkivskyi <ilevkivskyi@hopper.com>
  • Loading branch information
ilevkivskyi and Ivan Levkivskyi committed Aug 15, 2023
1 parent fda7a46 commit 14418bc
Show file tree
Hide file tree
Showing 20 changed files with 639 additions and 234 deletions.
11 changes: 7 additions & 4 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
AnyType,
CallableType,
Instance,
Parameters,
ParamSpecType,
PartialType,
TupleType,
Expand Down Expand Up @@ -112,9 +111,13 @@ def apply_generic_arguments(
if param_spec is not None:
nt = id_to_type.get(param_spec.id)
if nt is not None:
nt = get_proper_type(nt)
if isinstance(nt, (CallableType, Parameters)):
callable = callable.expand_param_spec(nt)
# ParamSpec expansion is special-cased, so we need to always expand callable
# as a whole, not expanding arguments individually.
callable = expand_type(callable, id_to_type)
assert isinstance(callable, CallableType)
return callable.copy_modified(
variables=[tv for tv in tvars if tv.id not in id_to_type]
)

# Apply arguments to argument types.
var_arg = callable.var_arg()
Expand Down
13 changes: 10 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4280,12 +4280,14 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
return_type = self.return_types[-1]
return_type = get_proper_type(return_type)

is_lambda = isinstance(self.scope.top_function(), LambdaExpr)
if isinstance(return_type, UninhabitedType):
self.fail(message_registry.NO_RETURN_EXPECTED, s)
return
# Avoid extra error messages for failed inference in lambdas
if not is_lambda or not return_type.ambiguous:
self.fail(message_registry.NO_RETURN_EXPECTED, s)
return

if s.expr:
is_lambda = isinstance(self.scope.top_function(), LambdaExpr)
declared_none_return = isinstance(return_type, NoneType)
declared_any_return = isinstance(return_type, AnyType)

Expand Down Expand Up @@ -7376,6 +7378,11 @@ def visit_erased_type(self, t: ErasedType) -> bool:
# This can happen inside a lambda.
return True

def visit_type_var(self, t: TypeVarType) -> bool:
# This is needed to prevent leaking into partial types during
# multi-step type inference.
return t.id.is_meta_var()


class SetNothingToAny(TypeTranslator):
"""Replace all ambiguous <nothing> types with Any (to avoid spurious extra errors)."""
Expand Down
123 changes: 109 additions & 14 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from mypy.checkstrformat import StringFormatterChecker
from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars
from mypy.errors import ErrorWatcher, report_internal_error
from mypy.expandtype import expand_type, expand_type_by_instance, freshen_function_type_vars
from mypy.expandtype import (
expand_type,
expand_type_by_instance,
freshen_all_functions_type_vars,
freshen_function_type_vars,
)
from mypy.infer import ArgumentInferContext, infer_function_type_arguments, infer_type_arguments
from mypy.literals import literal
from mypy.maptype import map_instance_to_supertype
Expand Down Expand Up @@ -122,6 +127,7 @@
false_only,
fixup_partial_type,
function_type,
get_all_type_vars,
get_type_vars,
is_literal_type_like,
make_simplified_union,
Expand All @@ -145,6 +151,7 @@
LiteralValue,
NoneType,
Overloaded,
Parameters,
ParamSpecFlavor,
ParamSpecType,
PartialType,
Expand All @@ -167,6 +174,7 @@
get_proper_types,
has_recursive_types,
is_named_instance,
remove_dups,
split_with_prefix_and_suffix,
)
from mypy.types_utils import (
Expand Down Expand Up @@ -1579,6 +1587,16 @@ def check_callable_call(
lambda i: self.accept(args[i]),
)

# This is tricky: return type may contain its own type variables, like in
# def [S] (S) -> def [T] (T) -> tuple[S, T], so we need to update their ids
# to avoid possible id clashes if this call itself appears in a generic
# function body.
ret_type = get_proper_type(callee.ret_type)
if isinstance(ret_type, CallableType) and ret_type.variables:
fresh_ret_type = freshen_all_functions_type_vars(callee.ret_type)
freeze_all_type_vars(fresh_ret_type)
callee = callee.copy_modified(ret_type=fresh_ret_type)

if callee.is_generic():
need_refresh = any(
isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables
Expand All @@ -1597,7 +1615,7 @@ def check_callable_call(
lambda i: self.accept(args[i]),
)
callee = self.infer_function_type_arguments(
callee, args, arg_kinds, formal_to_actual, context
callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context
)
if need_refresh:
formal_to_actual = map_actuals_to_formals(
Expand Down Expand Up @@ -1864,6 +1882,8 @@ def infer_function_type_arguments_using_context(
# def identity(x: T) -> T: return x
#
# expects_literal(identity(3)) # Should type-check
# TODO: we may want to add similar exception if all arguments are lambdas, since
# in this case external context is almost everything we have.
if not is_generic_instance(ctx) and not is_literal_type_like(ctx):
return callable.copy_modified()
args = infer_type_arguments(callable.variables, ret_type, erased_ctx)
Expand All @@ -1885,7 +1905,9 @@ def infer_function_type_arguments(
callee_type: CallableType,
args: list[Expression],
arg_kinds: list[ArgKind],
arg_names: Sequence[str | None] | None,
formal_to_actual: list[list[int]],
need_refresh: bool,
context: Context,
) -> CallableType:
"""Infer the type arguments for a generic callee type.
Expand Down Expand Up @@ -1927,7 +1949,14 @@ def infer_function_type_arguments(
if 2 in arg_pass_nums:
# Second pass of type inference.
(callee_type, inferred_args) = self.infer_function_type_arguments_pass2(
callee_type, args, arg_kinds, formal_to_actual, inferred_args, context
callee_type,
args,
arg_kinds,
arg_names,
formal_to_actual,
inferred_args,
need_refresh,
context,
)

if (
Expand All @@ -1953,6 +1982,17 @@ def infer_function_type_arguments(
or set(get_type_vars(a)) & set(callee_type.variables)
for a in inferred_args
):
if need_refresh:
# Technically we need to refresh formal_to_actual after *each* inference pass,
# since each pass can expand ParamSpec or TypeVarTuple. Although such situations
# are very rare, not doing this can cause crashes.
formal_to_actual = map_actuals_to_formals(
arg_kinds,
arg_names,
callee_type.arg_kinds,
callee_type.arg_names,
lambda a: self.accept(args[a]),
)
# If the regular two-phase inference didn't work, try inferring type
# variables while allowing for polymorphic solutions, i.e. for solutions
# potentially involving free variables.
Expand Down Expand Up @@ -2000,8 +2040,10 @@ def infer_function_type_arguments_pass2(
callee_type: CallableType,
args: list[Expression],
arg_kinds: list[ArgKind],
arg_names: Sequence[str | None] | None,
formal_to_actual: list[list[int]],
old_inferred_args: Sequence[Type | None],
need_refresh: bool,
context: Context,
) -> tuple[CallableType, list[Type | None]]:
"""Perform second pass of generic function type argument inference.
Expand All @@ -2023,6 +2065,14 @@ def infer_function_type_arguments_pass2(
if isinstance(arg, (NoneType, UninhabitedType)) or has_erased_component(arg):
inferred_args[i] = None
callee_type = self.apply_generic_arguments(callee_type, inferred_args, context)
if need_refresh:
formal_to_actual = map_actuals_to_formals(
arg_kinds,
arg_names,
callee_type.arg_kinds,
callee_type.arg_names,
lambda a: self.accept(args[a]),
)

arg_types = self.infer_arg_types_in_context(callee_type, args, arg_kinds, formal_to_actual)

Expand Down Expand Up @@ -4735,8 +4785,22 @@ def infer_lambda_type_using_context(
# they must be considered as indeterminate. We use ErasedType since it
# does not affect type inference results (it is for purposes like this
# only).
callable_ctx = get_proper_type(replace_meta_vars(ctx, ErasedType()))
assert isinstance(callable_ctx, CallableType)
if self.chk.options.new_type_inference:
# With new type inference we can preserve argument types even if they
# are generic, since new inference algorithm can handle constraints
# like S <: T (we still erase return type since it's ultimately unknown).
extra_vars = []
for arg in ctx.arg_types:
meta_vars = [tv for tv in get_all_type_vars(arg) if tv.id.is_meta_var()]
extra_vars.extend([tv for tv in meta_vars if tv not in extra_vars])
callable_ctx = ctx.copy_modified(
ret_type=replace_meta_vars(ctx.ret_type, ErasedType()),
variables=list(ctx.variables) + extra_vars,
)
else:
erased_ctx = replace_meta_vars(ctx, ErasedType())
assert isinstance(erased_ctx, ProperType) and isinstance(erased_ctx, CallableType)
callable_ctx = erased_ctx

# The callable_ctx may have a fallback of builtins.type if the context
# is a constructor -- but this fallback doesn't make sense for lambdas.
Expand Down Expand Up @@ -5693,18 +5757,28 @@ def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None:
self.bound_tvars: set[TypeVarLikeType] = set()
self.seen_aliases: set[TypeInfo] = set()

def visit_callable_type(self, t: CallableType) -> Type:
found_vars = set()
def collect_vars(self, t: CallableType | Parameters) -> list[TypeVarLikeType]:
found_vars = []
for arg in t.arg_types:
found_vars |= set(get_type_vars(arg)) & self.poly_tvars
for tv in get_all_type_vars(arg):
if isinstance(tv, ParamSpecType):
normalized: TypeVarLikeType = tv.copy_modified(
flavor=ParamSpecFlavor.BARE, prefix=Parameters([], [], [])
)
else:
normalized = tv
if normalized in self.poly_tvars and normalized not in self.bound_tvars:
found_vars.append(normalized)
return remove_dups(found_vars)

found_vars -= self.bound_tvars
self.bound_tvars |= found_vars
def visit_callable_type(self, t: CallableType) -> Type:
found_vars = self.collect_vars(t)
self.bound_tvars |= set(found_vars)
result = super().visit_callable_type(t)
self.bound_tvars -= found_vars
self.bound_tvars -= set(found_vars)

assert isinstance(result, ProperType) and isinstance(result, CallableType)
result.variables = list(result.variables) + list(found_vars)
result.variables = list(result.variables) + found_vars
return result

def visit_type_var(self, t: TypeVarType) -> Type:
Expand All @@ -5713,8 +5787,9 @@ def visit_type_var(self, t: TypeVarType) -> Type:
return super().visit_type_var(t)

def visit_param_spec(self, t: ParamSpecType) -> Type:
# TODO: Support polymorphic apply for ParamSpec.
raise PolyTranslationError()
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_param_spec(t)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
# TODO: Support polymorphic apply for TypeVarTuple.
Expand All @@ -5730,6 +5805,26 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type:
raise PolyTranslationError()

def visit_instance(self, t: Instance) -> Type:
if t.type.has_param_spec_type:
# We need this special-casing to preserve the possibility to store a
# generic function in an instance type. Things like
# forall T . Foo[[x: T], T]
# are not really expressible in current type system, but this looks like
# a useful feature, so let's keep it.
param_spec_index = next(
i for (i, tv) in enumerate(t.type.defn.type_vars) if isinstance(tv, ParamSpecType)
)
p = get_proper_type(t.args[param_spec_index])
if isinstance(p, Parameters):
found_vars = self.collect_vars(p)
self.bound_tvars |= set(found_vars)
new_args = [a.accept(self) for a in t.args]
self.bound_tvars -= set(found_vars)

repl = new_args[param_spec_index]
assert isinstance(repl, ProperType) and isinstance(repl, Parameters)
repl.variables = list(repl.variables) + list(found_vars)
return t.copy_modified(args=new_args)
# There is the same problem with callback protocols as with aliases
# (callback protocols are essentially more flexible aliases to callables).
# Note: consider supporting bindings in instances, e.g. LRUCache[[x: T], T].
Expand Down
Loading

0 comments on commit 14418bc

Please sign in to comment.