From 9f3dd69be38bd16957fee308cfd56e7ee99f6f93 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Thu, 15 Sep 2022 06:28:03 -0400 Subject: [PATCH] FashionMNIST/EMNIST Datamodules (#871) --- pl_bolts/datamodules/emnist_datamodule.py | 57 ++++++++----------- .../datamodules/fashion_mnist_datamodule.py | 32 +++++------ tests/datamodules/test_datamodules.py | 6 +- 3 files changed, 40 insertions(+), 55 deletions(-) diff --git a/pl_bolts/datamodules/emnist_datamodule.py b/pl_bolts/datamodules/emnist_datamodule.py index fb9831f88c..1c76cd2050 100644 --- a/pl_bolts/datamodules/emnist_datamodule.py +++ b/pl_bolts/datamodules/emnist_datamodule.py @@ -3,7 +3,6 @@ from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.transforms.dataset_normalizations import emnist_normalization from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -14,7 +13,6 @@ EMNIST = object -@under_review() class EMNISTDataModule(VisionDataModule): """ .. figure:: https://user-images.githubusercontent.com/4632336/123210742-4d6b3380-d477-11eb-80da-3e9a74a18a07.png @@ -76,6 +74,23 @@ class EMNISTDataModule(VisionDataModule): | + Args: + data_dir: Root directory of dataset. + split: The dataset has 6 different splits: ``byclass``, ``bymerge``, + ``balanced``, ``letters``, ``digits`` and ``mnist``. + This argument is passed to :class:`torchvision.datasets.EMNIST`. + val_split: Percent (float) or number (int) of samples to use for the validation split. + num_workers: How many workers to use for loading data + normalize: If ``True``, applies image normalize. + batch_size: How many samples per batch to load. + seed: Random seed to be used for train/val/test splits. + shuffle: If ``True``, shuffles the train data every epoch. + pin_memory: If ``True``, the data loader will copy Tensors into + CUDA pinned memory before returning them. + drop_last: If ``True``, drops the last incomplete batch. + strict_val_split: If ``True``, uses the validation split defined in the paper and ignores ``val_split``. + Note that it only works with ``"balanced"``, ``"digits"``, ``"letters"``, ``"mnist"`` splits. + Here is the default EMNIST, train, val, test-splits and transforms. Transforms:: @@ -87,8 +102,10 @@ class EMNISTDataModule(VisionDataModule): Example:: from pl_bolts.datamodules import EMNISTDataModule + dm = EMNISTDataModule('.') model = LitModel() + Trainer().fit(model, datamodule=dm) """ @@ -119,25 +136,6 @@ def __init__( *args: Any, **kwargs: Any, ) -> None: - """ - Args: - data_dir: Where to save/load the data. - split: The dataset has 6 different splits: ``byclass``, ``bymerge``, - ``balanced``, ``letters``, ``digits`` and ``mnist``. - This argument is passed to :class:`torchvision.datasets.EMNIST`. - val_split: Percent (float) or number (int) of samples - to use for the validation split. - num_workers: How many workers to use for loading data - normalize: If ``True``, applies image normalize. - batch_size: How many samples per batch to load. - seed: Random seed to be used for train/val/test splits. - shuffle: If ``True``, shuffles the train data every epoch. - pin_memory: If ``True``, the data loader will copy Tensors into - CUDA pinned memory before returning them. - drop_last: If ``True``, drops the last incomplete batch. - strict_val_split: If ``True``, uses the validation split defined in the paper and ignores ``val_split``. - Note that it only works with ``"balanced"``, ``"digits"``, ``"letters"``, ``"mnist"`` splits. - """ if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( "You want to use MNIST dataset loaded from `torchvision` which is not installed yet." @@ -183,13 +181,11 @@ def num_classes(self) -> int: def prepare_data(self, *args: Any, **kwargs: Any) -> None: """Saves files to ``data_dir``.""" - self.dataset_cls(self.data_dir, split=self.split, train=True, download=True) self.dataset_cls(self.data_dir, split=self.split, train=False, download=True) def setup(self, stage: Optional[str] = None) -> None: """Creates train, val, and test dataset.""" - if stage == "fit" or stage is None: train_transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms val_transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms @@ -212,14 +208,9 @@ def setup(self, stage: Optional[str] = None) -> None: ) def default_transforms(self) -> Callable: + if self.normalize: + emnist_transforms = transform_lib.Compose([transform_lib.ToTensor(), emnist_normalization(self.split)]) + else: + emnist_transforms = transform_lib.Compose([transform_lib.ToTensor()]) - return ( - transform_lib.Compose( - [ - transform_lib.ToTensor(), - emnist_normalization(self.split), - ] - ) - if self.normalize - else transform_lib.Compose([transform_lib.ToTensor()]) - ) + return emnist_transforms diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index bbb4a5a875..102e567f8a 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -2,7 +2,6 @@ from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -13,7 +12,6 @@ FashionMNIST = None -@under_review() class FashionMNISTDataModule(VisionDataModule): """ .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/ @@ -21,6 +19,18 @@ class FashionMNISTDataModule(VisionDataModule): :width: 400 :alt: Fashion MNIST + Args: + data_dir: Root directory of dataset. + val_split: Percent (float) or number (int) of samples to use for the validation split. + num_workers: Number of workers to use for loading data. + normalize: If ``True``, applies image normalization. + batch_size: Number of samples per batch to load. + seed: Random seed to be used for train/val/test splits. + shuffle: If ``True``, shuffles the train data every epoch. + pin_memory: If ``True``, the data loader will copy Tensors into CUDA pinned memory before + returning them. + drop_last: If ``True``, drops the last incomplete batch. + Specs: - 10 classes (1 per type) - Each image is (1 x 28 x 28) @@ -61,19 +71,6 @@ def __init__( *args: Any, **kwargs: Any, ) -> None: - """ - Args: - data_dir: Where to save/load the data - val_split: Percent (float) or number (int) of samples to use for the validation split - num_workers: How many workers to use for loading data - normalize: If true applies image normalize - batch_size: How many samples per batch to load - seed: Random seed to be used for train/val/test splits - shuffle: If true shuffles the train data every epoch - pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before - returning them - drop_last: If true drops the last incomplete batch - """ if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError( "You want to use FashionMNIST dataset loaded from `torchvision` which is not installed yet." @@ -95,10 +92,7 @@ def __init__( @property def num_classes(self) -> int: - """ - Return: - 10 - """ + """Returns the number of classes.""" return 10 def default_transforms(self) -> Callable: diff --git a/tests/datamodules/test_datamodules.py b/tests/datamodules/test_datamodules.py index c2ac9842cf..c81db6406b 100644 --- a/tests/datamodules/test_datamodules.py +++ b/tests/datamodules/test_datamodules.py @@ -112,7 +112,7 @@ def test_sr_datamodule(datadir): @pytest.mark.parametrize("split", ["byclass", "bymerge", "balanced", "letters", "digits", "mnist"]) @pytest.mark.parametrize("dm_cls", [BinaryEMNISTDataModule, EMNISTDataModule]) -def test_emnist_datamodules(datadir, dm_cls, split): +def test_emnist_datamodules(datadir, catch_warnings, dm_cls, split): """Test BinaryEMNIST and EMNIST datamodules download data and have the correct shape.""" dm = _create_dm(dm_cls, datadir, split=split) train_loader = dm.train_dataloader() @@ -129,7 +129,7 @@ def test_emnist_datamodules(datadir, dm_cls, split): @pytest.mark.parametrize("dm_cls", [BinaryEMNISTDataModule, EMNISTDataModule]) -def test_emnist_datamodules_with_invalid_split(datadir, dm_cls): +def test_emnist_datamodules_with_invalid_split(datadir, catch_warnings, dm_cls): """Test EMNIST datamodules raise an exception if the provided `split` doesn't exist.""" with pytest.raises(ValueError, match="Unknown value"): @@ -148,7 +148,7 @@ def test_emnist_datamodules_with_invalid_split(datadir, dm_cls): ("mnist", 10_000), ], ) -def test_emnist_datamodules_with_strict_val_split(datadir, dm_cls, split, expected_val_split): +def test_emnist_datamodules_with_strict_val_split(datadir, catch_warnings, dm_cls, split, expected_val_split): """Test EMNIST datamodules when strict_val_split is specified to use the validation set defined in the paper. Refer to https://arxiv.org/abs/1702.05373 for `expected_val_split` values.