From 5fa9f67853a7112e08185ed416de7907b8e524da Mon Sep 17 00:00:00 2001 From: Ned Batchelder Date: Sat, 4 May 2024 08:20:53 -0400 Subject: [PATCH] fix: avoid max recursion errors in ast code. #1774 --- CHANGES.rst | 7 +++++++ coverage/phystokens.py | 32 ++++++++++++-------------------- coverage/regions.py | 28 ++++++++++++++++++++-------- tests/test_phystokens.py | 1 + 4 files changed, 40 insertions(+), 28 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 707acfd2b..97d2654b2 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -28,6 +28,10 @@ Unreleased on the first line. This closes `issue 754`_. The fix was contributed by `Daniel Diniz `_. +- Fix: very complex source files like `this one `_ could + cause a maximum recursion error when creating an HTML report. This is now + fixed, closing `issue 1774`_. + - HTML report improvements: - Support files (JavaScript and CSS) referenced by the HTML report now have @@ -41,10 +45,13 @@ Unreleased - Column sort order is remembered better as you move between the index pages, fixing `issue 1766`_. Thanks, `Daniel Diniz `_. + +.. _resolvent_lookup: https://github.com/sympy/sympy/blob/130950f3e6b3f97fcc17f4599ac08f70fdd2e9d4/sympy/polys/numberfields/resolvent_lookup.py .. _issue 754: https://github.com/nedbat/coveragepy/issues/754 .. _issue 1766: https://github.com/nedbat/coveragepy/issues/1766 .. _pull 1768: https://github.com/nedbat/coveragepy/pull/1768 .. _pull 1773: https://github.com/nedbat/coveragepy/pull/1773 +.. _issue 1774: https://github.com/nedbat/coveragepy/issues/1774 .. scriv-start-here diff --git a/coverage/phystokens.py b/coverage/phystokens.py index 626e9967e..a42d184a6 100644 --- a/coverage/phystokens.py +++ b/coverage/phystokens.py @@ -78,26 +78,19 @@ def _phys_tokens(toks: TokenInfos) -> TokenInfos: last_lineno = elineno -class SoftKeywordFinder(ast.NodeVisitor): +def find_soft_key_lines(source: str) -> set[TLineNo]: """Helper for finding lines with soft keywords, like match/case lines.""" - def __init__(self, source: str) -> None: - # This will be the set of line numbers that start with a soft keyword. - self.soft_key_lines: set[TLineNo] = set() - self.visit(ast.parse(source)) - - if sys.version_info >= (3, 10): - def visit_Match(self, node: ast.Match) -> None: - """Invoked by ast.NodeVisitor.visit""" - self.soft_key_lines.add(node.lineno) + soft_key_lines: set[TLineNo] = set() + + for node in ast.walk(ast.parse(source)): + if sys.version_info >= (3, 10) and isinstance(node, ast.Match): + soft_key_lines.add(node.lineno) for case in node.cases: - self.soft_key_lines.add(case.pattern.lineno) - self.generic_visit(node) + soft_key_lines.add(case.pattern.lineno) + elif sys.version_info >= (3, 12) and isinstance(node, ast.TypeAlias): + soft_key_lines.add(node.lineno) - if sys.version_info >= (3, 12): - def visit_TypeAlias(self, node: ast.TypeAlias) -> None: - """Invoked by ast.NodeVisitor.visit""" - self.soft_key_lines.add(node.lineno) - self.generic_visit(node) + return soft_key_lines def source_token_lines(source: str) -> TSourceTokenLines: @@ -124,7 +117,7 @@ def source_token_lines(source: str) -> TSourceTokenLines: tokgen = generate_tokens(source) if env.PYBEHAVIOR.soft_keywords: - soft_key_lines = SoftKeywordFinder(source).soft_key_lines + soft_key_lines = find_soft_key_lines(source) for ttype, ttext, (sline, scol), (_, ecol), _ in _phys_tokens(tokgen): mark_start = True @@ -151,8 +144,7 @@ def source_token_lines(source: str) -> TSourceTokenLines: # Need the version_info check to keep mypy from borking # on issoftkeyword here. if env.PYBEHAVIOR.soft_keywords and keyword.issoftkeyword(ttext): - # Soft keywords appear at the start of the line, - # on lines that start match or case statements. + # Soft keywords appear at the start of their line. if len(line) == 0: is_start_of_line = True elif (len(line) == 1) and line[0][0] == "ws": diff --git a/coverage/regions.py b/coverage/regions.py index ad319b0e2..7954be69d 100644 --- a/coverage/regions.py +++ b/coverage/regions.py @@ -21,7 +21,7 @@ class Context: lines: set[int] -class RegionFinder(ast.NodeVisitor): +class RegionFinder: """An ast visitor that will find and track regions of code. Functions and classes are tracked by name. Results are in the .regions @@ -34,13 +34,27 @@ def __init__(self) -> None: def parse_source(self, source: str) -> None: """Parse `source` and walk the ast to populate the .regions attribute.""" - self.visit(ast.parse(source)) + self.handle_node(ast.parse(source)) def fq_node_name(self) -> str: """Get the current fully qualified name we're processing.""" return ".".join(c.name for c in self.context) - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + def handle_node(self, node: ast.AST) -> None: + """Recursively handle any node.""" + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + self.handle_FunctionDef(node) + elif isinstance(node, ast.ClassDef): + self.handle_ClassDef(node) + else: + self.handle_node_body(node) + + def handle_node_body(self, node: ast.AST) -> None: + """Recursively handle the nodes in this node's body, if any.""" + for body_node in getattr(node, "body", ()): + self.handle_node(body_node) + + def handle_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None: """Called for `def` or `async def`.""" lines = set(range(node.body[0].lineno, cast(int, node.body[-1].end_lineno) + 1)) if self.context and self.context[-1].kind == "class": @@ -60,12 +74,10 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: lines=lines, ) ) - self.generic_visit(node) + self.handle_node_body(node) self.context.pop() - visit_AsyncFunctionDef = visit_FunctionDef # type: ignore[assignment] - - def visit_ClassDef(self, node: ast.ClassDef) -> None: + def handle_ClassDef(self, node: ast.ClassDef) -> None: """Called for `class`.""" # The lines for a class are the lines in the methods of the class. # We start empty, and count on visit_FunctionDef to add the lines it @@ -80,7 +92,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: lines=lines, ) ) - self.generic_visit(node) + self.handle_node_body(node) self.context.pop() # Class bodies should be excluded from the enclosing classes. for ancestor in reversed(self.context): diff --git a/tests/test_phystokens.py b/tests/test_phystokens.py index ca1efeae5..063e517d6 100644 --- a/tests/test_phystokens.py +++ b/tests/test_phystokens.py @@ -135,6 +135,7 @@ def match(): global case """) tokens = list(source_token_lines(source)) + print(tokens) assert tokens[0][0] == ("key", "match") assert tokens[0][4] == ("nam", "match") assert tokens[1][1] == ("key", "case")