Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add BasePredictionWriter 3/3 #7127

Merged
merged 59 commits into from
Apr 27, 2021
Merged
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
2e9b932
wip
tchaton Apr 20, 2021
7effee2
update
tchaton Apr 20, 2021
9321a73
update
tchaton Apr 20, 2021
128cc45
update
tchaton Apr 20, 2021
c7e49e9
update
tchaton Apr 20, 2021
9f82f7a
update
tchaton Apr 20, 2021
ce85174
typo
tchaton Apr 20, 2021
d3f9f30
update on comments
tchaton Apr 21, 2021
e1ccd1a
update
tchaton Apr 21, 2021
2a994db
update
tchaton Apr 21, 2021
69b6d77
update
tchaton Apr 21, 2021
bcf3c2b
update
tchaton Apr 22, 2021
643c8e5
update changelog
tchaton Apr 22, 2021
7109c16
update
tchaton Apr 22, 2021
fea8294
Merge branch 'master' into predict_loop_1
carmocca Apr 22, 2021
ce2656d
Fix merge
carmocca Apr 22, 2021
4ba47ed
Fix merge
carmocca Apr 22, 2021
0705ca7
Merge branch 'master' into predict_loop_1
tchaton Apr 22, 2021
54a5008
Merge branch 'predict_loop_1' of https://github.com/PyTorchLightning/…
tchaton Apr 22, 2021
1bf0325
move code
tchaton Apr 22, 2021
5243c91
resolve test
tchaton Apr 22, 2021
550d3f3
add extra test
tchaton Apr 22, 2021
0169e9e
add an extra test
tchaton Apr 22, 2021
4962459
update on comments
tchaton Apr 23, 2021
a371c5c
add typing
tchaton Apr 23, 2021
a163c2d
resolve flake8
tchaton Apr 23, 2021
63551ca
Refactor and Docs
carmocca Apr 23, 2021
0937e73
Fix tests
carmocca Apr 23, 2021
d4f523e
Fix tests
carmocca Apr 23, 2021
9a44529
Fix tests
carmocca Apr 23, 2021
d66d704
Duplicate
carmocca Apr 23, 2021
71685f2
Fix tests
carmocca Apr 23, 2021
89b281e
resolve bug
tchaton Apr 26, 2021
4416fa5
update
tchaton Apr 26, 2021
b627ed0
update on comments
tchaton Apr 26, 2021
e2d202c
Update pytorch_lightning/utilities/imports.py
tchaton Apr 26, 2021
851fb5f
Update pytorch_lightning/utilities/device_parser.py
tchaton Apr 26, 2021
e58c707
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
tchaton Apr 27, 2021
8b8258c
update
tchaton Apr 27, 2021
33dafe2
Merge branch 'predict_loop_1' of https://github.com/PyTorchLightning/…
tchaton Apr 27, 2021
6d752be
update
tchaton Apr 27, 2021
2012a64
update
tchaton Apr 27, 2021
e7fa7f9
update on comments
tchaton Apr 27, 2021
0ae1bdd
resolve flkae8
tchaton Apr 27, 2021
6ab6228
update test
tchaton Apr 27, 2021
681fe12
Apply suggestions from code review
carmocca Apr 27, 2021
8d4f26e
update on comments
tchaton Apr 27, 2021
89e2286
Merge branch 'predict_loop_1' of https://github.com/PyTorchLightning/…
tchaton Apr 27, 2021
f8de388
Update pytorch_lightning/callbacks/prediction_writer.py
kaushikb11 Apr 27, 2021
b8386a5
Update pytorch_lightning/callbacks/prediction_writer.py
kaushikb11 Apr 27, 2021
7c4c391
Update pytorch_lightning/callbacks/prediction_writer.py
kaushikb11 Apr 27, 2021
7e7885c
update on comments
tchaton Apr 27, 2021
0505e96
Merge branch 'predict_loop_1' of https://github.com/PyTorchLightning/…
tchaton Apr 27, 2021
f054e62
update
tchaton Apr 27, 2021
721f589
update on comment
tchaton Apr 27, 2021
c55d78a
Apply suggestions from code review
Borda Apr 27, 2021
ee5aac0
update
tchaton Apr 27, 2021
5572aa9
Merge branch 'predict_loop_1' of https://github.com/PyTorchLightning/…
tchaton Apr 27, 2021
b3f60b8
Merge branch 'master' into predict_loop_1
tchaton Apr 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `trainer.predict(return_predictions=None|False|True)` ([#7215](https://github.com/PyTorchLightning/pytorch-lightning/pull/7215))


- Added `BasePredictionWriter` callback to implement prediction saving ([#7127](https://github.com/PyTorchLightning/pytorch-lightning/pull/7127))


### Changed

- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
Expand Down
1 change: 1 addition & 0 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ Lightning has a few built-in callbacks.
LearningRateMonitor
ModelCheckpoint
ModelPruning
BasePredictionWriter
ProgressBar
ProgressBarBase
QuantizationAwareTraining
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pytorch_lightning.callbacks.lambda_function import LambdaCallback
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.pruning import ModelPruning
from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining
Expand All @@ -36,6 +37,7 @@
'LearningRateMonitor',
'ModelCheckpoint',
'ModelPruning',
'BasePredictionWriter',
'ProgressBar',
'ProgressBarBase',
'QuantizationAwareTraining',
Expand Down
119 changes: 119 additions & 0 deletions pytorch_lightning/callbacks/prediction_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
BasePredictionWriter
====================
tchaton marked this conversation as resolved.
Show resolved Hide resolved

Aids in saving predictions
"""
from typing import Any, List, Optional

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class WriteInterval(LightningEnum):
BATCH = "batch"
EPOCH = "epoch"
BATCH_AND_EPOCH = "batch_and_epoch"

@property
def on_batch(self) -> bool:
return self in (self.BATCH, self.BATCH_AND_EPOCH)

@property
def on_epoch(self) -> bool:
return self in (self.EPOCH, self.BATCH_AND_EPOCH)


class BasePredictionWriter(Callback):
"""
Base class to implement how the predictions should be stored.

Args:
write_interval: When to write.

Example::

from pytorch_lightning.callbacks import BasePredictionWriter

class CustomWriter(BasePredictionWriter):

def __init__(self, output_dir: str, write_interval: str):
super().__init__(write_interval)
self.output_dir

def write_on_batch_end(
self, trainer, pl_module: 'LightningModule', prediction: Any, batch_indices: List[int], batch: Any,
batch_idx: int, dataloader_idx: int
):
torch.save(prediction, os.path.join(self.output_dir, dataloader_idx, f"{batch_idx}.pt"))

def write_on_epoch_end(
self, trainer, pl_module: 'LightningModule', predictions: List[Any], batch_indices: List[Any]
):
torch.save(predictions, os.path.join(self.output_dir, "predictions.pt"))
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, write_interval: str = "batch") -> None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if write_interval not in list(WriteInterval):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException(f"`write_interval` should be one of {[i.value for i in WriteInterval]}.")
self.interval = WriteInterval(write_interval)

def write_on_batch_end(
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
prediction: Any,
batch_indices: Optional[List[int]],
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
"""Override with the logic to write a single batch."""
raise NotImplementedError()

def write_on_epoch_end(
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
predictions: List[Any],
batch_indices: Optional[List[Any]],
tchaton marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Override with the logic to write all batches."""
raise NotImplementedError()

def on_predict_batch_end(
self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
outputs: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
if not self.interval.on_batch:
return
is_distributed = trainer.accelerator_connector.is_distributed
batch_indices = trainer.predict_loop.batch_indices if is_distributed else None
self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx)

def on_predict_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: List[Any]) -> None:
if not self.interval.on_epoch:
return
is_distributed = trainer.accelerator_connector.is_distributed
epoch_batch_indices = trainer.predict_loop.epoch_batch_indices if is_distributed else None
self.write_on_epoch_end(trainer, pl_module, trainer.predict_loop.predictions, epoch_batch_indices)
8 changes: 6 additions & 2 deletions pytorch_lightning/trainer/predict_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def return_predictions(self, return_predictions: Optional[bool] = None) -> None:
# For non ``DDPSpawnPlugin`` plugin, the `return_predictions` is True by default unless user decide otherwise.
self._return_predictions = not is_ddp_spawn if return_predictions is None else return_predictions

@property
def should_store_predictions(self) -> bool:
return self.return_predictions or any(c.interval.on_epoch for c in self.trainer.prediction_writer_callbacks)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

carmocca marked this conversation as resolved.
Show resolved Hide resolved
def on_trainer_init(self):
self.trainer.num_predict_batches = []

Expand Down Expand Up @@ -112,14 +116,14 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:

self.trainer.call_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx)

if self.return_predictions:
if self.should_store_predictions:
self.predictions[dataloader_idx].append(predictions)

def _store_batch_indices(self, dataloader_idx: int) -> None:
batch_sampler = self.trainer.predict_dataloaders[dataloader_idx].batch_sampler
if isinstance(batch_sampler, IndexBatchSamplerWrapper):
self.batch_indices = batch_sampler.batch_indices
if self.return_predictions:
if self.should_store_predictions:
self.epoch_batch_indices[dataloader_idx].append(batch_sampler.batch_indices)

def on_predict_start(self) -> None:
Expand Down
9 changes: 9 additions & 0 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loggers import LightningLoggerBase
Expand Down Expand Up @@ -309,6 +310,14 @@ def early_stopping_callbacks(self) -> List[EarlyStopping]:
"""
return [c for c in self.callbacks if isinstance(c, EarlyStopping)]

@property
def prediction_writer_callbacks(self) -> Optional[List[BasePredictionWriter]]:
"""
A list of all instances of :class:`~pytorch_lightning.callbacks.prediction_writer.BasePredictionWriter`
found in the Trainer.callbacks list.
"""
return [c for c in self.callbacks if isinstance(c, BasePredictionWriter)]
tchaton marked this conversation as resolved.
Show resolved Hide resolved

@property
def checkpoint_callback(self) -> Optional[ModelCheckpoint]:
"""
Expand Down
69 changes: 69 additions & 0 deletions tests/callbacks/test_prediction_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import BasePredictionWriter
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel


def test_prediction_writer(tmpdir):

class CustomPredictionWriter(BasePredictionWriter):

def __init__(self, writer_interval: str):
super().__init__(writer_interval)

self.write_on_batch_end_called = False
self.write_on_epoch_end_called = False

def write_on_batch_end(self, *args, **kwargs):
self.write_on_batch_end_called = True

def write_on_epoch_end(self, *args, **kwargs):
self.write_on_epoch_end_called = True

with pytest.raises(MisconfigurationException, match=r"`write_interval` should be one of \['batch"):
CustomPredictionWriter("something")

model = BoringModel()
cb = CustomPredictionWriter("batch_and_epoch")
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
results = trainer.predict(model, dataloaders=model.train_dataloader())
assert len(results) == 4
assert cb.write_on_batch_end_called
assert cb.write_on_epoch_end_called

cb = CustomPredictionWriter("batch_and_epoch")
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
results = trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
assert cb.write_on_batch_end_called
assert cb.write_on_epoch_end_called
assert results == 1

cb = CustomPredictionWriter("batch")
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
results = trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
assert cb.write_on_batch_end_called
assert not cb.write_on_epoch_end_called
assert results == 1

cb = CustomPredictionWriter("epoch")
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
results = trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
assert not cb.write_on_batch_end_called
assert cb.write_on_epoch_end_called
assert results == 1
2 changes: 1 addition & 1 deletion tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_epoch_end(self, outputs) -> None:

class StoreHistoryLogger(logger_class):

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.history = []

Expand Down
2 changes: 1 addition & 1 deletion tests/loggers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_adding_step_key(tmpdir):

class CustomTensorBoardLogger(TensorBoardLogger):

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.logged_step = 0

Expand Down
49 changes: 44 additions & 5 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import tests.helpers.utils as tutils
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler
Expand Down Expand Up @@ -1512,7 +1513,33 @@ def predict_dataloader(self):
return self._dataloaders


class CustomPredictionWriter(Callback):
class CustomPredictionWriter(BasePredictionWriter):

write_on_batch_end_called = False
write_on_epoch_end_called = False

def __init__(self, output_dir: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.output_dir = output_dir

def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, *args, **kwargs):
assert prediction.shape == torch.Size([1, 2])
if trainer.accelerator_connector.is_distributed:
assert len(batch_indices) == 1
else:
assert batch_indices is None
self.write_on_batch_end_called = True

def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
expected = 1 if trainer.accelerator_connector.is_distributed else 2
assert len(predictions) == 2
assert len(predictions[0]) == expected
if trainer.accelerator_connector.is_distributed:
assert len(batch_indices) == 2
assert len(batch_indices[0]) == expected
else:
assert batch_indices is None
self.write_on_epoch_end_called = True

def on_predict_epoch_end(self, trainer, pl_module, outputs):
if trainer.accelerator_connector.is_distributed:
Expand All @@ -1522,12 +1549,17 @@ def on_predict_epoch_end(self, trainer, pl_module, outputs):
super().on_predict_epoch_end(trainer, pl_module, outputs)


def predict(tmpdir, accelerator, gpus, num_processes, model=None, plugins=None, datamodule=True, pbrr=None):
def predict(
tmpdir, accelerator, gpus, num_processes, model=None, plugins=None, datamodule=True, pbrr=None, use_callbacks=True
):
Comment on lines +1552 to +1554
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to be careful that we don't overload these test helper functions with complexity. If the test functions are too complex we would need tests for the tests and so this goes in circles xD

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also rather keep them protected so noone would import them...

dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))]

model = model or BoringModel()
dm = TestLightningDataModule(dataloaders)

cb = CustomPredictionWriter(tmpdir, write_interval="batch")
cb_1 = CustomPredictionWriter(tmpdir, write_interval="epoch")

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
Expand All @@ -1538,7 +1570,7 @@ def predict(tmpdir, accelerator, gpus, num_processes, model=None, plugins=None,
num_processes=num_processes,
plugins=plugins,
progress_bar_refresh_rate=pbrr,
callbacks=[CustomPredictionWriter()]
callbacks=[cb, cb_1] if use_callbacks else []
)
if accelerator == "ddp_spawn":
with pytest.raises(MisconfigurationException):
Expand All @@ -1550,6 +1582,13 @@ def predict(tmpdir, accelerator, gpus, num_processes, model=None, plugins=None,
results = trainer.predict(model, dataloaders=dataloaders)

if not isinstance(trainer.training_type_plugin, DDPSpawnPlugin):
if use_callbacks:
assert cb.write_on_batch_end_called
assert not cb.write_on_epoch_end_called

assert not cb_1.write_on_batch_end_called
assert cb_1.write_on_epoch_end_called

num_samples = 1 if accelerator == "ddp" else 2
assert len(results) == 2
assert len(results[0]) == num_samples
Expand All @@ -1572,7 +1611,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
return super().predict_step(batch, batch_idx, dataloader_idx)

with pytest.warns(UserWarning, match='predict returned None'):
predict(tmpdir, None, None, 1, model=CustomBoringModel())
predict(tmpdir, None, None, 1, model=CustomBoringModel(), use_callbacks=False)


def test_trainer_predict_grad(tmpdir):
Expand All @@ -1583,7 +1622,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
assert batch.expand_as(batch).grad_fn is None
return super().predict_step(batch, batch_idx, dataloader_idx)

predict(tmpdir, None, None, 1, model=CustomBoringModel())
predict(tmpdir, None, None, 1, model=CustomBoringModel(), use_callbacks=False)

x = torch.zeros(1, requires_grad=True)
assert x.expand_as(x).grad_fn is not None
Expand Down