diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index a07dee867b..9575a412b4 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -524,4 +524,4 @@ weighted_patch_samples, zero_margins, ) -from .utils_pytorch_numpy_unification import in1d, moveaxis +from .utils_pytorch_numpy_unification import in1d, moveaxis, where diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 311534aa8b..b6fa1f72b7 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -28,6 +28,7 @@ from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter from monai.transforms.transform import RandomizableTransform, Transform from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array +from monai.transforms.utils_pytorch_numpy_unification import where from monai.utils import ( PT_BEFORE_1_7, InvalidPyTorchVersionError, @@ -655,6 +656,8 @@ class ThresholdIntensity(Transform): cval: value to fill the remaining parts of the image, default is 0. """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, threshold: float, above: bool = True, cval: float = 0.0) -> None: if not isinstance(threshold, (int, float)): raise ValueError("threshold must be a float or int number.") @@ -662,13 +665,14 @@ def __init__(self, threshold: float, above: bool = True, cval: float = 0.0) -> N self.above = above self.cval = cval - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - return np.asarray( - np.where(img > self.threshold if self.above else img < self.threshold, img, self.cval), dtype=img.dtype - ) + mask = img > self.threshold if self.above else img < self.threshold + res = where(mask, img, self.cval) + res, *_ = convert_data_type(res, dtype=img.dtype) + return res class ScaleIntensityRange(Transform): diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 22b1edd5fd..cccf3e2a90 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -664,6 +664,8 @@ class ThresholdIntensityd(MapTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = ThresholdIntensity.backend + def __init__( self, keys: KeysCollection, @@ -675,7 +677,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.filter = ThresholdIntensity(threshold, above, cval) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.filter(d[key]) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 2eebe3eda3..70ecb2848d 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -17,6 +17,7 @@ __all__ = [ "moveaxis", "in1d", + "where", ] @@ -50,3 +51,17 @@ def in1d(x, y): if isinstance(x, np.ndarray): return np.in1d(x, y) return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1) + + +def where(condition: NdarrayOrTensor, x, y) -> NdarrayOrTensor: + """ + Note that `torch.where` may convert y.dtype to x.dtype. + """ + result: NdarrayOrTensor + if isinstance(condition, np.ndarray): + result = np.where(condition, x, y) + else: + x = torch.as_tensor(x, device=condition.device) + y = torch.as_tensor(y, device=condition.device, dtype=x.dtype) + result = torch.where(condition, x, y) + return result diff --git a/tests/test_threshold_intensity.py b/tests/test_threshold_intensity.py index a6d3895709..0614514456 100644 --- a/tests/test_threshold_intensity.py +++ b/tests/test_threshold_intensity.py @@ -15,20 +15,21 @@ from parameterized import parameterized from monai.transforms import ThresholdIntensity +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [{"threshold": 5, "above": True, "cval": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)] - -TEST_CASE_2 = [{"threshold": 5, "above": False, "cval": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)] - -TEST_CASE_3 = [{"threshold": 5, "above": True, "cval": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append([p, {"threshold": 5, "above": True, "cval": 0}, (0, 0, 0, 0, 0, 0, 6, 7, 8, 9)]) + TESTS.append([p, {"threshold": 5, "above": False, "cval": 0}, (0, 1, 2, 3, 4, 0, 0, 0, 0, 0)]) + TESTS.append([p, {"threshold": 5, "above": True, "cval": 5}, (5, 5, 5, 5, 5, 5, 6, 7, 8, 9)]) class TestThresholdIntensity(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_value(self, input_param, expected_value): - test_data = np.arange(10) + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, expected_value): + test_data = in_type(np.arange(10)) result = ThresholdIntensity(**input_param)(test_data) - np.testing.assert_allclose(result, expected_value) + assert_allclose(result, expected_value) if __name__ == "__main__": diff --git a/tests/test_threshold_intensityd.py b/tests/test_threshold_intensityd.py index efcfcfe604..398f9cfe91 100644 --- a/tests/test_threshold_intensityd.py +++ b/tests/test_threshold_intensityd.py @@ -15,31 +15,41 @@ from parameterized import parameterized from monai.transforms import ThresholdIntensityd - -TEST_CASE_1 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 0}, - (0, 0, 0, 0, 0, 0, 6, 7, 8, 9), -] - -TEST_CASE_2 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": False, "cval": 0}, - (0, 1, 2, 3, 4, 0, 0, 0, 0, 0), -] - -TEST_CASE_3 = [ - {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 5}, - (5, 5, 5, 5, 5, 5, 6, 7, 8, 9), -] +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 0}, + (0, 0, 0, 0, 0, 0, 6, 7, 8, 9), + ] + ) + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": False, "cval": 0}, + (0, 1, 2, 3, 4, 0, 0, 0, 0, 0), + ] + ) + TESTS.append( + [ + p, + {"keys": ["image", "label", "extra"], "threshold": 5, "above": True, "cval": 5}, + (5, 5, 5, 5, 5, 5, 6, 7, 8, 9), + ] + ) class TestThresholdIntensityd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) - def test_value(self, input_param, expected_value): - test_data = {"image": np.arange(10), "label": np.arange(10), "extra": np.arange(10)} + @parameterized.expand(TESTS) + def test_value(self, in_type, input_param, expected_value): + test_data = {"image": in_type(np.arange(10)), "label": in_type(np.arange(10)), "extra": in_type(np.arange(10))} result = ThresholdIntensityd(**input_param)(test_data) - np.testing.assert_allclose(result["image"], expected_value) - np.testing.assert_allclose(result["label"], expected_value) - np.testing.assert_allclose(result["extra"], expected_value) + assert_allclose(result["image"], expected_value) + assert_allclose(result["label"], expected_value) + assert_allclose(result["extra"], expected_value) if __name__ == "__main__":