diff --git a/references/classification/train.py b/references/classification/train.py index 77f3127782e..47a7e5955e6 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -79,7 +79,7 @@ def _get_cache_path(filepath): return cache_path -def load_data(traindir, valdir, cache_dataset, distributed): +def load_data(traindir, valdir, args): # Data loading code print("Loading data") normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], @@ -88,20 +88,28 @@ def load_data(traindir, valdir, cache_dataset, distributed): print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) - if cache_dataset and os.path.exists(cache_path): + if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_train from {}".format(cache_path)) dataset, _ = torch.load(cache_path) else: + trans = [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + ] + if args.auto_augment is not None: + aa_policy = transforms.AutoAugmentPolicy(args.auto_augment) + trans.append(transforms.AutoAugment(policy=aa_policy)) + trans.extend([ + transforms.ToTensor(), + normalize, + ]) + if args.random_erase > 0: + trans.append(transforms.RandomErasing(p=args.random_erase)) dataset = torchvision.datasets.ImageFolder( traindir, - transforms.Compose([ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])) - if cache_dataset: + transforms.Compose(trans)) + if args.cache_dataset: print("Saving dataset_train to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset, traindir), cache_path) @@ -109,7 +117,7 @@ def load_data(traindir, valdir, cache_dataset, distributed): print("Loading validation data") cache_path = _get_cache_path(valdir) - if cache_dataset and os.path.exists(cache_path): + if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_test from {}".format(cache_path)) dataset_test, _ = torch.load(cache_path) @@ -122,13 +130,13 @@ def load_data(traindir, valdir, cache_dataset, distributed): transforms.ToTensor(), normalize, ])) - if cache_dataset: + if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) print("Creating data loaders") - if distributed: + if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) else: @@ -155,8 +163,7 @@ def main(args): train_dir = os.path.join(args.data_path, 'train') val_dir = os.path.join(args.data_path, 'val') - dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, - args.cache_dataset, args.distributed) + dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True) @@ -283,6 +290,8 @@ def parse_args(): help="Use pre-trained models from the modelzoo", action="store_true", ) + parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)') + parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)') # Mixed precision training parameters parser.add_argument('--apex', action='store_true',