Skip to content

Commit

Permalink
Improve performance for NormalizeIntensity (#6887)
Browse files Browse the repository at this point in the history
### Description
In order to implement the "nonzero" functionality of the
NormalizeIntensity transform a mask is used. In case nonzero is False,
the mask is still used, but is initialized to all True/1. This
unecessary masking causes a considerable performance hit. The changed
implementation forgoes using the mask in case nonzero is False. I ran a
quick benchmark on my system comparing the old implementation, the new
implementation and the normalization using the wrapper around the
torchvision normalize transform. The results were the following, showing
a more than 10x performance improvement (notice the times for the old
normalize are in milliseconds, the other times are in microseconds):

> [-------------- torchvision ---------------]
>                        |    cpu    |   cuda 
> 1 threads: ---------------------------------
>       (250, 250, 250)  |  18847.2  |  1440.5
>       (100, 100, 100)  |    484.6  |   395.5
> 
> Times are in microseconds (us).
> 
> [--------------- monai ----------------]
>                        |   cpu   |  cuda
> 1 threads: -----------------------------
>       (250, 250, 250)  |  603.7  |  11.5
>       (100, 100, 100)  |   39.9  |   1.5
> 
> Times are in milliseconds (ms).
> 
> [------------- monai_improved ------------]
>                        |    cpu    |   cuda
> 1 threads: --------------------------------
>       (250, 250, 250)  |  17763.2  |  720.0
>       (100, 100, 100)  |    938.0  |  185.2
> 
> Times are in microseconds (us).

The benchmarks were created with the following code (the
ImprovedNormalizeIntensity class does not exist in the PR, this was my
quick fix to have both the old and the new implementation available)
```python
import torch.utils.benchmark as benchmark
import torch
from monai.transforms import TorchVision
from monai.transforms.intensity.array import ImprovedNormalizeIntensity, NormalizeIntensity

shapes = [
      (250, 250, 250),
      (100,100,100)
      ]

normalizers = {
    'torchvision': TorchVision(name="Normalize", mean=1000, std=333),
    'monai': NormalizeIntensity(subtrahend=1000, divisor=333),
    'monai_improved': ImprovedNormalizeIntensity(subtrahend=1000, divisor=333),
}
results = []
for shape in shapes:
    for device in ['cpu', 'cuda']:
        torch_tensor = torch.rand((1,1)+shape).to(device)

        for name, normalizer in normalizers.items():
            t = benchmark.Timer(
                stmt='normalizer(x)',
                globals={'normalizer': normalizer , 'x': torch_tensor},
                label=name,
                sub_label=str(shape),
                description=device,
                num_threads=1,
                )
            results.append(t.blocked_autorange(min_run_time=10))

compare = benchmark.Compare(results)
compare.print()
```
 

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: John Zielke <john.zielke@snkeos.com>
  • Loading branch information
john-zielke-snkeos authored Aug 18, 2023
1 parent 6e47140 commit 8aabdc9
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,29 +839,33 @@ def _normalize(self, img: NdarrayOrTensor, sub=None, div=None) -> NdarrayOrTenso

if self.nonzero:
slices = img != 0
masked_img = img[slices]
if not slices.any():
return img
else:
if isinstance(img, np.ndarray):
slices = np.ones_like(img, dtype=bool)
else:
slices = torch.ones_like(img, dtype=torch.bool)
if not slices.any():
return img
slices = None
masked_img = img

_sub = sub if sub is not None else self._mean(img[slices])
_sub = sub if sub is not None else self._mean(masked_img)
if isinstance(_sub, (torch.Tensor, np.ndarray)):
_sub, *_ = convert_to_dst_type(_sub, img)
_sub = _sub[slices]
if slices is not None:
_sub = _sub[slices]

_div = div if div is not None else self._std(img[slices])
_div = div if div is not None else self._std(masked_img)
if np.isscalar(_div):
if _div == 0.0:
_div = 1.0
elif isinstance(_div, (torch.Tensor, np.ndarray)):
_div, *_ = convert_to_dst_type(_div, img)
_div = _div[slices]
if slices is not None:
_div = _div[slices]
_div[_div == 0.0] = 1.0

img[slices] = (img[slices] - _sub) / _div
if slices is not None:
img[slices] = (masked_img - _sub) / _div
else:
img = (img - _sub) / _div
return img

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
Expand Down

0 comments on commit 8aabdc9

Please sign in to comment.