Skip to content

Commit

Permalink
Merge pull request #173 from python-discord/file-scan-recursion-fix
Browse files Browse the repository at this point in the history
Fix recursion error during file attachment parsing of deep nested paths
  • Loading branch information
ChrisLovering authored May 9, 2023
2 parents 9acc6f5 + 90910bd commit 9804a10
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 41 deletions.
31 changes: 23 additions & 8 deletions snekbox/memfs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Memory filesystem for snekbox."""
from __future__ import annotations

import glob
import logging
import time
import warnings
import weakref
from collections.abc import Generator
Expand Down Expand Up @@ -125,6 +127,7 @@ def files(
limit: int,
pattern: str = "**/*",
exclude_files: dict[Path, float] | None = None,
timeout: float | None = None,
) -> Generator[FileAttachment, None, None]:
"""
Yields FileAttachments for files found in the output directory.
Expand All @@ -135,12 +138,18 @@ def files(
exclude_files: A dict of Paths and last modified times.
Files will be excluded if their last modified time
is equal to the provided value.
timeout: Maximum time in seconds for file parsing.
Raises:
TimeoutError: If file parsing exceeds timeout.
"""
start_time = time.monotonic()
count = 0
for file in self.output.rglob(pattern):
# Ignore hidden directories or files
if any(part.startswith(".") for part in file.parts):
log.info(f"Skipping hidden path {file!s}")
files = glob.iglob(pattern, root_dir=str(self.output), recursive=True, include_hidden=False)
for file in (Path(self.output, f) for f in files):
if timeout and (time.monotonic() - start_time) > timeout:
raise TimeoutError("File parsing timeout exceeded in MemFS.files")

if not file.is_file():
continue

if exclude_files and (orig_time := exclude_files.get(file)):
Expand All @@ -154,17 +163,17 @@ def files(
log.info(f"Max attachments {limit} reached, skipping remaining files")
break

if file.is_file():
count += 1
log.info(f"Found valid file for upload {file.name!r}")
yield FileAttachment.from_path(file, relative_to=self.output)
count += 1
log.info(f"Found valid file for upload {file.name!r}")
yield FileAttachment.from_path(file, relative_to=self.output)

def files_list(
self,
limit: int,
pattern: str,
exclude_files: dict[Path, float] | None = None,
preload_dict: bool = False,
timeout: float | None = None,
) -> list[FileAttachment]:
"""
Return a sorted list of file paths within the output directory.
Expand All @@ -176,15 +185,21 @@ def files_list(
Files will be excluded if their last modified time
is equal to the provided value.
preload_dict: Whether to preload as_dict property data.
timeout: Maximum time in seconds for file parsing.
Returns:
List of FileAttachments sorted lexically by path name.
Raises:
TimeoutError: If file parsing exceeds timeout.
"""
start_time = time.monotonic()
res = sorted(
self.files(limit=limit, pattern=pattern, exclude_files=exclude_files),
key=lambda f: f.path,
)
if preload_dict:
for file in res:
if timeout and (time.monotonic() - start_time) > timeout:
raise TimeoutError("File parsing timeout exceeded in MemFS.files_list")
# Loads the cached property as attribute
_ = file.as_dict
return res
33 changes: 22 additions & 11 deletions snekbox/nsjail.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from snekbox.memfs import MemFS
from snekbox.process import EvalResult
from snekbox.snekio import FileAttachment
from snekbox.utils.timed import timed
from snekbox.utils.timed import time_limit

__all__ = ("NsJail",)

Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(
memfs_home: str = "home",
memfs_output: str = "home",
files_limit: int | None = 100,
files_timeout: float | None = 8,
files_timeout: int | None = 5,
files_pattern: str = "**/[!_]*",
):
"""
Expand Down Expand Up @@ -267,21 +267,32 @@ def python3(

# Parse attachments with time limit
try:
attachments = timed(
MemFS.files_list,
(fs, self.files_limit, self.files_pattern),
{
"preload_dict": True,
"exclude_files": files_written,
},
timeout=self.files_timeout,
)
with time_limit(self.files_timeout):
attachments = fs.files_list(
limit=self.files_limit,
pattern=self.files_pattern,
preload_dict=True,
exclude_files=files_written,
timeout=self.files_timeout,
)
log.info(f"Found {len(attachments)} files.")
except RecursionError:
log.info("Recursion error while parsing attachments")
return EvalResult(
args,
None,
"FileParsingError: Exceeded directory depth limit while parsing attachments",
)
except TimeoutError as e:
log.info(f"Exceeded time limit while parsing attachments: {e}")
return EvalResult(
args, None, "TimeoutError: Exceeded time limit while parsing attachments"
)
except Exception as e:
log.exception(f"Unexpected {type(e).__name__} while parse attachments", exc_info=e)
return EvalResult(
args, None, "FileParsingError: Unknown error while parsing attachments"
)

log_lines = nsj_log.read().decode("utf-8").splitlines()
if not log_lines and returncode == 255:
Expand Down
41 changes: 19 additions & 22 deletions snekbox/utils/timed.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,34 @@
"""Calling functions with time limits."""
import multiprocessing
from collections.abc import Callable, Iterable, Mapping
from typing import Any, TypeVar
import signal
from collections.abc import Generator
from contextlib import contextmanager
from typing import TypeVar

_T = TypeVar("_T")
_V = TypeVar("_V")

__all__ = ("timed",)
__all__ = ("time_limit",)


def timed(
func: Callable[[_T], _V],
args: Iterable = (),
kwds: Mapping[str, Any] | None = None,
timeout: float | None = None,
) -> _V:
@contextmanager
def time_limit(timeout: int | None = None) -> Generator[None, None, None]:
"""
Call a function with a time limit.
Decorator to call a function with a time limit.
Args:
func: Function to call.
args: Arguments for function.
kwds: Keyword arguments for function.
timeout: Timeout limit in seconds.
Raises:
TimeoutError: If the function call takes longer than `timeout` seconds.
"""
if kwds is None:
kwds = {}
with multiprocessing.Pool(1) as pool:
result = pool.apply_async(func, args, kwds)
try:
return result.get(timeout)
except multiprocessing.TimeoutError as e:
raise TimeoutError(f"Call to {func.__name__} timed out after {timeout} seconds.") from e

def signal_handler(_signum, _frame):
raise TimeoutError(f"time_limit call timed out after {timeout} seconds.")

signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(timeout)

try:
yield
finally:
signal.alarm(0)
23 changes: 23 additions & 0 deletions tests/test_nsjail.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,29 @@ def test_file_parsing_timeout(self):
)
self.assertEqual(result.stderr, None)

def test_file_parsing_depth_limit(self):
code = dedent(
"""
import os
x = ""
for _ in range(1000):
x += "a/"
os.mkdir(x)
open(f"{x}test.txt", "w").write("test")
"""
).strip()

nsjail = NsJail(memfs_instance_size=32 * Size.MiB, files_timeout=5)
result = nsjail.python3(["-c", code])
self.assertEqual(result.returncode, None)
self.assertEqual(
result.stdout,
"FileParsingError: Exceeded directory depth limit while parsing attachments",
)
self.assertEqual(result.stderr, None)

def test_file_write_error(self):
"""Test errors during file write."""
result = self.nsjail.python3(
Expand Down
30 changes: 30 additions & 0 deletions tests/test_timed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import math
import time
from unittest import TestCase

from snekbox.utils.timed import time_limit


class TimedTests(TestCase):
def test_sleep(self):
"""Test that a sleep can be interrupted."""
_finished = False
start = time.perf_counter()
with self.assertRaises(TimeoutError):
with time_limit(1):
time.sleep(2)
_finished = True
end = time.perf_counter()
self.assertLess(end - start, 2)
self.assertFalse(_finished)

def test_iter(self):
"""Test that a long-running built-in function can be interrupted."""
_result = 0
start = time.perf_counter()
with self.assertRaises(TimeoutError):
with time_limit(1):
_result = math.factorial(2**30)
end = time.perf_counter()
self.assertEqual(_result, 0)
self.assertLess(end - start, 2)

0 comments on commit 9804a10

Please sign in to comment.