Skip to content

Commit

Permalink
Disable train dataloader shuffle when overfit_batches is active.
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilJd committed Sep 15, 2020
1 parent 4dc4c8c commit 8cf5ad5
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed getting `experiment_id` from MLFlow only once instead of each training loop ([#3394](https://github.com/PyTorchLightning/pytorch-lightning/pull/3394))

- Fixed overfit_batches which now correctly disables shuffling for the training loader. ([#3501](https://github.com/PyTorchLightning/pytorch-lightning/pull/3501))

## [0.9.0] - YYYY-MM-DD

### Added
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,12 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
model: The current `LightningModule`
"""
self.train_dataloader = self.request_dataloader(model.train_dataloader)
if (self.overfit_batches > 0):
if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler):
rank_zero_warn('You requested to overfit but enabled training dataloader shuffling.'
' We are turning it off for you.')
self.train_dataloader = self.replace_sampler(
self.train_dataloader, SequentialSampler(self.train_dataloader.dataset))

# debugging
self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader])
Expand Down Expand Up @@ -247,7 +253,7 @@ def _reset_eval_dataloader(

# when overfitting, the dataloader should not have sampler
if self.overfit_batches > 0:
rank_zero_warn('You requested to overfit but enabled training dataloader shuffling.'
rank_zero_warn('You requested to overfit but enabled test/val dataloader shuffling.'
' We are turning it off for you.')
dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset))

Expand Down
7 changes: 5 additions & 2 deletions tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,13 @@ def test_overfit_batch_limits(tmpdir):
# ------------------------------------------------------
# get the training loader and batch
# ------------------------------------------------------
# Create a reference train dataloader without shuffling.
train_loader = DataLoader(model.train_dataloader().dataset, shuffle=False)
(xa, ya) = next(iter(train_loader))
train_loader = DataLoader(model.train_dataloader().dataset, shuffle=True)
full_train_samples = len(train_loader)
num_train_samples = int(0.11 * full_train_samples)

(xa, ya) = next(iter(train_loader))

# ------------------------------------------------------
# set VAL and Test loaders
Expand All @@ -87,7 +89,8 @@ def test_overfit_batch_limits(tmpdir):

trainer = Trainer(overfit_batches=0.11)
trainer.reset_train_dataloader(model)
assert trainer.train_dataloader is train_loader
# The dataloader should have been overwritten with a Sequential sampler.
assert trainer.train_dataloader is not train_loader
assert trainer.num_training_batches == num_train_samples

# make sure the loaders are the same
Expand Down

0 comments on commit 8cf5ad5

Please sign in to comment.