Skip to content

Commit

Permalink
Minor changes to NMS Op implementation (#2552)
Browse files Browse the repository at this point in the history
  • Loading branch information
quic-saksrai authored Nov 7, 2023
1 parent 1b483b4 commit 0e7430a
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,16 +361,27 @@ def forward(self, *args) -> torch.Tensor:
res = []
for index, (boxes, scores) in enumerate(zip(batches_boxes, batch_scores)):
for class_index, classes_score in enumerate(scores):
filtered_score_ind = (classes_score > self.score_threshold).nonzero()[:, 0]
boxes = boxes[filtered_score_ind, :]
classes_score = classes_score[filtered_score_ind]
temp_res = torchvision.ops.nms(boxes, classes_score, self.iou_threshold)
res_ = filtered_score_ind[temp_res]
for val in res_:
res.append([index, class_index, val.detach()])
res = res[:(self.max_output_boxes_per_class *(index+1))]
nms_output = self.perform_nms_per_class(boxes, classes_score)
res_per_class = []
for val in nms_output:
res_per_class.append([index, class_index, val.detach()])
res_per_class = res_per_class[:self.max_output_boxes_per_class]
res.extend(res_per_class)
return torch.Tensor(res).type(torch.int64)

def perform_nms_per_class(self, boxes: torch.Tensor, classes_score: torch.Tensor) -> torch.Tensor:
"""
Performs NMS per class
:param boxes: boxes on which NMS should be performed
:param classes_score: corresponding class scores for the boxes
:return: returns box indices filtered out by NMS
"""
filtered_score_ind = (classes_score > self.score_threshold).nonzero()[:, 0]
filtered_boxes = boxes[filtered_score_ind]
filtered_classes_score = classes_score[filtered_score_ind]
res_ = torchvision.ops.nms(filtered_boxes, filtered_classes_score, self.iou_threshold)
return filtered_score_ind[res_]


class GatherNd(torch.nn.Module):
""" GatherNd op implementation"""
Expand Down

0 comments on commit 0e7430a

Please sign in to comment.