Skip to content

Commit

Permalink
ScaleIntensityRange, ScaleIntensityRanged, ScaleIntensityRangePercent…
Browse files Browse the repository at this point in the history
…iles, ScaleIntensityRangePercentilesd (#2943)

ScaleIntensityRange, ScaleIntensityRanged, ScaleIntensityRangePercentiles, ScaleIntensityRangePercentilesd
  • Loading branch information
rijobro authored Sep 14, 2021
1 parent 7ab0711 commit f26a712
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 27 deletions.
2 changes: 1 addition & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,4 +524,4 @@
weighted_patch_samples,
zero_margins,
)
from .utils_pytorch_numpy_unification import in1d, moveaxis, where
from .utils_pytorch_numpy_unification import clip, in1d, moveaxis, percentile, where
18 changes: 11 additions & 7 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter
from monai.transforms.transform import RandomizableTransform, Transform
from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array
from monai.transforms.utils_pytorch_numpy_unification import where
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where
from monai.utils import (
PT_BEFORE_1_7,
InvalidPyTorchVersionError,
Expand Down Expand Up @@ -688,14 +688,16 @@ class ScaleIntensityRange(Transform):
clip: whether to perform clip after scaling.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, a_min: float, a_max: float, b_min: float, b_max: float, clip: bool = False) -> None:
self.a_min = a_min
self.a_max = a_max
self.b_min = b_min
self.b_max = b_max
self.clip = clip

def __call__(self, img: np.ndarray):
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
Expand All @@ -706,7 +708,7 @@ def __call__(self, img: np.ndarray):
img = (img - self.a_min) / (self.a_max - self.a_min)
img = img * (self.b_max - self.b_min) + self.b_min
if self.clip:
img = np.asarray(np.clip(img, self.b_min, self.b_max))
img = clip(img, self.b_min, self.b_max)
return img


Expand Down Expand Up @@ -835,6 +837,8 @@ class ScaleIntensityRangePercentiles(Transform):
relative: whether to scale to the corresponding percentiles of [b_min, b_max].
"""

backend = ScaleIntensityRange.backend

def __init__(
self, lower: float, upper: float, b_min: float, b_max: float, clip: bool = False, relative: bool = False
) -> None:
Expand All @@ -849,12 +853,12 @@ def __init__(
self.clip = clip
self.relative = relative

def __call__(self, img: np.ndarray):
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
a_min = np.percentile(img, self.lower)
a_max = np.percentile(img, self.upper)
a_min: float = percentile(img, self.lower) # type: ignore
a_max: float = percentile(img, self.upper) # type: ignore
b_min = self.b_min
b_max = self.b_max

Expand All @@ -866,7 +870,7 @@ def __call__(self, img: np.ndarray):
img = scalar(img)

if self.clip:
img = np.asarray(np.clip(img, self.b_min, self.b_max))
img = clip(img, self.b_min, self.b_max)

return img

Expand Down
8 changes: 6 additions & 2 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,8 @@ class ScaleIntensityRanged(MapTransform):
allow_missing_keys: don't raise exception if key is missing.
"""

backend = ScaleIntensityRange.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -712,7 +714,7 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.scaler(d[key])
Expand Down Expand Up @@ -816,6 +818,8 @@ class ScaleIntensityRangePercentilesd(MapTransform):
allow_missing_keys: don't raise exception if key is missing.
"""

backend = ScaleIntensityRangePercentiles.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -830,7 +834,7 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.scaler(d[key])
Expand Down
54 changes: 54 additions & 0 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import numpy as np
import torch

Expand All @@ -17,6 +19,8 @@
__all__ = [
"moveaxis",
"in1d",
"clip",
"percentile",
"where",
]

Expand Down Expand Up @@ -53,6 +57,56 @@ def in1d(x, y):
return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1)


def clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor:
"""`np.clip` with equivalent implementation for torch."""
result: NdarrayOrTensor
if isinstance(a, np.ndarray):
result = np.clip(a, a_min, a_max)
else:
result = torch.clip(a, a_min, a_max)
return result


def percentile(x: NdarrayOrTensor, q) -> Union[NdarrayOrTensor, float, int]:
"""`np.percentile` with equivalent implementation for torch.
Pytorch uses `quantile`, but this functionality is only available from v1.7.
For earlier methods, we calculate it ourselves. This doesn't do interpolation,
so is the equivalent of ``numpy.percentile(..., interpolation="nearest")``.
Args:
x: input data
q: percentile to compute (should in range 0 <= q <= 100)
Returns:
Resulting value (scalar)
"""
if np.isscalar(q):
if not 0 <= q <= 100:
raise ValueError
else:
if any(q < 0) or any(q > 100):
raise ValueError
result: Union[NdarrayOrTensor, float, int]
if isinstance(x, np.ndarray):
result = np.percentile(x, q)
else:
q = torch.tensor(q, device=x.device)
if hasattr(torch, "quantile"):
result = torch.quantile(x, q / 100.0)
else:
# Note that ``kthvalue()`` works one-based, i.e., the first sorted value
# corresponds to k=1, not k=0. Thus, we need the `1 +`.
k = 1 + (0.01 * q * (x.numel() - 1)).round().int()
if k.numel() > 1:
r = [x.view(-1).kthvalue(int(_k)).values.item() for _k in k]
result = torch.tensor(r, device=x.device)
else:
result = x.view(-1).kthvalue(int(k)).values.item()

return result


def where(condition: NdarrayOrTensor, x, y) -> NdarrayOrTensor:
"""
Note that `torch.where` may convert y.dtype to x.dtype.
Expand Down
13 changes: 6 additions & 7 deletions tests/test_scale_intensity_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,18 @@

import unittest

import numpy as np

from monai.transforms import ScaleIntensityRange
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


class IntensityScaleIntensityRange(NumpyImageTestCase2D):
def test_image_scale_intensity_range(self):
scaler = ScaleIntensityRange(a_min=20, a_max=108, b_min=50, b_max=80)
scaled = scaler(self.imt)
expected = (self.imt - 20) / 88
expected = expected * 30 + 50
self.assertTrue(np.allclose(scaled, expected))
for p in TEST_NDARRAYS:
scaled = scaler(p(self.imt))
expected = (self.imt - 20) / 88
expected = expected * 30 + 50
assert_allclose(scaled, expected)


if __name__ == "__main__":
Expand Down
10 changes: 7 additions & 3 deletions tests/test_scale_intensity_range_percentiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np

from monai.transforms.intensity.array import ScaleIntensityRangePercentiles
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


class TestScaleIntensityRangePercentiles(NumpyImageTestCase2D):
Expand All @@ -30,7 +30,9 @@ def test_scaling(self):
expected = (img - a_min) / (a_max - a_min)
expected = (expected * (b_max - b_min)) + b_min
scaler = ScaleIntensityRangePercentiles(lower=lower, upper=upper, b_min=b_min, b_max=b_max)
self.assertTrue(np.allclose(expected, scaler(img)))
for p in TEST_NDARRAYS:
result = scaler(p(img))
assert_allclose(expected, result)

def test_relative_scaling(self):
img = self.imt
Expand All @@ -47,7 +49,9 @@ def test_relative_scaling(self):
expected_img = (img - expected_a_min) / (expected_a_max - expected_a_min)
expected_img = (expected_img * (expected_b_max - expected_b_min)) + expected_b_min

self.assertTrue(np.allclose(expected_img, scaler(img)))
for p in TEST_NDARRAYS:
result = scaler(p(img))
assert_allclose(expected_img, result)

def test_invalid_instantiation(self):
self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=-10, upper=99, b_min=0, b_max=255)
Expand Down
13 changes: 6 additions & 7 deletions tests/test_scale_intensity_ranged.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,19 @@

import unittest

import numpy as np

from monai.transforms import ScaleIntensityRanged
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


class IntensityScaleIntensityRanged(NumpyImageTestCase2D):
def test_image_scale_intensity_ranged(self):
key = "img"
scaler = ScaleIntensityRanged(keys=key, a_min=20, a_max=108, b_min=50, b_max=80)
scaled = scaler({key: self.imt})
expected = (self.imt - 20) / 88
expected = expected * 30 + 50
self.assertTrue(np.allclose(scaled[key], expected))
for p in TEST_NDARRAYS:
scaled = scaler({key: p(self.imt)})
expected = (self.imt - 20) / 88
expected = expected * 30 + 50
assert_allclose(scaled[key], expected)


if __name__ == "__main__":
Expand Down
46 changes: 46 additions & 0 deletions tests/test_utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import torch

from monai.transforms.utils_pytorch_numpy_unification import percentile
from tests.utils import TEST_NDARRAYS, assert_allclose, set_determinism


class TestPytorchNumpyUnification(unittest.TestCase):
def setUp(self) -> None:
set_determinism(0)

def test_percentile(self):
for size in (1, 100):
q = np.random.randint(0, 100, size=size)
results = []
for p in TEST_NDARRAYS:
arr = p(np.arange(100 * 101).reshape(1, 100, 101).astype(np.float32))
results.append(percentile(arr, q))
# pre torch 1.7, no `quantile`. Our own method doesn't interpolate,
# so we can only be accurate to 0.5
atol = 0.5 if not hasattr(torch, "quantile") else 1e-4
assert_allclose(results[0], results[-1], atol=atol)

def test_fails(self):
for p in TEST_NDARRAYS:
for q in (-1, 101):
arr = p(np.arange(100 * 101).reshape(1, 100, 101).astype(np.float32))
with self.assertRaises(ValueError):
percentile(arr, q)


if __name__ == "__main__":
unittest.main()

0 comments on commit f26a712

Please sign in to comment.