Skip to content

Commit

Permalink
Fix incorrectly frozen BN on ResNet FPN backbone (#3396)
Browse files Browse the repository at this point in the history
* Avoid freezing bn1 if all layers are trainable.

* Remove misleading comments.
  • Loading branch information
datumbox authored Feb 15, 2021
1 parent 067b9dc commit eca37cf
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def resnet_fpn_backbone(
# select layers that wont be frozen
assert 0 <= trainable_layers <= 5
layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers]
# freeze layers only if pretrained backbone is used
if trainable_layers == 5:
layers_to_train.append('bn1')
for name, parameter in backbone.named_parameters():
if all([not name.startswith(layer) for layer in layers_to_train]):
parameter.requires_grad_(False)
Expand Down Expand Up @@ -152,7 +153,6 @@ def mobilenet_backbone(
assert 0 <= trainable_layers <= num_stages
freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]

# freeze layers only if pretrained backbone is used
for b in backbone[:freeze_before]:
for parameter in b.parameters():
parameter.requires_grad_(False)
Expand Down

2 comments on commit eca37cf

@vadimkantorov
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't running stats still be updated if the full network moves to train() mode? Or are there other ways to fix that?

@datumbox
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already replace the BN with frozen versions where the running stats are not updated:

norm_layer=misc_nn_ops.FrozenBatchNorm2d,

Please sign in to comment.