diff --git a/CHANGELOG.md b/CHANGELOG.md index c71f55f8a4db4..aaea00a18b856 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -149,6 +149,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated `LightningModule.write_predictions` and `LigtningModule.write_predictions_dict` ([#7066](https://github.com/PyTorchLightning/pytorch-lightning/pull/7066)) + + - Deprecated `TrainerLoggingMixin` in favor of a separate utilities module for metric handling ([#7180](https://github.com/PyTorchLightning/pytorch-lightning/pull/7180)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 445ed76d7d27d..c824502e25ec4 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -406,7 +406,14 @@ def write_prediction( when running in distributed mode, calling ``write_prediction`` will create a file for each device with respective names: ``filename_rank_0.pt``, ``filename_rank_1.pt``, ... + .. deprecated::v1.3 + Will be removed in v1.5.0. """ + rank_zero_deprecation( + 'LightningModule method `write_prediction` was deprecated in v1.3' + ' and will be removed in v1.5.' + ) + self.trainer.evaluation_loop.predictions._add_prediction(name, value, filename) def write_prediction_dict(self, predictions_dict: Dict[str, Any], filename: str = 'predictions.pt'): @@ -426,7 +433,14 @@ def write_prediction_dict(self, predictions_dict: Dict[str, Any], filename: str when running in distributed mode, calling ``write_prediction_dict`` will create a file for each device with respective names: ``filename_rank_0.pt``, ``filename_rank_1.pt``, ... + .. deprecated::v1.3 + Will be removed in v1.5.0. """ + rank_zero_deprecation( + 'LightningModule method `write_prediction_dict` was deprecated in v1.3 and' + ' will be removed in v1.5.' + ) + for k, v in predictions_dict.items(): self.write_prediction(k, v, filename) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 0a838da2faf27..b8f398b9e1c0f 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test deprecated functionality which will be removed in v1.5.0""" +import os from unittest import mock import pytest @@ -244,6 +245,42 @@ def bar(self): pass +def test_v1_5_0_lightning_module_write_prediction(tmpdir): + + class DeprecatedWritePredictionsModel(BoringModel): + + def __init__(self): + super().__init__() + self._predictions_file = os.path.join(tmpdir, "predictions.pt") + + def test_step(self, batch, batch_idx): + super().test_step(batch, batch_idx) + self.write_prediction("a", torch.Tensor(0), self._predictions_file) + + def test_epoch_end(self, outputs): + self.write_prediction_dict({"a": "b"}, self._predictions_file) + + with pytest.deprecated_call(match="`write_prediction` was deprecated in v1.3 and will be removed in v1.5"): + model = DeprecatedWritePredictionsModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + checkpoint_callback=False, + logger=False, + ) + trainer.test(model) + + with pytest.deprecated_call(match="`write_prediction_dict` was deprecated in v1.3 and will be removed in v1.5"): + model = DeprecatedWritePredictionsModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + checkpoint_callback=False, + logger=False, + ) + trainer.test(model) + + def test_v1_5_0_trainer_logging_mixin(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False, logger=False) with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):