Skip to content

Commit

Permalink
work in progress trystar refs #1516
Browse files Browse the repository at this point in the history
Visit child nodes of TryStar

Test children

Avoid creating TryExcept under TryStar

block range tests

Proper coverage of block_range in try star node

Add text for name lookup

Co-authored-by: Jacob Walls <jacobtylerwalls@gmail.com>
  • Loading branch information
Pierre-Sassoulas and jacobtylerwalls committed Mar 5, 2023
1 parent 857232e commit da728ca
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 0 deletions.
1 change: 1 addition & 0 deletions astroid/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
Subscript,
TryExcept,
TryFinally,
TryStar,
Tuple,
UnaryOp,
Unknown,
Expand Down
101 changes: 101 additions & 0 deletions astroid/nodes/node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4216,6 +4216,107 @@ def get_children(self):
yield from self.finalbody


class TryStar(_base_nodes.MultiLineWithElseBlockNode, _base_nodes.Statement):
"""Class representing an :class:`ast.TryStar` node."""

_astroid_fields = ("body", "handlers", "orelse", "finalbody")
_multi_line_block_fields = ("body", "handlers", "orelse", "finalbody")

def __init__(
self,
*,
lineno: int | None = None,
col_offset: int | None = None,
end_lineno: int | None = None,
end_col_offset: int | None = None,
parent: NodeNG | None = None,
) -> None:
"""
:param lineno: The line that this node appears on in the source code.
:param col_offset: The column that this node appears on in the
source code.
:param parent: The parent node in the syntax tree.
:param end_lineno: The last line this node appears on in the source code.
:param end_col_offset: The end column this node appears on in the
source code. Note: This is after the last symbol.
"""
self.body: list[NodeNG] = []
"""The contents of the block to catch exceptions from."""

self.handlers: list[ExceptHandler] = []
"""The exception handlers."""

self.orelse: list[NodeNG] = []
"""The contents of the ``else`` block."""

self.finalbody: list[NodeNG] = []
"""The contents of the ``finally`` block."""

super().__init__(
lineno=lineno,
col_offset=col_offset,
end_lineno=end_lineno,
end_col_offset=end_col_offset,
parent=parent,
)

def postinit(
self,
*,
body: list[NodeNG] | None = None,
handlers: list[ExceptHandler] | None = None,
orelse: list[NodeNG] | None = None,
finalbody: list[NodeNG] | None = None,
) -> None:
"""Do some setup after initialisation.
:param body: The contents of the block to catch exceptions from.
:param handlers: The exception handlers.
:param orelse: The contents of the ``else`` block.
:param finalbody: The contents of the ``finally`` block.
"""
if body:
self.body = body
if handlers:
self.handlers = handlers
if orelse:
self.orelse = orelse
if finalbody:
self.finalbody = finalbody

def _infer_name(self, frame, name):
return name

def block_range(self, lineno: int) -> tuple[int, int]:
"""Get a range from a given line number to where this node ends."""
if lineno == self.fromlineno:
return lineno, lineno
if self.body and self.body[0].fromlineno <= lineno <= self.body[-1].tolineno:
# Inside try body - return from lineno till end of try body
return lineno, self.body[-1].tolineno
for exhandler in self.handlers:
if exhandler.type and lineno == exhandler.type.fromlineno:
return lineno, lineno
if exhandler.body[0].fromlineno <= lineno <= exhandler.body[-1].tolineno:
return lineno, exhandler.body[-1].tolineno
if self.orelse:
if self.orelse[0].fromlineno - 1 == lineno:
return lineno, lineno
if self.orelse[0].fromlineno <= lineno <= self.orelse[-1].tolineno:
return lineno, self.orelse[-1].tolineno
if self.finalbody:
if self.finalbody[0].fromlineno - 1 == lineno:
return lineno, lineno
if self.finalbody[0].fromlineno <= lineno <= self.finalbody[-1].tolineno:
return lineno, self.finalbody[-1].tolineno
return lineno, self.tolineno

def get_children(self):
yield from self.body
yield from self.handlers
yield from self.orelse
yield from self.finalbody


class Tuple(BaseContainer):
"""Class representing an :class:`ast.Tuple` node.
Expand Down
16 changes: 16 additions & 0 deletions astroid/rebuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1822,6 +1822,22 @@ def visit_try(
return self.visit_tryexcept(node, parent)
return None

def visit_trystar(self, node: ast.TryStar, parent: NodeNG) -> nodes.TryStar:
newnode = nodes.TryStar(
lineno=node.lineno,
col_offset=node.col_offset,
end_lineno=getattr(node, "end_lineno", None),
end_col_offset=getattr(node, "end_col_offset", None),
parent=parent,
)
newnode.postinit(
body=[self.visit(n, newnode) for n in node.body],
handlers=[self.visit(n, newnode) for n in node.handlers],
orelse=[self.visit(n, newnode) for n in node.orelse],
finalbody=[self.visit(n, newnode) for n in node.finalbody],
)
return newnode

def visit_tuple(self, node: ast.Tuple, parent: NodeNG) -> nodes.Tuple:
"""Visit a Tuple node by returning a fresh instance of it."""
context = self._get_context(node)
Expand Down
111 changes: 111 additions & 0 deletions tests/test_group_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
# Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt
import textwrap

import pytest

from astroid import (
AssignName,
ExceptHandler,
For,
Name,
TryExcept,
Uninferable,
bases,
extract_node,
)
from astroid.const import PY311_PLUS
from astroid.context import InferenceContext
from astroid.nodes import Expr, Raise, TryStar


@pytest.mark.skipif(not PY311_PLUS, reason="Requires Python 3.11 or higher")
def test_group_exceptions() -> None:
node = extract_node(
textwrap.dedent(
"""
try:
raise ExceptionGroup("group", [ValueError(654)])
except ExceptionGroup as eg:
for err in eg.exceptions:
if isinstance(err, ValueError):
print("Handling ValueError")
elif isinstance(err, TypeError):
print("Handling TypeError")"""
)
)
assert isinstance(node, TryExcept)
handler = node.handlers[0]
exception_group_block_range = (1, 4)
assert node.block_range(lineno=1) == exception_group_block_range
assert node.block_range(lineno=2) == (2, 2)
assert node.block_range(lineno=5) == (5, 9)
assert isinstance(handler, ExceptHandler)
assert handler.type.name == "ExceptionGroup"
children = list(handler.get_children())
assert len(children) == 3
exception_group, short_name, for_loop = children
assert isinstance(exception_group, Name)
assert exception_group.block_range(1) == exception_group_block_range
assert isinstance(short_name, AssignName)
assert isinstance(for_loop, For)


@pytest.mark.skipif(not PY311_PLUS, reason="Requires Python 3.11 or higher")
def test_star_exceptions() -> None:
node = extract_node(
textwrap.dedent(
"""
try:
raise ExceptionGroup("group", [ValueError(654)])
except* ValueError:
print("Handling ValueError")
except* TypeError:
print("Handling TypeError")
else:
sys.exit(127)
finally:
sys.exit(0)"""
)
)
assert isinstance(node, TryStar)
assert isinstance(node.body[0], Raise)
assert node.block_range(1) == (1, 11)
assert node.block_range(2) == (2, 2)
assert node.block_range(3) == (3, 3)
assert node.block_range(4) == (4, 4)
assert node.block_range(5) == (5, 5)
assert node.block_range(6) == (6, 6)
assert node.block_range(7) == (7, 7)
assert node.block_range(8) == (8, 8)
assert node.block_range(9) == (9, 9)
assert node.block_range(10) == (10, 10)
assert node.block_range(11) == (11, 11)
assert node.handlers
handler = node.handlers[0]
assert isinstance(handler, ExceptHandler)
assert handler.type.name == "ValueError"
orelse = node.orelse[0]
assert isinstance(orelse, Expr)
assert orelse.value.args[0].value == 127
final = node.finalbody[0]
assert isinstance(final, Expr)
assert final.value.args[0].value == 0


@pytest.mark.skipif(not PY311_PLUS, reason="Requires Python 3.11 or higher")
def test_star_exceptions_infer_name() -> None:
trystar = extract_node(
"""
try:
1/0
except* ValueError:
pass"""
)
name = "arbitraryName"
context = InferenceContext()
context.lookupname = name
stmts = bases._infer_stmts([trystar], context)
assert list(stmts) == [Uninferable]
assert context.lookupname == name

0 comments on commit da728ca

Please sign in to comment.