Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate write_predictions on the LightningModule #7066

Merged
merged 4 commits into from
Apr 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- Deprecated `LightningModule.write_predictions` and `LigtningModule.write_predictions_dict` ([#7066](https://github.com/PyTorchLightning/pytorch-lightning/pull/7066))


- Deprecated `TrainerLoggingMixin` in favor of a separate utilities module for metric handling ([#7180](https://github.com/PyTorchLightning/pytorch-lightning/pull/7180))


Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,14 @@ def write_prediction(
when running in distributed mode, calling ``write_prediction`` will create a file for
each device with respective names: ``filename_rank_0.pt``, ``filename_rank_1.pt``, ...

.. deprecated::v1.3
Will be removed in v1.5.0.
"""
rank_zero_deprecation(
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
'LightningModule method `write_prediction` was deprecated in v1.3'
' and will be removed in v1.5.'
)

self.trainer.evaluation_loop.predictions._add_prediction(name, value, filename)

def write_prediction_dict(self, predictions_dict: Dict[str, Any], filename: str = 'predictions.pt'):
Expand All @@ -426,7 +433,14 @@ def write_prediction_dict(self, predictions_dict: Dict[str, Any], filename: str
when running in distributed mode, calling ``write_prediction_dict`` will create a file for
each device with respective names: ``filename_rank_0.pt``, ``filename_rank_1.pt``, ...

.. deprecated::v1.3
Will be removed in v1.5.0.
"""
rank_zero_deprecation(
'LightningModule method `write_prediction_dict` was deprecated in v1.3 and'
' will be removed in v1.5.'
)

for k, v in predictions_dict.items():
self.write_prediction(k, v, filename)

Expand Down
37 changes: 37 additions & 0 deletions tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in v1.5.0"""
import os
from unittest import mock

import pytest
Expand Down Expand Up @@ -244,6 +245,42 @@ def bar(self):
pass


def test_v1_5_0_lightning_module_write_prediction(tmpdir):

class DeprecatedWritePredictionsModel(BoringModel):

def __init__(self):
super().__init__()
self._predictions_file = os.path.join(tmpdir, "predictions.pt")

def test_step(self, batch, batch_idx):
super().test_step(batch, batch_idx)
self.write_prediction("a", torch.Tensor(0), self._predictions_file)

def test_epoch_end(self, outputs):
self.write_prediction_dict({"a": "b"}, self._predictions_file)

with pytest.deprecated_call(match="`write_prediction` was deprecated in v1.3 and will be removed in v1.5"):
model = DeprecatedWritePredictionsModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
checkpoint_callback=False,
logger=False,
)
trainer.test(model)

with pytest.deprecated_call(match="`write_prediction_dict` was deprecated in v1.3 and will be removed in v1.5"):
model = DeprecatedWritePredictionsModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
checkpoint_callback=False,
logger=False,
)
trainer.test(model)


def test_v1_5_0_trainer_logging_mixin(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, checkpoint_callback=False, logger=False)
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
Expand Down