Skip to content

Commit

Permalink
refac: updated calibration code + moved it into a separate script
Browse files Browse the repository at this point in the history
  • Loading branch information
federico-carrara committed Nov 27, 2024
1 parent f0fcc89 commit 0b1960e
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 181 deletions.
172 changes: 172 additions & 0 deletions src/careamics/lvae_training/calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from typing import Union

import numpy as np
import torch
from scipy import stats


def get_last_index(bin_count, quantile):
cumsum = np.cumsum(bin_count)
normalized_cumsum = cumsum / cumsum[-1]
for i in range(1, len(normalized_cumsum)):
if normalized_cumsum[-i] < quantile:
return i - 1
return None


def get_first_index(bin_count, quantile):
cumsum = np.cumsum(bin_count)
normalized_cumsum = cumsum / cumsum[-1]
for i in range(len(normalized_cumsum)):
if normalized_cumsum[i] > quantile:
return i
return None


class Calibration:
def __init__(self, num_bins: int = 15):
self._bins = num_bins
self._bin_boundaries = None

def logvar_to_std(self, logvar: np.ndarray) -> np.ndarray:
return np.exp(logvar / 2)

def compute_bin_boundaries(self, predict_std: np.ndarray) -> np.ndarray:
"""Compute the bin boundaries for `num_bins` bins and predicted std values."""
min_std = np.min(predict_std)
max_std = np.max(predict_std)
return np.linspace(min_std, max_std, self._bins + 1)

def compute_stats(
self, pred: np.ndarray, pred_std: np.ndarray, target: np.ndarray
) -> dict[int, dict[str, Union[np.ndarray, list]]]:
"""
It computes the bin-wise RMSE and RMV for each channel of the predicted image.
Recall that:
- RMSE = np.sqrt((pred - target)**2 / num_pixels)
- RMV = np.sqrt(np.mean(pred_std**2))
ALGORITHM
- For each channel:
- Given the bin boundaries, assign pixels of `std_ch` array to a specific bin index.
- For each bin index:
- Compute the RMSE, RMV, and number of pixels for that bin.
NOTE: each channel of the predicted image/logvar has its own stats.
Parameters
----------
pred: np.ndarray
Predicted patches, shape (n, h, w, c).
pred_std: np.ndarray
Std computed over the predicted patches, shape (n, h, w, c).
target: np.ndarray
Target GT image, shape (n, h, w, c).
"""
self._bin_boundaries = {}
stats_dict = {}
for ch_idx in range(pred.shape[-1]):
stats_dict[ch_idx] = {
'bin_count': [],
'rmv': [],
'rmse': [],
'bin_boundaries': None,
'bin_matrix': [],
'rmse_err': []
}
pred_ch = pred[..., ch_idx]
std_ch = pred_std[..., ch_idx]
target_ch = target[..., ch_idx]
boundaries = self.compute_bin_boundaries(std_ch)
stats_dict[ch_idx]['bin_boundaries'] = boundaries
bin_matrix = np.digitize(std_ch.reshape(-1), boundaries)
bin_matrix = bin_matrix.reshape(std_ch.shape)
stats_dict[ch_idx]['bin_matrix'] = bin_matrix
error = (pred_ch - target_ch)**2
for bin_idx in range(1, 1+self._bins):
bin_mask = bin_matrix == bin_idx
bin_error = error[bin_mask]
bin_size = np.sum(bin_mask)
bin_error = np.sqrt(np.sum(bin_error) / bin_size) if bin_size > 0 else None
stderr = np.std(error[bin_mask]) / np.sqrt(bin_size) if bin_size > 0 else None
rmse_stderr = np.sqrt(stderr) if stderr is not None else None

bin_var = np.mean((std_ch[bin_mask]**2))
stats_dict[ch_idx]['rmse'].append(bin_error)
stats_dict[ch_idx]['rmse_err'].append(rmse_stderr)
stats_dict[ch_idx]['rmv'].append(np.sqrt(bin_var))
stats_dict[ch_idx]['bin_count'].append(bin_size)
return stats_dict


def get_calibrated_factor_for_stdev(
pred: Union[np.ndarray, torch.Tensor],
pred_std: Union[np.ndarray, torch.Tensor],
target: Union[np.ndarray, torch.Tensor],
q_s: float = 0.00001,
q_e: float = 0.99999,
num_bins: int = 30,
) -> dict[str, float]:
"""Calibrate the uncertainty by multiplying the predicted std with a scalar.
Parameters
----------
pred : Union[np.ndarray, torch.Tensor]
Predicted image, shape (n, h, w, c).
pred_std : Union[np.ndarray, torch.Tensor]
Predicted std, shape (n, h, w, c).
target : Union[np.ndarray, torch.Tensor]
Target image, shape (n, h, w, c).
q_s : float, optional
Start quantile, by default 0.00001.
q_e : float, optional
End quantile, by default 0.99999.
num_bins : int, optional
Number of bins to use for calibration, by default 30.
Returns
-------
dict[str, float]
Calibrated factor for each channel (slope + intercept).
"""
calib = Calibration(num_bins=num_bins)
stats_dict = calib.compute_stats(pred, pred_std, target)
outputs = {}
for ch_idx in stats_dict.keys():
y = stats_dict[ch_idx]['rmse']
x = stats_dict[ch_idx]['rmv']
count = stats_dict[ch_idx]['bin_count']

first_idx = get_first_index(count, q_s)
last_idx = get_last_index(count, q_e)
x = x[first_idx:-last_idx]
y = y[first_idx:-last_idx]
slope, intercept, *_ = stats.linregress(x,y)
output = {'scalar': slope, 'offset': intercept}
outputs[ch_idx] = output
return outputs


def plot_calibration(ax, calibration_stats):
first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.001)
last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.999)
ax.plot(
calibration_stats[0]["rmv"][first_idx:-last_idx],
calibration_stats[0]["rmse"][first_idx:-last_idx],
"o",
label=r"$\hat{C}_0$: Ch1",
)

first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.001)
last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.999)
ax.plot(
calibration_stats[1]["rmv"][first_idx:-last_idx],
calibration_stats[1]["rmse"][first_idx:-last_idx],
"o",
label=r"$\hat{C}_1: : Ch2$",
)

ax.set_xlabel("RMV")
ax.set_ylabel("RMSE")
ax.legend()
184 changes: 3 additions & 181 deletions src/careamics/lvae_training/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
- quantify the performance of the model
- create plots to visualize the results.
"""

import math
import os
from typing import Dict, List, Literal, Union
from typing import List, Literal, Union

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import torch
from torch import nn
from torch.utils.data import Dataset
Expand Down Expand Up @@ -859,181 +858,4 @@ def stitch_predictions_new(predictions, dset):
else:
raise ValueError(f"Unsupported shape {output.shape}")

return output


# ------------------------------------------------------------------------------------------


# ------------------------------------------------------------------------------------------
### Classes and Functions used for Calibration
class Calibration:

def __init__(
self, num_bins: int = 15, mode: Literal["pixelwise", "patchwise"] = "pixelwise"
):
self._bins = num_bins
self._bin_boundaries = None
self._mode = mode
assert mode in ["pixelwise", "patchwise"]
self._boundary_mode = "uniform"
assert self._boundary_mode in ["quantile", "uniform"]
# self._bin_boundaries = {}

def logvar_to_std(self, logvar: np.ndarray) -> np.ndarray:
return np.exp(logvar / 2)

def compute_bin_boundaries(self, predict_logvar: np.ndarray) -> np.ndarray:
"""
Compute the bin boundaries for `num_bins` bins and the given logvar values.
"""
if self._boundary_mode == "quantile":
boundaries = np.quantile(
self.logvar_to_std(predict_logvar), np.linspace(0, 1, self._bins + 1)
)
return boundaries
else:
min_logvar = np.min(predict_logvar)
max_logvar = np.max(predict_logvar)
min_std = self.logvar_to_std(min_logvar)
max_std = self.logvar_to_std(max_logvar)
return np.linspace(min_std, max_std, self._bins + 1)

def compute_stats(
self, pred: np.ndarray, pred_logvar: np.ndarray, target: np.ndarray
) -> Dict[int, Dict[str, Union[np.ndarray, List]]]:
"""
It computes the bin-wise RMSE and RMV for each channel of the predicted image.
Recall that:
- RMSE = np.sqrt((pred - target)**2 / num_pixels)
- RMV = np.sqrt(np.mean(pred_std**2))
ALGORITHM
- For each channel:
- Given the bin boundaries, assign pixels of `std_ch` array to a specific bin index.
- For each bin index:
- Compute the RMSE, RMV, and number of pixels for that bin.
NOTE: each channel of the predicted image/logvar has its own stats.
Args:
pred: np.ndarray, shape (n, h, w, c)
pred_logvar: np.ndarray, shape (n, h, w, c)
target: np.ndarray, shape (n, h, w, c)
"""
self._bin_boundaries = {}
stats = {}
for ch_idx in range(pred.shape[-1]):
stats[ch_idx] = {
"bin_count": [],
"rmv": [],
"rmse": [],
"bin_boundaries": None,
"bin_matrix": [],
}
pred_ch = pred[..., ch_idx]
logvar_ch = pred_logvar[..., ch_idx]
std_ch = self.logvar_to_std(logvar_ch)
target_ch = target[..., ch_idx]
if self._mode == "pixelwise":
boundaries = self.compute_bin_boundaries(logvar_ch)
stats[ch_idx]["bin_boundaries"] = boundaries
bin_matrix = np.digitize(std_ch.reshape(-1), boundaries)
bin_matrix = bin_matrix.reshape(std_ch.shape)
stats[ch_idx]["bin_matrix"] = bin_matrix
error = (pred_ch - target_ch) ** 2
for bin_idx in range(self._bins):
bin_mask = bin_matrix == bin_idx
bin_error = error[bin_mask]
bin_size = np.sum(bin_mask)
bin_error = (
np.sqrt(np.sum(bin_error) / bin_size) if bin_size > 0 else None
) # RMSE
bin_var = np.sqrt(np.mean(std_ch[bin_mask] ** 2)) # RMV
stats[ch_idx]["rmse"].append(bin_error)
stats[ch_idx]["rmv"].append(bin_var)
stats[ch_idx]["bin_count"].append(bin_size)
else:
raise NotImplementedError("Patchwise mode is not implemented yet.")
return stats


def nll(x, mean, logvar):
"""
Log of the probability density of the values x under the Normal
distribution with parameters mean and logvar.
:param x: tensor of points, with shape (batch, channels, dim1, dim2)
:param mean: tensor with mean of distribution, shape
(batch, channels, dim1, dim2)
:param logvar: tensor with log-variance of distribution, shape has to be
either scalar or broadcastable
"""
var = torch.exp(logvar)
log_prob = -0.5 * (
((x - mean) ** 2) / var + logvar + torch.tensor(2 * math.pi).log()
)
nll = -log_prob
return nll


def get_calibrated_factor_for_stdev(
pred: Union[np.ndarray, torch.Tensor],
pred_logvar: Union[np.ndarray, torch.Tensor],
target: Union[np.ndarray, torch.Tensor],
batch_size: int = 32,
epochs: int = 500,
lr: float = 0.01,
):
"""
Here, we calibrate the uncertainty by multiplying the predicted std (mmse estimate or predicted logvar) with a scalar.
We return the calibrated scalar. This needs to be multiplied with the std.
NOTE: Why is the input logvar and not std? because the model typically predicts logvar and not std.
"""
# create a learnable scalar
scalar = torch.nn.Parameter(torch.tensor(2.0))
optimizer = torch.optim.Adam([scalar], lr=lr)

bar = tqdm(range(epochs))
for _ in bar:
optimizer.zero_grad()
# Select a random batch of predictions
mask = np.random.randint(0, pred.shape[0], batch_size)
pred_batch = torch.Tensor(pred[mask]).cuda()
pred_logvar_batch = torch.Tensor(pred_logvar[mask]).cuda()
target_batch = torch.Tensor(target[mask]).cuda()

loss = torch.mean(
nll(target_batch, pred_batch, pred_logvar_batch + torch.log(scalar))
)
loss.backward()
optimizer.step()
bar.set_description(f"nll: {loss.item()} scalar: {scalar.item()}")

return np.sqrt(scalar.item())


def plot_calibration(ax, calibration_stats):
first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.001)
last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.999)
ax.plot(
calibration_stats[0]["rmv"][first_idx:-last_idx],
calibration_stats[0]["rmse"][first_idx:-last_idx],
"o",
label=r"$\hat{C}_0$: Ch1",
)

first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.001)
last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.999)
ax.plot(
calibration_stats[1]["rmv"][first_idx:-last_idx],
calibration_stats[1]["rmse"][first_idx:-last_idx],
"o",
label=r"$\hat{C}_1: : Ch2$",
)

ax.set_xlabel("RMV")
ax.set_ylabel("RMSE")
ax.legend()
return output

0 comments on commit 0b1960e

Please sign in to comment.