Skip to content

Commit

Permalink
Added (Rand)ScaleScaleIntensityFixedMean(d) and modified (Rand)Adjust…
Browse files Browse the repository at this point in the history
…Contrast(d) (#6542)

### Description

This PR adds the intensity transform **ScaleIntensityFixedMean**
(including its random and random dictionary version) and modifies the
intensity transform **AdjustContrast** and its dictionary version. It
adds functionality available in the corresponding nnU-Net transforms
[**ContrastAugmentationTransform**](https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/transforms/color_transforms.py#L25)
and
[**GammaTransform**](https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/transforms/color_transforms.py#L132).

Specifically, **ScaleIntensityFixedMean** scales the intensity of the
input image by a factor _v = v * (1 + factor)_ (same as the existing
**ScaleIntensity** transform when used with the factor argument). The
added functionality is provided by two arguments:

1. _fixed_mean_: subtract the mean intensity before scaling with
_factor_, then add the same value after scaling to ensure that the
output has the same mean intensity as the input.
  
2. _preserve_range_: clips the output array/tensor to the range of the
input array/tensor
  
AdjustContrast was modified by adding two arguments:

1. _invert_image_: multiplies all intensity values by -1 before gamma
transform and again after gamma transform
  
2. _retain_stats_: applies a scaling factor and an offset to all
intensity values after the gamma transform to ensure that the output
intensity distribution has the same mean and standard deviation as the
intensity distribution of the input


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Aaron Kujawa <askujawa@gmail.com>
  • Loading branch information
aaronkujawa authored Jun 8, 2023
1 parent c33f1ba commit 52b3ed2
Show file tree
Hide file tree
Showing 9 changed files with 593 additions and 37 deletions.
18 changes: 18 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,18 @@ Intensity
:members:
:special-members: __call__

`ScaleIntensityFixedMean`
"""""""""""""""""""""""""
.. autoclass:: ScaleIntensityFixedMean
:members:
:special-members: __call__

`RandScaleIntensityFixedMean`
"""""""""""""""""""""""""""""
.. autoclass:: RandScaleIntensityFixedMean
:members:
:special-members: __call__

`NormalizeIntensity`
""""""""""""""""""""
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/NormalizeIntensity.png
Expand Down Expand Up @@ -1386,6 +1398,12 @@ Intensity (Dict)
:members:
:special-members: __call__

`RandScaleIntensityFixedMeand`
"""""""""""""""""""""""""""""""
.. autoclass:: RandScaleIntensityFixedMeand
:members:
:special-members: __call__

`NormalizeIntensityd`
"""""""""""""""""""""
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/NormalizeIntensityd.png
Expand Down
5 changes: 5 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,12 @@
RandKSpaceSpikeNoise,
RandRicianNoise,
RandScaleIntensity,
RandScaleIntensityFixedMean,
RandShiftIntensity,
RandStdShiftIntensity,
SavitzkyGolaySmooth,
ScaleIntensity,
ScaleIntensityFixedMean,
ScaleIntensityRange,
ScaleIntensityRangePercentiles,
ShiftIntensity,
Expand Down Expand Up @@ -198,6 +200,9 @@
RandScaleIntensityd,
RandScaleIntensityD,
RandScaleIntensityDict,
RandScaleIntensityFixedMeand,
RandScaleIntensityFixedMeanD,
RandScaleIntensityFixedMeanDict,
RandShiftIntensityd,
RandShiftIntensityD,
RandShiftIntensityDict,
Expand Down
231 changes: 223 additions & 8 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
"RandBiasField",
"ScaleIntensity",
"RandScaleIntensity",
"ScaleIntensityFixedMean",
"RandScaleIntensityFixedMean",
"NormalizeIntensity",
"ThresholdIntensity",
"ScaleIntensityRange",
Expand Down Expand Up @@ -466,6 +468,161 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
return ret


class ScaleIntensityFixedMean(Transform):
"""
Scale the intensity of input image by ``v = v * (1 + factor)``, then shift the output so that the output image has the
same mean as the input.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
factor: float = 0,
preserve_range: bool = False,
fixed_mean: bool = True,
channel_wise: bool = False,
dtype: DtypeLike = np.float32,
) -> None:
"""
Args:
factor: factor scale by ``v = v * (1 + factor)``.
preserve_range: clips the output array/tensor to the range of the input array/tensor
fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
to ensure that the output has the same mean as the input.
channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
channel of the image if True.
dtype: output data type, if None, same as input image. defaults to float32.
"""
self.factor = factor
self.preserve_range = preserve_range
self.fixed_mean = fixed_mean
self.channel_wise = channel_wise
self.dtype = dtype

def __call__(self, img: NdarrayOrTensor, factor=None) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
Args:
img: the input tensor/array
factor: factor scale by ``v = v * (1 + factor)``
"""

factor = factor if factor is not None else self.factor

img = convert_to_tensor(img, track_meta=get_track_meta())
img_t = convert_to_tensor(img, track_meta=False)
ret: NdarrayOrTensor
if self.channel_wise:
out = []
for d in img_t:
if self.preserve_range:
clip_min = d.min()
clip_max = d.max()

if self.fixed_mean:
mn = d.mean()
d = d - mn

out_channel = d * (1 + factor)

if self.fixed_mean:
out_channel = out_channel + mn

if self.preserve_range:
out_channel = clip(out_channel, clip_min, clip_max)

out.append(out_channel)
ret = torch.stack(out) # type: ignore
else:
if self.preserve_range:
clip_min = img_t.min()
clip_max = img_t.max()

if self.fixed_mean:
mn = img_t.mean()
img_t = img_t - mn

ret = img_t * (1 + factor)

if self.fixed_mean:
ret = ret + mn

if self.preserve_range:
ret = clip(ret, clip_min, clip_max)

ret = convert_to_dst_type(ret, dst=img, dtype=self.dtype or img_t.dtype)[0]
return ret


class RandScaleIntensityFixedMean(RandomizableTransform):
"""
Randomly scale the intensity of input image by ``v = v * (1 + factor)`` where the `factor`
is randomly picked. Subtract the mean intensity before scaling with `factor`, then add the same value after scaling
to ensure that the output has the same mean as the input.
"""

backend = ScaleIntensityFixedMean.backend

def __init__(
self,
prob: float = 0.1,
factors: Sequence[float] | float = 0,
fixed_mean: bool = True,
preserve_range: bool = False,
dtype: DtypeLike = np.float32,
) -> None:
"""
Args:
factors: factor range to randomly scale by ``v = v * (1 + factor)``.
if single number, factor value is picked from (-factors, factors).
preserve_range: clips the output array/tensor to the range of the input array/tensor
fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
to ensure that the output has the same mean as the input.
channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
channel of the image if True.
dtype: output data type, if None, same as input image. defaults to float32.
"""
RandomizableTransform.__init__(self, prob)
if isinstance(factors, (int, float)):
self.factors = (min(-factors, factors), max(-factors, factors))
elif len(factors) != 2:
raise ValueError("factors should be a number or pair of numbers.")
else:
self.factors = (min(factors), max(factors))
self.factor = self.factors[0]
self.fixed_mean = fixed_mean
self.preserve_range = preserve_range
self.dtype = dtype

self.scaler = ScaleIntensityFixedMean(
factor=self.factor, fixed_mean=self.fixed_mean, preserve_range=self.preserve_range, dtype=self.dtype
)

def randomize(self, data: Any | None = None) -> None:
super().randomize(None)
if not self._do_transform:
return None
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])

def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
if randomize:
self.randomize()

if not self._do_transform:
return convert_data_type(img, dtype=self.dtype)[0]

return self.scaler(img, self.factor)


class RandScaleIntensity(RandomizableTransform):
"""
Randomly scale the intensity of input image by ``v = v * (1 + factor)`` where the `factor`
Expand Down Expand Up @@ -799,48 +956,99 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:

class AdjustContrast(Transform):
"""
Changes image intensity by gamma. Each pixel/voxel intensity is updated as::
Changes image intensity with gamma transform. Each pixel/voxel intensity is updated as::
x = ((x - min) / intensity_range) ^ gamma * intensity_range + min
Args:
gamma: gamma value to adjust the contrast as function.
invert_image: whether to invert the image before applying gamma augmentation. If True, multiply all intensity
values with -1 before the gamma transform and again after the gamma transform. This behaviour is mimicked
from `nnU-Net <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
<https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
function.
retain_stats: if True, applies a scaling factor and an offset to all intensity values after gamma transform to
ensure that the output intensity distribution has the same mean and standard deviation as the intensity
distribution of the input. This behaviour is mimicked from `nnU-Net
<https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
<https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
function.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, gamma: float) -> None:
def __init__(self, gamma: float, invert_image: bool = False, retain_stats: bool = False) -> None:
if not isinstance(gamma, (int, float)):
raise ValueError(f"gamma must be a float or int number, got {type(gamma)} {gamma}.")
self.gamma = gamma
self.invert_image = invert_image
self.retain_stats = retain_stats

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
def __call__(self, img: NdarrayOrTensor, gamma=None) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
gamma: gamma value to adjust the contrast as function.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
gamma = gamma if gamma is not None else self.gamma

if self.invert_image:
img = -img

if self.retain_stats:
mn = img.mean()
sd = img.std()

epsilon = 1e-7
img_min = img.min()
img_range = img.max() - img_min
ret: NdarrayOrTensor = ((img - img_min) / float(img_range + epsilon)) ** self.gamma * img_range + img_min
ret: NdarrayOrTensor = ((img - img_min) / float(img_range + epsilon)) ** gamma * img_range + img_min

if self.retain_stats:
# zero mean and normalize
ret = ret - ret.mean()
ret = ret / (ret.std() + 1e-8)
# restore old mean and standard deviation
ret = sd * ret + mn

if self.invert_image:
ret = -ret

return ret


class RandAdjustContrast(RandomizableTransform):
"""
Randomly changes image intensity by gamma. Each pixel/voxel intensity is updated as::
Randomly changes image intensity with gamma transform. Each pixel/voxel intensity is updated as:
x = ((x - min) / intensity_range) ^ gamma * intensity_range + min
Args:
prob: Probability of adjustment.
gamma: Range of gamma values.
If single number, value is picked from (0.5, gamma), default is (0.5, 4.5).
invert_image: whether to invert the image before applying gamma augmentation. If True, multiply all intensity
values with -1 before the gamma transform and again after the gamma transform. This behaviour is mimicked
from `nnU-Net <https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
<https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
function.
retain_stats: if True, applies a scaling factor and an offset to all intensity values after gamma transform to
ensure that the output intensity distribution has the same mean and standard deviation as the intensity
distribution of the input. This behaviour is mimicked from `nnU-Net
<https://www.nature.com/articles/s41592-020-01008-z>`_, specifically `this
<https://github.com/MIC-DKFZ/batchgenerators/blob/7fb802b28b045b21346b197735d64f12fbb070aa/batchgenerators/augmentations/color_augmentations.py#L107>`_
function.
"""

backend = AdjustContrast.backend

def __init__(self, prob: float = 0.1, gamma: Sequence[float] | float = (0.5, 4.5)) -> None:
def __init__(
self,
prob: float = 0.1,
gamma: Sequence[float] | float = (0.5, 4.5),
invert_image: bool = False,
retain_stats: bool = False,
) -> None:
RandomizableTransform.__init__(self, prob)

if isinstance(gamma, (int, float)):
Expand All @@ -854,7 +1062,13 @@ def __init__(self, prob: float = 0.1, gamma: Sequence[float] | float = (0.5, 4.5
else:
self.gamma = (min(gamma), max(gamma))

self.gamma_value: float | None = None
self.gamma_value: float = 1.0
self.invert_image: bool = invert_image
self.retain_stats: bool = retain_stats

self.adjust_contrast = AdjustContrast(
self.gamma_value, invert_image=self.invert_image, retain_stats=self.retain_stats
)

def randomize(self, data: Any | None = None) -> None:
super().randomize(None)
Expand All @@ -875,7 +1089,8 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen

if self.gamma_value is None:
raise RuntimeError("gamma_value is not set, please call `randomize` function first.")
return AdjustContrast(self.gamma_value)(img)

return self.adjust_contrast(img, self.gamma_value)


class ScaleIntensityRangePercentiles(Transform):
Expand Down
Loading

0 comments on commit 52b3ed2

Please sign in to comment.