Skip to content

ThresholdIntensity, ThresholdIntensityd #2944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,4 +524,4 @@
weighted_patch_samples,
zero_margins,
)
from .utils_pytorch_numpy_unification import in1d, moveaxis
from .utils_pytorch_numpy_unification import in1d, moveaxis, where
12 changes: 8 additions & 4 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter
from monai.transforms.transform import RandomizableTransform, Transform
from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array
from monai.transforms.utils_pytorch_numpy_unification import where
from monai.utils import (
PT_BEFORE_1_7,
InvalidPyTorchVersionError,
Expand Down Expand Up @@ -655,20 +656,23 @@ class ThresholdIntensity(Transform):
cval: value to fill the remaining parts of the image, default is 0.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, threshold: float, above: bool = True, cval: float = 0.0) -> None:
if not isinstance(threshold, (int, float)):
raise ValueError("threshold must be a float or int number.")
self.threshold = threshold
self.above = above
self.cval = cval

def __call__(self, img: np.ndarray) -> np.ndarray:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
return np.asarray(
np.where(img > self.threshold if self.above else img < self.threshold, img, self.cval), dtype=img.dtype
)
mask = img > self.threshold if self.above else img < self.threshold
res = where(mask, img, self.cval)
res, *_ = convert_data_type(res, dtype=img.dtype)
return res


class ScaleIntensityRange(Transform):
Expand Down
4 changes: 3 additions & 1 deletion monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,8 @@ class ThresholdIntensityd(MapTransform):
allow_missing_keys: don't raise exception if key is missing.
"""

backend = ThresholdIntensity.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -675,7 +677,7 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.filter = ThresholdIntensity(threshold, above, cval)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.filter(d[key])
Expand Down
15 changes: 15 additions & 0 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
__all__ = [
"moveaxis",
"in1d",
"where",
]


Expand Down Expand Up @@ -50,3 +51,17 @@ def in1d(x, y):
if isinstance(x, np.ndarray):
return np.in1d(x, y)
return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1)


def where(condition: NdarrayOrTensor, x, y) -> NdarrayOrTensor:
"""
Note that `torch.where` may convert y.dtype to x.dtype.
"""
result: NdarrayOrTensor
if isinstance(condition, np.ndarray):
result = np.where(condition, x, y)
else:
x = torch.as_tensor(x, device=condition.device)
y = torch.as_tensor(y, device=condition.device, dtype=x.dtype)
result = torch.where(condition, x, y)
return result
19 changes: 10 additions & 9 deletions tests/test_threshold_intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,21 @@
from parameterized import parameterized

from monai.transforms import ThresholdIntensity
from tests.utils import TEST_NDARRAYS, assert_allclose

TEST_CASE_1 = [{"threshold": 5, "above": True, "cval": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)]

TEST_CASE_2 = [{"threshold": 5, "above": False, "cval": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)]

TEST_CASE_3 = [{"threshold": 5, "above": True, "cval": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)]
TESTS = []
for p in TEST_NDARRAYS:
TESTS.append([p, {"threshold": 5, "above": True, "cval": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)])
TESTS.append([p, {"threshold": 5, "above": False, "cval": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)])
TESTS.append([p, {"threshold": 5, "above": True, "cval": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)])


class TestThresholdIntensity(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_value(self, input_param, expected_value):
test_data = np.arange(10)
@parameterized.expand(TESTS)
def test_value(self, in_type, input_param, expected_value):
test_data = in_type(np.arange(10))
result = ThresholdIntensity(**input_param)(test_data)
np.testing.assert_allclose(result, expected_value)
assert_allclose(result, expected_value)


if __name__ == "__main__":
Expand Down
52 changes: 31 additions & 21 deletions tests/test_threshold_intensityd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,41 @@
from parameterized import parameterized

from monai.transforms import ThresholdIntensityd

TEST_CASE_1 = [
{"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 0},
(0, 0, 0, 0, 0, 0, 6, 7, 8, 9),
]

TEST_CASE_2 = [
{"keys": ["image", "label", "extra"], "threshold": 5, "above": False, "cval": 0},
(0, 1, 2, 3, 4, 0, 0, 0, 0, 0),
]

TEST_CASE_3 = [
{"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 5},
(5, 5, 5, 5, 5, 5, 6, 7, 8, 9),
]
from tests.utils import TEST_NDARRAYS, assert_allclose

TESTS = []
for p in TEST_NDARRAYS:
TESTS.append(
[
p,
{"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 0},
(0, 0, 0, 0, 0, 0, 6, 7, 8, 9),
]
)
TESTS.append(
[
p,
{"keys": ["image", "label", "extra"], "threshold": 5, "above": False, "cval": 0},
(0, 1, 2, 3, 4, 0, 0, 0, 0, 0),
]
)
TESTS.append(
[
p,
{"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 5},
(5, 5, 5, 5, 5, 5, 6, 7, 8, 9),
]
)


class TestThresholdIntensityd(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_value(self, input_param, expected_value):
test_data = {"image": np.arange(10), "label": np.arange(10), "extra": np.arange(10)}
@parameterized.expand(TESTS)
def test_value(self, in_type, input_param, expected_value):
test_data = {"image": in_type(np.arange(10)), "label": in_type(np.arange(10)), "extra": in_type(np.arange(10))}
result = ThresholdIntensityd(**input_param)(test_data)
np.testing.assert_allclose(result["image"], expected_value)
np.testing.assert_allclose(result["label"], expected_value)
np.testing.assert_allclose(result["extra"], expected_value)
assert_allclose(result["image"], expected_value)
assert_allclose(result["label"], expected_value)
assert_allclose(result["extra"], expected_value)


if __name__ == "__main__":
Expand Down