Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FashionMNIST/EMNIST Datamodules #871

Merged
merged 8 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 24 additions & 33 deletions pl_bolts/datamodules/emnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.transforms.dataset_normalizations import emnist_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -14,7 +13,6 @@
EMNIST = object


@under_review()
class EMNISTDataModule(VisionDataModule):
"""
.. figure:: https://user-images.githubusercontent.com/4632336/123210742-4d6b3380-d477-11eb-80da-3e9a74a18a07.png
Expand Down Expand Up @@ -76,6 +74,23 @@ class EMNISTDataModule(VisionDataModule):
|
Args:
data_dir: Root directory of dataset.
split: The dataset has 6 different splits: ``byclass``, ``bymerge``,
``balanced``, ``letters``, ``digits`` and ``mnist``.
This argument is passed to :class:`torchvision.datasets.EMNIST`.
val_split: Percent (float) or number (int) of samples to use for the validation split.
num_workers: How many workers to use for loading data
normalize: If ``True``, applies image normalize.
batch_size: How many samples per batch to load.
seed: Random seed to be used for train/val/test splits.
shuffle: If ``True``, shuffles the train data every epoch.
pin_memory: If ``True``, the data loader will copy Tensors into
CUDA pinned memory before returning them.
drop_last: If ``True``, drops the last incomplete batch.
strict_val_split: If ``True``, uses the validation split defined in the paper and ignores ``val_split``.
Note that it only works with ``"balanced"``, ``"digits"``, ``"letters"``, ``"mnist"`` splits.
Here is the default EMNIST, train, val, test-splits and transforms.
Transforms::
Expand All @@ -87,8 +102,10 @@ class EMNISTDataModule(VisionDataModule):
Example::
from pl_bolts.datamodules import EMNISTDataModule
dm = EMNISTDataModule('.')
model = LitModel()
Trainer().fit(model, datamodule=dm)
"""

Expand Down Expand Up @@ -119,25 +136,6 @@ def __init__(
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: Where to save/load the data.
split: The dataset has 6 different splits: ``byclass``, ``bymerge``,
``balanced``, ``letters``, ``digits`` and ``mnist``.
This argument is passed to :class:`torchvision.datasets.EMNIST`.
val_split: Percent (float) or number (int) of samples
to use for the validation split.
num_workers: How many workers to use for loading data
normalize: If ``True``, applies image normalize.
batch_size: How many samples per batch to load.
seed: Random seed to be used for train/val/test splits.
shuffle: If ``True``, shuffles the train data every epoch.
pin_memory: If ``True``, the data loader will copy Tensors into
CUDA pinned memory before returning them.
drop_last: If ``True``, drops the last incomplete batch.
strict_val_split: If ``True``, uses the validation split defined in the paper and ignores ``val_split``.
Note that it only works with ``"balanced"``, ``"digits"``, ``"letters"``, ``"mnist"`` splits.
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
"You want to use MNIST dataset loaded from `torchvision` which is not installed yet."
Expand Down Expand Up @@ -183,13 +181,11 @@ def num_classes(self) -> int:

def prepare_data(self, *args: Any, **kwargs: Any) -> None:
"""Saves files to ``data_dir``."""

self.dataset_cls(self.data_dir, split=self.split, train=True, download=True)
self.dataset_cls(self.data_dir, split=self.split, train=False, download=True)

def setup(self, stage: Optional[str] = None) -> None:
"""Creates train, val, and test dataset."""

if stage == "fit" or stage is None:
train_transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms
val_transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms
Expand All @@ -212,14 +208,9 @@ def setup(self, stage: Optional[str] = None) -> None:
)

def default_transforms(self) -> Callable:
if self.normalize:
emnist_transforms = transform_lib.Compose([transform_lib.ToTensor(), emnist_normalization(self.split)])
else:
emnist_transforms = transform_lib.Compose([transform_lib.ToTensor()])

return (
transform_lib.Compose(
[
transform_lib.ToTensor(),
emnist_normalization(self.split),
]
)
if self.normalize
else transform_lib.Compose([transform_lib.ToTensor()])
)
return emnist_transforms
32 changes: 13 additions & 19 deletions pl_bolts/datamodules/fashion_mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -13,14 +12,25 @@
FashionMNIST = None


@under_review()
class FashionMNISTDataModule(VisionDataModule):
"""
.. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/
wp-content/uploads/2019/02/Plot-of-a-Subset-of-Images-from-the-Fashion-MNIST-Dataset.png
:width: 400
:alt: Fashion MNIST
Args:
data_dir: Root directory of dataset.
val_split: Percent (float) or number (int) of samples to use for the validation split.
num_workers: Number of workers to use for loading data.
normalize: If ``True``, applies image normalization.
batch_size: Number of samples per batch to load.
seed: Random seed to be used for train/val/test splits.
shuffle: If ``True``, shuffles the train data every epoch.
pin_memory: If ``True``, the data loader will copy Tensors into CUDA pinned memory before
returning them.
drop_last: If ``True``, drops the last incomplete batch.
Specs:
- 10 classes (1 per type)
- Each image is (1 x 28 x 28)
Expand Down Expand Up @@ -61,19 +71,6 @@ def __init__(
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: Where to save/load the data
val_split: Percent (float) or number (int) of samples to use for the validation split
num_workers: How many workers to use for loading data
normalize: If true applies image normalize
batch_size: How many samples per batch to load
seed: Random seed to be used for train/val/test splits
shuffle: If true shuffles the train data every epoch
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
"You want to use FashionMNIST dataset loaded from `torchvision` which is not installed yet."
Expand All @@ -95,10 +92,7 @@ def __init__(

@property
def num_classes(self) -> int:
"""
Return:
10
"""
"""Returns the number of classes."""
return 10

def default_transforms(self) -> Callable:
Expand Down
6 changes: 3 additions & 3 deletions tests/datamodules/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_sr_datamodule(datadir):

@pytest.mark.parametrize("split", ["byclass", "bymerge", "balanced", "letters", "digits", "mnist"])
@pytest.mark.parametrize("dm_cls", [BinaryEMNISTDataModule, EMNISTDataModule])
otaj marked this conversation as resolved.
Show resolved Hide resolved
def test_emnist_datamodules(datadir, dm_cls, split):
def test_emnist_datamodules(datadir, catch_warnings, dm_cls, split):
"""Test BinaryEMNIST and EMNIST datamodules download data and have the correct shape."""
dm = _create_dm(dm_cls, datadir, split=split)
train_loader = dm.train_dataloader()
Expand All @@ -129,7 +129,7 @@ def test_emnist_datamodules(datadir, dm_cls, split):


@pytest.mark.parametrize("dm_cls", [BinaryEMNISTDataModule, EMNISTDataModule])
def test_emnist_datamodules_with_invalid_split(datadir, dm_cls):
def test_emnist_datamodules_with_invalid_split(datadir, catch_warnings, dm_cls):
"""Test EMNIST datamodules raise an exception if the provided `split` doesn't exist."""

with pytest.raises(ValueError, match="Unknown value"):
Expand All @@ -148,7 +148,7 @@ def test_emnist_datamodules_with_invalid_split(datadir, dm_cls):
("mnist", 10_000),
],
)
def test_emnist_datamodules_with_strict_val_split(datadir, dm_cls, split, expected_val_split):
def test_emnist_datamodules_with_strict_val_split(datadir, catch_warnings, dm_cls, split, expected_val_split):
"""Test EMNIST datamodules when strict_val_split is specified to use the validation set defined in the paper.
Refer to https://arxiv.org/abs/1702.05373 for `expected_val_split` values.
Expand Down