diff --git a/astroid/nodes/__init__.py b/astroid/nodes/__init__.py index 68ddad74b0..b527ff7c3f 100644 --- a/astroid/nodes/__init__.py +++ b/astroid/nodes/__init__.py @@ -84,6 +84,7 @@ Subscript, TryExcept, TryFinally, + TryStar, Tuple, UnaryOp, Unknown, diff --git a/astroid/nodes/node_classes.py b/astroid/nodes/node_classes.py index 3cec089189..b7772c3c62 100644 --- a/astroid/nodes/node_classes.py +++ b/astroid/nodes/node_classes.py @@ -4216,6 +4216,107 @@ def get_children(self): yield from self.finalbody +class TryStar(_base_nodes.MultiLineWithElseBlockNode, _base_nodes.Statement): + """Class representing an :class:`ast.TryStar` node.""" + + _astroid_fields = ("body", "handlers", "orelse", "finalbody") + _multi_line_block_fields = ("body", "handlers", "orelse", "finalbody") + + def __init__( + self, + *, + lineno: int | None = None, + col_offset: int | None = None, + end_lineno: int | None = None, + end_col_offset: int | None = None, + parent: NodeNG | None = None, + ) -> None: + """ + :param lineno: The line that this node appears on in the source code. + :param col_offset: The column that this node appears on in the + source code. + :param parent: The parent node in the syntax tree. + :param end_lineno: The last line this node appears on in the source code. + :param end_col_offset: The end column this node appears on in the + source code. Note: This is after the last symbol. + """ + self.body: list[NodeNG] = [] + """The contents of the block to catch exceptions from.""" + + self.handlers: list[ExceptHandler] = [] + """The exception handlers.""" + + self.orelse: list[NodeNG] = [] + """The contents of the ``else`` block.""" + + self.finalbody: list[NodeNG] = [] + """The contents of the ``finally`` block.""" + + super().__init__( + lineno=lineno, + col_offset=col_offset, + end_lineno=end_lineno, + end_col_offset=end_col_offset, + parent=parent, + ) + + def postinit( + self, + *, + body: list[NodeNG] | None = None, + handlers: list[ExceptHandler] | None = None, + orelse: list[NodeNG] | None = None, + finalbody: list[NodeNG] | None = None, + ) -> None: + """Do some setup after initialisation. + :param body: The contents of the block to catch exceptions from. + :param handlers: The exception handlers. + :param orelse: The contents of the ``else`` block. + :param finalbody: The contents of the ``finally`` block. + """ + if body: + self.body = body + if handlers: + self.handlers = handlers + if orelse: + self.orelse = orelse + if finalbody: + self.finalbody = finalbody + + def _infer_name(self, frame, name): + return name + + def block_range(self, lineno: int) -> tuple[int, int]: + """Get a range from a given line number to where this node ends.""" + if lineno == self.fromlineno: + return lineno, lineno + if self.body and self.body[0].fromlineno <= lineno <= self.body[-1].tolineno: + # Inside try body - return from lineno till end of try body + return lineno, self.body[-1].tolineno + for exhandler in self.handlers: + if exhandler.type and lineno == exhandler.type.fromlineno: + return lineno, lineno + if exhandler.body[0].fromlineno <= lineno <= exhandler.body[-1].tolineno: + return lineno, exhandler.body[-1].tolineno + if self.orelse: + if self.orelse[0].fromlineno - 1 == lineno: + return lineno, lineno + if self.orelse[0].fromlineno <= lineno <= self.orelse[-1].tolineno: + return lineno, self.orelse[-1].tolineno + if self.finalbody: + if self.finalbody[0].fromlineno - 1 == lineno: + return lineno, lineno + if self.finalbody[0].fromlineno <= lineno <= self.finalbody[-1].tolineno: + return lineno, self.finalbody[-1].tolineno + return lineno, self.tolineno + + def get_children(self): + yield from self.body + yield from self.handlers + yield from self.orelse + yield from self.finalbody + + class Tuple(BaseContainer): """Class representing an :class:`ast.Tuple` node. diff --git a/astroid/rebuilder.py b/astroid/rebuilder.py index 0407dbfb74..6e996defdc 100644 --- a/astroid/rebuilder.py +++ b/astroid/rebuilder.py @@ -1822,6 +1822,22 @@ def visit_try( return self.visit_tryexcept(node, parent) return None + def visit_trystar(self, node: ast.TryStar, parent: NodeNG) -> nodes.TryStar: + newnode = nodes.TryStar( + lineno=node.lineno, + col_offset=node.col_offset, + end_lineno=getattr(node, "end_lineno", None), + end_col_offset=getattr(node, "end_col_offset", None), + parent=parent, + ) + newnode.postinit( + body=[self.visit(n, newnode) for n in node.body], + handlers=[self.visit(n, newnode) for n in node.handlers], + orelse=[self.visit(n, newnode) for n in node.orelse], + finalbody=[self.visit(n, newnode) for n in node.finalbody], + ) + return newnode + def visit_tuple(self, node: ast.Tuple, parent: NodeNG) -> nodes.Tuple: """Visit a Tuple node by returning a fresh instance of it.""" context = self._get_context(node) diff --git a/tests/test_group_exceptions.py b/tests/test_group_exceptions.py new file mode 100644 index 0000000000..173c25ed00 --- /dev/null +++ b/tests/test_group_exceptions.py @@ -0,0 +1,111 @@ +# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html +# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE +# Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt +import textwrap + +import pytest + +from astroid import ( + AssignName, + ExceptHandler, + For, + Name, + TryExcept, + Uninferable, + bases, + extract_node, +) +from astroid.const import PY311_PLUS +from astroid.context import InferenceContext +from astroid.nodes import Expr, Raise, TryStar + + +@pytest.mark.skipif(not PY311_PLUS, reason="Requires Python 3.11 or higher") +def test_group_exceptions() -> None: + node = extract_node( + textwrap.dedent( + """ + try: + raise ExceptionGroup("group", [ValueError(654)]) + except ExceptionGroup as eg: + for err in eg.exceptions: + if isinstance(err, ValueError): + print("Handling ValueError") + elif isinstance(err, TypeError): + print("Handling TypeError")""" + ) + ) + assert isinstance(node, TryExcept) + handler = node.handlers[0] + exception_group_block_range = (1, 4) + assert node.block_range(lineno=1) == exception_group_block_range + assert node.block_range(lineno=2) == (2, 2) + assert node.block_range(lineno=5) == (5, 9) + assert isinstance(handler, ExceptHandler) + assert handler.type.name == "ExceptionGroup" + children = list(handler.get_children()) + assert len(children) == 3 + exception_group, short_name, for_loop = children + assert isinstance(exception_group, Name) + assert exception_group.block_range(1) == exception_group_block_range + assert isinstance(short_name, AssignName) + assert isinstance(for_loop, For) + + +@pytest.mark.skipif(not PY311_PLUS, reason="Requires Python 3.11 or higher") +def test_star_exceptions() -> None: + node = extract_node( + textwrap.dedent( + """ + try: + raise ExceptionGroup("group", [ValueError(654)]) + except* ValueError: + print("Handling ValueError") + except* TypeError: + print("Handling TypeError") + else: + sys.exit(127) + finally: + sys.exit(0)""" + ) + ) + assert isinstance(node, TryStar) + assert isinstance(node.body[0], Raise) + assert node.block_range(1) == (1, 11) + assert node.block_range(2) == (2, 2) + assert node.block_range(3) == (3, 3) + assert node.block_range(4) == (4, 4) + assert node.block_range(5) == (5, 5) + assert node.block_range(6) == (6, 6) + assert node.block_range(7) == (7, 7) + assert node.block_range(8) == (8, 8) + assert node.block_range(9) == (9, 9) + assert node.block_range(10) == (10, 10) + assert node.block_range(11) == (11, 11) + assert node.handlers + handler = node.handlers[0] + assert isinstance(handler, ExceptHandler) + assert handler.type.name == "ValueError" + orelse = node.orelse[0] + assert isinstance(orelse, Expr) + assert orelse.value.args[0].value == 127 + final = node.finalbody[0] + assert isinstance(final, Expr) + assert final.value.args[0].value == 0 + + +@pytest.mark.skipif(not PY311_PLUS, reason="Requires Python 3.11 or higher") +def test_star_exceptions_infer_name() -> None: + trystar = extract_node( + """ +try: + 1/0 +except* ValueError: + pass""" + ) + name = "arbitraryName" + context = InferenceContext() + context.lookupname = name + stmts = bases._infer_stmts([trystar], context) + assert list(stmts) == [Uninferable] + assert context.lookupname == name