Skip to content

Commit

Permalink
Merge branch 'master' into tests-lsun
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Feb 26, 2021
2 parents 4c30595 + 2e8c124 commit a55eebc
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 72 deletions.
153 changes: 82 additions & 71 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import itertools
import os
import pathlib
import random
import string
import unittest
import unittest.mock
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
Expand All @@ -32,6 +34,7 @@
"create_image_folder",
"create_video_file",
"create_video_folder",
"create_random_string",
]


Expand Down Expand Up @@ -93,14 +96,6 @@ def inner_wrapper(*args, **kwargs):
return outer_wrapper


# As of Python 3.7 this is provided by contextlib
# https://docs.python.org/3.7/library/contextlib.html#contextlib.nullcontext
# TODO: If the minimum Python requirement is >= 3.7, replace this
@contextlib.contextmanager
def nullcontext(enter_result=None):
yield enter_result


def test_all_configs(test):
"""Decorator to run test against all configurations.
Expand All @@ -116,7 +111,7 @@ def test_foo(self, config):

@functools.wraps(test)
def wrapper(self):
for config in self.CONFIGS:
for config in self.CONFIGS or (self._DEFAULT_CONFIG,):
with self.subTest(**config):
test(self, config)

Expand Down Expand Up @@ -207,6 +202,8 @@ def test_baz(self):
CONFIGS = None
REQUIRED_PACKAGES = None

_DEFAULT_CONFIG = None

_TRANSFORM_KWARGS = {
"transform",
"target_transform",
Expand Down Expand Up @@ -268,7 +265,7 @@ def create_dataset(
self,
config: Optional[Dict[str, Any]] = None,
inject_fake_data: bool = True,
disable_download_extract: Optional[bool] = None,
patch_checks: Optional[bool] = None,
**kwargs: Any,
) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
r"""Create the dataset in a temporary directory.
Expand All @@ -278,8 +275,8 @@ def create_dataset(
default configuration is used.
inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
creating the dataset.
disable_download_extract (Optional[bool]): If ``True`` disable download and extract logic while creating
the dataset. If ``None`` (default) this takes the same value as ``inject_fake_data``.
patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
omitted defaults to the same value as ``inject_fake_data``.
**kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they
overlap with ``config``.
Expand All @@ -288,43 +285,28 @@ def create_dataset(
info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
for details.
"""
if config is None:
config = self.CONFIGS[0].copy()
default_config = self._DEFAULT_CONFIG.copy()
if config is not None:
default_config.update(config)
config = default_config

if patch_checks is None:
patch_checks = inject_fake_data

special_kwargs, other_kwargs = self._split_kwargs(kwargs)
if "download" in self._HAS_SPECIAL_KWARG:
special_kwargs["download"] = False
config.update(other_kwargs)

if disable_download_extract is None:
disable_download_extract = inject_fake_data
patchers = self._patch_download_extract()
if patch_checks:
patchers.update(self._patch_checks())

with get_tmp_dir() as tmpdir:
args = self.dataset_args(tmpdir, config)
info = self._inject_fake_data(tmpdir, config) if inject_fake_data else None

if inject_fake_data:
info = self.inject_fake_data(tmpdir, config)
if info is None:
raise UsageError(
"The method 'inject_fake_data' needs to return at least an integer indicating the number of "
"examples for the current configuration."
)
elif isinstance(info, int):
info = dict(num_examples=info)
elif not isinstance(info, dict):
raise UsageError(
f"The additional information returned by the method 'inject_fake_data' must be either an "
f"integer indicating the number of examples for the current configuration or a dictionary with "
f"the same content. Got {type(info)} instead."
)
elif "num_examples" not in info:
raise UsageError(
"The information dictionary returned by the method 'inject_fake_data' must contain a "
"'num_examples' field that holds the number of examples for the current configuration."
)
else:
info = None

cm = self._disable_download_extract if disable_download_extract else nullcontext
with cm(special_kwargs), disable_console_output():
with self._maybe_apply_patches(patchers), disable_console_output():
dataset = self.DATASET_CLASS(*args, **config, **special_kwargs)

yield dataset, info
Expand Down Expand Up @@ -352,19 +334,17 @@ def _verify_required_public_class_attributes(cls):
@classmethod
def _populate_private_class_attributes(cls):
argspec = inspect.getfullargspec(cls.DATASET_CLASS.__init__)

cls._DEFAULT_CONFIG = {
kwarg: default
for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults)
if kwarg not in cls._SPECIAL_KWARGS
}

cls._HAS_SPECIAL_KWARG = {name for name in cls._SPECIAL_KWARGS if name in argspec.args}

@classmethod
def _process_optional_public_class_attributes(cls):
argspec = inspect.getfullargspec(cls.DATASET_CLASS.__init__)
if cls.CONFIGS is None:
config = {
kwarg: default
for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults)
if kwarg not in cls._SPECIAL_KWARGS
}
cls.CONFIGS = (config,)

if cls.REQUIRED_PACKAGES is not None:
try:
for pkg in cls.REQUIRED_PACKAGES:
Expand All @@ -380,28 +360,44 @@ def _split_kwargs(self, kwargs):
other_kwargs = {key: special_kwargs.pop(key) for key in set(special_kwargs.keys()) - self._SPECIAL_KWARGS}
return special_kwargs, other_kwargs

@contextlib.contextmanager
def _disable_download_extract(self, special_kwargs):
inject_download_kwarg = "download" in self._HAS_SPECIAL_KWARG and "download" not in special_kwargs
if inject_download_kwarg:
special_kwargs["download"] = False
def _inject_fake_data(self, tmpdir, config):
info = self.inject_fake_data(tmpdir, config)
if info is None:
raise UsageError(
"The method 'inject_fake_data' needs to return at least an integer indicating the number of "
"examples for the current configuration."
)
elif isinstance(info, int):
info = dict(num_examples=info)
elif not isinstance(info, dict):
raise UsageError(
f"The additional information returned by the method 'inject_fake_data' must be either an "
f"integer indicating the number of examples for the current configuration or a dictionary with "
f"the same content. Got {type(info)} instead."
)
elif "num_examples" not in info:
raise UsageError(
"The information dictionary returned by the method 'inject_fake_data' must contain a "
"'num_examples' field that holds the number of examples for the current configuration."
)
return info

def _patch_download_extract(self):
module = inspect.getmodule(self.DATASET_CLASS).__name__
return {unittest.mock.patch(f"{module}.{function}") for function in self._DOWNLOAD_EXTRACT_FUNCTIONS}

def _patch_checks(self):
module = inspect.getmodule(self.DATASET_CLASS).__name__
return {unittest.mock.patch(f"{module}.{function}", return_value=True) for function in self._CHECK_FUNCTIONS}

@contextlib.contextmanager
def _maybe_apply_patches(self, patchers):
with contextlib.ExitStack() as stack:
mocks = {}
for function, kwargs in itertools.chain(
zip(self._CHECK_FUNCTIONS, [dict(return_value=True)] * len(self._CHECK_FUNCTIONS)),
zip(self._DOWNLOAD_EXTRACT_FUNCTIONS, [dict()] * len(self._DOWNLOAD_EXTRACT_FUNCTIONS)),
):
for patcher in patchers:
with contextlib.suppress(AttributeError):
patcher = unittest.mock.patch(f"{module}.{function}", **kwargs)
mocks[function] = stack.enter_context(patcher)

try:
yield mocks
finally:
if inject_download_kwarg:
del special_kwargs["download"]
mocks[patcher.target] = stack.enter_context(patcher)
yield mocks

def test_not_found_or_corrupted(self):
with self.assertRaises((FileNotFoundError, RuntimeError)):
Expand Down Expand Up @@ -469,13 +465,13 @@ def create_dataset(
self,
config: Optional[Dict[str, Any]] = None,
inject_fake_data: bool = True,
disable_download_extract: Optional[bool] = None,
patch_checks: Optional[bool] = None,
**kwargs: Any,
) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
with super().create_dataset(
config=config,
inject_fake_data=inject_fake_data,
disable_download_extract=disable_download_extract,
patch_checks=patch_checks,
**kwargs,
) as (dataset, info):
# PIL.Image.open() only loads the image meta data upfront and keeps the file open until the first access
Expand Down Expand Up @@ -572,7 +568,7 @@ def create_image_file(

image = create_image_or_video_tensor(size)
file = pathlib.Path(root) / name
PIL.Image.fromarray(image.permute(2, 1, 0).numpy()).save(file)
PIL.Image.fromarray(image.permute(2, 1, 0).numpy()).save(file, **kwargs)
return file


Expand Down Expand Up @@ -708,6 +704,21 @@ def size(idx):
os.makedirs(root)

return [
create_video_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size)
create_video_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs)
for idx in range(num_examples)
]


def create_random_string(length: int, *digits: str) -> str:
"""Create a random string.
Args:
length (int): Number of characters in the generated string.
*characters (str): Characters to sample from. If omitted defaults to :attr:`string.ascii_lowercase`.
"""
if not digits:
digits = string.ascii_lowercase
else:
digits = "".join(itertools.chain(*digits))

return "".join(random.choice(digits) for _ in range(length))
12 changes: 11 additions & 1 deletion torchvision/datasets/phototour.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,17 @@


class PhotoTour(VisionDataset):
"""`Learning Local Image Descriptors Data <http://phototour.cs.washington.edu/patches/default.htm>`_ Dataset.
"""`Multi-view Stereo Correspondence <http://matthewalunbrown.com/patchdata/patchdata.html>`_ Dataset.
.. note::
We only provide the newer version of the dataset, since the authors state that it
is more suitable for training descriptors based on difference of Gaussian, or Harris corners, as the
patches are centred on real interest point detections, rather than being projections of 3D points as is the
case in the old dataset.
The original dataset is available under http://phototour.cs.washington.edu/patches/default.htm.
Args:
Expand Down

0 comments on commit a55eebc

Please sign in to comment.