Skip to content

Commit

Permalink
feat: Enable prediction step during training (#266)
Browse files Browse the repository at this point in the history
### Description

Following #148, I have been
exploring how to predict during training. This PR would allow adding
`Callback` that use `predict_step` during Training.

- **What**: Allow callbacks to call `predict_step` during training.
- **Why**: Some applications might require predicting consistently on
full images to assess training performances throughout training.
- **How**: Modified `FCNModule.predict_step` to make it compatible with
a `TrainDataModule` (all calls to `trainer.datamodule` were written with
the expectation that it returns a `PredictDataModule`.

### Changes Made

- **Modified**: `lightning_module.py`, `test_lightning_module.py`

### Related Issues

- Resolves #148

### Additional Notes and Examples

```python
    import numpy as np
    from pytorch_lightning import Callback, Trainer

    from careamics import CAREamist, Configuration
    from careamics.lightning import PredictDataModule, create_predict_datamodule
    from careamics.prediction_utils import convert_outputs

    config = Configuration(**minimum_configuration)

    class CustomPredictAfterValidationCallback(Callback):
        def __init__(self, pred_datamodule: PredictDataModule):
            self.pred_datamodule = pred_datamodule

            # prepare data and setup
            self.pred_datamodule.prepare_data()
            self.pred_datamodule.setup()
            self.pred_dataloader = pred_datamodule.predict_dataloader()

        def on_validation_epoch_end(self, trainer: Trainer, pl_module):
            if trainer.sanity_checking:  # optional skip
                return

            # update statistics in the prediction dataset for coherence
            # (they can computed on-line by the training dataset)
            self.pred_datamodule.predict_dataset.image_means = (
                trainer.datamodule.train_dataset.image_stats.means
            )
            self.pred_datamodule.predict_dataset.image_stds = (
                trainer.datamodule.train_dataset.image_stats.stds
            )

            # predict on the dataset
            outputs = []
            for idx, batch in enumerate(self.pred_dataloader):
                batch = pl_module.transfer_batch_to_device(batch, pl_module.device, 0)
                outputs.append(pl_module.predict_step(batch, batch_idx=idx))

           data = convert_outputs(outputs, self.pred_datamodule.tiled)
           # can save data here

    array = np.arange(32 * 32).reshape((32, 32))
    pred_datamodule = create_predict_datamodule(
        pred_data=array,
        data_type=config.data_config.data_type,
        axes=config.data_config.axes,
        image_means=[11.8], # random placeholder
        image_stds=[3.14],
        # can do tiling here
    )

    predict_after_val_callback = CustomPredictAfterValidationCallback(
        pred_datamodule=pred_datamodule
    )
    engine = CAREamist(config, callbacks=[predict_after_val_callback])
    engine.train(train_source=array)
```



Currently, this current implementation is not fully satisfactory and
here are a few important points:

- For this PR to work we need to discriminate between `TrainDataModule`
and `PredictDataModule` in `predict_step`, which is a bit of a hack as
it currently check `hasattr(..., "tiled")`. The reason is to avoid a
circular import of `PredictDataModule`. We should revisit that.
- `TrainDataModule` and `PredictDataModule` have incompatible members:
`PredictDataModule` has `.tiled`, and the two have different naming
conventions for the statistics (`PredictDataModule` has `image_means`
and `image_stds`, while `TrainDataModule` has them wrapped in a `stats`
dataclass). These statistics are retrieved either through
`_trainer.datamodule.predict_dataset` or
`_trainer.datamodule.train_dataset`.
- We do not provide the `Callable` that would allow to use such feature.
We might want to some heavy lifting here as well (see example).
- Probably the most serious issue, normalization is done in the datasets
but denormalization is performed in the `predict_step`. In our case,
that means that normalization could be applied by a `PredictDataModule`
(in the `Callback` and the denormalization by the `TrainDataModule` (in
`predict_step`). That is incoherent and due to the way we wrote
CAREamics.

All in all, this draft exemplifies two problems with CAREamics:
- `TrainDataModule` and `PredictDataModule` have different members
- Normalization is done by the `DataModule` but denormalization by
`LightningModule`

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [x] PR to the documentation exists (for bug fixes / features)
  • Loading branch information
jdeschamps authored Nov 15, 2024
1 parent aa79625 commit c7e2912
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 4 deletions.
31 changes: 27 additions & 4 deletions src/careamics/lightning/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SupportedOptimizer,
SupportedScheduler,
)
from careamics.config.tile_information import TileInformation
from careamics.losses import loss_factory
from careamics.models.lvae.likelihoods import (
GaussianLikelihood,
Expand Down Expand Up @@ -163,15 +164,28 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
Any
Model output.
"""
if self._trainer.datamodule.tiled:
# TODO refactor when redoing datasets
# hacky way to determine if it is PredictDataModule, otherwise there is a
# circular import to solve with isinstance
from_prediction = hasattr(self._trainer.datamodule, "tiled")
is_tiled = (
len(batch) > 1
and isinstance(batch[1], list)
and isinstance(batch[1][0], TileInformation)
)

if is_tiled:
x, *aux = batch
else:
x = batch
aux = []

# apply test-time augmentation if available
# TODO: probably wont work with batch size > 1
if self._trainer.datamodule.prediction_config.tta_transforms:
if (
from_prediction
and self._trainer.datamodule.prediction_config.tta_transforms
):
tta = ImageRestorationTTA()
augmented_batch = tta.forward(x) # list of augmented tensors
augmented_output = []
Expand All @@ -183,9 +197,18 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
output = self.model(x)

# Denormalize the output
# TODO incompatible API between predict and train datasets
denorm = Denormalize(
image_means=self._trainer.datamodule.predict_dataset.image_means,
image_stds=self._trainer.datamodule.predict_dataset.image_stds,
image_means=(
self._trainer.datamodule.predict_dataset.image_means
if from_prediction
else self._trainer.datamodule.train_dataset.image_stats.means
),
image_stds=(
self._trainer.datamodule.predict_dataset.image_stds
if from_prediction
else self._trainer.datamodule.train_dataset.image_stats.stds
),
)
denormalized_output = denorm(patch=output.cpu().numpy())

Expand Down
64 changes: 64 additions & 0 deletions tests/lightning/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,67 @@ def test_fcn_module_unet_depth_3_channels_3D(n_channels):
x = torch.rand((1, n_channels, 16, 64, 64))
y: torch.Tensor = model.forward(x)
assert y.shape == x.shape


@pytest.mark.parametrize("tiled", [False, True])
def test_prediction_callback_during_training(minimum_configuration, tiled):
import numpy as np
from pytorch_lightning import Callback, Trainer

from careamics import CAREamist, Configuration
from careamics.lightning import PredictDataModule, create_predict_datamodule
from careamics.prediction_utils import convert_outputs

config = Configuration(**minimum_configuration)

class CustomPredictAfterValidationCallback(Callback):
def __init__(self, pred_datamodule: PredictDataModule):
self.pred_datamodule = pred_datamodule

# prepare data and setup
self.pred_datamodule.prepare_data()
self.pred_datamodule.setup()
self.pred_dataloader = pred_datamodule.predict_dataloader()

self.data = None

def on_validation_epoch_end(self, trainer: Trainer, pl_module):
if trainer.sanity_checking: # optional skip
return

# update statistics in the prediction dataset for coherence
# (they can computed on-line by the training dataset)
self.pred_datamodule.predict_dataset.image_means = (
trainer.datamodule.train_dataset.image_stats.means
)
self.pred_datamodule.predict_dataset.image_stds = (
trainer.datamodule.train_dataset.image_stats.stds
)

# predict on the dataset
outputs = []
for idx, batch in enumerate(self.pred_dataloader):
batch = pl_module.transfer_batch_to_device(batch, pl_module.device, 0)
outputs.append(pl_module.predict_step(batch, batch_idx=idx))

self.data = convert_outputs(outputs, self.pred_datamodule.tiled)

array = np.arange(64 * 64).reshape((64, 64))
pred_datamodule = create_predict_datamodule(
pred_data=array,
data_type=config.data_config.data_type,
axes=config.data_config.axes,
image_means=[11.8], # random placeholder
image_stds=[3.14],
tile_size=(16, 16) if tiled else None,
tile_overlap=(8, 8) if tiled else None,
batch_size=2,
)

predict_after_val_callback = CustomPredictAfterValidationCallback(
pred_datamodule=pred_datamodule
)
engine = CAREamist(config, callbacks=[predict_after_val_callback])
engine.train(train_source=array)

assert not np.allclose(array, predict_after_val_callback.data)

0 comments on commit c7e2912

Please sign in to comment.