Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-44796: Unify TypeVar and ParamSpec substitution #31143

Merged
merged 4 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 9 additions & 54 deletions Lib/_collections_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,25 +430,13 @@ def __new__(cls, origin, args):
raise TypeError(
"Callable must be used as Callable[[arg, ...], result].")
t_args, t_result = args
if isinstance(t_args, list):
if isinstance(t_args, (tuple, list)):
args = (*t_args, t_result)
elif not _is_param_expr(t_args):
raise TypeError(f"Expected a list of types, an ellipsis, "
f"ParamSpec, or Concatenate. Got {t_args}")
return super().__new__(cls, origin, args)

@property
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have to keep __parameters__, it's de facto a public API by now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is inherited from GenericAlias now.

def __parameters__(self):
params = []
for arg in self.__args__:
# Looks like a genericalias
if hasattr(arg, "__parameters__") and isinstance(arg.__parameters__, tuple):
params.extend(arg.__parameters__)
else:
if _is_typevarlike(arg):
params.append(arg)
return tuple(dict.fromkeys(params))

def __repr__(self):
if len(self.__args__) == 2 and _is_param_expr(self.__args__[0]):
return super().__repr__()
Expand All @@ -468,57 +456,24 @@ def __getitem__(self, item):
# code is copied from typing's _GenericAlias and the builtin
# types.GenericAlias.

# A special case in PEP 612 where if X = Callable[P, int],
# then X[int, str] == X[[int, str]].
param_len = len(self.__parameters__)
if param_len == 0:
raise TypeError(f'{self} is not a generic class')
if not isinstance(item, tuple):
item = (item,)
if (param_len == 1 and _is_param_expr(self.__parameters__[0])
# A special case in PEP 612 where if X = Callable[P, int],
# then X[int, str] == X[[int, str]].
if (len(self.__parameters__) == 1
and _is_param_expr(self.__parameters__[0])
and item and not _is_param_expr(item[0])):
item = (list(item),)
item_len = len(item)
if item_len != param_len:
raise TypeError(f'Too {"many" if item_len > param_len else "few"}'
f' arguments for {self};'
f' actual {item_len}, expected {param_len}')
subst = dict(zip(self.__parameters__, item))
new_args = []
for arg in self.__args__:
if _is_typevarlike(arg):
if _is_param_expr(arg):
arg = subst[arg]
if not _is_param_expr(arg):
raise TypeError(f"Expected a list of types, an ellipsis, "
f"ParamSpec, or Concatenate. Got {arg}")
else:
arg = subst[arg]
# Looks like a GenericAlias
elif hasattr(arg, '__parameters__') and isinstance(arg.__parameters__, tuple):
subparams = arg.__parameters__
if subparams:
subargs = tuple(subst[x] for x in subparams)
arg = arg[subargs]
if isinstance(arg, tuple):
new_args.extend(arg)
else:
new_args.append(arg)
item = (item,)

new_args = super().__getitem__(item).__args__

# args[0] occurs due to things like Z[[int, str, bool]] from PEP 612
if not isinstance(new_args[0], list):
if not isinstance(new_args[0], (tuple, list)):
t_result = new_args[-1]
t_args = new_args[:-1]
new_args = (t_args, t_result)
return _CallableGenericAlias(Callable, tuple(new_args))


def _is_typevarlike(arg):
obj = type(arg)
# looks like a TypeVar/ParamSpec
return (obj.__module__ == 'typing'
and obj.__name__ in {'ParamSpec', 'TypeVar'})

def _is_param_expr(obj):
"""Checks if obj matches either a list of types, ``...``, ``ParamSpec`` or
``_ConcatenateGenericAlias`` from typing.py
Expand Down
75 changes: 61 additions & 14 deletions Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,36 @@ def test_no_bivariant(self):
with self.assertRaises(ValueError):
TypeVar('T', covariant=True, contravariant=True)

def test_var_substitution(self):
T = TypeVar('T')
subst = T.__typing_subst__
self.assertIs(subst(int), int)
self.assertEqual(subst(list[int]), list[int])
self.assertEqual(subst(List[int]), List[int])
self.assertEqual(subst(List), List)
self.assertIs(subst(Any), Any)
self.assertIs(subst(None), type(None))
self.assertIs(subst(T), T)
self.assertEqual(subst(int|str), int|str)
self.assertEqual(subst(Union[int, str]), Union[int, str])

def test_bad_var_substitution(self):
T = TypeVar('T')
P = ParamSpec("P")
bad_args = (
42, ..., [int], (), (int, str), P, Union,
Generic, Generic[T], Protocol, Protocol[T],
Final, Final[int], ClassVar, ClassVar[int],
)
for arg in bad_args:
with self.subTest(arg=arg):
with self.assertRaises(TypeError):
T.__typing_subst__(arg)
with self.assertRaises(TypeError):
List[T][arg]
with self.assertRaises(TypeError):
list[T][arg]


class UnionTests(BaseTestCase):

Expand Down Expand Up @@ -568,8 +598,10 @@ def test_var_substitution(self):
C2 = Callable[[KT, T], VT]
C3 = Callable[..., T]
self.assertEqual(C1[str], Callable[[int, str], str])
self.assertEqual(C1[None], Callable[[int, type(None)], type(None)])
self.assertEqual(C2[int, float, str], Callable[[int, float], str])
self.assertEqual(C3[int], Callable[..., int])
self.assertEqual(C3[NoReturn], Callable[..., NoReturn])

# multi chaining
C4 = C2[int, VT, str]
Expand Down Expand Up @@ -2107,7 +2139,10 @@ def test_all_repr_eq_any(self):
for obj in objs:
self.assertNotEqual(repr(obj), '')
self.assertEqual(obj, obj)
if getattr(obj, '__parameters__', None) and len(obj.__parameters__) == 1:
if (getattr(obj, '__parameters__', None)
and not isinstance(obj, typing.TypeVar)
and isinstance(obj.__parameters__, tuple)
and len(obj.__parameters__) == 1):
self.assertEqual(obj[Any].__args__, (Any,))
if isinstance(obj, type):
for base in obj.__mro__:
Expand Down Expand Up @@ -4981,21 +5016,29 @@ class X(Generic[P, P2]):
self.assertEqual(G1.__args__, ((int, str), (bytes,)))
self.assertEqual(G2.__args__, ((int,), (str, bytes)))

def test_no_paramspec_in__parameters__(self):
# ParamSpec should not be found in __parameters__
# of generics. Usages outside Callable, Concatenate
# and Generic are invalid.
def test_var_substitution(self):
T = TypeVar("T")
P = ParamSpec("P")
self.assertNotIn(P, List[P].__parameters__)
self.assertIn(T, Tuple[T, P].__parameters__)

# Test for consistency with builtin generics.
self.assertNotIn(P, list[P].__parameters__)
self.assertIn(T, tuple[T, P].__parameters__)

self.assertNotIn(P, (list[P] | int).__parameters__)
self.assertIn(T, (tuple[T, P] | int).__parameters__)
subst = P.__typing_subst__
self.assertEqual(subst((int, str)), (int, str))
self.assertEqual(subst([int, str]), (int, str))
self.assertEqual(subst([None]), (type(None),))
self.assertIs(subst(...), ...)
self.assertIs(subst(P), P)
self.assertEqual(subst(Concatenate[int, P]), Concatenate[int, P])

def test_bad_var_substitution(self):
T = TypeVar('T')
P = ParamSpec('P')
bad_args = (42, int, None, T, int|str, Union[int, str])
for arg in bad_args:
with self.subTest(arg=arg):
with self.assertRaises(TypeError):
P.__typing_subst__(arg)
with self.assertRaises(TypeError):
typing.Callable[P, T][arg, str]
with self.assertRaises(TypeError):
collections.abc.Callable[P, T][arg, str]

def test_paramspec_in_nested_generics(self):
# Although ParamSpec should not be found in __parameters__ of most
Expand All @@ -5010,6 +5053,10 @@ def test_paramspec_in_nested_generics(self):
self.assertEqual(G1.__parameters__, (P, T))
self.assertEqual(G2.__parameters__, (P, T))
self.assertEqual(G3.__parameters__, (P, T))
C = Callable[[int, str], float]
self.assertEqual(G1[[int, str], float], List[C])
self.assertEqual(G2[[int, str], float], list[C])
self.assertEqual(G3[[int, str], float], list[C] | int)


class ConcatenateTests(BaseTestCase):
Expand Down
72 changes: 39 additions & 33 deletions Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,13 @@ def _type_check(arg, msg, is_argument=True, module=None, *, allow_special_forms=
if (isinstance(arg, _GenericAlias) and
arg.__origin__ in invalid_generic_forms):
raise TypeError(f"{arg} is not valid as type argument")
if arg in (Any, NoReturn, ClassVar, Final):
if arg in (Any, NoReturn):
return arg
if allow_special_forms and arg in (ClassVar, Final):
return arg
if isinstance(arg, _SpecialForm) or arg in (Generic, Protocol):
raise TypeError(f"Plain {arg} is not valid as type argument")
if isinstance(arg, (type, TypeVar, ForwardRef, types.UnionType, ParamSpec)):
if isinstance(arg, (type, TypeVar, ForwardRef, types.UnionType)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this affect putting ParamSpec inside Annotated? Looks like we don't have tests for that yet.

Copy link
Contributor

@GBeauregard GBeauregard Feb 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into this.

This check is needed in order to allow instances of ParamSpec (i.e. P in P = ParamSpec("P")) to pass typing._type_check because callable(P) is False. This can show up for instance like Callable[Annotated[P, ""], T]. This line's change would regress the behavior to a runtime error.

n.b. this is moot if #31151 gets merged since this line is removed entirely

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Callable[Annotated[P, ""], T] does not conform the PEP 612 specification.

callable ::= Callable "[" parameters_expression, type_expression "]"

parameters_expression ::=
  | "..."
  | "[" [ type_expression ("," type_expression)* ] "]"
  | parameter_specification_variable
  | concatenate "["
                   type_expression ("," type_expression)* ","
                   parameter_specification_variable
                "]"

Copy link
Contributor

@GBeauregard GBeauregard Feb 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced it was the intent of the spec here to disallow Annotated, but you can also reach this code here:

from typing import Callable, TypeVar, ParamSpec, get_type_hints
T = TypeVar("T")
P = ParamSpec("P")
def add_logging(f: Callable["P", T]):
    pass
get_type_hints(add_logging)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I re-added ParamSpec. Now some errors will not be caught at runtime. We will return to this in future.

return arg
if not callable(arg):
raise TypeError(f"{msg} Got {arg!r:.100}.")
Expand Down Expand Up @@ -211,21 +213,22 @@ def _type_repr(obj):
return repr(obj)


def _collect_type_vars(types_, typevar_types=None):
"""Collect all type variable contained
in types in order of first appearance (lexicographic order). For example::
def _collect_parameters(args):
"""Collect all type variables and parameter specifications in args
in order of first appearance (lexicographic order). For example::

_collect_type_vars((T, List[S, T])) == (T, S)
_collect_parameters((T, Callable[P, T])) == (T, P)
"""
if typevar_types is None:
typevar_types = TypeVar
tvars = []
for t in types_:
if isinstance(t, typevar_types) and t not in tvars:
tvars.append(t)
if isinstance(t, (_GenericAlias, GenericAlias, types.UnionType)):
tvars.extend([t for t in t.__parameters__ if t not in tvars])
return tuple(tvars)
parameters = []
for t in args:
if hasattr(t, '__typing_subst__'):
if t not in parameters:
parameters.append(t)
else:
for x in getattr(t, '__parameters__', ()):
if x not in parameters:
parameters.append(x)
return tuple(parameters)


def _check_generic(cls, parameters, elen):
Expand Down Expand Up @@ -818,6 +821,11 @@ def __init__(self, name, *constraints, bound=None,
if def_mod != 'typing':
self.__module__ = def_mod

def __typing_subst__(self, arg):
msg = "Parameters to generic types must be types."
arg = _type_check(arg, msg, is_argument=True)
return arg


class ParamSpecArgs(_Final, _Immutable, _root=True):
"""The args for a ParamSpec object.
Expand Down Expand Up @@ -918,6 +926,14 @@ def __init__(self, name, *, bound=None, covariant=False, contravariant=False):
if def_mod != 'typing':
self.__module__ = def_mod

def __typing_subst__(self, arg):
if isinstance(arg, (list, tuple)):
arg = tuple(_type_check(a, "Expected a type.") for a in arg)
elif not _is_param_expr(arg):
raise TypeError(f"Expected a list of types, an ellipsis, "
f"ParamSpec, or Concatenate. Got {arg}")
return arg


def _is_dunder(attr):
return attr.startswith('__') and attr.endswith('__')
Expand Down Expand Up @@ -972,7 +988,7 @@ def __getattr__(self, attr):

def __setattr__(self, attr, val):
if _is_dunder(attr) or attr in {'_name', '_inst', '_nparams',
'_typevar_types', '_paramspec_tvars'}:
'_paramspec_tvars'}:
super().__setattr__(attr, val)
else:
setattr(self.__origin__, attr, val)
Expand Down Expand Up @@ -1001,16 +1017,14 @@ def __dir__(self):

class _GenericAlias(_BaseGenericAlias, _root=True):
def __init__(self, origin, params, *, inst=True, name=None,
_typevar_types=TypeVar,
_paramspec_tvars=False):
super().__init__(origin, inst=inst, name=name)
if not isinstance(params, tuple):
params = (params,)
self.__args__ = tuple(... if a is _TypingEllipsis else
() if a is _TypingEmpty else
a for a in params)
self.__parameters__ = _collect_type_vars(params, typevar_types=_typevar_types)
self._typevar_types = _typevar_types
self.__parameters__ = _collect_parameters(params)
self._paramspec_tvars = _paramspec_tvars
if not name:
self.__module__ = origin.__module__
Expand Down Expand Up @@ -1047,16 +1061,11 @@ def __getitem__(self, params):
subst = dict(zip(self.__parameters__, params))
new_args = []
for arg in self.__args__:
if isinstance(arg, self._typevar_types):
if isinstance(arg, ParamSpec):
arg = subst[arg]
if not _is_param_expr(arg):
raise TypeError(f"Expected a list of types, an ellipsis, "
f"ParamSpec, or Concatenate. Got {arg}")
else:
arg = subst[arg]
elif isinstance(arg, (_GenericAlias, GenericAlias, types.UnionType)):
subparams = arg.__parameters__
substfunc = getattr(arg, '__typing_subst__', None)
if substfunc:
arg = substfunc(subst[arg])
else:
subparams = getattr(arg, '__parameters__', ())
if subparams:
subargs = tuple(subst[x] for x in subparams)
arg = arg[subargs]
Expand Down Expand Up @@ -1172,7 +1181,6 @@ class _CallableType(_SpecialGenericAlias, _root=True):
def copy_with(self, params):
return _CallableGenericAlias(self.__origin__, params,
name=self._name, inst=self._inst,
_typevar_types=(TypeVar, ParamSpec),
_paramspec_tvars=True)

def __getitem__(self, params):
Expand Down Expand Up @@ -1272,7 +1280,6 @@ def __hash__(self):
class _ConcatenateGenericAlias(_GenericAlias, _root=True):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs,
_typevar_types=(TypeVar, ParamSpec),
_paramspec_tvars=True)

def copy_with(self, params):
Expand Down Expand Up @@ -1333,7 +1340,6 @@ def __class_getitem__(cls, params):
else:
_check_generic(cls, params, len(cls.__parameters__))
return _GenericAlias(cls, params,
_typevar_types=(TypeVar, ParamSpec),
_paramspec_tvars=True)

def __init_subclass__(cls, *args, **kwargs):
Expand All @@ -1346,7 +1352,7 @@ def __init_subclass__(cls, *args, **kwargs):
if error:
raise TypeError("Cannot inherit from plain Generic")
if '__orig_bases__' in cls.__dict__:
tvars = _collect_type_vars(cls.__orig_bases__, (TypeVar, ParamSpec))
tvars = _collect_parameters(cls.__orig_bases__)
# Look for Generic[T1, ..., Tn].
# If found, tvars must be a subset of it.
# If not found, tvars is it.
Expand Down
Loading