Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure deterministic in MixUp, CutMix, CutOut #7813

Merged
merged 9 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions monai/transforms/regularization/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

import torch

from monai.data.meta_obj import get_track_meta
from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor

from ..transform import RandomizableTransform

__all__ = ["MixUp", "CutMix", "CutOut", "Mixer"]
Expand Down Expand Up @@ -53,9 +56,11 @@ def randomize(self, data=None) -> None:
as needed. You need to call this method everytime you apply the transform to a new
batch.
"""
super().randomize(None)
self._params = (
torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32),
self.R.permutation(self.batch_size),
[torch.from_numpy(self.R.randint(0, d, size=(1,))) for d in data.shape[2:]] if data is not None else None,
)


Expand All @@ -69,7 +74,7 @@ class MixUp(Mixer):
"""

def apply(self, data: torch.Tensor):
weight, perm = self._params
weight, perm, _ = self._params
nsamples, *dims = data.shape
if len(weight) != nsamples:
raise ValueError(f"Expected batch of size: {len(weight)}, but got {nsamples}")
Expand All @@ -80,11 +85,20 @@ def apply(self, data: torch.Tensor):
mixweight = weight[(Ellipsis,) + (None,) * len(dims)]
return mixweight * data + (1 - mixweight) * data[perm, ...]

def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
self.randomize()
def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True):
data = convert_to_tensor(data, track_meta=get_track_meta())
data_t = convert_to_tensor(data, track_meta=False)
if labels is not None:
labels_t = convert_to_tensor(labels, track_meta=get_track_meta())
labels_t = convert_to_tensor(labels, track_meta=False)
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
if randomize:
self.randomize()
if labels is None:
return self.apply(data)
return self.apply(data), self.apply(labels)
return convert_to_dst_type(self.apply(data_t), dst=data)[0]
return (
convert_to_dst_type(self.apply(data_t), dst=data)[0],
convert_to_dst_type(self.apply(labels_t), dst=labels)[0],
)


class CutMix(Mixer):
Expand Down Expand Up @@ -113,33 +127,40 @@ class CutMix(Mixer):
"""

def apply(self, data: torch.Tensor):
weights, perm = self._params
weights, perm, coords = self._params
nsamples, _, *dims = data.shape
if len(weights) != nsamples:
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")

mask = torch.ones_like(data)
for s, weight in enumerate(weights):
coords = [torch.randint(0, d, size=(1,)) for d in dims]
lengths = [d * sqrt(1 - weight) for d in dims]
idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)]
mask[s][idx] = 0

return mask * data + (1 - mask) * data[perm, ...]

def apply_on_labels(self, labels: torch.Tensor):
weights, perm = self._params
weights, perm, _ = self._params
nsamples, *dims = labels.shape
if len(weights) != nsamples:
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")

mixweight = weights[(Ellipsis,) + (None,) * len(dims)]
return mixweight * labels + (1 - mixweight) * labels[perm, ...]

def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
self.randomize()
augmented = self.apply(data)
return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented
def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True):
data = convert_to_tensor(data, track_meta=get_track_meta())
data_t = convert_to_tensor(data, track_meta=False)
if labels is not None:
labels_t = convert_to_tensor(labels, track_meta=get_track_meta())
labels_t = convert_to_tensor(labels, track_meta=False)
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
if randomize:
self.randomize(data)
augmented = convert_to_dst_type(self.apply(data_t), dst=data)[0]
if labels is not None:
augmented_label = convert_to_dst_type(self.apply(labels_t), dst=labels)[0]
return (augmented, augmented_label) if labels is not None else augmented

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, I'm a researcher currently working with CutMix and other similar method. My group and I noticed an inconsistent behavior in CutMix on the segmentation labels on our different machines. After some digging, we think it may come from this PR: before this change, self.apply_on_labels(labels) is used for the label, but the modified change use self.apply(labels_t) which uses the same cropping mechanism as for the image.

This PR does not mention the reason for this change, could you enlighten me please ?

Thank you in advance for your time.



class CutOut(Mixer):
Expand All @@ -155,20 +176,22 @@ class CutOut(Mixer):
"""

def apply(self, data: torch.Tensor):
weights, _ = self._params
weights, _, coords = self._params
nsamples, _, *dims = data.shape
if len(weights) != nsamples:
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")

mask = torch.ones_like(data)
for s, weight in enumerate(weights):
coords = [torch.randint(0, d, size=(1,)) for d in dims]
lengths = [d * sqrt(1 - weight) for d in dims]
idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)]
mask[s][idx] = 0

return mask * data

def __call__(self, data: torch.Tensor):
self.randomize()
return self.apply(data)
def __call__(self, data: torch.Tensor, randomize=True):
data = convert_to_tensor(data, track_meta=get_track_meta())
data_t = convert_to_tensor(data, track_meta=False)
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
if randomize:
self.randomize(data)
return convert_to_dst_type(self.apply(data_t), dst=data)[0]
73 changes: 50 additions & 23 deletions monai/transforms/regularization/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,23 @@

from __future__ import annotations

from collections.abc import Hashable

import numpy as np

from monai.config import KeysCollection
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.utils import convert_to_tensor
from monai.utils.misc import ensure_tuple

from ..transform import MapTransform
from ..transform import MapTransform, RandomizableTransform
from .array import CutMix, CutOut, MixUp

__all__ = ["MixUpd", "MixUpD", "MixUpDict", "CutMixd", "CutMixD", "CutMixDict", "CutOutd", "CutOutD", "CutOutDict"]


class MixUpd(MapTransform):
class MixUpd(MapTransform, RandomizableTransform):
"""
Dictionary-based version :py:class:`monai.transforms.MixUp`.

Expand All @@ -31,18 +38,24 @@ class MixUpd(MapTransform):
def __init__(
self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False
) -> None:
super().__init__(keys, allow_missing_keys)
MapTransform.__init__(self, keys, allow_missing_keys)
self.mixup = MixUp(batch_size, alpha)

def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> MixUpd:
super().set_random_state(seed, state)
self.mixup.set_random_state(seed, state)
return self

def __call__(self, data):
self.mixup.randomize()
result = dict(data)
for k in self.keys:
result[k] = self.mixup.apply(data[k])
return result
d = dict(data)
# all the keys share the same random state
self.mixup.randomize(None)
ericspod marked this conversation as resolved.
Show resolved Hide resolved
for k in self.key_iterator(self.keys):
d[k] = self.mixup(data[k], randomize=False)
return d


class CutMixd(MapTransform):
class CutMixd(MapTransform, RandomizableTransform):
"""
Dictionary-based version :py:class:`monai.transforms.CutMix`.

Expand All @@ -63,17 +76,27 @@ def __init__(
self.mixer = CutMix(batch_size, alpha)
self.label_keys = ensure_tuple(label_keys) if label_keys is not None else []

def __call__(self, data):
self.mixer.randomize()
result = dict(data)
for k in self.keys:
result[k] = self.mixer.apply(data[k])
for k in self.label_keys:
result[k] = self.mixer.apply_on_labels(data[k])
return result
def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> MixUpd:
super().set_random_state(seed, state)
self.mixer.set_random_state(seed, state)
return self


class CutOutd(MapTransform):
def __call__(self, data):
d = dict(data)
first_key: Hashable = self.first_key(d)
if first_key == ():
out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())
return out
self.mixer.randomize(d[first_key])
for key, label_key in self.key_iterator(self.keys, self.label_keys):
ret = self.mixer(data[key], data.get(label_key, None), randomize=False)
d[key] = ret[0]
if label_key in d:
d[label_key] = ret[1]
return d


class CutOutd(MapTransform, RandomizableTransform):
"""
Dictionary-based version :py:class:`monai.transforms.CutOut`.

Expand All @@ -85,11 +108,15 @@ def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bo
self.cutout = CutOut(batch_size)

def __call__(self, data):
result = dict(data)
self.cutout.randomize()
d = dict(data)
first_key: Hashable = self.first_key(d)
if first_key == ():
out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())
return out
self.cutout.randomize(d[first_key])
for k in self.keys:
result[k] = self.cutout(data[k])
return result
d[k] = self.cutout(data[k])
return d


MixUpD = MixUpDict = MixUpd
Expand Down
62 changes: 35 additions & 27 deletions tests/test_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,31 @@

import unittest

import numpy as np
import torch

from monai.transforms import CutMix, CutMixd, CutOut, MixUp, MixUpd
from monai.utils import set_determinism
from tests.utils import assert_allclose


@unittest.skip("Mixup is non-deterministic. Skip it temporarily")
class TestMixup(unittest.TestCase):

def setUp(self) -> None:
set_determinism(seed=0)

def tearDown(self) -> None:
set_determinism(None)

def test_mixup(self):
for dims in [2, 3]:
shape = (6, 3) + (32,) * dims
sample = torch.rand(*shape, dtype=torch.float32)
mixup = MixUp(6, 1.0)
mixup.set_random_state(seed=0)
output = mixup(sample)
np.random.seed(0)
# simulate the randomize() of transform
np.random.random()
weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)
perm = np.random.permutation(6)
self.assertEqual(output.shape, sample.shape)
self.assertTrue(any(not torch.allclose(sample, mixup(sample)) for _ in range(10)))
mixweight = weight[(Ellipsis,) + (None,) * (dims + 1)]
expected = mixweight * sample + (1 - mixweight) * sample[perm, ...]
assert_allclose(output, expected, type_test=False, atol=1e-7)

with self.assertRaises(ValueError):
MixUp(6, -0.5)
Expand All @@ -53,27 +55,32 @@ def test_mixupd(self):
t = torch.rand(*shape, dtype=torch.float32)
sample = {"a": t, "b": t}
mixup = MixUpd(["a", "b"], 6)
mixup.set_random_state(seed=0)
output = mixup(sample)
self.assertTrue(torch.allclose(output["a"], output["b"]))
np.random.seed(0)
# simulate the randomize() of transform
np.random.random()
weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)
perm = np.random.permutation(6)
self.assertEqual(output["a"].shape, sample["a"].shape)
mixweight = weight[(Ellipsis,) + (None,) * (dims + 1)]
expected = mixweight * sample["a"] + (1 - mixweight) * sample["a"][perm, ...]
assert_allclose(output["a"], expected, type_test=False, atol=1e-7)
assert_allclose(output["a"], output["b"], type_test=False, atol=1e-7)
# self.assertTrue(torch.allclose(output["a"], output["b"]))

with self.assertRaises(ValueError):
MixUpd(["k1", "k2"], 6, -0.5)


@unittest.skip("CutMix is non-deterministic. Skip it temporarily")
class TestCutMix(unittest.TestCase):

def setUp(self) -> None:
set_determinism(seed=0)

def tearDown(self) -> None:
set_determinism(None)

def test_cutmix(self):
for dims in [2, 3]:
shape = (6, 3) + (32,) * dims
sample = torch.rand(*shape, dtype=torch.float32)
cutmix = CutMix(6, 1.0)
cutmix.set_random_state(seed=0)
output = cutmix(sample)
self.assertEqual(output.shape, sample.shape)
self.assertTrue(any(not torch.allclose(sample, cutmix(sample)) for _ in range(10)))
Expand All @@ -85,30 +92,31 @@ def test_cutmixd(self):
label = torch.randint(0, 1, shape)
sample = {"a": t, "b": t, "lbl1": label, "lbl2": label}
cutmix = CutMixd(["a", "b"], 6, label_keys=("lbl1", "lbl2"))
cutmix.set_random_state(seed=123)
output = cutmix(sample)
# croppings are different on each application
self.assertTrue(not torch.allclose(output["a"], output["b"]))
ericspod marked this conversation as resolved.
Show resolved Hide resolved
# but mixing of labels is not affected by it
self.assertTrue(torch.allclose(output["lbl1"], output["lbl2"]))


@unittest.skip("CutOut is non-deterministic. Skip it temporarily")
class TestCutOut(unittest.TestCase):

def setUp(self) -> None:
set_determinism(seed=0)

def tearDown(self) -> None:
set_determinism(None)

def test_cutout(self):
for dims in [2, 3]:
shape = (6, 3) + (32,) * dims
sample = torch.rand(*shape, dtype=torch.float32)
cutout = CutOut(6, 1.0)
cutout.set_random_state(seed=123)
output = cutout(sample)
np.random.seed(123)
# simulate the randomize() of transform
np.random.random()
weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)
perm = np.random.permutation(6)
coords = [torch.from_numpy(np.random.randint(0, d, size=(1,))) for d in sample.shape[2:]]
assert_allclose(weight, cutout._params[0])
assert_allclose(perm, cutout._params[1])
self.assertSequenceEqual(coords, cutout._params[2])
self.assertEqual(output.shape, sample.shape)
self.assertTrue(any(not torch.allclose(sample, cutout(sample)) for _ in range(10)))


if __name__ == "__main__":
Expand Down
Loading