Skip to content

Commit

Permalink
bugfix: batch_size parameter for DataModules remaining (#344)
Browse files Browse the repository at this point in the history
* bugfix: batch_size for DataModules remaining

* Update sklearn datamodule tests

* Fix default_transforms. Keep internal for every data module

* fix typo on binary_mnist_datamodule

thanks @akihironitta

Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>

Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
  • Loading branch information
hecoding and akihironitta authored Dec 1, 2020
1 parent 2e903c3 commit 7c2e651
Show file tree
Hide file tree
Showing 12 changed files with 90 additions and 108 deletions.
4 changes: 2 additions & 2 deletions docs/source/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ Here's an example for logistic regression
# use any numpy or sklearn dataset
X, y = load_iris(return_X_y=True)
dm = SklearnDataModule(X, y)
dm = SklearnDataModule(X, y, batch_size=12)
# build model
model = LogisticRegression(input_dim=4, num_classes=3)
Expand All @@ -434,7 +434,7 @@ Here's an example for logistic regression
trainer = pl.Trainer(tpu_cores=8, precision=16)
trainer.fit(model, dm.train_dataloader(), dm.val_dataloader())
trainer.test(test_dataloaders=dm.test_dataloader(batch_size=12))
trainer.test(test_dataloaders=dm.test_dataloader())
Any input will be flattened across all dimensions except the first one (batch).
This means images, sound, etc... work out of the box.
Expand Down
33 changes: 12 additions & 21 deletions pl_bolts/datamodules/binary_mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
num_workers: int = 16,
normalize: bool = False,
seed: int = 42,
batch_size: int = 32,
*args,
**kwargs,
):
Expand All @@ -62,6 +63,7 @@ def __init__(
val_split: how many of the training images to use for the validation split
num_workers: how many workers to use for loading data
normalize: If true applies image normalize
batch_size: size of batch
"""
super().__init__(*args, **kwargs)

Expand All @@ -76,6 +78,7 @@ def __init__(
self.num_workers = num_workers
self.normalize = normalize
self.seed = seed
self.batch_size = batch_size

@property
def num_classes(self):
Expand All @@ -92,15 +95,11 @@ def prepare_data(self):
BinaryMNIST(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor())
BinaryMNIST(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor())

def train_dataloader(self, batch_size=32, transforms=None):
def train_dataloader(self):
"""
MNIST train set removes a subset to use for validation
Args:
batch_size: size of batch
transforms: custom transforms
"""
transforms = transforms or self.train_transforms or self._default_transforms()
transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms

dataset = BinaryMNIST(self.data_dir, train=True, download=False, transform=transforms)
train_length = len(dataset)
Expand All @@ -111,23 +110,19 @@ def train_dataloader(self, batch_size=32, transforms=None):
)
loader = DataLoader(
dataset_train,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
)
return loader

def val_dataloader(self, batch_size=32, transforms=None):
def val_dataloader(self):
"""
MNIST val set uses a subset of the training set for validation
Args:
batch_size: size of batch
transforms: custom transforms
"""
transforms = transforms or self.val_transforms or self._default_transforms()
transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
dataset = BinaryMNIST(self.data_dir, train=True, download=False, transform=transforms)
train_length = len(dataset)
_, dataset_val = random_split(
Expand All @@ -137,28 +132,24 @@ def val_dataloader(self, batch_size=32, transforms=None):
)
loader = DataLoader(
dataset_val,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
)
return loader

def test_dataloader(self, batch_size=32, transforms=None):
def test_dataloader(self):
"""
MNIST test set uses the test split
Args:
batch_size: size of batch
transforms: custom transforms
"""
transforms = transforms or self.val_transforms or self._default_transforms()
transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms

dataset = BinaryMNIST(self.data_dir, train=False, download=False, transform=transforms)
loader = DataLoader(
dataset,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
Expand Down
8 changes: 4 additions & 4 deletions pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def train_dataloader(self):
"""
CIFAR train set removes a subset to use for validation
"""
transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms
transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms

dataset = self.DATASET(self.data_dir, train=True, download=False, transform=transforms, **self.extra_args)
train_length = len(dataset)
Expand All @@ -139,7 +139,7 @@ def val_dataloader(self):
"""
CIFAR10 val set uses a subset of the training set for validation
"""
transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms
transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms

dataset = self.DATASET(self.data_dir, train=True, download=False, transform=transforms, **self.extra_args)
train_length = len(dataset)
Expand All @@ -162,7 +162,7 @@ def test_dataloader(self):
"""
CIFAR10 test set uses the test split
"""
transforms = self.default_transforms() if self.test_transforms is None else self.test_transforms
transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms

dataset = self.DATASET(self.data_dir, train=False, download=False, transform=transforms, **self.extra_args)
loader = DataLoader(
Expand All @@ -175,7 +175,7 @@ def test_dataloader(self):
)
return loader

def default_transforms(self):
def _default_transforms(self):
cf10_transforms = transform_lib.Compose([
transform_lib.ToTensor(),
cifar10_normalization()
Expand Down
16 changes: 8 additions & 8 deletions pl_bolts/datamodules/cityscapes_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def train_dataloader(self):
"""
Cityscapes train set
"""
transforms = self.train_transforms or self.default_transforms()
target_transforms = self.target_transforms or self.default_target_transforms()
transforms = self.train_transforms or self._default_transforms()
target_transforms = self.target_transforms or self._default_target_transforms()

dataset = Cityscapes(self.data_dir,
split='train',
Expand All @@ -136,8 +136,8 @@ def val_dataloader(self):
"""
Cityscapes val set
"""
transforms = self.val_transforms or self.default_transforms()
target_transforms = self.target_transforms or self.default_target_transforms()
transforms = self.val_transforms or self._default_transforms()
target_transforms = self.target_transforms or self._default_target_transforms()

dataset = Cityscapes(self.data_dir,
split='val',
Expand All @@ -161,8 +161,8 @@ def test_dataloader(self):
"""
Cityscapes test set
"""
transforms = self.test_transforms or self.default_transforms()
target_transforms = self.target_transforms or self.default_target_transforms()
transforms = self.test_transforms or self._default_transforms()
target_transforms = self.target_transforms or self._default_target_transforms()

dataset = Cityscapes(self.data_dir,
split='test',
Expand All @@ -181,7 +181,7 @@ def test_dataloader(self):
)
return loader

def default_transforms(self):
def _default_transforms(self):
cityscapes_transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(
Expand All @@ -191,7 +191,7 @@ def default_transforms(self):
])
return cityscapes_transforms

def default_target_transforms(self):
def _default_target_transforms(self):
cityscapes_target_trasnforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Lambda(lambda t: t.squeeze())
Expand Down
33 changes: 12 additions & 21 deletions pl_bolts/datamodules/fashion_mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
val_split: int = 5000,
num_workers: int = 16,
seed: int = 42,
batch_size: int = 32,
*args,
**kwargs,
):
Expand All @@ -59,6 +60,7 @@ def __init__(
data_dir: where to save/load the data
val_split: how many of the training images to use for the validation split
num_workers: how many workers to use for loading data
batch_size: size of batch
"""
super().__init__(*args, **kwargs)

Expand All @@ -72,6 +74,7 @@ def __init__(
self.val_split = val_split
self.num_workers = num_workers
self.seed = seed
self.batch_size = batch_size

@property
def num_classes(self):
Expand All @@ -88,15 +91,11 @@ def prepare_data(self):
FashionMNIST(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor())
FashionMNIST(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor())

def train_dataloader(self, batch_size=32, transforms=None):
def train_dataloader(self):
"""
FashionMNIST train set removes a subset to use for validation
Args:
batch_size: size of batch
transforms: custom transforms
"""
transforms = transforms or self.train_transforms or self._default_transforms()
transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms

dataset = FashionMNIST(self.data_dir, train=True, download=False, transform=transforms)
train_length = len(dataset)
Expand All @@ -107,23 +106,19 @@ def train_dataloader(self, batch_size=32, transforms=None):
)
loader = DataLoader(
dataset_train,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
)
return loader

def val_dataloader(self, batch_size=32, transforms=None):
def val_dataloader(self):
"""
FashionMNIST val set uses a subset of the training set for validation
Args:
batch_size: size of batch
transforms: custom transforms
"""
transforms = transforms or self.val_transforms or self._default_transforms()
transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms

dataset = FashionMNIST(self.data_dir, train=True, download=False, transform=transforms)
train_length = len(dataset)
Expand All @@ -134,28 +129,24 @@ def val_dataloader(self, batch_size=32, transforms=None):
)
loader = DataLoader(
dataset_val,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
)
return loader

def test_dataloader(self, batch_size=32, transforms=None):
def test_dataloader(self):
"""
FashionMNIST test set uses the test split
Args:
batch_size: size of batch
transforms: custom transforms
"""
transforms = transforms or self.test_transforms or self._default_transforms()
transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms

dataset = FashionMNIST(self.data_dir, train=False, download=False, transform=transforms)
loader = DataLoader(
dataset,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
Expand Down
16 changes: 9 additions & 7 deletions pl_bolts/datamodules/kitti_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,8 @@ def __init__(
self.num_workers = num_workers
self.seed = seed

self.default_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
std=[0.32064945, 0.32098866, 0.32325324])
])

# split into train, val, test
kitti_dataset = KittiDataset(self.data_dir, transform=self.default_transforms)
kitti_dataset = KittiDataset(self.data_dir, transform=self._default_transforms())

val_len = round(val_split * len(kitti_dataset))
test_len = round(test_split * len(kitti_dataset))
Expand Down Expand Up @@ -111,3 +105,11 @@ def test_dataloader(self):
shuffle=False,
num_workers=self.num_workers)
return loader

def _default_transforms(self):
kitti_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
std=[0.32064945, 0.32098866, 0.32325324])
])
return kitti_transforms
18 changes: 5 additions & 13 deletions pl_bolts/datamodules/mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
val_split: how many of the training images to use for the validation split
num_workers: how many workers to use for loading data
normalize: If true applies image normalize
batch_size: size of batch
"""
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -95,11 +96,8 @@ def prepare_data(self):
def train_dataloader(self):
"""
MNIST train set removes a subset to use for validation
Args:
transforms: custom transforms
"""
transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms
transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms

dataset = MNIST(self.data_dir, train=True, download=False, transform=transforms)
train_length = len(dataset)
Expand All @@ -119,11 +117,8 @@ def train_dataloader(self):
def val_dataloader(self):
"""
MNIST val set uses a subset of the training set for validation
Args:
transforms: custom transforms
"""
transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms
transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
dataset = MNIST(self.data_dir, train=True, download=False, transform=transforms)
train_length = len(dataset)
_, dataset_val = random_split(
Expand All @@ -142,11 +137,8 @@ def val_dataloader(self):
def test_dataloader(self):
"""
MNIST test set uses the test split
Args:
transforms: custom transforms
"""
transforms = self.default_transforms() if self.test_transforms is None else self.test_transforms
transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms

dataset = MNIST(self.data_dir, train=False, download=False, transform=transforms)
loader = DataLoader(
Expand All @@ -155,7 +147,7 @@ def test_dataloader(self):
)
return loader

def default_transforms(self):
def _default_transforms(self):
if self.normalize:
mnist_transforms = transform_lib.Compose(
[transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
Expand Down
Loading

0 comments on commit 7c2e651

Please sign in to comment.