From e76384c8e33869c4ddd85563a826d7720e78fd84 Mon Sep 17 00:00:00 2001 From: Katerina Vankova Date: Sun, 3 Nov 2024 20:02:18 +0100 Subject: [PATCH] feat: Add object support to ScriptContentReader --- src/script_content_reader.py | 57 ++++++++++++++++++++++------- tests/test_script_content_reader.py | 2 +- 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/src/script_content_reader.py b/src/script_content_reader.py index a6e7b71..f383032 100644 --- a/src/script_content_reader.py +++ b/src/script_content_reader.py @@ -1,3 +1,4 @@ +import ast import re from typing import Protocol @@ -36,27 +37,34 @@ def _read_full_script(self, scripts: list[ScriptMetadata]) -> list[ScriptMetadat def _process_scripts(self, scripts: list[ScriptMetadata]) -> list[ScriptMetadata]: full_scripts = [script for script in scripts if not script.extraction_part] - scripts_with_sections = [script for script in scripts if script.extraction_part] + scripts_with_extraction_part = [script for script in scripts if script.extraction_part] - if scripts_with_sections: - scripts_with_sections = self._read_script_section(scripts_with_sections) + if scripts_with_extraction_part: + scripts_with_extraction_part = self._read_script_part(scripts_with_extraction_part) - return full_scripts + scripts_with_sections + return full_scripts + scripts_with_extraction_part - def _read_script_section(self, scripts: list[ScriptMetadata]) -> list[ScriptMetadata]: + def _read_script_part(self, scripts: list[ScriptMetadata]) -> list[ScriptMetadata]: return [ ScriptMetadata( path=script.path, extraction_part=script.extraction_part, readme_start=script.readme_start, readme_end=script.readme_end, - content=self._extract_section(script), + content=self._extract_part(script), ) for script in scripts ] - def _extract_section(self, script: ScriptMetadata) -> str: + def _extract_part(self, script: ScriptMetadata) -> str: lines = script.content.split("\n") + + is_object, start, end = self._find_object_bounds(script) + + if is_object: + return "\n".join(lines[start:end]) + + # Else, it is a section section_bounds = self._find_section_bounds(lines) if not section_bounds: @@ -67,17 +75,40 @@ def _extract_section(self, script: ScriptMetadata) -> str: return "\n".join(lines[start:end]) def _find_section_bounds(self, lines: list[str]) -> tuple[int, int] | None: - section_start = None - section_end = None + start = None + end = None for i, line in enumerate(lines): if re.search(self._section_start_regex, line): - section_start = i + 1 + start = i + 1 elif re.search(self._section_end_regex, line): - section_end = i + end = i break - if section_start is None or section_end is None: + if start is None or end is None: return None - return section_start, section_end + return start, end + + def _find_object_bounds( + self, script: ScriptMetadata + ) -> tuple[bool, int | None, int | None]: + tree = ast.parse(script.content) + + object_nodes = [] + + for node in ast.walk(tree): + if ( + isinstance(node, ast.FunctionDef) + | isinstance(node, ast.AsyncFunctionDef) + | isinstance(node, ast.ClassDef) + ): + object_nodes.append(node) + + for node in object_nodes: + if script.extraction_part == getattr(node, "name", None): + start_lineno = getattr(node, "lineno", None) + end_lineno = getattr(node, "end_lineno", None) + return (True, start_lineno - 1 if start_lineno else None, end_lineno) + + return False, None, None diff --git a/tests/test_script_content_reader.py b/tests/test_script_content_reader.py index 0ca5831..d271ffb 100644 --- a/tests/test_script_content_reader.py +++ b/tests/test_script_content_reader.py @@ -206,4 +206,4 @@ def test_read_script_section( ): script_content_reader = ScriptContentReader() - assert script_content_reader._read_script_section(scripts) == expected_scripts + assert script_content_reader._read_script_part(scripts) == expected_scripts