Skip to content

Commit

Permalink
Adding fcn and deeplabv3 directly on mobilenetv3 backbone.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Jan 22, 2021
1 parent 77da44c commit 462d59a
Showing 1 changed file with 58 additions and 5 deletions.
63 changes: 58 additions & 5 deletions torchvision/models/segmentation/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from .._utils import IntermediateLayerGetter
from ..utils import load_state_dict_from_url
from .. import mobilenet
from .. import resnet
from .deeplabv3 import DeepLabHead, DeepLabV3
from .fcn import FCN, FCNHead


__all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101']
__all__ = ['fcn_resnet50', 'fcn_resnet101', 'fcn_mobilenet_v3_large', 'deeplabv3_resnet50', 'deeplabv3_resnet101',
'deeplabv3_mobilenet_v3_large']


model_urls = {
'fcn_resnet50_coco': 'https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth',
'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth',
'fcn_mobilenet_v3_large_coco': None,
'deeplabv3_resnet50_coco': 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth',
'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth',
'deeplabv3_mobilenet_v3_large_coco': None,
}


Expand All @@ -22,7 +26,22 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True)
pretrained=pretrained_backbone,
replace_stride_with_dilation=[False, True, True])
out_layer = 'layer4'
out_inplanes = 2048
aux_layer = 'layer3'
aux_inplanes = 1024
elif 'mobilenet' in backbone_name:
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained_backbone).features

# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [
len(backbone) - 1]
out_pos = stage_indices[-1]
out_layer = str(out_pos)
out_inplanes = backbone[out_pos].out_channels
aux_pos = stage_indices[-2]
aux_layer = str(aux_pos)
aux_inplanes = backbone[aux_pos].out_channels
else:
raise NotImplementedError('backbone {} is not supported as of now'.format(backbone_name))

Expand All @@ -33,15 +52,13 @@ def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True)

aux_classifier = None
if aux:
inplanes = 1024
aux_classifier = FCNHead(inplanes, num_classes)
aux_classifier = FCNHead(aux_inplanes, num_classes)

model_map = {
'deeplabv3': (DeepLabHead, DeepLabV3),
'fcn': (FCNHead, FCN),
}
inplanes = 2048
classifier = model_map[name][0](inplanes, num_classes)
classifier = model_map[name][0](out_inplanes, num_classes)
base_model = model_map[name][1]

model = base_model(backbone, classifier, aux_classifier)
Expand Down Expand Up @@ -71,6 +88,8 @@ def fcn_resnet50(pretrained=False, progress=True,
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)

Expand All @@ -83,10 +102,26 @@ def fcn_resnet101(pretrained=False, progress=True,
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)


def fcn_mobilenet_v3_large(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
"""Constructs a Fully-Convolutional Network model with a MobileNetV3-Large backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model('fcn', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs)


def deeplabv3_resnet50(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
"""Constructs a DeepLabV3 model with a ResNet-50 backbone.
Expand All @@ -95,6 +130,8 @@ def deeplabv3_resnet50(pretrained=False, progress=True,
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)

Expand All @@ -107,5 +144,21 @@ def deeplabv3_resnet101(pretrained=False, progress=True,
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs)


def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True,
num_classes=21, aux_loss=None, **kwargs):
"""Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs)

0 comments on commit 462d59a

Please sign in to comment.