Unofficial and Partial implementation of Fast AutoAugment in Pytorch.
- Fast AutoAugment (hereafter FAA) finds the optimal set of data augmentation operations via density matching using Bayesian optimization.
- FAA delivers comparable performance to AutoAugment but in a much shorter period of time.
- Unlike AutoAugment that discretizes the search space, FAA can handle continuous search space directly.
$ git clone https://github.com/junkwhinger/fastautoaugment_jsh.git
cd fastautoaugment_jsh
pip install -r requirements.txt
You can train or test the model with the baseline or optimal augmentation policies found by FAA with the following commands.
# Baseline
python train.py --model_dir experiments/baseline --eval_only
# Fast AutoAugment
python train.py --model_dir experiments/fastautoaugment --eval_only
# Baseline
python train.py --model_dir experiments/baseline
# Fast AutoAugment
python train.py --model_dir experiments/fastautoaugment
You can run Fast AutoAugment with the following commands. It takes time.
- train_mode: train models on D_Ms for 5 splits (takes roughly 4.5 hours)
- bayesian_mode: run bayesian optimiazation with HyperOpt to find the optimal policy (takes 3 hours)
- merge: aggregates the trials and combines the best policies from the splits. Writes the result as a file
optimal_policy.json
. To use the policy for training, please copy this file into yourexperiments/fastautoaugment
folder.
# Train models on D_Ms & Bayesian Optimization & Merge
python search_fastautoaugment.py --train_mode --bayesian_mode
# Bayesian Optimization & Merge
python search_fastautoaugment.py --bayesian_mode
# Merge only
python search_fastautoaugment.py
Here are the checkpoints I made during the replication of the paper.
- for training and testing (baseline / fastautoaugment)
experiments/baseline/best_model.torch
: a trained model for Baseline at epoch 200experiments/baseline/params.json
: a hyper-parameter set for Baselineexperiments/baseline/train.log
: a training log for Baseline
- for FAA policy searching
fastautoaugment/k0_t0_trials.pkl
: a pickled trial log for 0th split and 0th search widthfastautoaugment/model_k_0.torch
: a model file that trained on D_M[0]fastautoaugment/optimal_policy.json
: an optimal policy json file from the searchfastautoaugment/params.json
: a hyper-parameter set for FAAfastautoaugment/train.log
: a training log for FAA
- Operation : an augmentation function (e.g. Cutout)
- Probability : (attribute of an operation) the chance that the operation is turned on. This value ranges from 0 to 1, 0 being always off, 1 always on.
- Magnitude : (attribute of an operation) the amount that the operation transforms a given image. This value ranges from 0 to 1, and gets adjusted according to the corresponding range of its operation. For example, for Rotate means Rotate -30 degree.
- Sub-policy : a random sequence of operations. The length of a sub-policy is determined by Search Width(). For example, a sub-policy that has Cutout and Rotate transforms a given image in 4 ways.
- Policy : a set of sub-policies. FAA aims to find that contains from th split of the train dataset.
- FAA attempts to find the probability and magnitude for the following 16 augmentation operations.
- ShearX, ShearY, TranslateX, TranslateY, Rotate, AutoContrast, Invert, Equalize, Solarize, Posterize, Contrast, Color, Brightness, Sharpness, Cutout, Sample Pairing
- Inputs
- Step 1: Shuffle
- Step 2: Train
- Step 3: Explore-and-Exploit
- Step 4. Merge
Search: 7.5 GPU Hours on a single Tesla V100 16GB Memory machine (FAA in paper took 3.5 GPU Hours)
Model(CIFAR-10) | Baseline(paper) | Baseline(mine) | FAA(paper/direct) | FAA(mine/direct) |
---|---|---|---|---|
Wide-ResNet-40-2 | 5.3 | 5.6 | 3.7 | 5.5 |
-
Failed to replicate the Baseline performance of the paper despite the same hyper-parameter set I tried to follow.
- During debugging the original code, I found some discrepancies regarding the dataset size that could have caused the issue (covered in-depth in ETC).
- Revision needed on
train.py
andmodel/data_loader.py
.
-
Failed to replicate Fast AutoAugment performance. The improvement on Test Error that I gained via FAA (-0.1) is much smaller than the paper's result(-1.6).
- Revision needed on
search_fastautoaugment.py
.
- Revision needed on
-
The optimal policies I found appear to have a storng tendency to keep the given images unchanged as much as possible.
- The red dots mark the points with the lowest validation error.
- Brightness, Contrast, Color, Sharpness values (magnitudes) are around 0.5 which are converted around 1 that returns the original image.
- TranslateX, TranslateY are given high probabilties, yet they have values around 0.5, making the resulting transformation very subtle.
- AutoContrast, Invert, Solarize are given near zero probabilities.
- I chose a uniform distribution between 0 and 1 for the probability and magnitude for the following operations. I wonder if a distribution that excludes regions that barely changes images would lead to a different result. (e.g. Rotate between -30 ~ -10 and +10 ~ 30)
- I did not include SamplePairing from the set of augmentation operations to optimize.
- I did not use GradualWarmupScheduler for training on . (I did for training Baseline and FAA final model)
- I did not use parallel or distributed training using ray or horovod.
- Testing: FAA official implementation
python train.py -c confs/wresnet40x2_cifar10_b512.yaml --aug fa_reduced_cifar10 --dataset cifar10
- It runs validation steps with the same 16 images every 10th epoch (AutoAugment set 7,325 images aside for validation).
- The images used in the validation phase are augmented with the optimal policies, unlike my previous expectation that we do NOT augment the validation dataset for a normal training loop.
- The image batches loaded from validloader are as follows:
- On FAA paper, Algorithm 1 decribed on page 5 can be somewhat misleading.
- Junsik Hwang, junsik.whang@gmail.com
- Fast AutoAugment
- Paper: https://arxiv.org/abs/1905.00397
- Codes: https://github.com/kakaobrain/fast-autoaugment
- GradualWarmupScheduler: https://github.com/ildoonet/pytorch-gradual-warmup-lr
- AutoAugment
- Wide Residual Network
- HyperOpt
- Official Documentation: http://hyperopt.github.io/hyperopt/
- Tutorials: https://medium.com/district-data-labs/parameter-tuning-with-hyperopt-faa86acdfdce
- FloydHub (Cloud GPU)
- Website: http://floydhub.com/