From fbcff3a5b559acb834bca23215a49ded3386f5bd Mon Sep 17 00:00:00 2001 From: Jacob Walls Date: Wed, 21 Jun 2023 21:33:18 -0400 Subject: [PATCH] Add TypeAlias and TypeVar nodes (Python 3.12) --- astroid/nodes/__init__.py | 6 ++ astroid/nodes/as_string.py | 10 +++ astroid/nodes/node_classes.py | 95 +++++++++++++++++++++- astroid/nodes/scoped_nodes/scoped_nodes.py | 32 +++++++- astroid/rebuilder.py | 45 +++++++++- doc/api/astroid.nodes.rst | 6 ++ tests/test_type_params.py | 36 ++++++++ 7 files changed, 223 insertions(+), 7 deletions(-) create mode 100644 tests/test_type_params.py diff --git a/astroid/nodes/__init__.py b/astroid/nodes/__init__.py index f677ff509b..17c8f32f6b 100644 --- a/astroid/nodes/__init__.py +++ b/astroid/nodes/__init__.py @@ -83,6 +83,8 @@ TryFinally, TryStar, Tuple, + TypeAlias, + TypeVar, UnaryOp, Unknown, While, @@ -193,6 +195,8 @@ TryFinally, TryStar, Tuple, + TypeAlias, + TypeVar, UnaryOp, Unknown, While, @@ -285,6 +289,8 @@ "TryFinally", "TryStar", "Tuple", + "TypeAlias", + "TypeVar", "UnaryOp", "Unknown", "unpack_infer", diff --git a/astroid/nodes/as_string.py b/astroid/nodes/as_string.py index 49ef1b77e3..657913d615 100644 --- a/astroid/nodes/as_string.py +++ b/astroid/nodes/as_string.py @@ -178,6 +178,7 @@ def visit_classdef(self, node) -> str: args += [n.accept(self) for n in node.keywords] args_str = f"({', '.join(args)})" if args else "" docs = self._docs_dedent(node.doc_node) + # TODO: handle type_params return "\n\n{}class {}{}:{}\n{}\n".format( decorate, node.name, args_str, docs, self._stmt_list(node.body) ) @@ -330,6 +331,7 @@ def handle_functiondef(self, node, keyword) -> str: if node.returns: return_annotation = " -> " + node.returns.as_string() trailer = return_annotation + ":" + # TODO: handle type_params def_format = "\n%s%s %s(%s)%s%s\n%s" return def_format % ( decorate, @@ -517,6 +519,14 @@ def visit_tuple(self, node) -> str: return f"({node.elts[0].accept(self)}, )" return f"({', '.join(child.accept(self) for child in node.elts)})" + def visit_typealias(self, node: nodes.TypeAlias) -> str: + """return an astroid.TypeAlias node as string""" + return f"{node.value}{node.type_params or ''}" + + def visit_typevar(self, node: nodes.TypeVar) -> str: + """return an astroid.TypeVar node as string""" + return node.name + def visit_unaryop(self, node) -> str: """return an astroid.UnaryOp node as string""" if node.op == "not": diff --git a/astroid/nodes/node_classes.py b/astroid/nodes/node_classes.py index 5afb36594c..1d7b541055 100644 --- a/astroid/nodes/node_classes.py +++ b/astroid/nodes/node_classes.py @@ -19,7 +19,6 @@ ClassVar, Literal, Optional, - TypeVar, Union, ) @@ -62,8 +61,8 @@ def _is_const(value) -> bool: return isinstance(value, tuple(CONST_CLS)) -_NodesT = TypeVar("_NodesT", bound=NodeNG) -_BadOpMessageT = TypeVar("_BadOpMessageT", bound=util.BadOperationMessage) +_NodesT = typing.TypeVar("_NodesT", bound=NodeNG) +_BadOpMessageT = typing.TypeVar("_BadOpMessageT", bound=util.BadOperationMessage) AssignedStmtsPossibleNode = Union["List", "Tuple", "AssignName", "AssignAttr", None] AssignedStmtsCall = Callable[ @@ -3310,6 +3309,96 @@ def getitem(self, index, context: InferenceContext | None = None): return _container_getitem(self, self.elts, index, context=context) +class TypeAlias(_base_nodes.AssignTypeNode): + """Class representing a :class:`ast.TypeAlias` node. + + >>> import astroid + >>> node = astroid.extract_node('type Point = tuple[float, float]') + >>> node + + """ + + _astroid_fields = ("type_params", "value") + + def __init__( + self, + lineno: int | None = None, + col_offset: int | None = None, + parent: NodeNG | None = None, + *, + end_lineno: int | None = None, + end_col_offset: int | None = None, + ) -> None: + self.type_params: list[TypeVar] + self.value: NodeNG + super().__init__( + lineno=lineno, + col_offset=col_offset, + end_lineno=end_lineno, + end_col_offset=end_col_offset, + parent=parent, + ) + + def postinit( + self, + *, + type_params: list[TypeVar], + value: NodeNG, + ) -> None: + self.type_params = type_params + self.value = value + + assigned_stmts: ClassVar[ + Callable[ + [ + TypeAlias, + AssignName, + InferenceContext | None, + None, + ], + Generator[NodeNG, None, None], + ] + ] + """Returns the assigned statement (non inferred) according to the assignment type. + See astroid/protocols.py for actual implementation. + """ + + +class TypeVar(_base_nodes.AssignTypeNode): + """Class representing a :class:`ast.TypeVar` node. + + >>> import astroid + >>> node = astroid.extract_node('type Point[T] = tuple[float, float]') + >>> node.type_params[0] + + """ + + _astroid_fields = ("bound",) + + def __init__( + self, + lineno: int | None = None, + col_offset: int | None = None, + parent: NodeNG | None = None, + *, + end_lineno: int | None = None, + end_col_offset: int | None = None, + ) -> None: + self.name: str + self.bound: NodeNG | None + super().__init__( + lineno=lineno, + col_offset=col_offset, + end_lineno=end_lineno, + end_col_offset=end_col_offset, + parent=parent, + ) + + def postinit(self, *, name: str, bound: NodeNG | None) -> None: + self.name = name + self.bound = bound + + class UnaryOp(NodeNG): """Class representing an :class:`ast.UnaryOp` node. diff --git a/astroid/nodes/scoped_nodes/scoped_nodes.py b/astroid/nodes/scoped_nodes/scoped_nodes.py index bfe1462fd3..bd85948921 100644 --- a/astroid/nodes/scoped_nodes/scoped_nodes.py +++ b/astroid/nodes/scoped_nodes/scoped_nodes.py @@ -1055,7 +1055,14 @@ class FunctionDef( """ - _astroid_fields = ("decorators", "args", "returns", "doc_node", "body") + _astroid_fields = ( + "decorators", + "args", + "returns", + "type_params", + "doc_node", + "body", + ) _multi_line_block_fields = ("body",) returns = None @@ -1123,6 +1130,9 @@ def __init__( self.body: list[NodeNG] = [] """The contents of the function body.""" + self.type_params: list[nodes.TypeVar] = [] + """PEP 695 (Python 3.12+) type params, e.g. first 'T' in def func[T]() -> T: ...""" + self.instance_attrs: dict[str, list[NodeNG]] = {} super().__init__( @@ -1147,6 +1157,7 @@ def postinit( *, position: Position | None = None, doc_node: Const | None = None, + type_params: list[nodes.TypeVar] | None = None, ): """Do some setup after initialisation. @@ -1164,6 +1175,8 @@ def postinit( Position of function keyword(s) and name. :param doc_node: The doc node associated with this node. + :param type_params: + The type_params associated with this node. """ self.args = args self.body = body @@ -1173,6 +1186,7 @@ def postinit( self.type_comment_args = type_comment_args self.position = position self.doc_node = doc_node + self.type_params = type_params or [] @cached_property def extra_decorators(self) -> list[node_classes.Call]: @@ -1739,7 +1753,7 @@ def get_wrapping_class(node): return klass -class ClassDef( +class ClassDef( # pylint: disable=too-many-instance-attributes _base_nodes.FilterStmtsBaseNode, LocalsDictNodeNG, _base_nodes.Statement ): """Class representing an :class:`ast.ClassDef` node. @@ -1758,7 +1772,14 @@ def my_meth(self, arg): # by a raw factories # a dictionary of class instances attributes - _astroid_fields = ("decorators", "bases", "keywords", "doc_node", "body") # name + _astroid_fields = ( + "decorators", + "bases", + "keywords", + "doc_node", + "body", + "type_params", + ) # name decorators = None """The decorators that are applied to this class. @@ -1825,6 +1846,9 @@ def __init__( self.is_dataclass: bool = False """Whether this class is a dataclass.""" + self.type_params: list[nodes.TypeVar] = [] + """PEP 695 (Python 3.12+) type params, e.g. class MyClass[T]: ...""" + super().__init__( lineno=lineno, col_offset=col_offset, @@ -1866,6 +1890,7 @@ def postinit( *, position: Position | None = None, doc_node: Const | None = None, + type_params: list[nodes.TypeVar] | None = None, ) -> None: if keywords is not None: self.keywords = keywords @@ -1876,6 +1901,7 @@ def postinit( self._metaclass = metaclass self.position = position self.doc_node = doc_node + self.type_params = type_params or [] def _newstyle_impl(self, context: InferenceContext | None = None): if context is None: diff --git a/astroid/rebuilder.py b/astroid/rebuilder.py index 64c1c12362..71456e3c18 100644 --- a/astroid/rebuilder.py +++ b/astroid/rebuilder.py @@ -18,7 +18,7 @@ from astroid import nodes from astroid._ast import ParserModule, get_parser_module, parse_function_type_comment -from astroid.const import IS_PYPY, PY38, PY39_PLUS, Context +from astroid.const import IS_PYPY, PY38, PY39_PLUS, PY312_PLUS, Context from astroid.manager import AstroidManager from astroid.nodes import NodeNG from astroid.nodes.utils import Position @@ -432,6 +432,16 @@ def visit(self, node: ast.TryStar, parent: NodeNG) -> nodes.TryStar: def visit(self, node: ast.Tuple, parent: NodeNG) -> nodes.Tuple: ... + if sys.version_info >= (3, 12): + + @overload + def visit(self, node: ast.TypeAlias, parent: NodeNG) -> nodes.TypeAlias: + ... + + @overload + def visit(self, node: ast.TypeVar, parent: NodeNG) -> nodes.TypeVar: + ... + @overload def visit(self, node: ast.UnaryOp, parent: NodeNG) -> nodes.UnaryOp: ... @@ -870,6 +880,9 @@ def visit_classdef( ], position=self._get_position_info(node, newnode), doc_node=self.visit(doc_ast_node, newnode), + type_params=[self.visit(param, newnode) for param in node.type_params] + if PY312_PLUS + else [], ) return newnode @@ -1170,6 +1183,9 @@ def _visit_functiondef( type_comment_args=type_comment_args, position=self._get_position_info(node, newnode), doc_node=self.visit(doc_ast_node, newnode), + type_params=[self.visit(param, newnode) for param in node.type_params] + if PY312_PLUS + else [], ) self._global_names.pop() return newnode @@ -1669,6 +1685,33 @@ def visit_tuple(self, node: ast.Tuple, parent: NodeNG) -> nodes.Tuple: newnode.postinit([self.visit(child, newnode) for child in node.elts]) return newnode + def visit_typealias(self, node: ast.TypeAlias, parent: NodeNG) -> nodes.TypeAlias: + """Visit a TypeAlias node by returning a fresh instance of it.""" + newnode = nodes.TypeAlias( + lineno=node.lineno, + col_offset=node.col_offset, + end_lineno=node.end_lineno, + end_col_offset=node.end_col_offset, + parent=parent, + ) + newnode.postinit( + type_params=[self.visit(p, newnode) for p in node.type_params], + value=self.visit(node.value, newnode), + ) + return newnode + + def visit_typevar(self, node: ast.TypeVar, parent: NodeNG) -> nodes.TypeVar: + """Visit a TypeVar node by returning a fresh instance of it.""" + newnode = nodes.TypeVar( + lineno=node.lineno, + col_offset=node.col_offset, + end_lineno=node.end_lineno, + end_col_offset=node.end_col_offset, + parent=parent, + ) + newnode.postinit(name=node.name, bound=self.visit(node.bound, newnode)) + return newnode + def visit_unaryop(self, node: ast.UnaryOp, parent: NodeNG) -> nodes.UnaryOp: """Visit a UnaryOp node by returning a fresh instance of it.""" newnode = nodes.UnaryOp( diff --git a/doc/api/astroid.nodes.rst b/doc/api/astroid.nodes.rst index 7783b45d3d..26b8b15527 100644 --- a/doc/api/astroid.nodes.rst +++ b/doc/api/astroid.nodes.rst @@ -79,6 +79,8 @@ Nodes astroid.nodes.TryFinally astroid.nodes.TryStar astroid.nodes.Tuple + astroid.nodes.TypeAlias + astroid.nodes.TypeVar astroid.nodes.UnaryOp astroid.nodes.Unknown astroid.nodes.While @@ -226,6 +228,10 @@ Nodes .. autoclass:: astroid.nodes.Tuple +.. autoclass:: astroid.nodes.TypeAlias + +.. autoclass:: astroid.nodes.TypeVar + .. autoclass:: astroid.nodes.UnaryOp .. autoclass:: astroid.nodes.Unknown diff --git a/tests/test_type_params.py b/tests/test_type_params.py new file mode 100644 index 0000000000..179c053200 --- /dev/null +++ b/tests/test_type_params.py @@ -0,0 +1,36 @@ +# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html +# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE +# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt + +import pytest + +from astroid import extract_node +from astroid.const import PY312_PLUS +from astroid.nodes import Subscript, TypeAlias, TypeVar + + +@pytest.mark.skipif(not PY312_PLUS, reason="Requires Python 3.12 or higher") +def test_type_alias() -> None: + node = extract_node("type Point[T] = list[float, float]") + assert isinstance(node, TypeAlias) + assert isinstance(node.type_params[0], TypeVar) + assert node.type_params[0].name == "T" + assert node.type_params[0].bound is None + + assert isinstance(node.value, Subscript) + assert node.value.value.name == "list" + assert node.value.slice.name == "tuple" + assert all(elt.name == "float" for elt in node.value.slice.elts) + + +@pytest.mark.skipif(not PY312_PLUS, reason="Requires Python 3.12 or higher") +def test_type_param() -> None: + func_node = extract_node("def func[T]() -> T: ...") + assert isinstance(func_node.type_params[0], TypeVar) + assert func_node.type_params[0].name == "T" + assert func_node.type_params[0].bound is None + + class_node = extract_node("class MyClass[T]: ...") + assert isinstance(class_node.type_params[0], TypeVar) + assert class_node.type_params[0].name == "T" + assert class_node.type_params[0].bound is None