Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Early stopping when validation is disabled #1235

Merged
merged 6 commits into from
Mar 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed all warnings and errors in the docs build process ([#1191](https://github.com/PyTorchLightning/pytorch-lightning/pull/1191))
- Fixed an issue where `val_percent_check=0` would not disable validation ([#1251](https://github.com/PyTorchLightning/pytorch-lightning/pull/1251))
- Fixed average of incomplete `TensorRunningMean` ([#1309](https://github.com/PyTorchLightning/pytorch-lightning/pull/1309))
- Fixed an issue with early stopping that would prevent it from monitoring training metrics when validation is disabled / not implemented ([#1235](https://github.com/PyTorchLightning/pytorch-lightning/pull/1235)).

## [0.7.1] - 2020-03-07

Expand Down
15 changes: 12 additions & 3 deletions docs/source/early_stopping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ Early stopping
Default behavior
----------------
By default early stopping will be enabled if `'val_loss'`
is found in `validation_epoch_end()` return dict. Otherwise
training will proceed with early stopping disabled.
is found in :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`'s
return dict. Otherwise training will proceed with early stopping disabled.

Enable Early Stopping
---------------------
Expand All @@ -30,9 +30,18 @@ There are two ways to enable early stopping.
)
trainer = Trainer(early_stop_callback=early_stop_callback)

In any case, the callback will fall back to the training metrics (returned in
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`,
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step_end`)
looking for a key to monitor if validation is disabled or
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`
is not defined.


Disable Early Stopping
----------------------
To disable early stopping pass ``False`` to the `early_stop_callback`.
To disable early stopping pass ``False`` to the
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.early_stop_callback`.
Note that ``None`` will not disable early stopping but will lead to the
default behaviour.

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,8 @@ def train(self):
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True

# TODO wrap this logic into the callback
if self.enable_early_stop and not self.disable_validation and is_val_epoch:
if ((met_min_epochs and met_min_steps) or self.fast_dev_run):
if self.enable_early_stop:
if (met_min_epochs and met_min_steps) or self.fast_dev_run:
should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model())
# stop training
stop = should_stop and met_min_epochs
Expand Down
35 changes: 33 additions & 2 deletions tests/trainer/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import tests.base.utils as tutils
from pytorch_lightning import Callback
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import EarlyStopping
from tests.base import (
TestModelBase,
LightTrainDataloader,
LightTestMixin,
LightValidationMixin,
LightTestMixin
TestModelBase
)


Expand Down Expand Up @@ -150,3 +151,33 @@ def on_test_end(self, trainer, pl_module):

assert test_callback.on_test_start_called
assert test_callback.on_test_end_called


def test_early_stopping_without_val_step(tmpdir):
"""Test that early stopping callback falls back to training metrics when no validation defined."""
tutils.reset_seed()

class ModelWithoutValStep(LightTrainDataloader, TestModelBase):

def training_step(self, *args, **kwargs):
output = super().training_step(*args, **kwargs)
loss = output['loss'] # could be anything else
output.update({'my_train_metric': loss})
return output

hparams = tutils.get_default_hparams()
model = ModelWithoutValStep(hparams)

stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1)
trainer_options = dict(
default_save_path=tmpdir,
early_stop_callback=stopping,
overfit_pct=0.20,
max_epochs=10,
)

trainer = Trainer(**trainer_options)
result = trainer.fit(model)

assert result == 1, 'training failed to complete'
assert trainer.current_epoch < trainer.max_epochs - 1