Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix recursion error during file attachment parsing of deep nested paths #173

Merged
merged 15 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions snekbox/memfs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Memory filesystem for snekbox."""
from __future__ import annotations

import glob
import logging
import time
import warnings
import weakref
from collections.abc import Generator
Expand Down Expand Up @@ -125,6 +127,7 @@ def files(
limit: int,
pattern: str = "**/*",
exclude_files: dict[Path, float] | None = None,
timeout: float | None = None,
) -> Generator[FileAttachment, None, None]:
"""
Yields FileAttachments for files found in the output directory.
Expand All @@ -135,12 +138,17 @@ def files(
exclude_files: A dict of Paths and last modified times.
Files will be excluded if their last modified time
is equal to the provided value.
timeout: The maximum time for the file parsing. If exceeded,
a TimeoutError will be raised.
"""
count = 0
for file in self.output.rglob(pattern):
# Ignore hidden directories or files
if any(part.startswith(".") for part in file.parts):
log.info(f"Skipping hidden path {file!s}")
start_time = time.monotonic()
added = 0
files = glob.iglob(pattern, root_dir=str(self.output), recursive=True, include_hidden=False)
for file in (Path(self.output, f) for f in files):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Elegant use of a generator expression here!

if timeout and (time.monotonic() - start_time) > timeout:
raise TimeoutError("File parsing timeout exceeded in MemFS.files")

if not file.is_file():
continue

if exclude_files and (orig_time := exclude_files.get(file)):
Expand All @@ -150,21 +158,21 @@ def files(
log.info(f"Skipping {file.name!r} as it has not been modified")
continue

if count > limit:
if added > limit:
log.info(f"Max attachments {limit} reached, skipping remaining files")
break

if file.is_file():
count += 1
log.info(f"Found valid file for upload {file.name!r}")
yield FileAttachment.from_path(file, relative_to=self.output)
added += 1
log.info(f"Found valid file for upload {file.name!r}")
yield FileAttachment.from_path(file, relative_to=self.output)

def files_list(
self,
limit: int,
pattern: str,
exclude_files: dict[Path, float] | None = None,
preload_dict: bool = False,
timeout: float | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on using float("inf") (aka math.inf) instead of None to specify no timeout? Maybe a bit odd at first glance, but I think it makes sense, and it does avoid dealing with None checks.

Copy link
Member Author

@ionite34 ionite34 Mar 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like the None is still useful on performance grounds since we currently bypass the time.monotonic call and comparison when timeout is None.

) -> list[FileAttachment]:
"""
Return a sorted list of file paths within the output directory.
Expand All @@ -176,15 +184,20 @@ def files_list(
Files will be excluded if their last modified time
is equal to the provided value.
preload_dict: Whether to preload as_dict property data.
timeout: The maximum time for the file parsing. If exceeded,
a TimeoutError will be raised.
Returns:
List of FileAttachments sorted lexically by path name.
"""
start_time = time.monotonic()
res = sorted(
self.files(limit=limit, pattern=pattern, exclude_files=exclude_files),
key=lambda f: f.path,
)
if preload_dict:
for file in res:
if timeout and (time.monotonic() - start_time) > timeout:
raise TimeoutError("File parsing timeout exceeded in MemFS.files_list")
# Loads the cached property as attribute
_ = file.as_dict
return res
33 changes: 22 additions & 11 deletions snekbox/nsjail.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from snekbox.memfs import MemFS
from snekbox.process import EvalResult
from snekbox.snekio import FileAttachment
from snekbox.utils.timed import timed
from snekbox.utils.timed import time_limit

__all__ = ("NsJail",)

Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(
memfs_home: str = "home",
memfs_output: str = "home",
files_limit: int | None = 100,
files_timeout: float | None = 8,
files_timeout: float | None = 5,
files_pattern: str = "**/[!_]*",
):
"""
Expand Down Expand Up @@ -267,21 +267,32 @@ def python3(

# Parse attachments with time limit
try:
attachments = timed(
MemFS.files_list,
(fs, self.files_limit, self.files_pattern),
{
"preload_dict": True,
"exclude_files": files_written,
},
timeout=self.files_timeout,
)
with time_limit(self.files_timeout):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a test that asserts this time outs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a test parsing symlink files for timeout, but that doesn't determine if the timeout occured due to the time.monotic() comparison timeout or by the signal. The signal timeout would be if a very large single file was parsed, or pathlib ends up in some loop.

But also added 2 new tests for time_limit successfully interrupting time.sleep and a long-running built-in C function (math.factorial) in 1af5bf0

attachments = fs.files_list(
limit=self.files_limit,
pattern=self.files_pattern,
preload_dict=True,
exclude_files=files_written,
timeout=self.files_timeout,
)
log.info(f"Found {len(attachments)} files.")
except RecursionError:
log.info("Recursion error while parsing attachments")
return EvalResult(
args,
None,
"FileParsingError: Exceeded directory depth limit while parsing attachments",
)
except TimeoutError as e:
log.info(f"Exceeded time limit while parsing attachments: {e}")
return EvalResult(
args, None, "TimeoutError: Exceeded time limit while parsing attachments"
)
except Exception as e:
log.error(f"Unexpected {type(e).__name__} while parse attachments: {e}")
return EvalResult(
args, None, "FileParsingError: Unknown error while parsing attachments"
)

log_lines = nsj_log.read().decode("utf-8").splitlines()
if not log_lines and returncode == 255:
Expand Down
41 changes: 19 additions & 22 deletions snekbox/utils/timed.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,34 @@
"""Calling functions with time limits."""
import multiprocessing
from collections.abc import Callable, Iterable, Mapping
from typing import Any, TypeVar
import signal
from collections.abc import Generator
from contextlib import contextmanager
from typing import TypeVar

_T = TypeVar("_T")
_V = TypeVar("_V")

__all__ = ("timed",)
__all__ = ("time_limit",)


def timed(
func: Callable[[_T], _V],
args: Iterable = (),
kwds: Mapping[str, Any] | None = None,
timeout: float | None = None,
) -> _V:
@contextmanager
def time_limit(timeout: int | None = None) -> Generator[None, None, None]:
"""
Call a function with a time limit.
Decorator to call a function with a time limit. Uses SIGALRM, requires a UNIX system.

Args:
func: Function to call.
args: Arguments for function.
kwds: Keyword arguments for function.
timeout: Timeout limit in seconds.

Raises:
TimeoutError: If the function call takes longer than `timeout` seconds.
"""
if kwds is None:
kwds = {}
with multiprocessing.Pool(1) as pool:
result = pool.apply_async(func, args, kwds)
try:
return result.get(timeout)
except multiprocessing.TimeoutError as e:
raise TimeoutError(f"Call to {func.__name__} timed out after {timeout} seconds.") from e

def signal_handler(_signum, _frame):
raise TimeoutError(f"time_limit call timed out after {timeout} seconds.")

signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(timeout)

try:
yield
finally:
signal.alarm(0)
Comment on lines +29 to +34
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clever. I don't know enough about signals to say whether this is an appropriate use, but I think it's low risk to try.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to work pretty accurately for interrupting time.sleep and built-in C functions in these tests 1af5bf0, so it seems okay for our usecase, and avoids the potential broken pipes issue with multiprocessing.

23 changes: 23 additions & 0 deletions tests/test_nsjail.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,29 @@ def test_file_parsing_timeout(self):
)
self.assertEqual(result.stderr, None)

def test_file_parsing_depth_limit(self):
code = dedent(
"""
import os

x = ""
for _ in range(1000):
x += "a/"
os.mkdir(x)

open(f"{x}test.txt", "w").write("test")
"""
).strip()

nsjail = NsJail(memfs_instance_size=32 * Size.MiB, files_timeout=5)
result = nsjail.python3(["-c", code])
self.assertEqual(result.returncode, None)
self.assertEqual(
result.stdout,
"FileParsingError: Exceeded directory depth limit while parsing attachments",
)
self.assertEqual(result.stderr, None)

def test_file_write_error(self):
"""Test errors during file write."""
result = self.nsjail.python3(
Expand Down