Skip to content

Commit

Permalink
Fix crash on ParamSpec unification (#16251)
Browse files Browse the repository at this point in the history
Fixes #16245
Fixes #16248

Unfortunately I was a bit reckless with parentheses, but in my defense
`unify_generic_callable()` is kind of broken for long time, as it can
return "solutions" like ```{1: T`1}```. We need a more principled
approach there (IIRC there is already an issue about this in the scope
of `--new-type-inference`).

(The fix is quite trivial so I am not going to wait for review too long
to save time, unless there will be some issues in `mypy_primer` etc.)
  • Loading branch information
ilevkivskyi authored Oct 12, 2023
1 parent 2c1009e commit 72605dc
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 14 deletions.
10 changes: 6 additions & 4 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,15 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
return repl.copy_modified(
flavor=t.flavor,
prefix=t.prefix.copy_modified(
arg_types=self.expand_types(t.prefix.arg_types + repl.prefix.arg_types),
arg_types=self.expand_types(t.prefix.arg_types) + repl.prefix.arg_types,
arg_kinds=t.prefix.arg_kinds + repl.prefix.arg_kinds,
arg_names=t.prefix.arg_names + repl.prefix.arg_names,
),
)
elif isinstance(repl, Parameters):
assert t.flavor == ParamSpecFlavor.BARE
return Parameters(
self.expand_types(t.prefix.arg_types + repl.arg_types),
self.expand_types(t.prefix.arg_types) + repl.arg_types,
t.prefix.arg_kinds + repl.arg_kinds,
t.prefix.arg_names + repl.arg_names,
variables=[*t.prefix.variables, *repl.variables],
Expand Down Expand Up @@ -333,12 +333,14 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
# the replacement is ignored.
if isinstance(repl, Parameters):
# We need to expand both the types in the prefix and the ParamSpec itself
t = t.expand_param_spec(repl)
return t.copy_modified(
arg_types=self.expand_types(t.arg_types),
arg_types=self.expand_types(t.arg_types[:-2]) + repl.arg_types,
arg_kinds=t.arg_kinds[:-2] + repl.arg_kinds,
arg_names=t.arg_names[:-2] + repl.arg_names,
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
imprecise_arg_kinds=(t.imprecise_arg_kinds or repl.imprecise_arg_kinds),
variables=[*repl.variables, *t.variables],
)
elif isinstance(repl, ParamSpecType):
# We're substituting one ParamSpec for another; this can mean that the prefix
Expand Down
10 changes: 0 additions & 10 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2069,16 +2069,6 @@ def param_spec(self) -> ParamSpecType | None:
prefix = Parameters(self.arg_types[:-2], self.arg_kinds[:-2], self.arg_names[:-2])
return arg_type.copy_modified(flavor=ParamSpecFlavor.BARE, prefix=prefix)

def expand_param_spec(self, c: Parameters) -> CallableType:
variables = c.variables
return self.copy_modified(
arg_types=self.arg_types[:-2] + c.arg_types,
arg_kinds=self.arg_kinds[:-2] + c.arg_kinds,
arg_names=self.arg_names[:-2] + c.arg_names,
is_ellipsis_args=c.is_ellipsis_args,
variables=[*variables, *self.variables],
)

def with_unpacked_kwargs(self) -> NormalizedCallableType:
if not self.unpack_kwargs:
return cast(NormalizedCallableType, self)
Expand Down
37 changes: 37 additions & 0 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -1976,3 +1976,40 @@ g(cb, y=0, x='a') # OK
g(cb, y='a', x=0) # E: Argument "y" to "g" has incompatible type "str"; expected "int" \
# E: Argument "x" to "g" has incompatible type "int"; expected "str"
[builtins fixtures/paramspec.pyi]

[case testParamSpecNoCrashOnUnificationAlias]
import mod
[file mod.pyi]
from typing import Callable, Protocol, TypeVar, overload
from typing_extensions import ParamSpec

P = ParamSpec("P")
R_co = TypeVar("R_co", covariant=True)
Handler = Callable[P, R_co]

class HandlerDecorator(Protocol):
def __call__(self, handler: Handler[P, R_co]) -> Handler[P, R_co]: ...

@overload
def event(event_handler: Handler[P, R_co]) -> Handler[P, R_co]: ...
@overload
def event(namespace: str, *args, **kwargs) -> HandlerDecorator: ...
[builtins fixtures/paramspec.pyi]

[case testParamSpecNoCrashOnUnificationCallable]
import mod
[file mod.pyi]
from typing import Callable, Protocol, TypeVar, overload
from typing_extensions import ParamSpec

P = ParamSpec("P")
R_co = TypeVar("R_co", covariant=True)

class HandlerDecorator(Protocol):
def __call__(self, handler: Callable[P, R_co]) -> Callable[P, R_co]: ...

@overload
def event(event_handler: Callable[P, R_co]) -> Callable[P, R_co]: ...
@overload
def event(namespace: str, *args, **kwargs) -> HandlerDecorator: ...
[builtins fixtures/paramspec.pyi]

0 comments on commit 72605dc

Please sign in to comment.