diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index b8f57e0dbe..d1083a641b 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -308,24 +308,30 @@ Intensity :members: :special-members: __call__ +`RandCoarseTransform` +""""""""""""""""""""" + .. autoclass:: RandCoarseTransform + :members: + :special-members: __call__ + `RandCoarseDropout` """"""""""""""""""" .. autoclass:: RandCoarseDropout :members: :special-members: __call__ +`RandCoarseShuffle` +""""""""""""""""""" + .. autoclass:: RandCoarseShuffle + :members: + :special-members: __call__ + `HistogramNormalize` """""""""""""""""""" .. autoclass:: HistogramNormalize :members: :special-members: __call__ -`LocalPatchShuffling` -""""""""""""""""""""" -.. autoclass:: LocalPatchShuffling - :members: - :special-members: __call__ - IO ^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 2ea7e3aa63..b9ba303ed7 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -85,12 +85,13 @@ GibbsNoise, HistogramNormalize, KSpaceSpikeNoise, - LocalPatchShuffling, MaskIntensity, NormalizeIntensity, RandAdjustContrast, RandBiasField, RandCoarseDropout, + RandCoarseShuffle, + RandCoarseTransform, RandGaussianNoise, RandGaussianSharpen, RandGaussianSmooth, diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index de0c89e7b0..d10c3017a3 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -13,7 +13,7 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ -import copy +from abc import abstractmethod from collections.abc import Iterable from functools import partial from typing import Any, Callable, List, Optional, Sequence, Tuple, Union @@ -69,9 +69,10 @@ "RandGibbsNoise", "KSpaceSpikeNoise", "RandKSpaceSpikeNoise", + "RandCoarseTransform", "RandCoarseDropout", + "RandCoarseShuffle", "HistogramNormalize", - "LocalPatchShuffling", ] @@ -1630,13 +1631,11 @@ def _set_default_range(self, img: torch.Tensor) -> Sequence[Sequence[float]]: return tuple((i * 0.95, i * 1.1) for i in shifted_means) -class RandCoarseDropout(RandomizableTransform): +class RandCoarseTransform(RandomizableTransform): """ - Randomly coarse dropout regions in the image, then fill in the rectangular regions with specified value. - Or keep the rectangular regions and fill in the other areas with specified value. - Refer to papers: https://arxiv.org/abs/1708.04552, https://arxiv.org/pdf/1604.07379 - And other implementation: https://albumentations.ai/docs/api_reference/augmentations/transforms/ - #albumentations.augmentations.transforms.CoarseDropout. + Randomly select coarse regions in the image, then execute transform operations for the regions. + It's the base class of all kinds of region transforms. + Refer to papers: https://arxiv.org/abs/1708.04552 Args: holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to @@ -1646,12 +1645,6 @@ class RandCoarseDropout(RandomizableTransform): if some components of the `spatial_size` are non-positive values, the transform will use the corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted to `(32, 64)` if the second spatial dimension size of img is `64`. - dropout_holes: if `True`, dropout the regions of holes and fill value, if `False`, keep the holes and - dropout the outside and fill value. default to `True`. - fill_value: target value to fill the dropout regions, if providing a number, will use it as constant - value to fill all the regions. if providing a tuple for the `min` and `max`, will randomly select - value for every pixel / voxel from the range `[min, max)`. if None, will compute the `min` and `max` - value of input image then randomly select value to fill, default to None. max_holes: if not None, define the maximum number to randomly select the expected number of regions. max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region. if some components of the `max_spatial_size` are non-positive values, the transform will use the @@ -1665,8 +1658,6 @@ def __init__( self, holes: int, spatial_size: Union[Sequence[int], int], - dropout_holes: bool = True, - fill_value: Optional[Union[Tuple[float, float], float]] = None, max_holes: Optional[int] = None, max_spatial_size: Optional[Union[Sequence[int], int]] = None, prob: float = 0.1, @@ -1676,11 +1667,6 @@ def __init__( raise ValueError("number of holes must be greater than 0.") self.holes = holes self.spatial_size = spatial_size - self.dropout_holes = dropout_holes - if isinstance(fill_value, (tuple, list)): - if len(fill_value) != 2: - raise ValueError("fill value should contain 2 numbers if providing the `min` and `max`.") - self.fill_value = fill_value self.max_holes = max_holes self.max_spatial_size = max_spatial_size self.hole_coords: List = [] @@ -1697,28 +1683,126 @@ def randomize(self, img_size: Sequence[int]) -> None: valid_size = get_valid_patch_size(img_size, size) self.hole_coords.append((slice(None),) + get_random_patch(img_size, valid_size, self.R)) + @abstractmethod + def _transform_holes(self, img: np.ndarray) -> np.ndarray: + """ + Transform the randomly selected `self.hole_coords` in input images. + + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + def __call__(self, img: np.ndarray): self.randomize(img.shape[1:]) - ret = img if self._do_transform: - fill_value = (img.min(), img.max()) if self.fill_value is None else self.fill_value - - if self.dropout_holes: - for h in self.hole_coords: - if isinstance(fill_value, (tuple, list)): - ret[h] = self.R.uniform(fill_value[0], fill_value[1], size=img[h].shape) - else: - ret[h] = fill_value - else: + img = self._transform_holes(img=img) + + return img + + +class RandCoarseDropout(RandCoarseTransform): + """ + Randomly coarse dropout regions in the image, then fill in the rectangular regions with specified value. + Or keep the rectangular regions and fill in the other areas with specified value. + Refer to papers: https://arxiv.org/abs/1708.04552, https://arxiv.org/pdf/1604.07379 + And other implementation: https://albumentations.ai/docs/api_reference/augmentations/transforms/ + #albumentations.augmentations.transforms.CoarseDropout. + + Args: + holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to + randomly select the expected number of regions. + spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg + as the minimum spatial size to randomly select size for every region. + if some components of the `spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + dropout_holes: if `True`, dropout the regions of holes and fill value, if `False`, keep the holes and + dropout the outside and fill value. default to `True`. + fill_value: target value to fill the dropout regions, if providing a number, will use it as constant + value to fill all the regions. if providing a tuple for the `min` and `max`, will randomly select + value for every pixel / voxel from the range `[min, max)`. if None, will compute the `min` and `max` + value of input image then randomly select value to fill, default to None. + max_holes: if not None, define the maximum number to randomly select the expected number of regions. + max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region. + if some components of the `max_spatial_size` are non-positive values, the transform will use the + corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + prob: probability of applying the transform. + + """ + + def __init__( + self, + holes: int, + spatial_size: Union[Sequence[int], int], + dropout_holes: bool = True, + fill_value: Optional[Union[Tuple[float, float], float]] = None, + max_holes: Optional[int] = None, + max_spatial_size: Optional[Union[Sequence[int], int]] = None, + prob: float = 0.1, + ) -> None: + super().__init__( + holes=holes, + spatial_size=spatial_size, + max_holes=max_holes, + max_spatial_size=max_spatial_size, + prob=prob, + ) + self.dropout_holes = dropout_holes + if isinstance(fill_value, (tuple, list)): + if len(fill_value) != 2: + raise ValueError("fill value should contain 2 numbers if providing the `min` and `max`.") + self.fill_value = fill_value + + def _transform_holes(self, img: np.ndarray): + """ + Fill the randomly selected `self.hole_coords` in input images. + Please note that we usually only use `self.R` in `randomize()` method, here is a special case. + + """ + fill_value = (img.min(), img.max()) if self.fill_value is None else self.fill_value + + if self.dropout_holes: + for h in self.hole_coords: if isinstance(fill_value, (tuple, list)): - ret = self.R.uniform(fill_value[0], fill_value[1], size=img.shape).astype(img.dtype) + img[h] = self.R.uniform(fill_value[0], fill_value[1], size=img[h].shape) else: - ret = np.full_like(img, fill_value) - for h in self.hole_coords: - ret[h] = img[h] + img[h] = fill_value + ret = img + else: + if isinstance(fill_value, (tuple, list)): + ret = self.R.uniform(fill_value[0], fill_value[1], size=img.shape).astype(img.dtype) + else: + ret = np.full_like(img, fill_value) + for h in self.hole_coords: + ret[h] = img[h] return ret +class RandCoarseShuffle(RandCoarseTransform): + """ + Randomly select regions in the image, then shuffle the pixels within every region. + It shuffles every channel separately. + Refer to paper: + Kang, Guoliang, et al. "Patchshuffle regularization." arXiv preprint arXiv:1707.07103 (2017). + https://arxiv.org/abs/1707.07103 + + """ + + def _transform_holes(self, img: np.ndarray): + """ + Shuffle the content of randomly selected `self.hole_coords` in input images. + Please note that we usually only use `self.R` in `randomize()` method, here is a special case. + + """ + for h in self.hole_coords: + # shuffle every channel separately + for i, c in enumerate(img[h]): + patch_channel = c.flatten() + self.R.shuffle(patch_channel) + img[h][i] = patch_channel.reshape(c.shape) + return img + + class HistogramNormalize(Transform): """ Apply the histogram normalization to input image. @@ -1759,95 +1843,3 @@ def __call__(self, img: np.ndarray, mask: Optional[np.ndarray] = None) -> np.nda max=self.max, dtype=self.dtype, ) - - -class LocalPatchShuffling(RandomizableTransform): - """ - Takes a 3D image and based on input of the local patch size, shuffles the pixels of the local patch within it. - This process is repeated a for N number of times where every time a different random block is selected for local - pixel shuffling. - - Kang, Guoliang, et al. "Patchshuffle regularization." arXiv preprint arXiv:1707.07103 (2017). - """ - - def __init__( - self, - prob: float = 1.0, - number_blocks: int = 1000, - blocksize_ratio: int = 10, - channel_wise: bool = True, - device: Optional[torch.device] = None, - image_only: bool = False, - ) -> None: - """ - Args: - prob: The chance of this transform occuring on the given volume. - number_blocks: Total number of time a random 3D block will be selected for local shuffling of pixels/voxels - contained in the block. - blocksize_ratio: This ratio can be used to estimate the local 3D block sizes that will be selected. - channel_wise: If True, treats each channel of the image separately. - device: device on which the tensor will be allocated. - image_only: if True return only the image volume, otherwise return (image, affine). - """ - RandomizableTransform.__init__(self, prob) - self.prob = prob - self.number_blocks = number_blocks - self.blocksize_ratio = blocksize_ratio - self.channel_wise = channel_wise - - def _local_patch_shuffle(self, img: Union[torch.Tensor, np.ndarray], number_blocks: int, blocksize_ratio: int): - im_shape = img.shape - img_copy = copy.deepcopy(img) - for _each_block in range(number_blocks): - - block_size_x = self.R.randint(1, im_shape[0] // blocksize_ratio) - block_size_y = self.R.randint(1, im_shape[1] // blocksize_ratio) - block_size_z = self.R.randint(1, im_shape[2] // blocksize_ratio) - - noise_x = self.R.randint(0, im_shape[0] - block_size_x) - noise_y = self.R.randint(0, im_shape[1] - block_size_y) - noise_z = self.R.randint(0, im_shape[2] - block_size_z) - - local_patch = img[ - noise_x : noise_x + block_size_x, - noise_y : noise_y + block_size_y, - noise_z : noise_z + block_size_z, - ] - - local_patch = local_patch.flatten() - self.R.shuffle(local_patch) - local_patch = local_patch.reshape((block_size_x, block_size_y, block_size_z)) - - img_copy[ - noise_x : noise_x + block_size_x, noise_y : noise_y + block_size_y, noise_z : noise_z + block_size_z - ] = local_patch - - shuffled_image = img_copy - return shuffled_image - - def __call__( - self, - img: Union[np.ndarray, torch.Tensor], - # spatial_size: Optional[Union[Sequence[int], int]] = None, - # mode: Optional[Union[GridSampleMode, str]] = None, - # padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ): - """ - Args: - img: shape must be (num_channels, H, W[, D]), - - """ - - super().randomize(None) - if not self._do_transform: - return img - - if self.channel_wise: - # img = self._local_patch_shuffle(img=img) - for i, _d in enumerate(img): - img[i] = self._local_patch_shuffle( - img=img[i], blocksize_ratio=self.blocksize_ratio, number_blocks=self.number_blocks - ) - else: - raise AssertionError("If channel_wise is False, the image needs to be set to channel first") - return img diff --git a/tests/test_rand_local_patch_shuffle.py b/tests/test_rand_coarse_shuffle.py similarity index 59% rename from tests/test_rand_local_patch_shuffle.py rename to tests/test_rand_coarse_shuffle.py index 8e2eefb5d1..97d492fd24 100644 --- a/tests/test_rand_local_patch_shuffle.py +++ b/tests/test_rand_coarse_shuffle.py @@ -14,32 +14,39 @@ import numpy as np from parameterized import parameterized -from monai.transforms import LocalPatchShuffling +from monai.transforms import RandCoarseShuffle TEST_CASES = [ [ - {"number_blocks": 10, "blocksize_ratio": 1, "prob": 0.0}, + {"holes": 5, "spatial_size": 1, "max_spatial_size": -1, "prob": 0.0}, {"img": np.arange(8).reshape((1, 2, 2, 2))}, np.arange(8).reshape((1, 2, 2, 2)), ], [ - {"number_blocks": 10, "blocksize_ratio": 1, "prob": 1.0}, + {"holes": 10, "spatial_size": 1, "max_spatial_size": -1, "prob": 1.0}, {"img": np.arange(27).reshape((1, 3, 3, 3))}, - [ + np.asarray( [ - [[9, 1, 2], [3, 4, 5], [6, 7, 8]], - [[0, 10, 11], [12, 4, 14], [15, 16, 17]], - [[18, 19, 20], [21, 22, 23], [24, 25, 26]], - ] - ], + [ + [[8, 19, 26], [24, 6, 15], [0, 13, 25]], + [[17, 3, 5], [10, 1, 12], [22, 4, 11]], + [[21, 20, 23], [14, 2, 16], [18, 9, 7]], + ], + ], + ), + ], + [ + {"holes": 2, "spatial_size": 1, "max_spatial_size": -1, "prob": 1.0}, + {"img": np.arange(16).reshape((2, 2, 2, 2))}, + np.asarray([[[[6, 1], [4, 3]], [[0, 2], [7, 5]]], [[[14, 10], [9, 8]], [[12, 15], [13, 11]]]]), ], ] -class TestLocalPatchShuffle(unittest.TestCase): +class TestRandCoarseShuffle(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_local_patch_shuffle(self, input_param, input_data, expected_val): - g = LocalPatchShuffling(**input_param) + g = RandCoarseShuffle(**input_param) g.set_random_state(seed=12) result = g(**input_data) np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4)