Skip to content

Commit

Permalink
feat: Add object support to ScriptContentReader
Browse files Browse the repository at this point in the history
  • Loading branch information
kvankova committed Nov 3, 2024
1 parent 0a02c58 commit e76384c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 14 deletions.
57 changes: 44 additions & 13 deletions src/script_content_reader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import re
from typing import Protocol

Expand Down Expand Up @@ -36,27 +37,34 @@ def _read_full_script(self, scripts: list[ScriptMetadata]) -> list[ScriptMetadat

def _process_scripts(self, scripts: list[ScriptMetadata]) -> list[ScriptMetadata]:
full_scripts = [script for script in scripts if not script.extraction_part]
scripts_with_sections = [script for script in scripts if script.extraction_part]
scripts_with_extraction_part = [script for script in scripts if script.extraction_part]

if scripts_with_sections:
scripts_with_sections = self._read_script_section(scripts_with_sections)
if scripts_with_extraction_part:
scripts_with_extraction_part = self._read_script_part(scripts_with_extraction_part)

return full_scripts + scripts_with_sections
return full_scripts + scripts_with_extraction_part

def _read_script_section(self, scripts: list[ScriptMetadata]) -> list[ScriptMetadata]:
def _read_script_part(self, scripts: list[ScriptMetadata]) -> list[ScriptMetadata]:
return [
ScriptMetadata(
path=script.path,
extraction_part=script.extraction_part,
readme_start=script.readme_start,
readme_end=script.readme_end,
content=self._extract_section(script),
content=self._extract_part(script),
)
for script in scripts
]

def _extract_section(self, script: ScriptMetadata) -> str:
def _extract_part(self, script: ScriptMetadata) -> str:
lines = script.content.split("\n")

is_object, start, end = self._find_object_bounds(script)

if is_object:
return "\n".join(lines[start:end])

# Else, it is a section
section_bounds = self._find_section_bounds(lines)

if not section_bounds:
Expand All @@ -67,17 +75,40 @@ def _extract_section(self, script: ScriptMetadata) -> str:
return "\n".join(lines[start:end])

def _find_section_bounds(self, lines: list[str]) -> tuple[int, int] | None:
section_start = None
section_end = None
start = None
end = None

for i, line in enumerate(lines):
if re.search(self._section_start_regex, line):
section_start = i + 1
start = i + 1
elif re.search(self._section_end_regex, line):
section_end = i
end = i
break

if section_start is None or section_end is None:
if start is None or end is None:
return None

return section_start, section_end
return start, end

def _find_object_bounds(
self, script: ScriptMetadata
) -> tuple[bool, int | None, int | None]:
tree = ast.parse(script.content)

object_nodes = []

for node in ast.walk(tree):
if (
isinstance(node, ast.FunctionDef)
| isinstance(node, ast.AsyncFunctionDef)
| isinstance(node, ast.ClassDef)
):
object_nodes.append(node)

for node in object_nodes:
if script.extraction_part == getattr(node, "name", None):
start_lineno = getattr(node, "lineno", None)
end_lineno = getattr(node, "end_lineno", None)
return (True, start_lineno - 1 if start_lineno else None, end_lineno)

return False, None, None
2 changes: 1 addition & 1 deletion tests/test_script_content_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,4 +206,4 @@ def test_read_script_section(
):
script_content_reader = ScriptContentReader()

assert script_content_reader._read_script_section(scripts) == expected_scripts
assert script_content_reader._read_script_part(scripts) == expected_scripts

0 comments on commit e76384c

Please sign in to comment.