Skip to content

Commit

Permalink
refactor: Refactor ScriptContentReader
Browse files Browse the repository at this point in the history
  • Loading branch information
kvankova committed Nov 3, 2024
1 parent e76384c commit 7ad0199
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 29 deletions.
81 changes: 53 additions & 28 deletions src/script_content_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,38 @@ def __init__(self) -> None:
self._section_end_regex = r".*code_embedder:.*end"

def read(self, scripts: list[ScriptMetadata]) -> list[ScriptMetadata]:
script_contents = self._read_full_script(scripts)
return self._process_scripts(script_contents)
scripts_with_full_contents = self._read_full_script(scripts)
return self._process_scripts(scripts_with_full_contents)

def _read_full_script(self, scripts: list[ScriptMetadata]) -> list[ScriptMetadata]:
script_contents: list[ScriptMetadata] = []
scripts_with_full_contents: list[ScriptMetadata] = []

for script in scripts:
try:
with open(script.path) as script_file:
script.content = script_file.read()

script_contents.append(script)
scripts_with_full_contents.append(script)

except FileNotFoundError:
logger.error(f"Error: {script.path} not found. Skipping.")

return script_contents
return scripts_with_full_contents

def _process_scripts(self, scripts: list[ScriptMetadata]) -> list[ScriptMetadata]:
full_scripts = [script for script in scripts if not script.extraction_part]
scripts_with_extraction_part = [script for script in scripts if script.extraction_part]

if scripts_with_extraction_part:
scripts_with_extraction_part = self._read_script_part(scripts_with_extraction_part)
scripts_with_extraction_part = self._update_script_content_with_extraction_part(
scripts_with_extraction_part
)

return full_scripts + scripts_with_extraction_part

def _read_script_part(self, scripts: list[ScriptMetadata]) -> list[ScriptMetadata]:
def _update_script_content_with_extraction_part(
self, scripts: list[ScriptMetadata]
) -> list[ScriptMetadata]:
return [
ScriptMetadata(
path=script.path,
Expand All @@ -59,22 +63,51 @@ def _read_script_part(self, scripts: list[ScriptMetadata]) -> list[ScriptMetadat
def _extract_part(self, script: ScriptMetadata) -> str:
lines = script.content.split("\n")

# Try extracting as object first, then fall back to section
is_object, start, end = self._find_object_bounds(script)

if is_object:
return "\n".join(lines[start:end])

# Else, it is a section
section_bounds = self._find_section_bounds(lines)
return self._convert_lines_to_content(lines, start, end)

if not section_bounds:
logger.error(f"Section {script.extraction_part} not found in {script.path}")
# Extract section if not an object
start, end = self._find_section_bounds(lines)
if not self._validate_section_bounds(start, end, script):
return ""

start, end = section_bounds
return self._convert_lines_to_content(lines, start, end)

def _validate_section_bounds(
self, start: int | None, end: int | None, script: ScriptMetadata
) -> bool:
if not start and not end:
logger.error(
f"Part {script.extraction_part} not found in {script.path}. Skipping."
)
return False

if not start:
logger.error(
f"Start of section {script.extraction_part} not found in {script.path}. "
"Skipping."
)
return False

if not end:
logger.error(
f"End of section {script.extraction_part} not found in {script.path}. "
"Skipping."
)
return False

return True

def _convert_lines_to_content(
self, lines: list[str], start: int | None, end: int | None
) -> str:
if start is None or end is None:
return ""
return "\n".join(lines[start:end])

def _find_section_bounds(self, lines: list[str]) -> tuple[int, int] | None:
def _find_section_bounds(self, lines: list[str]) -> tuple[int | None, int | None]:
start = None
end = None

Expand All @@ -85,30 +118,22 @@ def _find_section_bounds(self, lines: list[str]) -> tuple[int, int] | None:
end = i
break

if start is None or end is None:
return None

return start, end

def _find_object_bounds(
self, script: ScriptMetadata
) -> tuple[bool, int | None, int | None]:
tree = ast.parse(script.content)

object_nodes = []

for node in ast.walk(tree):
if (
isinstance(node, ast.FunctionDef)
| isinstance(node, ast.AsyncFunctionDef)
| isinstance(node, ast.ClassDef)
):
object_nodes.append(node)

for node in object_nodes:
if script.extraction_part == getattr(node, "name", None):
start_lineno = getattr(node, "lineno", None)
end_lineno = getattr(node, "end_lineno", None)
return (True, start_lineno - 1 if start_lineno else None, end_lineno)
if script.extraction_part == getattr(node, "name", None):
start = getattr(node, "lineno", None)
end = getattr(node, "end_lineno", None)
return True, start - 1 if start else None, end

return False, None, None
7 changes: 6 additions & 1 deletion tests/test_script_content_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@
content=(
"import re\n"
"\n"
"\n"
"# Function verifying an email is valid\n"
"def verify_email(email: str) -> bool:\n"
' return re.match(r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+'
'$", email) is not None\n'
"\n"
"\n"
"class Person:\n"
" def __init__(self, name: str, age: int):\n"
" self.name = name\n"
Expand Down Expand Up @@ -206,4 +208,7 @@ def test_read_script_section(
):
script_content_reader = ScriptContentReader()

assert script_content_reader._read_script_part(scripts) == expected_scripts
assert (
script_content_reader._update_script_content_with_extraction_part(scripts)
== expected_scripts
)

0 comments on commit 7ad0199

Please sign in to comment.