Skip to content

Commit

Permalink
Add more dunder methods to TestCase to improve usability
Browse files Browse the repository at this point in the history
  • Loading branch information
tysmith committed Nov 27, 2023
1 parent 5e07e66 commit f4070d5
Show file tree
Hide file tree
Showing 12 changed files with 139 additions and 112 deletions.
4 changes: 2 additions & 2 deletions grizzly/adapter/no_op_adapter/test_no_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ def test_no_op_01():
adapter.setup(None, None)
with TestCase("a", adapter.name) as test:
assert not test.data_size
assert "a" not in test.contents
assert "a" not in test
adapter.generate(test, None)
assert "a" in test.contents
assert "a" in test
2 changes: 1 addition & 1 deletion grizzly/common/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def post_launch(self, delay=-1):
with TestCase("post_launch_delay.html", "None") as content:
content.add_from_file(
Path(__file__).parent / "post_launch_delay.html",
content.entry_point,
file_name=content.entry_point,
copy=True,
)
srv_map = ServerMap()
Expand Down
129 changes: 77 additions & 52 deletions grizzly/common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class TestFileExists(Exception):
TestFile with the same name"""


TestFile = namedtuple("TestFile", "file_name data_file")
TestFileMap = namedtuple("TestFileMap", "optional required")


Expand Down Expand Up @@ -76,20 +75,67 @@ def __init__(
self.time_limit = time_limit
self.timestamp = time() if timestamp is None else timestamp
self.version = __version__
self._files = TestFileMap(optional=[], required=[])
self._files = TestFileMap(optional={}, required={})
if data_path:
self._root = data_path
self._in_place = True
else:
self._root = Path(mkdtemp(prefix="testcase_", dir=grz_tmp("storage")))
self._in_place = False

def __contains__(self, item):
"""Scan TestCase contents for a URL.
Args:
item (str): URL.
Returns:
bool: URL found in contents.
"""
return item in self._files.required or item in self._files.optional

def __getitem__(self, key):
"""Lookup file path by URL.
Args:
key (str): URL.
Returns:
Path: Test file.
"""
if key in self._files.required:
return self._files.required[key]
return self._files.optional[key]

def __enter__(self):
return self

def __exit__(self, *exc):
self.cleanup()

def __iter__(self):
"""TestCase contents.
Args:
None
Yields:
str: URLs of contents.
"""
yield from self._files.required
yield from self._files.optional

def __len__(self):
"""Number files in TestCase.
Args:
None
Returns:
int: Number files in TestCase.
"""
return len(self._files.optional) + len(self._files.required)

def add_from_bytes(self, data, file_name, required=False):
"""Create a file and add it to the TestCase.
Expand Down Expand Up @@ -135,24 +181,24 @@ def add_from_file(self, src_file, file_name=None, required=False, copy=False):
url_path = self.sanitize_path(src_file.name)
else:
url_path = self.sanitize_path(file_name)
if url_path in self.contents:
if url_path in self:
raise TestFileExists(f"{url_path!r} exists in test")

test_file = TestFile(url_path, self._root / url_path)
dst_file = self._root / url_path
# don't move/copy data is already in place
if src_file.resolve() != test_file.data_file.resolve():
if src_file.resolve() != dst_file.resolve():
assert not self._in_place
test_file.data_file.parent.mkdir(parents=True, exist_ok=True)
dst_file.parent.mkdir(parents=True, exist_ok=True)
if copy:
copyfile(src_file, test_file.data_file)
copyfile(src_file, dst_file)
else:
move(src_file, test_file.data_file)
move(src_file, dst_file)

# entry_point is always 'required'
if required or test_file.file_name == self.entry_point:
self._files.required.append(test_file)
if required or url_path == self.entry_point:
self._files.required[url_path] = dst_file
else:
self._files.optional.append(test_file)
self._files.optional[url_path] = dst_file

def cleanup(self):
"""Remove all the test files.
Expand Down Expand Up @@ -210,28 +256,15 @@ def clone(self):
result.hang = self.hang
result.https = self.https
# copy test data files
for entry, required in chain(
product(self._files.required, [True]),
product(self._files.optional, [False]),
for (file_name, data_file), required in chain(
product(self._files.required.items(), [True]),
product(self._files.optional.items(), [False]),
):
result.add_from_file(
entry.data_file, file_name=entry.file_name, required=required, copy=True
data_file, file_name=file_name, required=required, copy=True
)
return result

@property
def contents(self):
"""All files in TestCase.
Args:
None
Yields:
str: File path (relative to wwwroot).
"""
for tfile in chain(self._files.required, self._files.optional):
yield tfile.file_name

@property
def data_size(self):
"""Total amount of data used (bytes) by the files in the TestCase.
Expand All @@ -243,8 +276,11 @@ def data_size(self):
int: Total size of the TestCase in bytes.
"""
total = 0
for group in self._files:
total += sum(x.data_file.stat().st_size for x in group)
for data_file in chain(
self._files.required.values(),
self._files.optional.values(),
):
total += data_file.stat().st_size
return total

def dump(self, dst_path, include_details=False):
Expand All @@ -259,10 +295,13 @@ def dump(self, dst_path, include_details=False):
"""
dst_path = Path(dst_path)
# save test files to dst_path
for test_file in chain(self._files.required, self._files.optional):
dst_file = dst_path / test_file.file_name
for src_url, src_file in chain(
self._files.required.items(),
self._files.optional.items(),
):
dst_file = dst_path / src_url
dst_file.parent.mkdir(parents=True, exist_ok=True)
copyfile(test_file.data_file, dst_file)
copyfile(src_file, dst_file)
# save test case files and meta data including:
# adapter used, input file, environment info and files
if include_details:
Expand Down Expand Up @@ -309,20 +348,6 @@ def _find_entry_point(path):
raise TestCaseLoadFailure("Could not determine entry point")
return entry_point

def get_file(self, path):
"""Lookup and return the TestFile with the specified file name.
Args:
path (str): Path (relative to wwwroot) of TestFile to retrieve.
Returns:
TestFile: TestFile with matching path otherwise None.
"""
for tfile in chain(self._files.optional, self._files.required):
if tfile.file_name == path:
return tfile
return None

@property
def landing_page(self):
"""TestCase.landing_page is deprecated!
Expand Down Expand Up @@ -400,7 +425,9 @@ def load(cls, path, entry_point=None, catalog=False):
and test.assets_path not in entry.parents
and entry.name != "test_info.json"
):
test.add_from_file(entry, entry.relative_to(test.root).as_posix())
test.add_from_file(
entry, file_name=entry.relative_to(test.root).as_posix()
)
else:
# add entry point
test.add_from_file(entry_point, required=True)
Expand Down Expand Up @@ -456,8 +483,7 @@ def optional(self):
Yields:
str: File path of each optional file.
"""
for test in self._files.optional:
yield test.file_name
yield from self._files.optional

@staticmethod
def read_info(path):
Expand Down Expand Up @@ -490,8 +516,7 @@ def required(self):
Yields:
str: File path of each file.
"""
for test in self._files.required:
yield test.file_name
yield from self._files.required

@property
def root(self):
Expand Down
2 changes: 1 addition & 1 deletion grizzly/common/test_iomanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_iomanager_02(report_size, iters):
for current in range(1, iters + 1):
tcase = iom.create_testcase("test-adapter", 10)
assert iom._generated == current
assert iom._test
assert iom._test is not None
precommit_size = len(iom.tests)
iom.commit()
assert iom._test is None
Expand Down
22 changes: 11 additions & 11 deletions grizzly/common/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_runner_02(mocker):
server.serve_path.return_value = (Served.ALL, {"a.bin": ""})
testcase = mocker.Mock(
spec_set=TestCase,
contents=serv_files,
__iter__=serv_files,
entry_point=serv_files[0],
required=serv_files,
)
Expand Down Expand Up @@ -172,13 +172,13 @@ def test_runner_04(mocker, ignore, status, idle, check_result):
"""test reporting timeout"""
server = mocker.Mock(spec_set=Sapphire)
target = mocker.Mock(spec_set=Target)
serv_files = {"a.bin": ""}
test = mocker.Mock(
spec_set=TestCase,
contents=["a.bin"],
__iter__=tuple(serv_files),
entry_point="a.bin",
required=["a.bin"],
)
serv_files = {"a.bin": ""}
server.serve_path.return_value = (Served.TIMEOUT, serv_files)
target.check_result.return_value = Result.FOUND
target.handle_hang.return_value = idle
Expand Down Expand Up @@ -349,9 +349,9 @@ def test_runner_10(mocker, tmp_path):
result = runner.run([], smap, test)
assert result.attempted
assert result.status == Result.NONE
assert "inc_file.bin" in test.contents
assert "nested/nested_inc.bin" in test.contents
assert "test/inc_file3.txt" in test.contents
assert "inc_file.bin" in test
assert "nested/nested_inc.bin" in test
assert "test/inc_file3.txt" in test


def test_runner_11(mocker):
Expand All @@ -366,8 +366,8 @@ def test_runner_11(mocker):
test.add_from_bytes(b"", "other.html")
# add untracked file
(test.root / "extra.js").touch()
assert "extra.html" not in test.contents
assert "other.html" in test.contents
assert "extra.html" not in test
assert "other.html" in test
server.serve_path.return_value = (
Served.ALL,
{
Expand All @@ -378,9 +378,9 @@ def test_runner_11(mocker):
result = runner.run([], ServerMap(), test)
assert result.attempted
assert result.status == Result.NONE
assert "test.html" in test.contents
assert "extra.js" in test.contents
assert "other.html" not in test.contents
assert "test.html" in test
assert "extra.js" in test
assert "other.html" not in test


@mark.parametrize(
Expand Down
Loading

0 comments on commit f4070d5

Please sign in to comment.