From 1400f898e7f5b14c196f86cf9f568272e15065d4 Mon Sep 17 00:00:00 2001 From: Tyson Smith Date: Mon, 25 Mar 2024 11:27:34 -0700 Subject: [PATCH] Add type hints to storage.py --- grizzly/common/storage.py | 172 ++++++++++++++++++--------------- grizzly/common/test_storage.py | 2 + 2 files changed, 96 insertions(+), 78 deletions(-) diff --git a/grizzly/common/storage.py b/grizzly/common/storage.py index 7feb5bc2..07e5423a 100644 --- a/grizzly/common/storage.py +++ b/grizzly/common/storage.py @@ -3,7 +3,7 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. import json -from collections import namedtuple +from dataclasses import dataclass, field from itertools import chain, product from logging import getLogger from os.path import normpath, split @@ -11,6 +11,7 @@ from shutil import copyfile, copytree, move, rmtree from tempfile import NamedTemporaryFile, mkdtemp from time import time +from typing import Any, Dict, Generator, Optional, Tuple, Union, cast from .utils import __version__, grz_tmp @@ -32,7 +33,10 @@ class TestFileExists(Exception): TestFile with the same name""" -TestFileMap = namedtuple("TestFileMap", "optional required") +@dataclass(eq=False) +class TestFileMap: + optional: Dict[str, Path] = field(default_factory=dict) + required: Dict[str, Path] = field(default_factory=dict) class TestCase: @@ -55,98 +59,101 @@ class TestCase: def __init__( self, - entry_point, - adapter_name, - data_path=None, - input_fname=None, - timestamp=None, - ): + entry_point: str, + adapter_name: str, + data_path: Optional[Path] = None, + input_fname: Optional[str] = None, + timestamp: Optional[float] = None, + ) -> None: assert entry_point self.adapter_name = adapter_name - self.assets = {} - self.assets_path = None - self.duration = None - self.env_vars = {} + self.assets: Dict[str, str] = {} + self.assets_path: Optional[Path] = None + self.duration: Optional[float] = None + self.env_vars: Dict[str, str] = {} self.hang = False self.https = False self.input_fname = input_fname # file that was used to create the test case self.entry_point = self.sanitize_path(entry_point) self.timestamp = time() if timestamp is None else timestamp self.version = __version__ - self._files = TestFileMap(optional={}, required={}) - if data_path: + self._files = TestFileMap() + if data_path is not None: self._root = data_path self._in_place = True else: self._root = Path(mkdtemp(prefix="testcase_", dir=grz_tmp("storage"))) self._in_place = False - def __contains__(self, item): + def __contains__(self, item: str) -> bool: """Scan TestCase contents for a URL. Args: - item (str): URL. + item: URL. Returns: - bool: URL found in contents. + URL found in contents. """ return item in self._files.required or item in self._files.optional - def __getitem__(self, key): + def __getitem__(self, key: str) -> Path: """Lookup file path by URL. Args: - key (str): URL. + key: URL. Returns: - Path: Test file. + Test file. """ if key in self._files.required: return self._files.required[key] return self._files.optional[key] - def __enter__(self): + def __enter__(self) -> "TestCase": return self - def __exit__(self, *exc): + def __exit__(self, *exc: Any) -> None: self.cleanup() - def __iter__(self): + def __iter__(self) -> Generator[str, None, None]: """TestCase contents. Args: None Yields: - str: URLs of contents. + URLs of contents. """ yield from self._files.required yield from self._files.optional - def __len__(self): + def __len__(self) -> int: """Number files in TestCase. Args: None Returns: - int: Number files in TestCase. + Number files in TestCase. """ return len(self._files.optional) + len(self._files.required) - def add_from_bytes(self, data, file_name, required=False): + def add_from_bytes( + self, data: bytes, file_name: str, required: bool = False + ) -> None: """Create a file and add it to the TestCase. Args: - data (bytes): Data to write to file. - file_name (str): Used as file path on disk and URI. Relative to wwwroot. - required (bool): Indicates whether the file must be served. + data: Data to write to file. + file_name: Used as file path on disk and URI. Relative to wwwroot. + required: Indicates whether the file must be served. Returns: None """ - assert isinstance(data, bytes) - with NamedTemporaryFile(delete=False, dir=grz_tmp("storage")) as in_fp: + with NamedTemporaryFile( + delete=False, dir=grz_tmp("storage"), mode="w+b" + ) as in_fp: in_fp.write(data) data_file = Path(in_fp.name) @@ -159,17 +166,23 @@ def add_from_bytes(self, data, file_name, required=False): # unless an exception occurred so remove it if needed data_file.unlink(missing_ok=True) - def add_from_file(self, src_file, file_name=None, required=False, copy=False): + def add_from_file( + self, + src_file: Union[str, Path], + file_name: Optional[str] = None, + required: bool = False, + copy: bool = False, + ) -> None: """Add a file to the TestCase. Copy or move an existing file if needed. Args: - src_file (str): Path to existing file to use. - file_name (str): Used as file path on disk and URI. Relative to wwwroot. - If file_name is not given the name of the src_file - will be used. - required (bool): Indicates whether the file must be served. Typically this - is only used for the entry point. - copy (bool): Copy existing file data. Existing data is moved by default. + src_file: Path to existing file to use. + file_name: Used as file path on disk and URI. Relative to wwwroot. + If file_name is not given the name of the src_file + will be used. + required: Indicates whether the file must be served. Typically this + is only used for the entry point. + copy: Copy existing file data. Existing data is moved by default. Returns: None @@ -198,7 +211,7 @@ def add_from_file(self, src_file, file_name=None, required=False, copy=False): else: self._files.optional[url_path] = dst_file - def cleanup(self): + def cleanup(self) -> None: """Remove all the test files. Args: @@ -210,7 +223,7 @@ def cleanup(self): if not self._in_place: rmtree(self._root, ignore_errors=True) - def clear_optional(self): + def clear_optional(self) -> None: """Clear optional files. This does not remove data from the file system. Args: @@ -221,14 +234,14 @@ def clear_optional(self): """ self._files.optional.clear() - def clone(self): + def clone(self) -> "TestCase": """Make a copy of the TestCase. Args: None Returns: - TestCase: A copy of the TestCase instance. + A copy of the TestCase instance. """ result = type(self)( self.entry_point, @@ -263,14 +276,14 @@ def clone(self): return result @property - def data_size(self): - """Total amount of data used (bytes) by the files in the TestCase. + def data_size(self) -> int: + """Total amount of data used (in bytes) by the files in the TestCase. Args: None Returns: - int: Total size of the TestCase in bytes. + Total size of the TestCase in bytes. """ total = 0 for data_file in chain( @@ -280,12 +293,12 @@ def data_size(self): total += data_file.stat().st_size return total - def dump(self, dst_path, include_details=False): + def dump(self, dst_path: Path, include_details: bool = False) -> None: """Write all the test case data to the filesystem. Args: - dst_path (Path): Path to directory to output data. - include_details (bool): Output test info file. + dst_path: Path to directory to output data. + include_details: Output test info file. Returns: None @@ -301,7 +314,6 @@ def dump(self, dst_path, include_details=False): # save test case files and meta data including: # adapter used, input file, environment info and files if include_details: - assert isinstance(self.env_vars, dict) info = { "https": self.https, "target": self.entry_point, @@ -322,23 +334,22 @@ def dump(self, dst_path, include_details=False): info["version"] = self.version # save target assets and update meta data if self.assets: - assert isinstance(self.assets, dict) assert isinstance(self.assets_path, Path) info["assets"] = self.assets info["assets_path"] = "_assets_" - copytree(self.assets_path, dst_path / info["assets_path"]) + copytree(self.assets_path, dst_path / cast(str, info["assets_path"])) with (dst_path / TEST_INFO).open("w") as out_fp: json.dump(info, out_fp, indent=2, sort_keys=True) @staticmethod - def _find_entry_point(path): + def _find_entry_point(path: Path) -> Path: """Locate potential entry point. Args: - path (Path): Directory to scan. + path: Directory to scan. Returns: - Path: Entry point. + Entry point. """ entry_point = None for entry in path.iterdir(): @@ -351,20 +362,22 @@ def _find_entry_point(path): return entry_point @classmethod - def load(cls, path, entry_point=None, catalog=False): + def load( + cls, path: Path, entry_point: Optional[Path] = None, catalog: bool = False + ) -> "TestCase": """Load a TestCase. Args: - path (Path): Path can be: + path: Path can be: - A single file to be used as a test case. - A directory containing the test case data. - entry_point (Path): File to use as entry point. - catalog (bool): Scan contents of TestCase.root and track files. + entry_point: File to use as entry point. + catalog: Scan contents of TestCase.root and track files. Untracked files will be missed when using clone() or dump(). Only the entry point will be marked as 'required'. Returns: - TestCase: A TestCase. + A TestCase. """ assert isinstance(path, Path) # load test case info @@ -419,15 +432,17 @@ def load(cls, path, entry_point=None, catalog=False): return test @classmethod - def load_meta(cls, path, entry_point=None): + def load_meta( + cls, path: Path, entry_point: Optional[Path] = None + ) -> Tuple[Path, Dict[str, Any]]: """Process and sanitize TestCase meta data. Args: - path (Path): Directory containing test info file. - entry_point (Path): See TestCase.load(). + path: Directory containing test info file. + entry_point: See TestCase.load(). Returns: - tuple(Path, dict): Test case entry point and loaded test info. + Test case entry point and loaded test info. """ # load test case info if available if path.is_dir(): @@ -440,6 +455,7 @@ def load_meta(cls, path, entry_point=None): info["target"] = entry_point.name elif info: entry_point = path / info["target"] + assert entry_point else: # attempt to determine entry point entry_point = cls._find_entry_point(path) @@ -456,26 +472,26 @@ def load_meta(cls, path, entry_point=None): return (entry_point, info) @property - def optional(self): + def optional(self) -> Generator[str, None, None]: """Get file paths of optional files. Args: None Yields: - str: File path of each optional file. + File path of each optional file. """ yield from self._files.optional @staticmethod - def read_info(path): + def read_info(path: Path) -> Dict[str, Any]: """Attempt to load test info. Args: - path (Path): Directory containing test info file. + path: Directory containing test info file. - Yields: - dict: Test info. + Returns: + Test info. """ try: with (path / TEST_INFO).open("r") as in_fp: @@ -489,38 +505,38 @@ def read_info(path): return info or {} @property - def required(self): + def required(self) -> Generator[str, None, None]: """Get file paths of required files. Args: None Yields: - str: File path of each file. + File path of each file. """ yield from self._files.required @property - def root(self): + def root(self) -> Path: """Location test data is stored on disk. This is intended to be used as wwwroot. Args: None Returns: - Path: Directory containing test case files. + Directory containing test case files. """ return self._root @staticmethod - def sanitize_path(path): + def sanitize_path(path: str) -> str: """Sanitize given path for use as a URI path. Args: - path (str): Path to sanitize. Must be relative to wwwroot. + path: Path to sanitize. Must be relative to wwwroot. Returns: - str: Sanitized path. + Sanitized path. """ assert isinstance(path, str) # check for missing filename or path containing drive letter (Windows) diff --git a/grizzly/common/test_storage.py b/grizzly/common/test_storage.py index 6416327d..f1758c60 100644 --- a/grizzly/common/test_storage.py +++ b/grizzly/common/test_storage.py @@ -286,6 +286,7 @@ def test_testcase_12(tmp_path, catalog): else: assert not any(loaded.optional) assert loaded.assets == {"example": "asset.txt"} + assert loaded.assets_path is not None assert (loaded.root / loaded.assets_path / "asset.txt").is_file() assert loaded.env_vars.get("TEST_ENV_VAR") == "100" assert len(tuple(loaded.required)) == 1 @@ -441,6 +442,7 @@ def test_testcase_17(tmp_path, remote_assets): assert not set(src.optional) ^ set(dst.optional) assert not set(src.required) ^ set(dst.required) if remote_assets is not None: + assert dst.assets_path assert (dst.assets_path / dst.assets["foo"]).is_file()