Skip to content

Commit

Permalink
refactor code as submodule of transforms module
Browse files Browse the repository at this point in the history
  • Loading branch information
juampatronics committed Nov 6, 2023
1 parent 5107554 commit b899421
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 115 deletions.
28 changes: 0 additions & 28 deletions docs/source/regularization.rst

This file was deleted.

42 changes: 42 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,27 @@ Post-processing
:members:
:special-members: __call__

Regularization
^^^^^^^^^^^^^^

`CutMix`
""""""""
.. autoclass:: CutMix
:members:
:special-members: __call__

`CutOut`
""""""""
.. autoclass:: CutOut
:members:
:special-members: __call__

`MixUp`
"""""""
.. autoclass:: MixUp
:members:
:special-members: __call__

Signal
^^^^^^^

Expand Down Expand Up @@ -1707,6 +1728,27 @@ Post-processing (Dict)
:members:
:special-members: __call__

Regularization (Dict)
^^^^^^^^^^^^^^^^^^^^^

`CutMixd`
"""""""""
.. autoclass:: CutMixd
:members:
:special-members: __call__

`CutOutd`
"""""""""
.. autoclass:: CutOutd
:members:
:special-members: __call__

`MixUpd`
""""""""
.. autoclass:: MixUpd
:members:
:special-members: __call__

Signal (Dict)
^^^^^^^^^^^^^

Expand Down
10 changes: 10 additions & 0 deletions docs/source/transforms_idx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ Post-processing
post.array
post.dictionary

Regularization
^^^^^^^^^^^^^^

.. autosummary::
:toctree: _gen
:nosignatures:

regularization.array
regularization.dictionary

Signal
^^^^^^

Expand Down
1 change: 0 additions & 1 deletion monai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
"metrics",
"networks",
"optimizers",
"regularization",
"transforms",
"utils",
"visualize",
Expand Down
12 changes: 12 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,3 +693,15 @@
unravel_index,
where,
)
from .regularization.array import MixUp, CutMix, CutOut
from .regularization.dictionary import (
CutMixd,
CutMixD,
CutMixDict,
MixUpd,
MixUpD,
MixUpDict,
CutOutd,
CutOutD,
CutOutDict,
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .mixup import MixUp, MixUpd, CutMix, CutMixd

__all__ = ["MixUp", "MixUpd", "CutMix", "CutMixd"]
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from monai.config import KeysCollection
import torch
from monai.transforms import Transform, MapTransform
from monai.utils.misc import ensure_tuple
from math import sqrt, ceil

__all__ = ["MixUp", "CutMix", "CutOut"]


class Mixer(Transform):
def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
Expand Down Expand Up @@ -70,33 +71,6 @@ def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
return self.apply(params, data), self.apply(params, labels)


class MixUpd(MapTransform):
"""MixUp as described in:
Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz.
mixup: Beyond Empirical Risk Minimization, ICLR 2018
Notice that the mixup transformation will be the same for all entries
for consistency, i.e. images and labels must be applied the same augmenation.
"""

def __init__(
self,
keys: KeysCollection,
batch_size: int,
alpha: float = 1.0,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.mixup = MixUp(batch_size, alpha)

def __call__(self, data):
result = dict(data)
params = self.mixup.sample_params()
for k in self.keys:
result[k] = self.mixup.apply(params, data[k])
return result


class CutMix(Mixer):
"""CutMix augmentation as described in:
Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo
Expand Down Expand Up @@ -133,39 +107,6 @@ def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
return (augmented, MixUp.apply(params, labels)) if labels is not None else augmented


class CutMixd(MapTransform):
"""CutMix augmentation as described in:
Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features,
ICCV 2019
Notice that the mixture weights will be the same for all entries
for consistency, i.e. images and labels must be aggregated with the same weights,
but the random crops are not.
"""

def __init__(
self,
keys: KeysCollection,
batch_size: int,
label_keys: KeysCollection | None = None,
alpha: float = 1.0,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.mixer = CutMix(batch_size, alpha)
self.label_keys = ensure_tuple(label_keys) if label_keys is not None else []

def __call__(self, data):
result = dict(data)
params = self.mixer.sample_params()
for k in self.keys:
result[k] = self.mixer.apply(params, data[k])
for k in self.label_keys:
result[k] = self.mixer.apply_on_labels(params, data[k])
return result


class CutOut(Mixer):
"""Cutout as described in the paper:
Terrance DeVries, Graham W. Taylor
Expand All @@ -191,23 +132,3 @@ def apply(cls, params: Tuple[torch.Tensor, torch.Tensor], data: torch.Tensor):

def __call__(self, data: torch.Tensor):
return self.apply(self.sample_params(), data)


class CutOutd(MapTransform):
"""Cutout as described in the paper:
Terrance DeVries, Graham W. Taylor
Improved Regularization of Convolutional Neural Networks with Cutout
arXiv:1708.04552
Notice that the cutout is different for every entry in the dictionary.
"""

def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bool = False) -> None:
super().__init__(keys, allow_missing_keys)
self.cutout = CutOut(batch_size)

def __call__(self, data):
result = dict(data)
for k in self.keys:
result[k] = self.cutout(data[k])
return result
106 changes: 106 additions & 0 deletions monai/transforms/regularization/dictionary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from monai.config import KeysCollection
from monai.transforms import MapTransform
from monai.utils.misc import ensure_tuple
from .array import MixUp, CutMix, CutOut

__all__ = ["MixUpd", "MixUpD", "MixUpDict", "CutMixd", "CutMixD", "CutMixDict", "CutOutd", "CutOutD", "CutOutDict"]


class MixUpd(MapTransform):
"""MixUp as described in:
Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz.
mixup: Beyond Empirical Risk Minimization, ICLR 2018
Notice that the mixup transformation will be the same for all entries
for consistency, i.e. images and labels must be applied the same augmenation.
"""

def __init__(
self,
keys: KeysCollection,
batch_size: int,
alpha: float = 1.0,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.mixup = MixUp(batch_size, alpha)

def __call__(self, data):
result = dict(data)
params = self.mixup.sample_params()
for k in self.keys:
result[k] = self.mixup.apply(params, data[k])
return result


MixUpD = MixUpDict = MixUpd


class CutMixd(MapTransform):
"""CutMix augmentation as described in:
Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features,
ICCV 2019
Notice that the mixture weights will be the same for all entries
for consistency, i.e. images and labels must be aggregated with the same weights,
but the random crops are not.
"""

def __init__(
self,
keys: KeysCollection,
batch_size: int,
label_keys: KeysCollection | None = None,
alpha: float = 1.0,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.mixer = CutMix(batch_size, alpha)
self.label_keys = ensure_tuple(label_keys) if label_keys is not None else []

def __call__(self, data):
result = dict(data)
params = self.mixer.sample_params()
for k in self.keys:
result[k] = self.mixer.apply(params, data[k])
for k in self.label_keys:
result[k] = self.mixer.apply_on_labels(params, data[k])
return result


CutMixD = CutMixDict = CutMixd


class CutOutd(MapTransform):
"""Cutout as described in the paper:
Terrance DeVries, Graham W. Taylor
Improved Regularization of Convolutional Neural Networks with Cutout
arXiv:1708.04552
Notice that the cutout is different for every entry in the dictionary.
"""

def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bool = False) -> None:
super().__init__(keys, allow_missing_keys)
self.cutout = CutOut(batch_size)

def __call__(self, data):
result = dict(data)
for k in self.keys:
result[k] = self.cutout(data[k])
return result


CutOutD = CutOutDict = CutOutd
2 changes: 1 addition & 1 deletion tests/test_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

import torch
from monai.regularization.mixup import MixUp, MixUpd, CutMix, CutMixd, CutOut
from monai.transforms import MixUp, MixUpd, CutMix, CutMixd, CutOut
import unittest


Expand Down

0 comments on commit b899421

Please sign in to comment.