Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Enable prediction step during training (#266)
### Description Following #148, I have been exploring how to predict during training. This PR would allow adding `Callback` that use `predict_step` during Training. - **What**: Allow callbacks to call `predict_step` during training. - **Why**: Some applications might require predicting consistently on full images to assess training performances throughout training. - **How**: Modified `FCNModule.predict_step` to make it compatible with a `TrainDataModule` (all calls to `trainer.datamodule` were written with the expectation that it returns a `PredictDataModule`. ### Changes Made - **Modified**: `lightning_module.py`, `test_lightning_module.py` ### Related Issues - Resolves #148 ### Additional Notes and Examples ```python import numpy as np from pytorch_lightning import Callback, Trainer from careamics import CAREamist, Configuration from careamics.lightning import PredictDataModule, create_predict_datamodule from careamics.prediction_utils import convert_outputs config = Configuration(**minimum_configuration) class CustomPredictAfterValidationCallback(Callback): def __init__(self, pred_datamodule: PredictDataModule): self.pred_datamodule = pred_datamodule # prepare data and setup self.pred_datamodule.prepare_data() self.pred_datamodule.setup() self.pred_dataloader = pred_datamodule.predict_dataloader() def on_validation_epoch_end(self, trainer: Trainer, pl_module): if trainer.sanity_checking: # optional skip return # update statistics in the prediction dataset for coherence # (they can computed on-line by the training dataset) self.pred_datamodule.predict_dataset.image_means = ( trainer.datamodule.train_dataset.image_stats.means ) self.pred_datamodule.predict_dataset.image_stds = ( trainer.datamodule.train_dataset.image_stats.stds ) # predict on the dataset outputs = [] for idx, batch in enumerate(self.pred_dataloader): batch = pl_module.transfer_batch_to_device(batch, pl_module.device, 0) outputs.append(pl_module.predict_step(batch, batch_idx=idx)) data = convert_outputs(outputs, self.pred_datamodule.tiled) # can save data here array = np.arange(32 * 32).reshape((32, 32)) pred_datamodule = create_predict_datamodule( pred_data=array, data_type=config.data_config.data_type, axes=config.data_config.axes, image_means=[11.8], # random placeholder image_stds=[3.14], # can do tiling here ) predict_after_val_callback = CustomPredictAfterValidationCallback( pred_datamodule=pred_datamodule ) engine = CAREamist(config, callbacks=[predict_after_val_callback]) engine.train(train_source=array) ``` Currently, this current implementation is not fully satisfactory and here are a few important points: - For this PR to work we need to discriminate between `TrainDataModule` and `PredictDataModule` in `predict_step`, which is a bit of a hack as it currently check `hasattr(..., "tiled")`. The reason is to avoid a circular import of `PredictDataModule`. We should revisit that. - `TrainDataModule` and `PredictDataModule` have incompatible members: `PredictDataModule` has `.tiled`, and the two have different naming conventions for the statistics (`PredictDataModule` has `image_means` and `image_stds`, while `TrainDataModule` has them wrapped in a `stats` dataclass). These statistics are retrieved either through `_trainer.datamodule.predict_dataset` or `_trainer.datamodule.train_dataset`. - We do not provide the `Callable` that would allow to use such feature. We might want to some heavy lifting here as well (see example). - Probably the most serious issue, normalization is done in the datasets but denormalization is performed in the `predict_step`. In our case, that means that normalization could be applied by a `PredictDataModule` (in the `Callback` and the denormalization by the `TrainDataModule` (in `predict_step`). That is incoherent and due to the way we wrote CAREamics. All in all, this draft exemplifies two problems with CAREamics: - `TrainDataModule` and `PredictDataModule` have different members - Normalization is done by the `DataModule` but denormalization by `LightningModule` --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [x] PR to the documentation exists (for bug fixes / features)
- Loading branch information