Skip to content

Commit

Permalink
Add some minimal support for PEP 612 in type stubs.
Browse files Browse the repository at this point in the history
pytype will gracefully fall back to Any when typing.ParamSpec or
typing.Concatenate is used as the first argument to a Callable in a type stub,
allowing typeshed to start using most of PEP 612. Custom generic classes
parameterized with a ParamSpec are still not supported.

pytype was reporting weird [bad-unpacking] errors in parser.py, so I kept
adding ast3.AST annotations until they went away.

I also fixed a bug in PrintVisitor that caused it to print typing_extensions
imports as aliases rather than imports.

For #786.

PiperOrigin-RevId: 366342708
  • Loading branch information
rchen152 committed Apr 2, 2021
1 parent 22150dd commit 0f1033e
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 33 deletions.
11 changes: 4 additions & 7 deletions pytype/load_pytd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,11 +401,12 @@ def f() -> List[int]: ...
loader = load_pytd.Loader(None, self.python_version, pythonpath=[d.path])
foo = loader.import_name("foo")
bar = loader.import_name("bar")
self.assertEqual(pytd_utils.Print(foo), "foo.List = list")
self.assertEqual(pytd_utils.Print(foo),
"from builtins import list as List")
self.assertEqual(pytd_utils.Print(bar), textwrap.dedent("""
from typing import List
bar.List = list
from builtins import list as List
def bar.f() -> List[int]: ...
""").strip())
Expand Down Expand Up @@ -569,11 +570,7 @@ def test_star_import(self):
self._pickle_modules(loader, d, foo, bar)
loaded_ast = self._load_pickled_module(d, bar)
loaded_ast.Visit(visitors.VerifyLookup())
self.assertMultiLineEqual(pytd_utils.Print(loaded_ast),
textwrap.dedent("""
import foo
bar.A = foo.A""").lstrip())
self.assertEqual(pytd_utils.Print(loaded_ast), "from foo import A")

def test_function_alias(self):
with file_utils.Tempdir() as d:
Expand Down
24 changes: 23 additions & 1 deletion pytype/pyi/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def __init__(self, module_info):
self.constants = []
self.aliases = collections.OrderedDict()
self.type_params = []
self.param_specs = []
self.generated_classes = collections.defaultdict(list)
self.module_path_map = {}

Expand Down Expand Up @@ -342,6 +343,16 @@ def add_type_var(self, name, typevar):
self.type_params.append(pytd.TypeParameter(
name=name, constraints=constraints, bound=bound))

def add_param_spec(self, name, paramspec):
if name != paramspec.name:
raise ParseError("ParamSpec name needs to be %r (not %r)" % (
paramspec.name, name))
# ParamSpec should probably be represented with its own pytd class, like
# TypeVar. This is just a quick, hacky way for us to keep track of which
# names refer to ParamSpecs so we can replace them with Any in
# _parameterized_type().
self.param_specs.append(pytd.NamedType(name))

def add_import(self, from_package, import_list):
"""Add an import.
Expand Down Expand Up @@ -419,7 +430,18 @@ def _parameterized_type(self, base_type, parameters):
if self._is_tuple_base_type(base_type):
return pytdgen.heterogeneous_tuple(base_type, parameters)
elif self._is_callable_base_type(base_type):
return pytdgen.pytd_callable(base_type, parameters)
callable_parameters = []
for p in parameters:
# We do not yet support PEP 612, Parameter Specification Variables.
# To avoid blocking typeshed from adopting this PEP, we convert new
# features to Any.
if p in self.param_specs or (
isinstance(p, pytd.GenericType) and self._matches_full_name(
p, ("typing.Concatenate", "typing_extensions.Concatenate"))):
callable_parameters.append(pytd.AnythingType())
else:
callable_parameters.append(p)
return pytdgen.pytd_callable(base_type, tuple(callable_parameters))
else:
assert parameters
return pytd.GenericType(base_type=base_type, parameters=parameters)
Expand Down
64 changes: 47 additions & 17 deletions pytype/pyi/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@
ParseError = types.ParseError


_TYPEVAR_IDS = ("TypeVar", "typing.TypeVar")
_PARAMSPEC_IDS = (
"ParamSpec", "typing.ParamSpec", "typing_extensions.ParamSpec")
_TYPING_NAMEDTUPLE_IDS = ("NamedTuple", "typing.NamedTuple")
_COLL_NAMEDTUPLE_IDS = ("namedtuple", "collections.namedtuple")
_TYPEDDICT_IDS = (
"TypedDict", "typing.TypedDict", "typing_extensions.TypedDict")
_NEWTYPE_IDS = ("NewType", "typing.NewType")


#------------------------------------------------------
# imports

Expand Down Expand Up @@ -86,6 +96,18 @@ def from_call(cls, node: ast3.AST) -> "_TypeVar":
return cls(name, bound, constraints)


@dataclasses.dataclass
class _ParamSpec:
"""Internal representation of ParamSpecs."""

name: str

@classmethod
def from_call(cls, node: ast3.AST) -> "_ParamSpec":
name, = node.args
return cls(name)


#------------------------------------------------------
# pytd utils

Expand Down Expand Up @@ -340,10 +362,13 @@ def visit_Assign(self, node):
target = targets[0]
name = target.id

# Record and erase typevar definitions.
# Record and erase TypeVar and ParamSpec definitions.
if isinstance(node.value, _TypeVar):
self.defs.add_type_var(name, node.value)
return Splice([])
elif isinstance(node.value, _ParamSpec):
self.defs.add_param_spec(name, node.value)
return Splice([])

if node.type_comment:
# TODO(mdemello): can pyi files have aliases with typecomments?
Expand Down Expand Up @@ -412,15 +437,15 @@ def visit_ImportFrom(self, node):
self.defs.add_import(module, imports)
return Splice([])

def _convert_newtype_args(self, node):
def _convert_newtype_args(self, node: ast3.AST):
if len(node.args) != 2:
msg = "Wrong args: expected NewType(name, [(field, type), ...])"
raise ParseError(msg)
name, typ = node.args
typ = self.convert_node(typ)
node.args = [name.s, typ]

def _convert_typing_namedtuple_args(self, node):
def _convert_typing_namedtuple_args(self, node: ast3.AST):
# TODO(mdemello): handle NamedTuple("X", a=int, b=str, ...)
if len(node.args) != 2:
msg = "Wrong args: expected NamedTuple(name, [(field, type), ...])"
Expand All @@ -430,7 +455,7 @@ def _convert_typing_namedtuple_args(self, node):
fields = [(types.string_value(n), t) for (n, t) in fields]
node.args = [name.s, fields]

def _convert_collections_namedtuple_args(self, node):
def _convert_collections_namedtuple_args(self, node: ast3.AST):
if len(node.args) != 2:
msg = "Wrong args: expected namedtuple(name, [field, ...])"
raise ParseError(msg)
Expand All @@ -454,7 +479,11 @@ def _convert_typevar_args(self, node):
val = types.string_value(kw.value, context="TypeVar bound")
kw.value = self.annotation_visitor.convert_late_annotation(val)

def _convert_typed_dict_args(self, node):
def _convert_paramspec_args(self, node):
name, = node.args
node.args = [name.s]

def _convert_typed_dict_args(self, node: ast3.AST):
# TODO(b/157603915): new_typed_dict currently doesn't do anything with the
# args, so we don't bother converting them fully.
msg = "Wrong args: expected TypedDict(name, {field: type, ...})"
Expand All @@ -473,30 +502,31 @@ def enter_Call(self, node):
# passing them to internal functions directly in visit_Call.
if isinstance(node.func, ast3.Attribute):
node.func = _attribute_to_name(node.func)
if node.func.id in ("TypeVar", "typing.TypeVar"):
if node.func.id in _TYPEVAR_IDS:
self._convert_typevar_args(node)
elif node.func.id in ("NamedTuple", "typing.NamedTuple"):
elif node.func.id in _PARAMSPEC_IDS:
self._convert_paramspec_args(node)
elif node.func.id in _TYPING_NAMEDTUPLE_IDS:
self._convert_typing_namedtuple_args(node)
elif node.func.id in ("namedtuple", "collections.namedtuple"):
elif node.func.id in _COLL_NAMEDTUPLE_IDS:
self._convert_collections_namedtuple_args(node)
elif node.func.id in ("TypedDict", "typing.TypedDict",
"typing_extensions.TypedDict"):
elif node.func.id in _TYPEDDICT_IDS:
self._convert_typed_dict_args(node)
elif node.func.id in ("NewType", "typing.NewType"):
elif node.func.id in _NEWTYPE_IDS:
return self._convert_newtype_args(node)

def visit_Call(self, node):
if node.func.id in ("TypeVar", "typing.TypeVar"):
if node.func.id in _TYPEVAR_IDS:
if self.level > 0:
raise ParseError("TypeVars need to be defined at module level")
return _TypeVar.from_call(node)
elif node.func.id in ("NamedTuple", "typing.NamedTuple",
"namedtuple", "collections.namedtuple"):
elif node.func.id in _PARAMSPEC_IDS:
return _ParamSpec.from_call(node)
elif node.func.id in _TYPING_NAMEDTUPLE_IDS + _COLL_NAMEDTUPLE_IDS:
return self.defs.new_named_tuple(*node.args)
elif node.func.id in ("TypedDict", "typing.TypedDict",
"typing_extensions.TypedDict"):
elif node.func.id in _TYPEDDICT_IDS:
return self.defs.new_typed_dict(*node.args, total=False)
elif node.func.id in ("NewType", "typing.NewType"):
elif node.func.id in _NEWTYPE_IDS:
return self.defs.new_new_type(*node.args)
# Convert all other calls to NamedTypes; for example:
# * typing.pyi uses things like
Expand Down
124 changes: 124 additions & 0 deletions pytype/pyi/parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2715,5 +2715,129 @@ def test_feature_version(self):
self.assertEqual(actual, expected)


class ParamSpecTest(_ParserTestBase):

def test_from_typing(self):
self.check("""
from typing import Awaitable, Callable, ParamSpec, TypeVar
P = ParamSpec('P')
R = TypeVar('R')
def f(x: Callable[P, R]) -> Callable[P, Awaitable[R]]: ...
""", """
from typing import Awaitable, Callable, TypeVar
R = TypeVar('R')
def f(x: Callable[..., R]) -> Callable[..., Awaitable[R]]: ...
""")

def test_from_typing_extensions(self):
self.check("""
from typing import Awaitable, Callable, TypeVar
from typing_extensions import ParamSpec
P = ParamSpec('P')
R = TypeVar('R')
def f(x: Callable[P, R]) -> Callable[P, Awaitable[R]]: ...
""", """
from typing import Awaitable, Callable, TypeVar
from typing_extensions import ParamSpec
R = TypeVar('R')
def f(x: Callable[..., R]) -> Callable[..., Awaitable[R]]: ...
""")

@test_base.skip("ParamSpec in custom generic classes not supported yet")
def test_custom_generic(self):
self.check("""
from typing import Callable, Generic, ParamSpec, TypeVar
P = ParamSpec('P')
T = TypeVar('T')
class X(Generic[T, P]):
f: Callable[P, int]
x: T
""")

@test_base.skip("ParamSpec in custom generic classes not supported yet")
def test_double_brackets(self):
# Double brackets can be omitted when instantiating a class parameterized
# with only a single ParamSpec.
self.check("""
from typing import Generic, ParamSpec
P = ParamSpec('P')
class X(Generic[P]): ...
def f1(x: X[int, str]) -> None: ...
def f2(x: X[[int, str]]) -> None: ...
""", """
from typing import Generic, ParamSpec
P = ParamSpec('P')
class X(Generic[P]): ...
def f1(x: X[int, str]) -> None: ...
def f2(x: X[int, str]) -> None: ...
""")


class ConcatenateTest(_ParserTestBase):

def test_from_typing(self):
self.check("""
from typing import Callable, Concatenate, ParamSpec, TypeVar
P = ParamSpec('P')
R = TypeVar('R')
class X: ...
def f(x: Callable[Concatenate[X, P], R]) -> Callable[P, R]: ...
""", """
from typing import Callable, TypeVar
R = TypeVar('R')
class X: ...
def f(x: Callable[..., R]) -> Callable[..., R]: ...
""")

def test_from_typing_extensions(self):
self.check("""
from typing import Callable, TypeVar
from typing_extensions import Concatenate, ParamSpec
P = ParamSpec('P')
R = TypeVar('R')
class X: ...
def f(x: Callable[Concatenate[X, P], R]) -> Callable[P, R]: ...
""", """
from typing import Callable, TypeVar
from typing_extensions import Concatenate
from typing_extensions import ParamSpec
R = TypeVar('R')
class X: ...
def f(x: Callable[..., R]) -> Callable[..., R]: ...
""")


if __name__ == "__main__":
unittest.main()
9 changes: 6 additions & 3 deletions pytype/pytd/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,16 @@ def EnterAlias(self, _):

def VisitAlias(self, node):
"""Convert an import or alias to a string."""
if isinstance(self.old_node.type, pytd.NamedType):
if isinstance(self.old_node.type, (pytd.NamedType, pytd.ClassType)):
full_name = self.old_node.type.name
suffix = ""
module, _, name = full_name.rpartition(".")
if module:
if name not in ("*", self.old_node.name):
suffix += " as " + self.old_node.name
alias_name = self.old_node.name
if alias_name.startswith(f"{self._unit_name}."):
alias_name = alias_name[len(self._unit_name)+1:]
if name not in ("*", alias_name):
suffix += " as " + alias_name
self.imports = self.old_imports # undo unnecessary imports change
return "from " + module + " import " + name + suffix
elif isinstance(self.old_node.type, (pytd.Constant, pytd.Function)):
Expand Down
6 changes: 1 addition & 5 deletions pytype/pytd/visitors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,7 @@ class A: ...
ast2 = ast2.Visit(visitors.LookupExternalTypes(
{"foo": ast1}, self_name=None))
self.assertEqual(name, ast2.name)
self.assertMultiLineEqual(pytd_utils.Print(ast2), textwrap.dedent("""
import foo
A = foo.A
""").strip())
self.assertEqual(pytd_utils.Print(ast2), "from foo import A")

def test_lookup_two_star_aliases(self):
src1 = "class A: ..."
Expand Down

0 comments on commit 0f1033e

Please sign in to comment.