Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Post-paper Detection Optimizations #5444

Merged
merged 54 commits into from
Apr 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
0f6fa39
Use frozen BN only if pre-trained.
datumbox Feb 18, 2022
7a94595
Add LSJ and ability to from scratch training.
datumbox Feb 19, 2022
89a5b9d
Fixing formatter
datumbox Feb 19, 2022
20470c1
Merge branch 'main' into references/detection_recipe
datumbox Feb 20, 2022
22d7f47
Merge branch 'main' into references/detection_recipe
datumbox Feb 24, 2022
a0322dd
Merge branch 'main' into references/detection_recipe
datumbox Feb 24, 2022
2943182
Merge branch 'main' into references/detection_recipe
datumbox Feb 25, 2022
629e149
Merge branch 'main' into references/detection_recipe
datumbox Feb 25, 2022
53fbd71
Merge branch 'main' into references/detection_recipe
datumbox Feb 25, 2022
d3b8dad
Merge branch 'main' into references/detection_recipe
datumbox Mar 2, 2022
5aa97c3
Merge branch 'main' into references/detection_recipe
datumbox Mar 4, 2022
8537c48
Adding `--opt` and `--norm-weight-decay` support in Detection.
datumbox Mar 5, 2022
f7f8e2f
Fix error message
datumbox Mar 5, 2022
ed2a24c
Make ScaleJitter proportional.
datumbox Mar 6, 2022
bc7a8a9
Merge branch 'main' into references/detection_recipe
datumbox Mar 7, 2022
a1786bb
Merge branch 'main' into references/detection_recipe
datumbox Mar 7, 2022
6c12921
Merge branch 'main' into references/detection_recipe
datumbox Mar 7, 2022
bcf0afc
Adding more norm layers in split_normalization_params.
datumbox Mar 8, 2022
9c66a7c
Merge branch 'main' into references/detection_recipe
datumbox Mar 10, 2022
65e4116
Add FixedSizeCrop
datumbox Mar 10, 2022
ab63af6
Temporary fix for fill values on PIL
datumbox Mar 10, 2022
7365cdc
Merge branch 'main' into references/detection_recipe
datumbox Mar 11, 2022
c714c66
Fix the bug on fill.
datumbox Mar 11, 2022
c415639
Merge branch 'main' into references/detection_recipe
datumbox Mar 12, 2022
13fb5b3
Add RandomShortestSize.
datumbox Mar 12, 2022
0d230ab
Skip resize when an augmentation method is used.
datumbox Mar 12, 2022
a187917
multiscale in [480, 800]
datumbox Mar 13, 2022
4b4d300
Merge branch 'main' into references/detection_recipe
datumbox Mar 14, 2022
8dd6975
Merge branch 'main' into references/detection_recipe
datumbox Mar 14, 2022
efcf9ed
Merge branch 'main' into references/detection_recipe
datumbox Mar 22, 2022
7542a94
Add missing star
datumbox Mar 23, 2022
c67893c
Add new RetinaNet variant.
datumbox Mar 23, 2022
0d7917c
Add tests.
datumbox Mar 23, 2022
7354684
Update expected file for old retina
datumbox Mar 23, 2022
bb8aac0
Merge branch 'main' into references/detection_recipe
datumbox Mar 23, 2022
38ef843
Fixing tests
datumbox Mar 23, 2022
cd9c302
Add FrozenBN to retinav2
datumbox Mar 24, 2022
29c57f6
Fix network initialization issues
datumbox Mar 24, 2022
19f2b25
Adding BN support in MaskRCNNHeads and FPN
datumbox Mar 24, 2022
124fd8a
Adding support of FasterRCNNHeads
datumbox Mar 24, 2022
53aa8b7
Introduce norm_layers in backbone utils.
datumbox Mar 25, 2022
f9ba509
Bigger RPN head + 2x rcnn v2 models.
datumbox Mar 25, 2022
e5cbb97
Merge branch 'main' into references/detection_recipe
datumbox Mar 30, 2022
592784d
Adding gIoU support to retinanet
datumbox Mar 30, 2022
2cff640
Fix assert
datumbox Mar 30, 2022
a6f0ea7
Merge branch 'main' into references/detection_recipe
datumbox Mar 31, 2022
61412df
Add back nesterov momentum
datumbox Apr 1, 2022
99479ee
Merge branch 'main' into references/detection_recipe
datumbox Apr 1, 2022
08307ca
Merge branch 'main' into references/detection_recipe
datumbox Apr 1, 2022
a322dd2
Merge branch 'main' into references/detection_recipe
datumbox Apr 1, 2022
eb649e8
Merge branch 'main' into references/detection_recipe
datumbox Apr 1, 2022
24b8643
Rename and extend `FastRCNNConvFCHead` to support arbitrary FCs
datumbox Apr 4, 2022
6488c41
Fix linter
datumbox Apr 4, 2022
00e182a
Merge branch 'main' into references/detection_recipe
datumbox Apr 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
Binary file not shown.
35 changes: 35 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,14 @@ def _check_input_backprop(model, inputs):
"googlenet": lambda x: x.logits,
"inception_v3": lambda x: x.logits,
"fasterrcnn_resnet50_fpn": lambda x: x[1],
"fasterrcnn_resnet50_fpn_v2": lambda x: x[1],
"fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
"fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
"maskrcnn_resnet50_fpn": lambda x: x[1],
"maskrcnn_resnet50_fpn_v2": lambda x: x[1],
"keypointrcnn_resnet50_fpn": lambda x: x[1],
"retinanet_resnet50_fpn": lambda x: x[1],
"retinanet_resnet50_fpn_v2": lambda x: x[1],
"ssd300_vgg16": lambda x: x[1],
"ssdlite320_mobilenet_v3_large": lambda x: x[1],
"fcos_resnet50_fpn": lambda x: x[1],
Expand Down Expand Up @@ -227,6 +230,7 @@ def _check_input_backprop(model, inputs):
"fcn_resnet101",
"lraspp_mobilenet_v3_large",
"maskrcnn_resnet50_fpn",
"maskrcnn_resnet50_fpn_v2",
)

# The tests for the following quantized models are flaky possibly due to inconsistent
Expand All @@ -246,6 +250,13 @@ def _check_input_backprop(model, inputs):
"max_size": 224,
"input_shape": (3, 224, 224),
},
"retinanet_resnet50_fpn_v2": {
"num_classes": 20,
"score_thresh": 0.01,
"min_size": 224,
"max_size": 224,
"input_shape": (3, 224, 224),
},
"keypointrcnn_resnet50_fpn": {
"num_classes": 2,
"min_size": 224,
Expand All @@ -259,6 +270,12 @@ def _check_input_backprop(model, inputs):
"max_size": 224,
"input_shape": (3, 224, 224),
},
"fasterrcnn_resnet50_fpn_v2": {
"num_classes": 20,
"min_size": 224,
"max_size": 224,
"input_shape": (3, 224, 224),
},
"fcos_resnet50_fpn": {
"num_classes": 2,
"score_thresh": 0.05,
Expand All @@ -272,6 +289,12 @@ def _check_input_backprop(model, inputs):
"max_size": 224,
"input_shape": (3, 224, 224),
},
"maskrcnn_resnet50_fpn_v2": {
"num_classes": 10,
"min_size": 224,
"max_size": 224,
"input_shape": (3, 224, 224),
},
"fasterrcnn_mobilenet_v3_large_fpn": {
"box_score_thresh": 0.02076,
},
Expand Down Expand Up @@ -311,6 +334,10 @@ def _check_input_backprop(model, inputs):
"max_trainable": 5,
"n_trn_params_per_layer": [36, 46, 65, 78, 88, 89],
},
"retinanet_resnet50_fpn_v2": {
"max_trainable": 5,
"n_trn_params_per_layer": [44, 74, 131, 170, 200, 203],
},
"keypointrcnn_resnet50_fpn": {
"max_trainable": 5,
"n_trn_params_per_layer": [48, 58, 77, 90, 100, 101],
Expand All @@ -319,10 +346,18 @@ def _check_input_backprop(model, inputs):
"max_trainable": 5,
"n_trn_params_per_layer": [30, 40, 59, 72, 82, 83],
},
"fasterrcnn_resnet50_fpn_v2": {
"max_trainable": 5,
"n_trn_params_per_layer": [50, 80, 137, 176, 206, 209],
},
"maskrcnn_resnet50_fpn": {
"max_trainable": 5,
"n_trn_params_per_layer": [42, 52, 71, 84, 94, 95],
},
"maskrcnn_resnet50_fpn_v2": {
"max_trainable": 5,
"n_trn_params_per_layer": [66, 96, 153, 192, 222, 225],
},
"fasterrcnn_mobilenet_v3_large_fpn": {
"max_trainable": 6,
"n_trn_params_per_layer": [22, 23, 44, 70, 91, 97, 100],
Expand Down
28 changes: 26 additions & 2 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import math
from collections import OrderedDict
from typing import List, Tuple
from typing import Dict, List, Optional, Tuple

import torch
from torch import Tensor, nn
from torchvision.ops.misc import FrozenBatchNorm2d
from torch.nn import functional as F
from torchvision.ops import FrozenBatchNorm2d, generalized_box_iou_loss


class BalancedPositiveNegativeSampler:
Expand Down Expand Up @@ -507,3 +508,26 @@ def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
axis_dim_val = torch._shape_as_tensor(input)[axis].unsqueeze(0)
min_kval = torch.min(torch.cat((torch.tensor([orig_kval], dtype=axis_dim_val.dtype), axis_dim_val), 0))
return _fake_cast_onnx(min_kval)


def _box_loss(
type: str,
box_coder: BoxCoder,
anchors_per_image: Tensor,
matched_gt_boxes_per_image: Tensor,
bbox_regression_per_image: Tensor,
cnf: Optional[Dict[str, float]] = None,
datumbox marked this conversation as resolved.
Show resolved Hide resolved
) -> Tensor:
torch._assert(type in ["l1", "smooth_l1", "giou"], f"Unsupported loss: {type}")

if type == "l1":
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
return F.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
elif type == "smooth_l1":
target_regression = box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
beta = cnf["beta"] if cnf is not None and "beta" in cnf else 1.0
return F.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum", beta=beta)
else: # giou
bbox_per_image = box_coder.decode_single(bbox_regression_per_image, anchors_per_image)
eps = cnf["eps"] if cnf is not None and "eps" in cnf else 1e-7
return generalized_box_iou_loss(bbox_per_image, matched_gt_boxes_per_image, reduction="sum", eps=eps)
13 changes: 11 additions & 2 deletions torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class BackboneWithFPN(nn.Module):
in_channels_list (List[int]): number of channels for each feature map
that is returned, in the order they are present in the OrderedDict
out_channels (int): number of channels in the FPN.
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
Attributes:
out_channels (int): the number of channels in the FPN
"""
Expand All @@ -36,6 +37,7 @@ def __init__(
in_channels_list: List[int],
out_channels: int,
extra_blocks: Optional[ExtraFPNBlock] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()

Expand All @@ -47,6 +49,7 @@ def __init__(
in_channels_list=in_channels_list,
out_channels=out_channels,
extra_blocks=extra_blocks,
norm_layer=norm_layer,
)
self.out_channels = out_channels

Expand Down Expand Up @@ -115,6 +118,7 @@ def _resnet_fpn_extractor(
trainable_layers: int,
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> BackboneWithFPN:

# select layers that wont be frozen
Expand All @@ -139,7 +143,9 @@ def _resnet_fpn_extractor(
in_channels_stage2 = backbone.inplanes // 8
in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
out_channels = 256
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
return BackboneWithFPN(
backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
)


def _validate_trainable_layers(
Expand Down Expand Up @@ -194,6 +200,7 @@ def _mobilenet_extractor(
trainable_layers: int,
returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> nn.Module:
backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
Expand Down Expand Up @@ -222,7 +229,9 @@ def _mobilenet_extractor(
return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}

in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
return BackboneWithFPN(
backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks, norm_layer=norm_layer
)
else:
m = nn.Sequential(
backbone,
Expand Down
115 changes: 111 additions & 4 deletions torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union
from typing import Any, Callable, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
Expand All @@ -24,14 +24,22 @@
__all__ = [
"FasterRCNN",
"FasterRCNN_ResNet50_FPN_Weights",
"FasterRCNN_ResNet50_FPN_V2_Weights",
"FasterRCNN_MobileNet_V3_Large_FPN_Weights",
"FasterRCNN_MobileNet_V3_Large_320_FPN_Weights",
"fasterrcnn_resnet50_fpn",
"fasterrcnn_resnet50_fpn_v2",
"fasterrcnn_mobilenet_v3_large_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn",
]


def _default_anchorgen():
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
return AnchorGenerator(anchor_sizes, aspect_ratios)


class FasterRCNN(GeneralizedRCNN):
"""
Implements Faster R-CNN.
Expand Down Expand Up @@ -216,9 +224,7 @@ def __init__(
out_channels = backbone.out_channels

if rpn_anchor_generator is None:
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
rpn_anchor_generator = _default_anchorgen()
if rpn_head is None:
rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])

Expand Down Expand Up @@ -298,6 +304,43 @@ def forward(self, x):
return x


class FastRCNNConvFCHead(nn.Sequential):
def __init__(
self,
input_size: Tuple[int, int, int],
conv_layers: List[int],
fc_layers: List[int],
norm_layer: Optional[Callable[..., nn.Module]] = None,
):
"""
Args:
input_size (Tuple[int, int, int]): the input size in CHW format.
conv_layers (list): feature dimensions of each Convolution layer
fc_layers (list): feature dimensions of each FCN layer
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
"""
in_channels, in_height, in_width = input_size

blocks = []
previous_channels = in_channels
for current_channels in conv_layers:
blocks.append(misc_nn_ops.Conv2dNormActivation(previous_channels, current_channels, norm_layer=norm_layer))
previous_channels = current_channels
blocks.append(nn.Flatten())
previous_channels = previous_channels * in_height * in_width
for current_channels in fc_layers:
blocks.append(nn.Linear(previous_channels, current_channels))
blocks.append(nn.ReLU(inplace=True))
previous_channels = current_channels

super().__init__(*blocks)
for layer in self.modules():
if isinstance(layer, nn.Conv2d):
nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity="relu")
if layer.bias is not None:
nn.init.zeros_(layer.bias)


class FastRCNNPredictor(nn.Module):
"""
Standard classification + bounding box regression layers
Expand Down Expand Up @@ -349,6 +392,10 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
DEFAULT = COCO_V1


class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
pass


class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
Expand Down Expand Up @@ -481,6 +528,66 @@ def fasterrcnn_resnet50_fpn(
return model


def fasterrcnn_resnet50_fpn_v2(
*,
weights: Optional[FasterRCNN_ResNet50_FPN_V2_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[ResNet50_Weights] = None,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
"""
Constructs an improved Faster R-CNN model with a ResNet-50-FPN backbone.

Reference: `"Benchmarking Detection Transfer Learning with Vision Transformers"
<https://arxiv.org/abs/2111.11429>`_.

:func:`~torchvision.models.detection.fasterrcnn_resnet50_fpn` for more details.

Args:
weights (FasterRCNN_ResNet50_FPN_V2_Weights, optional): The pretrained weights for the model
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int, optional): number of output classes of the model (including the background)
weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
passed (the default) this value is set to 3.
"""
weights = FasterRCNN_ResNet50_FPN_V2_Weights.verify(weights)
weights_backbone = ResNet50_Weights.verify(weights_backbone)

if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 91

is_trained = weights is not None or weights_backbone is not None
trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)

backbone = resnet50(weights=weights_backbone, progress=progress)
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer=nn.BatchNorm2d)
rpn_anchor_generator = _default_anchorgen()
rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)
box_head = FastRCNNConvFCHead(
(backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
)
model = FasterRCNN(
backbone,
num_classes=num_classes,
rpn_anchor_generator=rpn_anchor_generator,
rpn_head=rpn_head,
box_head=box_head,
**kwargs,
)

if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))

return model


def _fasterrcnn_mobilenet_v3_large_fpn(
*,
weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]],
Expand Down
Loading