From 24d37578b8850b4e9f6aad04b7b5f0acced48b26 Mon Sep 17 00:00:00 2001 From: John Zielke Date: Tue, 26 Mar 2024 18:17:06 -0400 Subject: [PATCH 1/2] Fix inconsistent alpha parameter/docs for RandGibbsNoise/RandGibbsNoised Signed-off-by: John Zielke --- monai/transforms/intensity/array.py | 8 ++++++-- monai/transforms/intensity/dictionary.py | 5 +++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index a2f63a7482..0085050ee3 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1840,15 +1840,19 @@ class RandGibbsNoise(RandomizableTransform): Args: prob (float): probability of applying the transform. - alpha (Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes + alpha (float, Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes values in the interval [0,1] with alpha = 0 acting as the identity mapping. If a length-2 list is given as [a,b] then the value of alpha will be sampled uniformly from the interval [a,b]. 0 <= a <= b <= 1. + If a float is given, then the value of alpha will be sampled uniformly from the interval [0, alpha]. """ backend = GibbsNoise.backend - def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0)) -> None: + def __init__(self, prob: float = 0.1, alpha: float | Sequence[float] = (0.0, 1.0)) -> None: + if isinstance(alpha, float): + alpha = (0, alpha) + alpha = ensure_tuple(alpha) if len(alpha) != 2: raise ValueError("alpha length must be 2.") if alpha[1] > 1 or alpha[0] < 0: diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 7e93464e64..5b911904b0 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -1423,10 +1423,11 @@ class RandGibbsNoised(RandomizableTransform, MapTransform): keys: 'image', 'label', or ['image', 'label'] depending on which data you need to transform. prob (float): probability of applying the transform. - alpha (float, List[float]): Parametrizes the intensity of the Gibbs noise filter applied. Takes + alpha (float, Sequence[float]): Parametrizes the intensity of the Gibbs noise filter applied. Takes values in the interval [0,1] with alpha = 0 acting as the identity mapping. If a length-2 list is given as [a,b] then the value of alpha will be sampled uniformly from the interval [a,b]. + If a float is given, then the value of alpha will be sampled uniformly from the interval [0, alpha]. allow_missing_keys: do not raise exception if key is missing. """ @@ -1436,7 +1437,7 @@ def __init__( self, keys: KeysCollection, prob: float = 0.1, - alpha: Sequence[float] = (0.0, 1.0), + alpha: float | Sequence[float] = (0.0, 1.0), allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) From 4e48e07dfce8a679321ec1a7e6a2942da07400b0 Mon Sep 17 00:00:00 2001 From: John Zielke Date: Tue, 26 Mar 2024 18:25:08 -0400 Subject: [PATCH 2/2] Add tests for single alpha value Signed-off-by: John Zielke --- tests/test_rand_gibbs_noise.py | 9 +++++++++ tests/test_rand_gibbs_noised.py | 8 ++++++++ 2 files changed, 17 insertions(+) diff --git a/tests/test_rand_gibbs_noise.py b/tests/test_rand_gibbs_noise.py index 4befeffbe2..5ef249a1f4 100644 --- a/tests/test_rand_gibbs_noise.py +++ b/tests/test_rand_gibbs_noise.py @@ -90,6 +90,15 @@ def test_alpha(self, im_shape, input_type): self.assertGreaterEqual(t.sampled_alpha, 0.5) self.assertLessEqual(t.sampled_alpha, 0.51) + @parameterized.expand(TEST_CASES) + def test_alpha_single_value(self, im_shape, input_type): + im = self.get_data(im_shape, input_type) + alpha = 0.01 + t = RandGibbsNoise(1.0, alpha) + _ = t(deepcopy(im)) + self.assertGreaterEqual(t.sampled_alpha, 0) + self.assertLessEqual(t.sampled_alpha, 0.01) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_rand_gibbs_noised.py b/tests/test_rand_gibbs_noised.py index 6580189af6..382290dd39 100644 --- a/tests/test_rand_gibbs_noised.py +++ b/tests/test_rand_gibbs_noised.py @@ -105,6 +105,14 @@ def test_alpha(self, im_shape, input_type): _ = t(deepcopy(data)) self.assertTrue(0.5 <= t.rand_gibbs_noise.sampled_alpha <= 0.51) + @parameterized.expand(TEST_CASES) + def test_alpha_single_value(self, im_shape, input_type): + data = self.get_data(im_shape, input_type) + alpha = 0.01 + t = RandGibbsNoised(KEYS, 1.0, alpha) + _ = t(deepcopy(data)) + self.assertTrue(0 <= t.rand_gibbs_noise.sampled_alpha <= 0.01) + if __name__ == "__main__": unittest.main()