Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add mapped_type_vars to TypeInfo #18274

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion mypy/mro.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Callable

from mypy.nodes import TypeInfo
from mypy.types import Instance
from mypy.types import Instance, ProperType, TypeVarLikeType
from mypy.typestate import type_state


Expand All @@ -15,11 +15,61 @@ def calculate_mro(info: TypeInfo, obj_type: Callable[[], Instance] | None = None
mro = linearize_hierarchy(info, obj_type)
assert mro, f"Could not produce a MRO at all for {info}"
info.mro = mro
fill_mapped_type_vars(info)
# The property of falling back to Any is inherited.
info.fallback_to_any = any(baseinfo.fallback_to_any for baseinfo in info.mro)
type_state.reset_all_subtype_caches_for(info)


def fill_mapped_type_vars(info: TypeInfo) -> None:
"""Calculates the final TypeVar value from inheritor to parent.

class A[T1]:
# mapped_type_vars = {T1: str}

class B[T2]:
# mapped_type_vars = {T2: T4}

class C[T3](B[T3]):
# mapped_type_vars = {T3: T4}

class D[T4](C[T4], A[str]):
# mapped_type_vars = {}
"""
bases = {b.type: b for b in info.bases}

for subinfo in filter(lambda x: x.is_generic, info.mro):
if base_info := bases.get(subinfo):
subinfo.mapped_type_vars = {
tv: actual_type for tv, actual_type in zip(subinfo.defn.type_vars, base_info.args)
}
info.mapped_type_vars |= subinfo.mapped_type_vars

final_mapped_type_vars: dict[TypeVarLikeType, ProperType] = {}
for k, v in info.mapped_type_vars.items():
final_mapped_type_vars[k] = _resolve_mappped_vars(info.mapped_type_vars, v)

for subinfo in filter(lambda x: x.is_generic, info.mro):
_resolve_info_type_vars(subinfo, final_mapped_type_vars)


def _resolve_info_type_vars(
info: TypeInfo, mapped_type_vars: dict[TypeVarLikeType, ProperType]
) -> None:
final_mapped_type_vars = {}
for tv in info.defn.type_vars:
final_mapped_type_vars[tv] = _resolve_mappped_vars(mapped_type_vars, tv)
info.mapped_type_vars = final_mapped_type_vars


def _resolve_mappped_vars(
mapped_type_vars: dict[TypeVarLikeType, ProperType], key: ProperType
) -> ProperType:
if key in mapped_type_vars:
return _resolve_mappped_vars(mapped_type_vars, mapped_type_vars[key])
return key


class MroError(Exception):
"""Raised if a consistent mro cannot be determined for a class."""

Expand Down
4 changes: 4 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2944,6 +2944,7 @@ class is generic then it will be a type constructor of higher kind.
"fallback_to_any",
"meta_fallback_to_any",
"type_vars",
"mapped_type_vars",
"has_param_spec_type",
"bases",
"_promote",
Expand Down Expand Up @@ -3048,6 +3049,8 @@ class is generic then it will be a type constructor of higher kind.

# Generic type variable names (full names)
type_vars: list[str]
# Map of current class TypeVars and Inheritor specified type to calculate real type in MRO
mapped_type_vars: dict[mypy.types.TypeVarLikeType, mypy.types.ProperType]

# Whether this class has a ParamSpec type variable
has_param_spec_type: bool
Expand Down Expand Up @@ -3139,6 +3142,7 @@ def __init__(self, names: SymbolTable, defn: ClassDef, module_name: str) -> None
self.defn = defn
self.module_name = module_name
self.type_vars = []
self.mapped_type_vars = {}
self.has_param_spec_type = False
self.has_type_var_tuple_type = False
self.bases = []
Expand Down
Loading