Skip to content

Commit

Permalink
Added RandSimulateLowResolution(d) array and dictionary transforms an…
Browse files Browse the repository at this point in the history
…d corresponding unit tests (#6806)

Fixes #3781.

### Description
Random simulation of low resolution corresponding to nnU-Net's
(https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/resample_transforms.py#L23).
First, the array/tensor is resampled at lower resolution as determined
by the zoom_factor which is uniformly sampled from the `zoom_range`.
Then, the array/tensor is resampled at the original resolution. MONAI's
`Resize` transform is used for the resampling operations.

### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Aaron Kujawa <askujawa@gmail.com>
  • Loading branch information
aaronkujawa authored Aug 1, 2023
1 parent 4c22a27 commit d6bafc9
Show file tree
Hide file tree
Showing 6 changed files with 360 additions and 1 deletion.
13 changes: 13 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,12 @@ Spatial
:members:
:special-members: __call__

`RandSimulateLowResolution`
"""""""""""""""""""""""""""
.. autoclass:: RandSimulateLowResolution
:members:
:special-members: __call__


Smooth Field
^^^^^^^^^^^^
Expand Down Expand Up @@ -1886,6 +1892,13 @@ Spatial (Dict)
:members:
:special-members: __call__

`RandSimulateLowResolutiond`
""""""""""""""""""""""""""""
.. autoclass:: RandSimulateLowResolutiond
:members:
:special-members: __call__


Smooth Field (Dict)
^^^^^^^^^^^^^^^^^^^

Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@
RandGridPatch,
RandRotate,
RandRotate90,
RandSimulateLowResolution,
RandZoom,
Resample,
ResampleToMatch,
Expand Down Expand Up @@ -437,6 +438,9 @@
RandRotated,
RandRotateD,
RandRotateDict,
RandSimulateLowResolutiond,
RandSimulateLowResolutionD,
RandSimulateLowResolutionDict,
RandZoomd,
RandZoomD,
RandZoomDict,
Expand Down
95 changes: 94 additions & 1 deletion monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from monai.config import USE_COMPILED, DtypeLike
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_obj import get_track_meta, set_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
Expand Down Expand Up @@ -111,6 +111,7 @@
"RandAffine",
"Rand2DElastic",
"Rand3DElastic",
"RandSimulateLowResolution",
]

RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]]
Expand Down Expand Up @@ -3456,3 +3457,95 @@ def __call__(self, array: NdarrayOrTensor, randomize: bool = True):
if randomize:
self.randomize(array)
return super().__call__(array)


class RandSimulateLowResolution(RandomizableTransform):
"""
Random simulation of low resolution corresponding to nnU-Net's SimulateLowResolutionTransform
(https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/resample_transforms.py#L23)
First, the array/tensor is resampled at lower resolution as determined by the zoom_factor which is uniformly sampled
from the `zoom_range`. Then, the array/tensor is resampled at the original resolution.
"""

backend = Affine.backend

def __init__(
self,
prob: float = 0.1,
downsample_mode: InterpolateMode | str = InterpolateMode.NEAREST,
upsample_mode: InterpolateMode | str = InterpolateMode.TRILINEAR,
zoom_range: Sequence[float] = (0.5, 1.0),
align_corners=False,
device: torch.device | None = None,
) -> None:
"""
Args:
prob: probability of performing this augmentation
downsample_mode: interpolation mode for downsampling operation
upsample_mode: interpolation mode for upsampling operation
zoom_range: range from which the random zoom factor for the downsampling and upsampling operation is
sampled. It determines the shape of the downsampled tensor.
align_corners: This only has an effect when downsample_mode or upsample_mode is 'linear', 'bilinear',
'bicubic' or 'trilinear'. Default: False
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
device: device on which the tensor will be allocated.
"""
RandomizableTransform.__init__(self, prob)

self.downsample_mode = downsample_mode
self.upsample_mode = upsample_mode
self.zoom_range = zoom_range
self.align_corners = align_corners
self.device = device
self.zoom_factor = 1.0

def randomize(self, data: Any | None = None) -> None:
super().randomize(None)
self.zoom_factor = self.R.uniform(self.zoom_range[0], self.zoom_range[1])
if not self._do_transform:
return None

def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:
"""
Args:
img: shape must be (num_channels, H, W[, D]),
randomize: whether to execute `randomize()` function first, defaults to True.
"""
if randomize:
self.randomize()

if self._do_transform:
input_shape = img.shape[1:]
target_shape = np.round(np.array(input_shape) * self.zoom_factor).astype(np.int_)

resize_tfm_downsample = Resize(
spatial_size=target_shape, size_mode="all", mode=self.downsample_mode, anti_aliasing=False
)

resize_tfm_upsample = Resize(
spatial_size=input_shape,
size_mode="all",
mode=self.upsample_mode,
anti_aliasing=False,
align_corners=self.align_corners,
)
# temporarily disable metadata tracking, since we do not want to invert the two Resize functions during
# post-processing
original_tack_meta_value = get_track_meta()
set_track_meta(False)

img_downsampled = resize_tfm_downsample(img)
img_upsampled = resize_tfm_upsample(img_downsampled)

# reset metadata tracking to original value
set_track_meta(original_tack_meta_value)

# copy metadata from original image to down-and-upsampled image
img_upsampled = MetaTensor(img_upsampled)
img_upsampled.copy_meta_from(img)

return img_upsampled

else:
return img
93 changes: 93 additions & 0 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
RandGridDistortion,
RandGridPatch,
RandRotate,
RandSimulateLowResolution,
RandZoom,
ResampleToMatch,
Resize,
Expand Down Expand Up @@ -140,6 +141,9 @@
"RandGridPatchd",
"RandGridPatchD",
"RandGridPatchDict",
"RandSimulateLowResolutiond",
"RandSimulateLowResolutionD",
"RandSimulateLowResolutionDict",
]


Expand Down Expand Up @@ -2518,6 +2522,94 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class RandSimulateLowResolutiond(RandomizableTransform, MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.RandSimulateLowResolution`.
Random simulation of low resolution corresponding to nnU-Net's SimulateLowResolutionTransform
(https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/resample_transforms.py#L23)
First, the array/tensor is resampled at lower resolution as determined by the zoom_factor which is uniformly sampled
from the `zoom_range`. Then, the array/tensor is resampled at the original resolution.
"""

backend = RandAffine.backend

def __init__(
self,
keys: KeysCollection,
prob: float = 0.1,
downsample_mode: InterpolateMode | str = InterpolateMode.NEAREST,
upsample_mode: InterpolateMode | str = InterpolateMode.TRILINEAR,
zoom_range=(0.5, 1.0),
align_corners=False,
allow_missing_keys: bool = False,
device: torch.device | None = None,
) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
prob: probability of performing this augmentation
downsample_mode: interpolation mode for downsampling operation
upsample_mode: interpolation mode for upsampling operation
zoom_range: range from which the random zoom factor for the downsampling and upsampling operation is
sampled. It determines the shape of the downsampled tensor.
align_corners: This only has an effect when downsample_mode or upsample_mode is 'linear', 'bilinear',
'bicubic' or 'trilinear'. Default: False
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
allow_missing_keys: don't raise exception if key is missing.
device: device on which the tensor will be allocated.
See also:
- :py:class:`monai.transforms.compose.MapTransform`
"""
MapTransform.__init__(self, keys, allow_missing_keys)
RandomizableTransform.__init__(self, prob)

self.downsample_mode = downsample_mode
self.upsample_mode = upsample_mode
self.zoom_range = zoom_range
self.align_corners = align_corners
self.device = device

self.sim_lowres_tfm = RandSimulateLowResolution(
prob=1.0, # probability is handled by dictionary class
downsample_mode=self.downsample_mode,
upsample_mode=self.upsample_mode,
zoom_range=self.zoom_range,
align_corners=self.align_corners,
device=self.device,
)

def set_random_state(
self, seed: int | None = None, state: np.random.RandomState | None = None
) -> RandSimulateLowResolutiond:
super().set_random_state(seed, state)
return self

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
"""
Args:
data: a dictionary containing the tensor-like data to be transformed. The ``keys`` specified
in this dictionary must be tensor like arrays that are channel first and have at most
three spatial dimensions
"""
d = dict(data)
first_key: Hashable = self.first_key(d)
if first_key == ():
out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())
return out

self.randomize(None)

for key in self.key_iterator(d):
# do the transform
if self._do_transform:
d[key] = self.sim_lowres_tfm(d[key]) # type: ignore
else:
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)
return d


SpatialResampleD = SpatialResampleDict = SpatialResampled
ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd
SpacingD = SpacingDict = Spacingd
Expand All @@ -2541,3 +2633,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
GridSplitD = GridSplitDict = GridSplitd
GridPatchD = GridPatchDict = GridPatchd
RandGridPatchD = RandGridPatchDict = RandGridPatchd
RandSimulateLowResolutionD = RandSimulateLowResolutionDict = RandSimulateLowResolutiond
83 changes: 83 additions & 0 deletions tests/test_rand_simulate_low_resolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import numpy as np
from parameterized import parameterized

from monai.transforms import RandSimulateLowResolution
from tests.utils import TEST_NDARRAYS, assert_allclose

TESTS = []
for p in TEST_NDARRAYS:
TESTS.append(
[
dict(prob=1.0, zoom_range=(0.8, 0.81)),
p(
np.array(
[
[
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]],
[[16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31]],
[[32, 33, 34, 35], [36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47]],
[[48, 49, 50, 51], [52, 53, 54, 55], [56, 57, 58, 59], [60, 61, 62, 63]],
]
]
)
),
np.array(
[
[
[
[0.0000, 0.6250, 1.3750, 2.0000],
[2.5000, 3.1250, 3.8750, 4.5000],
[5.5000, 6.1250, 6.8750, 7.5000],
[8.0000, 8.6250, 9.3750, 10.0000],
],
[
[10.0000, 10.6250, 11.3750, 12.0000],
[12.5000, 13.1250, 13.8750, 14.5000],
[15.5000, 16.1250, 16.8750, 17.5000],
[18.0000, 18.6250, 19.3750, 20.0000],
],
[
[22.0000, 22.6250, 23.3750, 24.0000],
[24.5000, 25.1250, 25.8750, 26.5000],
[27.5000, 28.1250, 28.8750, 29.5000],
[30.0000, 30.6250, 31.3750, 32.0000],
],
[
[32.0000, 32.6250, 33.3750, 34.0000],
[34.5000, 35.1250, 35.8750, 36.5000],
[37.5000, 38.1250, 38.8750, 39.5000],
[40.0000, 40.6250, 41.3750, 42.0000],
],
]
]
),
]
)


class TestRandGaussianSmooth(unittest.TestCase):
@parameterized.expand(TESTS)
def test_value(self, arguments, image, expected_data):
randsimlowres = RandSimulateLowResolution(**arguments)
randsimlowres.set_random_state(seed=0)
result = randsimlowres(image)
assert_allclose(result, expected_data, rtol=1e-4, type_test="tensor")


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit d6bafc9

Please sign in to comment.