Skip to content

Commit

Permalink
Merge pull request #29 from ornlneutronimaging/troubleshoot_ms_denoise
Browse files Browse the repository at this point in the history
Improve multi-scale vertical streak removal quality
  • Loading branch information
KedoKudo authored Aug 28, 2024
2 parents a31cfee + f25f1db commit 1eb54f2
Show file tree
Hide file tree
Showing 9 changed files with 393 additions and 253 deletions.
6 changes: 3 additions & 3 deletions notebooks/demo_denoise_mode.ipynb

Large diffs are not rendered by default.

36 changes: 18 additions & 18 deletions notebooks/example.ipynb

Large diffs are not rendered by default.

118 changes: 88 additions & 30 deletions notebooks/example_ms.ipynb

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions src/bm3dornl/block_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def get_signal_patch_positions(
# Note: raise error when couldn't find a single signal patch from the entire
# sinogram, which usually indicating a bad background estimation.
if len(signal_patches) == 0:
raise ValueError(
"Couldn't find any signal patches in the image! Please check the background threshold."
)
raise ValueError("Couldn't find any signal patches in the image!")

return np.array(signal_patches)

Expand Down
209 changes: 166 additions & 43 deletions src/bm3dornl/bm3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import logging
import numpy as np
import cupy as cp
from typing import Tuple, Callable
from scipy.signal import medfilt2d
from .block_matching import (
get_signal_patch_positions,
get_patch_numba,
Expand Down Expand Up @@ -32,8 +32,8 @@
hadamard_transform,
)
from .utils import (
horizontal_binning,
horizontal_debinning,
downscale_2d_horizontal,
upscale_2d_horizontal,
)

# NOTE: These default parameters are based on the parameter tuning study.
Expand Down Expand Up @@ -295,6 +295,66 @@ def global_fourier_thresholding(
return new_noisy_image


def padded_piecewise_weighted_denoising(
sinogram: np.ndarray,
window_size: int = 50,
step_size: int = 10,
pad_size: int = None,
) -> np.ndarray:
"""
Perform piecewise weighted denoising on a sinogram with padding using CuPy for GPU acceleration.
Parameters
----------
sinogram : np.ndarray
Input sinogram to be denoised.
window_size : int, optional
Size of the window used for piecewise denoising, by default 50.
step_size : int, optional
Step size for moving the window, by default 10.
pad_size : int, optional
Padding size added to the sinogram before denoising. If None, pad_size is set to 2 * window_size.
Returns
-------
np.ndarray
Denoised sinogram after piecewise weighted denoising.
"""
if pad_size is None:
pad_size = window_size * 2

# Move data to GPU
sinogram_gpu = cp.asarray(sinogram)

# Pad the sinogram
padded_sinogram = cp.pad(sinogram_gpu, ((pad_size, pad_size), (0, 0)), mode="wrap")

rows, _ = padded_sinogram.shape
new_img = cp.zeros_like(padded_sinogram, dtype=cp.float32)
new_wgt = cp.zeros_like(padded_sinogram, dtype=cp.float32)

for i in range(0, rows, step_size):
end = min(i + window_size, rows)
start = max(0, end - window_size)
window = padded_sinogram[start:end, :]
median = cp.median(window, axis=0)
new_img[start:end, :] += window - median
new_wgt[start:end, :] += 1

# Avoid division by zero
new_wgt = cp.maximum(new_wgt, 1e-10)
denoised = new_img / new_wgt

# Restore the overall intensity level
denoised += cp.median(sinogram_gpu)

# Remove padding
denoised = denoised[pad_size:-pad_size, :]

# Move result back to CPU
return cp.asnumpy(denoised)


def estimate_noise_free_sinogram(sinogram: np.ndarray) -> np.ndarray:
"""
Estimate noise-free sinogram from noisy sinogram.
Expand All @@ -309,14 +369,18 @@ def estimate_noise_free_sinogram(sinogram: np.ndarray) -> np.ndarray:
np.ndarray
Noise-free sinogram.
"""
# subtract column-wise median
sinogram = sinogram - np.median(sinogram, axis=0)
# perform median filtering to remove salt-and-pepper noise
sinogram = medfilt2d(sinogram, kernel_size=3)
# rescale to [0, 1]
sinogram -= sinogram.min()
sinogram /= sinogram.max()
return sinogram
# use piecewise weighted denoising
window_size = sinogram.shape[0] // 4
step_size = 1
denoised = padded_piecewise_weighted_denoising(
sinogram, window_size=window_size, step_size=step_size
)

# normalize to [0, 1]
denoised -= np.min(denoised)
denoised /= np.max(denoised)

return denoised


def bm3d_full(
Expand Down Expand Up @@ -654,14 +718,62 @@ def bm3d_ring_artifact_removal(
raise ValueError(f"Unknown mode: {mode}")


def get_scale_adjusted_blockmatching_params(
original_params: dict, scale_factor: int
) -> dict:
"""Scale the parameters based on the given factor.
Parameters
----------
original_params : dict
The original parameters.
scale_factor : int
The scale factor.
Returns
-------
dict
The adjusted parameters.
"""
adjusted_params = original_params.copy()

# Adjust patch size
# minimum patch size is 3x3
adjusted_params["patch_size"] = tuple(
max(3, int(x * scale_factor)) for x in original_params["patch_size"]
)

# Adjust stride
# minimum stride is 1
adjusted_params["stride"] = max(1, int(original_params["stride"] * scale_factor))

# Adjust cut-off distance
# minimum cut-off distance is 8
adjusted_params["cut_off_distance"] = tuple(
max(8, int(x / scale_factor)) for x in original_params["cut_off_distance"]
)

# Optionally adjust number of patches per group
if scale_factor > 1:
adjusted_params["num_patches_per_group"] = max(
16, original_params["num_patches_per_group"] // scale_factor
)

return adjusted_params


def bm3d_ring_artifact_removal_ms(
sinogram: np.ndarray,
k: int = 3,
mode: str = "simple", # express, simple, full
block_matching_kwargs: dict = default_block_matching_kwargs,
filter_kwargs: dict = default_filter_kwargs,
use_iterative_refinement: bool = True,
refinement_iterations: int = 3,
scale_factor_base: int = 2,
) -> np.ndarray:
"""Multiscale BM3D for streak removal
"""
Multiscale BM3D for streak removal
Parameters
----------
Expand All @@ -675,6 +787,12 @@ def bm3d_ring_artifact_removal_ms(
The block matching parameters.
filter_kwargs : dict
The filter parameters.
use_iterative_refinement : bool, optional
Whether to use iterative refinement in upscaling, by default True
refinement_iterations : int, optional
Number of refinement iterations if using iterative refinement, by default 3
scale_factor_base : int, optional
The base scale factor for binning, by default 2
Returns
-------
Expand All @@ -686,8 +804,8 @@ def bm3d_ring_artifact_removal_ms(
[1] ref: `Collaborative Filtering of Correlated Noise <https://doi.org/10.1109/TIP.2020.3014721>`_
[2] ref: `Ring artifact reduction via multiscale nonlocal collaborative filtering of spatially correlated noise <https://doi.org/10.1107/S1600577521001910>`_
"""
# step 0: median filter the sinogram
sino_star = sinogram
# step 0: initialize
sino_star = np.array(sinogram)

if k == 0:
# single pass
Expand All @@ -698,46 +816,51 @@ def bm3d_ring_artifact_removal_ms(
filter_kwargs=filter_kwargs,
)

denoised_sino = None
# Make a copy of an original sinogram
sino_orig = horizontal_binning(sino_star, 1, dim=0)
binned_sinos_orig = [np.copy(sino_orig)]

# Contains upscaled denoised sinograms
binned_sinos = [np.zeros(0)]
binned_sinos_orig = [sino_star]

# Bin horizontally
for i in range(0, k):
binned_sinos_orig.append(
horizontal_binning(binned_sinos_orig[-1], fac=2, dim=1)
)
binned_sinos.append(np.zeros(0))

binned_sinos[-1] = binned_sinos_orig[-1]
for i in range(k):
binned_sinos_orig.append(downscale_2d_horizontal(binned_sinos_orig[-1], 2))

# Multi-scale denoising
for i in range(k, -1, -1):
logging.info(f"Processing binned sinogram {i + 1} of {k}")
logging.info(f"Processing binned sinogram {i + 1} of {k + 1}")

# compute the adjusted parameters
scale_factor = int(scale_factor_base ** (i / 2))
adjusted_block_matching_kwargs = get_scale_adjusted_blockmatching_params(
block_matching_kwargs, scale_factor
)
adjusted_filter_kwargs = filter_kwargs.copy()
adjusted_filter_kwargs["shrinkage_factor"] = (
filter_kwargs["shrinkage_factor"] / scale_factor
)

# Denoise binned sinogram
denoised_sino = bm3d_ring_artifact_removal(
binned_sinos[i],
binned_sinos_orig[i],
mode=mode,
block_matching_kwargs=block_matching_kwargs,
filter_kwargs=filter_kwargs,
block_matching_kwargs=adjusted_block_matching_kwargs,
filter_kwargs=adjusted_filter_kwargs,
)

# For iterations except the last, create the next noisy image with a finer scale residual
if i > 0:
debinned_sino = horizontal_debinning(
denoised_sino - binned_sinos_orig[i],
binned_sinos_orig[i - 1].shape[1],
# Calculate the noise at current scale
noise_at_scale_i = binned_sinos_orig[i] - denoised_sino

# Upscale the noise to the next finer scale
upscaled_noise = upscale_2d_horizontal(
noise_at_scale_i,
2,
30,
dim=1,
original_width=binned_sinos_orig[i - 1].shape[1],
use_iterative_refinement=use_iterative_refinement,
refinement_iterations=refinement_iterations,
)
binned_sinos[i - 1] = binned_sinos_orig[i - 1] + debinned_sino

# residual
sino_star = sino_star + horizontal_debinning(
denoised_sino - sino_orig, sino_star.shape[0], fac=1, n_iter=30, dim=0
)
# Remove the upscaled noise from the finer scale
# NOTE: The subtraction of noise will also be upscaled in the next iteration, therefore
# propagating the noise removal from coarser to finer scales
binned_sinos_orig[i - 1] -= upscaled_noise

return sino_star
return binned_sinos_orig[0]
Loading

0 comments on commit 1eb54f2

Please sign in to comment.