diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 8990e7991d..bd3feb3497 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -661,6 +661,27 @@ Post-processing :members: :special-members: __call__ +Regularization +^^^^^^^^^^^^^^ + +`CutMix` +"""""""" +.. autoclass:: CutMix + :members: + :special-members: __call__ + +`CutOut` +"""""""" +.. autoclass:: CutOut + :members: + :special-members: __call__ + +`MixUp` +""""""" +.. autoclass:: MixUp + :members: + :special-members: __call__ + Signal ^^^^^^^ @@ -1707,6 +1728,27 @@ Post-processing (Dict) :members: :special-members: __call__ +Regularization (Dict) +^^^^^^^^^^^^^^^^^^^^^ + +`CutMixd` +""""""""" +.. autoclass:: CutMixd + :members: + :special-members: __call__ + +`CutOutd` +""""""""" +.. autoclass:: CutOutd + :members: + :special-members: __call__ + +`MixUpd` +"""""""" +.. autoclass:: MixUpd + :members: + :special-members: __call__ + Signal (Dict) ^^^^^^^^^^^^^ diff --git a/docs/source/transforms_idx.rst b/docs/source/transforms_idx.rst index f4d02a483f..650d45db71 100644 --- a/docs/source/transforms_idx.rst +++ b/docs/source/transforms_idx.rst @@ -74,6 +74,16 @@ Post-processing post.array post.dictionary +Regularization +^^^^^^^^^^^^^^ + +.. autosummary:: + :toctree: _gen + :nosignatures: + + regularization.array + regularization.dictionary + Signal ^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 2aa8fbf8a1..349533fb3e 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -336,6 +336,18 @@ VoteEnsembled, VoteEnsembleDict, ) +from .regularization.array import CutMix, CutOut, MixUp +from .regularization.dictionary import ( + CutMixd, + CutMixD, + CutMixDict, + CutOutd, + CutOutD, + CutOutDict, + MixUpd, + MixUpD, + MixUpDict, +) from .signal.array import ( SignalContinuousWavelet, SignalFillEmpty, diff --git a/monai/transforms/regularization/__init__.py b/monai/transforms/regularization/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/transforms/regularization/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py new file mode 100644 index 0000000000..6c9022d647 --- /dev/null +++ b/monai/transforms/regularization/array.py @@ -0,0 +1,173 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +from math import ceil, sqrt + +import torch + +from ..transform import RandomizableTransform + +__all__ = ["MixUp", "CutMix", "CutOut", "Mixer"] + + +class Mixer(RandomizableTransform): + def __init__(self, batch_size: int, alpha: float = 1.0) -> None: + """ + Mixer is a base class providing the basic logic for the mixup-class of + augmentations. In all cases, we need to sample the mixing weights for each + sample (lambda in the notation used in the papers). Also, pairs of samples + being mixed are picked by randomly shuffling the batch samples. + + Args: + batch_size (int): number of samples per batch. That is, samples are expected tp + be of size batchsize x channels [x depth] x height x width. + alpha (float, optional): mixing weights are sampled from the Beta(alpha, alpha) + distribution. Defaults to 1.0, the uniform distribution. + """ + super().__init__() + if alpha <= 0: + raise ValueError(f"Expected positive number, but got {alpha = }") + self.alpha = alpha + self.batch_size = batch_size + + @abstractmethod + def apply(self, data: torch.Tensor): + raise NotImplementedError() + + def randomize(self, data=None) -> None: + """ + Sometimes you need may to apply the same transform to different tensors. + The idea is to get a sample and then apply it with apply() as often + as needed. You need to call this method everytime you apply the transform to a new + batch. + """ + self._params = ( + torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32), + self.R.permutation(self.batch_size), + ) + + +class MixUp(Mixer): + """MixUp as described in: + Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz. + mixup: Beyond Empirical Risk Minimization, ICLR 2018 + + Class derived from :py:class:`monai.transforms.Mixer`. See corresponding + documentation for details on the constructor parameters. + """ + + def apply(self, data: torch.Tensor): + weight, perm = self._params + nsamples, *dims = data.shape + if len(weight) != nsamples: + raise ValueError(f"Expected batch of size: {len(weight)}, but got {nsamples}") + + if len(dims) not in [3, 4]: + raise ValueError("Unexpected number of dimensions") + + mixweight = weight[(Ellipsis,) + (None,) * len(dims)] + return mixweight * data + (1 - mixweight) * data[perm, ...] + + def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None): + self.randomize() + if labels is None: + return self.apply(data) + return self.apply(data), self.apply(labels) + + +class CutMix(Mixer): + """CutMix augmentation as described in: + Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo. + CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features, + ICCV 2019 + + Class derived from :py:class:`monai.transforms.Mixer`. See corresponding + documentation for details on the constructor parameters. Here, alpha not only determines + the mixing weight but also the size of the random rectangles used during for mixing. + Please refer to the paper for details. + + The most common use case is something close to: + + .. code-block:: python + + cm = CutMix(batch_size=8, alpha=0.5) + for batch in loader: + images, labels = batch + augimg, auglabels = cm(images, labels) + output = model(augimg) + loss = loss_function(output, auglabels) + ... + + """ + + def apply(self, data: torch.Tensor): + weights, perm = self._params + nsamples, _, *dims = data.shape + if len(weights) != nsamples: + raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") + + mask = torch.ones_like(data) + for s, weight in enumerate(weights): + coords = [torch.randint(0, d, size=(1,)) for d in dims] + lengths = [d * sqrt(1 - weight) for d in dims] + idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)] + mask[s][idx] = 0 + + return mask * data + (1 - mask) * data[perm, ...] + + def apply_on_labels(self, labels: torch.Tensor): + weights, perm = self._params + nsamples, *dims = labels.shape + if len(weights) != nsamples: + raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") + + mixweight = weights[(Ellipsis,) + (None,) * len(dims)] + return mixweight * labels + (1 - mixweight) * labels[perm, ...] + + def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None): + self.randomize() + augmented = self.apply(data) + return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented + + +class CutOut(Mixer): + """Cutout as described in the paper: + Terrance DeVries, Graham W. Taylor. + Improved Regularization of Convolutional Neural Networks with Cutout, + arXiv:1708.04552 + + Class derived from :py:class:`monai.transforms.Mixer`. See corresponding + documentation for details on the constructor parameters. Here, alpha not only determines + the mixing weight but also the size of the random rectangles being cut put. + Please refer to the paper for details. + """ + + def apply(self, data: torch.Tensor): + weights, _ = self._params + nsamples, _, *dims = data.shape + if len(weights) != nsamples: + raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") + + mask = torch.ones_like(data) + for s, weight in enumerate(weights): + coords = [torch.randint(0, d, size=(1,)) for d in dims] + lengths = [d * sqrt(1 - weight) for d in dims] + idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)] + mask[s][idx] = 0 + + return mask * data + + def __call__(self, data: torch.Tensor): + self.randomize() + return self.apply(data) diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py new file mode 100644 index 0000000000..373913da99 --- /dev/null +++ b/monai/transforms/regularization/dictionary.py @@ -0,0 +1,97 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from monai.config import KeysCollection +from monai.utils.misc import ensure_tuple + +from ..transform import MapTransform +from .array import CutMix, CutOut, MixUp + +__all__ = ["MixUpd", "MixUpD", "MixUpDict", "CutMixd", "CutMixD", "CutMixDict", "CutOutd", "CutOutD", "CutOutDict"] + + +class MixUpd(MapTransform): + """ + Dictionary-based version :py:class:`monai.transforms.MixUp`. + + Notice that the mixup transformation will be the same for all entries + for consistency, i.e. images and labels must be applied the same augmenation. + """ + + def __init__( + self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mixup = MixUp(batch_size, alpha) + + def __call__(self, data): + self.mixup.randomize() + result = dict(data) + for k in self.keys: + result[k] = self.mixup.apply(data[k]) + return result + + +class CutMixd(MapTransform): + """ + Dictionary-based version :py:class:`monai.transforms.CutMix`. + + Notice that the mixture weights will be the same for all entries + for consistency, i.e. images and labels must be aggregated with the same weights, + but the random crops are not. + """ + + def __init__( + self, + keys: KeysCollection, + batch_size: int, + label_keys: KeysCollection | None = None, + alpha: float = 1.0, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mixer = CutMix(batch_size, alpha) + self.label_keys = ensure_tuple(label_keys) if label_keys is not None else [] + + def __call__(self, data): + self.mixer.randomize() + result = dict(data) + for k in self.keys: + result[k] = self.mixer.apply(data[k]) + for k in self.label_keys: + result[k] = self.mixer.apply_on_labels(data[k]) + return result + + +class CutOutd(MapTransform): + """ + Dictionary-based version :py:class:`monai.transforms.CutOut`. + + Notice that the cutout is different for every entry in the dictionary. + """ + + def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bool = False) -> None: + super().__init__(keys, allow_missing_keys) + self.cutout = CutOut(batch_size) + + def __call__(self, data): + result = dict(data) + self.cutout.randomize() + for k in self.keys: + result[k] = self.cutout(data[k]) + return result + + +MixUpD = MixUpDict = MixUpd +CutMixD = CutMixDict = CutMixd +CutOutD = CutOutDict = CutOutd diff --git a/tests/test_regularization.py b/tests/test_regularization.py new file mode 100644 index 0000000000..d381ea72ca --- /dev/null +++ b/tests/test_regularization.py @@ -0,0 +1,90 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.transforms import CutMix, CutMixd, CutOut, MixUp, MixUpd + + +class TestMixup(unittest.TestCase): + def test_mixup(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + mixup = MixUp(6, 1.0) + output = mixup(sample) + self.assertEqual(output.shape, sample.shape) + self.assertTrue(any(not torch.allclose(sample, mixup(sample)) for _ in range(10))) + + with self.assertRaises(ValueError): + MixUp(6, -0.5) + + mixup = MixUp(6, 0.5) + for dims in [2, 3]: + with self.assertRaises(ValueError): + shape = (5, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + mixup(sample) + + def test_mixupd(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + t = torch.rand(*shape, dtype=torch.float32) + sample = {"a": t, "b": t} + mixup = MixUpd(["a", "b"], 6) + output = mixup(sample) + self.assertTrue(torch.allclose(output["a"], output["b"])) + + with self.assertRaises(ValueError): + MixUpd(["k1", "k2"], 6, -0.5) + + +class TestCutMix(unittest.TestCase): + def test_cutmix(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + cutmix = CutMix(6, 1.0) + output = cutmix(sample) + self.assertEqual(output.shape, sample.shape) + self.assertTrue(any(not torch.allclose(sample, cutmix(sample)) for _ in range(10))) + + def test_cutmixd(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + t = torch.rand(*shape, dtype=torch.float32) + label = torch.randint(0, 1, shape) + sample = {"a": t, "b": t, "lbl1": label, "lbl2": label} + cutmix = CutMixd(["a", "b"], 6, label_keys=("lbl1", "lbl2")) + output = cutmix(sample) + # croppings are different on each application + self.assertTrue(not torch.allclose(output["a"], output["b"])) + # but mixing of labels is not affected by it + self.assertTrue(torch.allclose(output["lbl1"], output["lbl2"])) + + +class TestCutOut(unittest.TestCase): + def test_cutout(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + cutout = CutOut(6, 1.0) + output = cutout(sample) + self.assertEqual(output.shape, sample.shape) + self.assertTrue(any(not torch.allclose(sample, cutout(sample)) for _ in range(10))) + + +if __name__ == "__main__": + unittest.main()