From 870edfdf65692077a655a18a3b3f7b2d47177652 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 25 Feb 2021 11:39:35 +0100 Subject: [PATCH 1/3] [skip CI] fix kwargs forwarding in fake data utility functions (#3459) --- test/datasets_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 337a7382366..4e3fd0ac0e3 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -572,7 +572,7 @@ def create_image_file( image = create_image_or_video_tensor(size) file = pathlib.Path(root) / name - PIL.Image.fromarray(image.permute(2, 1, 0).numpy()).save(file) + PIL.Image.fromarray(image.permute(2, 1, 0).numpy()).save(file, **kwargs) return file @@ -708,6 +708,6 @@ def size(idx): os.makedirs(root) return [ - create_video_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size) + create_video_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs) for idx in range(num_examples) ] From a24191ed60a920f0bac2576e5ad0d7be2c8d944a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 25 Feb 2021 12:36:25 +0100 Subject: [PATCH 2/3] add version information to docstring of Phototour dataset (#3437) Co-authored-by: vfdev --- torchvision/datasets/phototour.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/torchvision/datasets/phototour.py b/torchvision/datasets/phototour.py index ce427e04883..a64287678df 100644 --- a/torchvision/datasets/phototour.py +++ b/torchvision/datasets/phototour.py @@ -10,7 +10,17 @@ class PhotoTour(VisionDataset): - """`Learning Local Image Descriptors Data `_ Dataset. + """`Multi-view Stereo Correspondence `_ Dataset. + + .. note:: + + We only provide the newer version of the dataset, since the authors state that it + + is more suitable for training descriptors based on difference of Gaussian, or Harris corners, as the + patches are centred on real interest point detections, rather than being projections of 3D points as is the + case in the old dataset. + + The original dataset is available under http://phototour.cs.washington.edu/patches/default.htm. Args: From 2e8c124ff7266437efd37c7a23684f64fe6904cd Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 25 Feb 2021 17:54:14 +0100 Subject: [PATCH 3/3] Improve dataset test infrastructure (#3450) * always use default config as base * fix test_all_configs decorator * lint * add a utility function to create a random string * move output check of inject_fake_data to dedicated method * always disable download and extract functionality --- test/datasets_utils.py | 149 ++++++++++++++++++++++------------------- 1 file changed, 80 insertions(+), 69 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 4e3fd0ac0e3..34190b2bfbc 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -6,6 +6,8 @@ import itertools import os import pathlib +import random +import string import unittest import unittest.mock from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union @@ -32,6 +34,7 @@ "create_image_folder", "create_video_file", "create_video_folder", + "create_random_string", ] @@ -93,14 +96,6 @@ def inner_wrapper(*args, **kwargs): return outer_wrapper -# As of Python 3.7 this is provided by contextlib -# https://docs.python.org/3.7/library/contextlib.html#contextlib.nullcontext -# TODO: If the minimum Python requirement is >= 3.7, replace this -@contextlib.contextmanager -def nullcontext(enter_result=None): - yield enter_result - - def test_all_configs(test): """Decorator to run test against all configurations. @@ -116,7 +111,7 @@ def test_foo(self, config): @functools.wraps(test) def wrapper(self): - for config in self.CONFIGS: + for config in self.CONFIGS or (self._DEFAULT_CONFIG,): with self.subTest(**config): test(self, config) @@ -207,6 +202,8 @@ def test_baz(self): CONFIGS = None REQUIRED_PACKAGES = None + _DEFAULT_CONFIG = None + _TRANSFORM_KWARGS = { "transform", "target_transform", @@ -268,7 +265,7 @@ def create_dataset( self, config: Optional[Dict[str, Any]] = None, inject_fake_data: bool = True, - disable_download_extract: Optional[bool] = None, + patch_checks: Optional[bool] = None, **kwargs: Any, ) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]: r"""Create the dataset in a temporary directory. @@ -278,8 +275,8 @@ def create_dataset( default configuration is used. inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before creating the dataset. - disable_download_extract (Optional[bool]): If ``True`` disable download and extract logic while creating - the dataset. If ``None`` (default) this takes the same value as ``inject_fake_data``. + patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If + omitted defaults to the same value as ``inject_fake_data``. **kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they overlap with ``config``. @@ -288,43 +285,28 @@ def create_dataset( info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data` for details. """ - if config is None: - config = self.CONFIGS[0].copy() + default_config = self._DEFAULT_CONFIG.copy() + if config is not None: + default_config.update(config) + config = default_config + + if patch_checks is None: + patch_checks = inject_fake_data special_kwargs, other_kwargs = self._split_kwargs(kwargs) + if "download" in self._HAS_SPECIAL_KWARG: + special_kwargs["download"] = False config.update(other_kwargs) - if disable_download_extract is None: - disable_download_extract = inject_fake_data + patchers = self._patch_download_extract() + if patch_checks: + patchers.update(self._patch_checks()) with get_tmp_dir() as tmpdir: args = self.dataset_args(tmpdir, config) + info = self._inject_fake_data(tmpdir, config) if inject_fake_data else None - if inject_fake_data: - info = self.inject_fake_data(tmpdir, config) - if info is None: - raise UsageError( - "The method 'inject_fake_data' needs to return at least an integer indicating the number of " - "examples for the current configuration." - ) - elif isinstance(info, int): - info = dict(num_examples=info) - elif not isinstance(info, dict): - raise UsageError( - f"The additional information returned by the method 'inject_fake_data' must be either an " - f"integer indicating the number of examples for the current configuration or a dictionary with " - f"the same content. Got {type(info)} instead." - ) - elif "num_examples" not in info: - raise UsageError( - "The information dictionary returned by the method 'inject_fake_data' must contain a " - "'num_examples' field that holds the number of examples for the current configuration." - ) - else: - info = None - - cm = self._disable_download_extract if disable_download_extract else nullcontext - with cm(special_kwargs), disable_console_output(): + with self._maybe_apply_patches(patchers), disable_console_output(): dataset = self.DATASET_CLASS(*args, **config, **special_kwargs) yield dataset, info @@ -352,19 +334,17 @@ def _verify_required_public_class_attributes(cls): @classmethod def _populate_private_class_attributes(cls): argspec = inspect.getfullargspec(cls.DATASET_CLASS.__init__) + + cls._DEFAULT_CONFIG = { + kwarg: default + for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults) + if kwarg not in cls._SPECIAL_KWARGS + } + cls._HAS_SPECIAL_KWARG = {name for name in cls._SPECIAL_KWARGS if name in argspec.args} @classmethod def _process_optional_public_class_attributes(cls): - argspec = inspect.getfullargspec(cls.DATASET_CLASS.__init__) - if cls.CONFIGS is None: - config = { - kwarg: default - for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults) - if kwarg not in cls._SPECIAL_KWARGS - } - cls.CONFIGS = (config,) - if cls.REQUIRED_PACKAGES is not None: try: for pkg in cls.REQUIRED_PACKAGES: @@ -380,28 +360,44 @@ def _split_kwargs(self, kwargs): other_kwargs = {key: special_kwargs.pop(key) for key in set(special_kwargs.keys()) - self._SPECIAL_KWARGS} return special_kwargs, other_kwargs - @contextlib.contextmanager - def _disable_download_extract(self, special_kwargs): - inject_download_kwarg = "download" in self._HAS_SPECIAL_KWARG and "download" not in special_kwargs - if inject_download_kwarg: - special_kwargs["download"] = False + def _inject_fake_data(self, tmpdir, config): + info = self.inject_fake_data(tmpdir, config) + if info is None: + raise UsageError( + "The method 'inject_fake_data' needs to return at least an integer indicating the number of " + "examples for the current configuration." + ) + elif isinstance(info, int): + info = dict(num_examples=info) + elif not isinstance(info, dict): + raise UsageError( + f"The additional information returned by the method 'inject_fake_data' must be either an " + f"integer indicating the number of examples for the current configuration or a dictionary with " + f"the same content. Got {type(info)} instead." + ) + elif "num_examples" not in info: + raise UsageError( + "The information dictionary returned by the method 'inject_fake_data' must contain a " + "'num_examples' field that holds the number of examples for the current configuration." + ) + return info + + def _patch_download_extract(self): + module = inspect.getmodule(self.DATASET_CLASS).__name__ + return {unittest.mock.patch(f"{module}.{function}") for function in self._DOWNLOAD_EXTRACT_FUNCTIONS} + def _patch_checks(self): module = inspect.getmodule(self.DATASET_CLASS).__name__ + return {unittest.mock.patch(f"{module}.{function}", return_value=True) for function in self._CHECK_FUNCTIONS} + + @contextlib.contextmanager + def _maybe_apply_patches(self, patchers): with contextlib.ExitStack() as stack: mocks = {} - for function, kwargs in itertools.chain( - zip(self._CHECK_FUNCTIONS, [dict(return_value=True)] * len(self._CHECK_FUNCTIONS)), - zip(self._DOWNLOAD_EXTRACT_FUNCTIONS, [dict()] * len(self._DOWNLOAD_EXTRACT_FUNCTIONS)), - ): + for patcher in patchers: with contextlib.suppress(AttributeError): - patcher = unittest.mock.patch(f"{module}.{function}", **kwargs) - mocks[function] = stack.enter_context(patcher) - - try: - yield mocks - finally: - if inject_download_kwarg: - del special_kwargs["download"] + mocks[patcher.target] = stack.enter_context(patcher) + yield mocks def test_not_found_or_corrupted(self): with self.assertRaises((FileNotFoundError, RuntimeError)): @@ -469,13 +465,13 @@ def create_dataset( self, config: Optional[Dict[str, Any]] = None, inject_fake_data: bool = True, - disable_download_extract: Optional[bool] = None, + patch_checks: Optional[bool] = None, **kwargs: Any, ) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]: with super().create_dataset( config=config, inject_fake_data=inject_fake_data, - disable_download_extract=disable_download_extract, + patch_checks=patch_checks, **kwargs, ) as (dataset, info): # PIL.Image.open() only loads the image meta data upfront and keeps the file open until the first access @@ -711,3 +707,18 @@ def size(idx): create_video_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs) for idx in range(num_examples) ] + + +def create_random_string(length: int, *digits: str) -> str: + """Create a random string. + + Args: + length (int): Number of characters in the generated string. + *characters (str): Characters to sample from. If omitted defaults to :attr:`string.ascii_lowercase`. + """ + if not digits: + digits = string.ascii_lowercase + else: + digits = "".join(itertools.chain(*digits)) + + return "".join(random.choice(digits) for _ in range(length))