Skip to content

Commit

Permalink
ThresholdIntensity, ThresholdIntensityd (#2944)
Browse files Browse the repository at this point in the history
* ThresholdIntensity, ThresholdIntensityd

Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com>
  • Loading branch information
rijobro authored Sep 13, 2021
1 parent 132aa37 commit 8765fc7
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 36 deletions.
2 changes: 1 addition & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,4 +524,4 @@
weighted_patch_samples,
zero_margins,
)
from .utils_pytorch_numpy_unification import in1d, moveaxis
from .utils_pytorch_numpy_unification import in1d, moveaxis, where
12 changes: 8 additions & 4 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter
from monai.transforms.transform import RandomizableTransform, Transform
from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array
from monai.transforms.utils_pytorch_numpy_unification import where
from monai.utils import (
PT_BEFORE_1_7,
InvalidPyTorchVersionError,
Expand Down Expand Up @@ -655,20 +656,23 @@ class ThresholdIntensity(Transform):
cval: value to fill the remaining parts of the image, default is 0.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, threshold: float, above: bool = True, cval: float = 0.0) -> None:
if not isinstance(threshold, (int, float)):
raise ValueError("threshold must be a float or int number.")
self.threshold = threshold
self.above = above
self.cval = cval

def __call__(self, img: np.ndarray) -> np.ndarray:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
return np.asarray(
np.where(img > self.threshold if self.above else img < self.threshold, img, self.cval), dtype=img.dtype
)
mask = img > self.threshold if self.above else img < self.threshold
res = where(mask, img, self.cval)
res, *_ = convert_data_type(res, dtype=img.dtype)
return res


class ScaleIntensityRange(Transform):
Expand Down
4 changes: 3 additions & 1 deletion monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,8 @@ class ThresholdIntensityd(MapTransform):
allow_missing_keys: don't raise exception if key is missing.
"""

backend = ThresholdIntensity.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -675,7 +677,7 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.filter = ThresholdIntensity(threshold, above, cval)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.filter(d[key])
Expand Down
15 changes: 15 additions & 0 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
__all__ = [
"moveaxis",
"in1d",
"where",
]


Expand Down Expand Up @@ -50,3 +51,17 @@ def in1d(x, y):
if isinstance(x, np.ndarray):
return np.in1d(x, y)
return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1)


def where(condition: NdarrayOrTensor, x, y) -> NdarrayOrTensor:
"""
Note that `torch.where` may convert y.dtype to x.dtype.
"""
result: NdarrayOrTensor
if isinstance(condition, np.ndarray):
result = np.where(condition, x, y)
else:
x = torch.as_tensor(x, device=condition.device)
y = torch.as_tensor(y, device=condition.device, dtype=x.dtype)
result = torch.where(condition, x, y)
return result
19 changes: 10 additions & 9 deletions tests/test_threshold_intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,21 @@
from parameterized import parameterized

from monai.transforms import ThresholdIntensity
from tests.utils import TEST_NDARRAYS, assert_allclose

TEST_CASE_1 = [{"threshold": 5, "above": True, "cval": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)]

TEST_CASE_2 = [{"threshold": 5, "above": False, "cval": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)]

TEST_CASE_3 = [{"threshold": 5, "above": True, "cval": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)]
TESTS = []
for p in TEST_NDARRAYS:
TESTS.append([p, {"threshold": 5, "above": True, "cval": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)])
TESTS.append([p, {"threshold": 5, "above": False, "cval": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)])
TESTS.append([p, {"threshold": 5, "above": True, "cval": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)])


class TestThresholdIntensity(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_value(self, input_param, expected_value):
test_data = np.arange(10)
@parameterized.expand(TESTS)
def test_value(self, in_type, input_param, expected_value):
test_data = in_type(np.arange(10))
result = ThresholdIntensity(**input_param)(test_data)
np.testing.assert_allclose(result, expected_value)
assert_allclose(result, expected_value)


if __name__ == "__main__":
Expand Down
52 changes: 31 additions & 21 deletions tests/test_threshold_intensityd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,41 @@
from parameterized import parameterized

from monai.transforms import ThresholdIntensityd

TEST_CASE_1 = [
{"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 0},
(0, 0, 0, 0, 0, 0, 6, 7, 8, 9),
]

TEST_CASE_2 = [
{"keys": ["image", "label", "extra"], "threshold": 5, "above": False, "cval": 0},
(0, 1, 2, 3, 4, 0, 0, 0, 0, 0),
]

TEST_CASE_3 = [
{"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 5},
(5, 5, 5, 5, 5, 5, 6, 7, 8, 9),
]
from tests.utils import TEST_NDARRAYS, assert_allclose

TESTS = []
for p in TEST_NDARRAYS:
TESTS.append(
[
p,
{"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 0},
(0, 0, 0, 0, 0, 0, 6, 7, 8, 9),
]
)
TESTS.append(
[
p,
{"keys": ["image", "label", "extra"], "threshold": 5, "above": False, "cval": 0},
(0, 1, 2, 3, 4, 0, 0, 0, 0, 0),
]
)
TESTS.append(
[
p,
{"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 5},
(5, 5, 5, 5, 5, 5, 6, 7, 8, 9),
]
)


class TestThresholdIntensityd(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_value(self, input_param, expected_value):
test_data = {"image": np.arange(10), "label": np.arange(10), "extra": np.arange(10)}
@parameterized.expand(TESTS)
def test_value(self, in_type, input_param, expected_value):
test_data = {"image": in_type(np.arange(10)), "label": in_type(np.arange(10)), "extra": in_type(np.arange(10))}
result = ThresholdIntensityd(**input_param)(test_data)
np.testing.assert_allclose(result["image"], expected_value)
np.testing.assert_allclose(result["label"], expected_value)
np.testing.assert_allclose(result["extra"], expected_value)
assert_allclose(result["image"], expected_value)
assert_allclose(result["label"], expected_value)
assert_allclose(result["extra"], expected_value)


if __name__ == "__main__":
Expand Down

0 comments on commit 8765fc7

Please sign in to comment.