Skip to content

Commit

Permalink
Port test/test_datasets.py to use pytest (#4215)
Browse files Browse the repository at this point in the history
* Port test_datasets.py to use pytest

* A better replacement of self.assertSequenceEqual

* Migrate from equality check to identity check
  • Loading branch information
yiwen-song authored Jul 28, 2021
1 parent e2dbadb commit b29ed34
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 71 deletions.
21 changes: 10 additions & 11 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import PIL
import PIL.Image
import pytest
import torch
import torchvision.datasets
import torchvision.io
Expand Down Expand Up @@ -519,18 +520,18 @@ def _maybe_apply_patches(self, patchers):
yield mocks

def test_not_found_or_corrupted(self):
with self.assertRaises((FileNotFoundError, RuntimeError)):
with pytest.raises((FileNotFoundError, RuntimeError)):
with self.create_dataset(inject_fake_data=False):
pass

def test_smoke(self):
with self.create_dataset() as (dataset, _):
self.assertIsInstance(dataset, torchvision.datasets.VisionDataset)
assert isinstance(dataset, torchvision.datasets.VisionDataset)

@test_all_configs
def test_str_smoke(self, config):
with self.create_dataset(config) as (dataset, _):
self.assertIsInstance(str(dataset), str)
assert isinstance(str(dataset), str)

@test_all_configs
def test_feature_types(self, config):
Expand All @@ -540,23 +541,21 @@ def test_feature_types(self, config):
if len(self.FEATURE_TYPES) > 1:
actual = len(example)
expected = len(self.FEATURE_TYPES)
self.assertEqual(
actual,
expected,
f"The number of the returned features does not match the the number of elements in FEATURE_TYPES: "
f"{actual} != {expected}",
)
assert (
actual == expected
), "The number of the returned features does not match the the number of elements in FEATURE_TYPES: "
f"{actual} != {expected}"
else:
example = (example,)

for idx, (feature, expected_feature_type) in enumerate(zip(example, self.FEATURE_TYPES)):
with self.subTest(idx=idx):
self.assertIsInstance(feature, expected_feature_type)
assert isinstance(feature, expected_feature_type)

@test_all_configs
def test_num_examples(self, config):
with self.create_dataset(config) as (dataset, info):
self.assertEqual(len(dataset), info["num_examples"])
assert len(dataset) == info["num_examples"]

@test_all_configs
def test_transforms(self, config):
Expand Down
116 changes: 56 additions & 60 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import PIL
import datasets_utils
import numpy as np
import pytest
import torch
import torch.nn.functional as F
from torchvision import datasets
Expand Down Expand Up @@ -88,20 +89,20 @@ def inject_fake_data(self, tmpdir, config):
def test_folds(self):
for fold in range(10):
with self.create_dataset(split="train", folds=fold) as (dataset, _):
self.assertEqual(len(dataset), fold + 1)
assert len(dataset) == fold + 1

def test_unlabeled(self):
with self.create_dataset(split="unlabeled") as (dataset, _):
labels = [dataset[idx][1] for idx in range(len(dataset))]
self.assertTrue(all(label == -1 for label in labels))
assert all(label == -1 for label in labels)

def test_invalid_folds1(self):
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
with self.create_dataset(folds=10):
pass

def test_invalid_folds2(self):
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
with self.create_dataset(folds="0"):
pass

Expand Down Expand Up @@ -167,23 +168,19 @@ def test_combined_targets(self):

actual = len(individual_targets)
expected = len(combined_targets)
self.assertEqual(
actual,
expected,
f"The number of the returned combined targets does not match the the number targets if requested "
f"individually: {actual} != {expected}",
)
assert (
actual == expected
), "The number of the returned combined targets does not match the the number targets if requested "
f"individually: {actual} != {expected}",

for target_type, combined_target, individual_target in zip(target_types, combined_targets, individual_targets):
with self.subTest(target_type=target_type):
actual = type(combined_target)
expected = type(individual_target)
self.assertIs(
actual,
expected,
f"Type of the combined target does not match the type of the corresponding individual target: "
f"{actual} is not {expected}",
)
assert (
actual is expected
), "Type of the combined target does not match the type of the corresponding individual target: "
f"{actual} is not {expected}",


class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -363,26 +360,26 @@ def test_combined_targets(self):

with self.create_dataset(target_type=target_types) as (dataset, _):
output = dataset[0]
self.assertTrue(isinstance(output, tuple))
self.assertTrue(len(output) == 2)
self.assertTrue(isinstance(output[0], PIL.Image.Image))
self.assertTrue(isinstance(output[1], tuple))
self.assertTrue(len(output[1]) == 3)
self.assertTrue(isinstance(output[1][0], PIL.Image.Image)) # semantic
self.assertTrue(isinstance(output[1][1], dict)) # polygon
self.assertTrue(isinstance(output[1][2], PIL.Image.Image)) # color
assert isinstance(output, tuple)
assert len(output) == 2
assert isinstance(output[0], PIL.Image.Image)
assert isinstance(output[1], tuple)
assert len(output[1]) == 3
assert isinstance(output[1][0], PIL.Image.Image) # semantic
assert isinstance(output[1][1], dict) # polygon
assert isinstance(output[1][2], PIL.Image.Image) # color

def test_feature_types_target_color(self):
with self.create_dataset(target_type='color') as (dataset, _):
color_img, color_target = dataset[0]
self.assertTrue(isinstance(color_img, PIL.Image.Image))
self.assertTrue(np.array(color_target).shape[2] == 4)
assert isinstance(color_img, PIL.Image.Image)
assert np.array(color_target).shape[2] == 4

def test_feature_types_target_polygon(self):
with self.create_dataset(target_type='polygon') as (dataset, info):
polygon_img, polygon_target = dataset[0]
self.assertTrue(isinstance(polygon_img, PIL.Image.Image))
self.assertEqual(polygon_target, info['expected_polygon_target'])
assert isinstance(polygon_img, PIL.Image.Image)
(polygon_target, info['expected_polygon_target'])


class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -469,7 +466,7 @@ def test_class_to_idx(self):
with self.create_dataset() as (dataset, info):
expected = {category: label for label, category in enumerate(info["categories"])}
actual = dataset.class_to_idx
self.assertEqual(actual, expected)
assert actual == expected


class CIFAR100(CIFAR10TestCase):
Expand Down Expand Up @@ -573,33 +570,29 @@ def test_combined_targets(self):

actual = len(individual_targets)
expected = len(combined_targets)
self.assertEqual(
actual,
expected,
f"The number of the returned combined targets does not match the the number targets if requested "
f"individually: {actual} != {expected}",
)
assert (
actual == expected
), "The number of the returned combined targets does not match the the number targets if requested "
f"individually: {actual} != {expected}",

for target_type, combined_target, individual_target in zip(target_types, combined_targets, individual_targets):
with self.subTest(target_type=target_type):
actual = type(combined_target)
expected = type(individual_target)
self.assertIs(
actual,
expected,
f"Type of the combined target does not match the type of the corresponding individual target: "
f"{actual} is not {expected}",
)
assert (
actual is expected
), "Type of the combined target does not match the type of the corresponding individual target: "
f"{actual} is not {expected}",

def test_no_target(self):
with self.create_dataset(target_type=[]) as (dataset, _):
_, target = dataset[0]

self.assertIsNone(target)
assert target is None

def test_attr_names(self):
with self.create_dataset() as (dataset, info):
self.assertEqual(tuple(dataset.attr_names), info["attr_names"])
assert tuple(dataset.attr_names) == info["attr_names"]


class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -704,16 +697,16 @@ def test_annotations(self):
with self.create_dataset() as (dataset, info):
_, target = dataset[0]

self.assertIn("annotation", target)
assert "annotation" in target
annotation = target["annotation"]

self.assertIn("object", annotation)
assert "object" in annotation
objects = annotation["object"]

self.assertEqual(len(objects), 1)
assert len(objects) == 1
object = objects[0]

self.assertEqual(object, info["annotation"])
assert object == info["annotation"]


class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -789,7 +782,7 @@ def _create_annotations(self, image_ids, num_annotations_per_image):
def test_captions(self):
with self.create_dataset() as (dataset, info):
_, captions = dataset[0]
self.assertEqual(tuple(captions), tuple(info["captions"]))
assert tuple(captions) == tuple(info["captions"])


class UCF101TestCase(datasets_utils.VideoDatasetTestCase):
Expand Down Expand Up @@ -940,7 +933,7 @@ def _create_lmdb(self, root, cls):
def test_not_found_or_corrupted(self):
# LSUN does not raise built-in exception, but a custom one. It is expressive enough to not 'cast' it to
# RuntimeError or FileNotFoundError that are normally checked by this test.
with self.assertRaises(datasets_utils.lazy_importer.lmdb.Error):
with pytest.raises(datasets_utils.lazy_importer.lmdb.Error):
super().test_not_found_or_corrupted()


Expand Down Expand Up @@ -1369,7 +1362,8 @@ def _create_captions(self, num_captions_per_image):
def test_captions(self):
with self.create_dataset() as (dataset, info):
_, captions = dataset[0]
self.assertSequenceEqual(captions, info["captions"])
assert len(captions) == len(info["captions"])
assert all([a == b for a, b in zip(captions, info["captions"])])


class Flickr30kTestCase(Flickr8kTestCase):
Expand Down Expand Up @@ -1513,7 +1507,7 @@ def test_num_examples_test50k(self):
with self.create_dataset(what="test50k") as (dataset, info):
# Since the split 'test50k' selects all images beginning from the index 10000, we subtract the number of
# created examples by this.
self.assertEqual(len(dataset), info["num_examples"] - 10000)
assert len(dataset) == info["num_examples"] - 10000


class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -1578,12 +1572,13 @@ def test_is_valid_file(self, config):
with self.create_dataset(
config, extensions=None, is_valid_file=lambda file: pathlib.Path(file).suffix[1:] in extensions
) as (dataset, info):
self.assertEqual(len(dataset), info["num_examples"])
assert len(dataset) == info["num_examples"]

@datasets_utils.test_all_configs
def test_classes(self, config):
with self.create_dataset(config) as (dataset, info):
self.assertSequenceEqual(dataset.classes, info["classes"])
assert len(dataset.classes) == len(info["classes"])
assert all([a == b for a, b in zip(dataset.classes, info["classes"])])


class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase):
Expand All @@ -1603,7 +1598,8 @@ def inject_fake_data(self, tmpdir, config):
@datasets_utils.test_all_configs
def test_classes(self, config):
with self.create_dataset(config) as (dataset, info):
self.assertSequenceEqual(dataset.classes, info["classes"])
assert len(dataset.classes) == len(info["classes"])
assert all([a == b for a, b in zip(dataset.classes, info["classes"])])


class KittiTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -1742,15 +1738,15 @@ def inject_fake_data(self, tmpdir, config):
def test_classes(self):
classes = list(map(lambda x: x[0], self._CATEGORIES_CONTENT))
with self.create_dataset() as (dataset, _):
self.assertEqual(dataset.classes, classes)
assert dataset.classes == classes

def test_class_to_idx(self):
class_to_idx = dict(self._CATEGORIES_CONTENT)
with self.create_dataset() as (dataset, _):
self.assertEqual(dataset.class_to_idx, class_to_idx)
assert dataset.class_to_idx == class_to_idx

def test_images_download_preexisting(self):
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
with self.create_dataset({'download': True}):
pass

Expand Down Expand Up @@ -1788,9 +1784,9 @@ def test_targets(self):
with self.create_dataset(target_type=target_types, version="2021_valid") as (dataset, _):
items = [d[1] for d in dataset]
for i, item in enumerate(items):
self.assertEqual(dataset.category_name("kingdom", item[0]), "Akingdom")
self.assertEqual(dataset.category_name("phylum", item[1]), f"{i // 3}phylum")
self.assertEqual(item[6], i // 3)
assert dataset.category_name("kingdom", item[0]) == "Akingdom"
assert dataset.category_name("phylum", item[1]) == f"{i // 3}phylum"
assert item[6] == i // 3


if __name__ == "__main__":
Expand Down

0 comments on commit b29ed34

Please sign in to comment.