Skip to content

Commit

Permalink
PR: Add UCF101 dataset tests (#2548)
Browse files Browse the repository at this point in the history
* Add fake data generator for UCF101

* Minor error correction

* Reduce total number of categories

* Fix naming

* Increase length

* Store in uint8

* Close fds

* Add assertGreater

* Add dimension tests

* Use numel instead of size

* Iterate over folds and splits
  • Loading branch information
andfoy authored Aug 4, 2020
1 parent c2bbefc commit 23295fb
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 1 deletion.
47 changes: 47 additions & 0 deletions test/fakedata_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import torch
from common_utils import get_tmp_dir
import pickle
import random
from itertools import cycle
from torchvision.io.video import write_video


@contextlib.contextmanager
Expand Down Expand Up @@ -265,3 +268,47 @@ def voc_root():
f.write('test')

yield tmp_dir


@contextlib.contextmanager
def ucf101_root():
with get_tmp_dir() as tmp_dir:
ucf_dir = os.path.join(tmp_dir, 'UCF-101')
video_dir = os.path.join(ucf_dir, 'video')
annotations = os.path.join(ucf_dir, 'annotations')

os.makedirs(ucf_dir)
os.makedirs(video_dir)
os.makedirs(annotations)

fold_files = []
for split in {'train', 'test'}:
for fold in range(1, 4):
fold_file = '{:s}list{:02d}.txt'.format(split, fold)
fold_files.append(os.path.join(annotations, fold_file))

file_handles = [open(x, 'w') for x in fold_files]
file_iter = cycle(file_handles)

for i in range(0, 2):
current_class = 'class_{0}'.format(i + 1)
class_dir = os.path.join(video_dir, current_class)
os.makedirs(class_dir)
for group in range(0, 3):
for clip in range(0, 4):
# Save sample file
clip_name = 'v_{0}_g{1}_c{2}.avi'.format(
current_class, group, clip)
clip_path = os.path.join(class_dir, clip_name)
length = random.randrange(10, 21)
this_clip = torch.randint(
0, 256, (length * 25, 320, 240, 3), dtype=torch.uint8)
write_video(clip_path, this_clip, 25)
# Add to annotations
ann_file = next(file_iter)
ann_file.write('{0}\n'.format(
os.path.join(current_class, clip_name)))
# Close all file descriptors
for f in file_handles:
f.close()
yield (video_dir, annotations)
28 changes: 27 additions & 1 deletion test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torchvision
from common_utils import get_tmp_dir
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
cityscapes_root, svhn_root, voc_root
cityscapes_root, svhn_root, voc_root, ucf101_root
import xml.etree.ElementTree as ET


Expand All @@ -19,6 +19,12 @@
except ImportError:
HAS_SCIPY = False

try:
import av
HAS_PYAV = True
except ImportError:
HAS_PYAV = False


class Tester(unittest.TestCase):
def generic_classification_dataset_test(self, dataset, num_images=1):
Expand Down Expand Up @@ -254,6 +260,26 @@ def test_voc_parse_xml(self, mock_download_extract):
}]
}})

@unittest.skipIf(not HAS_PYAV, "PyAV unavailable")
def test_ucf101(self):
with ucf101_root() as (root, ann_root):
for split in {True, False}:
for fold in range(1, 4):
for length in {10, 15, 20}:
dataset = torchvision.datasets.UCF101(
root, ann_root, length, fold=fold, train=split)
self.assertGreater(len(dataset), 0)

video, audio, label = dataset[0]
self.assertEqual(video.size(), (length, 320, 240, 3))
self.assertEqual(audio.numel(), 0)
self.assertEqual(label, 0)

video, audio, label = dataset[len(dataset) - 1]
self.assertEqual(video.size(), (length, 320, 240, 3))
self.assertEqual(audio.numel(), 0)
self.assertEqual(label, 1)


if __name__ == '__main__':
unittest.main()

0 comments on commit 23295fb

Please sign in to comment.