Skip to content

Commit

Permalink
Added (Rand)ScaleScaleIntensityFixedMean(d) and modified (Rand)Adjust…
Browse files Browse the repository at this point in the history
…Contrast(d) with by adding arguments
  • Loading branch information
aaronkujawa committed May 22, 2023
1 parent 5174165 commit dce66d0
Show file tree
Hide file tree
Showing 8 changed files with 537 additions and 34 deletions.
5 changes: 5 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,12 @@
RandKSpaceSpikeNoise,
RandRicianNoise,
RandScaleIntensity,
RandScaleIntensityFixedMean,
RandShiftIntensity,
RandStdShiftIntensity,
SavitzkyGolaySmooth,
ScaleIntensity,
ScaleIntensityFixedMean,
ScaleIntensityRange,
ScaleIntensityRangePercentiles,
ShiftIntensity,
Expand Down Expand Up @@ -198,6 +200,9 @@
RandScaleIntensityd,
RandScaleIntensityD,
RandScaleIntensityDict,
RandScaleIntensityFixedMeand,
RandScaleIntensityFixedMeanD,
RandScaleIntensityFixedMeanDict,
RandShiftIntensityd,
RandShiftIntensityD,
RandShiftIntensityDict,
Expand Down
204 changes: 199 additions & 5 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
"RandBiasField",
"ScaleIntensity",
"RandScaleIntensity",
"ScaleIntensityFixedMean",
"RandScaleIntensityFixedMean",
"NormalizeIntensity",
"ThresholdIntensity",
"ScaleIntensityRange",
Expand Down Expand Up @@ -467,6 +469,160 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
return ret


class ScaleIntensityFixedMean(Transform):
"""
Scale the intensity of input image ``v = v * (1 + factor)``, then shift the output so that the output image has the
same mean as the input.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
factor: float | None = 0,
preserve_range: bool = False,
fixed_mean: bool = True,
channel_wise: bool = False,
dtype: DtypeLike = np.float32,
) -> None:
"""
Args:
factor: factor scale by ``v = v * (1 + factor)``.
preserve_range: clips the output array/tensor to the range of the input array/tensor
fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
to ensure that the output has the same mean as the input.
channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
channel of the image if True.
dtype: output data type, if None, same as input image. defaults to float32.
"""
self.factor = factor
self.preserve_range = preserve_range
self.fixed_mean = fixed_mean
self.channel_wise = channel_wise
self.dtype = dtype

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
Raises:
ValueError: When ``self.fixed_mean=True`` and ``self.factor=None``. Incompatible values.
"""

if self.fixed_mean and not self.factor:
raise ValueError(f"{self.fixed_mean=} and {self.factor=} is incompatible.")

img = convert_to_tensor(img, track_meta=get_track_meta())
img_t = convert_to_tensor(img, track_meta=False)
ret: NdarrayOrTensor
if self.channel_wise:
out = []
for d in img_t:
if self.preserve_range:
clip_min = d.min()
clip_max = d.max()

if self.fixed_mean:
mn = d.mean()
d = d - mn

out_channel = (d * (1 + self.factor)) if self.factor is not None else d

if self.fixed_mean:
out_channel = out_channel + mn

if self.preserve_range:
out_channel = clip(out_channel, clip_min, clip_max)

out.append(out_channel)
ret = torch.stack(out) # type: ignore
else:
if self.preserve_range:
clip_min = img_t.min()
clip_max = img_t.max()

if self.fixed_mean:
mn = img_t.mean()
img_t = img_t - mn

ret = (img_t * (1 + self.factor)) if self.factor is not None else img_t

if self.fixed_mean:
ret = ret + mn

if self.preserve_range:
ret = clip(ret, clip_min, clip_max)

ret = convert_to_dst_type(ret, dst=img, dtype=self.dtype or img_t.dtype)[0]
return ret


class RandScaleIntensityFixedMean(RandomizableTransform):
"""
Randomly scale the intensity of input image by ``v = v * (1 + factor)`` where the `factor`
is randomly picked. Subtract the mean intensity before scaling with `factor`, then add the same value after scaling
to ensure that the output has the same mean as the input.
"""

backend = ScaleIntensityFixedMean.backend

def __init__(
self,
prob: float = 0.1,
factors: Sequence[float] | float = 0,
fixed_mean: bool = True,
preserve_range: bool = False,
dtype: DtypeLike = np.float32,
) -> None:
"""
Args:
factors: factor range to randomly scale by ``v = v * (1 + factor)``.
if single number, factor value is picked from (-factors, factors).
preserve_range: clips the output array/tensor to the range of the input array/tensor
fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
to ensure that the output has the same mean as the input.
channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
channel of the image if True.
dtype: output data type, if None, same as input image. defaults to float32.
"""
RandomizableTransform.__init__(self, prob)
if isinstance(factors, (int, float)):
self.factors = (min(-factors, factors), max(-factors, factors))
elif len(factors) != 2:
raise ValueError("factors should be a number or pair of numbers.")
else:
self.factors = (min(factors), max(factors))
self.factor = self.factors[0]
self.fixed_mean = fixed_mean
self.preserve_range = preserve_range
self.dtype = dtype

def randomize(self, data: Any | None = None) -> None:
super().randomize(None)
if not self._do_transform:
return None
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])

def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
if randomize:
self.randomize()

if not self._do_transform:
return convert_data_type(img, dtype=self.dtype)[0]

return ScaleIntensityFixedMean(
factor=self.factor, fixed_mean=self.fixed_mean, preserve_range=self.preserve_range, dtype=self.dtype
)(img)


class RandScaleIntensity(RandomizableTransform):
"""
Randomly scale the intensity of input image by ``v = v * (1 + factor)`` where the `factor`
Expand Down Expand Up @@ -800,48 +956,83 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:

class AdjustContrast(Transform):
"""
Changes image intensity by gamma. Each pixel/voxel intensity is updated as::
Changes image intensity with gamma transform. Each pixel/voxel intensity is updated as::
x = ((x - min) / intensity_range) ^ gamma * intensity_range + min
Args:
gamma: gamma value to adjust the contrast as function.
invert_image: multiplies all intensity values with -1 before gamma transform and again after gamma transform
retain_stats: applies a scaling factor and an offset to all intensity values after gamma transform to ensure
that the output intensity distribution has the same mean and standard deviation as the intensity
distribution of the input
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, gamma: float) -> None:
def __init__(self, gamma: float, invert_image: bool = False, retain_stats: bool = False) -> None:
if not isinstance(gamma, (int, float)):
raise ValueError(f"gamma must be a float or int number, got {type(gamma)} {gamma}.")
self.gamma = gamma
self.invert_image = invert_image
self.retain_stats = retain_stats

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())

if self.invert_image:
img = -img

if self.retain_stats:
mn = img.mean()
sd = img.std()

epsilon = 1e-7
img_min = img.min()
img_range = img.max() - img_min
ret: NdarrayOrTensor = ((img - img_min) / float(img_range + epsilon)) ** self.gamma * img_range + img_min

if self.retain_stats:
# zero mean and normalize
ret = ret - ret.mean()
ret = ret / (ret.std() + 1e-8)
# restore old mean and standard deviation
ret = sd * ret + mn

if self.invert_image:
ret = -ret

return ret


class RandAdjustContrast(RandomizableTransform):
"""
Randomly changes image intensity by gamma. Each pixel/voxel intensity is updated as::
Randomly changes image intensity with gamma transform. Each pixel/voxel intensity is updated as:
x = ((x - min) / intensity_range) ^ gamma * intensity_range + min
Args:
prob: Probability of adjustment.
gamma: Range of gamma values.
If single number, value is picked from (0.5, gamma), default is (0.5, 4.5).
invert_image: multiplies all intensity values with -1 before gamma transform and again after gamma transform
retain_stats: applies a scaling factor and an offset to all intensity values after gamma transform to ensure
that the output intensity distribution has the same mean and standard deviation as the intensity
distribution of the input
"""

backend = AdjustContrast.backend

def __init__(self, prob: float = 0.1, gamma: Sequence[float] | float = (0.5, 4.5)) -> None:
def __init__(
self,
prob: float = 0.1,
gamma: Sequence[float] | float = (0.5, 4.5),
invert_image: bool = False,
retain_stats: bool = False,
) -> None:
RandomizableTransform.__init__(self, prob)

if isinstance(gamma, (int, float)):
Expand All @@ -856,6 +1047,8 @@ def __init__(self, prob: float = 0.1, gamma: Sequence[float] | float = (0.5, 4.5
self.gamma = (min(gamma), max(gamma))

self.gamma_value: float | None = None
self.invert_image: bool = invert_image
self.retain_stats: bool = retain_stats

def randomize(self, data: Any | None = None) -> None:
super().randomize(None)
Expand All @@ -876,7 +1069,8 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen

if self.gamma_value is None:
raise RuntimeError("gamma_value is not set, please call `randomize` function first.")
return AdjustContrast(self.gamma_value)(img)

return AdjustContrast(self.gamma_value, invert_image=self.invert_image, retain_stats=self.retain_stats)(img)


class ScaleIntensityRangePercentiles(Transform):
Expand Down
Loading

0 comments on commit dce66d0

Please sign in to comment.