Skip to content

Commit

Permalink
Add channel_wise in RandShiftIntensity (#7025)
Browse files Browse the repository at this point in the history
Fixes #6629.

### Description

Add `channel_wise` in `RandShiftIntensity` and `RandShiftIntensityd`.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <yunl@nvidia.com>
  • Loading branch information
KumoLiu authored Sep 22, 2023
1 parent 2539266 commit c21df49
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 6 deletions.
25 changes: 21 additions & 4 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,18 @@ class RandShiftIntensity(RandomizableTransform):

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, offsets: tuple[float, float] | float, safe: bool = False, prob: float = 0.1) -> None:
def __init__(
self, offsets: tuple[float, float] | float, safe: bool = False, prob: float = 0.1, channel_wise: bool = False
) -> None:
"""
Args:
offsets: offset range to randomly shift.
if single number, offset value is picked from (-offsets, offsets).
safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.
prob: probability of shift.
channel_wise: if True, shift intensity on each channel separately. For each channel, a random offset will be chosen.
Please ensure that the first dimension represents the channel of the image if True.
"""
RandomizableTransform.__init__(self, prob)
if isinstance(offsets, (int, float)):
Expand All @@ -272,13 +276,17 @@ def __init__(self, offsets: tuple[float, float] | float, safe: bool = False, pro
else:
self.offsets = (min(offsets), max(offsets))
self._offset = self.offsets[0]
self.channel_wise = channel_wise
self._shifter = ShiftIntensity(self._offset, safe)

def randomize(self, data: Any | None = None) -> None:
super().randomize(None)
if not self._do_transform:
return None
self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1])
if self.channel_wise:
self._offset = [self.R.uniform(low=self.offsets[0], high=self.offsets[1]) for _ in range(data.shape[0])] # type: ignore
else:
self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1])

def __call__(self, img: NdarrayOrTensor, factor: float | None = None, randomize: bool = True) -> NdarrayOrTensor:
"""
Expand All @@ -292,12 +300,21 @@ def __call__(self, img: NdarrayOrTensor, factor: float | None = None, randomize:
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
if randomize:
self.randomize()
self.randomize(img)

if not self._do_transform:
return img

return self._shifter(img, self._offset if factor is None else self._offset * factor)
ret: NdarrayOrTensor
if self.channel_wise:
out = []
for i, d in enumerate(img):
out_channel = self._shifter(d, self._offset[i] if factor is None else self._offset[i] * factor) # type: ignore
out.append(out_channel)
ret = torch.stack(out) # type: ignore
else:
ret = self._shifter(img, self._offset if factor is None else self._offset * factor)
return ret


class StdShiftIntensity(Transform):
Expand Down
14 changes: 12 additions & 2 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def __init__(
meta_keys: KeysCollection | None = None,
meta_key_postfix: str = DEFAULT_POST_FIX,
prob: float = 0.1,
channel_wise: bool = False,
allow_missing_keys: bool = False,
) -> None:
"""
Expand All @@ -399,6 +400,8 @@ def __init__(
used to extract the factor value is `factor_key` is not None.
prob: probability of shift.
(Default 0.1, with 10% probability it returns an array shifted intensity.)
channel_wise: if True, shift intensity on each channel separately. For each channel, a random offset will be chosen.
Please ensure that the first dimension represents the channel of the image if True.
allow_missing_keys: don't raise exception if key is missing.
"""
MapTransform.__init__(self, keys, allow_missing_keys)
Expand All @@ -409,7 +412,7 @@ def __init__(
if len(self.keys) != len(self.meta_keys):
raise ValueError("meta_keys should have the same length as keys.")
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
self.shifter = RandShiftIntensity(offsets=offsets, safe=safe, prob=1.0)
self.shifter = RandShiftIntensity(offsets=offsets, safe=safe, prob=1.0, channel_wise=channel_wise)

def set_random_state(
self, seed: int | None = None, state: np.random.RandomState | None = None
Expand All @@ -426,8 +429,15 @@ def __call__(self, data) -> dict[Hashable, NdarrayOrTensor]:
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
return d

# expect all the specified keys have same spatial shape and share same random holes
first_key: Hashable = self.first_key(d)
if first_key == ():
for key in self.key_iterator(d):
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
return d

# all the keys share the same random shift factor
self.shifter.randomize(None)
self.shifter.randomize(d[first_key])
for key, factor_key, meta_key, meta_key_postfix in self.key_iterator(
d, self.factor_key, self.meta_keys, self.meta_key_postfix
):
Expand Down
14 changes: 14 additions & 0 deletions tests/test_rand_shift_intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ def test_value(self, p):
expected = self.imt + np.random.uniform(low=-1.0, high=1.0)
assert_allclose(result, expected, type_test="tensor")

@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_channel_wise(self, p):
scaler = RandShiftIntensity(offsets=3.0, channel_wise=True, prob=1.0)
scaler.set_random_state(seed=0)
im = p(self.imt)
result = scaler(im)
np.random.seed(0)
# simulate the randomize() of transform
np.random.random()
channel_num = self.imt.shape[0]
factor = [np.random.uniform(low=-3.0, high=3.0) for _ in range(channel_num)]
expected = p(np.stack([np.asarray((self.imt[i]) + factor[i]) for i in range(channel_num)]).astype(np.float32))
assert_allclose(result, expected, atol=0, rtol=1e-5, type_test=False)


if __name__ == "__main__":
unittest.main()
16 changes: 16 additions & 0 deletions tests/test_rand_shift_intensityd.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,22 @@ def test_factor(self):
expected = self.imt + np.random.uniform(low=-1.0, high=1.0) * np.nanmax(self.imt)
np.testing.assert_allclose(result[key], expected)

def test_channel_wise(self):
key = "img"
for p in TEST_NDARRAYS:
scaler = RandShiftIntensityd(keys=[key], offsets=3.0, prob=1.0, channel_wise=True)
scaler.set_random_state(seed=0)
result = scaler({key: p(self.imt)})
np.random.seed(0)
# simulate the randomize function of transform
np.random.random()
channel_num = self.imt.shape[0]
factor = [np.random.uniform(low=-3.0, high=3.0) for _ in range(channel_num)]
expected = p(
np.stack([np.asarray((self.imt[i]) + factor[i]) for i in range(channel_num)]).astype(np.float32)
)
assert_allclose(result[key], p(expected), type_test="tensor")


if __name__ == "__main__":
unittest.main()

0 comments on commit c21df49

Please sign in to comment.