forked from facebookresearch/ClassyVision
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
mixup data augmentation (facebookresearch#469)
Summary: Pull Request resolved: facebookresearch#469 This diff implements the mixup data augmentation in the paper `mixup: Beyond Empirical Risk Minimization` (https://arxiv.org/abs/1710.09412) Empirically, it is much faster to do mixup transform on gpu than doing that on cpu. # Results accuracy gain - 1.0% with 135 training epochs - 1.3% with 270 training epochs [TODO]: fix accuracy meter at training phases. Reviewed By: mannatsingh Differential Revision: D20911088 fbshipit-source-id: 339c1939eaa224125a072fe971a2e1ce958ca26a
- Loading branch information
1 parent
c635e82
commit 3539f57
Showing
7 changed files
with
144 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Any, Dict, Optional | ||
|
||
import torch | ||
from classy_vision.generic.util import convert_to_one_hot | ||
from torch.distributions.beta import Beta | ||
|
||
|
||
class MixupTransform: | ||
""" | ||
This implements the mixup data augmentation in the paper | ||
"mixup: Beyond Empirical Risk Minimization" (https://arxiv.org/abs/1710.09412) | ||
""" | ||
|
||
def __init__(self, alpha: float, num_classes: Optional[int] = None): | ||
""" | ||
Args: | ||
alpha: the hyperparameter of Beta distribution used to sample mixup | ||
coefficient. | ||
num_classes: number of classes in the dataset. | ||
""" | ||
self.alpha = alpha | ||
self.num_classes = num_classes | ||
|
||
def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: | ||
""" | ||
Args: | ||
sample: the batch data. | ||
""" | ||
if sample["target"].ndim == 1: | ||
assert self.num_classes is not None, "num_classes is expected for 1D target" | ||
sample["target"] = convert_to_one_hot( | ||
sample["target"].view(-1, 1), self.num_classes | ||
) | ||
else: | ||
assert sample["target"].ndim == 2, "target tensor shape must be 1D or 2D" | ||
|
||
c = Beta(self.alpha, self.alpha).sample().to(device=sample["target"].device) | ||
permuted_indices = torch.randperm(sample["target"].shape[0]) | ||
for key in ["input", "target"]: | ||
sample[key] = c * sample[key] + (1.0 - c) * sample[key][permuted_indices, :] | ||
|
||
return sample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import unittest | ||
|
||
import torch | ||
from classy_vision.dataset.transforms.mixup import MixupTransform | ||
|
||
|
||
class DatasetTransformsMixupTest(unittest.TestCase): | ||
def test_mixup_transform_single_label(self): | ||
alpha = 2.0 | ||
num_classes = 3 | ||
mixup_transform = MixupTransform(alpha, num_classes) | ||
sample = { | ||
"input": torch.rand(4, 3, 224, 224, dtype=torch.float32), | ||
"target": torch.as_tensor([0, 1, 2, 2], dtype=torch.int32), | ||
} | ||
sample_mixup = mixup_transform(sample) | ||
self.assertTrue(sample["input"].shape == sample_mixup["input"].shape) | ||
self.assertTrue(sample_mixup["target"].shape[0] == 4) | ||
self.assertTrue(sample_mixup["target"].shape[1] == 3) | ||
|
||
def test_mixup_transform_single_label_missing_num_classes(self): | ||
alpha = 2.0 | ||
mixup_transform = MixupTransform(alpha, None) | ||
sample = { | ||
"input": torch.rand(4, 3, 224, 224, dtype=torch.float32), | ||
"target": torch.as_tensor([0, 1, 2, 2], dtype=torch.int32), | ||
} | ||
with self.assertRaises(Exception): | ||
mixup_transform(sample) | ||
|
||
def test_mixup_transform_multi_label(self): | ||
alpha = 2.0 | ||
mixup_transform = MixupTransform(alpha, None) | ||
sample = { | ||
"input": torch.rand(4, 3, 224, 224, dtype=torch.float32), | ||
"target": torch.as_tensor( | ||
[[1, 0, 0, 0], [0, 1, 0, 1], [0, 0, 1, 1], [0, 1, 1, 1]], | ||
dtype=torch.int32, | ||
), | ||
} | ||
sample_mixup = mixup_transform(sample) | ||
self.assertTrue(sample["input"].shape == sample_mixup["input"].shape) | ||
self.assertTrue(sample["target"].shape == sample_mixup["target"].shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters