Skip to content

Commit

Permalink
mixup data augmentation (facebookresearch#469)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#469

This diff implements the mixup data augmentation in the paper `mixup: Beyond Empirical Risk Minimization` (https://arxiv.org/abs/1710.09412)

Empirically, it is much faster to do mixup transform on gpu than doing that on cpu.

# Results
accuracy gain
- 1.0% with 135 training epochs
- 1.3% with 270 training epochs

[TODO]: fix accuracy meter at training phases.

Reviewed By: mannatsingh

Differential Revision: D20911088

fbshipit-source-id: 339c1939eaa224125a072fe971a2e1ce958ca26a
  • Loading branch information
stephenyan1231 authored and facebook-github-bot committed Apr 11, 2020
1 parent c635e82 commit 3539f57
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 8 deletions.
48 changes: 48 additions & 0 deletions classy_vision/dataset/transforms/mixup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional

import torch
from classy_vision.generic.util import convert_to_one_hot
from torch.distributions.beta import Beta


class MixupTransform:
"""
This implements the mixup data augmentation in the paper
"mixup: Beyond Empirical Risk Minimization" (https://arxiv.org/abs/1710.09412)
"""

def __init__(self, alpha: float, num_classes: Optional[int] = None):
"""
Args:
alpha: the hyperparameter of Beta distribution used to sample mixup
coefficient.
num_classes: number of classes in the dataset.
"""
self.alpha = alpha
self.num_classes = num_classes

def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
sample: the batch data.
"""
if sample["target"].ndim == 1:
assert self.num_classes is not None, "num_classes is expected for 1D target"
sample["target"] = convert_to_one_hot(
sample["target"].view(-1, 1), self.num_classes
)
else:
assert sample["target"].ndim == 2, "target tensor shape must be 1D or 2D"

c = Beta(self.alpha, self.alpha).sample().to(device=sample["target"].device)
permuted_indices = torch.randperm(sample["target"].shape[0])
for key in ["input", "target"]:
sample[key] = c * sample[key] + (1.0 - c) * sample[key][permuted_indices, :]

return sample
11 changes: 5 additions & 6 deletions classy_vision/generic/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,12 +736,11 @@ def maybe_convert_to_one_hot(target, model_output):
):
target = convert_to_one_hot(target.view(-1, 1), model_output.shape[1])

assert (target.shape == model_output.shape) and (
torch.min(target.eq(0) + target.eq(1)) == 1
), (
"Target must be one-hot/multi-label encoded and of the "
"same shape as model_output."
)
# target are not necessarily hard 0/1 encoding. It can be soft
# (i.e. fractional) in some cases, such as mixup label
assert (
target.shape == model_output.shape
), "Target must of the same shape as model_output."

return target

Expand Down
9 changes: 8 additions & 1 deletion classy_vision/losses/soft_target_cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import torch
import torch.nn.functional as F
from classy_vision.generic.util import convert_to_one_hot
from classy_vision.losses import ClassyLoss, register_loss


Expand Down Expand Up @@ -58,13 +59,19 @@ def from_config(cls, config: Dict[str, Any]) -> "SoftTargetCrossEntropyLoss":
def forward(self, output, target):
"""for N examples and C classes
- output: N x C these are raw outputs (without softmax/sigmoid)
- target: N x C corresponding targets
- target: N x C or N corresponding targets
Target elements set to ignore_index contribute 0 loss.
Samples where all entries are ignore_index do not contribute to the loss
reduction.
"""
# check if targets are inputted as class integers
if target.ndim == 1:
assert (
output.shape[0] == target.shape[0]
), "SoftTargetCrossEntropyLoss requires output and target to have same batch size"
target = convert_to_one_hot(target.view(-1, 1), output.shape[1])
assert (
output.shape == target.shape
), "SoftTargetCrossEntropyLoss requires output and target to be same"
Expand Down
1 change: 0 additions & 1 deletion classy_vision/meters/accuracy_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def update(self, model_output, target, **kwargs):
for i, k in enumerate(self._topk):
self._curr_correct_predictions_k[i] += (
torch.gather(target, dim=1, index=pred[:, :k])
.long()
.max(dim=1)
.values.sum()
.item()
Expand Down
26 changes: 26 additions & 0 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
import torch.nn as nn
from classy_vision.dataset import ClassyDataset, build_dataset
from classy_vision.dataset.transforms.mixup import MixupTransform
from classy_vision.generic.distributed_util import (
all_reduce_mean,
barrier,
Expand Down Expand Up @@ -141,6 +142,7 @@ def __init__(self):
BroadcastBuffersMode.DISABLED
)
self.amp_args = None
self.mixup_transform = None
self.perf_log = []
self.last_batch = None
self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED
Expand Down Expand Up @@ -326,6 +328,19 @@ def set_amp_args(self, amp_args: Optional[Dict[str, Any]]):
logging.info(f"AMP enabled with args {amp_args}")
return self

def set_mixup_transform(self, mixup_transform: Optional["MixupTransform"]):
"""Disable / enable mixup transform for data augmentation
Args::
mixup_transform: a callable object which performs mixup data augmentation
"""
self.mixup_transform = mixup_transform
if mixup_transform is None:
logging.info(f"mixup disabled")
else:
logging.info(f"mixup enabled")
return self

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
"""Instantiates a ClassificationTask from a configuration.
Expand Down Expand Up @@ -353,6 +368,13 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
meters = build_meters(config.get("meters", {}))
model = build_model(config["model"])

mixup_transform = None
if config.get("mixup") is not None:
assert "alpha" in config["mixup"], "key alpha is missing in mixup dict"
mixup_transform = MixupTransform(
config["mixup"]["alpha"], config["mixup"].get("num_classes")
)

# hooks config is optional
hooks_config = config.get("hooks")
hooks = []
Expand All @@ -371,6 +393,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
.set_optimizer(optimizer)
.set_meters(meters)
.set_amp_args(amp_args)
.set_mixup_transform(mixup_transform)
.set_distributed_options(
broadcast_buffers_mode=BroadcastBuffersMode[
config.get("broadcast_buffers", "disabled").upper()
Expand Down Expand Up @@ -775,6 +798,9 @@ def train_step(self):
for key, value in sample.items():
sample[key] = recursive_copy_to_gpu(value, non_blocking=True)

if self.mixup_transform is not None:
sample = self.mixup_transform(sample)

with torch.enable_grad():
# Forward pass
output = self.model(sample["input"])
Expand Down
50 changes: 50 additions & 0 deletions test/dataset_transforms_mixup_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import unittest

import torch
from classy_vision.dataset.transforms.mixup import MixupTransform


class DatasetTransformsMixupTest(unittest.TestCase):
def test_mixup_transform_single_label(self):
alpha = 2.0
num_classes = 3
mixup_transform = MixupTransform(alpha, num_classes)
sample = {
"input": torch.rand(4, 3, 224, 224, dtype=torch.float32),
"target": torch.as_tensor([0, 1, 2, 2], dtype=torch.int32),
}
sample_mixup = mixup_transform(sample)
self.assertTrue(sample["input"].shape == sample_mixup["input"].shape)
self.assertTrue(sample_mixup["target"].shape[0] == 4)
self.assertTrue(sample_mixup["target"].shape[1] == 3)

def test_mixup_transform_single_label_missing_num_classes(self):
alpha = 2.0
mixup_transform = MixupTransform(alpha, None)
sample = {
"input": torch.rand(4, 3, 224, 224, dtype=torch.float32),
"target": torch.as_tensor([0, 1, 2, 2], dtype=torch.int32),
}
with self.assertRaises(Exception):
mixup_transform(sample)

def test_mixup_transform_multi_label(self):
alpha = 2.0
mixup_transform = MixupTransform(alpha, None)
sample = {
"input": torch.rand(4, 3, 224, 224, dtype=torch.float32),
"target": torch.as_tensor(
[[1, 0, 0, 0], [0, 1, 0, 1], [0, 0, 1, 1], [0, 1, 1, 1]],
dtype=torch.int32,
),
}
sample_mixup = mixup_transform(sample)
self.assertTrue(sample["input"].shape == sample_mixup["input"].shape)
self.assertTrue(sample["target"].shape == sample_mixup["target"].shape)
7 changes: 7 additions & 0 deletions test/losses_soft_target_cross_entropy_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ def test_soft_target_cross_entropy(self):
targets = torch.tensor([[-1, 0, 0, 0, 1]])
self.assertAlmostEqual(crit(outputs, targets).item(), 5.01097918)

def test_soft_target_cross_entropy_integer_label(self):
config = self._get_config()
crit = SoftTargetCrossEntropyLoss.from_config(config)
outputs = self._get_outputs()
targets = torch.tensor([4])
self.assertAlmostEqual(crit(outputs, targets).item(), 5.01097918)

def test_unnormalized_soft_target_cross_entropy(self):
config = {
"name": "soft_target_cross_entropy",
Expand Down

0 comments on commit 3539f57

Please sign in to comment.