Skip to content

Commit

Permalink
Use features of C4 expansion.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Jan 13, 2021
1 parent 2186b07 commit 381fa48
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
19 changes: 10 additions & 9 deletions torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from torchvision.ops import misc as misc_nn_ops
from .._utils import IntermediateLayerGetter
from .. import mobilenet
from .. import mobilenetv3
from .. import resnet


Expand Down Expand Up @@ -125,17 +125,16 @@ def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value,
return trainable_backbone_layers


def mobilenet_backbone(
def mobilenetv3_backbone(
backbone_name,
pretrained,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=2
):
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features
backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features

# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indeces = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1]
# Gather the indeces of blocks which are strided. These are the locations of C0, C1, C2, C3, C4 blocks.
stage_indeces = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)]
num_stages = len(stage_indeces)

# find the index of the layer from which we wont freeze
Expand All @@ -147,13 +146,15 @@ def mobilenet_backbone(
for parameter in b.parameters():
parameter.requires_grad_(False)

backbone_channels = backbone[-1].out_channels
C4_pos = stage_indeces[-1]
C4_expansion = backbone[C4_pos].block[0]
out_channels = 256

m = nn.Sequential(
backbone,
*backbone[:C4_pos],
C4_expansion,
# depthwise linear combination of channels to reduce their size
nn.Conv2d(backbone_channels, out_channels, 1),
nn.Conv2d(C4_expansion.out_channels, out_channels, 1),
)
m.out_channels = out_channels
return m
15 changes: 7 additions & 8 deletions torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from . import _utils as det_utils
from .anchor_utils import AnchorGenerator
from .transform import GeneralizedRCNNTransform
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenetv3_backbone
from ...ops.feature_pyramid_network import LastLevelP6P7
from ...ops import sigmoid_focal_loss
from ...ops import boxes as box_ops
Expand Down Expand Up @@ -559,8 +559,7 @@ def forward(self, images, targets=None):

# TODO: replace with pytorch links
model_urls = {
'retinanet_mobilenet_v3_large_coco':
'https://github.com/datumbox/torchvision-models/raw/main/retinanet_mobilenet_v3_large-1aa2fe5a.pth',
'retinanet_mobilenet_v3_large_coco': None,
'retinanet_resnet50_fpn_coco':
'https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth',
}
Expand Down Expand Up @@ -646,16 +645,16 @@ def retinanet_mobilenet_v3_large(pretrained=False, progress=True, num_classes=91
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
# check default parameters and by default set it to 3 if possible
# check default parameters and by default set it to 2 if possible
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3)
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 2)

if pretrained:
pretrained_backbone = False
backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone,
trainable_layers=trainable_backbone_layers)
backbone = mobilenetv3_backbone("mobilenet_v3_large", pretrained_backbone,
trainable_layers=trainable_backbone_layers)

anchor_sizes = ((128, 256, 512,), )
aspect_ratios = ((0.5, 1.0, 2.0), )
Expand Down

0 comments on commit 381fa48

Please sign in to comment.