Skip to content

Commit

Permalink
Merge pull request #9 from ornlneutronimaging/faster_block_matching
Browse files Browse the repository at this point in the history
Faster block matching to improve overall speed
  • Loading branch information
KedoKudo authored May 23, 2024
2 parents 5479b81 + 2a4444b commit c07bc7d
Show file tree
Hide file tree
Showing 6 changed files with 705 additions and 124 deletions.
121 changes: 121 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Contributing to BM3D-ORNL

Thank you for considering contributing to the BM3D-ORNL project!
We welcome contributions from the community and are grateful for your help in improving this library.
This guide provides instructions on how to contribute to the project.

## Table of Contents

- [Code of Conduct](#code-of-conduct)
- [Getting Started](#getting-started)
- [Development Workflow](#development-workflow)
- [Coding Standards](#coding-standards)
- [Testing](#testing)
- [Submitting Changes](#submitting-changes)
- [Reporting Issues](#reporting-issues)
- [Contact](#contact)

## Code of Conduct

By participating in this project, you agree to abide by the [Code of Conduct](CODE_OF_CONDUCT.md).

## Getting Started

1. **Fork the Repository**: Fork the [bm3dornl repository](https://github.com/ornlneutronimaging/bm3dornl) to your GitHub account.

2. **Clone the Fork**: Clone your forked repository to your local machine.

```bash
git clone https://github.com/your-username/bm3dornl.git
cd bm3dornl
```

3. **Set Upstream Remote**: Add the original repository as an upstream remote.

```bash
git remote add upstream https://github.com/ornlneutronimaging/bm3dornl.git
```

4. **Create a Virtual Environment**: Set up a virtual environment to manage dependencies.

```bash
micromamba create -f environment.yml
micromamba activate bm3dornl
```

## Development Workflow

- **Create a Branch**: Create a new branch for your feature or bugfix.

```bash
git checkout -b feature/your-feature-name
```

- **Make Changes**: Make your changes in the codebase. Use `pre-commit` to help you format your code and check for common issues.

```bash
pre-commit install
```

> Note: you only need to run `pre-commit install` once. After that, the pre-commit checks will run automatically before each commit.

- **Write Tests**: Write tests for your changes to ensure they are well-tested. See the [testing](#testing) section for more details.

- **Commit Changes**: Commit your changes with a meaningful commit message.

```bash
git add .
git commit -m "Description of your changes"
```

- **Push Changes**: Push your changes to your forked repository.

```bash
git push origin feature/your-feature-name
```

- **Open a Pull Request**: Open a pull request (PR) from your forked repository to the `next` branch of the original repository. Provide a clear description of your changes and any relevant information.

## Coding Standards

- **PEP 8**: Follow the PEP 8 style guide for Python code.
- **Docstrings**: Use `numpy` docstrings style to document all public modules, classes, and functions.
- **Type Annotations**: Use type annotations for function signatures.
- **Imports**: Group imports into standard library, third-party, and local module sections. Use absolute imports.

## Testing

We use `pytest` for testing. Ensure that your changes are covered by tests.

- **Run Tests**: Run the tests using `pytest`.

```bash
pytest -v
```

- **Check Coverage**: Check the test coverage.

```bash
pytest --cov=src/bm3dornl
```

## Submitting Changes

1. **Ensure Tests Pass**: Make sure all tests pass and the coverage is satisfactory.

2. **Update Documentation**: If your changes affect the documentation, update the relevant sections.

3. **Open a Pull Request**: Open a pull request with a clear description of your changes. Reference any related issues in your PR description.

4. **Review Process**: Your pull request will be reviewed by the maintainers. Be prepared to make changes based on feedback.

## Reporting Issues

If you find a bug or have a feature request, please open an issue on the [GitHub issues page](https://github.com/ornlneutronimaging/bm3dornl/issues).
Provide as much detail as possible, including steps to reproduce the issue if applicable.

## Contact

If you have any questions or need further assistance, please contact the [repo maintainer](zhangc@ornl.gov).

Thank you for contributing to BM3D-ORNL!
377 changes: 350 additions & 27 deletions notebooks/example.ipynb

Large diffs are not rendered by default.

87 changes: 25 additions & 62 deletions src/bm3dornl/block_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import numpy as np
from typing import Tuple, Optional
from bm3dornl.utils import (
compute_hyper_block,
compute_signal_blocks_matrix,
get_patch_numba,
get_signal_patch_positions,
find_candidate_patch_ids,
is_within_threshold,
pad_patch_ids,
)


Expand Down Expand Up @@ -74,8 +74,7 @@ def get_patch(
The patch extracted from the image.
"""
source_image = self._image if source_image is None else source_image
i, j = position
return source_image[i : i + self.patch_size[0], j : j + self.patch_size[1]]
return get_patch_numba(source_image, position, self.patch_size)

def group_signal_patches(
self, cut_off_distance: tuple, intensity_diff_threshold: float
Expand All @@ -96,36 +95,21 @@ def group_signal_patches(
# - the matrix is symmetric
# - the zero values means the patches are not similar
# - the non-zero values are the Euclidean distance between the patches, i.e smaller values means smaller distance, higher similarity
self.signal_blocks_matrix = np.zeros(
(num_patches, num_patches),
dtype=float,
)
self.signal_blocks_matrix = np.zeros((num_patches, num_patches), dtype=float)

# Cache patches as views
cached_patches = [self.get_patch(pos) for pos in self.signal_patches_pos]

for ref_patch_id in range(num_patches):
ref_patch = cached_patches[ref_patch_id]
candidate_patch_ids = find_candidate_patch_ids(
self.signal_patches_pos, ref_patch_id, cut_off_distance
)
# iterate over the candidate patches
for neightbor_patch_id in candidate_patch_ids:
if is_within_threshold(
ref_patch,
cached_patches[neightbor_patch_id],
intensity_diff_threshold,
):
val_diff = max(
np.linalg.norm(ref_patch - cached_patches[neightbor_patch_id]),
1e-8,
)
self.signal_blocks_matrix[ref_patch_id, neightbor_patch_id] = (
val_diff
)
self.signal_blocks_matrix[neightbor_patch_id, ref_patch_id] = (
val_diff
)
cached_patches = np.array(
[self.get_patch(pos) for pos in self.signal_patches_pos]
)

# Compute signal blocks matrix using Numba JIT for speed
compute_signal_blocks_matrix(
self.signal_blocks_matrix,
cached_patches,
np.array(self.signal_patches_pos),
np.array(cut_off_distance),
intensity_diff_threshold,
)

def get_hyper_block(
self,
Expand All @@ -150,35 +134,14 @@ def get_hyper_block(
-------
tuple
A tuple containing the 4D array of patch groups and the corresponding positions.
TODO:
-----
- use multi-processing to further improve the speed of block building
"""
group_size = len(self.signal_blocks_matrix)
block = np.empty(
(group_size, num_patches_per_group, *self.patch_size), dtype=np.float32
source_image = self._image if alternative_source is None else alternative_source
block, positions = compute_hyper_block(
self.signal_blocks_matrix,
np.array(self.signal_patches_pos),
self.patch_size,
num_patches_per_group,
source_image,
padding_mode,
)
positions = np.empty((group_size, num_patches_per_group, 2), dtype=np.int32)

for i, row in enumerate(self.signal_blocks_matrix):
# find the ids
candidate_patch_ids = np.where(row > 0)[0]
# get the difference
candidate_patch_val = row[candidate_patch_ids]
# sort candidate_patch_ids by candidate_patch_val, smallest first
candidate_patch_ids = candidate_patch_ids[np.argsort(candidate_patch_val)]
# pad the patch ids
padded_patch_ids = pad_patch_ids(
candidate_patch_ids, num_patches_per_group, mode=padding_mode
)
# update block and positions
block[i] = np.array(
[
self.get_patch(self.signal_patches_pos[idx], alternative_source)
for idx in padded_patch_ids
]
)
positions[i] = np.array(self.signal_patches_pos[padded_patch_ids])

return block, positions
6 changes: 0 additions & 6 deletions src/bm3dornl/denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,6 @@ def thresholding(
# Normalize by the weights to compute the average
self.estimate_denoised_image /= np.maximum(weights, 1)

# # update the patch manager with the new estimate
# self.patch_manager.background_threshold *= (
# 0.5 # reduce the threshold for background threshold further
# )
# self.patch_manager.image = self.estimate_denoised_image

def re_filtering(
self,
cut_off_distance: Tuple[int, int],
Expand Down
Loading

0 comments on commit c07bc7d

Please sign in to comment.