diff --git a/src/careamics/lightning/lightning_module.py b/src/careamics/lightning/lightning_module.py index 438d8080..d9b65ff8 100644 --- a/src/careamics/lightning/lightning_module.py +++ b/src/careamics/lightning/lightning_module.py @@ -14,6 +14,7 @@ SupportedOptimizer, SupportedScheduler, ) +from careamics.config.tile_information import TileInformation from careamics.losses import loss_factory from careamics.models.lvae.likelihoods import ( GaussianLikelihood, @@ -163,7 +164,17 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any: Any Model output. """ - if self._trainer.datamodule.tiled: + # TODO refactor when redoing datasets + # hacky way to determine if it is PredictDataModule, otherwise there is a + # circular import to solve with isinstance + from_prediction = hasattr(self._trainer.datamodule, "tiled") + is_tiled = ( + len(batch) > 1 + and isinstance(batch[1], list) + and isinstance(batch[1][0], TileInformation) + ) + + if is_tiled: x, *aux = batch else: x = batch @@ -171,7 +182,10 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any: # apply test-time augmentation if available # TODO: probably wont work with batch size > 1 - if self._trainer.datamodule.prediction_config.tta_transforms: + if ( + from_prediction + and self._trainer.datamodule.prediction_config.tta_transforms + ): tta = ImageRestorationTTA() augmented_batch = tta.forward(x) # list of augmented tensors augmented_output = [] @@ -183,9 +197,18 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any: output = self.model(x) # Denormalize the output + # TODO incompatible API between predict and train datasets denorm = Denormalize( - image_means=self._trainer.datamodule.predict_dataset.image_means, - image_stds=self._trainer.datamodule.predict_dataset.image_stds, + image_means=( + self._trainer.datamodule.predict_dataset.image_means + if from_prediction + else self._trainer.datamodule.train_dataset.image_stats.means + ), + image_stds=( + self._trainer.datamodule.predict_dataset.image_stds + if from_prediction + else self._trainer.datamodule.train_dataset.image_stats.stds + ), ) denormalized_output = denorm(patch=output.cpu().numpy()) diff --git a/tests/lightning/test_lightning_module.py b/tests/lightning/test_lightning_module.py index 0b0eaafc..fc809487 100644 --- a/tests/lightning/test_lightning_module.py +++ b/tests/lightning/test_lightning_module.py @@ -308,3 +308,67 @@ def test_fcn_module_unet_depth_3_channels_3D(n_channels): x = torch.rand((1, n_channels, 16, 64, 64)) y: torch.Tensor = model.forward(x) assert y.shape == x.shape + + +@pytest.mark.parametrize("tiled", [False, True]) +def test_prediction_callback_during_training(minimum_configuration, tiled): + import numpy as np + from pytorch_lightning import Callback, Trainer + + from careamics import CAREamist, Configuration + from careamics.lightning import PredictDataModule, create_predict_datamodule + from careamics.prediction_utils import convert_outputs + + config = Configuration(**minimum_configuration) + + class CustomPredictAfterValidationCallback(Callback): + def __init__(self, pred_datamodule: PredictDataModule): + self.pred_datamodule = pred_datamodule + + # prepare data and setup + self.pred_datamodule.prepare_data() + self.pred_datamodule.setup() + self.pred_dataloader = pred_datamodule.predict_dataloader() + + self.data = None + + def on_validation_epoch_end(self, trainer: Trainer, pl_module): + if trainer.sanity_checking: # optional skip + return + + # update statistics in the prediction dataset for coherence + # (they can computed on-line by the training dataset) + self.pred_datamodule.predict_dataset.image_means = ( + trainer.datamodule.train_dataset.image_stats.means + ) + self.pred_datamodule.predict_dataset.image_stds = ( + trainer.datamodule.train_dataset.image_stats.stds + ) + + # predict on the dataset + outputs = [] + for idx, batch in enumerate(self.pred_dataloader): + batch = pl_module.transfer_batch_to_device(batch, pl_module.device, 0) + outputs.append(pl_module.predict_step(batch, batch_idx=idx)) + + self.data = convert_outputs(outputs, self.pred_datamodule.tiled) + + array = np.arange(64 * 64).reshape((64, 64)) + pred_datamodule = create_predict_datamodule( + pred_data=array, + data_type=config.data_config.data_type, + axes=config.data_config.axes, + image_means=[11.8], # random placeholder + image_stds=[3.14], + tile_size=(16, 16) if tiled else None, + tile_overlap=(8, 8) if tiled else None, + batch_size=2, + ) + + predict_after_val_callback = CustomPredictAfterValidationCallback( + pred_datamodule=pred_datamodule + ) + engine = CAREamist(config, callbacks=[predict_after_val_callback]) + engine.train(train_source=array) + + assert not np.allclose(array, predict_after_val_callback.data)