Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

support proxylessnas with NNI NAS APIs #1863

Merged
merged 64 commits into from
Feb 10, 2020
Merged
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
bd7c0f0
update doc
zhangql08hit Nov 5, 2019
d9f3afb
update
zhangql08hit Nov 5, 2019
b5c295c
update
zhangql08hit Nov 5, 2019
8f9c7bc
update
zhangql08hit Nov 5, 2019
0e7f6b9
update
zhangql08hit Nov 5, 2019
c7c218f
Merge branch 'dev-nas-refactor' of github.com:Microsoft/nni into dev-…
zhangql08hit Nov 11, 2019
bccb536
init commit
zhangql08hit Nov 13, 2019
5647dd0
update
zhangql08hit Nov 14, 2019
5b7cb43
update
zhangql08hit Nov 14, 2019
366b793
debug
zhangql08hit Nov 16, 2019
088a56c
update
zhangql08hit Nov 17, 2019
0a47184
update
zhangql08hit Nov 17, 2019
52dd740
update
zhangql08hit Nov 18, 2019
95b1974
update
zhangql08hit Nov 18, 2019
44145e4
update
zhangql08hit Nov 18, 2019
a0febf9
update
zhangql08hit Nov 19, 2019
7b92588
Merge branch 'dev-nas-refactor' of github.com:Microsoft/nni into dev-…
zhangql08hit Nov 19, 2019
cc8a1fb
update
zhangql08hit Nov 19, 2019
dacbdf7
update
zhangql08hit Nov 19, 2019
007e043
update
zhangql08hit Nov 20, 2019
098fe3d
fix bug
zhangql08hit Dec 10, 2019
ca9ec6c
update
zhangql08hit Dec 11, 2019
3d2159e
update
zhangql08hit Dec 11, 2019
181f9c0
update
zhangql08hit Dec 12, 2019
5578542
update
zhangql08hit Dec 12, 2019
55c75f5
update
zhangql08hit Dec 12, 2019
5a403ec
update
zhangql08hit Dec 12, 2019
3e2ee56
update
zhangql08hit Dec 12, 2019
ed27d47
update
zhangql08hit Dec 12, 2019
80eafc4
update
zhangql08hit Dec 12, 2019
b8e29e8
update
zhangql08hit Dec 12, 2019
1354025
update
zhangql08hit Dec 13, 2019
4b611db
update
zhangql08hit Dec 13, 2019
a624c12
fix bug
zhangql08hit Dec 13, 2019
393d837
update
zhangql08hit Dec 13, 2019
f768b5a
update
zhangql08hit Dec 13, 2019
8bc69a8
update
zhangql08hit Dec 13, 2019
640103d
update
zhangql08hit Dec 13, 2019
810ea95
update
zhangql08hit Dec 13, 2019
b890fce
update
zhangql08hit Dec 13, 2019
5996d4f
update
zhangql08hit Dec 13, 2019
51128bb
update
zhangql08hit Dec 16, 2019
3b3aba4
update
zhangql08hit Dec 16, 2019
14f3f1d
add retrain
zhangql08hit Dec 16, 2019
346e5a4
update
zhangql08hit Dec 16, 2019
5ff1ccd
Merge branch 'master' of github.com:Microsoft/nni into dev-plnas
zhangql08hit Dec 16, 2019
0eddd52
update
zhangql08hit Dec 16, 2019
8d499ec
retrain tested
zhangql08hit Dec 17, 2019
cb0c2e9
update
zhangql08hit Dec 18, 2019
38fab2d
update
zhangql08hit Dec 18, 2019
eab6e22
update
zhangql08hit Dec 19, 2019
a7f59f0
update
zhangql08hit Dec 19, 2019
8ef5f6d
add doc string
zhangql08hit Dec 22, 2019
477af83
update
zhangql08hit Dec 22, 2019
aab28e2
add docstring
zhangql08hit Dec 23, 2019
d9a778d
update
zhangql08hit Dec 23, 2019
e9c7603
add doc
zhangql08hit Dec 23, 2019
4f7c662
resolve comments
QuanluZhang Dec 23, 2019
0b8cb1e
update
QuanluZhang Dec 24, 2019
b462b25
Merge branch 'master' of https://github.com/microsoft/nni into dev-plnas
Feb 10, 2020
927ab9e
update doc
Feb 10, 2020
e32bb72
update doc
Feb 10, 2020
61d2944
update toctree
Feb 10, 2020
fba009e
fix broken link
Feb 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 208 additions & 0 deletions examples/nas/proxylessnas/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

import os
import numpy as np
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

def get_split_list(in_dim, child_num):
in_dim_list = [in_dim // child_num] * child_num
for _i in range(in_dim % child_num):
in_dim_list[_i] += 1
return in_dim_list

class DataProvider:
VALID_SEED = 0 # random seed for the validation set

@staticmethod
def name():
""" Return name of the dataset """
raise NotImplementedError

@property
def data_shape(self):
""" Return shape as python list of one data entry """
raise NotImplementedError

@property
def n_classes(self):
""" Return `int` of num classes """
raise NotImplementedError

@property
def save_path(self):
""" local path to save the data """
raise NotImplementedError

@property
def data_url(self):
""" link to download the data """
raise NotImplementedError

@staticmethod
def random_sample_valid_set(train_labels, valid_size, n_classes):
train_size = len(train_labels)
assert train_size > valid_size

g = torch.Generator()
g.manual_seed(DataProvider.VALID_SEED) # set random seed before sampling validation set
rand_indexes = torch.randperm(train_size, generator=g).tolist()

train_indexes, valid_indexes = [], []
per_class_remain = get_split_list(valid_size, n_classes)

for idx in rand_indexes:
label = train_labels[idx]
if isinstance(label, float):
label = int(label)
elif isinstance(label, np.ndarray):
label = np.argmax(label)
else:
assert isinstance(label, int)
if per_class_remain[label] > 0:
valid_indexes.append(idx)
per_class_remain[label] -= 1
else:
train_indexes.append(idx)
return train_indexes, valid_indexes


class ImagenetDataProvider(DataProvider):

def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None,
n_worker=32, resize_scale=0.08, distort_color=None):

self._save_path = save_path
train_transforms = self.build_train_transform(distort_color, resize_scale)
train_dataset = datasets.ImageFolder(self.train_path, train_transforms)

if valid_size is not None:
if isinstance(valid_size, float):
valid_size = int(valid_size * len(train_dataset))
else:
assert isinstance(valid_size, int), 'invalid valid_size: %s' % valid_size
train_indexes, valid_indexes = self.random_sample_valid_set(
[cls for _, cls in train_dataset.samples], valid_size, self.n_classes,
)
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)

valid_dataset = datasets.ImageFolder(self.train_path, transforms.Compose([
transforms.Resize(self.resize_value),
transforms.CenterCrop(self.image_size),
transforms.ToTensor(),
self.normalize,
]))

self.train = torch.utils.data.DataLoader(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
self.train = torch.utils.data.DataLoader(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None

self.test = torch.utils.data.DataLoader(
datasets.ImageFolder(self.valid_path, transforms.Compose([
transforms.Resize(self.resize_value),
transforms.CenterCrop(self.image_size),
transforms.ToTensor(),
self.normalize,
])), batch_size=test_batch_size, shuffle=False, num_workers=n_worker, pin_memory=True,
)

if self.valid is None:
self.valid = self.test

@staticmethod
def name():
return 'imagenet'

@property
def data_shape(self):
return 3, self.image_size, self.image_size # C, H, W

@property
def n_classes(self):
return 1000

@property
def save_path(self):
if self._save_path is None:
self._save_path = '/dataset/imagenet'
return self._save_path

@property
def data_url(self):
raise ValueError('unable to download ImageNet')

@property
def train_path(self):
return os.path.join(self.save_path, 'train')

@property
def valid_path(self):
return os.path.join(self._save_path, 'val')

@property
def normalize(self):
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

def build_train_transform(self, distort_color, resize_scale):
print('Color jitter: %s' % distort_color)
if distort_color == 'strong':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif distort_color == 'normal':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if color_transform is None:
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(self.image_size, scale=(resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
self.normalize,
])
else:
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(self.image_size, scale=(resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
color_transform,
transforms.ToTensor(),
self.normalize,
])
return train_transforms

@property
def resize_value(self):
return 256

@property
def image_size(self):
return 224
109 changes: 109 additions & 0 deletions examples/nas/proxylessnas/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import sys
import logging
from argparse import ArgumentParser
import torch
import datasets

from putils import get_parameters
from model import SearchMobileNet
from nni.nas.pytorch.proxylessnas import ProxylessNasTrainer
from retrain import Retrain

logger = logging.getLogger('nni_proxylessnas')

if __name__ == "__main__":
parser = ArgumentParser("proxylessnas")
# configurations of the model
parser.add_argument("--n_cell_stages", default='4,4,4,4,4,1', type=str)
parser.add_argument("--stride_stages", default='2,2,2,1,2,1', type=str)
parser.add_argument("--width_stages", default='24,40,80,96,192,320', type=str)
parser.add_argument("--bn_momentum", default=0.1, type=float)
parser.add_argument("--bn_eps", default=1e-3, type=float)
parser.add_argument("--dropout_rate", default=0, type=float)
parser.add_argument("--no_decay_keys", default='bn', type=str, choices=[None, 'bn', 'bn#bias'])
# configurations of imagenet dataset
parser.add_argument("--data_path", default='/data/ssd1/v-yugzh/imagenet/', type=str)
#parser.add_argument("--data_path", default='/mnt/v-yugzh/imagenet/', type=str)
parser.add_argument("--train_batch_size", default=256, type=int)
parser.add_argument("--test_batch_size", default=500, type=int)
parser.add_argument("--n_worker", default=32, type=int)
parser.add_argument("--resize_scale", default=0.08, type=float)
parser.add_argument("--distort_color", default='normal', type=str, choices=['normal', 'strong', 'None'])
# configurations for training mode
parser.add_argument("--train_mode", default='search', type=str, choices=['search', 'retrain'])
# configurations for search
parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str)
parser.add_argument("--arch_path", default='./arch_path.pt', type=str)
# configurations for retrain
parser.add_argument("--exported_arch_path", default=None, type=str)

args = parser.parse_args()
if args.train_mode == 'retrain' and args.exported_arch_path is None:
logger.error('When --train_mode is retrain, --exported_arch_path must be specified.')
sys.exit(-1)

model = SearchMobileNet(width_stages=[int(i) for i in args.width_stages.split(',')],
n_cell_stages=[int(i) for i in args.n_cell_stages.split(',')],
stride_stages=[int(i) for i in args.stride_stages.split(',')],
n_classes=1000,
dropout_rate=args.dropout_rate,
bn_param=(args.bn_momentum, args.bn_eps))
logger.info('SearchMobileNet model create done')
model.init_model()
logger.info('SearchMobileNet model init done')

# move network to GPU if available
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')

logger.info('Creating data provider...')
data_provider = datasets.ImagenetDataProvider(save_path=args.data_path,
train_batch_size=args.train_batch_size,
test_batch_size=args.test_batch_size,
valid_size=None,
n_worker=args.n_worker,
resize_scale=args.resize_scale,
distort_color=args.distort_color)
logger.info('Creating data provider done')

if args.no_decay_keys:
keys = args.no_decay_keys
momentum, nesterov = 0.9, True
optimizer = torch.optim.SGD([
{'params': get_parameters(model, keys, mode='exclude'), 'weight_decay': 4e-5},
{'params': get_parameters(model, keys, mode='include'), 'weight_decay': 0},
], lr=0.05, momentum=momentum, nesterov=nesterov)
else:
optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5)

if args.train_mode == 'search':
# this is architecture search
logger.info('Creating ProxylessNasTrainer...')
trainer = ProxylessNasTrainer(model,
model_optim=optimizer,
train_loader=data_provider.train,
valid_loader=data_provider.valid,
device=device,
warmup=True,
ckpt_path=args.checkpoint_path,
arch_path=args.arch_path)

logger.info('Start to train with ProxylessNasTrainer...')
trainer.train()
logger.info('Training done')
trainer.export(args.arch_path)
logger.info('Best architecture exported in %s', args.arch_path)
elif args.train_mode == 'retrain':
# this is retrain
from nni.nas.pytorch.fixed import apply_fixed_architecture
assert os.path.isfile(args.exported_arch_path), \
"exported_arch_path {} should be a file.".format(args.exported_arch_path)
apply_fixed_architecture(model, args.exported_arch_path, device=device)
trainer = Retrain(model, optimizer, device, data_provider, n_epochs=300)
trainer.run()
Loading