From ef4d1bc7acae4b97b3094d933641b3ce72109fa4 Mon Sep 17 00:00:00 2001 From: akstt <13146229470@163.com> Date: Wed, 11 Jan 2023 19:16:03 +0800 Subject: [PATCH 01/13] Suitable for Ascend. Modify SSD and RetinaHead Model --- mmdet/core/bbox/assigners/__init__.py | 5 +- .../bbox/assigners/ascend_assign_result.py | 37 ++ .../bbox/assigners/ascend_max_iou_assigner.py | 141 +++++++ mmdet/models/dense_heads/__init__.py | 6 +- .../models/dense_heads/ascend_anchor_head.py | 394 ++++++++++++++++++ .../models/dense_heads/ascend_retina_head.py | 123 ++++++ mmdet/models/dense_heads/ascend_ssd_head.py | 318 ++++++++++++++ mmdet/utils/__init__.py | 4 +- mmdet/utils/ascend_util.py | 39 ++ .../test_dense_heads/test_ascend_head.py | 228 ++++++++++ 10 files changed, 1291 insertions(+), 4 deletions(-) create mode 100644 mmdet/core/bbox/assigners/ascend_assign_result.py create mode 100644 mmdet/core/bbox/assigners/ascend_max_iou_assigner.py create mode 100644 mmdet/models/dense_heads/ascend_anchor_head.py create mode 100644 mmdet/models/dense_heads/ascend_retina_head.py create mode 100644 mmdet/models/dense_heads/ascend_ssd_head.py create mode 100644 mmdet/utils/ascend_util.py create mode 100644 tests/test_models/test_dense_heads/test_ascend_head.py diff --git a/mmdet/core/bbox/assigners/__init__.py b/mmdet/core/bbox/assigners/__init__.py index 5eaf7fa3af6..ec092b25672 100644 --- a/mmdet/core/bbox/assigners/__init__.py +++ b/mmdet/core/bbox/assigners/__init__.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .approx_max_iou_assigner import ApproxMaxIoUAssigner +from .ascend_assign_result import AscendAssignResult +from .ascend_max_iou_assigner import AscendMaxIoUAssigner from .assign_result import AssignResult from .atss_assigner import ATSSAssigner from .base_assigner import BaseAssigner @@ -18,5 +20,6 @@ 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult', 'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner', 'HungarianAssigner', 'RegionAssigner', 'UniformAssigner', 'SimOTAAssigner', - 'TaskAlignedAssigner', 'MaskHungarianAssigner' + 'TaskAlignedAssigner', 'MaskHungarianAssigner', + 'AscendAssignResult', 'AscendMaxIoUAssigner' ] diff --git a/mmdet/core/bbox/assigners/ascend_assign_result.py b/mmdet/core/bbox/assigners/ascend_assign_result.py new file mode 100644 index 00000000000..39f8af6ea58 --- /dev/null +++ b/mmdet/core/bbox/assigners/ascend_assign_result.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmdet.utils import util_mixins + + +class AscendAssignResult(util_mixins.NiceRepr): + """Stores ascend assignments between predicted and truth boxes. + + Attributes: + concat_num_gts (list[int]): the number of truth boxes considered. + + concat_pos_mask (IntTensor): Positive samples mask in all images. + + concat_neg_mask (IntTensor): Negative samples mask in all images. + + concat_max_overlaps (FloatTensor): The max overlaps of all bboxes + and ground truth boxes. + + concat_anchor_gt_indes(None | LongTensor): The the assigned truth + box index of all anchors + . + + concat_anchor_gt_labels(None | LongTensor): The gt labels of all anchors + """ + + def __init__(self, concat_num_gts, concat_pos_mask, concat_neg_mask, + concat_max_overlaps, concat_anchor_gt_indes=None, + concat_anchor_gt_labels=None): + self.concat_num_gts = concat_num_gts + self.concat_pos_mask = concat_pos_mask + self.concat_neg_mask = concat_neg_mask + self.concat_max_overlaps = concat_max_overlaps + self.concat_anchor_gt_indes = concat_anchor_gt_indes + self.concat_anchor_gt_labels = concat_anchor_gt_labels + # Interface for possible user-defined properties + self._extra_properties = {} diff --git a/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py b/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py new file mode 100644 index 00000000000..a406c111237 --- /dev/null +++ b/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py @@ -0,0 +1,141 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from ..builder import BBOX_ASSIGNERS +from ..iou_calculators import build_iou_calculator +from .ascend_assign_result import AscendAssignResult +from .base_assigner import BaseAssigner +from ....utils import set_index + + +@BBOX_ASSIGNERS.register_module() +class AscendMaxIoUAssigner(BaseAssigner): + """Assign a corresponding gt bbox or background to each bbox. + + Each proposals will be assigned with `-1`, or a semi-positive integer + indicating the ground truth index. + + - -1: negative sample, no assigned gt + - semi-positive integer: positive sample, index (0-based) of assigned gt + + Args: + pos_iou_thr (float): IoU threshold for positive bboxes. + neg_iou_thr (float or tuple): IoU threshold for negative bboxes. + min_pos_iou (float): Minimum iou for a bbox to be considered as a + positive bbox. Positive samples can have smaller IoU than + pos_iou_thr due to the 4th step (assign max IoU sample to each gt). + `min_pos_iou` is set to avoid assigning bboxes that have extremely + small iou with GT as positive samples. It brings about 0.3 mAP + improvements in 1x schedule but does not affect the performance of + 3x schedule. More comparisons can be found in + `PR #7464 `_. + gt_max_assign_all (bool): Whether to assign all bboxes with the same + highest overlap with some gt to that gt. + ignore_iof_thr (float): IoF threshold for ignoring bboxes (if + `gt_bboxes_ignore` is specified). Negative values mean not + ignoring any bboxes. + ignore_wrt_candidates (bool): Whether to compute the iof between + `bboxes` and `gt_bboxes_ignore`, or the contrary. + match_low_quality (bool): Whether to allow low quality matches. This is + usually allowed for RPN and single stage detectors, but not allowed + in the second stage. Details are demonstrated in Step 4. + gpu_assign_thr (int): The upper bound of the number of GT for GPU + assign. When the number of gt is above this threshold, will assign + on CPU device. Negative values mean not assign on CPU. + """ + + def __init__(self, + pos_iou_thr, + neg_iou_thr, + min_pos_iou=.0, + gt_max_assign_all=True, + ignore_iof_thr=-1, + ignore_wrt_candidates=True, + match_low_quality=True, + gpu_assign_thr=-1, + iou_calculator=dict(type='BboxOverlaps2D')): + self.pos_iou_thr = pos_iou_thr + self.neg_iou_thr = neg_iou_thr + self.min_pos_iou = min_pos_iou + self.gt_max_assign_all = gt_max_assign_all + self.ignore_iof_thr = ignore_iof_thr + self.ignore_wrt_candidates = ignore_wrt_candidates + self.gpu_assign_thr = gpu_assign_thr + self.match_low_quality = match_low_quality + self.iou_calculator = build_iou_calculator(iou_calculator) + + def assign(self, + concat_bboxes, + concat_gt_bboxes, + concat_gt_bboxes_ignore=None, + concat_gt_labels=None, + concat_bboxes_ignore_mask=None, + concat_num_gts=None): + """Assign gt to bboxes. + Args: + concat_bboxes (Tensor): Bounding boxes to be assigned, shape(b, n, 4). + concat_gt_bboxes (Tensor): Ground truth boxes, shape (b, k, 4). + concat_gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are + labelled as `ignored`, e.g., crowd boxes in COCO. + concat_gt_labels (Tensor, optional): Label of gt_bboxes, shape (b, k, ). + concat_bboxes_ignore_mask: (b, n) + concat_num_gts:(b, ) + Returns: + :obj:`AssignResult`: The assign result. + """ + concat_overlaps = self.iou_calculator(concat_gt_bboxes, concat_bboxes) + concat_overlaps = set_index(concat_overlaps, concat_bboxes_ignore_mask.unsqueeze(1).float(), -1, neg=True) + if self.ignore_iof_thr > 0 and concat_gt_bboxes_ignore is not None: + if self.ignore_wrt_candidates: + concat_ignore_overlaps = self.iou_calculator(concat_bboxes, concat_gt_bboxes_ignore, mode='iof') + concat_ignore_overlaps = set_index(concat_ignore_overlaps, concat_bboxes_ignore_mask, -1) + concat_ignore_max_overlaps, _ = concat_ignore_overlaps.max(dim=2) + else: + concat_ignore_overlaps = self.iou_calculator(concat_gt_bboxes_ignore, concat_bboxes, mode='iof') + concat_ignore_overlaps = set_index(concat_ignore_overlaps, concat_bboxes_ignore_mask, -1) + concat_ignore_max_overlaps, _ = concat_ignore_overlaps.max(dim=1) + concat_ignore_mask = concat_ignore_max_overlaps > self.ignore_iof_thr + concat_overlaps = set_index(concat_overlaps, concat_ignore_mask, -1) + concat_assign_result = self.concat_assign_wrt_overlaps(concat_overlaps, concat_gt_labels, concat_num_gts) + return concat_assign_result + + def concat_assign_wrt_overlaps(self, concat_overlaps, concat_gt_labels=None, concat_num_gts=None): + num_images, num_gts, num_bboxes = concat_overlaps.size() + concat_max_overlaps, concat_argmax_overlaps = concat_overlaps.max(dim=1) + if isinstance(self.neg_iou_thr, float): + concat_neg_mask = ((concat_max_overlaps >= 0) & (concat_max_overlaps < self.neg_iou_thr)).int() + elif isinstance(self.neg_iou_thr, tuple): + assert len(self.neg_iou_thr) == 2 + concat_neg_mask = ((concat_max_overlaps >= self.neg_iou_thr[0]) & (concat_max_overlaps < self.neg_iou_thr[1])).int() + else: + concat_neg_mask = torch.zeros(concat_max_overlaps.size(), dtype=torch.int, device=concat_max_overlaps.device) + concat_pos_mask = (concat_max_overlaps >= self.pos_iou_thr).int() + if self.match_low_quality: + concat_gt_max_overlaps, concat_gt_argmax_overlaps = concat_overlaps.max(dim=2) + concat_index_bool = (concat_gt_max_overlaps >= self.min_pos_iou) & (concat_gt_max_overlaps > 0) + if self.gt_max_assign_all: + pos_inds_low_quality = (concat_overlaps == concat_gt_max_overlaps.unsqueeze(2)) & concat_index_bool.unsqueeze(2) + for i in range(num_gts): + pos_inds_low_quality_gt = pos_inds_low_quality[:, i, :] + concat_argmax_overlaps[pos_inds_low_quality_gt] = i + concat_pos_mask[pos_inds_low_quality_gt] = 1 + else: + index_temp = torch.arange(0, num_gts, device=concat_max_overlaps.device) + for index_image in range(num_images): + gt_argmax_overlaps = concat_gt_argmax_overlaps[index_image] + index_bool = concat_index_bool[index_image] + pos_inds_low_quality = gt_argmax_overlaps[index_bool] + concat_argmax_overlaps[index_image][pos_inds_low_quality] = index_temp[index_bool] + concat_pos_mask[index_image][pos_inds_low_quality] = 1 + concat_neg_mask = concat_neg_mask * (1 - concat_pos_mask) + if concat_gt_labels is not None: + concat_anchor_gt_labels = torch.zeros((num_images, num_bboxes), + dtype=concat_gt_labels.dtype, + device=concat_gt_labels.device) + for index_image in range(num_images): + concat_anchor_gt_labels[index_image] = torch.index_select(concat_gt_labels[index_image], 0, + concat_argmax_overlaps[index_image]) + else: + concat_anchor_gt_labels = None + return AscendAssignResult(concat_num_gts, concat_pos_mask, concat_neg_mask, + concat_max_overlaps, concat_argmax_overlaps, concat_anchor_gt_labels) diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py index 1c2286996e7..ebffef847af 100644 --- a/mmdet/models/dense_heads/__init__.py +++ b/mmdet/models/dense_heads/__init__.py @@ -2,6 +2,9 @@ from .anchor_free_head import AnchorFreeHead from .anchor_head import AnchorHead from .atss_head import ATSSHead +from .ascend_anchor_head import AscendAnchorHead +from .ascend_retina_head import AscendRetinaHead +from .ascend_ssd_head import AscendSSDHead from .autoassign_head import AutoAssignHead from .cascade_rpn_head import CascadeRPNHead, StageCascadeRPNHead from .centernet_head import CenterNetHead @@ -54,5 +57,6 @@ 'DETRHead', 'YOLOFHead', 'DeformableDETRHead', 'SOLOHead', 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead', 'DecoupledSOLOLightHead', 'LADHead', 'TOODHead', 'MaskFormerHead', - 'Mask2FormerHead', 'SOLOV2Head', 'DDODHead' + 'Mask2FormerHead', 'SOLOV2Head', 'DDODHead', + 'AscendAnchorHead', 'AscendRetinaHead', 'AscendSSDHead' ] diff --git a/mmdet/models/dense_heads/ascend_anchor_head.py b/mmdet/models/dense_heads/ascend_anchor_head.py new file mode 100644 index 00000000000..b729c836f46 --- /dev/null +++ b/mmdet/models/dense_heads/ascend_anchor_head.py @@ -0,0 +1,394 @@ +import torch +from ..builder import HEADS +from .anchor_head import AnchorHead +from ...core.bbox.assigners import AscendMaxIoUAssigner +from ...core.bbox.samplers import PseudoSampler +from ...utils import set_index, images_to_levels, generate_max_gt_nums + + +@HEADS.register_module() +class AscendAnchorHead(AnchorHead): + """Ascend Anchor-based head (RetinaNet, SSD, etc.). + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + feat_channels (int): Number of hidden channels. Used in child classes. + anchor_generator (dict): Config dict for anchor generator + bbox_coder (dict): Config of bounding box coder. + reg_decoded_bbox (bool): If true, the regression loss would be + applied directly on decoded bounding boxes, converting both + the predicted boxes and regression targets to absolute + coordinates format. Default False. It should be `True` when + using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. + loss_cls (dict): Config of classification loss. + loss_bbox (dict): Config of localization loss. + train_cfg (dict): Training config of anchor head. + test_cfg (dict): Testing config of anchor head. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ # noqa: W605 + + def __init__(self, + num_classes, + in_channels, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8, 16, 32], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + clip_border=True, + target_means=(.0, .0, .0, .0), + target_stds=(1.0, 1.0, 1.0, 1.0)), + reg_decoded_bbox=False, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0), + loss_bbox=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), + train_cfg=None, + test_cfg=None, + init_cfg=dict(type='Normal', layer='Conv2d', std=0.01)): + super(AscendAnchorHead, self).__init__( + num_classes=num_classes, + in_channels=in_channels, + feat_channels=feat_channels, + anchor_generator=anchor_generator, + bbox_coder=bbox_coder, + reg_decoded_bbox=reg_decoded_bbox, + loss_cls=loss_cls, + loss_bbox=loss_bbox, + train_cfg=train_cfg, + test_cfg=test_cfg, + init_cfg=init_cfg + ) + + def _get_concat_gt_bboxes(self, + gt_bboxes_list, + num_images, + gt_nums, + device, + max_gt_labels): + """Get ground truth bboxes of all image. + + Args: + gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. + num_images (int): The num of images. + gt_nums(list[int]): The ground truth bboxes num of each image. + device (torch.device | str): Device for returned tensors + max_gt_labels(int): The max ground truth bboxes num of all image. + Returns: + concat_gt_bboxes: (Tensor): Ground truth bboxes of all image. + """ + if not hasattr(self, 'concat_gt_bboxes'): + self.concat_gt_bboxes = {} + if not hasattr(self, 'min_anchor'): + self.min_anchor = (-1354, -1344) + if gt_bboxes_list is None: + concat_gt_bboxes = None + else: + if self.concat_gt_bboxes.get(max_gt_labels) is None: + concat_gt_bboxes = torch.zeros((num_images, max_gt_labels, 4), + dtype=gt_bboxes_list[0].dtype, + device=device) + concat_gt_bboxes[:, :, :2] = self.min_anchor[0] + concat_gt_bboxes[:, :, 2:] = self.min_anchor[1] + self.concat_gt_bboxes[max_gt_labels] = concat_gt_bboxes.clone() + else: + concat_gt_bboxes = self.concat_gt_bboxes.get(max_gt_labels).clone() + for index_imgs, gt_bboxes in enumerate(gt_bboxes_list): + concat_gt_bboxes[index_imgs, :gt_nums[index_imgs]] = gt_bboxes + return concat_gt_bboxes + + def _get_concat_gt_bboxes_ignore(self, + gt_bboxes_ignore_list, + num_images, + gt_nums, + device): + """Ground truth bboxes to be ignored of all image. + + Args: + gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be + ignored. + num_images (int): The num of images. + gt_nums(list[int]): The ground truth bboxes num of each image. + device (torch.device | str): Device for returned tensors + Returns: + concat_gt_bboxes_ignore: (Tensor): Ground truth bboxes to be + ignored of all image. + """ + # TODO: support gt_bboxes_ignore_list + if gt_bboxes_ignore_list is None: + concat_gt_bboxes_ignore = None + else: + raise RuntimeError("gt_bboxes_ignore not support yet") + return concat_gt_bboxes_ignore + + def _get_concat_gt_labels(self, + gt_labels_list, + num_images, + gt_nums, + device, + max_gt_labels): + """Ground truth bboxes to be ignored of all image. + + Args: + gt_labels_list (list[Tensor]): Ground truth labels. + num_images (int): The num of images. + gt_nums(list[int]): The ground truth bboxes num of each image. + device (torch.device | str): Device for returned tensors + Returns: + concat_gt_labels: (Tensor): Ground truth labels of all image. + """ + if gt_labels_list is None: + concat_gt_labels = None + else: + concat_gt_labels = torch.zeros((num_images, max_gt_labels), + dtype=gt_labels_list[0].dtype, + device=device) + for index_imgs, gt_labels in enumerate(gt_labels_list): + concat_gt_labels[index_imgs, :gt_nums[index_imgs]] = gt_labels + + return concat_gt_labels + + def _get_targets_concat(self, + concat_anchors, + concat_valid_flags, + concat_gt_bboxes, + concat_gt_bboxes_ignore, + concat_gt_labels, + img_metas, + label_channels=1, + unmap_outputs=True): + """Compute regression and classification targets for anchors in all + images. + + Args: + concat_anchors (Tensor): anchors of all image, which are + concatenated into a single tensor of + shape (num_imgs, num_anchors ,4). + concat_valid_flags (Tensor): valid flags of all image, + which are concatenated into a single tensor of + shape (num_imgs, num_anchors,). + concat_gt_bboxes (Tensor): Ground truth bboxes of all image, + shape (num_imgs, max_gt_nums, 4). + concat_gt_bboxes_ignore (Tensor): Ground truth bboxes to be + ignored, shape (num_imgs, num_ignored_gts, 4). + concat_gt_labels (Tensor): Ground truth labels of each box, + shape (num_imgs, max_gt_nums,). + img_metas (list[dict]): Meta info of each image. + label_channels (int): Channel of label. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: + concat_labels (Tensor): Labels of all level + concat_label_weights (Tensor): Label weights of all level + concat_bbox_targets (Tensor): BBox targets of all level + concat_bbox_weights (Tensor): BBox weights of all level + concat_pos_mask (Tensor): Positive samples mask in all images + concat_neg_mask (Tensor): Negative samples mask in all images + sampling_result (Sampling): The result of sampling, + default: None. + """ + num_imgs, num_anchors, _ = concat_anchors.size() + # assign gt and sample concat_anchors + assign_result = self.assigner.assign( + concat_anchors, concat_gt_bboxes, concat_gt_bboxes_ignore, + None if self.sampling else concat_gt_labels, + concat_bboxes_ignore_mask=concat_valid_flags) + # TODO: support sampling_result + sampling_result = None + concat_pos_mask = assign_result.concat_pos_mask + concat_neg_mask = assign_result.concat_neg_mask + concat_anchor_gt_indes = assign_result.concat_anchor_gt_indes + concat_anchor_gt_labels = assign_result.concat_anchor_gt_labels + + concat_anchor_gt_bboxes = torch.zeros(concat_anchors.size(), + dtype=concat_anchors.dtype, + device=concat_anchors.device) + for index_imgs in range(num_imgs): + concat_anchor_gt_bboxes[index_imgs] = torch.index_select(concat_gt_bboxes[index_imgs], 0, + concat_anchor_gt_indes[index_imgs]) + + concat_bbox_targets = torch.zeros_like(concat_anchors) + concat_bbox_weights = torch.zeros_like(concat_anchors) + concat_labels = concat_anchors.new_full((num_imgs, num_anchors), self.num_classes, dtype=torch.int) + concat_label_weights = concat_anchors.new_zeros((num_imgs, num_anchors), dtype=torch.float) + + if not self.reg_decoded_bbox: + concat_pos_bbox_targets = self.bbox_coder.encode(concat_anchors, concat_anchor_gt_bboxes) + else: + concat_pos_bbox_targets = concat_anchor_gt_bboxes + + concat_bbox_targets = set_index(concat_bbox_targets, concat_pos_mask.unsqueeze(2), concat_pos_bbox_targets) + concat_bbox_weights = set_index(concat_bbox_weights, concat_pos_mask.unsqueeze(2), 1.0) + if concat_gt_labels is None: + concat_labels = set_index(concat_labels, concat_pos_mask, 0.0) + else: + concat_labels = set_index(concat_labels, concat_pos_mask, concat_anchor_gt_labels) + if self.train_cfg.pos_weight <= 0: + concat_label_weights = set_index(concat_label_weights, concat_pos_mask, 1.0) + else: + concat_label_weights = set_index(concat_label_weights, concat_pos_mask, self.train_cfg.pos_weight) + concat_label_weights = set_index(concat_label_weights, concat_neg_mask, 1.0) + return (concat_labels, concat_label_weights, concat_bbox_targets, + concat_bbox_weights, concat_pos_mask, + concat_neg_mask, sampling_result) + + def get_targets(self, + anchor_list, + valid_flag_list, + gt_bboxes_list, + img_metas, + gt_bboxes_ignore_list=None, + gt_labels_list=None, + label_channels=1, + unmap_outputs=True, + return_sampling_results=False, + return_level=True): + """Compute regression and classification targets for anchors in + multiple images. + + Args: + anchor_list (list[list[Tensor]]): Multi level anchors of each + image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, 4). + valid_flag_list (list[list[Tensor]]): Multi level valid flags of + each image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, ) + gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. + img_metas (list[dict]): Meta info of each image. + gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be + ignored. + gt_labels_list (list[Tensor]): Ground truth labels of each box. + label_channels (int): Channel of label. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + return_sampling_results (bool): Whether to return the result of + sample. + return_level (bool): Whether to map outputs back to the levels + of feature map sizes. + Returns: + tuple: Usually returns a tuple containing learning targets. + + - labels_list (list[Tensor]): Labels of each level. + - label_weights_list (list[Tensor]): Label weights of each + level. + - bbox_targets_list (list[Tensor]): BBox targets of each level. + - bbox_weights_list (list[Tensor]): BBox weights of each level. + - num_total_pos (int): Number of positive samples in all + images. + - num_total_neg (int): Number of negative samples in all + images. + + additional_returns: This function enables user-defined returns from + `self._get_targets_single`. These returns are currently refined + to properties at each feature map (i.e. having HxW dimension). + The results will be concatenated after the end + """ + assert gt_bboxes_ignore_list is None + assert unmap_outputs is True + assert return_sampling_results is False + assert self.train_cfg.allowed_border < 0 + assert isinstance(self.assigner, AscendMaxIoUAssigner) + assert isinstance(self.sampler, PseudoSampler) + num_imgs = len(img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + device = anchor_list[0][0].device + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + + concat_anchor_list = [] + concat_valid_flag_list = [] + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + concat_anchor_list.append(torch.cat(anchor_list[i])) + concat_valid_flag_list.append(torch.cat(valid_flag_list[i])) + concat_anchors = torch.cat( + [torch.unsqueeze(anchor, 0) for anchor in concat_anchor_list], 0 + ) + concat_valid_flags = torch.cat( + [torch.unsqueeze(concat_valid_flag, 0) + for concat_valid_flag in concat_valid_flag_list], 0 + ) + + gt_nums = [len(gt_bbox) for gt_bbox in gt_bboxes_list] + max_gt_nums = generate_max_gt_nums(gt_nums) + concat_gt_bboxes = self._get_concat_gt_bboxes( + gt_bboxes_list, + num_imgs, + gt_nums, + device, + max_gt_nums) + concat_gt_bboxes_ignore = self._get_concat_gt_bboxes_ignore( + gt_bboxes_ignore_list, + num_imgs, + gt_nums, + device + ) + concat_gt_labels = self._get_concat_gt_labels( + gt_labels_list, + num_imgs, + gt_nums, + device, + max_gt_nums) + + results = self._get_targets_concat( + concat_anchors, + concat_valid_flags, + concat_gt_bboxes, + concat_gt_bboxes_ignore, + concat_gt_labels, + img_metas, + label_channels=label_channels, + unmap_outputs=unmap_outputs) + + (concat_labels, concat_label_weights, concat_bbox_targets, + concat_bbox_weights, concat_pos_mask, + concat_neg_mask, sampling_result) = results[:7] + rest_results = list(results[7:]) # user-added return values + + # sampled anchors of all images + min_num = torch.ones((num_imgs,), + dtype=concat_pos_mask.dtype, + device=concat_pos_mask.device) + num_total_pos = torch.sum(torch.max(torch.sum(concat_pos_mask, dim=1), + min_num)) + num_total_neg = torch.sum(torch.max(torch.sum(concat_neg_mask, dim=1), + min_num)) + if return_level is True: + labels_list = images_to_levels( + concat_labels, + num_level_anchors) + label_weights_list = images_to_levels( + concat_label_weights, + num_level_anchors) + bbox_targets_list = images_to_levels( + concat_bbox_targets, + num_level_anchors) + bbox_weights_list = images_to_levels( + concat_bbox_weights, + num_level_anchors) + res = (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, num_total_pos, num_total_neg) + if return_sampling_results: + res = res + (sampling_result,) + for i, r in enumerate(rest_results): # user-added return values + rest_results[i] = images_to_levels(r, num_level_anchors) + + return res + tuple(rest_results) + else: + res = (concat_labels, concat_label_weights, concat_bbox_targets, + concat_bbox_weights, concat_pos_mask, concat_neg_mask, + sampling_result, num_total_pos, num_total_neg, + concat_anchors) + return res + diff --git a/mmdet/models/dense_heads/ascend_retina_head.py b/mmdet/models/dense_heads/ascend_retina_head.py new file mode 100644 index 00000000000..13bc52cf855 --- /dev/null +++ b/mmdet/models/dense_heads/ascend_retina_head.py @@ -0,0 +1,123 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..builder import HEADS +from .ascend_anchor_head import AscendAnchorHead +from .retina_head import RetinaHead + + +@HEADS.register_module() +class AscendRetinaHead(RetinaHead, AscendAnchorHead): + r"""An anchor-based head used in `RetinaNet + `_. + + The head contains two subnetworks. The first classifies anchor boxes and + the second regresses deltas for the anchors. + + Example: + >>> import torch + >>> self = RetinaHead(11, 7) + >>> x = torch.rand(1, 7, 32, 32) + >>> cls_score, bbox_pred = self.forward_single(x) + >>> # Each anchor predicts a score for each class except background + >>> cls_per_anchor = cls_score.shape[1] / self.num_anchors + >>> box_per_anchor = bbox_pred.shape[1] / self.num_anchors + >>> assert cls_per_anchor == (self.num_classes) + >>> assert box_per_anchor == 4 + """ + + def __init__(self, + num_classes, + in_channels, + stacked_convs=4, + conv_cfg=None, + norm_cfg=None, + anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + init_cfg=dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', + name='retina_cls', + std=0.01, + bias_prob=0.01)), + **kwargs): + super(AscendRetinaHead, self).__init__( + num_classes=num_classes, + in_channels=in_channels, + stacked_convs=stacked_convs, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + anchor_generator=anchor_generator, + init_cfg=init_cfg, + **kwargs) + + def get_targets(self, + anchor_list, + valid_flag_list, + gt_bboxes_list, + img_metas, + gt_bboxes_ignore_list=None, + gt_labels_list=None, + label_channels=1, + unmap_outputs=True, + return_sampling_results=False, + return_level=True): + """Compute regression and classification targets for anchors in + multiple images. + + Args: + anchor_list (list[list[Tensor]]): Multi level anchors of each + image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, 4). + valid_flag_list (list[list[Tensor]]): Multi level valid flags of + each image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, ) + gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. + img_metas (list[dict]): Meta info of each image. + gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be + ignored. + gt_labels_list (list[Tensor]): Ground truth labels of each box. + label_channels (int): Channel of label. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + return_sampling_results (bool): Whether to return the result of + sample. + return_level (bool): Whether to map outputs back to the levels + of feature map sizes. + Returns: + tuple: Usually returns a tuple containing learning targets. + + - labels_list (list[Tensor]): Labels of each level. + - label_weights_list (list[Tensor]): Label weights of each + level. + - bbox_targets_list (list[Tensor]): BBox targets of each level. + - bbox_weights_list (list[Tensor]): BBox weights of each level. + - num_total_pos (int): Number of positive samples in all + images. + - num_total_neg (int): Number of negative samples in all + images. + + additional_returns: This function enables user-defined returns from + `self._get_targets_single`. These returns are currently refined + to properties at each feature map (i.e. having HxW dimension). + The results will be concatenated after the end + """ + return AscendAnchorHead.get_targets( + self, + anchor_list, + valid_flag_list, + gt_bboxes_list, + img_metas, + gt_bboxes_ignore_list, + gt_labels_list, + label_channels, + unmap_outputs, + return_sampling_results, + return_level) diff --git a/mmdet/models/dense_heads/ascend_ssd_head.py b/mmdet/models/dense_heads/ascend_ssd_head.py new file mode 100644 index 00000000000..ee6d279c92f --- /dev/null +++ b/mmdet/models/dense_heads/ascend_ssd_head.py @@ -0,0 +1,318 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F +from mmcv.runner import force_fp32 +from ..builder import HEADS +from ..losses import smooth_l1_loss +from .ascend_anchor_head import AscendAnchorHead +from .ssd_head import SSDHead + + +@HEADS.register_module() +class AscendSSDHead(SSDHead, AscendAnchorHead): + """Ascend SSD head used in https://arxiv.org/abs/1512.02325. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + stacked_convs (int): Number of conv layers in cls and reg tower. + Default: 0. + feat_channels (int): Number of hidden channels when stacked_convs + > 0. Default: 256. + use_depthwise (bool): Whether to use DepthwiseSeparableConv. + Default: False. + conv_cfg (dict): Dictionary to construct and config conv layer. + Default: None. + norm_cfg (dict): Dictionary to construct and config norm layer. + Default: None. + act_cfg (dict): Dictionary to construct and config activation layer. + Default: None. + anchor_generator (dict): Config dict for anchor generator + bbox_coder (dict): Config of bounding box coder. + reg_decoded_bbox (bool): If true, the regression loss would be + applied directly on decoded bounding boxes, converting both + the predicted boxes and regression targets to absolute + coordinates format. Default False. It should be `True` when + using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. + train_cfg (dict): Training config of anchor head. + test_cfg (dict): Testing config of anchor head. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ # noqa: W605 + + def __init__(self, + num_classes=80, + in_channels=(512, 1024, 512, 256, 256, 256), + stacked_convs=0, + feat_channels=256, + use_depthwise=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + anchor_generator=dict( + type='SSDAnchorGenerator', + scale_major=False, + input_size=300, + strides=[8, 16, 32, 64, 100, 300], + ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]), + basesize_ratio_range=(0.1, 0.9)), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + clip_border=True, + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0], + ), + reg_decoded_bbox=False, + train_cfg=None, + test_cfg=None, + init_cfg=dict( + type='Xavier', + layer='Conv2d', + distribution='uniform', + bias=0)): + super(AscendSSDHead, self).__init__( + num_classes=num_classes, + in_channels=in_channels, + stacked_convs=stacked_convs, + feat_channels=feat_channels, + use_depthwise=use_depthwise, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + anchor_generator=anchor_generator, + bbox_coder=bbox_coder, + reg_decoded_bbox=reg_decoded_bbox, + train_cfg=train_cfg, + test_cfg=test_cfg, + init_cfg=init_cfg + ) + assert self.reg_decoded_bbox is False, \ + 'reg_decoded_bbox only support False now.' + + def get_static_anchors(self, featmap_sizes, img_metas, device="cuda"): + """Get static anchors according to feature map sizes. + + Args: + featmap_sizes (list[tuple]): Multi-level feature map sizes. + img_metas (list[dict]): Image meta info. + device (torch.device | str): Device for returned tensors + + Returns: + tuple: + anchor_list (list[Tensor]): Anchors of each image. + valid_flag_list (list[Tensor]): Valid flags of each image. + """ + if not hasattr(self, 'static_anchors') or not hasattr(self, 'static_valid_flags'): + static_anchors, static_valid_flags = self.get_anchors(featmap_sizes, img_metas, device) + self.static_anchors = static_anchors + self.static_valid_flags = static_valid_flags + return self.static_anchors, self.static_valid_flags + + def get_targets(self, + anchor_list, + valid_flag_list, + gt_bboxes_list, + img_metas, + gt_bboxes_ignore_list=None, + gt_labels_list=None, + label_channels=1, + unmap_outputs=True, + return_sampling_results=False, + return_level=True): + """Compute regression and classification targets for anchors in + multiple images. + + Args: + anchor_list (list[list[Tensor]]): Multi level anchors of each + image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, 4). + valid_flag_list (list[list[Tensor]]): Multi level valid flags of + each image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, ) + gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. + img_metas (list[dict]): Meta info of each image. + gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be + ignored. + gt_labels_list (list[Tensor]): Ground truth labels of each box. + label_channels (int): Channel of label. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + return_sampling_results (bool): Whether to return the result of + sample. + return_level (bool): Whether to map outputs back to the levels + of feature map sizes. + Returns: + tuple: Usually returns a tuple containing learning targets. + + - labels_list (list[Tensor]): Labels of each level. + - label_weights_list (list[Tensor]): Label weights of each + level. + - bbox_targets_list (list[Tensor]): BBox targets of each level. + - bbox_weights_list (list[Tensor]): BBox weights of each level. + - num_total_pos (int): Number of positive samples in all + images. + - num_total_neg (int): Number of negative samples in all + images. + + additional_returns: This function enables user-defined returns from + `self._get_targets_single`. These returns are currently refined + to properties at each feature map (i.e. having HxW dimension). + The results will be concatenated after the end + """ + return AscendAnchorHead.get_targets( + self, + anchor_list, + valid_flag_list, + gt_bboxes_list, + img_metas, + gt_bboxes_ignore_list, + gt_labels_list, + label_channels, + unmap_outputs, + return_sampling_results, + return_level, + ) + + def concat_loss(self, + concat_cls_score, concat_bbox_pred, + concat_anchor, concat_labels, + concat_label_weights, + concat_bbox_targets, concat_bbox_weights, + concat_pos_mask, concat_neg_mask, + num_total_samples): + """Compute loss of all images. + + Args: + concat_cls_score (Tensor): Box scores for all image + Has shape (num_imgs, num_total_anchors, num_classes). + concat_bbox_pred (Tensor): Box energies / deltas for all image + level with shape (num_imgs, num_total_anchors, 4). + concat_anchor (Tensor): Box reference for all image with shape + (num_imgs, num_total_anchors, 4). + concat_labels (Tensor): Labels of all anchors with shape + (num_imgs, num_total_anchors,). + concat_label_weights (Tensor): Label weights of all anchor with + shape (num_imgs, num_total_anchors,) + concat_bbox_targets (Tensor): BBox regression targets of all anchor + weight shape (num_imgs, num_total_anchors, 4). + concat_bbox_weights (Tensor): BBox regression loss weights of + all anchor with shape (num_imgs, num_total_anchors, 4). + concat_pos_mask (Tensor): Positive samples mask in all images. + concat_neg_mask (Tensor): negative samples mask in all images. + num_total_samples (int): If sampling, num total samples equal to + the number of total anchors; Otherwise, it is the number of + positive anchors. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_images, num_anchors, _ = concat_anchor.size() + + concat_loss_cls_all = F.cross_entropy( + concat_cls_score.view((-1, self.cls_out_channels)), concat_labels.view(-1), + reduction='none').view(concat_label_weights.size()) * concat_label_weights + # # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + concat_num_pos_samples = torch.sum(concat_pos_mask, dim=1) + concat_num_neg_samples = self.train_cfg.neg_pos_ratio * concat_num_pos_samples + + concat_num_neg_samples_max = torch.sum(concat_neg_mask, dim=1) + concat_num_neg_samples = torch.min(concat_num_neg_samples, concat_num_neg_samples_max) + + concat_topk_loss_cls_neg, _ = torch.topk(concat_loss_cls_all * concat_neg_mask, k=num_anchors, dim=1) + concat_loss_cls_pos = torch.sum(concat_loss_cls_all * concat_pos_mask, dim=1) + + anchor_index = torch.arange(end=num_anchors, dtype=torch.float, device=concat_anchor.device).view((1, -1)) + topk_loss_neg_mask = (anchor_index < concat_num_neg_samples.view(-1, 1)).float() + + concat_loss_cls_neg = torch.sum(concat_topk_loss_cls_neg * topk_loss_neg_mask, dim=1) + loss_cls = (concat_loss_cls_pos + concat_loss_cls_neg) / num_total_samples + + if self.reg_decoded_bbox: + # TODO: support self.reg_decoded_bbox is True + raise RuntimeError + + loss_bbox_all = smooth_l1_loss( + concat_bbox_pred, + concat_bbox_targets, + concat_bbox_weights, + reduction="none", + beta=self.train_cfg.smoothl1_beta, + avg_factor=num_total_samples) + eps = torch.finfo(torch.float32).eps + + sum_dim = (i for i in range(1, len(loss_bbox_all.size()))) + loss_bbox = loss_bbox_all.sum(tuple(sum_dim)) / (num_total_samples + eps) + return loss_cls[None], loss_bbox + + @force_fp32(apply_to=('cls_scores', 'bbox_preds')) + def loss(self, + cls_scores, + bbox_preds, + gt_bboxes, + gt_labels, + img_metas, + gt_bboxes_ignore=None): + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + gt_bboxes (list[Tensor]): each item are the truth boxes for each + image in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): class indices corresponding to each box + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (None | list[Tensor]): specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, img_metas, device=device) + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + gt_bboxes, + img_metas, + gt_bboxes_ignore_list=gt_bboxes_ignore, + gt_labels_list=gt_labels, + label_channels=1, + unmap_outputs=True, + return_level=False) + if cls_reg_targets is None: + return None + + (concat_labels, concat_label_weights, concat_bbox_targets, concat_bbox_weights, concat_pos_mask, + concat_neg_mask, sampling_result, num_total_pos, num_total_neg, concat_anchors) = cls_reg_targets + + num_imgs = len(img_metas) + concat_cls_score = torch.cat([ + s.permute(0, 2, 3, 1).reshape( + num_imgs, -1, self.cls_out_channels) for s in cls_scores + ], 1) + + concat_bbox_pred = torch.cat([ + b.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) + for b in bbox_preds + ], -2) + + concat_losses_cls, concat_losses_bbox = self.concat_loss( + concat_cls_score, concat_bbox_pred, + concat_anchors, concat_labels, + concat_label_weights, + concat_bbox_targets, concat_bbox_weights, + concat_pos_mask, concat_neg_mask, + num_total_pos) + losses_cls = [concat_losses_cls[:, index_imgs] for index_imgs in range(num_imgs)] + losses_bbox = [losses_bbox for losses_bbox in concat_losses_bbox] + return dict(loss_cls=losses_cls, loss_bbox=losses_bbox) diff --git a/mmdet/utils/__init__.py b/mmdet/utils/__init__.py index 9b6e0295a4c..42c8e42c4f5 100644 --- a/mmdet/utils/__init__.py +++ b/mmdet/utils/__init__.py @@ -5,15 +5,15 @@ from .memory import AvoidCUDAOOM, AvoidOOM from .misc import find_latest_checkpoint, update_data_root from .replace_cfg_vals import replace_cfg_vals -from .rfnext import rfnext_init_model from .setup_env import setup_multi_processes from .split_batch import split_batch from .util_distribution import build_ddp, build_dp, get_device +from .ascend_util import set_index, images_to_levels, generate_max_gt_nums __all__ = [ 'get_root_logger', 'collect_env', 'find_latest_checkpoint', 'update_data_root', 'setup_multi_processes', 'get_caller_name', 'log_img_scale', 'compat_cfg', 'split_batch', 'build_ddp', 'build_dp', 'get_device', 'replace_cfg_vals', 'AvoidOOM', 'AvoidCUDAOOM', - 'rfnext_init_model' + 'set_index', 'images_to_levels', 'generate_max_gt_nums' ] diff --git a/mmdet/utils/ascend_util.py b/mmdet/utils/ascend_util.py new file mode 100644 index 00000000000..fa35425613f --- /dev/null +++ b/mmdet/utils/ascend_util.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +def set_index(ori_tensor, mask, new_value, neg=False): + if mask is None: + return ori_tensor + else: + if neg: + return ori_tensor * mask + new_value * (1 - mask) + else: + return ori_tensor * (1 - mask) + new_value * mask + + +def images_to_levels(target, num_levels): + """Convert targets by image to targets by feature level. + + [target_img0, target_img1] -> [target_level0, target_level1, ...] + """ + if not isinstance(target, torch.Tensor): + target = torch.stack(target, 0) + level_targets = [] + start = 0 + for n in num_levels: + end = start + n + # level_targets.append(target[:, start:end].squeeze(0)) + level_targets.append(target[:, start:end]) + start = end + return level_targets + + +def generate_max_gt_nums(gt_nums, minimum_gt_nums=32, maximum_gt_nums=1024): + max_gt_nums = max(gt_nums) + max_gt_nums_align = minimum_gt_nums + while max_gt_nums_align < max_gt_nums: + max_gt_nums_align *= 2 + if max_gt_nums_align > maximum_gt_nums: + raise RuntimeError + return max_gt_nums_align diff --git a/tests/test_models/test_dense_heads/test_ascend_head.py b/tests/test_models/test_dense_heads/test_ascend_head.py new file mode 100644 index 00000000000..ae91dbe4ba2 --- /dev/null +++ b/tests/test_models/test_dense_heads/test_ascend_head.py @@ -0,0 +1,228 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import torch + +from mmdet.models.dense_heads import AscendAnchorHead +from mmdet.models.dense_heads import AscendRetinaHead +from mmdet.models.dense_heads import AscendSSDHead + + +def test_ascend_anchor_head_loss(): + """Tests AscendAnchorHead loss when truth is empty and non-empty.""" + s = 256 + img_metas = [{ + 'img_shape': (s, s, 3), + 'scale_factor': 1, + 'pad_shape': (s, s, 3) + }] + + cfg = mmcv.Config( + dict( + assigner=dict( + type='AscendMaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False) + ) + self = AscendAnchorHead(num_classes=4, in_channels=1, train_cfg=cfg) + + # Anchor head expects a multiple levels of features per image + feat = [ + torch.rand(1, 1, s // (2 ** (i + 2)), s // (2 ** (i + 2))) + for i in range(len(self.prior_generator.strides)) + ] + cls_scores, bbox_preds = self.forward(feat) + + # Test that empty ground truth encourages the network to predict background + gt_bboxes = [torch.empty((0, 4))] + gt_labels = [torch.LongTensor([])] + + gt_bboxes_ignore = None + empty_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + # When there is no truth, the cls loss should be nonzero but there should + # be no box loss. + empty_cls_loss = sum(empty_gt_losses['loss_cls']) + empty_box_loss = sum(empty_gt_losses['loss_bbox']) + assert empty_cls_loss.item() > 0, 'cls loss should be non-zero' + assert empty_box_loss.item() == 0, ( + 'there should be no box loss when there are no true boxes') + + # When truth is non-empty then both cls and box loss should be nonzero for + # random inputs + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + ] + gt_labels = [torch.LongTensor([2])] + one_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + onegt_cls_loss = sum(one_gt_losses['loss_cls']) + onegt_box_loss = sum(one_gt_losses['loss_bbox']) + assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero' + assert onegt_box_loss.item() > 0, 'box loss should be non-zero' + + +def test_ascend_retina_head_loss(): + """Tests AscendRetinaHead loss when truth is empty and non-empty.""" + img_shape = (800, 1067, 3) + pad_shape = (800, 1088, 3) + num_classes = 80 + in_channels = 256 + + img_metas = [{ + 'img_shape': img_shape, + 'scale_factor': 1, + 'pad_shape': pad_shape + }] + + cfg = mmcv.Config( + dict( + assigner=dict( + type='AscendMaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False) + ) + self = AscendRetinaHead(num_classes=num_classes, + in_channels=in_channels, + train_cfg=cfg) + + # Anchor head expects a multiple levels of features per image + feat = [ + torch.rand(1, + in_channels, + pad_shape[0] // strides[0], + pad_shape[1] // strides[1]) + for strides in self.prior_generator.strides + ] + cls_scores, bbox_preds = self.forward(feat) + + # Test that empty ground truth encourages the network to predict background + gt_bboxes = [torch.empty((0, 4))] + gt_labels = [torch.LongTensor([])] + + gt_bboxes_ignore = None + empty_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + # When there is no truth, the cls loss should be nonzero but there should + # be no box loss. + empty_cls_loss = sum(empty_gt_losses['loss_cls']) + empty_box_loss = sum(empty_gt_losses['loss_bbox']) + assert empty_cls_loss.item() > 0, 'cls loss should be non-zero' + assert empty_box_loss.item() == 0, ( + 'there should be no box loss when there are no true boxes') + + # When truth is non-empty then both cls and box loss should be nonzero for + # random inputs + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + ] + gt_labels = [torch.LongTensor([2])] + one_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + onegt_cls_loss = sum(one_gt_losses['loss_cls']) + onegt_box_loss = sum(one_gt_losses['loss_bbox']) + assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero' + assert onegt_box_loss.item() > 0, 'box loss should be non-zero' + + +def test_ascend_ssd_head_loss(): + """Tests anchor head loss when truth is empty and non-empty.""" + img_shape = (320, 320, 3) + pad_shape = (320, 320, 3) + in_channels = (96, 1280, 512, 256, 256, 128) + img_metas = [ + { + 'img_shape': img_shape, + 'scale_factor': 1, + 'pad_shape': pad_shape + }, + { + 'img_shape': img_shape, + 'scale_factor': 1, + 'pad_shape': pad_shape + } + ] + + self = AscendSSDHead( + in_channels=in_channels, + num_classes=80, + use_depthwise=True, + norm_cfg=dict(type='BN', eps=0.001, momentum=0.03), + act_cfg=dict(type='ReLU6'), + init_cfg=dict(type='Normal', layer='Conv2d', std=0.001), + + anchor_generator=dict( + type='SSDAnchorGenerator', + scale_major=False, + strides=[16, 32, 64, 107, 160, 320], + ratios=[[2, 3], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]], + min_sizes=[48, 100, 150, 202, 253, 304], + max_sizes=[100, 150, 202, 253, 304, 320]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + train_cfg=mmcv.Config(dict( + assigner=dict( + type='AscendMaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0., + ignore_iof_thr=-1, + gt_max_assign_all=False), + smoothl1_beta=1., + allowed_border=-1, + pos_weight=-1, + neg_pos_ratio=3, + debug=False)) + ) + + # Anchor head expects a multiple levels of features per image + feat = [ + torch.rand( + 2, + in_channels[i], + round(pad_shape[0] / self.prior_generator.strides[i][0]), + round(pad_shape[1] / self.prior_generator.strides[i][1]) + ) + for i in range(len(self.prior_generator.strides)) + ] + cls_scores, bbox_preds = self.forward(feat) + + # Test that empty ground truth encourages the network to predict background + gt_bboxes = [torch.empty((0, 4)), torch.empty((0, 4))] + gt_labels = [torch.LongTensor([]), torch.LongTensor([])] + + gt_bboxes_ignore = None + empty_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + # When there is no truth, the cls loss should be nonzero but there should + # be no box loss. + empty_cls_loss = sum(empty_gt_losses['loss_cls']) + empty_box_loss = sum(empty_gt_losses['loss_bbox']) + assert empty_cls_loss.item() >= 0, 'cls loss should be non-zero' + assert empty_box_loss.item() == 0, ( + 'there should be no box loss when there are no true boxes') + + # When truth is non-empty then both cls and box loss should be nonzero for + # random inputs + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + ] + gt_labels = [torch.LongTensor([2]), torch.LongTensor([2])] + one_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + onegt_cls_loss = sum(one_gt_losses['loss_cls']) + onegt_box_loss = sum(one_gt_losses['loss_bbox']) + assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero' + assert onegt_box_loss.item() > 0, 'box loss should be non-zero' From aab5622e48e501bf21b9e798005ded31b9cee1d6 Mon Sep 17 00:00:00 2001 From: akstt <13146229470@163.com> Date: Thu, 12 Jan 2023 14:57:18 +0800 Subject: [PATCH 02/13] add copyright --- mmdet/models/dense_heads/ascend_anchor_head.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mmdet/models/dense_heads/ascend_anchor_head.py b/mmdet/models/dense_heads/ascend_anchor_head.py index b729c836f46..9dfcece5169 100644 --- a/mmdet/models/dense_heads/ascend_anchor_head.py +++ b/mmdet/models/dense_heads/ascend_anchor_head.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. import torch from ..builder import HEADS from .anchor_head import AnchorHead From ff43cc44f6546838bd666617cba4e321d7cfd7ce Mon Sep 17 00:00:00 2001 From: akstt <13146229470@163.com> Date: Thu, 12 Jan 2023 15:15:54 +0800 Subject: [PATCH 03/13] modify tensor type --- mmdet/models/dense_heads/ascend_anchor_head.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmdet/models/dense_heads/ascend_anchor_head.py b/mmdet/models/dense_heads/ascend_anchor_head.py index 9dfcece5169..6c880a3b1ec 100644 --- a/mmdet/models/dense_heads/ascend_anchor_head.py +++ b/mmdet/models/dense_heads/ascend_anchor_head.py @@ -359,7 +359,7 @@ def get_targets(self, # sampled anchors of all images min_num = torch.ones((num_imgs,), - dtype=concat_pos_mask.dtype, + dtype=torch.long, device=concat_pos_mask.device) num_total_pos = torch.sum(torch.max(torch.sum(concat_pos_mask, dim=1), min_num)) From 53e4dbf6a294907e8ac1cede5047483ff004be7f Mon Sep 17 00:00:00 2001 From: akstt <13146229470@163.com> Date: Thu, 12 Jan 2023 19:56:24 +0800 Subject: [PATCH 04/13] clean code --- mmdet/core/bbox/assigners/__init__.py | 4 +- .../bbox/assigners/ascend_assign_result.py | 16 +- .../bbox/assigners/ascend_max_iou_assigner.py | 107 ++++++++---- mmdet/models/dense_heads/__init__.py | 6 +- .../models/dense_heads/ascend_anchor_head.py | 159 ++++++++---------- .../models/dense_heads/ascend_retina_head.py | 14 +- mmdet/models/dense_heads/ascend_ssd_head.py | 83 +++++---- mmdet/utils/__init__.py | 4 +- .../test_dense_heads/test_ascend_head.py | 81 ++++----- 9 files changed, 248 insertions(+), 226 deletions(-) diff --git a/mmdet/core/bbox/assigners/__init__.py b/mmdet/core/bbox/assigners/__init__.py index ec092b25672..d6480a783be 100644 --- a/mmdet/core/bbox/assigners/__init__.py +++ b/mmdet/core/bbox/assigners/__init__.py @@ -20,6 +20,6 @@ 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult', 'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner', 'HungarianAssigner', 'RegionAssigner', 'UniformAssigner', 'SimOTAAssigner', - 'TaskAlignedAssigner', 'MaskHungarianAssigner', - 'AscendAssignResult', 'AscendMaxIoUAssigner' + 'TaskAlignedAssigner', 'MaskHungarianAssigner', 'AscendAssignResult', + 'AscendMaxIoUAssigner' ] diff --git a/mmdet/core/bbox/assigners/ascend_assign_result.py b/mmdet/core/bbox/assigners/ascend_assign_result.py index 39f8af6ea58..5075db2275e 100644 --- a/mmdet/core/bbox/assigners/ascend_assign_result.py +++ b/mmdet/core/bbox/assigners/ascend_assign_result.py @@ -1,6 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import torch - from mmdet.utils import util_mixins @@ -18,14 +16,18 @@ class AscendAssignResult(util_mixins.NiceRepr): and ground truth boxes. concat_anchor_gt_indes(None | LongTensor): The the assigned truth - box index of all anchors - . + box index of all anchors. - concat_anchor_gt_labels(None | LongTensor): The gt labels of all anchors + concat_anchor_gt_labels(None | LongTensor): The gt labels + of all anchors """ - def __init__(self, concat_num_gts, concat_pos_mask, concat_neg_mask, - concat_max_overlaps, concat_anchor_gt_indes=None, + def __init__(self, + concat_num_gts, + concat_pos_mask, + concat_neg_mask, + concat_max_overlaps, + concat_anchor_gt_indes=None, concat_anchor_gt_labels=None): self.concat_num_gts = concat_num_gts self.concat_pos_mask = concat_pos_mask diff --git a/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py b/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py index a406c111237..724c93b9d9e 100644 --- a/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py +++ b/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from ....utils import set_index from ..builder import BBOX_ASSIGNERS from ..iou_calculators import build_iou_calculator from .ascend_assign_result import AscendAssignResult from .base_assigner import BaseAssigner -from ....utils import set_index @BBOX_ASSIGNERS.register_module() @@ -72,70 +72,111 @@ def assign(self, concat_bboxes_ignore_mask=None, concat_num_gts=None): """Assign gt to bboxes. + Args: - concat_bboxes (Tensor): Bounding boxes to be assigned, shape(b, n, 4). - concat_gt_bboxes (Tensor): Ground truth boxes, shape (b, k, 4). - concat_gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are - labelled as `ignored`, e.g., crowd boxes in COCO. - concat_gt_labels (Tensor, optional): Label of gt_bboxes, shape (b, k, ). + concat_bboxes (Tensor): Bounding boxes to be assigned, + shape(b, n, 4). + concat_gt_bboxes (Tensor): Ground truth boxes, + shape (b, k, 4). + concat_gt_bboxes_ignore (Tensor, optional): Ground truth + bboxes that are labelled as `ignored`, + e.g., crowd boxes in COCO. + concat_gt_labels (Tensor, optional): Label of gt_bboxes, + shape (b, k, ). concat_bboxes_ignore_mask: (b, n) concat_num_gts:(b, ) Returns: :obj:`AssignResult`: The assign result. """ concat_overlaps = self.iou_calculator(concat_gt_bboxes, concat_bboxes) - concat_overlaps = set_index(concat_overlaps, concat_bboxes_ignore_mask.unsqueeze(1).float(), -1, neg=True) + concat_overlaps = set_index( + concat_overlaps, + concat_bboxes_ignore_mask.unsqueeze(1).float(), + -1, + neg=True) if self.ignore_iof_thr > 0 and concat_gt_bboxes_ignore is not None: if self.ignore_wrt_candidates: - concat_ignore_overlaps = self.iou_calculator(concat_bboxes, concat_gt_bboxes_ignore, mode='iof') - concat_ignore_overlaps = set_index(concat_ignore_overlaps, concat_bboxes_ignore_mask, -1) - concat_ignore_max_overlaps, _ = concat_ignore_overlaps.max(dim=2) + concat_ignore_overlaps = self.iou_calculator( + concat_bboxes, concat_gt_bboxes_ignore, mode='iof') + concat_ignore_overlaps = set_index(concat_ignore_overlaps, + concat_bboxes_ignore_mask, + -1) + concat_ignore_max_overlaps, _ = concat_ignore_overlaps.max( + dim=2) else: - concat_ignore_overlaps = self.iou_calculator(concat_gt_bboxes_ignore, concat_bboxes, mode='iof') - concat_ignore_overlaps = set_index(concat_ignore_overlaps, concat_bboxes_ignore_mask, -1) - concat_ignore_max_overlaps, _ = concat_ignore_overlaps.max(dim=1) - concat_ignore_mask = concat_ignore_max_overlaps > self.ignore_iof_thr - concat_overlaps = set_index(concat_overlaps, concat_ignore_mask, -1) - concat_assign_result = self.concat_assign_wrt_overlaps(concat_overlaps, concat_gt_labels, concat_num_gts) + concat_ignore_overlaps = self.iou_calculator( + concat_gt_bboxes_ignore, concat_bboxes, mode='iof') + concat_ignore_overlaps = set_index(concat_ignore_overlaps, + concat_bboxes_ignore_mask, + -1) + concat_ignore_max_overlaps, _ = \ + concat_ignore_overlaps.max(dim=1) + concat_ignore_mask = \ + concat_ignore_max_overlaps > self.ignore_iof_thr + concat_overlaps = set_index(concat_overlaps, concat_ignore_mask, + -1) + concat_assign_result = self.concat_assign_wrt_overlaps( + concat_overlaps, concat_gt_labels, concat_num_gts) return concat_assign_result - def concat_assign_wrt_overlaps(self, concat_overlaps, concat_gt_labels=None, concat_num_gts=None): + def concat_assign_wrt_overlaps(self, + concat_overlaps, + concat_gt_labels=None, + concat_num_gts=None): num_images, num_gts, num_bboxes = concat_overlaps.size() - concat_max_overlaps, concat_argmax_overlaps = concat_overlaps.max(dim=1) + concat_max_overlaps, concat_argmax_overlaps = concat_overlaps.max( + dim=1) if isinstance(self.neg_iou_thr, float): - concat_neg_mask = ((concat_max_overlaps >= 0) & (concat_max_overlaps < self.neg_iou_thr)).int() + concat_neg_mask = \ + ((concat_max_overlaps >= 0) + & (concat_max_overlaps < self.neg_iou_thr)).int() elif isinstance(self.neg_iou_thr, tuple): assert len(self.neg_iou_thr) == 2 - concat_neg_mask = ((concat_max_overlaps >= self.neg_iou_thr[0]) & (concat_max_overlaps < self.neg_iou_thr[1])).int() + concat_neg_mask = \ + ((concat_max_overlaps >= self.neg_iou_thr[0]) + & (concat_max_overlaps < self.neg_iou_thr[1])).int() else: - concat_neg_mask = torch.zeros(concat_max_overlaps.size(), dtype=torch.int, device=concat_max_overlaps.device) + concat_neg_mask = torch.zeros( + concat_max_overlaps.size(), + dtype=torch.int, + device=concat_max_overlaps.device) concat_pos_mask = (concat_max_overlaps >= self.pos_iou_thr).int() if self.match_low_quality: - concat_gt_max_overlaps, concat_gt_argmax_overlaps = concat_overlaps.max(dim=2) - concat_index_bool = (concat_gt_max_overlaps >= self.min_pos_iou) & (concat_gt_max_overlaps > 0) + concat_gt_max_overlaps, concat_gt_argmax_overlaps = \ + concat_overlaps.max(dim=2) + concat_index_bool = (concat_gt_max_overlaps >= self.min_pos_iou) &\ + (concat_gt_max_overlaps > 0) if self.gt_max_assign_all: - pos_inds_low_quality = (concat_overlaps == concat_gt_max_overlaps.unsqueeze(2)) & concat_index_bool.unsqueeze(2) + pos_inds_low_quality = \ + (concat_overlaps == concat_gt_max_overlaps.unsqueeze(2)) &\ + concat_index_bool.unsqueeze(2) for i in range(num_gts): pos_inds_low_quality_gt = pos_inds_low_quality[:, i, :] concat_argmax_overlaps[pos_inds_low_quality_gt] = i concat_pos_mask[pos_inds_low_quality_gt] = 1 else: - index_temp = torch.arange(0, num_gts, device=concat_max_overlaps.device) + index_temp = torch.arange( + 0, num_gts, device=concat_max_overlaps.device) for index_image in range(num_images): gt_argmax_overlaps = concat_gt_argmax_overlaps[index_image] index_bool = concat_index_bool[index_image] pos_inds_low_quality = gt_argmax_overlaps[index_bool] - concat_argmax_overlaps[index_image][pos_inds_low_quality] = index_temp[index_bool] + concat_argmax_overlaps[index_image][pos_inds_low_quality] \ + = index_temp[index_bool] concat_pos_mask[index_image][pos_inds_low_quality] = 1 concat_neg_mask = concat_neg_mask * (1 - concat_pos_mask) if concat_gt_labels is not None: - concat_anchor_gt_labels = torch.zeros((num_images, num_bboxes), - dtype=concat_gt_labels.dtype, - device=concat_gt_labels.device) + concat_anchor_gt_labels = torch.zeros( + (num_images, num_bboxes), + dtype=concat_gt_labels.dtype, + device=concat_gt_labels.device) for index_image in range(num_images): - concat_anchor_gt_labels[index_image] = torch.index_select(concat_gt_labels[index_image], 0, - concat_argmax_overlaps[index_image]) + concat_anchor_gt_labels[index_image] = torch.index_select( + concat_gt_labels[index_image], 0, + concat_argmax_overlaps[index_image]) else: concat_anchor_gt_labels = None - return AscendAssignResult(concat_num_gts, concat_pos_mask, concat_neg_mask, - concat_max_overlaps, concat_argmax_overlaps, concat_anchor_gt_labels) + return AscendAssignResult(concat_num_gts, concat_pos_mask, + concat_neg_mask, concat_max_overlaps, + concat_argmax_overlaps, + concat_anchor_gt_labels) diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py index ebffef847af..9c60ae14796 100644 --- a/mmdet/models/dense_heads/__init__.py +++ b/mmdet/models/dense_heads/__init__.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .anchor_free_head import AnchorFreeHead from .anchor_head import AnchorHead -from .atss_head import ATSSHead from .ascend_anchor_head import AscendAnchorHead from .ascend_retina_head import AscendRetinaHead from .ascend_ssd_head import AscendSSDHead +from .atss_head import ATSSHead from .autoassign_head import AutoAssignHead from .cascade_rpn_head import CascadeRPNHead, StageCascadeRPNHead from .centernet_head import CenterNetHead @@ -57,6 +57,6 @@ 'DETRHead', 'YOLOFHead', 'DeformableDETRHead', 'SOLOHead', 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead', 'DecoupledSOLOLightHead', 'LADHead', 'TOODHead', 'MaskFormerHead', - 'Mask2FormerHead', 'SOLOV2Head', 'DDODHead', - 'AscendAnchorHead', 'AscendRetinaHead', 'AscendSSDHead' + 'Mask2FormerHead', 'SOLOV2Head', 'DDODHead', 'AscendAnchorHead', + 'AscendRetinaHead', 'AscendSSDHead' ] diff --git a/mmdet/models/dense_heads/ascend_anchor_head.py b/mmdet/models/dense_heads/ascend_anchor_head.py index 6c880a3b1ec..f52d1b93e78 100644 --- a/mmdet/models/dense_heads/ascend_anchor_head.py +++ b/mmdet/models/dense_heads/ascend_anchor_head.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from ..builder import HEADS -from .anchor_head import AnchorHead + from ...core.bbox.assigners import AscendMaxIoUAssigner from ...core.bbox.samplers import PseudoSampler -from ...utils import set_index, images_to_levels, generate_max_gt_nums +from ...utils import generate_max_gt_nums, images_to_levels, set_index +from ..builder import HEADS +from .anchor_head import AnchorHead @HEADS.register_module() @@ -65,15 +66,10 @@ def __init__(self, loss_bbox=loss_bbox, train_cfg=train_cfg, test_cfg=test_cfg, - init_cfg=init_cfg - ) + init_cfg=init_cfg) - def _get_concat_gt_bboxes(self, - gt_bboxes_list, - num_images, - gt_nums, - device, - max_gt_labels): + def _get_concat_gt_bboxes(self, gt_bboxes_list, num_images, gt_nums, + device, max_gt_labels): """Get ground truth bboxes of all image. Args: @@ -100,16 +96,14 @@ def _get_concat_gt_bboxes(self, concat_gt_bboxes[:, :, 2:] = self.min_anchor[1] self.concat_gt_bboxes[max_gt_labels] = concat_gt_bboxes.clone() else: - concat_gt_bboxes = self.concat_gt_bboxes.get(max_gt_labels).clone() + concat_gt_bboxes = self.concat_gt_bboxes.get( + max_gt_labels).clone() for index_imgs, gt_bboxes in enumerate(gt_bboxes_list): concat_gt_bboxes[index_imgs, :gt_nums[index_imgs]] = gt_bboxes return concat_gt_bboxes - def _get_concat_gt_bboxes_ignore(self, - gt_bboxes_ignore_list, - num_images, - gt_nums, - device): + def _get_concat_gt_bboxes_ignore(self, gt_bboxes_ignore_list, num_images, + gt_nums, device): """Ground truth bboxes to be ignored of all image. Args: @@ -126,15 +120,11 @@ def _get_concat_gt_bboxes_ignore(self, if gt_bboxes_ignore_list is None: concat_gt_bboxes_ignore = None else: - raise RuntimeError("gt_bboxes_ignore not support yet") + raise RuntimeError('gt_bboxes_ignore not support yet') return concat_gt_bboxes_ignore - def _get_concat_gt_labels(self, - gt_labels_list, - num_images, - gt_nums, - device, - max_gt_labels): + def _get_concat_gt_labels(self, gt_labels_list, num_images, gt_nums, + device, max_gt_labels): """Ground truth bboxes to be ignored of all image. Args: @@ -200,7 +190,9 @@ def _get_targets_concat(self, num_imgs, num_anchors, _ = concat_anchors.size() # assign gt and sample concat_anchors assign_result = self.assigner.assign( - concat_anchors, concat_gt_bboxes, concat_gt_bboxes_ignore, + concat_anchors, + concat_gt_bboxes, + concat_gt_bboxes_ignore, None if self.sampling else concat_gt_labels, concat_bboxes_ignore_mask=concat_valid_flags) # TODO: support sampling_result @@ -210,37 +202,51 @@ def _get_targets_concat(self, concat_anchor_gt_indes = assign_result.concat_anchor_gt_indes concat_anchor_gt_labels = assign_result.concat_anchor_gt_labels - concat_anchor_gt_bboxes = torch.zeros(concat_anchors.size(), - dtype=concat_anchors.dtype, - device=concat_anchors.device) + concat_anchor_gt_bboxes = torch.zeros( + concat_anchors.size(), + dtype=concat_anchors.dtype, + device=concat_anchors.device) for index_imgs in range(num_imgs): - concat_anchor_gt_bboxes[index_imgs] = torch.index_select(concat_gt_bboxes[index_imgs], 0, - concat_anchor_gt_indes[index_imgs]) + concat_anchor_gt_bboxes[index_imgs] = torch.index_select( + concat_gt_bboxes[index_imgs], 0, + concat_anchor_gt_indes[index_imgs]) concat_bbox_targets = torch.zeros_like(concat_anchors) concat_bbox_weights = torch.zeros_like(concat_anchors) - concat_labels = concat_anchors.new_full((num_imgs, num_anchors), self.num_classes, dtype=torch.int) - concat_label_weights = concat_anchors.new_zeros((num_imgs, num_anchors), dtype=torch.float) + concat_labels = concat_anchors.new_full((num_imgs, num_anchors), + self.num_classes, + dtype=torch.int) + concat_label_weights = concat_anchors.new_zeros( + (num_imgs, num_anchors), dtype=torch.float) if not self.reg_decoded_bbox: - concat_pos_bbox_targets = self.bbox_coder.encode(concat_anchors, concat_anchor_gt_bboxes) + concat_pos_bbox_targets = self.bbox_coder.encode( + concat_anchors, concat_anchor_gt_bboxes) else: concat_pos_bbox_targets = concat_anchor_gt_bboxes - concat_bbox_targets = set_index(concat_bbox_targets, concat_pos_mask.unsqueeze(2), concat_pos_bbox_targets) - concat_bbox_weights = set_index(concat_bbox_weights, concat_pos_mask.unsqueeze(2), 1.0) + concat_bbox_targets = set_index(concat_bbox_targets, + concat_pos_mask.unsqueeze(2), + concat_pos_bbox_targets) + concat_bbox_weights = set_index(concat_bbox_weights, + concat_pos_mask.unsqueeze(2), 1.0) if concat_gt_labels is None: concat_labels = set_index(concat_labels, concat_pos_mask, 0.0) else: - concat_labels = set_index(concat_labels, concat_pos_mask, concat_anchor_gt_labels) + concat_labels = set_index(concat_labels, concat_pos_mask, + concat_anchor_gt_labels) if self.train_cfg.pos_weight <= 0: - concat_label_weights = set_index(concat_label_weights, concat_pos_mask, 1.0) + concat_label_weights = set_index(concat_label_weights, + concat_pos_mask, 1.0) else: - concat_label_weights = set_index(concat_label_weights, concat_pos_mask, self.train_cfg.pos_weight) - concat_label_weights = set_index(concat_label_weights, concat_neg_mask, 1.0) + concat_label_weights = set_index(concat_label_weights, + concat_pos_mask, + self.train_cfg.pos_weight) + concat_label_weights = set_index(concat_label_weights, concat_neg_mask, + 1.0) return (concat_labels, concat_label_weights, concat_bbox_targets, - concat_bbox_weights, concat_pos_mask, - concat_neg_mask, sampling_result) + concat_bbox_weights, concat_pos_mask, concat_neg_mask, + sampling_result) def get_targets(self, anchor_list, @@ -314,33 +320,22 @@ def get_targets(self, concat_anchor_list.append(torch.cat(anchor_list[i])) concat_valid_flag_list.append(torch.cat(valid_flag_list[i])) concat_anchors = torch.cat( - [torch.unsqueeze(anchor, 0) for anchor in concat_anchor_list], 0 - ) - concat_valid_flags = torch.cat( - [torch.unsqueeze(concat_valid_flag, 0) - for concat_valid_flag in concat_valid_flag_list], 0 - ) + [torch.unsqueeze(anchor, 0) for anchor in concat_anchor_list], 0) + concat_valid_flags = torch.cat([ + torch.unsqueeze(concat_valid_flag, 0) + for concat_valid_flag in concat_valid_flag_list + ], 0) gt_nums = [len(gt_bbox) for gt_bbox in gt_bboxes_list] max_gt_nums = generate_max_gt_nums(gt_nums) - concat_gt_bboxes = self._get_concat_gt_bboxes( - gt_bboxes_list, - num_imgs, - gt_nums, - device, - max_gt_nums) + concat_gt_bboxes = self._get_concat_gt_bboxes(gt_bboxes_list, num_imgs, + gt_nums, device, + max_gt_nums) concat_gt_bboxes_ignore = self._get_concat_gt_bboxes_ignore( - gt_bboxes_ignore_list, - num_imgs, - gt_nums, - device - ) - concat_gt_labels = self._get_concat_gt_labels( - gt_labels_list, - num_imgs, - gt_nums, - device, - max_gt_nums) + gt_bboxes_ignore_list, num_imgs, gt_nums, device) + concat_gt_labels = self._get_concat_gt_labels(gt_labels_list, num_imgs, + gt_nums, device, + max_gt_nums) results = self._get_targets_concat( concat_anchors, @@ -353,35 +348,30 @@ def get_targets(self, unmap_outputs=unmap_outputs) (concat_labels, concat_label_weights, concat_bbox_targets, - concat_bbox_weights, concat_pos_mask, - concat_neg_mask, sampling_result) = results[:7] + concat_bbox_weights, concat_pos_mask, concat_neg_mask, + sampling_result) = results[:7] rest_results = list(results[7:]) # user-added return values # sampled anchors of all images - min_num = torch.ones((num_imgs,), + min_num = torch.ones((num_imgs, ), dtype=torch.long, device=concat_pos_mask.device) - num_total_pos = torch.sum(torch.max(torch.sum(concat_pos_mask, dim=1), - min_num)) - num_total_neg = torch.sum(torch.max(torch.sum(concat_neg_mask, dim=1), - min_num)) + num_total_pos = torch.sum( + torch.max(torch.sum(concat_pos_mask, dim=1), min_num)) + num_total_neg = torch.sum( + torch.max(torch.sum(concat_neg_mask, dim=1), min_num)) if return_level is True: - labels_list = images_to_levels( - concat_labels, - num_level_anchors) - label_weights_list = images_to_levels( - concat_label_weights, - num_level_anchors) - bbox_targets_list = images_to_levels( - concat_bbox_targets, - num_level_anchors) - bbox_weights_list = images_to_levels( - concat_bbox_weights, - num_level_anchors) + labels_list = images_to_levels(concat_labels, num_level_anchors) + label_weights_list = images_to_levels(concat_label_weights, + num_level_anchors) + bbox_targets_list = images_to_levels(concat_bbox_targets, + num_level_anchors) + bbox_weights_list = images_to_levels(concat_bbox_weights, + num_level_anchors) res = (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) if return_sampling_results: - res = res + (sampling_result,) + res = res + (sampling_result, ) for i, r in enumerate(rest_results): # user-added return values rest_results[i] = images_to_levels(r, num_level_anchors) @@ -392,4 +382,3 @@ def get_targets(self, sampling_result, num_total_pos, num_total_neg, concat_anchors) return res - diff --git a/mmdet/models/dense_heads/ascend_retina_head.py b/mmdet/models/dense_heads/ascend_retina_head.py index 13bc52cf855..159fe75c1ca 100644 --- a/mmdet/models/dense_heads/ascend_retina_head.py +++ b/mmdet/models/dense_heads/ascend_retina_head.py @@ -110,14 +110,6 @@ def get_targets(self, The results will be concatenated after the end """ return AscendAnchorHead.get_targets( - self, - anchor_list, - valid_flag_list, - gt_bboxes_list, - img_metas, - gt_bboxes_ignore_list, - gt_labels_list, - label_channels, - unmap_outputs, - return_sampling_results, - return_level) + self, anchor_list, valid_flag_list, gt_bboxes_list, img_metas, + gt_bboxes_ignore_list, gt_labels_list, label_channels, + unmap_outputs, return_sampling_results, return_level) diff --git a/mmdet/models/dense_heads/ascend_ssd_head.py b/mmdet/models/dense_heads/ascend_ssd_head.py index ee6d279c92f..66d7f4dd7eb 100644 --- a/mmdet/models/dense_heads/ascend_ssd_head.py +++ b/mmdet/models/dense_heads/ascend_ssd_head.py @@ -2,6 +2,7 @@ import torch import torch.nn.functional as F from mmcv.runner import force_fp32 + from ..builder import HEADS from ..losses import smooth_l1_loss from .ascend_anchor_head import AscendAnchorHead @@ -84,12 +85,11 @@ def __init__(self, reg_decoded_bbox=reg_decoded_bbox, train_cfg=train_cfg, test_cfg=test_cfg, - init_cfg=init_cfg - ) + init_cfg=init_cfg) assert self.reg_decoded_bbox is False, \ 'reg_decoded_bbox only support False now.' - def get_static_anchors(self, featmap_sizes, img_metas, device="cuda"): + def get_static_anchors(self, featmap_sizes, img_metas, device='cuda'): """Get static anchors according to feature map sizes. Args: @@ -102,8 +102,10 @@ def get_static_anchors(self, featmap_sizes, img_metas, device="cuda"): anchor_list (list[Tensor]): Anchors of each image. valid_flag_list (list[Tensor]): Valid flags of each image. """ - if not hasattr(self, 'static_anchors') or not hasattr(self, 'static_valid_flags'): - static_anchors, static_valid_flags = self.get_anchors(featmap_sizes, img_metas, device) + if not hasattr(self, 'static_anchors') or \ + not hasattr(self, 'static_valid_flags'): + static_anchors, static_valid_flags = self.get_anchors( + featmap_sizes, img_metas, device) self.static_anchors = static_anchors self.static_valid_flags = static_valid_flags return self.static_anchors, self.static_valid_flags @@ -175,12 +177,9 @@ def get_targets(self, return_level, ) - def concat_loss(self, - concat_cls_score, concat_bbox_pred, - concat_anchor, concat_labels, - concat_label_weights, - concat_bbox_targets, concat_bbox_weights, - concat_pos_mask, concat_neg_mask, + def concat_loss(self, concat_cls_score, concat_bbox_pred, concat_anchor, + concat_labels, concat_label_weights, concat_bbox_targets, + concat_bbox_weights, concat_pos_mask, concat_neg_mask, num_total_samples): """Compute loss of all images. @@ -211,23 +210,34 @@ def concat_loss(self, num_images, num_anchors, _ = concat_anchor.size() concat_loss_cls_all = F.cross_entropy( - concat_cls_score.view((-1, self.cls_out_channels)), concat_labels.view(-1), - reduction='none').view(concat_label_weights.size()) * concat_label_weights + concat_cls_score.view((-1, self.cls_out_channels)), + concat_labels.view(-1), + reduction='none').view( + concat_label_weights.size()) * concat_label_weights # # FG cat_id: [0, num_classes -1], BG cat_id: num_classes concat_num_pos_samples = torch.sum(concat_pos_mask, dim=1) - concat_num_neg_samples = self.train_cfg.neg_pos_ratio * concat_num_pos_samples + concat_num_neg_samples = \ + self.train_cfg.neg_pos_ratio * concat_num_pos_samples concat_num_neg_samples_max = torch.sum(concat_neg_mask, dim=1) - concat_num_neg_samples = torch.min(concat_num_neg_samples, concat_num_neg_samples_max) + concat_num_neg_samples = torch.min(concat_num_neg_samples, + concat_num_neg_samples_max) - concat_topk_loss_cls_neg, _ = torch.topk(concat_loss_cls_all * concat_neg_mask, k=num_anchors, dim=1) - concat_loss_cls_pos = torch.sum(concat_loss_cls_all * concat_pos_mask, dim=1) + concat_topk_loss_cls_neg, _ = torch.topk( + concat_loss_cls_all * concat_neg_mask, k=num_anchors, dim=1) + concat_loss_cls_pos = torch.sum( + concat_loss_cls_all * concat_pos_mask, dim=1) - anchor_index = torch.arange(end=num_anchors, dtype=torch.float, device=concat_anchor.device).view((1, -1)) - topk_loss_neg_mask = (anchor_index < concat_num_neg_samples.view(-1, 1)).float() + anchor_index = torch.arange( + end=num_anchors, dtype=torch.float, + device=concat_anchor.device).view((1, -1)) + topk_loss_neg_mask = (anchor_index < concat_num_neg_samples.view( + -1, 1)).float() - concat_loss_cls_neg = torch.sum(concat_topk_loss_cls_neg * topk_loss_neg_mask, dim=1) - loss_cls = (concat_loss_cls_pos + concat_loss_cls_neg) / num_total_samples + concat_loss_cls_neg = torch.sum( + concat_topk_loss_cls_neg * topk_loss_neg_mask, dim=1) + loss_cls = \ + (concat_loss_cls_pos + concat_loss_cls_neg) / num_total_samples if self.reg_decoded_bbox: # TODO: support self.reg_decoded_bbox is True @@ -237,13 +247,14 @@ def concat_loss(self, concat_bbox_pred, concat_bbox_targets, concat_bbox_weights, - reduction="none", + reduction='none', beta=self.train_cfg.smoothl1_beta, avg_factor=num_total_samples) eps = torch.finfo(torch.float32).eps sum_dim = (i for i in range(1, len(loss_bbox_all.size()))) - loss_bbox = loss_bbox_all.sum(tuple(sum_dim)) / (num_total_samples + eps) + loss_bbox = loss_bbox_all.sum(tuple(sum_dim)) / ( + num_total_samples + eps) return loss_cls[None], loss_bbox @force_fp32(apply_to=('cls_scores', 'bbox_preds')) @@ -292,27 +303,27 @@ def loss(self, if cls_reg_targets is None: return None - (concat_labels, concat_label_weights, concat_bbox_targets, concat_bbox_weights, concat_pos_mask, - concat_neg_mask, sampling_result, num_total_pos, num_total_neg, concat_anchors) = cls_reg_targets + (concat_labels, concat_label_weights, concat_bbox_targets, + concat_bbox_weights, concat_pos_mask, concat_neg_mask, + sampling_result, num_total_pos, num_total_neg, + concat_anchors) = cls_reg_targets num_imgs = len(img_metas) concat_cls_score = torch.cat([ - s.permute(0, 2, 3, 1).reshape( - num_imgs, -1, self.cls_out_channels) for s in cls_scores + s.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.cls_out_channels) + for s in cls_scores ], 1) concat_bbox_pred = torch.cat([ - b.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) - for b in bbox_preds + b.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) for b in bbox_preds ], -2) concat_losses_cls, concat_losses_bbox = self.concat_loss( - concat_cls_score, concat_bbox_pred, - concat_anchors, concat_labels, - concat_label_weights, - concat_bbox_targets, concat_bbox_weights, - concat_pos_mask, concat_neg_mask, - num_total_pos) - losses_cls = [concat_losses_cls[:, index_imgs] for index_imgs in range(num_imgs)] + concat_cls_score, concat_bbox_pred, concat_anchors, concat_labels, + concat_label_weights, concat_bbox_targets, concat_bbox_weights, + concat_pos_mask, concat_neg_mask, num_total_pos) + losses_cls = [ + concat_losses_cls[:, index_imgs] for index_imgs in range(num_imgs) + ] losses_bbox = [losses_bbox for losses_bbox in concat_losses_bbox] return dict(loss_cls=losses_cls, loss_bbox=losses_bbox) diff --git a/mmdet/utils/__init__.py b/mmdet/utils/__init__.py index 42c8e42c4f5..c20f1175574 100644 --- a/mmdet/utils/__init__.py +++ b/mmdet/utils/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .ascend_util import generate_max_gt_nums, images_to_levels, set_index from .collect_env import collect_env from .compat_config import compat_cfg from .logger import get_caller_name, get_root_logger, log_img_scale @@ -8,12 +9,11 @@ from .setup_env import setup_multi_processes from .split_batch import split_batch from .util_distribution import build_ddp, build_dp, get_device -from .ascend_util import set_index, images_to_levels, generate_max_gt_nums __all__ = [ 'get_root_logger', 'collect_env', 'find_latest_checkpoint', 'update_data_root', 'setup_multi_processes', 'get_caller_name', 'log_img_scale', 'compat_cfg', 'split_batch', 'build_ddp', 'build_dp', 'get_device', 'replace_cfg_vals', 'AvoidOOM', 'AvoidCUDAOOM', - 'set_index', 'images_to_levels', 'generate_max_gt_nums' + 'generate_max_gt_nums', 'set_index', 'images_to_levels' ] diff --git a/tests/test_models/test_dense_heads/test_ascend_head.py b/tests/test_models/test_dense_heads/test_ascend_head.py index ae91dbe4ba2..843a55fe7b4 100644 --- a/tests/test_models/test_dense_heads/test_ascend_head.py +++ b/tests/test_models/test_dense_heads/test_ascend_head.py @@ -2,9 +2,8 @@ import mmcv import torch -from mmdet.models.dense_heads import AscendAnchorHead -from mmdet.models.dense_heads import AscendRetinaHead -from mmdet.models.dense_heads import AscendSSDHead +from mmdet.models.dense_heads import (AscendAnchorHead, AscendRetinaHead, + AscendSSDHead) def test_ascend_anchor_head_loss(): @@ -26,13 +25,12 @@ def test_ascend_anchor_head_loss(): ignore_iof_thr=-1), allowed_border=-1, pos_weight=-1, - debug=False) - ) + debug=False)) self = AscendAnchorHead(num_classes=4, in_channels=1, train_cfg=cfg) # Anchor head expects a multiple levels of features per image feat = [ - torch.rand(1, 1, s // (2 ** (i + 2)), s // (2 ** (i + 2))) + torch.rand(1, 1, s // (2**(i + 2)), s // (2**(i + 2))) for i in range(len(self.prior_generator.strides)) ] cls_scores, bbox_preds = self.forward(feat) @@ -89,17 +87,13 @@ def test_ascend_retina_head_loss(): ignore_iof_thr=-1), allowed_border=-1, pos_weight=-1, - debug=False) - ) - self = AscendRetinaHead(num_classes=num_classes, - in_channels=in_channels, - train_cfg=cfg) + debug=False)) + self = AscendRetinaHead( + num_classes=num_classes, in_channels=in_channels, train_cfg=cfg) # Anchor head expects a multiple levels of features per image feat = [ - torch.rand(1, - in_channels, - pad_shape[0] // strides[0], + torch.rand(1, in_channels, pad_shape[0] // strides[0], pad_shape[1] // strides[1]) for strides in self.prior_generator.strides ] @@ -139,18 +133,15 @@ def test_ascend_ssd_head_loss(): img_shape = (320, 320, 3) pad_shape = (320, 320, 3) in_channels = (96, 1280, 512, 256, 256, 128) - img_metas = [ - { - 'img_shape': img_shape, - 'scale_factor': 1, - 'pad_shape': pad_shape - }, - { - 'img_shape': img_shape, - 'scale_factor': 1, - 'pad_shape': pad_shape - } - ] + img_metas = [{ + 'img_shape': img_shape, + 'scale_factor': 1, + 'pad_shape': pad_shape + }, { + 'img_shape': img_shape, + 'scale_factor': 1, + 'pad_shape': pad_shape + }] self = AscendSSDHead( in_channels=in_channels, @@ -159,7 +150,6 @@ def test_ascend_ssd_head_loss(): norm_cfg=dict(type='BN', eps=0.001, momentum=0.03), act_cfg=dict(type='ReLU6'), init_cfg=dict(type='Normal', layer='Conv2d', std=0.001), - anchor_generator=dict( type='SSDAnchorGenerator', scale_major=False, @@ -171,29 +161,26 @@ def test_ascend_ssd_head_loss(): type='DeltaXYWHBBoxCoder', target_means=[.0, .0, .0, .0], target_stds=[0.1, 0.1, 0.2, 0.2]), - train_cfg=mmcv.Config(dict( - assigner=dict( - type='AscendMaxIoUAssigner', - pos_iou_thr=0.5, - neg_iou_thr=0.5, - min_pos_iou=0., - ignore_iof_thr=-1, - gt_max_assign_all=False), - smoothl1_beta=1., - allowed_border=-1, - pos_weight=-1, - neg_pos_ratio=3, - debug=False)) - ) + train_cfg=mmcv.Config( + dict( + assigner=dict( + type='AscendMaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0., + ignore_iof_thr=-1, + gt_max_assign_all=False), + smoothl1_beta=1., + allowed_border=-1, + pos_weight=-1, + neg_pos_ratio=3, + debug=False))) # Anchor head expects a multiple levels of features per image feat = [ - torch.rand( - 2, - in_channels[i], - round(pad_shape[0] / self.prior_generator.strides[i][0]), - round(pad_shape[1] / self.prior_generator.strides[i][1]) - ) + torch.rand(2, in_channels[i], + round(pad_shape[0] / self.prior_generator.strides[i][0]), + round(pad_shape[1] / self.prior_generator.strides[i][1])) for i in range(len(self.prior_generator.strides)) ] cls_scores, bbox_preds = self.forward(feat) From f851098dfa92beda922646bd458d8a5577f1c71a Mon Sep 17 00:00:00 2001 From: akstt <13146229470@163.com> Date: Tue, 17 Jan 2023 10:23:14 +0800 Subject: [PATCH 05/13] resolve view comments --- .../bbox/assigners/ascend_assign_result.py | 43 ++- .../bbox/assigners/ascend_max_iou_assigner.py | 156 +++++------ .../models/dense_heads/ascend_anchor_head.py | 261 +++++++++--------- mmdet/models/dense_heads/ascend_ssd_head.py | 95 ++++--- mmdet/utils/__init__.py | 4 +- mmdet/utils/ascend_util.py | 11 +- 6 files changed, 280 insertions(+), 290 deletions(-) diff --git a/mmdet/core/bbox/assigners/ascend_assign_result.py b/mmdet/core/bbox/assigners/ascend_assign_result.py index 5075db2275e..c481b67d9a3 100644 --- a/mmdet/core/bbox/assigners/ascend_assign_result.py +++ b/mmdet/core/bbox/assigners/ascend_assign_result.py @@ -5,35 +5,30 @@ class AscendAssignResult(util_mixins.NiceRepr): """Stores ascend assignments between predicted and truth boxes. - Attributes: - concat_num_gts (list[int]): the number of truth boxes considered. - - concat_pos_mask (IntTensor): Positive samples mask in all images. - - concat_neg_mask (IntTensor): Negative samples mask in all images. - - concat_max_overlaps (FloatTensor): The max overlaps of all bboxes + Arguments: + batch_num_gts (list[int]): the number of truth boxes considered. + batch_pos_mask (IntTensor): Positive samples mask in all images. + batch_neg_mask (IntTensor): Negative samples mask in all images. + batch_max_overlaps (FloatTensor): The max overlaps of all bboxes and ground truth boxes. - - concat_anchor_gt_indes(None | LongTensor): The the assigned truth + batch_anchor_gt_indes(None | LongTensor): The the assigned truth box index of all anchors. - - concat_anchor_gt_labels(None | LongTensor): The gt labels + batch_anchor_gt_labels(None | LongTensor): The gt labels of all anchors """ def __init__(self, - concat_num_gts, - concat_pos_mask, - concat_neg_mask, - concat_max_overlaps, - concat_anchor_gt_indes=None, - concat_anchor_gt_labels=None): - self.concat_num_gts = concat_num_gts - self.concat_pos_mask = concat_pos_mask - self.concat_neg_mask = concat_neg_mask - self.concat_max_overlaps = concat_max_overlaps - self.concat_anchor_gt_indes = concat_anchor_gt_indes - self.concat_anchor_gt_labels = concat_anchor_gt_labels + batch_num_gts, + batch_pos_mask, + batch_neg_mask, + batch_max_overlaps, + batch_anchor_gt_indes=None, + batch_anchor_gt_labels=None): + self.batch_num_gts = batch_num_gts + self.batch_pos_mask = batch_pos_mask + self.batch_neg_mask = batch_neg_mask + self.batch_max_overlaps = batch_max_overlaps + self.batch_anchor_gt_indes = batch_anchor_gt_indes + self.batch_anchor_gt_labels = batch_anchor_gt_labels # Interface for possible user-defined properties self._extra_properties = {} diff --git a/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py b/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py index 724c93b9d9e..2d4190c6ccd 100644 --- a/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py +++ b/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py @@ -65,118 +65,112 @@ def __init__(self, self.iou_calculator = build_iou_calculator(iou_calculator) def assign(self, - concat_bboxes, - concat_gt_bboxes, - concat_gt_bboxes_ignore=None, - concat_gt_labels=None, - concat_bboxes_ignore_mask=None, - concat_num_gts=None): + batch_bboxes, + batch_gt_bboxes, + batch_gt_bboxes_ignore=None, + batch_gt_labels=None, + batch_bboxes_ignore_mask=None, + batch_num_gts=None): """Assign gt to bboxes. Args: - concat_bboxes (Tensor): Bounding boxes to be assigned, + batch_bboxes (Tensor): Bounding boxes to be assigned, shape(b, n, 4). - concat_gt_bboxes (Tensor): Ground truth boxes, + batch_gt_bboxes (Tensor): Ground truth boxes, shape (b, k, 4). - concat_gt_bboxes_ignore (Tensor, optional): Ground truth + batch_gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are labelled as `ignored`, e.g., crowd boxes in COCO. - concat_gt_labels (Tensor, optional): Label of gt_bboxes, + batch_gt_labels (Tensor, optional): Label of gt_bboxes, shape (b, k, ). - concat_bboxes_ignore_mask: (b, n) - concat_num_gts:(b, ) + batch_bboxes_ignore_mask: (b, n) + batch_num_gts:(b, ) Returns: :obj:`AssignResult`: The assign result. """ - concat_overlaps = self.iou_calculator(concat_gt_bboxes, concat_bboxes) - concat_overlaps = set_index( - concat_overlaps, - concat_bboxes_ignore_mask.unsqueeze(1).float(), + batch_overlaps = self.iou_calculator(batch_gt_bboxes, batch_bboxes) + batch_overlaps = set_index( + batch_overlaps, + batch_bboxes_ignore_mask.unsqueeze(1).float(), -1, neg=True) - if self.ignore_iof_thr > 0 and concat_gt_bboxes_ignore is not None: + if self.ignore_iof_thr > 0 and batch_gt_bboxes_ignore is not None: if self.ignore_wrt_candidates: - concat_ignore_overlaps = self.iou_calculator( - concat_bboxes, concat_gt_bboxes_ignore, mode='iof') - concat_ignore_overlaps = set_index(concat_ignore_overlaps, - concat_bboxes_ignore_mask, - -1) - concat_ignore_max_overlaps, _ = concat_ignore_overlaps.max( - dim=2) + batch_ignore_overlaps = self.iou_calculator( + batch_bboxes, batch_gt_bboxes_ignore, mode='iof') + batch_ignore_overlaps = set_index(batch_ignore_overlaps, + batch_bboxes_ignore_mask, -1) + batch_ignore_max_overlaps, _ = batch_ignore_overlaps.max(dim=2) else: - concat_ignore_overlaps = self.iou_calculator( - concat_gt_bboxes_ignore, concat_bboxes, mode='iof') - concat_ignore_overlaps = set_index(concat_ignore_overlaps, - concat_bboxes_ignore_mask, - -1) - concat_ignore_max_overlaps, _ = \ - concat_ignore_overlaps.max(dim=1) - concat_ignore_mask = \ - concat_ignore_max_overlaps > self.ignore_iof_thr - concat_overlaps = set_index(concat_overlaps, concat_ignore_mask, - -1) - concat_assign_result = self.concat_assign_wrt_overlaps( - concat_overlaps, concat_gt_labels, concat_num_gts) - return concat_assign_result + batch_ignore_overlaps = self.iou_calculator( + batch_gt_bboxes_ignore, batch_bboxes, mode='iof') + batch_ignore_overlaps = set_index(batch_ignore_overlaps, + batch_bboxes_ignore_mask, -1) + batch_ignore_max_overlaps, _ = \ + batch_ignore_overlaps.max(dim=1) + batch_ignore_mask = \ + batch_ignore_max_overlaps > self.ignore_iof_thr + batch_overlaps = set_index(batch_overlaps, batch_ignore_mask, -1) + batch_assign_result = self.batch_assign_wrt_overlaps( + batch_overlaps, batch_gt_labels, batch_num_gts) + return batch_assign_result def concat_assign_wrt_overlaps(self, - concat_overlaps, - concat_gt_labels=None, - concat_num_gts=None): - num_images, num_gts, num_bboxes = concat_overlaps.size() - concat_max_overlaps, concat_argmax_overlaps = concat_overlaps.max( - dim=1) + batch_overlaps, + batch_gt_labels=None, + batch_num_gts=None): + num_images, num_gts, num_bboxes = batch_overlaps.size() + batch_max_overlaps, batch_argmax_overlaps = batch_overlaps.max(dim=1) if isinstance(self.neg_iou_thr, float): - concat_neg_mask = \ - ((concat_max_overlaps >= 0) - & (concat_max_overlaps < self.neg_iou_thr)).int() + batch_neg_mask = \ + ((batch_max_overlaps >= 0) + & (batch_max_overlaps < self.neg_iou_thr)).int() elif isinstance(self.neg_iou_thr, tuple): assert len(self.neg_iou_thr) == 2 - concat_neg_mask = \ - ((concat_max_overlaps >= self.neg_iou_thr[0]) - & (concat_max_overlaps < self.neg_iou_thr[1])).int() + batch_neg_mask = \ + ((batch_max_overlaps >= self.neg_iou_thr[0]) + & (batch_max_overlaps < self.neg_iou_thr[1])).int() else: - concat_neg_mask = torch.zeros( - concat_max_overlaps.size(), + batch_neg_mask = torch.zeros( + batch_max_overlaps.size(), dtype=torch.int, - device=concat_max_overlaps.device) - concat_pos_mask = (concat_max_overlaps >= self.pos_iou_thr).int() + device=batch_max_overlaps.device) + batch_pos_mask = (batch_max_overlaps >= self.pos_iou_thr).int() if self.match_low_quality: - concat_gt_max_overlaps, concat_gt_argmax_overlaps = \ - concat_overlaps.max(dim=2) - concat_index_bool = (concat_gt_max_overlaps >= self.min_pos_iou) &\ - (concat_gt_max_overlaps > 0) + batch_gt_max_overlaps, batch_gt_argmax_overlaps = \ + batch_overlaps.max(dim=2) + batch_index_bool = (batch_gt_max_overlaps >= self.min_pos_iou) & \ + (batch_gt_max_overlaps > 0) if self.gt_max_assign_all: pos_inds_low_quality = \ - (concat_overlaps == concat_gt_max_overlaps.unsqueeze(2)) &\ - concat_index_bool.unsqueeze(2) + (batch_overlaps == batch_gt_max_overlaps.unsqueeze(2)) & \ + batch_index_bool.unsqueeze(2) for i in range(num_gts): pos_inds_low_quality_gt = pos_inds_low_quality[:, i, :] - concat_argmax_overlaps[pos_inds_low_quality_gt] = i - concat_pos_mask[pos_inds_low_quality_gt] = 1 + batch_argmax_overlaps[pos_inds_low_quality_gt] = i + batch_pos_mask[pos_inds_low_quality_gt] = 1 else: index_temp = torch.arange( - 0, num_gts, device=concat_max_overlaps.device) + 0, num_gts, device=batch_max_overlaps.device) for index_image in range(num_images): - gt_argmax_overlaps = concat_gt_argmax_overlaps[index_image] - index_bool = concat_index_bool[index_image] + gt_argmax_overlaps = batch_gt_argmax_overlaps[index_image] + index_bool = batch_index_bool[index_image] pos_inds_low_quality = gt_argmax_overlaps[index_bool] - concat_argmax_overlaps[index_image][pos_inds_low_quality] \ + batch_argmax_overlaps[index_image][pos_inds_low_quality] \ = index_temp[index_bool] - concat_pos_mask[index_image][pos_inds_low_quality] = 1 - concat_neg_mask = concat_neg_mask * (1 - concat_pos_mask) - if concat_gt_labels is not None: - concat_anchor_gt_labels = torch.zeros( - (num_images, num_bboxes), - dtype=concat_gt_labels.dtype, - device=concat_gt_labels.device) + batch_pos_mask[index_image][pos_inds_low_quality] = 1 + batch_neg_mask = batch_neg_mask * (1 - batch_pos_mask) + if batch_gt_labels is not None: + batch_anchor_gt_labels = torch.zeros((num_images, num_bboxes), + dtype=batch_gt_labels.dtype, + device=batch_gt_labels.device) for index_image in range(num_images): - concat_anchor_gt_labels[index_image] = torch.index_select( - concat_gt_labels[index_image], 0, - concat_argmax_overlaps[index_image]) + batch_anchor_gt_labels[index_image] = torch.index_select( + batch_gt_labels[index_image], 0, + batch_argmax_overlaps[index_image]) else: - concat_anchor_gt_labels = None - return AscendAssignResult(concat_num_gts, concat_pos_mask, - concat_neg_mask, concat_max_overlaps, - concat_argmax_overlaps, - concat_anchor_gt_labels) + batch_anchor_gt_labels = None + return AscendAssignResult(batch_num_gts, batch_pos_mask, + batch_neg_mask, batch_max_overlaps, + batch_argmax_overlaps, + batch_anchor_gt_labels) diff --git a/mmdet/models/dense_heads/ascend_anchor_head.py b/mmdet/models/dense_heads/ascend_anchor_head.py index f52d1b93e78..d2ae41e8766 100644 --- a/mmdet/models/dense_heads/ascend_anchor_head.py +++ b/mmdet/models/dense_heads/ascend_anchor_head.py @@ -3,7 +3,7 @@ from ...core.bbox.assigners import AscendMaxIoUAssigner from ...core.bbox.samplers import PseudoSampler -from ...utils import generate_max_gt_nums, images_to_levels, set_index +from ...utils import batch_images_to_levels, get_max_num_gt, set_index from ..builder import HEADS from .anchor_head import AnchorHead @@ -68,8 +68,8 @@ def __init__(self, test_cfg=test_cfg, init_cfg=init_cfg) - def _get_concat_gt_bboxes(self, gt_bboxes_list, num_images, gt_nums, - device, max_gt_labels): + def get_batch_gt_bboxes(self, gt_bboxes_list, num_images, gt_nums, device, + max_gt_labels): """Get ground truth bboxes of all image. Args: @@ -79,31 +79,31 @@ def _get_concat_gt_bboxes(self, gt_bboxes_list, num_images, gt_nums, device (torch.device | str): Device for returned tensors max_gt_labels(int): The max ground truth bboxes num of all image. Returns: - concat_gt_bboxes: (Tensor): Ground truth bboxes of all image. + batch_gt_bboxes: (Tensor): Ground truth bboxes of all image. """ - if not hasattr(self, 'concat_gt_bboxes'): - self.concat_gt_bboxes = {} + if not hasattr(self, 'batch_gt_bboxes'): + self.batch_gt_bboxes = {} if not hasattr(self, 'min_anchor'): self.min_anchor = (-1354, -1344) if gt_bboxes_list is None: - concat_gt_bboxes = None + batch_gt_bboxes = None else: - if self.concat_gt_bboxes.get(max_gt_labels) is None: - concat_gt_bboxes = torch.zeros((num_images, max_gt_labels, 4), - dtype=gt_bboxes_list[0].dtype, - device=device) - concat_gt_bboxes[:, :, :2] = self.min_anchor[0] - concat_gt_bboxes[:, :, 2:] = self.min_anchor[1] - self.concat_gt_bboxes[max_gt_labels] = concat_gt_bboxes.clone() + if self.batch_gt_bboxes.get(max_gt_labels) is None: + batch_gt_bboxes = torch.zeros((num_images, max_gt_labels, 4), + dtype=gt_bboxes_list[0].dtype, + device=device) + batch_gt_bboxes[:, :, :2] = self.min_anchor[0] + batch_gt_bboxes[:, :, 2:] = self.min_anchor[1] + self.batch_gt_bboxes[max_gt_labels] = batch_gt_bboxes.clone() else: - concat_gt_bboxes = self.concat_gt_bboxes.get( + batch_gt_bboxes = self.batch_gt_bboxes.get( max_gt_labels).clone() for index_imgs, gt_bboxes in enumerate(gt_bboxes_list): - concat_gt_bboxes[index_imgs, :gt_nums[index_imgs]] = gt_bboxes - return concat_gt_bboxes + batch_gt_bboxes[index_imgs, :gt_nums[index_imgs]] = gt_bboxes + return batch_gt_bboxes - def _get_concat_gt_bboxes_ignore(self, gt_bboxes_ignore_list, num_images, - gt_nums, device): + def get_batch_gt_bboxes_ignore(self, gt_bboxes_ignore_list, num_images, + gt_nums, device): """Ground truth bboxes to be ignored of all image. Args: @@ -113,18 +113,18 @@ def _get_concat_gt_bboxes_ignore(self, gt_bboxes_ignore_list, num_images, gt_nums(list[int]): The ground truth bboxes num of each image. device (torch.device | str): Device for returned tensors Returns: - concat_gt_bboxes_ignore: (Tensor): Ground truth bboxes to be + batch_gt_bboxes_ignore: (Tensor): Ground truth bboxes to be ignored of all image. """ # TODO: support gt_bboxes_ignore_list if gt_bboxes_ignore_list is None: - concat_gt_bboxes_ignore = None + batch_gt_bboxes_ignore = None else: raise RuntimeError('gt_bboxes_ignore not support yet') - return concat_gt_bboxes_ignore + return batch_gt_bboxes_ignore - def _get_concat_gt_labels(self, gt_labels_list, num_images, gt_nums, - device, max_gt_labels): + def get_batch_gt_labels(self, gt_labels_list, num_images, gt_nums, device, + max_gt_labels): """Ground truth bboxes to be ignored of all image. Args: @@ -133,25 +133,25 @@ def _get_concat_gt_labels(self, gt_labels_list, num_images, gt_nums, gt_nums(list[int]): The ground truth bboxes num of each image. device (torch.device | str): Device for returned tensors Returns: - concat_gt_labels: (Tensor): Ground truth labels of all image. + batch_gt_labels: (Tensor): Ground truth labels of all image. """ if gt_labels_list is None: - concat_gt_labels = None + batch_gt_labels = None else: - concat_gt_labels = torch.zeros((num_images, max_gt_labels), - dtype=gt_labels_list[0].dtype, - device=device) + batch_gt_labels = torch.zeros((num_images, max_gt_labels), + dtype=gt_labels_list[0].dtype, + device=device) for index_imgs, gt_labels in enumerate(gt_labels_list): - concat_gt_labels[index_imgs, :gt_nums[index_imgs]] = gt_labels + batch_gt_labels[index_imgs, :gt_nums[index_imgs]] = gt_labels - return concat_gt_labels + return batch_gt_labels def _get_targets_concat(self, - concat_anchors, - concat_valid_flags, - concat_gt_bboxes, - concat_gt_bboxes_ignore, - concat_gt_labels, + batch_anchors, + batch_valid_flags, + batch_gt_bboxes, + batch_gt_bboxes_ignore, + batch_gt_labels, img_metas, label_channels=1, unmap_outputs=True): @@ -159,17 +159,17 @@ def _get_targets_concat(self, images. Args: - concat_anchors (Tensor): anchors of all image, which are + batch_anchors (Tensor): anchors of all image, which are concatenated into a single tensor of shape (num_imgs, num_anchors ,4). - concat_valid_flags (Tensor): valid flags of all image, + batch_valid_flags (Tensor): valid flags of all image, which are concatenated into a single tensor of shape (num_imgs, num_anchors,). - concat_gt_bboxes (Tensor): Ground truth bboxes of all image, + batch_gt_bboxes (Tensor): Ground truth bboxes of all image, shape (num_imgs, max_gt_nums, 4). - concat_gt_bboxes_ignore (Tensor): Ground truth bboxes to be + batch_gt_bboxes_ignore (Tensor): Ground truth bboxes to be ignored, shape (num_imgs, num_ignored_gts, 4). - concat_gt_labels (Tensor): Ground truth labels of each box, + batch_gt_labels (Tensor): Ground truth labels of each box, shape (num_imgs, max_gt_nums,). img_metas (list[dict]): Meta info of each image. label_channels (int): Channel of label. @@ -178,74 +178,74 @@ def _get_targets_concat(self, Returns: tuple: - concat_labels (Tensor): Labels of all level - concat_label_weights (Tensor): Label weights of all level - concat_bbox_targets (Tensor): BBox targets of all level - concat_bbox_weights (Tensor): BBox weights of all level - concat_pos_mask (Tensor): Positive samples mask in all images - concat_neg_mask (Tensor): Negative samples mask in all images + batch_labels (Tensor): Labels of all level + batch_label_weights (Tensor): Label weights of all level + batch_bbox_targets (Tensor): BBox targets of all level + batch_bbox_weights (Tensor): BBox weights of all level + batch_pos_mask (Tensor): Positive samples mask in all images + batch_neg_mask (Tensor): Negative samples mask in all images sampling_result (Sampling): The result of sampling, default: None. """ - num_imgs, num_anchors, _ = concat_anchors.size() - # assign gt and sample concat_anchors + num_imgs, num_anchors, _ = batch_anchors.size() + # assign gt and sample batch_anchors assign_result = self.assigner.assign( - concat_anchors, - concat_gt_bboxes, - concat_gt_bboxes_ignore, - None if self.sampling else concat_gt_labels, - concat_bboxes_ignore_mask=concat_valid_flags) + batch_anchors, + batch_gt_bboxes, + batch_gt_bboxes_ignore, + None if self.sampling else batch_gt_labels, + batch_bboxes_ignore_mask=batch_valid_flags) # TODO: support sampling_result sampling_result = None - concat_pos_mask = assign_result.concat_pos_mask - concat_neg_mask = assign_result.concat_neg_mask - concat_anchor_gt_indes = assign_result.concat_anchor_gt_indes - concat_anchor_gt_labels = assign_result.concat_anchor_gt_labels + batch_pos_mask = assign_result.batch_pos_mask + batch_neg_mask = assign_result.batch_neg_mask + batch_anchor_gt_indes = assign_result.batch_anchor_gt_indes + batch_anchor_gt_labels = assign_result.batch_anchor_gt_labels - concat_anchor_gt_bboxes = torch.zeros( - concat_anchors.size(), - dtype=concat_anchors.dtype, - device=concat_anchors.device) + batch_anchor_gt_bboxes = torch.zeros( + batch_anchors.size(), + dtype=batch_anchors.dtype, + device=batch_anchors.device) for index_imgs in range(num_imgs): - concat_anchor_gt_bboxes[index_imgs] = torch.index_select( - concat_gt_bboxes[index_imgs], 0, - concat_anchor_gt_indes[index_imgs]) + batch_anchor_gt_bboxes[index_imgs] = torch.index_select( + batch_gt_bboxes[index_imgs], 0, + batch_anchor_gt_indes[index_imgs]) - concat_bbox_targets = torch.zeros_like(concat_anchors) - concat_bbox_weights = torch.zeros_like(concat_anchors) - concat_labels = concat_anchors.new_full((num_imgs, num_anchors), - self.num_classes, - dtype=torch.int) - concat_label_weights = concat_anchors.new_zeros( - (num_imgs, num_anchors), dtype=torch.float) + batch_bbox_targets = torch.zeros_like(batch_anchors) + batch_bbox_weights = torch.zeros_like(batch_anchors) + batch_labels = batch_anchors.new_full((num_imgs, num_anchors), + self.num_classes, + dtype=torch.int) + batch_label_weights = batch_anchors.new_zeros((num_imgs, num_anchors), + dtype=torch.float) if not self.reg_decoded_bbox: - concat_pos_bbox_targets = self.bbox_coder.encode( - concat_anchors, concat_anchor_gt_bboxes) + batch_pos_bbox_targets = self.bbox_coder.encode( + batch_anchors, batch_anchor_gt_bboxes) else: - concat_pos_bbox_targets = concat_anchor_gt_bboxes + batch_pos_bbox_targets = batch_anchor_gt_bboxes - concat_bbox_targets = set_index(concat_bbox_targets, - concat_pos_mask.unsqueeze(2), - concat_pos_bbox_targets) - concat_bbox_weights = set_index(concat_bbox_weights, - concat_pos_mask.unsqueeze(2), 1.0) - if concat_gt_labels is None: - concat_labels = set_index(concat_labels, concat_pos_mask, 0.0) + batch_bbox_targets = set_index(batch_bbox_targets, + batch_pos_mask.unsqueeze(2), + batch_pos_bbox_targets) + batch_bbox_weights = set_index(batch_bbox_weights, + batch_pos_mask.unsqueeze(2), 1.0) + if batch_gt_labels is None: + batch_labels = set_index(batch_labels, batch_pos_mask, 0.0) else: - concat_labels = set_index(concat_labels, concat_pos_mask, - concat_anchor_gt_labels) + batch_labels = set_index(batch_labels, batch_pos_mask, + batch_anchor_gt_labels) if self.train_cfg.pos_weight <= 0: - concat_label_weights = set_index(concat_label_weights, - concat_pos_mask, 1.0) + batch_label_weights = set_index(batch_label_weights, + batch_pos_mask, 1.0) else: - concat_label_weights = set_index(concat_label_weights, - concat_pos_mask, - self.train_cfg.pos_weight) - concat_label_weights = set_index(concat_label_weights, concat_neg_mask, - 1.0) - return (concat_labels, concat_label_weights, concat_bbox_targets, - concat_bbox_weights, concat_pos_mask, concat_neg_mask, + batch_label_weights = set_index(batch_label_weights, + batch_pos_mask, + self.train_cfg.pos_weight) + batch_label_weights = set_index(batch_label_weights, batch_neg_mask, + 1.0) + return (batch_labels, batch_label_weights, batch_bbox_targets, + batch_bbox_weights, batch_pos_mask, batch_neg_mask, sampling_result) def get_targets(self, @@ -313,72 +313,73 @@ def get_targets(self, device = anchor_list[0][0].device num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] - concat_anchor_list = [] - concat_valid_flag_list = [] + batch_anchor_list = [] + batch_valid_flag_list = [] for i in range(num_imgs): assert len(anchor_list[i]) == len(valid_flag_list[i]) - concat_anchor_list.append(torch.cat(anchor_list[i])) - concat_valid_flag_list.append(torch.cat(valid_flag_list[i])) - concat_anchors = torch.cat( - [torch.unsqueeze(anchor, 0) for anchor in concat_anchor_list], 0) - concat_valid_flags = torch.cat([ - torch.unsqueeze(concat_valid_flag, 0) - for concat_valid_flag in concat_valid_flag_list + batch_anchor_list.append(torch.cat(anchor_list[i])) + batch_valid_flag_list.append(torch.cat(valid_flag_list[i])) + batch_anchors = torch.cat( + [torch.unsqueeze(anchor, 0) for anchor in batch_anchor_list], 0) + batch_valid_flags = torch.cat([ + torch.unsqueeze(batch_valid_flag, 0) + for batch_valid_flag in batch_valid_flag_list ], 0) gt_nums = [len(gt_bbox) for gt_bbox in gt_bboxes_list] - max_gt_nums = generate_max_gt_nums(gt_nums) - concat_gt_bboxes = self._get_concat_gt_bboxes(gt_bboxes_list, num_imgs, - gt_nums, device, - max_gt_nums) - concat_gt_bboxes_ignore = self._get_concat_gt_bboxes_ignore( + max_gt_nums = get_max_num_gt(gt_nums) + batch_gt_bboxes = self.get_batch_gt_bboxes(gt_bboxes_list, num_imgs, + gt_nums, device, + max_gt_nums) + batch_gt_bboxes_ignore = self.get_batch_gt_bboxes_ignore( gt_bboxes_ignore_list, num_imgs, gt_nums, device) - concat_gt_labels = self._get_concat_gt_labels(gt_labels_list, num_imgs, - gt_nums, device, - max_gt_nums) + batch_gt_labels = self.get_batch_gt_labels(gt_labels_list, num_imgs, + gt_nums, device, + max_gt_nums) results = self._get_targets_concat( - concat_anchors, - concat_valid_flags, - concat_gt_bboxes, - concat_gt_bboxes_ignore, - concat_gt_labels, + batch_anchors, + batch_valid_flags, + batch_gt_bboxes, + batch_gt_bboxes_ignore, + batch_gt_labels, img_metas, label_channels=label_channels, unmap_outputs=unmap_outputs) - (concat_labels, concat_label_weights, concat_bbox_targets, - concat_bbox_weights, concat_pos_mask, concat_neg_mask, + (batch_labels, batch_label_weights, batch_bbox_targets, + batch_bbox_weights, batch_pos_mask, batch_neg_mask, sampling_result) = results[:7] rest_results = list(results[7:]) # user-added return values # sampled anchors of all images min_num = torch.ones((num_imgs, ), dtype=torch.long, - device=concat_pos_mask.device) + device=batch_pos_mask.device) num_total_pos = torch.sum( - torch.max(torch.sum(concat_pos_mask, dim=1), min_num)) + torch.max(torch.sum(batch_pos_mask, dim=1), min_num)) num_total_neg = torch.sum( - torch.max(torch.sum(concat_neg_mask, dim=1), min_num)) + torch.max(torch.sum(batch_neg_mask, dim=1), min_num)) if return_level is True: - labels_list = images_to_levels(concat_labels, num_level_anchors) - label_weights_list = images_to_levels(concat_label_weights, - num_level_anchors) - bbox_targets_list = images_to_levels(concat_bbox_targets, - num_level_anchors) - bbox_weights_list = images_to_levels(concat_bbox_weights, + labels_list = batch_images_to_levels(batch_labels, num_level_anchors) + label_weights_list = batch_images_to_levels( + batch_label_weights, num_level_anchors) + bbox_targets_list = batch_images_to_levels(batch_bbox_targets, + num_level_anchors) + bbox_weights_list = batch_images_to_levels(batch_bbox_weights, + num_level_anchors) res = (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) if return_sampling_results: res = res + (sampling_result, ) for i, r in enumerate(rest_results): # user-added return values - rest_results[i] = images_to_levels(r, num_level_anchors) + rest_results[i] = batch_images_to_levels(r, num_level_anchors) return res + tuple(rest_results) else: - res = (concat_labels, concat_label_weights, concat_bbox_targets, - concat_bbox_weights, concat_pos_mask, concat_neg_mask, + res = (batch_labels, batch_label_weights, batch_bbox_targets, + batch_bbox_weights, batch_pos_mask, batch_neg_mask, sampling_result, num_total_pos, num_total_neg, - concat_anchors) + batch_anchors) return res diff --git a/mmdet/models/dense_heads/ascend_ssd_head.py b/mmdet/models/dense_heads/ascend_ssd_head.py index 66d7f4dd7eb..9e326b48bc1 100644 --- a/mmdet/models/dense_heads/ascend_ssd_head.py +++ b/mmdet/models/dense_heads/ascend_ssd_head.py @@ -177,29 +177,29 @@ def get_targets(self, return_level, ) - def concat_loss(self, concat_cls_score, concat_bbox_pred, concat_anchor, - concat_labels, concat_label_weights, concat_bbox_targets, - concat_bbox_weights, concat_pos_mask, concat_neg_mask, - num_total_samples): + def batch_loss(self, batch_cls_score, batch_bbox_pred, batch_anchor, + batch_labels, batch_label_weights, batch_bbox_targets, + batch_bbox_weights, batch_pos_mask, batch_neg_mask, + num_total_samples): """Compute loss of all images. Args: - concat_cls_score (Tensor): Box scores for all image + batch_cls_score (Tensor): Box scores for all image Has shape (num_imgs, num_total_anchors, num_classes). - concat_bbox_pred (Tensor): Box energies / deltas for all image + batch_bbox_pred (Tensor): Box energies / deltas for all image level with shape (num_imgs, num_total_anchors, 4). - concat_anchor (Tensor): Box reference for all image with shape + batch_anchor (Tensor): Box reference for all image with shape (num_imgs, num_total_anchors, 4). - concat_labels (Tensor): Labels of all anchors with shape + batch_labels (Tensor): Labels of all anchors with shape (num_imgs, num_total_anchors,). - concat_label_weights (Tensor): Label weights of all anchor with + batch_label_weights (Tensor): Label weights of all anchor with shape (num_imgs, num_total_anchors,) - concat_bbox_targets (Tensor): BBox regression targets of all anchor + batch_bbox_targets (Tensor): BBox regression targets of all anchor weight shape (num_imgs, num_total_anchors, 4). - concat_bbox_weights (Tensor): BBox regression loss weights of + batch_bbox_weights (Tensor): BBox regression loss weights of all anchor with shape (num_imgs, num_total_anchors, 4). - concat_pos_mask (Tensor): Positive samples mask in all images. - concat_neg_mask (Tensor): negative samples mask in all images. + batch_pos_mask (Tensor): Positive samples mask in all images. + batch_neg_mask (Tensor): negative samples mask in all images. num_total_samples (int): If sampling, num total samples equal to the number of total anchors; Otherwise, it is the number of positive anchors. @@ -207,46 +207,46 @@ def concat_loss(self, concat_cls_score, concat_bbox_pred, concat_anchor, Returns: dict[str, Tensor]: A dictionary of loss components. """ - num_images, num_anchors, _ = concat_anchor.size() + num_images, num_anchors, _ = batch_anchor.size() - concat_loss_cls_all = F.cross_entropy( - concat_cls_score.view((-1, self.cls_out_channels)), - concat_labels.view(-1), + batch_loss_cls_all = F.cross_entropy( + batch_cls_score.view((-1, self.cls_out_channels)), + batch_labels.view(-1), reduction='none').view( - concat_label_weights.size()) * concat_label_weights + batch_label_weights.size()) * batch_label_weights # # FG cat_id: [0, num_classes -1], BG cat_id: num_classes - concat_num_pos_samples = torch.sum(concat_pos_mask, dim=1) - concat_num_neg_samples = \ - self.train_cfg.neg_pos_ratio * concat_num_pos_samples + batch_num_pos_samples = torch.sum(batch_pos_mask, dim=1) + batch_num_neg_samples = \ + self.train_cfg.neg_pos_ratio * batch_num_pos_samples - concat_num_neg_samples_max = torch.sum(concat_neg_mask, dim=1) - concat_num_neg_samples = torch.min(concat_num_neg_samples, - concat_num_neg_samples_max) + batch_num_neg_samples_max = torch.sum(batch_neg_mask, dim=1) + batch_num_neg_samples = torch.min(batch_num_neg_samples, + batch_num_neg_samples_max) - concat_topk_loss_cls_neg, _ = torch.topk( - concat_loss_cls_all * concat_neg_mask, k=num_anchors, dim=1) - concat_loss_cls_pos = torch.sum( - concat_loss_cls_all * concat_pos_mask, dim=1) + batch_topk_loss_cls_neg, _ = torch.topk( + batch_loss_cls_all * batch_neg_mask, k=num_anchors, dim=1) + batch_loss_cls_pos = torch.sum( + batch_loss_cls_all * batch_pos_mask, dim=1) anchor_index = torch.arange( end=num_anchors, dtype=torch.float, - device=concat_anchor.device).view((1, -1)) - topk_loss_neg_mask = (anchor_index < concat_num_neg_samples.view( + device=batch_anchor.device).view((1, -1)) + topk_loss_neg_mask = (anchor_index < batch_num_neg_samples.view( -1, 1)).float() - concat_loss_cls_neg = torch.sum( - concat_topk_loss_cls_neg * topk_loss_neg_mask, dim=1) + batch_loss_cls_neg = torch.sum( + batch_topk_loss_cls_neg * topk_loss_neg_mask, dim=1) loss_cls = \ - (concat_loss_cls_pos + concat_loss_cls_neg) / num_total_samples + (batch_loss_cls_pos + batch_loss_cls_neg) / num_total_samples if self.reg_decoded_bbox: # TODO: support self.reg_decoded_bbox is True raise RuntimeError loss_bbox_all = smooth_l1_loss( - concat_bbox_pred, - concat_bbox_targets, - concat_bbox_weights, + batch_bbox_pred, + batch_bbox_targets, + batch_bbox_weights, reduction='none', beta=self.train_cfg.smoothl1_beta, avg_factor=num_total_samples) @@ -303,27 +303,26 @@ def loss(self, if cls_reg_targets is None: return None - (concat_labels, concat_label_weights, concat_bbox_targets, - concat_bbox_weights, concat_pos_mask, concat_neg_mask, - sampling_result, num_total_pos, num_total_neg, - concat_anchors) = cls_reg_targets + (batch_labels, batch_label_weights, batch_bbox_targets, + batch_bbox_weights, batch_pos_mask, batch_neg_mask, sampling_result, + num_total_pos, num_total_neg, batch_anchors) = cls_reg_targets num_imgs = len(img_metas) - concat_cls_score = torch.cat([ + batch_cls_score = torch.cat([ s.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.cls_out_channels) for s in cls_scores ], 1) - concat_bbox_pred = torch.cat([ + batch_bbox_pred = torch.cat([ b.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) for b in bbox_preds ], -2) - concat_losses_cls, concat_losses_bbox = self.concat_loss( - concat_cls_score, concat_bbox_pred, concat_anchors, concat_labels, - concat_label_weights, concat_bbox_targets, concat_bbox_weights, - concat_pos_mask, concat_neg_mask, num_total_pos) + batch_losses_cls, batch_losses_bbox = self.batch_loss( + batch_cls_score, batch_bbox_pred, batch_anchors, batch_labels, + batch_label_weights, batch_bbox_targets, batch_bbox_weights, + batch_pos_mask, batch_neg_mask, num_total_pos) losses_cls = [ - concat_losses_cls[:, index_imgs] for index_imgs in range(num_imgs) + batch_losses_cls[:, index_imgs] for index_imgs in range(num_imgs) ] - losses_bbox = [losses_bbox for losses_bbox in concat_losses_bbox] + losses_bbox = [losses_bbox for losses_bbox in batch_losses_bbox] return dict(loss_cls=losses_cls, loss_bbox=losses_bbox) diff --git a/mmdet/utils/__init__.py b/mmdet/utils/__init__.py index c20f1175574..f9231e70dfd 100644 --- a/mmdet/utils/__init__.py +++ b/mmdet/utils/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .ascend_util import generate_max_gt_nums, images_to_levels, set_index +from .ascend_util import batch_images_to_levels, get_max_num_gt, set_index from .collect_env import collect_env from .compat_config import compat_cfg from .logger import get_caller_name, get_root_logger, log_img_scale @@ -15,5 +15,5 @@ 'update_data_root', 'setup_multi_processes', 'get_caller_name', 'log_img_scale', 'compat_cfg', 'split_batch', 'build_ddp', 'build_dp', 'get_device', 'replace_cfg_vals', 'AvoidOOM', 'AvoidCUDAOOM', - 'generate_max_gt_nums', 'set_index', 'images_to_levels' + 'get_max_num_gt', 'set_index', 'batch_images_to_levels' ] diff --git a/mmdet/utils/ascend_util.py b/mmdet/utils/ascend_util.py index fa35425613f..bc5effddb0d 100644 --- a/mmdet/utils/ascend_util.py +++ b/mmdet/utils/ascend_util.py @@ -12,10 +12,11 @@ def set_index(ori_tensor, mask, new_value, neg=False): return ori_tensor * (1 - mask) + new_value * mask -def images_to_levels(target, num_levels): +def batch_images_to_levels(target, num_levels): """Convert targets by image to targets by feature level. - [target_img0, target_img1] -> [target_level0, target_level1, ...] + [target_img0, target_img1] -> [target_level0, target_level1, ...] or + target_imgs -> [target_level0, target_level1, ...] """ if not isinstance(target, torch.Tensor): target = torch.stack(target, 0) @@ -29,11 +30,11 @@ def images_to_levels(target, num_levels): return level_targets -def generate_max_gt_nums(gt_nums, minimum_gt_nums=32, maximum_gt_nums=1024): +def get_max_num_gt(gt_nums, min_num_gt=32, max_num_gt=1024): max_gt_nums = max(gt_nums) - max_gt_nums_align = minimum_gt_nums + max_gt_nums_align = min_num_gt while max_gt_nums_align < max_gt_nums: max_gt_nums_align *= 2 - if max_gt_nums_align > maximum_gt_nums: + if max_gt_nums_align > max_num_gt: raise RuntimeError return max_gt_nums_align From 2f56366bdb59ec6c35495a41da4a0304ccf69f04 Mon Sep 17 00:00:00 2001 From: akstt <13146229470@163.com> Date: Tue, 17 Jan 2023 10:35:40 +0800 Subject: [PATCH 06/13] resolve view comments --- mmdet/core/bbox/assigners/ascend_max_iou_assigner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py b/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py index 2d4190c6ccd..d4239fb4be3 100644 --- a/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py +++ b/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py @@ -115,10 +115,10 @@ def assign(self, batch_overlaps, batch_gt_labels, batch_num_gts) return batch_assign_result - def concat_assign_wrt_overlaps(self, - batch_overlaps, - batch_gt_labels=None, - batch_num_gts=None): + def batch_assign_wrt_overlaps(self, + batch_overlaps, + batch_gt_labels=None, + batch_num_gts=None): num_images, num_gts, num_bboxes = batch_overlaps.size() batch_max_overlaps, batch_argmax_overlaps = batch_overlaps.max(dim=1) if isinstance(self.neg_iou_thr, float): From 49b2657d9ef69b071782fc28bfbc76632cf26af0 Mon Sep 17 00:00:00 2001 From: akstt <13146229470@163.com> Date: Tue, 17 Jan 2023 19:16:29 +0800 Subject: [PATCH 07/13] resolve view comments --- .../bbox/assigners/ascend_max_iou_assigner.py | 16 ++++++---- .../models/dense_heads/ascend_anchor_head.py | 32 +++++++++---------- mmdet/utils/__init__.py | 4 +-- mmdet/utils/ascend_util.py | 27 +++++++++++++++- 4 files changed, 53 insertions(+), 26 deletions(-) diff --git a/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py b/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py index d4239fb4be3..f8f528aead6 100644 --- a/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py +++ b/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from ....utils import set_index +from ....utils import masked_fill from ..builder import BBOX_ASSIGNERS from ..iou_calculators import build_iou_calculator from .ascend_assign_result import AscendAssignResult @@ -89,7 +89,7 @@ def assign(self, :obj:`AssignResult`: The assign result. """ batch_overlaps = self.iou_calculator(batch_gt_bboxes, batch_bboxes) - batch_overlaps = set_index( + batch_overlaps = masked_fill( batch_overlaps, batch_bboxes_ignore_mask.unsqueeze(1).float(), -1, @@ -98,19 +98,21 @@ def assign(self, if self.ignore_wrt_candidates: batch_ignore_overlaps = self.iou_calculator( batch_bboxes, batch_gt_bboxes_ignore, mode='iof') - batch_ignore_overlaps = set_index(batch_ignore_overlaps, - batch_bboxes_ignore_mask, -1) + batch_ignore_overlaps = masked_fill(batch_ignore_overlaps, + batch_bboxes_ignore_mask, + -1) batch_ignore_max_overlaps, _ = batch_ignore_overlaps.max(dim=2) else: batch_ignore_overlaps = self.iou_calculator( batch_gt_bboxes_ignore, batch_bboxes, mode='iof') - batch_ignore_overlaps = set_index(batch_ignore_overlaps, - batch_bboxes_ignore_mask, -1) + batch_ignore_overlaps = masked_fill(batch_ignore_overlaps, + batch_bboxes_ignore_mask, + -1) batch_ignore_max_overlaps, _ = \ batch_ignore_overlaps.max(dim=1) batch_ignore_mask = \ batch_ignore_max_overlaps > self.ignore_iof_thr - batch_overlaps = set_index(batch_overlaps, batch_ignore_mask, -1) + batch_overlaps = masked_fill(batch_overlaps, batch_ignore_mask, -1) batch_assign_result = self.batch_assign_wrt_overlaps( batch_overlaps, batch_gt_labels, batch_num_gts) return batch_assign_result diff --git a/mmdet/models/dense_heads/ascend_anchor_head.py b/mmdet/models/dense_heads/ascend_anchor_head.py index d2ae41e8766..e6b5f6ae022 100644 --- a/mmdet/models/dense_heads/ascend_anchor_head.py +++ b/mmdet/models/dense_heads/ascend_anchor_head.py @@ -3,7 +3,7 @@ from ...core.bbox.assigners import AscendMaxIoUAssigner from ...core.bbox.samplers import PseudoSampler -from ...utils import batch_images_to_levels, get_max_num_gt, set_index +from ...utils import batch_images_to_levels, get_max_num_gt, masked_fill from ..builder import HEADS from .anchor_head import AnchorHead @@ -225,25 +225,25 @@ def _get_targets_concat(self, else: batch_pos_bbox_targets = batch_anchor_gt_bboxes - batch_bbox_targets = set_index(batch_bbox_targets, - batch_pos_mask.unsqueeze(2), - batch_pos_bbox_targets) - batch_bbox_weights = set_index(batch_bbox_weights, - batch_pos_mask.unsqueeze(2), 1.0) + batch_bbox_targets = masked_fill(batch_bbox_targets, + batch_pos_mask.unsqueeze(2), + batch_pos_bbox_targets) + batch_bbox_weights = masked_fill(batch_bbox_weights, + batch_pos_mask.unsqueeze(2), 1.0) if batch_gt_labels is None: - batch_labels = set_index(batch_labels, batch_pos_mask, 0.0) + batch_labels = masked_fill(batch_labels, batch_pos_mask, 0.0) else: - batch_labels = set_index(batch_labels, batch_pos_mask, - batch_anchor_gt_labels) + batch_labels = masked_fill(batch_labels, batch_pos_mask, + batch_anchor_gt_labels) if self.train_cfg.pos_weight <= 0: - batch_label_weights = set_index(batch_label_weights, - batch_pos_mask, 1.0) + batch_label_weights = masked_fill(batch_label_weights, + batch_pos_mask, 1.0) else: - batch_label_weights = set_index(batch_label_weights, - batch_pos_mask, - self.train_cfg.pos_weight) - batch_label_weights = set_index(batch_label_weights, batch_neg_mask, - 1.0) + batch_label_weights = masked_fill(batch_label_weights, + batch_pos_mask, + self.train_cfg.pos_weight) + batch_label_weights = masked_fill(batch_label_weights, batch_neg_mask, + 1.0) return (batch_labels, batch_label_weights, batch_bbox_targets, batch_bbox_weights, batch_pos_mask, batch_neg_mask, sampling_result) diff --git a/mmdet/utils/__init__.py b/mmdet/utils/__init__.py index f9231e70dfd..c4c4e2b81c2 100644 --- a/mmdet/utils/__init__.py +++ b/mmdet/utils/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .ascend_util import batch_images_to_levels, get_max_num_gt, set_index +from .ascend_util import batch_images_to_levels, get_max_num_gt, masked_fill from .collect_env import collect_env from .compat_config import compat_cfg from .logger import get_caller_name, get_root_logger, log_img_scale @@ -15,5 +15,5 @@ 'update_data_root', 'setup_multi_processes', 'get_caller_name', 'log_img_scale', 'compat_cfg', 'split_batch', 'build_ddp', 'build_dp', 'get_device', 'replace_cfg_vals', 'AvoidOOM', 'AvoidCUDAOOM', - 'get_max_num_gt', 'set_index', 'batch_images_to_levels' + 'get_max_num_gt', 'masked_fill', 'batch_images_to_levels' ] diff --git a/mmdet/utils/ascend_util.py b/mmdet/utils/ascend_util.py index bc5effddb0d..ff50b569412 100644 --- a/mmdet/utils/ascend_util.py +++ b/mmdet/utils/ascend_util.py @@ -2,7 +2,18 @@ import torch -def set_index(ori_tensor, mask, new_value, neg=False): +def masked_fill(ori_tensor, mask, new_value, neg=False): + """The Value of ori_tensor is new_value, depending on mask. + + Args: + ori_tensor (Tensor): Input tensor. + mask (Tensor): If select new_value. + new_value(Tensor | scalar): Value selected for ori_tensor. + neg (bool): If True, select ori_tensor. If False, select new_value. + Returns: + ori_tensor: (Tensor): The Value of ori_tensor is new_value, + depending on mask. + """ if mask is None: return ori_tensor else: @@ -17,6 +28,11 @@ def batch_images_to_levels(target, num_levels): [target_img0, target_img1] -> [target_level0, target_level1, ...] or target_imgs -> [target_level0, target_level1, ...] + Args: + target (Tensor | List[Tensor]): Tensor split to image levels. + num_levels (List[int]): Image levels num. + Returns: + level_targets: (Tensor): Tensor split by image levels. """ if not isinstance(target, torch.Tensor): target = torch.stack(target, 0) @@ -31,6 +47,15 @@ def batch_images_to_levels(target, num_levels): def get_max_num_gt(gt_nums, min_num_gt=32, max_num_gt=1024): + """Count max num of gt. + + Args: + gt_nums (List[int]): Ground truth bboxes num of images. + min_num_gt (int): Min num of ground truth bboxes. + max_num_gt (int): Max num of ground truth bboxes. + Returns: + max_gt_nums_align: (int): max num of ground truth bboxes. + """ max_gt_nums = max(gt_nums) max_gt_nums_align = min_num_gt while max_gt_nums_align < max_gt_nums: From bc00aed7321565a5a5037d560195c547e98c0804 Mon Sep 17 00:00:00 2001 From: akstt <13146229470@163.com> Date: Tue, 17 Jan 2023 20:33:27 +0800 Subject: [PATCH 08/13] resolve view comments --- mmdet/core/bbox/assigners/ascend_assign_result.py | 2 +- mmdet/models/dense_heads/ascend_anchor_head.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mmdet/core/bbox/assigners/ascend_assign_result.py b/mmdet/core/bbox/assigners/ascend_assign_result.py index c481b67d9a3..03d33c2b59a 100644 --- a/mmdet/core/bbox/assigners/ascend_assign_result.py +++ b/mmdet/core/bbox/assigners/ascend_assign_result.py @@ -11,7 +11,7 @@ class AscendAssignResult(util_mixins.NiceRepr): batch_neg_mask (IntTensor): Negative samples mask in all images. batch_max_overlaps (FloatTensor): The max overlaps of all bboxes and ground truth boxes. - batch_anchor_gt_indes(None | LongTensor): The the assigned truth + batch_anchor_gt_indes(None | LongTensor): The assigned truth box index of all anchors. batch_anchor_gt_labels(None | LongTensor): The gt labels of all anchors diff --git a/mmdet/models/dense_heads/ascend_anchor_head.py b/mmdet/models/dense_heads/ascend_anchor_head.py index e6b5f6ae022..7c379899300 100644 --- a/mmdet/models/dense_heads/ascend_anchor_head.py +++ b/mmdet/models/dense_heads/ascend_anchor_head.py @@ -81,8 +81,10 @@ def get_batch_gt_bboxes(self, gt_bboxes_list, num_images, gt_nums, device, Returns: batch_gt_bboxes: (Tensor): Ground truth bboxes of all image. """ + # a static ground truth boxes if not hasattr(self, 'batch_gt_bboxes'): self.batch_gt_bboxes = {} + # a min anchor filled the excess anchor if not hasattr(self, 'min_anchor'): self.min_anchor = (-1354, -1344) if gt_bboxes_list is None: From d6f8776ec4ac00064b6d742c20b7a8e70094ba86 Mon Sep 17 00:00:00 2001 From: akstt <13146229470@163.com> Date: Wed, 18 Jan 2023 09:59:02 +0800 Subject: [PATCH 09/13] resolve view comments --- tests/test_utils/test_assigner.py | 43 +++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/test_utils/test_assigner.py b/tests/test_utils/test_assigner.py index a53d5304b0a..43adcb35729 100644 --- a/tests/test_utils/test_assigner.py +++ b/tests/test_utils/test_assigner.py @@ -9,6 +9,7 @@ import torch from mmdet.core.bbox.assigners import (ApproxMaxIoUAssigner, + AscendMaxIoUAssigner, CenterRegionAssigner, HungarianAssigner, MaskHungarianAssigner, MaxIoUAssigner, PointAssigner, SimOTAAssigner, @@ -661,3 +662,45 @@ def test_mask_hungarian_match_assigner(): dice_cost=dict(type='DiceCost', weight=0.0, pred_act=True, eps=1.0)) with pytest.raises(AssertionError): self = MaskHungarianAssigner(**assigner_cfg) + + +def test_ascend_max_iou_assigner(): + self = AscendMaxIoUAssigner( + pos_iou_thr=0.5, + neg_iou_thr=0.5, + ) + batch_bboxes = torch.FloatTensor([ + [ + [0, 0, 10, 10], + [10, 10, 20, 20], + [5, 5, 15, 15], + [32, 32, 38, 42], + ] + ]) + batch_gt_bboxes = torch.FloatTensor([ + [ + [0, 0, 10, 9], + [0, 10, 10, 19], + ] + ]) + batch_gt_labels = torch.LongTensor([[2, 3]]) + batch_bboxes_ignore_mask = torch.IntTensor( + [ + [1, 1, 1, 1] + ] + ) + assign_result = self.assign( + batch_bboxes, batch_gt_bboxes, + batch_gt_labels=batch_gt_labels, + batch_bboxes_ignore_mask=batch_bboxes_ignore_mask) + expected_batch_pos_mask = torch.IntTensor([1, 0, 1, 0]) + expected_batch_anchor_gt_indes = torch.IntTensor([0, 0, 1, 0]) + expected_batch_anchor_gt_labels = torch.IntTensor([2, 0, 3, 0]) + + assert torch.all(assign_result.batch_pos_mask == expected_batch_pos_mask) + assert torch.all( + assign_result.batch_anchor_gt_indes * assign_result.batch_pos_mask + == expected_batch_anchor_gt_indes) + assert torch.all( + assign_result.batch_anchor_gt_labels * assign_result.batch_pos_mask == + expected_batch_anchor_gt_labels) From 030ecfba8813274985d0b0ac0155fa0e67db38e4 Mon Sep 17 00:00:00 2001 From: akstt <13146229470@163.com> Date: Wed, 18 Jan 2023 11:36:23 +0800 Subject: [PATCH 10/13] resolve view comments --- mmdet/models/dense_heads/ascend_anchor_head.py | 8 +++++--- mmdet/utils/__init__.py | 5 +++-- mmdet/utils/ascend_util.py | 8 ++++++-- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/mmdet/models/dense_heads/ascend_anchor_head.py b/mmdet/models/dense_heads/ascend_anchor_head.py index 7c379899300..fc4caba07b6 100644 --- a/mmdet/models/dense_heads/ascend_anchor_head.py +++ b/mmdet/models/dense_heads/ascend_anchor_head.py @@ -3,7 +3,8 @@ from ...core.bbox.assigners import AscendMaxIoUAssigner from ...core.bbox.samplers import PseudoSampler -from ...utils import batch_images_to_levels, get_max_num_gt, masked_fill +from ...utils import batch_images_to_levels, \ + get_max_num_gt_division_factor, masked_fill from ..builder import HEADS from .anchor_head import AnchorHead @@ -81,7 +82,8 @@ def get_batch_gt_bboxes(self, gt_bboxes_list, num_images, gt_nums, device, Returns: batch_gt_bboxes: (Tensor): Ground truth bboxes of all image. """ - # a static ground truth boxes + # a static ground truth boxes. + # Save static gt. Related to Ascend. Helps improve performance if not hasattr(self, 'batch_gt_bboxes'): self.batch_gt_bboxes = {} # a min anchor filled the excess anchor @@ -329,7 +331,7 @@ def get_targets(self, ], 0) gt_nums = [len(gt_bbox) for gt_bbox in gt_bboxes_list] - max_gt_nums = get_max_num_gt(gt_nums) + max_gt_nums = get_max_num_gt_division_factor(gt_nums) batch_gt_bboxes = self.get_batch_gt_bboxes(gt_bboxes_list, num_imgs, gt_nums, device, max_gt_nums) diff --git a/mmdet/utils/__init__.py b/mmdet/utils/__init__.py index c4c4e2b81c2..6db1160e939 100644 --- a/mmdet/utils/__init__.py +++ b/mmdet/utils/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .ascend_util import batch_images_to_levels, get_max_num_gt, masked_fill +from .ascend_util import batch_images_to_levels, \ + get_max_num_gt_division_factor, masked_fill from .collect_env import collect_env from .compat_config import compat_cfg from .logger import get_caller_name, get_root_logger, log_img_scale @@ -15,5 +16,5 @@ 'update_data_root', 'setup_multi_processes', 'get_caller_name', 'log_img_scale', 'compat_cfg', 'split_batch', 'build_ddp', 'build_dp', 'get_device', 'replace_cfg_vals', 'AvoidOOM', 'AvoidCUDAOOM', - 'get_max_num_gt', 'masked_fill', 'batch_images_to_levels' + 'get_max_num_gt_division_factor', 'masked_fill', 'batch_images_to_levels' ] diff --git a/mmdet/utils/ascend_util.py b/mmdet/utils/ascend_util.py index ff50b569412..df90dec8205 100644 --- a/mmdet/utils/ascend_util.py +++ b/mmdet/utils/ascend_util.py @@ -46,20 +46,24 @@ def batch_images_to_levels(target, num_levels): return level_targets -def get_max_num_gt(gt_nums, min_num_gt=32, max_num_gt=1024): +def get_max_num_gt_division_factor(gt_nums, + min_num_gt=32, + max_num_gt=1024, + division_factor=2): """Count max num of gt. Args: gt_nums (List[int]): Ground truth bboxes num of images. min_num_gt (int): Min num of ground truth bboxes. max_num_gt (int): Max num of ground truth bboxes. + division_factor (int): Division factor of result. Returns: max_gt_nums_align: (int): max num of ground truth bboxes. """ max_gt_nums = max(gt_nums) max_gt_nums_align = min_num_gt while max_gt_nums_align < max_gt_nums: - max_gt_nums_align *= 2 + max_gt_nums_align *= division_factor if max_gt_nums_align > max_num_gt: raise RuntimeError return max_gt_nums_align From 3dc7c610b8124e5326a024078f61e8925918a0c3 Mon Sep 17 00:00:00 2001 From: akstt <13146229470@163.com> Date: Wed, 18 Jan 2023 11:37:19 +0800 Subject: [PATCH 11/13] resolve view comments --- mmdet/models/dense_heads/ascend_anchor_head.py | 4 ++-- mmdet/utils/__init__.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mmdet/models/dense_heads/ascend_anchor_head.py b/mmdet/models/dense_heads/ascend_anchor_head.py index fc4caba07b6..7d100ba9218 100644 --- a/mmdet/models/dense_heads/ascend_anchor_head.py +++ b/mmdet/models/dense_heads/ascend_anchor_head.py @@ -3,8 +3,8 @@ from ...core.bbox.assigners import AscendMaxIoUAssigner from ...core.bbox.samplers import PseudoSampler -from ...utils import batch_images_to_levels, \ - get_max_num_gt_division_factor, masked_fill +from ...utils import (batch_images_to_levels, get_max_num_gt_division_factor, + masked_fill) from ..builder import HEADS from .anchor_head import AnchorHead diff --git a/mmdet/utils/__init__.py b/mmdet/utils/__init__.py index 6db1160e939..5a384feafdf 100644 --- a/mmdet/utils/__init__.py +++ b/mmdet/utils/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .ascend_util import batch_images_to_levels, \ - get_max_num_gt_division_factor, masked_fill +from .ascend_util import (batch_images_to_levels, + get_max_num_gt_division_factor, masked_fill) from .collect_env import collect_env from .compat_config import compat_cfg from .logger import get_caller_name, get_root_logger, log_img_scale From c8c0be417b5466e309b6564b999590738c62085f Mon Sep 17 00:00:00 2001 From: akstt <13146229470@163.com> Date: Wed, 18 Jan 2023 14:07:47 +0800 Subject: [PATCH 12/13] resolve view comments --- tests/test_utils/test_assigner.py | 41 +++++++++++++------------------ 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/tests/test_utils/test_assigner.py b/tests/test_utils/test_assigner.py index 43adcb35729..f93600038dd 100644 --- a/tests/test_utils/test_assigner.py +++ b/tests/test_utils/test_assigner.py @@ -669,28 +669,21 @@ def test_ascend_max_iou_assigner(): pos_iou_thr=0.5, neg_iou_thr=0.5, ) - batch_bboxes = torch.FloatTensor([ - [ - [0, 0, 10, 10], - [10, 10, 20, 20], - [5, 5, 15, 15], - [32, 32, 38, 42], - ] - ]) - batch_gt_bboxes = torch.FloatTensor([ - [ - [0, 0, 10, 9], - [0, 10, 10, 19], - ] - ]) + batch_bboxes = torch.FloatTensor([[ + [0, 0, 10, 10], + [10, 10, 20, 20], + [5, 5, 15, 15], + [32, 32, 38, 42], + ]]) + batch_gt_bboxes = torch.FloatTensor([[ + [0, 0, 10, 9], + [0, 10, 10, 19], + ]]) batch_gt_labels = torch.LongTensor([[2, 3]]) - batch_bboxes_ignore_mask = torch.IntTensor( - [ - [1, 1, 1, 1] - ] - ) + batch_bboxes_ignore_mask = torch.IntTensor([[1, 1, 1, 1]]) assign_result = self.assign( - batch_bboxes, batch_gt_bboxes, + batch_bboxes, + batch_gt_bboxes, batch_gt_labels=batch_gt_labels, batch_bboxes_ignore_mask=batch_bboxes_ignore_mask) expected_batch_pos_mask = torch.IntTensor([1, 0, 1, 0]) @@ -699,8 +692,8 @@ def test_ascend_max_iou_assigner(): assert torch.all(assign_result.batch_pos_mask == expected_batch_pos_mask) assert torch.all( - assign_result.batch_anchor_gt_indes * assign_result.batch_pos_mask - == expected_batch_anchor_gt_indes) + assign_result.batch_anchor_gt_indes * + assign_result.batch_pos_mask == expected_batch_anchor_gt_indes) assert torch.all( - assign_result.batch_anchor_gt_labels * assign_result.batch_pos_mask == - expected_batch_anchor_gt_labels) + assign_result.batch_anchor_gt_labels * + assign_result.batch_pos_mask == expected_batch_anchor_gt_labels) From 8f927c0680f01496ecf851ec408a1054363bf81f Mon Sep 17 00:00:00 2001 From: akstt <13146229470@163.com> Date: Wed, 18 Jan 2023 14:10:45 +0800 Subject: [PATCH 13/13] resolve view comments --- tests/test_utils/test_assigner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_utils/test_assigner.py b/tests/test_utils/test_assigner.py index f93600038dd..7cdb08ba0fb 100644 --- a/tests/test_utils/test_assigner.py +++ b/tests/test_utils/test_assigner.py @@ -686,6 +686,7 @@ def test_ascend_max_iou_assigner(): batch_gt_bboxes, batch_gt_labels=batch_gt_labels, batch_bboxes_ignore_mask=batch_bboxes_ignore_mask) + expected_batch_pos_mask = torch.IntTensor([1, 0, 1, 0]) expected_batch_anchor_gt_indes = torch.IntTensor([0, 0, 1, 0]) expected_batch_anchor_gt_labels = torch.IntTensor([2, 0, 3, 0])