From ebd9fc9530242e1c9b5f3093dc62ceb4185735b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Apr 2020 15:25:32 +0200 Subject: [PATCH] Fix for incorrect run on the validation set with overwritten validation_epoch_end and test_end (#1353) * reorder if clauses * fix wrong method overload in test * fix formatting * update change_log * fix line too long --- CHANGELOG.md | 1 + pytorch_lightning/trainer/evaluation_loop.py | 33 +++++++++++--------- tests/trainer/test_trainer.py | 20 +++++++----- 3 files changed, 32 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d7af014edecb..5ea175f6d2e8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -61,6 +61,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue where `val_percent_check=0` would not disable validation ([#1251](https://github.com/PyTorchLightning/pytorch-lightning/pull/1251)) - Fixed average of incomplete `TensorRunningMean` ([#1309](https://github.com/PyTorchLightning/pytorch-lightning/pull/1309)) - Fixed an issue with early stopping that would prevent it from monitoring training metrics when validation is disabled / not implemented ([#1235](https://github.com/PyTorchLightning/pytorch-lightning/pull/1235)). +- Fixed a bug that would cause `trainer.test()` to run on the validation set when overloading `validation_epoch_end ` and `test_end` ([#1353](https://github.com/PyTorchLightning/pytorch-lightning/pull/1353)). ## [0.7.1] - 2020-03-07 diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 7c4f4f238852c..b55535784b3ff 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -295,20 +295,25 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_ if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)): model = model.module - # TODO: remove in v1.0.0 - if test_mode and self.is_overriden('test_end', model=model): - eval_results = model.test_end(outputs) - warnings.warn('Method `test_end` was deprecated in 0.7.0 and will be removed 1.0.0.' - ' Use `test_epoch_end` instead.', DeprecationWarning) - elif self.is_overriden('validation_end', model=model): - eval_results = model.validation_end(outputs) - warnings.warn('Method `validation_end` was deprecated in 0.7.0 and will be removed 1.0.0.' - ' Use `validation_epoch_end` instead.', DeprecationWarning) - - if test_mode and self.is_overriden('test_epoch_end', model=model): - eval_results = model.test_epoch_end(outputs) - elif self.is_overriden('validation_epoch_end', model=model): - eval_results = model.validation_epoch_end(outputs) + if test_mode: + if self.is_overriden('test_end', model=model): + # TODO: remove in v1.0.0 + eval_results = model.test_end(outputs) + warnings.warn('Method `test_end` was deprecated in 0.7.0 and will be removed 1.0.0.' + ' Use `test_epoch_end` instead.', DeprecationWarning) + + elif self.is_overriden('test_epoch_end', model=model): + eval_results = model.test_epoch_end(outputs) + + else: + if self.is_overriden('validation_end', model=model): + # TODO: remove in v1.0.0 + eval_results = model.validation_end(outputs) + warnings.warn('Method `validation_end` was deprecated in 0.7.0 and will be removed 1.0.0.' + ' Use `validation_epoch_end` instead.', DeprecationWarning) + + elif self.is_overriden('validation_epoch_end', model=model): + eval_results = model.validation_epoch_end(outputs) # enable train mode again model.train() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 30f65576f27c7..ddf79cb2f779e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -528,15 +528,15 @@ def test_disabled_validation(): class CurrentModel(LightTrainDataloader, LightValidationMixin, TestModelBase): validation_step_invoked = False - validation_end_invoked = False + validation_epoch_end_invoked = False def validation_step(self, *args, **kwargs): self.validation_step_invoked = True return super().validation_step(*args, **kwargs) - def validation_end(self, *args, **kwargs): - self.validation_end_invoked = True - return super().validation_end(*args, **kwargs) + def validation_epoch_end(self, *args, **kwargs): + self.validation_epoch_end_invoked = True + return super().validation_epoch_end(*args, **kwargs) hparams = tutils.get_default_hparams() model = CurrentModel(hparams) @@ -555,8 +555,10 @@ def validation_end(self, *args, **kwargs): # check that val_percent_check=0 turns off validation assert result == 1, 'training failed to complete' assert trainer.current_epoch == 1 - assert not model.validation_step_invoked, '`validation_step` should not run when `val_percent_check=0`' - assert not model.validation_end_invoked, '`validation_end` should not run when `val_percent_check=0`' + assert not model.validation_step_invoked, \ + '`validation_step` should not run when `val_percent_check=0`' + assert not model.validation_epoch_end_invoked, \ + '`validation_epoch_end` should not run when `val_percent_check=0`' # check that val_percent_check has no influence when fast_dev_run is turned on model = CurrentModel(hparams) @@ -566,8 +568,10 @@ def validation_end(self, *args, **kwargs): assert result == 1, 'training failed to complete' assert trainer.current_epoch == 0 - assert model.validation_step_invoked, 'did not run `validation_step` with `fast_dev_run=True`' - assert model.validation_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`' + assert model.validation_step_invoked, \ + 'did not run `validation_step` with `fast_dev_run=True`' + assert model.validation_epoch_end_invoked, \ + 'did not run `validation_epoch_end` with `fast_dev_run=True`' def test_nan_loss_detection(tmpdir):