Skip to content

Commit

Permalink
[Enhancement] Update pad logic in detection heads (#168)
Browse files Browse the repository at this point in the history
* pad with register

* fix lint

Co-authored-by: AllentDan <dongchunyu@sensetime.com>
  • Loading branch information
q.yao and AllentDan authored Mar 14, 2022
1 parent 9553b8c commit 987d48c
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 36 deletions.
5 changes: 3 additions & 2 deletions mmdeploy/codebase/mmdet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .core import * # noqa: F401,F403
from .deploy import (MMDetection, ObjectDetection, clip_bboxes,
get_post_processing_params, pad_with_value)
get_post_processing_params, pad_with_value,
pad_with_value_if_necessary)
from .models import * # noqa: F401,F403

__all__ = [
'get_post_processing_params', 'clip_bboxes', 'pad_with_value',
'MMDetection', 'ObjectDetection'
'pad_with_value_if_necessary', 'MMDetection', 'ObjectDetection'
]
5 changes: 3 additions & 2 deletions mmdeploy/codebase/mmdet/deploy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .mmdetection import MMDetection
from .object_detection import ObjectDetection
from .utils import clip_bboxes, get_post_processing_params, pad_with_value
from .utils import (clip_bboxes, get_post_processing_params, pad_with_value,
pad_with_value_if_necessary)

__all__ = [
'get_post_processing_params', 'clip_bboxes', 'pad_with_value',
'MMDetection', 'ObjectDetection'
'pad_with_value_if_necessary', 'MMDetection', 'ObjectDetection'
]
61 changes: 60 additions & 1 deletion mmdeploy/codebase/mmdet/deploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.core.rewriters.rewriter_utils import LibVersionChecker
from mmdeploy.utils import load_config
from mmdeploy.utils import Backend, load_config


def get_post_processing_params(deploy_cfg: Union[str, mmcv.Config]):
Expand Down Expand Up @@ -127,3 +127,62 @@ def pad_with_value(x: Tensor,
x_pad = x_pad.repeat(*repeat_size)
x = torch.cat([x, x_pad], dim=pad_dim)
return x


def pad_with_value_if_necessary(x: Tensor,
pad_dim: int,
pad_size: int,
pad_value: Optional[Any] = None):
"""Pad a tensor with a value along some dim if necessary.
Args:
x (Tensor): Input tensor.
pad_dim (int): Along which dim to pad.
pad_size (int): To which size to pad.
pad_value (Any): Filled value for padding. Defaults to `None`.
Returns:
Tensor: Padded tensor.
"""
return __pad_with_value_if_necessary(
x, pad_dim, pad_size=pad_size, pad_value=pad_value)


def __pad_with_value_if_necessary(x: Tensor,
pad_dim: int,
pad_size: int,
pad_value: Optional[Any] = None):
"""Pad a tensor with a value along some dim, do nothing on default.
Args:
x (Tensor): Input tensor.
pad_dim (int): Along which dim to pad.
pad_size (int): To which size to pad.
pad_value (Any): Filled value for padding. Defaults to `None`.
Returns:
Tensor: Padded tensor.
"""
return x


@FUNCTION_REWRITER.register_rewriter(
'mmdeploy.codebase.mmdet.deploy.utils.__pad_with_value_if_necessary',
backend=Backend.TENSORRT.value)
def __pad_with_value_if_necessary__tensorrt(ctx,
x: Tensor,
pad_dim: int,
pad_size: int,
pad_value: Optional[Any] = None):
"""Pad a tensor with a value along some dim.
Args:
x (Tensor): Input tensor.
pad_dim (int): Along which dim to pad.
pad_size (int): To which size to pad.
pad_value (Any): Filled value for padding. Defaults to `None`.
Returns:
Tensor: Padded tensor.
"""
return pad_with_value(x, pad_dim, pad_size=pad_size, pad_value=pad_value)
26 changes: 12 additions & 14 deletions mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from mmdet.core.bbox.transforms import distance2bbox

from mmdeploy.codebase.mmdet import (get_post_processing_params,
multiclass_nms, pad_with_value)
multiclass_nms,
pad_with_value_if_necessary)
from mmdeploy.codebase.mmdet.core.ops import ncnn_detection_output_forward
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend, get_backend, is_dynamic_shape
from mmdeploy.utils import Backend, is_dynamic_shape


@FUNCTION_REWRITER.register_rewriter(
Expand Down Expand Up @@ -60,7 +61,6 @@ def base_dense_head__get_bbox(ctx,
"""
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
backend = get_backend(deploy_cfg)
num_levels = len(cls_scores)

featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
Expand Down Expand Up @@ -98,10 +98,8 @@ def base_dense_head__get_bbox(ctx,
self.cls_out_channels)
if self.use_sigmoid_cls:
scores = scores.sigmoid()
nms_pre_score = scores
else:
scores = scores.softmax(-1)
nms_pre_score = scores
if with_score_factors:
score_factors = score_factors.permute(0, 2, 3,
1).reshape(batch_size,
Expand All @@ -112,16 +110,16 @@ def base_dense_head__get_bbox(ctx,
priors = priors.data
priors = priors.expand(batch_size, -1, priors.size(-1))
if pre_topk > 0:
priors = pad_with_value_if_necessary(priors, 1, pre_topk)
bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk)
scores = pad_with_value_if_necessary(scores, 1, pre_topk, 0.)
if with_score_factors:
score_factors = pad_with_value_if_necessary(
score_factors, 1, pre_topk, 0.)

nms_pre_score = scores
if with_score_factors:
nms_pre_score = nms_pre_score * score_factors
if backend == Backend.TENSORRT:
priors = pad_with_value(priors, 1, pre_topk)
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
scores = pad_with_value(scores, 1, pre_topk, 0.)
nms_pre_score = pad_with_value(nms_pre_score, 1, pre_topk, 0.)
if with_score_factors:
score_factors = pad_with_value(score_factors, 1, pre_topk,
0.)

# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
Expand Down Expand Up @@ -180,7 +178,7 @@ def base_dense_head__get_bbox(ctx,
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.base_dense_head.BaseDenseHead'
'.get_bboxes',
backend='ncnn')
backend=Backend.NCNN.value)
def base_dense_head__get_bboxes__ncnn(ctx,
self,
cls_scores,
Expand Down
15 changes: 7 additions & 8 deletions mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import torch

from mmdeploy.codebase.mmdet import (get_post_processing_params,
multiclass_nms, pad_with_value)
multiclass_nms,
pad_with_value_if_necessary)
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend, get_backend, is_dynamic_shape
from mmdeploy.utils import Backend, is_dynamic_shape


@FUNCTION_REWRITER.register_rewriter(
Expand Down Expand Up @@ -95,13 +96,11 @@ def rpn_head__get_bboxes(ctx,

anchors = anchors.expand_as(bbox_pred)

backend = get_backend(deploy_cfg)
# topk in tensorrt does not support shape<k
# concate zero to enable topk,
if backend == Backend.TENSORRT:
scores = pad_with_value(scores, 1, pre_topk, 0.)
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
anchors = pad_with_value(anchors, 1, pre_topk)
scores = pad_with_value_if_necessary(scores, 1, pre_topk, 0.)
bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk)
anchors = pad_with_value_if_necessary(anchors, 1, pre_topk)

if pre_topk > 0:
_, topk_inds = scores.squeeze(2).topk(pre_topk)
Expand Down Expand Up @@ -145,7 +144,7 @@ def rpn_head__get_bboxes(ctx,


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.RPNHead.get_bboxes', backend='ncnn')
'mmdet.models.dense_heads.RPNHead.get_bboxes', backend=Backend.NCNN.value)
def rpn_head__get_bboxes__ncnn(ctx,
self,
cls_scores,
Expand Down
16 changes: 8 additions & 8 deletions mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import torch

from mmdeploy.codebase.mmdet import (get_post_processing_params,
multiclass_nms, pad_with_value)
multiclass_nms,
pad_with_value_if_necessary)
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend, get_backend, is_dynamic_shape
from mmdeploy.utils import Backend, is_dynamic_shape


@FUNCTION_REWRITER.register_rewriter(
Expand Down Expand Up @@ -90,13 +91,11 @@ def yolov3_head__get_bboxes(ctx,
conf_pred = torch.sigmoid(pred_map[..., 4])
cls_pred = torch.sigmoid(pred_map[..., 5:]).view(
batch_size, -1, self.num_classes) # Cls pred one-hot.
backend = get_backend(ctx.cfg)
# topk in tensorrt does not support shape<k
# concate zero to enable topk,
if backend == Backend.TENSORRT:
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
conf_pred = pad_with_value(conf_pred, 1, pre_topk, 0.)
cls_pred = pad_with_value(cls_pred, 1, pre_topk, 0.)
bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk)
conf_pred = pad_with_value_if_necessary(conf_pred, 1, pre_topk, 0.)
cls_pred = pad_with_value_if_necessary(cls_pred, 1, pre_topk, 0.)

if pre_topk > 0:
_, topk_inds = conf_pred.topk(pre_topk)
Expand Down Expand Up @@ -161,7 +160,8 @@ def yolov3_head__get_bboxes(ctx,


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes', backend='ncnn')
func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes',
backend=Backend.NCNN.value)
def yolov3_head__get_bboxes__ncnn(ctx,
self,
pred_maps,
Expand Down
12 changes: 11 additions & 1 deletion tests/test_codebase/test_mmdet/test_mmdet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from mmdeploy.codebase import import_codebase
from mmdeploy.codebase.mmdet import (clip_bboxes, get_post_processing_params,
pad_with_value)
pad_with_value,
pad_with_value_if_necessary)
from mmdeploy.utils import Codebase

import_codebase(Codebase.MMDET)
Expand All @@ -29,6 +30,15 @@ def test_pad_with_value():
assert np.allclose(padded_x.sum(), x.sum(), rtol=1e-03, atol=1e-05)


def test_pad_with_value_if_necessary():
x = torch.rand(3, 2)
padded_x = pad_with_value_if_necessary(
x, pad_dim=1, pad_size=4, pad_value=0)
assert np.allclose(
padded_x.shape, torch.Size([3, 2]), rtol=1e-03, atol=1e-05)
assert np.allclose(padded_x.sum(), x.sum(), rtol=1e-03, atol=1e-05)


config_with_mmdet_params = mmcv.Config(
dict(
codebase_config=dict(
Expand Down

0 comments on commit 987d48c

Please sign in to comment.