Skip to content

Commit

Permalink
Merge pull request #7 from ornlneutronimaging/alt_step1
Browse files Browse the repository at this point in the history
Add alternative fast noise free estimate function
  • Loading branch information
KedoKudo authored May 20, 2024
2 parents 81235ae + bc1a27a commit 4491745
Show file tree
Hide file tree
Showing 11 changed files with 471 additions and 134 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ The BM3D algorithm is a state-of-the-art denoising algorithm that is widely used
The BM3D ORNL code is a Python implementation of the BM3D algorithm that has been optimized for performance using both `Numba` and `CuPy`.
The BM3D ORNL code is designed to be easy to use and easy to integrate into existing Python workflows.
The BM3D ORNL code is released under an open-source license, and is freely available for download and use.


For more information, check out our [FAQ](docs/FAQ.md).
80 changes: 80 additions & 0 deletions docs/FAQ.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# FAQ

## General

1. **What is the bm3dornl library?**
- `bm3dornl` is a Python library for removing streak artifacts in normalized sinograms to reduce ring artifacts in the final reconstruction. It uses a multiscale BM3D algorithm accelerated with CuPy and Numba.

2. **What is the purpose of the bm3dornl library?**
- The library aims to provide open-source, high-performance ring artifact removal for neutron imaging using advanced denoising techniques.

3. **Who are the contributors to the bm3dornl project?**
- Developed by the ORNL neutron software engineering team and maintained by the neutron imaging team, including MARS@HFIR and VENUS@SNS.

## Installation

1. **How do I install the bm3dornl library?**
- Currently under development. Install in developer mode:
- Clone the repository and checkout a feature branch (use `next` for latest features).
- Create a virtual environment, activate it, and install dependencies from `environment.yml`.
- Perform a developer install using `pip install -e .`.

2. **What are the system requirements for bm3dornl?**
- Requires Python 3.10 or later and a CUDA-enabled GPU for CuPy acceleration.

3. **How can I set up the environment to use bm3dornl?**
- Use the `environment.yml` file: `conda env create -f environment.yml`.

## Usage

1. **Can you provide a basic example of how to use bm3dornl for ring artifact removal?**
- For simple usage:
```python
from bm3dornl.denoiser import bm3d_streak_removal

sino_bm3dornl = bm3d_streak_removal(
sinogram=sinogram_noisy,
background_threshold=0.1,
patch_size=(8, 8),
stride=3,
cut_off_distance=(128, 128),
intensity_diff_threshold=0.2,
num_patches_per_group=300,
shrinkage_threshold=1 - 1e-4,
k=0,
fast_estimate=True,
)
```

2. **How do I use bm3dornl with CuPy for accelerated performance?**
- Ensure you have a CUDA-enabled GPU and install dependencies from `environment.yml`. Use bm3dornl functions that leverage CuPy for collaborative filtering and hard thresholding (`fast_estimate=False`).

3. **What are the main functions provided by bm3dornl?**
- Key components:
- `PatchManager` for block matching.
- `Numba` accelerated functions.
- `CuPy` accelerated functions.
- `bm3d_streak_removal` for ring artifact removal.
- Helper functions for visualization and data manipulation.

## Code and Implementation

1. **What does the `PatchManager` class do in bm3dornl?**
- Manages image patches, groups them based on spatial and intensity thresholds, and generates groups of similar patches as 4D arrays.

2. **How does the bm3dornl library utilize Numba for performance optimization?**
- `Numba` accelerates functions by compiling Python code to machine code at runtime.

3. **Can you explain the process of block matching in bm3dornl?**
- Involves finding similar patches in an image and grouping them for collaborative denoising, leveraging similar patches to enhance denoising effectiveness.

## Documentation and Support

1. **Where can I find the official documentation for bm3dornl?**
- Available in the repository’s [README](https://github.com/ornlneutronimaging/bm3dornl/blob/main/README.md) and additional docs.

2. **How can I contribute to the bm3dornl project?**
- Fork the repository, make your changes, and submit a pull request. Follow the contribution guidelines in the README.

3. **Are there any tutorials available for learning bm3dornl?**
- Check the repository for example notebooks or additional tutorial links.
227 changes: 131 additions & 96 deletions notebooks/example.ipynb

Large diffs are not rendered by default.

48 changes: 37 additions & 11 deletions src/bm3dornl/denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
horizontal_binning,
horizontal_debinning,
estimate_noise_std,
estimate_noise_free_sinogram,
)


Expand Down Expand Up @@ -130,11 +131,11 @@ def thresholding(
# Normalize by the weights to compute the average
self.estimate_denoised_image /= np.maximum(weights, 1)

# update the patch manager with the new estimate
self.patch_manager.background_threshold *= (
0.5 # reduce the threshold for background threshold further
)
self.patch_manager.image = self.estimate_denoised_image
# # update the patch manager with the new estimate
# self.patch_manager.background_threshold *= (
# 0.5 # reduce the threshold for background threshold further
# )
# self.patch_manager.image = self.estimate_denoised_image

def re_filtering(
self,
Expand Down Expand Up @@ -197,6 +198,7 @@ def denoise(
intensity_diff_threshold: float,
num_patches_per_group: int,
threshold: float,
fast_estimate: bool = True,
):
"""
Perform the BM3D denoising process on the input image.
Expand All @@ -211,16 +213,35 @@ def denoise(
The number of patch in each block.
threshold : float
The threshold value for hard thresholding during the first pass.
fast_estimate : bool
Whether to use a fast estimate for the denoised image. Default is True.
"""
logging.info("First pass: Hard thresholding")
self.thresholding(
cut_off_distance, intensity_diff_threshold, num_patches_per_group, threshold
)
# self.final_denoised_image = self.estimate_denoised_image
# step 1: estimate the noise free image
logging.info("Estimating noise free image...")
if fast_estimate:
logging.info("Using fast estimate")
self.estimate_denoised_image = estimate_noise_free_sinogram(
sinogram=self.image,
background_estimate=self.background_threshold,
)
else:
logging.info("Using block-matching with hard thresholding")
self.thresholding(
cut_off_distance=cut_off_distance,
intensity_diff_threshold=intensity_diff_threshold,
num_patches_per_group=num_patches_per_group,
threshold=threshold,
)

# step 2: update patch manager with the estimate_denoised_image
self.patch_manager.image = self.estimate_denoised_image

# step 3: re-filtering
logging.info("Second pass: Re-filtering")
self.re_filtering(
cut_off_distance, intensity_diff_threshold, num_patches_per_group
cut_off_distance=cut_off_distance,
intensity_diff_threshold=intensity_diff_threshold,
num_patches_per_group=num_patches_per_group,
)


Expand All @@ -234,6 +255,7 @@ def bm3d_streak_removal(
num_patches_per_group: int = 400,
shrinkage_threshold: float = 0.1,
k: int = 4,
fast_estimate: bool = True,
) -> np.ndarray:
"""Multiscale BM3D for streak removal
Expand All @@ -257,6 +279,8 @@ def bm3d_streak_removal(
The threshold for hard thresholding, by default 0.2
k : int, optional
The number of iterations for horizontal binning, by default 3
fast_estimate : bool, optional
Whether to use a fast estimate for the denoised image, by default True
Returns
-------
Expand Down Expand Up @@ -285,6 +309,7 @@ def bm3d_streak_removal(
intensity_diff_threshold=intensity_diff_threshold,
num_patches_per_group=num_patches_per_group,
threshold=shrinkage_threshold,
fast_estimate=fast_estimate,
)
return worker.final_denoised_image

Expand Down Expand Up @@ -314,6 +339,7 @@ def bm3d_streak_removal(
intensity_diff_threshold=intensity_diff_threshold,
num_patches_per_group=num_patches_per_group,
threshold=shrinkage_threshold,
fast_estimate=fast_estimate,
)
noise_estimate = sino - worker.final_denoised_image

Expand Down
5 changes: 4 additions & 1 deletion src/bm3dornl/phantom.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,12 @@ def generate_sinogram(
sinogram = radon(
input_img,
theta=thetas_deg,
circle=True,
circle=False,
).T # transpose to get the sinogram in the correct orientation for tomopy

# normalize sinogram to [0, 1]
sinogram = (sinogram - sinogram.min()) / (sinogram.max() - sinogram.min()) + 1e-8

return sinogram, thetas_deg


Expand Down
24 changes: 24 additions & 0 deletions src/bm3dornl/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python3
"""Plotting utilities for BM3D and ORNL."""

import numpy as np
from typing import Tuple


def compute_cdf(img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute the cumulative distribution function of an image.
Parameters
----------
img : np.ndarray
The input image.
Returns
-------
Tuple[np.ndarray, np.ndarray]
The sorted CDF values and the corresponding probabilities.
"""
cdf_org_sorted = np.sort(img.flatten())
p_org = 1.0 * np.arange(len(cdf_org_sorted)) / (len(cdf_org_sorted) - 1)
return cdf_org_sorted, p_org
52 changes: 52 additions & 0 deletions src/bm3dornl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from scipy.interpolate import RectBivariateSpline
from numba import jit
from typing import Tuple, List
from scipy.fft import fft2, ifft2, fftshift, ifftshift
from scipy.ndimage import gaussian_filter


@jit(nopython=True)
Expand Down Expand Up @@ -306,3 +308,53 @@ def estimate_noise_std(

# calculate the noise standard deviation
return np.std(np.abs(noisy_image - noise_free_image))


def estimate_noise_free_sinogram(
sinogram: np.ndarray,
background_estimate: float,
sigma_gaussian: float = 5.0,
) -> np.ndarray:
"""
Estimate noise-free sinogram from noisy sinogram.
Parameters
----------
sinogram : np.ndarray
Noisy sinogram.
background_estimate : float
Background estimate value.
sigma_gaussian : float, optional
Standard deviation of the 1D Gaussian filter, by default 5.0.
Returns
-------
np.ndarray
Noise-free sinogram.
"""
# Perform the hard-thresholding using FFT
sinogram_fft_shifted = fftshift(fft2(sinogram))
mask = np.ones_like(sinogram_fft_shifted)
crow = sinogram_fft_shifted.shape[0] // 2
mask[crow] = (
0 # this will suppress all vertical streaks, and some features (demerit)
)
sinogram_fft_shifted *= mask
sinogram_filtered = ifft2(ifftshift(sinogram_fft_shifted)).real

# Renormalize the sinogram to [0, 1] as the hard threshold mess up the intensity distribution
sinogram_filtered -= sinogram_filtered.min()
sinogram_filtered /= sinogram_filtered.max()

# Now reapply the background
sinogram_filtered[sinogram < background_estimate] = 0

sino_blurred = gaussian_filter(sinogram, sigma=sigma_gaussian)
scale_profile = np.sum(sinogram_filtered, axis=0) / np.sum(sino_blurred, axis=0)
sinogram_filtered /= scale_profile + 1e-8

# renormalize the sinogram to [0, 1]
sinogram_filtered -= sinogram_filtered.min()
sinogram_filtered /= sinogram_filtered.max()

return sinogram_filtered
Loading

0 comments on commit 4491745

Please sign in to comment.