Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Foundations for non-linear solver and polymorphic application #15287

Merged
merged 19 commits into from
Jun 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 2 additions & 104 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,12 @@
Callable,
ClassVar,
Dict,
Iterable,
Iterator,
Mapping,
NamedTuple,
NoReturn,
Sequence,
TextIO,
TypeVar,
)
from typing_extensions import Final, TypeAlias as _TypeAlias

Expand All @@ -47,6 +45,7 @@
import mypy.semanal_main
from mypy.checker import TypeChecker
from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error
from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
from mypy.indirection import TypeIndirectionVisitor
from mypy.messages import MessageBuilder
from mypy.nodes import Import, ImportAll, ImportBase, ImportFrom, MypyFile, SymbolTable, TypeInfo
Expand Down Expand Up @@ -3466,15 +3465,8 @@ def sorted_components(
edges = {id: deps_filtered(graph, vertices, id, pri_max) for id in vertices}
sccs = list(strongly_connected_components(vertices, edges))
# Topsort.
sccsmap = {id: frozenset(scc) for scc in sccs for id in scc}
data: dict[AbstractSet[str], set[AbstractSet[str]]] = {}
for scc in sccs:
deps: set[AbstractSet[str]] = set()
for id in scc:
deps.update(sccsmap[x] for x in deps_filtered(graph, vertices, id, pri_max))
data[frozenset(scc)] = deps
res = []
for ready in topsort(data):
for ready in topsort(prepare_sccs(sccs, edges)):
# Sort the sets in ready by reversed smallest State.order. Examples:
#
# - If ready is [{x}, {y}], x.order == 1, y.order == 2, we get
Expand All @@ -3499,100 +3491,6 @@ def deps_filtered(graph: Graph, vertices: AbstractSet[str], id: str, pri_max: in
]


def strongly_connected_components(
vertices: AbstractSet[str], edges: dict[str, list[str]]
) -> Iterator[set[str]]:
"""Compute Strongly Connected Components of a directed graph.

Args:
vertices: the labels for the vertices
edges: for each vertex, gives the target vertices of its outgoing edges

Returns:
An iterator yielding strongly connected components, each
represented as a set of vertices. Each input vertex will occur
exactly once; vertices not part of a SCC are returned as
singleton sets.

From https://code.activestate.com/recipes/578507/.
"""
identified: set[str] = set()
stack: list[str] = []
index: dict[str, int] = {}
boundaries: list[int] = []

def dfs(v: str) -> Iterator[set[str]]:
index[v] = len(stack)
stack.append(v)
boundaries.append(index[v])

for w in edges[v]:
if w not in index:
yield from dfs(w)
elif w not in identified:
while index[w] < boundaries[-1]:
boundaries.pop()

if boundaries[-1] == index[v]:
boundaries.pop()
scc = set(stack[index[v] :])
del stack[index[v] :]
identified.update(scc)
yield scc

for v in vertices:
if v not in index:
yield from dfs(v)


T = TypeVar("T")


def topsort(data: dict[T, set[T]]) -> Iterable[set[T]]:
"""Topological sort.

Args:
data: A map from vertices to all vertices that it has an edge
connecting it to. NOTE: This data structure
is modified in place -- for normalization purposes,
self-dependencies are removed and entries representing
orphans are added.

Returns:
An iterator yielding sets of vertices that have an equivalent
ordering.

Example:
Suppose the input has the following structure:

{A: {B, C}, B: {D}, C: {D}}

This is normalized to:

{A: {B, C}, B: {D}, C: {D}, D: {}}

The algorithm will yield the following values:

{D}
{B, C}
{A}

From https://code.activestate.com/recipes/577413/.
"""
# TODO: Use a faster algorithm?
for k, v in data.items():
v.discard(k) # Ignore self dependencies.
for item in set.union(*data.values()) - set(data.keys()):
data[item] = set()
while True:
ready = {item for item, dep in data.items() if not dep}
if not ready:
break
yield ready
data = {item: (dep - ready) for item, dep in data.items() if item not in ready}
assert not data, f"A cyclic dependency exists amongst {data!r}"


def missing_stubs_file(cache_dir: str) -> str:
return os.path.join(cache_dir, "missing_stubs")

Expand Down
145 changes: 143 additions & 2 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import mypy.errorcodes as codes
from mypy import applytype, erasetype, join, message_registry, nodes, operators, types
from mypy.argmap import ArgTypeExpander, map_actuals_to_formals, map_formals_to_actuals
from mypy.checkmember import analyze_member_access, type_object_type
from mypy.checkmember import analyze_member_access, freeze_all_type_vars, type_object_type
from mypy.checkstrformat import StringFormatterChecker
from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars
from mypy.errors import ErrorWatcher, report_internal_error
Expand Down Expand Up @@ -98,8 +98,15 @@
)
from mypy.semanal_enum import ENUM_BASES
from mypy.state import state
from mypy.subtypes import is_equivalent, is_same_type, is_subtype, non_method_protocol_members
from mypy.subtypes import (
find_member,
is_equivalent,
is_same_type,
is_subtype,
non_method_protocol_members,
)
from mypy.traverser import has_await_expression
from mypy.type_visitor import TypeTranslator
from mypy.typeanal import (
check_for_explicit_any,
has_any_from_unimported_type,
Expand All @@ -114,6 +121,7 @@
false_only,
fixup_partial_type,
function_type,
get_type_vars,
is_literal_type_like,
make_simplified_union,
simple_literal_type,
Expand Down Expand Up @@ -146,6 +154,7 @@
TypedDictType,
TypeOfAny,
TypeType,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
UninhabitedType,
Expand Down Expand Up @@ -300,6 +309,7 @@ def __init__(
# on whether current expression is a callee, to give better error messages
# related to type context.
self.is_callee = False
type_state.infer_polymorphic = self.chk.options.new_type_inference

def reset(self) -> None:
self.resolved_type = {}
Expand Down Expand Up @@ -1791,6 +1801,51 @@ def infer_function_type_arguments(
inferred_args[0] = self.named_type("builtins.str")
elif not first_arg or not is_subtype(self.named_type("builtins.str"), first_arg):
self.chk.fail(message_registry.KEYWORD_ARGUMENT_REQUIRES_STR_KEY_TYPE, context)

if self.chk.options.new_type_inference and any(
a is None
or isinstance(get_proper_type(a), UninhabitedType)
or set(get_type_vars(a)) & set(callee_type.variables)
for a in inferred_args
):
# If the regular two-phase inference didn't work, try inferring type
# variables while allowing for polymorphic solutions, i.e. for solutions
# potentially involving free variables.
# TODO: support the similar inference for return type context.
poly_inferred_args = infer_function_type_arguments(
callee_type,
arg_types,
arg_kinds,
formal_to_actual,
context=self.argument_infer_context(),
strict=self.chk.in_checked_function(),
allow_polymorphic=True,
)
for i, pa in enumerate(get_proper_types(poly_inferred_args)):
if isinstance(pa, (NoneType, UninhabitedType)) or has_erased_component(pa):
# Indicate that free variables should not be applied in the call below.
poly_inferred_args[i] = None
poly_callee_type = self.apply_generic_arguments(
callee_type, poly_inferred_args, context
)
yes_vars = poly_callee_type.variables
no_vars = {v for v in callee_type.variables if v not in poly_callee_type.variables}
if not set(get_type_vars(poly_callee_type)) & no_vars:
# Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can
# be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed.
applied = apply_poly(poly_callee_type, yes_vars)
if applied is not None and poly_inferred_args != [UninhabitedType()] * len(
poly_inferred_args
):
freeze_all_type_vars(applied)
return applied
# If it didn't work, erase free variables as <nothing>, to avoid confusing errors.
inferred_args = [
expand_type(a, {v.id: UninhabitedType() for v in callee_type.variables})
if a is not None
else None
for a in inferred_args
]
else:
# In dynamically typed functions use implicit 'Any' types for
# type variables.
Expand Down Expand Up @@ -5393,6 +5448,92 @@ def replace_callable_return_type(c: CallableType, new_ret_type: Type) -> Callabl
return c.copy_modified(ret_type=new_ret_type)


def apply_poly(tp: CallableType, poly_tvars: Sequence[TypeVarLikeType]) -> Optional[CallableType]:
"""Make free type variables generic in the type if possible.

This will translate the type `tp` while trying to create valid bindings for
type variables `poly_tvars` while traversing the type. This follows the same rules
as we do during semantic analysis phase, examples:
* Callable[Callable[[T], T], T] -> def [T] (def (T) -> T) -> T
* Callable[[], Callable[[T], T]] -> def () -> def [T] (T -> T)
* List[T] -> None (not possible)
"""
try:
return tp.copy_modified(
arg_types=[t.accept(PolyTranslator(poly_tvars)) for t in tp.arg_types],
ret_type=tp.ret_type.accept(PolyTranslator(poly_tvars)),
variables=[],
)
except PolyTranslationError:
return None


class PolyTranslationError(Exception):
pass


class PolyTranslator(TypeTranslator):
"""Make free type variables generic in the type if possible.

See docstring for apply_poly() for details.
"""

def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None:
self.poly_tvars = set(poly_tvars)
# This is a simplified version of TypeVarScope used during semantic analysis.
self.bound_tvars: set[TypeVarLikeType] = set()
self.seen_aliases: set[TypeInfo] = set()

def visit_callable_type(self, t: CallableType) -> Type:
found_vars = set()
for arg in t.arg_types:
found_vars |= set(get_type_vars(arg)) & self.poly_tvars

found_vars -= self.bound_tvars
self.bound_tvars |= found_vars
result = super().visit_callable_type(t)
self.bound_tvars -= found_vars

assert isinstance(result, ProperType) and isinstance(result, CallableType)
result.variables = list(result.variables) + list(found_vars)
return result

def visit_type_var(self, t: TypeVarType) -> Type:
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_type_var(t)

def visit_param_spec(self, t: ParamSpecType) -> Type:
# TODO: Support polymorphic apply for ParamSpec.
raise PolyTranslationError()

def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
# TODO: Support polymorphic apply for TypeVarTuple.
raise PolyTranslationError()

def visit_type_alias_type(self, t: TypeAliasType) -> Type:
if not t.args:
return t.copy_modified()
if not t.is_recursive:
return get_proper_type(t).accept(self)
# We can't handle polymorphic application for recursive generic aliases
# without risking an infinite recursion, just give up for now.
raise PolyTranslationError()

def visit_instance(self, t: Instance) -> Type:
# There is the same problem with callback protocols as with aliases
# (callback protocols are essentially more flexible aliases to callables).
# Note: consider supporting bindings in instances, e.g. LRUCache[[x: T], T].
if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]:
if t.type in self.seen_aliases:
raise PolyTranslationError()
self.seen_aliases.add(t.type)
call = find_member("__call__", t, t, is_operator=True)
assert call is not None
return call.accept(self)
return super().visit_instance(t)


class ArgInferSecondPassQuery(types.BoolTypeQuery):
"""Query whether an argument type should be inferred in the second pass.

Expand Down
25 changes: 24 additions & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,30 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
param_spec = template.param_spec()
if param_spec is None:
# FIX verify argument counts
# FIX what if one of the functions is generic
# TODO: Erase template variables if it is generic?
if (
type_state.infer_polymorphic
and cactual.variables
and cactual.param_spec() is None
# Technically, the correct inferred type for application of e.g.
# Callable[..., T] -> Callable[..., T] (with literal ellipsis), to a generic
# like U -> U, should be Callable[..., Any], but if U is a self-type, we can
# allow it to leak, to be later bound to self. A bunch of existing code
# depends on this old behaviour.
and not any(tv.id.raw_id == 0 for tv in cactual.variables)
):
# If actual is generic, unify it with template. Note: this is
# not an ideal solution (which would be adding the generic variables
# to the constraint inference set), but it's a good first approximation,
# and this will prevent leaking these variables in the solutions.
# Note: this may infer constraints like T <: S or T <: List[S]
# that contain variables in the target.
unified = mypy.subtypes.unify_generic_callable(
cactual, template, ignore_return=True
)
if unified is not None:
cactual = unified
res.extend(infer_constraints(cactual, template, neg_op(self.direction)))

# We can't infer constraints from arguments if the template is Callable[..., T]
# (with literal '...').
Expand Down
Loading
Loading