Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inconsistent alpha parameter/docs for RandGibbsNoise/RandGibbsNoised #7584

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1840,15 +1840,19 @@ class RandGibbsNoise(RandomizableTransform):

Args:
prob (float): probability of applying the transform.
alpha (Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes
alpha (float, Sequence(float)): Parametrizes the intensity of the Gibbs noise filter applied. Takes
values in the interval [0,1] with alpha = 0 acting as the identity mapping.
If a length-2 list is given as [a,b] then the value of alpha will be
sampled uniformly from the interval [a,b]. 0 <= a <= b <= 1.
If a float is given, then the value of alpha will be sampled uniformly from the interval [0, alpha].
"""

backend = GibbsNoise.backend

def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0)) -> None:
def __init__(self, prob: float = 0.1, alpha: float | Sequence[float] = (0.0, 1.0)) -> None:
if isinstance(alpha, float):
alpha = (0, alpha)
alpha = ensure_tuple(alpha)
if len(alpha) != 2:
raise ValueError("alpha length must be 2.")
if alpha[1] > 1 or alpha[0] < 0:
Expand Down
5 changes: 3 additions & 2 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,10 +1423,11 @@ class RandGibbsNoised(RandomizableTransform, MapTransform):
keys: 'image', 'label', or ['image', 'label'] depending on which data
you need to transform.
prob (float): probability of applying the transform.
alpha (float, List[float]): Parametrizes the intensity of the Gibbs noise filter applied. Takes
alpha (float, Sequence[float]): Parametrizes the intensity of the Gibbs noise filter applied. Takes
values in the interval [0,1] with alpha = 0 acting as the identity mapping.
If a length-2 list is given as [a,b] then the value of alpha will be sampled
uniformly from the interval [a,b].
If a float is given, then the value of alpha will be sampled uniformly from the interval [0, alpha].
allow_missing_keys: do not raise exception if key is missing.
"""

Expand All @@ -1436,7 +1437,7 @@ def __init__(
self,
keys: KeysCollection,
prob: float = 0.1,
alpha: Sequence[float] = (0.0, 1.0),
alpha: float | Sequence[float] = (0.0, 1.0),
allow_missing_keys: bool = False,
) -> None:
MapTransform.__init__(self, keys, allow_missing_keys)
Expand Down
9 changes: 9 additions & 0 deletions tests/test_rand_gibbs_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ def test_alpha(self, im_shape, input_type):
self.assertGreaterEqual(t.sampled_alpha, 0.5)
self.assertLessEqual(t.sampled_alpha, 0.51)

@parameterized.expand(TEST_CASES)
def test_alpha_single_value(self, im_shape, input_type):
im = self.get_data(im_shape, input_type)
alpha = 0.01
t = RandGibbsNoise(1.0, alpha)
_ = t(deepcopy(im))
self.assertGreaterEqual(t.sampled_alpha, 0)
self.assertLessEqual(t.sampled_alpha, 0.01)


if __name__ == "__main__":
unittest.main()
8 changes: 8 additions & 0 deletions tests/test_rand_gibbs_noised.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ def test_alpha(self, im_shape, input_type):
_ = t(deepcopy(data))
self.assertTrue(0.5 <= t.rand_gibbs_noise.sampled_alpha <= 0.51)

@parameterized.expand(TEST_CASES)
def test_alpha_single_value(self, im_shape, input_type):
data = self.get_data(im_shape, input_type)
alpha = 0.01
t = RandGibbsNoised(KEYS, 1.0, alpha)
_ = t(deepcopy(data))
self.assertTrue(0 <= t.rand_gibbs_noise.sampled_alpha <= 0.01)


if __name__ == "__main__":
unittest.main()
Loading