Skip to content

Commit

Permalink
Add Lite R-ASPP with MobileNetV3 backbone.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Jan 25, 2021
1 parent 10a51cf commit 359d941
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 9 deletions.
Binary file not shown.
1 change: 1 addition & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def get_available_video_models():
"fcn_resnet50",
"fcn_resnet101",
"fcn_mobilenet_v3_large",
"lraspp_mobilenet_v3_large",
)


Expand Down
1 change: 0 additions & 1 deletion torchvision/models/segmentation/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections import OrderedDict

import torch
from torch import nn
from torch.nn import functional as F

Expand Down
67 changes: 67 additions & 0 deletions torchvision/models/segmentation/lraspp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from collections import OrderedDict

from torch import nn, Tensor
from torch.nn import functional as F
from typing import Dict


__all__ = ["LRASPP"]


class LRASPP(nn.Module):
"""
Implements a Lite R-ASPP Network for semantic segmentation.
Args:
backbone (nn.Module): the network used to compute the features for the model.
The backbone should return an OrderedDict[Tensor], with the key being
"high" for the high level feature map and "low" for the low level feature map.
low_channels (int): the number of channels of the low level features.
high_channels (int): the number of channels of the high level features.
num_classes (int): number of output classes of the model (including the background).
inter_channels (int, optional): the number of channels for intermediate computations.
"""

def __init__(self, backbone, low_channels, high_channels, num_classes, inter_channels=128):
super().__init__()
self.backbone = backbone
self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels)

def forward(self, input):
features = self.backbone(input)
out = self.classifier(features)
out = F.interpolate(out, size=input.shape[-2:], mode='bilinear', align_corners=False)

result = OrderedDict()
result["out"] = out

return result


class LRASPPHead(nn.Module):

def __init__(self, low_channels, high_channels, num_classes, inter_channels):
super().__init__()
self.cbr = nn.Sequential(
nn.Conv2d(high_channels, inter_channels, 1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True)
)
self.scale = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(high_channels, inter_channels, 1, bias=False),
nn.Sigmoid(),
)
self.low_classifier = nn.Conv2d(low_channels, num_classes, 1)
self.high_classifier = nn.Conv2d(inter_channels, num_classes, 1)

def forward(self, input: Dict[str, Tensor]) -> Tensor:
low = input["low"]
high = input["high"]

x = self.cbr(high)
s = self.scale(high)
x = x * s
x = F.interpolate(x, size=low.shape[-2:], mode='bilinear', align_corners=False)

return self.low_classifier(low) + self.high_classifier(x)
60 changes: 52 additions & 8 deletions torchvision/models/segmentation/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from .. import resnet
from .deeplabv3 import DeepLabHead, DeepLabV3
from .fcn import FCN, FCNHead
from .lraspp import LRASPP


__all__ = ['fcn_resnet50', 'fcn_resnet101', 'fcn_mobilenet_v3_large', 'deeplabv3_resnet50', 'deeplabv3_resnet101',
'deeplabv3_mobilenet_v3_large']
'deeplabv3_mobilenet_v3_large', 'lraspp_mobilenet_v3_large']


model_urls = {
Expand All @@ -17,6 +18,7 @@
'deeplabv3_resnet50_coco': 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth',
'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth',
'deeplabv3_mobilenet_v3_large_coco': None,
'lraspp_mobilenet_v3_large_coco': None,
}


Expand Down Expand Up @@ -69,13 +71,34 @@ def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss
aux_loss = True
model = _segm_model(arch_type, backbone, num_classes, aux_loss, **kwargs)
if pretrained:
arch = arch_type + '_' + backbone + '_coco'
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
_load_weights(model, arch_type, backbone, progress)
return model


def _load_weights(model, arch_type, backbone, progress):
arch = arch_type + '_' + backbone + '_coco'
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)


def _segm_lraspp_mobilenetv3(backbone_name, num_classes, pretrained_backbone=True):
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, _dilated=True).features

# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
low_pos = stage_indices[-4] # use C2 here which has output_stride = 8
high_pos = stage_indices[-1] # use C5 which has output_stride = 16
low_channels = backbone[low_pos].out_channels
high_channels = backbone[high_pos].out_channels

backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): 'low', str(high_pos): 'high'})

model = LRASPP(backbone, low_channels, high_channels, num_classes)
return model


Expand Down Expand Up @@ -161,3 +184,24 @@ def deeplabv3_mobilenet_v3_large(pretrained=False, progress=True,
aux_loss (bool): If True, it uses an auxiliary loss
"""
return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs)


def lraspp_mobilenet_v3_large(pretrained=False, progress=True, num_classes=21, **kwargs):
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
contains the same classes as Pascal VOC
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
"""
if kwargs.pop("aux_loss", False):
raise NotImplementedError('This model does not use auxiliary loss')

backbone_name = 'mobilenet_v3_large'
model = _segm_lraspp_mobilenetv3(backbone_name, num_classes, **kwargs)

if pretrained:
_load_weights(model, 'lraspp', backbone_name, progress)

return model

0 comments on commit 359d941

Please sign in to comment.