Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port test/test_datasets.py to use pytest #4215

Merged
merged 4 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import PIL
import PIL.Image
import pytest
import torch
import torchvision.datasets
import torchvision.io
Expand Down Expand Up @@ -519,18 +520,18 @@ def _maybe_apply_patches(self, patchers):
yield mocks

def test_not_found_or_corrupted(self):
with self.assertRaises((FileNotFoundError, RuntimeError)):
with pytest.raises((FileNotFoundError, RuntimeError)):
with self.create_dataset(inject_fake_data=False):
pass

def test_smoke(self):
with self.create_dataset() as (dataset, _):
self.assertIsInstance(dataset, torchvision.datasets.VisionDataset)
assert isinstance(dataset, torchvision.datasets.VisionDataset)

@test_all_configs
def test_str_smoke(self, config):
with self.create_dataset(config) as (dataset, _):
self.assertIsInstance(str(dataset), str)
assert isinstance(str(dataset), str)

@test_all_configs
def test_feature_types(self, config):
Expand All @@ -540,23 +541,21 @@ def test_feature_types(self, config):
if len(self.FEATURE_TYPES) > 1:
actual = len(example)
expected = len(self.FEATURE_TYPES)
self.assertEqual(
actual,
expected,
f"The number of the returned features does not match the the number of elements in FEATURE_TYPES: "
f"{actual} != {expected}",
)
assert (
actual == expected
), "The number of the returned features does not match the the number of elements in FEATURE_TYPES: "
f"{actual} != {expected}"
else:
example = (example,)

for idx, (feature, expected_feature_type) in enumerate(zip(example, self.FEATURE_TYPES)):
with self.subTest(idx=idx):
self.assertIsInstance(feature, expected_feature_type)
assert isinstance(feature, expected_feature_type)

@test_all_configs
def test_num_examples(self, config):
with self.create_dataset(config) as (dataset, info):
self.assertEqual(len(dataset), info["num_examples"])
assert len(dataset) == info["num_examples"]

@test_all_configs
def test_transforms(self, config):
Expand Down
116 changes: 56 additions & 60 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import PIL
import datasets_utils
import numpy as np
import pytest
import torch
import torch.nn.functional as F
from torchvision import datasets
Expand Down Expand Up @@ -88,20 +89,20 @@ def inject_fake_data(self, tmpdir, config):
def test_folds(self):
for fold in range(10):
with self.create_dataset(split="train", folds=fold) as (dataset, _):
self.assertEqual(len(dataset), fold + 1)
assert len(dataset) == fold + 1

def test_unlabeled(self):
with self.create_dataset(split="unlabeled") as (dataset, _):
labels = [dataset[idx][1] for idx in range(len(dataset))]
self.assertTrue(all(label == -1 for label in labels))
assert all(label == -1 for label in labels)

def test_invalid_folds1(self):
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
with self.create_dataset(folds=10):
pass

def test_invalid_folds2(self):
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
with self.create_dataset(folds="0"):
pass

Expand Down Expand Up @@ -167,23 +168,19 @@ def test_combined_targets(self):

actual = len(individual_targets)
expected = len(combined_targets)
self.assertEqual(
actual,
expected,
f"The number of the returned combined targets does not match the the number targets if requested "
f"individually: {actual} != {expected}",
)
assert (
actual == expected
), "The number of the returned combined targets does not match the the number targets if requested "
f"individually: {actual} != {expected}",

for target_type, combined_target, individual_target in zip(target_types, combined_targets, individual_targets):
with self.subTest(target_type=target_type):
actual = type(combined_target)
expected = type(individual_target)
self.assertIs(
actual,
expected,
f"Type of the combined target does not match the type of the corresponding individual target: "
f"{actual} is not {expected}",
)
assert (
actual is expected
), "Type of the combined target does not match the type of the corresponding individual target: "
f"{actual} is not {expected}",


class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -363,26 +360,26 @@ def test_combined_targets(self):

with self.create_dataset(target_type=target_types) as (dataset, _):
output = dataset[0]
self.assertTrue(isinstance(output, tuple))
self.assertTrue(len(output) == 2)
self.assertTrue(isinstance(output[0], PIL.Image.Image))
self.assertTrue(isinstance(output[1], tuple))
self.assertTrue(len(output[1]) == 3)
self.assertTrue(isinstance(output[1][0], PIL.Image.Image)) # semantic
self.assertTrue(isinstance(output[1][1], dict)) # polygon
self.assertTrue(isinstance(output[1][2], PIL.Image.Image)) # color
assert isinstance(output, tuple)
assert len(output) == 2
assert isinstance(output[0], PIL.Image.Image)
assert isinstance(output[1], tuple)
assert len(output[1]) == 3
assert isinstance(output[1][0], PIL.Image.Image) # semantic
assert isinstance(output[1][1], dict) # polygon
assert isinstance(output[1][2], PIL.Image.Image) # color

def test_feature_types_target_color(self):
with self.create_dataset(target_type='color') as (dataset, _):
color_img, color_target = dataset[0]
self.assertTrue(isinstance(color_img, PIL.Image.Image))
self.assertTrue(np.array(color_target).shape[2] == 4)
assert isinstance(color_img, PIL.Image.Image)
assert np.array(color_target).shape[2] == 4

def test_feature_types_target_polygon(self):
with self.create_dataset(target_type='polygon') as (dataset, info):
polygon_img, polygon_target = dataset[0]
self.assertTrue(isinstance(polygon_img, PIL.Image.Image))
self.assertEqual(polygon_target, info['expected_polygon_target'])
assert isinstance(polygon_img, PIL.Image.Image)
(polygon_target, info['expected_polygon_target'])


class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -469,7 +466,7 @@ def test_class_to_idx(self):
with self.create_dataset() as (dataset, info):
expected = {category: label for label, category in enumerate(info["categories"])}
actual = dataset.class_to_idx
self.assertEqual(actual, expected)
assert actual == expected


class CIFAR100(CIFAR10TestCase):
Expand Down Expand Up @@ -573,33 +570,29 @@ def test_combined_targets(self):

actual = len(individual_targets)
expected = len(combined_targets)
self.assertEqual(
actual,
expected,
f"The number of the returned combined targets does not match the the number targets if requested "
f"individually: {actual} != {expected}",
)
assert (
actual == expected
), "The number of the returned combined targets does not match the the number targets if requested "
f"individually: {actual} != {expected}",

for target_type, combined_target, individual_target in zip(target_types, combined_targets, individual_targets):
with self.subTest(target_type=target_type):
actual = type(combined_target)
expected = type(individual_target)
self.assertIs(
actual,
expected,
f"Type of the combined target does not match the type of the corresponding individual target: "
f"{actual} is not {expected}",
)
assert (
actual is expected
), "Type of the combined target does not match the type of the corresponding individual target: "
f"{actual} is not {expected}",

def test_no_target(self):
with self.create_dataset(target_type=[]) as (dataset, _):
_, target = dataset[0]

self.assertIsNone(target)
assert target is None

def test_attr_names(self):
with self.create_dataset() as (dataset, info):
self.assertEqual(tuple(dataset.attr_names), info["attr_names"])
assert tuple(dataset.attr_names) == info["attr_names"]


class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -704,16 +697,16 @@ def test_annotations(self):
with self.create_dataset() as (dataset, info):
_, target = dataset[0]

self.assertIn("annotation", target)
assert "annotation" in target
annotation = target["annotation"]

self.assertIn("object", annotation)
assert "object" in annotation
objects = annotation["object"]

self.assertEqual(len(objects), 1)
assert len(objects) == 1
object = objects[0]

self.assertEqual(object, info["annotation"])
assert object == info["annotation"]


class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -789,7 +782,7 @@ def _create_annotations(self, image_ids, num_annotations_per_image):
def test_captions(self):
with self.create_dataset() as (dataset, info):
_, captions = dataset[0]
self.assertEqual(tuple(captions), tuple(info["captions"]))
assert tuple(captions) == tuple(info["captions"])


class UCF101TestCase(datasets_utils.VideoDatasetTestCase):
Expand Down Expand Up @@ -940,7 +933,7 @@ def _create_lmdb(self, root, cls):
def test_not_found_or_corrupted(self):
# LSUN does not raise built-in exception, but a custom one. It is expressive enough to not 'cast' it to
# RuntimeError or FileNotFoundError that are normally checked by this test.
with self.assertRaises(datasets_utils.lazy_importer.lmdb.Error):
with pytest.raises(datasets_utils.lazy_importer.lmdb.Error):
super().test_not_found_or_corrupted()


Expand Down Expand Up @@ -1369,7 +1362,8 @@ def _create_captions(self, num_captions_per_image):
def test_captions(self):
with self.create_dataset() as (dataset, info):
_, captions = dataset[0]
self.assertSequenceEqual(captions, info["captions"])
assert len(captions) == len(info["captions"])
assert all([a == b for a, b in zip(captions, info["captions"])])


class Flickr30kTestCase(Flickr8kTestCase):
Expand Down Expand Up @@ -1513,7 +1507,7 @@ def test_num_examples_test50k(self):
with self.create_dataset(what="test50k") as (dataset, info):
# Since the split 'test50k' selects all images beginning from the index 10000, we subtract the number of
# created examples by this.
self.assertEqual(len(dataset), info["num_examples"] - 10000)
assert len(dataset) == info["num_examples"] - 10000


class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -1578,12 +1572,13 @@ def test_is_valid_file(self, config):
with self.create_dataset(
config, extensions=None, is_valid_file=lambda file: pathlib.Path(file).suffix[1:] in extensions
) as (dataset, info):
self.assertEqual(len(dataset), info["num_examples"])
assert len(dataset) == info["num_examples"]

@datasets_utils.test_all_configs
def test_classes(self, config):
with self.create_dataset(config) as (dataset, info):
self.assertSequenceEqual(dataset.classes, info["classes"])
assert len(dataset.classes) == len(info["classes"])
assert all([a == b for a, b in zip(dataset.classes, info["classes"])])


class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase):
Expand All @@ -1603,7 +1598,8 @@ def inject_fake_data(self, tmpdir, config):
@datasets_utils.test_all_configs
def test_classes(self, config):
with self.create_dataset(config) as (dataset, info):
self.assertSequenceEqual(dataset.classes, info["classes"])
assert len(dataset.classes) == len(info["classes"])
assert all([a == b for a, b in zip(dataset.classes, info["classes"])])


class KittiTestCase(datasets_utils.ImageDatasetTestCase):
Expand Down Expand Up @@ -1742,15 +1738,15 @@ def inject_fake_data(self, tmpdir, config):
def test_classes(self):
classes = list(map(lambda x: x[0], self._CATEGORIES_CONTENT))
with self.create_dataset() as (dataset, _):
self.assertEqual(dataset.classes, classes)
assert dataset.classes == classes

def test_class_to_idx(self):
class_to_idx = dict(self._CATEGORIES_CONTENT)
with self.create_dataset() as (dataset, _):
self.assertEqual(dataset.class_to_idx, class_to_idx)
assert dataset.class_to_idx == class_to_idx

def test_images_download_preexisting(self):
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
with self.create_dataset({'download': True}):
pass

Expand Down Expand Up @@ -1788,9 +1784,9 @@ def test_targets(self):
with self.create_dataset(target_type=target_types, version="2021_valid") as (dataset, _):
items = [d[1] for d in dataset]
for i, item in enumerate(items):
self.assertEqual(dataset.category_name("kingdom", item[0]), "Akingdom")
self.assertEqual(dataset.category_name("phylum", item[1]), f"{i // 3}phylum")
self.assertEqual(item[6], i // 3)
assert dataset.category_name("kingdom", item[0]) == "Akingdom"
assert dataset.category_name("phylum", item[1]) == f"{i // 3}phylum"
assert item[6] == i // 3


if __name__ == "__main__":
Expand Down