Skip to content

Commit

Permalink
Adding auto-augment and random-erase in the training scripts.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Jan 1, 2021
1 parent 5198385 commit e4d130f
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _get_cache_path(filepath):
return cache_path


def load_data(traindir, valdir, cache_dataset, distributed):
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
Expand All @@ -88,28 +88,36 @@ def load_data(traindir, valdir, cache_dataset, distributed):
print("Loading training data")
st = time.time()
cache_path = _get_cache_path(traindir)
if cache_dataset and os.path.exists(cache_path):
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path)
else:
trans = [
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
]
if args.auto_augment is not None:
aa_policy = transforms.AutoAugmentPolicy(args.auto_augment)
trans.append(transforms.AutoAugment(policy=aa_policy))
trans.extend([
transforms.ToTensor(),
normalize,
])
if args.random_erase > 0:
trans.append(transforms.RandomErasing(p=args.random_erase))
dataset = torchvision.datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
if cache_dataset:
transforms.Compose(trans))
if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
print("Took", time.time() - st)

print("Loading validation data")
cache_path = _get_cache_path(valdir)
if cache_dataset and os.path.exists(cache_path):
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path)
Expand All @@ -122,13 +130,13 @@ def load_data(traindir, valdir, cache_dataset, distributed):
transforms.ToTensor(),
normalize,
]))
if cache_dataset:
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)

print("Creating data loaders")
if distributed:
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
else:
Expand All @@ -155,8 +163,7 @@ def main(args):

train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir,
args.cache_dataset, args.distributed)
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True)
Expand Down Expand Up @@ -283,6 +290,8 @@ def parse_args():
help="Use pre-trained models from the modelzoo",
action="store_true",
)
parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)')
parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)')

# Mixed precision training parameters
parser.add_argument('--apex', action='store_true',
Expand Down

0 comments on commit e4d130f

Please sign in to comment.