Skip to content

Commit

Permalink
add tests for LSUN (#3454)
Browse files Browse the repository at this point in the history
Summary:

Reviewed By: fmassa

Differential Revision: D26756274

fbshipit-source-id: b489819c79dfb03393a7ee9c2638f9a5bc35c11e

Co-authored-by: vfdev <vfdev.5@gmail.com>
Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
  • Loading branch information
3 people authored and facebook-github-bot committed Mar 4, 2021
1 parent 368c2ed commit a4f3f6d
Showing 1 changed file with 88 additions and 0 deletions.
88 changes: 88 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import shutil
import json
import random
import string
import io


try:
Expand Down Expand Up @@ -954,5 +956,91 @@ def _create_annotation_file(self, root, name, video_files):
fh.writelines(f"{file}\n" for file in sorted(video_files))


class LSUNTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.LSUN

REQUIRED_PACKAGES = ("lmdb",)
CONFIGS = datasets_utils.combinations_grid(
classes=("train", "test", "val", ["bedroom_train", "church_outdoor_train"])
)

_CATEGORIES = (
"bedroom",
"bridge",
"church_outdoor",
"classroom",
"conference_room",
"dining_room",
"kitchen",
"living_room",
"restaurant",
"tower",
)

def inject_fake_data(self, tmpdir, config):
root = pathlib.Path(tmpdir)

num_images = 0
for cls in self._parse_classes(config["classes"]):
num_images += self._create_lmdb(root, cls)

return num_images

@contextlib.contextmanager
def create_dataset(
self,
*args, **kwargs
):
with super().create_dataset(*args, **kwargs) as output:
yield output
# Currently datasets.LSUN caches the keys in the current directory rather than in the root directory. Thus,
# this creates a number of unique _cache_* files in the current directory that will not be removed together
# with the temporary directory
for file in os.listdir(os.getcwd()):
if file.startswith("_cache_"):
os.remove(file)

def _parse_classes(self, classes):
if not isinstance(classes, str):
return classes

split = classes
if split == "test":
return [split]

return [f"{category}_{split}" for category in self._CATEGORIES]

def _create_lmdb(self, root, cls):
lmdb = datasets_utils.lazy_importer.lmdb
hexdigits_lowercase = string.digits + string.ascii_lowercase[:6]

folder = f"{cls}_lmdb"

num_images = torch.randint(1, 4, size=()).item()
format = "png"
files = datasets_utils.create_image_folder(root, folder, lambda idx: f"{idx}.{format}", num_images)

with lmdb.open(str(root / folder)) as env, env.begin(write=True) as txn:
for file in files:
key = "".join(random.choice(hexdigits_lowercase) for _ in range(40)).encode()

buffer = io.BytesIO()
Image.open(file).save(buffer, format)
buffer.seek(0)
value = buffer.read()

txn.put(key, value)

os.remove(file)

return num_images

def test_not_found_or_corrupted(self):
# LSUN does not raise built-in exception, but a custom one. It is expressive enough to not 'cast' it to
# RuntimeError or FileNotFoundError that are normally checked by this test.
with self.assertRaises(datasets_utils.lazy_importer.lmdb.Error):
super().test_not_found_or_corrupted()


if __name__ == "__main__":
unittest.main()

0 comments on commit a4f3f6d

Please sign in to comment.