Skip to content

Commit

Permalink
Add TypeAlias and TypeVar nodes (Python 3.12)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobtylerwalls committed Jun 22, 2023
1 parent d1ef13e commit 38fa8e9
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 4 deletions.
6 changes: 6 additions & 0 deletions astroid/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@
TryFinally,
TryStar,
Tuple,
TypeAlias,
TypeVar,
UnaryOp,
Unknown,
While,
Expand Down Expand Up @@ -193,6 +195,8 @@
TryFinally,
TryStar,
Tuple,
TypeAlias,
TypeVar,
UnaryOp,
Unknown,
While,
Expand Down Expand Up @@ -285,6 +289,8 @@
"TryFinally",
"TryStar",
"Tuple",
"TypeAlias",
"TypeVar",
"UnaryOp",
"Unknown",
"unpack_infer",
Expand Down
9 changes: 9 additions & 0 deletions astroid/nodes/as_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def handle_functiondef(self, node, keyword) -> str:
if node.returns:
return_annotation = " -> " + node.returns.as_string()
trailer = return_annotation + ":"
# TODO: handle type_params
def_format = "\n%s%s %s(%s)%s%s\n%s"
return def_format % (
decorate,
Expand Down Expand Up @@ -517,6 +518,14 @@ def visit_tuple(self, node) -> str:
return f"({node.elts[0].accept(self)}, )"
return f"({', '.join(child.accept(self) for child in node.elts)})"

def visit_typealias(self, node: nodes.TypeAlias) -> str:
"""return an astroid.TypeAlias node as string"""
return f"{node.value}{node.type_params or ''}"

def visit_typevar(self, node: nodes.TypeVar) -> str:
"""return an astroid.TypeVar node as string"""
return node.name

def visit_unaryop(self, node) -> str:
"""return an astroid.UnaryOp node as string"""
if node.op == "not":
Expand Down
95 changes: 92 additions & 3 deletions astroid/nodes/node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
ClassVar,
Literal,
Optional,
TypeVar,
Union,
)

Expand Down Expand Up @@ -63,8 +62,8 @@ def _is_const(value) -> bool:
return isinstance(value, tuple(CONST_CLS))


_NodesT = TypeVar("_NodesT", bound=NodeNG)
_BadOpMessageT = TypeVar("_BadOpMessageT", bound=util.BadOperationMessage)
_NodesT = typing.TypeVar("_NodesT", bound=NodeNG)
_BadOpMessageT = typing.TypeVar("_BadOpMessageT", bound=util.BadOperationMessage)

AssignedStmtsPossibleNode = Union["List", "Tuple", "AssignName", "AssignAttr", None]
AssignedStmtsCall = Callable[
Expand Down Expand Up @@ -3311,6 +3310,96 @@ def getitem(self, index, context: InferenceContext | None = None):
return _container_getitem(self, self.elts, index, context=context)


class TypeAlias(_base_nodes.AssignTypeNode):
"""Class representing a :class:`ast.TypeAlias` node.
>>> import astroid
>>> node = astroid.extract_node('type Point = tuple[float, float]')
>>> node
<TypeAlias l.1 at 0x7f23b2e4e198>
"""

_astroid_fields = ("type_params", "value")

def __init__(
self,
lineno: int | None = None,
col_offset: int | None = None,
parent: NodeNG | None = None,
*,
end_lineno: int | None = None,
end_col_offset: int | None = None,
) -> None:
self.type_params: list[typing.TypeVar]
self.value: NodeNG
super().__init__(
lineno=lineno,
col_offset=col_offset,
end_lineno=end_lineno,
end_col_offset=end_col_offset,
parent=parent,
)

def postinit(
self,
*,
type_params: list[typing.TypeVar],
value: NodeNG,
) -> None:
self.type_params = type_params
self.value = value

assigned_stmts: ClassVar[
Callable[
[
TypeAlias,
AssignName,
InferenceContext | None,
None,
],
Generator[NodeNG, None, None],
]
]
"""Returns the assigned statement (non inferred) according to the assignment type.
See astroid/protocols.py for actual implementation.
"""


class TypeVar(_base_nodes.AssignTypeNode):
"""Class representing a :class:`ast.TypeVar` node.
>>> import astroid
>>> node = astroid.extract_node('type Point[T] = tuple[float, float]')
>>> node.type_params[0]
<TypeVar l.1 at 0x7f23b2e4e198>
"""

_astroid_fields = ("bound",)

def __init__(
self,
lineno: int | None = None,
col_offset: int | None = None,
parent: NodeNG | None = None,
*,
end_lineno: int | None = None,
end_col_offset: int | None = None,
) -> None:
self.name: str
self.bound: NodeNG | None
super().__init__(
lineno=lineno,
col_offset=col_offset,
end_lineno=end_lineno,
end_col_offset=end_col_offset,
parent=parent,
)

def postinit(self, *, name: str, bound: NodeNG | None) -> None:
self.name = name
self.bound = bound


class UnaryOp(NodeNG):
"""Class representing an :class:`ast.UnaryOp` node.
Expand Down
18 changes: 17 additions & 1 deletion astroid/nodes/scoped_nodes/scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,10 +1055,19 @@ class FunctionDef(
<FunctionDef.my_func l.2 at 0x7f23b2e71e10>
"""

_astroid_fields = ("decorators", "args", "returns", "doc_node", "body")
_astroid_fields = (
"decorators",
"args",
"returns",
"type_params",
"doc_node",
"body",
)
_multi_line_block_fields = ("body",)
returns = None

type_params = None

decorators: node_classes.Decorators | None
"""The decorators that are applied to this method or function."""

Expand Down Expand Up @@ -1123,6 +1132,9 @@ def __init__(
self.body: list[NodeNG] = []
"""The contents of the function body."""

self.type_params: list[nodes.TypeVar] = []
"""PEP 695 (Python 3.12+) type params, e.g. first 'T' in def func[T]() -> T: ..."""

self.instance_attrs: dict[str, list[NodeNG]] = {}

super().__init__(
Expand All @@ -1142,6 +1154,7 @@ def postinit(
body: list[NodeNG],
decorators: node_classes.Decorators | None = None,
returns=None,
type_params: list[nodes.TypeVar] | None = None,
type_comment_returns=None,
type_comment_args=None,
*,
Expand All @@ -1164,6 +1177,8 @@ def postinit(
Position of function keyword(s) and name.
:param doc_node:
The doc node associated with this node.
:param type_params:
The type_params associated with this node.
"""
self.args = args
self.body = body
Expand All @@ -1173,6 +1188,7 @@ def postinit(
self.type_comment_args = type_comment_args
self.position = position
self.doc_node = doc_node
self.type_params = type_params

@cached_property
def extra_decorators(self) -> list[node_classes.Call]:
Expand Down
38 changes: 38 additions & 0 deletions astroid/rebuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,16 @@ def visit(self, node: ast.TryStar, parent: NodeNG) -> nodes.TryStar:
def visit(self, node: ast.Tuple, parent: NodeNG) -> nodes.Tuple:
...

if sys.version_info >= (3, 12):

@overload
def visit(self, node: ast.TypeAlias, parent: NodeNG) -> nodes.TypeAlias:
...

@overload
def visit(self, node: ast.TypeVar, parent: NodeNG) -> nodes.TypeVar:
...

@overload
def visit(self, node: ast.UnaryOp, parent: NodeNG) -> nodes.UnaryOp:
...
Expand Down Expand Up @@ -1170,6 +1180,7 @@ def _visit_functiondef(
type_comment_args=type_comment_args,
position=self._get_position_info(node, newnode),
doc_node=self.visit(doc_ast_node, newnode),
type_params=[self.visit(param, newnode) for param in node.type_params],
)
self._global_names.pop()
return newnode
Expand Down Expand Up @@ -1669,6 +1680,33 @@ def visit_tuple(self, node: ast.Tuple, parent: NodeNG) -> nodes.Tuple:
newnode.postinit([self.visit(child, newnode) for child in node.elts])
return newnode

def visit_typealias(self, node: ast.TypeAlias, parent: NodeNG) -> nodes.TypeAlias:
"""Visit a TypeAlias node by returning a fresh instance of it."""
newnode = nodes.TypeAlias(
lineno=node.lineno,
col_offset=node.col_offset,
end_lineno=node.end_lineno,
end_col_offset=node.end_col_offset,
parent=parent,
)
newnode.postinit(
type_params=[self.visit(p, newnode) for p in node.type_params],
value=self.visit(node.value, newnode),
)
return newnode

def visit_typevar(self, node: ast.TypeVar, parent: NodeNG) -> nodes.TypeVar:
"""Visit a TypeVar node by returning a fresh instance of it."""
newnode = nodes.TypeVar(
lineno=node.lineno,
col_offset=node.col_offset,
end_lineno=node.end_lineno,
end_col_offset=node.end_col_offset,
parent=parent,
)
newnode.postinit(name=node.name, bound=self.visit(node.bound, newnode))
return newnode

def visit_unaryop(self, node: ast.UnaryOp, parent: NodeNG) -> nodes.UnaryOp:
"""Visit a UnaryOp node by returning a fresh instance of it."""
newnode = nodes.UnaryOp(
Expand Down
2 changes: 2 additions & 0 deletions doc/api/astroid.nodes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ Nodes
astroid.nodes.TryFinally
astroid.nodes.TryStar
astroid.nodes.Tuple
astroid.nodes.TypeAlias
astroid.nodes.TypeVar
astroid.nodes.UnaryOp
astroid.nodes.Unknown
astroid.nodes.While
Expand Down
31 changes: 31 additions & 0 deletions tests/test_type_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt

import pytest

from astroid import extract_node
from astroid.const import PY312_PLUS
from astroid.nodes import Subscript, TypeAlias, TypeVar


@pytest.mark.skipif(not PY312_PLUS, reason="Requires Python 3.12 or higher")
def test_type_alias() -> None:
node = extract_node("type Point[T] = list[float, float]")
assert isinstance(node, TypeAlias)
assert isinstance(node.type_params[0], TypeVar)
assert node.type_params[0].name == "T"
assert node.type_params[0].bound is None

assert isinstance(node.value, Subscript)
assert node.value.value.name == "list"
assert node.value.slice.name == "tuple"
assert all(elt.name == "float" for elt in node.value.slice.elts)


@pytest.mark.skipif(not PY312_PLUS, reason="Requires Python 3.12 or higher")
def test_type_param() -> None:
node = extract_node("def func[T]() -> T: ...")
assert isinstance(node.type_params[0], TypeVar)
assert node.type_params[0].name == "T"
assert node.type_params[0].bound is None

0 comments on commit 38fa8e9

Please sign in to comment.