Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ScaleIntensityRange, ScaleIntensityRanged, ScaleIntensityRangePercentiles, ScaleIntensityRangePercentilesd #2943

Merged
merged 8 commits into from
Sep 14, 2021
2 changes: 1 addition & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,4 +524,4 @@
weighted_patch_samples,
zero_margins,
)
from .utils_pytorch_numpy_unification import in1d, moveaxis, where
from .utils_pytorch_numpy_unification import clip, in1d, moveaxis, percentile, where
18 changes: 11 additions & 7 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter
from monai.transforms.transform import RandomizableTransform, Transform
from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array
from monai.transforms.utils_pytorch_numpy_unification import where
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where
from monai.utils import (
PT_BEFORE_1_7,
InvalidPyTorchVersionError,
Expand Down Expand Up @@ -688,14 +688,16 @@ class ScaleIntensityRange(Transform):
clip: whether to perform clip after scaling.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, a_min: float, a_max: float, b_min: float, b_max: float, clip: bool = False) -> None:
self.a_min = a_min
self.a_max = a_max
self.b_min = b_min
self.b_max = b_max
self.clip = clip

def __call__(self, img: np.ndarray):
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
Expand All @@ -706,7 +708,7 @@ def __call__(self, img: np.ndarray):
img = (img - self.a_min) / (self.a_max - self.a_min)
img = img * (self.b_max - self.b_min) + self.b_min
if self.clip:
img = np.asarray(np.clip(img, self.b_min, self.b_max))
img = clip(img, self.b_min, self.b_max)
return img


Expand Down Expand Up @@ -835,6 +837,8 @@ class ScaleIntensityRangePercentiles(Transform):
relative: whether to scale to the corresponding percentiles of [b_min, b_max].
"""

backend = ScaleIntensityRange.backend

def __init__(
self, lower: float, upper: float, b_min: float, b_max: float, clip: bool = False, relative: bool = False
) -> None:
Expand All @@ -849,12 +853,12 @@ def __init__(
self.clip = clip
self.relative = relative

def __call__(self, img: np.ndarray):
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
a_min = np.percentile(img, self.lower)
a_max = np.percentile(img, self.upper)
a_min: float = percentile(img, self.lower) # type: ignore
a_max: float = percentile(img, self.upper) # type: ignore
b_min = self.b_min
b_max = self.b_max

Expand All @@ -866,7 +870,7 @@ def __call__(self, img: np.ndarray):
img = scalar(img)

if self.clip:
img = np.asarray(np.clip(img, self.b_min, self.b_max))
img = clip(img, self.b_min, self.b_max)

return img

Expand Down
8 changes: 6 additions & 2 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,8 @@ class ScaleIntensityRanged(MapTransform):
allow_missing_keys: don't raise exception if key is missing.
"""

backend = ScaleIntensityRange.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -712,7 +714,7 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.scaler(d[key])
Expand Down Expand Up @@ -816,6 +818,8 @@ class ScaleIntensityRangePercentilesd(MapTransform):
allow_missing_keys: don't raise exception if key is missing.
"""

backend = ScaleIntensityRangePercentiles.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -830,7 +834,7 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.scaler(d[key])
Expand Down
54 changes: 54 additions & 0 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import numpy as np
import torch

Expand All @@ -17,6 +19,8 @@
__all__ = [
"moveaxis",
"in1d",
"clip",
"percentile",
"where",
]

Expand Down Expand Up @@ -53,6 +57,56 @@ def in1d(x, y):
return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1)


def clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor:
"""`np.clip` with equivalent implementation for torch."""
result: NdarrayOrTensor
if isinstance(a, np.ndarray):
result = np.clip(a, a_min, a_max)
else:
result = torch.clip(a, a_min, a_max)
return result


def percentile(x: NdarrayOrTensor, q) -> Union[NdarrayOrTensor, float, int]:
"""`np.percentile` with equivalent implementation for torch.

Pytorch uses `quantile`, but this functionality is only available from v1.7.
For earlier methods, we calculate it ourselves. This doesn't do interpolation,
so is the equivalent of ``numpy.percentile(..., interpolation="nearest")``.

Args:
x: input data
q: percentile to compute (should in range 0 <= q <= 100)

Returns:
Resulting value (scalar)
"""
if np.isscalar(q):
if not 0 <= q <= 100:
raise ValueError
else:
if any(q < 0) or any(q > 100):
raise ValueError
result: Union[NdarrayOrTensor, float, int]
if isinstance(x, np.ndarray):
result = np.percentile(x, q)
else:
q = torch.tensor(q, device=x.device)
if hasattr(torch, "quantile"):
result = torch.quantile(x, q / 100.0)
else:
# Note that ``kthvalue()`` works one-based, i.e., the first sorted value
# corresponds to k=1, not k=0. Thus, we need the `1 +`.
k = 1 + (0.01 * q * (x.numel() - 1)).round().int()
if k.numel() > 1:
r = [x.view(-1).kthvalue(int(_k)).values.item() for _k in k]
result = torch.tensor(r, device=x.device)
else:
result = x.view(-1).kthvalue(int(k)).values.item()

return result


def where(condition: NdarrayOrTensor, x, y) -> NdarrayOrTensor:
"""
Note that `torch.where` may convert y.dtype to x.dtype.
Expand Down
13 changes: 6 additions & 7 deletions tests/test_scale_intensity_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,18 @@

import unittest

import numpy as np

from monai.transforms import ScaleIntensityRange
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


class IntensityScaleIntensityRange(NumpyImageTestCase2D):
def test_image_scale_intensity_range(self):
scaler = ScaleIntensityRange(a_min=20, a_max=108, b_min=50, b_max=80)
scaled = scaler(self.imt)
expected = (self.imt - 20) / 88
expected = expected * 30 + 50
self.assertTrue(np.allclose(scaled, expected))
for p in TEST_NDARRAYS:
scaled = scaler(p(self.imt))
expected = (self.imt - 20) / 88
expected = expected * 30 + 50
assert_allclose(scaled, expected)


if __name__ == "__main__":
Expand Down
10 changes: 7 additions & 3 deletions tests/test_scale_intensity_range_percentiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np

from monai.transforms.intensity.array import ScaleIntensityRangePercentiles
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


class TestScaleIntensityRangePercentiles(NumpyImageTestCase2D):
Expand All @@ -30,7 +30,9 @@ def test_scaling(self):
expected = (img - a_min) / (a_max - a_min)
expected = (expected * (b_max - b_min)) + b_min
scaler = ScaleIntensityRangePercentiles(lower=lower, upper=upper, b_min=b_min, b_max=b_max)
self.assertTrue(np.allclose(expected, scaler(img)))
for p in TEST_NDARRAYS:
result = scaler(p(img))
assert_allclose(expected, result)

def test_relative_scaling(self):
img = self.imt
Expand All @@ -47,7 +49,9 @@ def test_relative_scaling(self):
expected_img = (img - expected_a_min) / (expected_a_max - expected_a_min)
expected_img = (expected_img * (expected_b_max - expected_b_min)) + expected_b_min

self.assertTrue(np.allclose(expected_img, scaler(img)))
for p in TEST_NDARRAYS:
result = scaler(p(img))
assert_allclose(expected_img, result)

def test_invalid_instantiation(self):
self.assertRaises(ValueError, ScaleIntensityRangePercentiles, lower=-10, upper=99, b_min=0, b_max=255)
Expand Down
13 changes: 6 additions & 7 deletions tests/test_scale_intensity_ranged.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,19 @@

import unittest

import numpy as np

from monai.transforms import ScaleIntensityRanged
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


class IntensityScaleIntensityRanged(NumpyImageTestCase2D):
def test_image_scale_intensity_ranged(self):
key = "img"
scaler = ScaleIntensityRanged(keys=key, a_min=20, a_max=108, b_min=50, b_max=80)
scaled = scaler({key: self.imt})
expected = (self.imt - 20) / 88
expected = expected * 30 + 50
self.assertTrue(np.allclose(scaled[key], expected))
for p in TEST_NDARRAYS:
scaled = scaler({key: p(self.imt)})
expected = (self.imt - 20) / 88
expected = expected * 30 + 50
assert_allclose(scaled[key], expected)


if __name__ == "__main__":
Expand Down
46 changes: 46 additions & 0 deletions tests/test_utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import torch

from monai.transforms.utils_pytorch_numpy_unification import percentile
from tests.utils import TEST_NDARRAYS, assert_allclose, set_determinism


class TestPytorchNumpyUnification(unittest.TestCase):
def setUp(self) -> None:
set_determinism(0)

def test_percentile(self):
for size in (1, 100):
q = np.random.randint(0, 100, size=size)
results = []
for p in TEST_NDARRAYS:
arr = p(np.arange(100 * 101).reshape(1, 100, 101).astype(np.float32))
results.append(percentile(arr, q))
# pre torch 1.7, no `quantile`. Our own method doesn't interpolate,
# so we can only be accurate to 0.5
atol = 0.5 if not hasattr(torch, "quantile") else 1e-4
assert_allclose(results[0], results[-1], atol=atol)

def test_fails(self):
for p in TEST_NDARRAYS:
for q in (-1, 101):
arr = p(np.arange(100 * 101).reshape(1, 100, 101).astype(np.float32))
with self.assertRaises(ValueError):
percentile(arr, q)


if __name__ == "__main__":
unittest.main()