From 2398960692e3daa10c54e0d24c9708c21e089df7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Apr 2020 00:04:01 +0200 Subject: [PATCH 1/5] reorder if clauses --- pytorch_lightning/trainer/evaluation_loop.py | 33 +++++++++++--------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index bc18958891a3c..ad01e5db085eb 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -296,20 +296,25 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_ if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)): model = model.module - # TODO: remove in v1.0.0 - if test_mode and self.is_overriden('test_end', model=model): - eval_results = model.test_end(outputs) - warnings.warn('Method `test_end` was deprecated in 0.7.0 and will be removed 1.0.0.' - ' Use `test_epoch_end` instead.', DeprecationWarning) - elif self.is_overriden('validation_end', model=model): - eval_results = model.validation_end(outputs) - warnings.warn('Method `validation_end` was deprecated in 0.7.0 and will be removed 1.0.0.' - ' Use `validation_epoch_end` instead.', DeprecationWarning) - - if test_mode and self.is_overriden('test_epoch_end', model=model): - eval_results = model.test_epoch_end(outputs) - elif self.is_overriden('validation_epoch_end', model=model): - eval_results = model.validation_epoch_end(outputs) + if test_mode: + if self.is_overriden('test_end', model=model): + # TODO: remove in v1.0.0 + eval_results = model.test_end(outputs) + warnings.warn('Method `test_end` was deprecated in 0.7.0 and will be removed 1.0.0.' + ' Use `test_epoch_end` instead.', DeprecationWarning) + + elif self.is_overriden('test_epoch_end', model=model): + eval_results = model.test_epoch_end(outputs) + + else: + if self.is_overriden('validation_end', model=model): + # TODO: remove in v1.0.0 + eval_results = model.validation_end(outputs) + warnings.warn('Method `validation_end` was deprecated in 0.7.0 and will be removed 1.0.0.' + ' Use `validation_epoch_end` instead.', DeprecationWarning) + + elif self.is_overriden('validation_epoch_end', model=model): + eval_results = model.validation_epoch_end(outputs) # enable train mode again model.train() From 69ae44c212de8eae5ffb4296059bbfada02cf3d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Apr 2020 03:43:53 +0200 Subject: [PATCH 2/5] fix wrong method overload in test --- tests/trainer/test_trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 923a4f833d7a0..0cb78932d6e6b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -528,15 +528,15 @@ def test_disabled_validation(): class CurrentModel(LightTrainDataloader, LightValidationMixin, TestModelBase): validation_step_invoked = False - validation_end_invoked = False + validation_epoch_end_invoked = False def validation_step(self, *args, **kwargs): self.validation_step_invoked = True return super().validation_step(*args, **kwargs) - def validation_end(self, *args, **kwargs): - self.validation_end_invoked = True - return super().validation_end(*args, **kwargs) + def validation_epoch_end(self, *args, **kwargs): + self.validation_epoch_end_invoked = True + return super().validation_epoch_end(*args, **kwargs) hparams = tutils.get_default_hparams() model = CurrentModel(hparams) @@ -556,7 +556,7 @@ def validation_end(self, *args, **kwargs): assert result == 1, 'training failed to complete' assert trainer.current_epoch == 1 assert not model.validation_step_invoked, '`validation_step` should not run when `val_percent_check=0`' - assert not model.validation_end_invoked, '`validation_end` should not run when `val_percent_check=0`' + assert not model.validation_epoch_end_invoked, '`validation_epoch_end` should not run when `val_percent_check=0`' # check that val_percent_check has no influence when fast_dev_run is turned on model = CurrentModel(hparams) @@ -567,7 +567,7 @@ def validation_end(self, *args, **kwargs): assert result == 1, 'training failed to complete' assert trainer.current_epoch == 0 assert model.validation_step_invoked, 'did not run `validation_step` with `fast_dev_run=True`' - assert model.validation_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`' + assert model.validation_epoch_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`' def test_nan_loss_detection(tmpdir): From 4e9435152b3f536355a91f832ac70dbb6ec6e4a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Apr 2020 03:46:48 +0200 Subject: [PATCH 3/5] fix formatting --- tests/trainer/test_trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0cb78932d6e6b..499e84b65972a 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -555,8 +555,10 @@ def validation_epoch_end(self, *args, **kwargs): # check that val_percent_check=0 turns off validation assert result == 1, 'training failed to complete' assert trainer.current_epoch == 1 - assert not model.validation_step_invoked, '`validation_step` should not run when `val_percent_check=0`' - assert not model.validation_epoch_end_invoked, '`validation_epoch_end` should not run when `val_percent_check=0`' + assert not model.validation_step_invoked, \ + '`validation_step` should not run when `val_percent_check=0`' + assert not model.validation_epoch_end_invoked, \ + '`validation_epoch_end` should not run when `val_percent_check=0`' # check that val_percent_check has no influence when fast_dev_run is turned on model = CurrentModel(hparams) From e3da528a76622286f153804696b6a2c76d288ded Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Apr 2020 03:56:47 +0200 Subject: [PATCH 4/5] update change_log --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fa9ce4f0a9a2..3fd2a805fb924 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue where `val_percent_check=0` would not disable validation ([#1251](https://github.com/PyTorchLightning/pytorch-lightning/pull/1251)) - Fixed average of incomplete `TensorRunningMean` ([#1309](https://github.com/PyTorchLightning/pytorch-lightning/pull/1309)) - Fixed an issue with early stopping that would prevent it from monitoring training metrics when validation is disabled / not implemented ([#1235](https://github.com/PyTorchLightning/pytorch-lightning/pull/1235)). +- Fixed a bug that would cause `trainer.test()` to run on the validation set when overloading `validation_epoch_end ` and `test_end` ([#1353](https://github.com/PyTorchLightning/pytorch-lightning/pull/1353)). ## [0.7.1] - 2020-03-07 From 08d93b2e84706d8a1797e9756da06938c461d2c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Apr 2020 03:57:33 +0200 Subject: [PATCH 5/5] fix line too long --- tests/trainer/test_trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 499e84b65972a..25729b27d0ba5 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -568,8 +568,10 @@ def validation_epoch_end(self, *args, **kwargs): assert result == 1, 'training failed to complete' assert trainer.current_epoch == 0 - assert model.validation_step_invoked, 'did not run `validation_step` with `fast_dev_run=True`' - assert model.validation_epoch_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`' + assert model.validation_step_invoked, \ + 'did not run `validation_step` with `fast_dev_run=True`' + assert model.validation_epoch_end_invoked, \ + 'did not run `validation_epoch_end` with `fast_dev_run=True`' def test_nan_loss_detection(tmpdir):