Skip to content

Commit

Permalink
FashionMNIST/EMNIST Datamodules (#871)
Browse files Browse the repository at this point in the history
  • Loading branch information
matsumotosan authored Sep 15, 2022
1 parent a4855d0 commit 9f3dd69
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 55 deletions.
57 changes: 24 additions & 33 deletions pl_bolts/datamodules/emnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.transforms.dataset_normalizations import emnist_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -14,7 +13,6 @@
EMNIST = object


@under_review()
class EMNISTDataModule(VisionDataModule):
"""
.. figure:: https://user-images.githubusercontent.com/4632336/123210742-4d6b3380-d477-11eb-80da-3e9a74a18a07.png
Expand Down Expand Up @@ -76,6 +74,23 @@ class EMNISTDataModule(VisionDataModule):
|
Args:
data_dir: Root directory of dataset.
split: The dataset has 6 different splits: ``byclass``, ``bymerge``,
``balanced``, ``letters``, ``digits`` and ``mnist``.
This argument is passed to :class:`torchvision.datasets.EMNIST`.
val_split: Percent (float) or number (int) of samples to use for the validation split.
num_workers: How many workers to use for loading data
normalize: If ``True``, applies image normalize.
batch_size: How many samples per batch to load.
seed: Random seed to be used for train/val/test splits.
shuffle: If ``True``, shuffles the train data every epoch.
pin_memory: If ``True``, the data loader will copy Tensors into
CUDA pinned memory before returning them.
drop_last: If ``True``, drops the last incomplete batch.
strict_val_split: If ``True``, uses the validation split defined in the paper and ignores ``val_split``.
Note that it only works with ``"balanced"``, ``"digits"``, ``"letters"``, ``"mnist"`` splits.
Here is the default EMNIST, train, val, test-splits and transforms.
Transforms::
Expand All @@ -87,8 +102,10 @@ class EMNISTDataModule(VisionDataModule):
Example::
from pl_bolts.datamodules import EMNISTDataModule
dm = EMNISTDataModule('.')
model = LitModel()
Trainer().fit(model, datamodule=dm)
"""

Expand Down Expand Up @@ -119,25 +136,6 @@ def __init__(
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: Where to save/load the data.
split: The dataset has 6 different splits: ``byclass``, ``bymerge``,
``balanced``, ``letters``, ``digits`` and ``mnist``.
This argument is passed to :class:`torchvision.datasets.EMNIST`.
val_split: Percent (float) or number (int) of samples
to use for the validation split.
num_workers: How many workers to use for loading data
normalize: If ``True``, applies image normalize.
batch_size: How many samples per batch to load.
seed: Random seed to be used for train/val/test splits.
shuffle: If ``True``, shuffles the train data every epoch.
pin_memory: If ``True``, the data loader will copy Tensors into
CUDA pinned memory before returning them.
drop_last: If ``True``, drops the last incomplete batch.
strict_val_split: If ``True``, uses the validation split defined in the paper and ignores ``val_split``.
Note that it only works with ``"balanced"``, ``"digits"``, ``"letters"``, ``"mnist"`` splits.
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
"You want to use MNIST dataset loaded from `torchvision` which is not installed yet."
Expand Down Expand Up @@ -183,13 +181,11 @@ def num_classes(self) -> int:

def prepare_data(self, *args: Any, **kwargs: Any) -> None:
"""Saves files to ``data_dir``."""

self.dataset_cls(self.data_dir, split=self.split, train=True, download=True)
self.dataset_cls(self.data_dir, split=self.split, train=False, download=True)

def setup(self, stage: Optional[str] = None) -> None:
"""Creates train, val, and test dataset."""

if stage == "fit" or stage is None:
train_transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms
val_transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms
Expand All @@ -212,14 +208,9 @@ def setup(self, stage: Optional[str] = None) -> None:
)

def default_transforms(self) -> Callable:
if self.normalize:
emnist_transforms = transform_lib.Compose([transform_lib.ToTensor(), emnist_normalization(self.split)])
else:
emnist_transforms = transform_lib.Compose([transform_lib.ToTensor()])

return (
transform_lib.Compose(
[
transform_lib.ToTensor(),
emnist_normalization(self.split),
]
)
if self.normalize
else transform_lib.Compose([transform_lib.ToTensor()])
)
return emnist_transforms
32 changes: 13 additions & 19 deletions pl_bolts/datamodules/fashion_mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -13,14 +12,25 @@
FashionMNIST = None


@under_review()
class FashionMNISTDataModule(VisionDataModule):
"""
.. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/
wp-content/uploads/2019/02/Plot-of-a-Subset-of-Images-from-the-Fashion-MNIST-Dataset.png
:width: 400
:alt: Fashion MNIST
Args:
data_dir: Root directory of dataset.
val_split: Percent (float) or number (int) of samples to use for the validation split.
num_workers: Number of workers to use for loading data.
normalize: If ``True``, applies image normalization.
batch_size: Number of samples per batch to load.
seed: Random seed to be used for train/val/test splits.
shuffle: If ``True``, shuffles the train data every epoch.
pin_memory: If ``True``, the data loader will copy Tensors into CUDA pinned memory before
returning them.
drop_last: If ``True``, drops the last incomplete batch.
Specs:
- 10 classes (1 per type)
- Each image is (1 x 28 x 28)
Expand Down Expand Up @@ -61,19 +71,6 @@ def __init__(
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: Where to save/load the data
val_split: Percent (float) or number (int) of samples to use for the validation split
num_workers: How many workers to use for loading data
normalize: If true applies image normalize
batch_size: How many samples per batch to load
seed: Random seed to be used for train/val/test splits
shuffle: If true shuffles the train data every epoch
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
"You want to use FashionMNIST dataset loaded from `torchvision` which is not installed yet."
Expand All @@ -95,10 +92,7 @@ def __init__(

@property
def num_classes(self) -> int:
"""
Return:
10
"""
"""Returns the number of classes."""
return 10

def default_transforms(self) -> Callable:
Expand Down
6 changes: 3 additions & 3 deletions tests/datamodules/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_sr_datamodule(datadir):

@pytest.mark.parametrize("split", ["byclass", "bymerge", "balanced", "letters", "digits", "mnist"])
@pytest.mark.parametrize("dm_cls", [BinaryEMNISTDataModule, EMNISTDataModule])
def test_emnist_datamodules(datadir, dm_cls, split):
def test_emnist_datamodules(datadir, catch_warnings, dm_cls, split):
"""Test BinaryEMNIST and EMNIST datamodules download data and have the correct shape."""
dm = _create_dm(dm_cls, datadir, split=split)
train_loader = dm.train_dataloader()
Expand All @@ -129,7 +129,7 @@ def test_emnist_datamodules(datadir, dm_cls, split):


@pytest.mark.parametrize("dm_cls", [BinaryEMNISTDataModule, EMNISTDataModule])
def test_emnist_datamodules_with_invalid_split(datadir, dm_cls):
def test_emnist_datamodules_with_invalid_split(datadir, catch_warnings, dm_cls):
"""Test EMNIST datamodules raise an exception if the provided `split` doesn't exist."""

with pytest.raises(ValueError, match="Unknown value"):
Expand All @@ -148,7 +148,7 @@ def test_emnist_datamodules_with_invalid_split(datadir, dm_cls):
("mnist", 10_000),
],
)
def test_emnist_datamodules_with_strict_val_split(datadir, dm_cls, split, expected_val_split):
def test_emnist_datamodules_with_strict_val_split(datadir, catch_warnings, dm_cls, split, expected_val_split):
"""Test EMNIST datamodules when strict_val_split is specified to use the validation set defined in the paper.
Refer to https://arxiv.org/abs/1702.05373 for `expected_val_split` values.
Expand Down

0 comments on commit 9f3dd69

Please sign in to comment.