Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add warning when Trainer(log_every_n_steps) not well chosen #7734

Merged
merged 9 commits into from
Jun 7, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `should_rank_save_checkpoint` property to Training Plugins ([#7684](https://github.com/PyTorchLightning/pytorch-lightning/pull/7684))


- Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#X](https://github.com/PyTorchLightning/pytorch-lightning/pull/X))


### Changed

- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)
Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class TrainerDataLoadingMixin(ABC):
test_dataloaders: Optional[List[DataLoader]]
num_test_batches: List[Union[int, float]]
limit_train_batches: Union[int, float]
log_every_n_steps: int
overfit_batches: Union[int, float]
distributed_sampler_kwargs: dict
accelerator: Accelerator
Expand Down Expand Up @@ -299,6 +300,13 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)

if self.logger is not None and self.num_training_batches < self.log_every_n_steps:
rank_zero_warn(
f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"
f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
f" you want to see logs for the training epoch."
)

def _reset_eval_dataloader(
self,
model: LightningModule,
Expand Down
18 changes: 18 additions & 0 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,24 @@ def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch):
trainer.fit(model, train_dataloader=dataloader)


def test_warning_with_small_dataloader_and_logging_interval(tmpdir):
model = BoringModel()
dataloader = DataLoader(RandomDataset(32, length=10))
model.train_dataloader = lambda: dataloader

with pytest.warns(UserWarning, match=r"The number of training samples \(10\) is smaller than the logging interval"):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
log_every_n_steps=11,
)
trainer.fit(model)

with pytest.warns(UserWarning, match=r"The number of training samples \(1\) is smaller than the logging interval"):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=2, limit_train_batches=1)
trainer.fit(model)


def test_warning_with_iterable_dataset_and_len(tmpdir):
""" Tests that a warning message is shown when an IterableDataset defines `__len__`. """
model = BoringModel()
Expand Down