diff --git a/src/careamics/lightning/lightning_module.py b/src/careamics/lightning/lightning_module.py index dc20ba7a..755b9c65 100644 --- a/src/careamics/lightning/lightning_module.py +++ b/src/careamics/lightning/lightning_module.py @@ -271,6 +271,12 @@ def __init__(self, algorithm_config: Union[VAEAlgorithmConfig, dict]) -> None: self.noise_model: NoiseModel = noise_model_factory( self.algorithm_config.noise_model ) + # TODO: here we can add some code to check whether the noise model is not None + # and `self.algorithm_config.noise_model_likelihood_model.noise_model` is, + # instead, None. In that case we could assign the noise model to the latter. + # This is particular useful when loading an algorithm config from file. + # Indeed, in that case the noise model in the nm likelihood is likely + # not available since excluded from serializaion. self.noise_model_likelihood: NoiseModelLikelihood = likelihood_factory( self.algorithm_config.noise_model_likelihood_model ) diff --git a/src/careamics/losses/loss_factory.py b/src/careamics/losses/loss_factory.py index 5066b970..2a7fabed 100644 --- a/src/careamics/losses/loss_factory.py +++ b/src/careamics/losses/loss_factory.py @@ -56,9 +56,9 @@ class LVAELossParameters: reconstruction_weight: float = 1.0 """Weight for the reconstruction loss in the total net loss (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`).""" - musplit_weight: float = 0.0 - """Weight for the muSplit loss (used in the muSplit-deonoiSplit loss).""" - denoisplit_weight: float = 1.0 + musplit_weight: float = 0.1 + """Weight for the muSplit loss (used in the muSplit-denoiSplit loss).""" + denoisplit_weight: float = 0.9 """Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss).""" kl_type: Literal["kl", "kl_restricted", "kl_spatial", "kl_channelwise"] = "kl" """Type of KL divergence used as KL loss.""" diff --git a/src/careamics/losses/lvae/losses.py b/src/careamics/losses/lvae/losses.py index fc041e1c..41c09335 100644 --- a/src/careamics/losses/lvae/losses.py +++ b/src/careamics/losses/lvae/losses.py @@ -137,8 +137,8 @@ def reconstruction_loss_musplit_denoisplit( recons_loss : torch.Tensor The reconstruction loss. Shape is (1, ). """ - # TODO: is this safe to check for predict_logvar value? - # otherwise use `gaussian_likelihood.predict_logvar` (or both) + # TODO: refactor this function to make it closer to `get_reconstruction_loss` + # (or viceversa) if predictions.shape[1] == 2 * targets.shape[1]: # predictions contain both mean and log-variance out_mean, _ = predictions.chunk(2, dim=1) diff --git a/src/careamics/lvae_training/eval_utils.py b/src/careamics/lvae_training/eval_utils.py index 6a5bdd7b..92b5f3c9 100644 --- a/src/careamics/lvae_training/eval_utils.py +++ b/src/careamics/lvae_training/eval_utils.py @@ -14,13 +14,19 @@ import matplotlib.pyplot as plt import numpy as np import torch +from torch import nn +from torch.utils.data import Dataset from matplotlib.gridspec import GridSpec from torch.utils.data import DataLoader from tqdm import tqdm +from careamics.lightning import VAEModule +from careamics.losses.lvae.losses import ( + get_reconstruction_loss, + reconstruction_loss_musplit_denoisplit, +) from careamics.models.lvae.utils import ModelType - -from .metrics import RangeInvariantPsnr, RunningPSNR +from careamics.utils.metrics import scale_invariant_psnr, RunningPSNR # ------------------------------------------------------------------------------------------------ @@ -40,28 +46,26 @@ def clean_ax(ax): ax.tick_params(left=False, right=False, top=False, bottom=False) -def get_plots_output_dir( +def get_eval_output_dir( saveplotsdir: str, patch_size: int, mmse_count: int = 50 ) -> str: """ Given the path to a root directory to save plots, patch size, and mmse count, it returns the specific directory to save the plots. """ - plotsrootdir = os.path.join( - saveplotsdir, f"plots/patch_{patch_size}_mmse_{mmse_count}" + eval_out_dir = os.path.join( + saveplotsdir, f"eval_outputs/patch_{patch_size}_mmse_{mmse_count}" ) - os.makedirs(plotsrootdir, exist_ok=True) - print(plotsrootdir) - return plotsrootdir + os.makedirs(eval_out_dir, exist_ok=True) + print(eval_out_dir) + return eval_out_dir def get_psnr_str(tar_hsnr, pred, col_idx): """ Compute PSNR between the ground truth (`tar_hsnr`) and the predicted image (`pred`). """ - return ( - f"{RangeInvariantPsnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item():.1f}" - ) + return f"{scale_invariant_psnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item():.1f}" def add_psnr_str(ax_, psnr): @@ -499,20 +503,40 @@ def get_predictions(idx, val_dset, model, mmse_count=50, patch_size=256): def get_dset_predictions( - model, - dset, + model: VAEModule, + dset: Dataset, batch_size: int, - model_type: ModelType = None, + loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"], mmse_count: int = 1, num_workers: int = 4, -): - """ - Get predictions from a model for the entire dataset. +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[float]]: + """Get patch-wise predictions from a model for the entire dataset. Parameters ---------- - mmse_count : int - Number of samples to generate for each input and then to average over for MMSE estimation. + model : VAEModule + Lightning model used for prediction. + dset : Dataset + Dataset to predict on. + batch_size : int + Batch size to use for prediction. + loss_type : + Type of reconstruction loss used by the model, by default `None`. + mmse_count : int, optional + Number of samples to generate for each input and then to average over for + MMSE estimation, by default 1. + num_workers : int, optional + Number of workers to use for DataLoader, by default 4. + + Returns + ------- + tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[float]] + Tuple containing: + - predictions: Predicted images for the dataset. + - predictions_std: Standard deviation of the predicted images. + - logvar_arr: Log variance of the predicted images. + - losses: Reconstruction losses for the predictions. + - psnr: PSNR values for the predictions. """ dloader = DataLoader( dset, @@ -521,69 +545,90 @@ def get_dset_predictions( shuffle=False, batch_size=batch_size, ) - likelihood = model.model.likelihood + + gauss_likelihood = model.gaussian_likelihood + nm_likelihood = model.noise_model_likelihood + predictions = [] predictions_std = [] losses = [] logvar_arr = [] - patch_psnr_channels = [RunningPSNR() for _ in range(dset[0][1].shape[0])] + num_channels = dset[0][1].shape[0] + patch_psnr_channels = [RunningPSNR() for _ in range(num_channels)] with torch.no_grad(): - for batch in tqdm(dloader): - inp, tar = batch[:2] + for batch in tqdm(dloader, desc="Predicting patches"): + inp, tar = batch inp = inp.cuda() tar = tar.cuda() - recon_img_list = [] + rec_img_list = [] for mmse_idx in range(mmse_count): - if model_type == ModelType.Denoiser: - assert model.denoise_channel in [ - "Ch1", - "Ch2", - "input", - ], '"all" denoise channel not supported for evaluation. Pick one of "Ch1", "Ch2", "input"' - - x_normalized_new, tar_new = model.get_new_input_target( - (inp, tar, *batch[2:]) + + # TODO: case of HDN left for future refactoring + # if model_type == ModelType.Denoiser: + # assert model.denoise_channel in [ + # "Ch1", + # "Ch2", + # "input", + # ], '"all" denoise channel not supported for evaluation. Pick one of "Ch1", "Ch2", "input"' + + # x_normalized_new, tar_new = model.get_new_input_target( + # (inp, tar, *batch[2:]) + # ) + # rec, _ = model(x_normalized_new) + # rec_loss, imgs = model.get_reconstruction_loss( + # rec, + # tar, + # x_normalized_new, + # return_predicted_img=True, + # ) + + # get model output + rec, _ = model(inp) + + # get reconstructed img + if model.model.predict_logvar is None: + rec_img = rec + logvar = torch.tensor([-1]) + else: + rec_img, logvar = torch.chunk(rec, chunks=2, dim=1) + rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim + logvar_arr.append(logvar.cpu().numpy()) + + # compute reconstruction loss + if loss_type == "musplit": + rec_loss = get_reconstruction_loss( + reconstruction=rec, target=tar, likelihood_obj=gauss_likelihood ) - tar_normalized = model.normalize_target(tar_new) - recon_normalized, _ = model(x_normalized_new) - rec_loss, imgs = model.get_reconstruction_loss( - recon_normalized, - tar_normalized, - x_normalized_new, - return_predicted_img=True, + elif loss_type == "denoisplit": + rec_loss = get_reconstruction_loss( + reconstruction=rec, target=tar, likelihood_obj=nm_likelihood ) - else: - x_normalized = model.normalize_input(inp) - tar_normalized = model.normalize_target(tar) - recon_normalized, _ = model(x_normalized) - rec_loss, imgs = model.get_reconstruction_loss( - recon_normalized, tar_normalized, inp, return_predicted_img=True + elif loss_type == "denoisplit_musplit": + rec_loss = reconstruction_loss_musplit_denoisplit( + predictions=rec, + targets=tar, + gaussian_likelihood=gauss_likelihood, + nm_likelihood=nm_likelihood, + nm_weight=model.loss_parameters.denoisplit_weight, + gaussian_weight=model.loss_parameters.musplit_weight, ) + rec_loss = {"loss": rec_loss} # hacky, but ok for now + # store rec loss values for first pred if mmse_idx == 0: - q_dic = ( - likelihood.distr_params(recon_normalized) - if likelihood is not None - else {"logvar": None} - ) - if q_dic["logvar"] is not None: - logvar_arr.append(q_dic["logvar"].cpu().numpy()) - else: - logvar_arr.append(np.array([-1])) - try: losses.append(rec_loss["loss"].cpu().numpy()) except: losses.append(rec_loss["loss"]) - for i in range(imgs.shape[1]): - patch_psnr_channels[i].update(imgs[:, i], tar_normalized[:, i]) + # update running PSNR + for i in range(num_channels): + patch_psnr_channels[i].update(rec_img[:, i], tar[:, i]) - recon_img_list.append(imgs.cpu()[None]) - - samples = torch.cat(recon_img_list, dim=0) - mmse_imgs = torch.mean(samples, dim=0) + # aggregate results + samples = torch.cat(rec_img_list, dim=0) + mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim mmse_std = torch.std(samples, dim=0) predictions.append(mmse_imgs.cpu().numpy()) predictions_std.append(mmse_std.cpu().numpy()) @@ -591,10 +636,10 @@ def get_dset_predictions( psnr = [x.get() for x in patch_psnr_channels] return ( np.concatenate(predictions, axis=0), - np.array(losses), + np.concatenate(predictions_std, axis=0), np.concatenate(logvar_arr), + np.array(losses), psnr, - np.concatenate(predictions_std, axis=0), ) diff --git a/src/careamics/models/lvae/likelihoods.py b/src/careamics/models/lvae/likelihoods.py index 47d1a4ec..58529b30 100644 --- a/src/careamics/models/lvae/likelihoods.py +++ b/src/careamics/models/lvae/likelihoods.py @@ -7,6 +7,7 @@ import math from typing import Literal, Union, TYPE_CHECKING, Any, Optional +import numpy as np import torch from torch import nn @@ -287,30 +288,37 @@ class NoiseModelLikelihood(LikelihoodModule): def __init__( self, - data_mean: torch.Tensor, - data_std: torch.Tensor, - noiseModel: NoiseModel, # TODO: check the type -> couldn't manage due to circular imports... + data_mean: Union[np.ndarray, torch.Tensor], + data_std: Union[np.ndarray, torch.Tensor], + noiseModel: NoiseModel, ): """Constructor. Parameters ---------- - data_mean: torch.Tensor + data_mean: Union[np.ndarray, torch.Tensor] The mean of the data, used to unnormalize data for noise model evaluation. - data_std: torch.Tensor + data_std: Union[np.ndarray, torch.Tensor] The standard deviation of the data, used to unnormalize data for noise model evaluation. noiseModel: NoiseModel The noise model instance used to compute the likelihood. """ super().__init__() - self.data_mean = data_mean - self.data_std = data_std + self.data_mean = torch.Tensor(data_mean) + self.data_std = torch.Tensor(data_std) self.noiseModel = noiseModel - def set_params_to_same_device_as( + def _set_params_to_same_device_as( self, correct_device_tensor: torch.Tensor - ) -> None: # TODO: needed? + ) -> None: + """Set the parameters to the same device as the input tensor. + + Parameters + ---------- + correct_device_tensor: torch.Tensor + The tensor whose device is used to set the parameters. + """ if self.data_mean.device != correct_device_tensor.device: self.data_mean = self.data_mean.to(correct_device_tensor.device) self.data_std = self.data_std.to(correct_device_tensor.device) @@ -355,6 +363,7 @@ def log_likelihood(self, x: torch.Tensor, params: dict[str, torch.Tensor]): torch.Tensor The log-likelihood tensor. Shape is (B, C, [Z], Y, X). """ + self._set_params_to_same_device_as(x) predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean x_denormalized = x * self.data_std + self.data_mean likelihoods = self.noiseModel.likelihood( diff --git a/src/careamics/models/lvae/lvae.py b/src/careamics/models/lvae/lvae.py index 97bb2941..b44d4594 100644 --- a/src/careamics/models/lvae/lvae.py +++ b/src/careamics/models/lvae/lvae.py @@ -795,11 +795,18 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor # return samples - # def reset_for_different_output_size(self, output_size): - # for i in range(self.n_layers): - # sz = output_size // 2**(1 + i) - # self.bottom_up_layers[i].output_expected_shape = (sz, sz) - # self.top_down_layers[i].latent_shape = (output_size, output_size) + def reset_for_different_output_size(self, output_size: int) -> None: + """Reset shape of output and latent tensors for different output size. + + Used during evaluation to reset expected shapes of tensors when + input/output shape changes. + For instance, it is needed when the model was trained on, say, 64x64 sized + patches, but prediction is done on 128x128 patches. + """ + for i in range(self.n_layers): + sz = output_size // 2 ** (1 + i) + self.bottom_up_layers[i].output_expected_shape = (sz, sz) + self.top_down_layers[i].latent_shape = (output_size, output_size) def pad_input(self, x): """ diff --git a/src/careamics/utils/metrics.py b/src/careamics/utils/metrics.py index 92499598..cc6229fa 100644 --- a/src/careamics/utils/metrics.py +++ b/src/careamics/utils/metrics.py @@ -45,12 +45,12 @@ def _zero_mean(x: np.ndarray) -> np.ndarray: Parameters ---------- - x : NumPy array + x : np.ndarray Input array. Returns ------- - NumPy array + np.ndarray Zero-mean array. """ return x - np.mean(x)