Skip to content

Commit

Permalink
Add RetinaNet improved weights (#5756)
Browse files Browse the repository at this point in the history
* Add RetinaNet improved weights

* Add weights.

* Change publication date.
  • Loading branch information
datumbox authored Apr 6, 2022
1 parent 08cc9a7 commit b5481e4
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,17 +672,22 @@ def forward(self, images, targets=None):
return self.eager_outputs(losses, detections)


_COMMON_META = {
"task": "image_object_detection",
"architecture": "RetinaNet",
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}


class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
transforms=ObjectDetection,
meta={
"task": "image_object_detection",
"architecture": "RetinaNet",
**_COMMON_META,
"publication_year": 2017,
"num_params": 34014999,
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
"map": 36.4,
},
Expand All @@ -691,7 +696,18 @@ class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):


class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum):
pass
COCO_V1 = Weights(
url="https://download.pytorch.org/models/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth",
transforms=ObjectDetection,
meta={
**_COMMON_META,
"publication_year": 2019,
"num_params": 38198935,
"recipe": "https://github.com/pytorch/vision/pull/5756",
"map": 41.5,
},
)
DEFAULT = COCO_V1


@handle_legacy_interface(
Expand Down

0 comments on commit b5481e4

Please sign in to comment.