Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ScaleIntensityRange, ScaleIntensityRanged, ScaleIntensityRangePercentiles, ScaleIntensityRangePercentilesd #2943

Merged
merged 8 commits into from
Sep 14, 2021
2 changes: 1 addition & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,4 +524,4 @@
weighted_patch_samples,
zero_margins,
)
from .utils_pytorch_numpy_unification import in1d, moveaxis
from .utils_pytorch_numpy_unification import clip, in1d, moveaxis, percentile
17 changes: 11 additions & 6 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter
from monai.transforms.transform import RandomizableTransform, Transform
from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile
from monai.utils import (
PT_BEFORE_1_7,
InvalidPyTorchVersionError,
Expand Down Expand Up @@ -684,14 +685,16 @@ class ScaleIntensityRange(Transform):
clip: whether to perform clip after scaling.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, a_min: float, a_max: float, b_min: float, b_max: float, clip: bool = False) -> None:
self.a_min = a_min
self.a_max = a_max
self.b_min = b_min
self.b_max = b_max
self.clip = clip

def __call__(self, img: np.ndarray):
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
Expand All @@ -702,7 +705,7 @@ def __call__(self, img: np.ndarray):
img = (img - self.a_min) / (self.a_max - self.a_min)
img = img * (self.b_max - self.b_min) + self.b_min
if self.clip:
img = np.asarray(np.clip(img, self.b_min, self.b_max))
img = clip(img, self.b_min, self.b_max)
return img


Expand Down Expand Up @@ -831,6 +834,8 @@ class ScaleIntensityRangePercentiles(Transform):
relative: whether to scale to the corresponding percentiles of [b_min, b_max].
"""

backend = ScaleIntensityRange.backend

def __init__(
self, lower: float, upper: float, b_min: float, b_max: float, clip: bool = False, relative: bool = False
) -> None:
Expand All @@ -845,12 +850,12 @@ def __init__(
self.clip = clip
self.relative = relative

def __call__(self, img: np.ndarray):
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
a_min = np.percentile(img, self.lower)
a_max = np.percentile(img, self.upper)
a_min: float = percentile(img, self.lower)
a_max: float = percentile(img, self.upper)
b_min = self.b_min
b_max = self.b_max

Expand All @@ -862,7 +867,7 @@ def __call__(self, img: np.ndarray):
img = scalar(img)

if self.clip:
img = np.asarray(np.clip(img, self.b_min, self.b_max))
img = clip(img, self.b_min, self.b_max)

return img

Expand Down
8 changes: 6 additions & 2 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,8 @@ class ScaleIntensityRanged(MapTransform):
allow_missing_keys: don't raise exception if key is missing.
"""

backend = ScaleIntensityRange.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -710,7 +712,7 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.scaler(d[key])
Expand Down Expand Up @@ -814,6 +816,8 @@ class ScaleIntensityRangePercentilesd(MapTransform):
allow_missing_keys: don't raise exception if key is missing.
"""

backend = ScaleIntensityRangePercentiles.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -828,7 +832,7 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.scaler(d[key])
Expand Down
22 changes: 22 additions & 0 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
__all__ = [
"moveaxis",
"in1d",
"clip",
"percentile",
]


Expand Down Expand Up @@ -50,3 +52,23 @@ def in1d(x, y):
if isinstance(x, np.ndarray):
return np.in1d(x, y)
return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1)


def clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor:
"""`np.clip` with equivalent implementation for torch."""
result: NdarrayOrTensor
if isinstance(a, np.ndarray):
result = np.clip(a, a_min, a_max)
else:
result = torch.clip(a, a_min, a_max)
return result


def percentile(x: NdarrayOrTensor, q):
"""`np.percentile` with equivalent implementation for torch."""
result: NdarrayOrTensor
if isinstance(x, np.ndarray):
result = np.percentile(x, q)
else:
result = torch.quantile(x, q / 100.0)
rijobro marked this conversation as resolved.
Show resolved Hide resolved
return result
13 changes: 6 additions & 7 deletions tests/test_scale_intensity_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,18 @@

import unittest

import numpy as np

from monai.transforms import ScaleIntensityRange
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


class IntensityScaleIntensityRange(NumpyImageTestCase2D):
def test_image_scale_intensity_range(self):
scaler = ScaleIntensityRange(a_min=20, a_max=108, b_min=50, b_max=80)
scaled = scaler(self.imt)
expected = (self.imt - 20) / 88
expected = expected * 30 + 50
self.assertTrue(np.allclose(scaled, expected))
for p in TEST_NDARRAYS:
scaled = scaler(p(self.imt))
expected = (self.imt - 20) / 88
expected = expected * 30 + 50
assert_allclose(scaled, expected)


if __name__ == "__main__":
Expand Down
10 changes: 7 additions & 3 deletions tests/test_scale_intensity_range_percentiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np

from monai.transforms.intensity.array import ScaleIntensityRangePercentiles
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


class TestScaleIntensityRangePercentiles(NumpyImageTestCase2D):
Expand All @@ -30,7 +30,9 @@ def test_scaling(self):
expected = (img - a_min) / (a_max - a_min)
expected = (expected * (b_max - b_min)) + b_min
scaler = ScaleIntensityRangePercentiles(lower=lower, upper=upper, b_min=b_min, b_max=b_max)
self.assertTrue(np.allclose(expected, scaler(img)))
for p in TEST_NDARRAYS:
result = scaler(p(img))
assert_allclose(expected, result)

def test_relative_scaling(self):
img = self.imt
Expand All @@ -47,7 +49,9 @@ def test_relative_scaling(self):
expected_img = (img - expected_a_min) / (expected_a_max - expected_a_min)
expected_img = (expected_img * (expected_b_max - expected_b_min)) + expected_b_min

self.assertTrue(np.allclose(expected_img, scaler(img)))
for p in TEST_NDARRAYS:
result = scaler(p(img))
assert_allclose(expected_img, result)

def test_invalid_instantiation(self):
self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=-10, upper=99, b_min=0, b_max=255)
Expand Down
13 changes: 6 additions & 7 deletions tests/test_scale_intensity_ranged.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,19 @@

import unittest

import numpy as np

from monai.transforms import ScaleIntensityRanged
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


class IntensityScaleIntensityRanged(NumpyImageTestCase2D):
def test_image_scale_intensity_ranged(self):
key = "img"
scaler = ScaleIntensityRanged(keys=key, a_min=20, a_max=108, b_min=50, b_max=80)
scaled = scaler({key: self.imt})
expected = (self.imt - 20) / 88
expected = expected * 30 + 50
self.assertTrue(np.allclose(scaled[key], expected))
for p in TEST_NDARRAYS:
scaled = scaler({key: p(self.imt)})
expected = (self.imt - 20) / 88
expected = expected * 30 + 50
assert_allclose(scaled[key], expected)


if __name__ == "__main__":
Expand Down