Skip to content

Commit

Permalink
Add sample_std parameter to RandGaussianNoise. Also fixed small typo …
Browse files Browse the repository at this point in the history
…in RandGaussianNoised class docstring.
  • Loading branch information
bakert1 committed Feb 26, 2024
1 parent f6f9e81 commit 420a2d1
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 15 deletions.
14 changes: 11 additions & 3 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,24 +91,32 @@ class RandGaussianNoise(RandomizableTransform):
mean: Mean or “centre” of the distribution.
std: Standard deviation (spread) of distribution.
dtype: output data type, if None, same as input image. defaults to float32.
sample_std: If True, sample the spread of the Gaussian distribution uniformly from 0 to std.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, prob: float = 0.1, mean: float = 0.0, std: float = 0.1, dtype: DtypeLike = np.float32) -> None:
def __init__(
self, prob: float = 0.1,
mean: float = 0.0,
std: float = 0.1,
dtype: DtypeLike = np.float32,
sample_std: bool = True,
) -> None:
RandomizableTransform.__init__(self, prob)
self.mean = mean
self.std = std
self.dtype = dtype
self.noise: np.ndarray | None = None
self.sample_std = sample_std

def randomize(self, img: NdarrayOrTensor, mean: float | None = None) -> None:
super().randomize(None)
if not self._do_transform:
return None
rand_std = self.R.uniform(0, self.std)
noise = self.R.normal(self.mean if mean is None else mean, rand_std, size=img.shape)
std = self.R.uniform(0, self.std) if self.sample_std else self.std
noise = self.R.normal(self.mean if mean is None else mean, std, size=img.shape)
# noise is float64 array, convert to the output dtype to save memory
self.noise, *_ = convert_data_type(noise, dtype=self.dtype)

Expand Down
6 changes: 4 additions & 2 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@
class RandGaussianNoised(RandomizableTransform, MapTransform):
"""
Dictionary-based version :py:class:`monai.transforms.RandGaussianNoise`.
Add Gaussian noise to image. This transform assumes all the expected fields have same shape, if want to add
Add Gaussian noise to image. This transform assumes all the expected fields have same shape, if you want to add
different noise for every field, please use this transform separately.
Args:
Expand All @@ -183,6 +183,7 @@ class RandGaussianNoised(RandomizableTransform, MapTransform):
std: Standard deviation (spread) of distribution.
dtype: output data type, if None, same as input image. defaults to float32.
allow_missing_keys: don't raise exception if key is missing.
sample_std: If True, sample the spread of the Gaussian distribution uniformly from 0 to std.
"""

backend = RandGaussianNoise.backend
Expand All @@ -195,10 +196,11 @@ def __init__(
std: float = 0.1,
dtype: DtypeLike = np.float32,
allow_missing_keys: bool = False,
sample_std: bool = True
) -> None:
MapTransform.__init__(self, keys, allow_missing_keys)
RandomizableTransform.__init__(self, prob)
self.rand_gaussian_noise = RandGaussianNoise(mean=mean, std=std, prob=1.0, dtype=dtype)
self.rand_gaussian_noise = RandGaussianNoise(mean=mean, std=std, prob=1.0, dtype=dtype, sample_std=sample_std)

def set_random_state(
self, seed: int | None = None, state: np.random.RandomState | None = None
Expand Down
12 changes: 7 additions & 5 deletions tests/test_rand_gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,24 @@

TESTS = []
for p in TEST_NDARRAYS:
TESTS.append(("test_zero_mean", p, 0, 0.1))
TESTS.append(("test_non_zero_mean", p, 1, 0.5))
TESTS.append(("test_zero_mean", p, 0, 0.1, True))
TESTS.append(("test_non_zero_mean", p, 1, 0.5, True))
TESTS.append(("test_no_sample_std", p, 1, 0.5, False))


class TestRandGaussianNoise(NumpyImageTestCase2D):

@parameterized.expand(TESTS)
def test_correct_results(self, _, im_type, mean, std):
def test_correct_results(self, _, im_type, mean, std, sample_std):
seed = 0
gaussian_fn = RandGaussianNoise(prob=1.0, mean=mean, std=std)
gaussian_fn = RandGaussianNoise(prob=1.0, mean=mean, std=std, sample_std=sample_std)
gaussian_fn.set_random_state(seed)
im = im_type(self.imt)
noised = gaussian_fn(im)
np.random.seed(seed)
np.random.random()
expected = self.imt + np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape)
_std = np.random.uniform(0, std) if sample_std else std
expected = self.imt + np.random.normal(mean, _std, size=self.imt.shape)
if isinstance(noised, torch.Tensor):
noised = noised.cpu()
np.testing.assert_allclose(expected, noised, atol=1e-5)
Expand Down
12 changes: 7 additions & 5 deletions tests/test_rand_gaussian_noised.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,26 @@

TESTS = []
for p in TEST_NDARRAYS:
TESTS.append(["test_zero_mean", p, ["img1", "img2"], 0, 0.1])
TESTS.append(["test_non_zero_mean", p, ["img1", "img2"], 1, 0.5])
TESTS.append(["test_zero_mean", p, ["img1", "img2"], 0, 0.1, True])
TESTS.append(["test_non_zero_mean", p, ["img1", "img2"], 1, 0.5, True])
TESTS.append(["test_no_sample_std", p, ["img1", "img2"], 1, 0.5, False])

seed = 0


class TestRandGaussianNoised(NumpyImageTestCase2D):

@parameterized.expand(TESTS)
def test_correct_results(self, _, im_type, keys, mean, std):
gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64)
def test_correct_results(self, _, im_type, keys, mean, std, sample_std):
gaussian_fn = RandGaussianNoised(keys=keys, prob=1.0, mean=mean, std=std, dtype=np.float64, sample_std=sample_std)
gaussian_fn.set_random_state(seed)
im = im_type(self.imt)
noised = gaussian_fn({k: im for k in keys})
np.random.seed(seed)
# simulate the randomize() of transform
np.random.random()
noise = np.random.normal(mean, np.random.uniform(0, std), size=self.imt.shape)
_std = np.random.uniform(0, std) if sample_std else std
noise = np.random.normal(mean, _std, size=self.imt.shape)
for k in keys:
expected = self.imt + noise
if isinstance(noised[k], torch.Tensor):
Expand Down

0 comments on commit 420a2d1

Please sign in to comment.