Skip to content

Commit

Permalink
fix: fixing various bugs to enable microSplit training (CAREamics#242)
Browse files Browse the repository at this point in the history
### Description

This PR aims at solving a variety of bugs present throughout the LVAE
codebase to allow training and evaluation of microSplit family of
models. Examples of how to train the different versions of microSplit
are available [in this
repo](https://github.com/CAREamics/microSplit-reproducibility).
Also added comments about things to be dealt with in future iterations.

- **What**: Fixed bugs in different places interfering with microSplit
training and evaluation. Now it is possible to train the model.
- **Why**: Of course we need a way to train and evaluate the models ;)
- **How**: Depending on the particular bug encountered.

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [ ] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: CatEek <zubarev.ia@gmail.com>
Co-authored-by: melisande-c <milly.croft@gmail.com>
  • Loading branch information
4 people authored Sep 6, 2024
1 parent 1b60f07 commit 26a3dc7
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 85 deletions.
6 changes: 6 additions & 0 deletions src/careamics/lightning/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,12 @@ def __init__(self, algorithm_config: Union[VAEAlgorithmConfig, dict]) -> None:
self.noise_model: NoiseModel = noise_model_factory(
self.algorithm_config.noise_model
)
# TODO: here we can add some code to check whether the noise model is not None
# and `self.algorithm_config.noise_model_likelihood_model.noise_model` is,
# instead, None. In that case we could assign the noise model to the latter.
# This is particular useful when loading an algorithm config from file.
# Indeed, in that case the noise model in the nm likelihood is likely
# not available since excluded from serializaion.
self.noise_model_likelihood: NoiseModelLikelihood = likelihood_factory(
self.algorithm_config.noise_model_likelihood_model
)
Expand Down
6 changes: 3 additions & 3 deletions src/careamics/losses/loss_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ class LVAELossParameters:
reconstruction_weight: float = 1.0
"""Weight for the reconstruction loss in the total net loss
(i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
musplit_weight: float = 0.0
"""Weight for the muSplit loss (used in the muSplit-deonoiSplit loss)."""
denoisplit_weight: float = 1.0
musplit_weight: float = 0.1
"""Weight for the muSplit loss (used in the muSplit-denoiSplit loss)."""
denoisplit_weight: float = 0.9
"""Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""
kl_type: Literal["kl", "kl_restricted", "kl_spatial", "kl_channelwise"] = "kl"
"""Type of KL divergence used as KL loss."""
Expand Down
4 changes: 2 additions & 2 deletions src/careamics/losses/lvae/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def reconstruction_loss_musplit_denoisplit(
recons_loss : torch.Tensor
The reconstruction loss. Shape is (1, ).
"""
# TODO: is this safe to check for predict_logvar value?
# otherwise use `gaussian_likelihood.predict_logvar` (or both)
# TODO: refactor this function to make it closer to `get_reconstruction_loss`
# (or viceversa)
if predictions.shape[1] == 2 * targets.shape[1]:
# predictions contain both mean and log-variance
out_mean, _ = predictions.chunk(2, dim=1)
Expand Down
173 changes: 109 additions & 64 deletions src/careamics/lvae_training/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset
from matplotlib.gridspec import GridSpec
from torch.utils.data import DataLoader
from tqdm import tqdm

from careamics.lightning import VAEModule
from careamics.losses.lvae.losses import (
get_reconstruction_loss,
reconstruction_loss_musplit_denoisplit,
)
from careamics.models.lvae.utils import ModelType

from .metrics import RangeInvariantPsnr, RunningPSNR
from careamics.utils.metrics import scale_invariant_psnr, RunningPSNR


# ------------------------------------------------------------------------------------------------
Expand All @@ -40,28 +46,26 @@ def clean_ax(ax):
ax.tick_params(left=False, right=False, top=False, bottom=False)


def get_plots_output_dir(
def get_eval_output_dir(
saveplotsdir: str, patch_size: int, mmse_count: int = 50
) -> str:
"""
Given the path to a root directory to save plots, patch size, and mmse count,
it returns the specific directory to save the plots.
"""
plotsrootdir = os.path.join(
saveplotsdir, f"plots/patch_{patch_size}_mmse_{mmse_count}"
eval_out_dir = os.path.join(
saveplotsdir, f"eval_outputs/patch_{patch_size}_mmse_{mmse_count}"
)
os.makedirs(plotsrootdir, exist_ok=True)
print(plotsrootdir)
return plotsrootdir
os.makedirs(eval_out_dir, exist_ok=True)
print(eval_out_dir)
return eval_out_dir


def get_psnr_str(tar_hsnr, pred, col_idx):
"""
Compute PSNR between the ground truth (`tar_hsnr`) and the predicted image (`pred`).
"""
return (
f"{RangeInvariantPsnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item():.1f}"
)
return f"{scale_invariant_psnr(tar_hsnr[col_idx][None], pred[col_idx][None]).item():.1f}"


def add_psnr_str(ax_, psnr):
Expand Down Expand Up @@ -499,20 +503,40 @@ def get_predictions(idx, val_dset, model, mmse_count=50, patch_size=256):


def get_dset_predictions(
model,
dset,
model: VAEModule,
dset: Dataset,
batch_size: int,
model_type: ModelType = None,
loss_type: Literal["musplit", "denoisplit", "denoisplit_musplit"],
mmse_count: int = 1,
num_workers: int = 4,
):
"""
Get predictions from a model for the entire dataset.
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[float]]:
"""Get patch-wise predictions from a model for the entire dataset.
Parameters
----------
mmse_count : int
Number of samples to generate for each input and then to average over for MMSE estimation.
model : VAEModule
Lightning model used for prediction.
dset : Dataset
Dataset to predict on.
batch_size : int
Batch size to use for prediction.
loss_type :
Type of reconstruction loss used by the model, by default `None`.
mmse_count : int, optional
Number of samples to generate for each input and then to average over for
MMSE estimation, by default 1.
num_workers : int, optional
Number of workers to use for DataLoader, by default 4.
Returns
-------
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[float]]
Tuple containing:
- predictions: Predicted images for the dataset.
- predictions_std: Standard deviation of the predicted images.
- logvar_arr: Log variance of the predicted images.
- losses: Reconstruction losses for the predictions.
- psnr: PSNR values for the predictions.
"""
dloader = DataLoader(
dset,
Expand All @@ -521,80 +545,101 @@ def get_dset_predictions(
shuffle=False,
batch_size=batch_size,
)
likelihood = model.model.likelihood

gauss_likelihood = model.gaussian_likelihood
nm_likelihood = model.noise_model_likelihood

predictions = []
predictions_std = []
losses = []
logvar_arr = []
patch_psnr_channels = [RunningPSNR() for _ in range(dset[0][1].shape[0])]
num_channels = dset[0][1].shape[0]
patch_psnr_channels = [RunningPSNR() for _ in range(num_channels)]
with torch.no_grad():
for batch in tqdm(dloader):
inp, tar = batch[:2]
for batch in tqdm(dloader, desc="Predicting patches"):
inp, tar = batch
inp = inp.cuda()
tar = tar.cuda()

recon_img_list = []
rec_img_list = []
for mmse_idx in range(mmse_count):
if model_type == ModelType.Denoiser:
assert model.denoise_channel in [
"Ch1",
"Ch2",
"input",
], '"all" denoise channel not supported for evaluation. Pick one of "Ch1", "Ch2", "input"'

x_normalized_new, tar_new = model.get_new_input_target(
(inp, tar, *batch[2:])

# TODO: case of HDN left for future refactoring
# if model_type == ModelType.Denoiser:
# assert model.denoise_channel in [
# "Ch1",
# "Ch2",
# "input",
# ], '"all" denoise channel not supported for evaluation. Pick one of "Ch1", "Ch2", "input"'

# x_normalized_new, tar_new = model.get_new_input_target(
# (inp, tar, *batch[2:])
# )
# rec, _ = model(x_normalized_new)
# rec_loss, imgs = model.get_reconstruction_loss(
# rec,
# tar,
# x_normalized_new,
# return_predicted_img=True,
# )

# get model output
rec, _ = model(inp)

# get reconstructed img
if model.model.predict_logvar is None:
rec_img = rec
logvar = torch.tensor([-1])
else:
rec_img, logvar = torch.chunk(rec, chunks=2, dim=1)
rec_img_list.append(rec_img.cpu().unsqueeze(0)) # add MMSE dim
logvar_arr.append(logvar.cpu().numpy())

# compute reconstruction loss
if loss_type == "musplit":
rec_loss = get_reconstruction_loss(
reconstruction=rec, target=tar, likelihood_obj=gauss_likelihood
)
tar_normalized = model.normalize_target(tar_new)
recon_normalized, _ = model(x_normalized_new)
rec_loss, imgs = model.get_reconstruction_loss(
recon_normalized,
tar_normalized,
x_normalized_new,
return_predicted_img=True,
elif loss_type == "denoisplit":
rec_loss = get_reconstruction_loss(
reconstruction=rec, target=tar, likelihood_obj=nm_likelihood
)
else:
x_normalized = model.normalize_input(inp)
tar_normalized = model.normalize_target(tar)
recon_normalized, _ = model(x_normalized)
rec_loss, imgs = model.get_reconstruction_loss(
recon_normalized, tar_normalized, inp, return_predicted_img=True
elif loss_type == "denoisplit_musplit":
rec_loss = reconstruction_loss_musplit_denoisplit(
predictions=rec,
targets=tar,
gaussian_likelihood=gauss_likelihood,
nm_likelihood=nm_likelihood,
nm_weight=model.loss_parameters.denoisplit_weight,
gaussian_weight=model.loss_parameters.musplit_weight,
)
rec_loss = {"loss": rec_loss} # hacky, but ok for now

# store rec loss values for first pred
if mmse_idx == 0:
q_dic = (
likelihood.distr_params(recon_normalized)
if likelihood is not None
else {"logvar": None}
)
if q_dic["logvar"] is not None:
logvar_arr.append(q_dic["logvar"].cpu().numpy())
else:
logvar_arr.append(np.array([-1]))

try:
losses.append(rec_loss["loss"].cpu().numpy())
except:
losses.append(rec_loss["loss"])

for i in range(imgs.shape[1]):
patch_psnr_channels[i].update(imgs[:, i], tar_normalized[:, i])
# update running PSNR
for i in range(num_channels):
patch_psnr_channels[i].update(rec_img[:, i], tar[:, i])

recon_img_list.append(imgs.cpu()[None])

samples = torch.cat(recon_img_list, dim=0)
mmse_imgs = torch.mean(samples, dim=0)
# aggregate results
samples = torch.cat(rec_img_list, dim=0)
mmse_imgs = torch.mean(samples, dim=0) # avg over MMSE dim
mmse_std = torch.std(samples, dim=0)
predictions.append(mmse_imgs.cpu().numpy())
predictions_std.append(mmse_std.cpu().numpy())

psnr = [x.get() for x in patch_psnr_channels]
return (
np.concatenate(predictions, axis=0),
np.array(losses),
np.concatenate(predictions_std, axis=0),
np.concatenate(logvar_arr),
np.array(losses),
psnr,
np.concatenate(predictions_std, axis=0),
)


Expand Down
27 changes: 18 additions & 9 deletions src/careamics/models/lvae/likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math
from typing import Literal, Union, TYPE_CHECKING, Any, Optional

import numpy as np
import torch
from torch import nn

Expand Down Expand Up @@ -287,30 +288,37 @@ class NoiseModelLikelihood(LikelihoodModule):

def __init__(
self,
data_mean: torch.Tensor,
data_std: torch.Tensor,
noiseModel: NoiseModel, # TODO: check the type -> couldn't manage due to circular imports...
data_mean: Union[np.ndarray, torch.Tensor],
data_std: Union[np.ndarray, torch.Tensor],
noiseModel: NoiseModel,
):
"""Constructor.
Parameters
----------
data_mean: torch.Tensor
data_mean: Union[np.ndarray, torch.Tensor]
The mean of the data, used to unnormalize data for noise model evaluation.
data_std: torch.Tensor
data_std: Union[np.ndarray, torch.Tensor]
The standard deviation of the data, used to unnormalize data for noise
model evaluation.
noiseModel: NoiseModel
The noise model instance used to compute the likelihood.
"""
super().__init__()
self.data_mean = data_mean
self.data_std = data_std
self.data_mean = torch.Tensor(data_mean)
self.data_std = torch.Tensor(data_std)
self.noiseModel = noiseModel

def set_params_to_same_device_as(
def _set_params_to_same_device_as(
self, correct_device_tensor: torch.Tensor
) -> None: # TODO: needed?
) -> None:
"""Set the parameters to the same device as the input tensor.
Parameters
----------
correct_device_tensor: torch.Tensor
The tensor whose device is used to set the parameters.
"""
if self.data_mean.device != correct_device_tensor.device:
self.data_mean = self.data_mean.to(correct_device_tensor.device)
self.data_std = self.data_std.to(correct_device_tensor.device)
Expand Down Expand Up @@ -355,6 +363,7 @@ def log_likelihood(self, x: torch.Tensor, params: dict[str, torch.Tensor]):
torch.Tensor
The log-likelihood tensor. Shape is (B, C, [Z], Y, X).
"""
self._set_params_to_same_device_as(x)
predicted_s_denormalized = params["mean"] * self.data_std + self.data_mean
x_denormalized = x * self.data_std + self.data_mean
likelihoods = self.noiseModel.likelihood(
Expand Down
17 changes: 12 additions & 5 deletions src/careamics/models/lvae/lvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,11 +795,18 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor

# return samples

# def reset_for_different_output_size(self, output_size):
# for i in range(self.n_layers):
# sz = output_size // 2**(1 + i)
# self.bottom_up_layers[i].output_expected_shape = (sz, sz)
# self.top_down_layers[i].latent_shape = (output_size, output_size)
def reset_for_different_output_size(self, output_size: int) -> None:
"""Reset shape of output and latent tensors for different output size.
Used during evaluation to reset expected shapes of tensors when
input/output shape changes.
For instance, it is needed when the model was trained on, say, 64x64 sized
patches, but prediction is done on 128x128 patches.
"""
for i in range(self.n_layers):
sz = output_size // 2 ** (1 + i)
self.bottom_up_layers[i].output_expected_shape = (sz, sz)
self.top_down_layers[i].latent_shape = (output_size, output_size)

def pad_input(self, x):
"""
Expand Down
4 changes: 2 additions & 2 deletions src/careamics/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ def _zero_mean(x: np.ndarray) -> np.ndarray:
Parameters
----------
x : NumPy array
x : np.ndarray
Input array.
Returns
-------
NumPy array
np.ndarray
Zero-mean array.
"""
return x - np.mean(x)
Expand Down

0 comments on commit 26a3dc7

Please sign in to comment.