Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MobileNetV3 Architecture in TorchVision #3182

Merged
merged 26 commits into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
22b9b12
partial implementation network architecture
datumbox Dec 15, 2020
0cff6fb
Simplify implementation and adding blocks.
datumbox Dec 16, 2020
fd54fdf
Refactoring the code to make it more readable.
datumbox Dec 16, 2020
834b185
Adding first conv layers.
datumbox Dec 16, 2020
1edd16b
Moving mobilenet.py to mobilenetv2.py
datumbox Dec 16, 2020
2f52f0d
Adding mobilenet.py for BC.
datumbox Dec 16, 2020
bb2ec9e
Extending ConvBNReLU for reuse.
datumbox Dec 16, 2020
e95ee5c
Moving mobilenet.py to mobilenetv2.py
datumbox Dec 16, 2020
2ebe8ba
Adding mobilenet.py for BC.
datumbox Dec 16, 2020
0c31a33
Extending ConvBNReLU for reuse.
datumbox Dec 16, 2020
db7522b
Reduce import scope on mobilenet to only the public and versioned cla…
datumbox Dec 16, 2020
fdbcec7
Merge branch 'refactoring/mobilenetv2_bc_move' into models/mobilenetv3
datumbox Dec 16, 2020
16f55f5
Further simplify by reusing MobileNetv2 methods.
datumbox Dec 16, 2020
8162fa4
Adding the remaining implementation of mobilenetv3.
datumbox Dec 16, 2020
8615585
Adding tests, docs and init methods.
datumbox Dec 16, 2020
8664fde
Refactoring and fixing formatter.
datumbox Dec 16, 2020
cfa20b7
Fixing type issues.
datumbox Dec 16, 2020
c189ae1
Using build-in Hardsigmoid and Hardswish.
datumbox Dec 16, 2020
5030435
Merge branch 'master' into models/mobilenetv3
datumbox Dec 17, 2020
9a758a8
Code review nits.
datumbox Dec 17, 2020
25f8b26
Putting inplace on Dropout.
datumbox Dec 17, 2020
5198385
Adding rmsprop support on the train.py
datumbox Jan 1, 2021
e4d130f
Adding auto-augment and random-erase in the training scripts.
datumbox Jan 1, 2021
5d0a664
Merge branch 'master' into models/mobilenetv3
datumbox Jan 3, 2021
c0a13a2
Adding support for reduced tail on MobileNetV3.
datumbox Jan 5, 2021
2414d2d
Tagging blocks with comments.
datumbox Jan 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from torchvision.models.googlenet import googlenet
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.mobilenet import mobilenet_v2
from torchvision.models.mobilenetv2 import mobilenet_v2
from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small
from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \
mnasnet1_3

Expand Down
49 changes: 33 additions & 16 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _get_cache_path(filepath):
return cache_path


def load_data(traindir, valdir, cache_dataset, distributed):
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
Expand All @@ -88,28 +88,36 @@ def load_data(traindir, valdir, cache_dataset, distributed):
print("Loading training data")
st = time.time()
cache_path = _get_cache_path(traindir)
if cache_dataset and os.path.exists(cache_path):
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path)
else:
trans = [
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
]
if args.auto_augment is not None:
aa_policy = transforms.AutoAugmentPolicy(args.auto_augment)
trans.append(transforms.AutoAugment(policy=aa_policy))
trans.extend([
transforms.ToTensor(),
normalize,
])
if args.random_erase > 0:
trans.append(transforms.RandomErasing(p=args.random_erase))
dataset = torchvision.datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
if cache_dataset:
transforms.Compose(trans))
if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
print("Took", time.time() - st)

print("Loading validation data")
cache_path = _get_cache_path(valdir)
if cache_dataset and os.path.exists(cache_path):
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path)
Expand All @@ -122,13 +130,13 @@ def load_data(traindir, valdir, cache_dataset, distributed):
transforms.ToTensor(),
normalize,
]))
if cache_dataset:
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)

print("Creating data loaders")
if distributed:
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
else:
Expand All @@ -155,8 +163,7 @@ def main(args):

train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir,
args.cache_dataset, args.distributed)
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True)
Expand All @@ -173,8 +180,15 @@ def main(args):

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
opt_name = args.opt.lower()
if opt_name == 'sgd':
optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
elif opt_name == 'rmsprop':
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These hardcoded params are crucial for convergence =. They can be turned into args.

else:
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))

if args.apex:
model, optimizer = amp.initialize(model, optimizer,
Expand Down Expand Up @@ -238,6 +252,7 @@ def parse_args():
help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--opt', default='sgd', type=str, help='optimizer')
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
Expand Down Expand Up @@ -275,6 +290,8 @@ def parse_args():
help="Use pre-trained models from the modelzoo",
action="store_true",
)
parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)')
parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)')

# Mixed precision training parameters
parser.add_argument('--apex', action='store_true',
Expand Down
Binary file not shown.
Binary file not shown.
17 changes: 9 additions & 8 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,16 +275,17 @@ def test_mobilenetv2_residual_setting(self):
out = model(x)
self.assertEqual(out.shape[-1], 1000)

def test_mobilenetv2_norm_layer(self):
model = models.__dict__["mobilenet_v2"]()
self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
def test_mobilenet_norm_layer(self):
for name in ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"]:
model = models.__dict__[name]()
self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))

def get_gn(num_channels):
return nn.GroupNorm(32, num_channels)
def get_gn(num_channels):
return nn.GroupNorm(32, num_channels)

model = models.__dict__["mobilenet_v2"](norm_layer=get_gn)
self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules()))
model = models.__dict__[name](norm_layer=get_gn)
self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules()))

def test_inceptionv3_eval(self):
# replacement for models.inception_v3(pretrained=True) that does not download weights
Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/mobilenet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .mobilenetv2 import MobileNetV2, mobilenet_v2, __all__ as mv2_all
from .mobilenetv3 import MobileNetV3, mobilenet_v3_large, mobilenet_v3_small, __all__ as mv3_all

__all__ = mv2_all
__all__ = mv2_all + mv3_all
Loading