Skip to content

Commit

Permalink
NMS Op modification (#2568)
Browse files Browse the repository at this point in the history
Signed-off-by: Alankar Mahajan <quic_alanmaha@quicinc.com>
  • Loading branch information
quic-alanmaha authored Nov 15, 2023
1 parent ae15bf0 commit 2f2e4d7
Showing 1 changed file with 6 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,12 @@ def forward(self, *args) -> torch.Tensor:
res_per_class.append([index, class_index, val.detach()])
res_per_class = res_per_class[:self.max_output_boxes_per_class]
res.extend(res_per_class)
return torch.Tensor(res).type(torch.int64)

res = torch.Tensor(res).type(torch.int64)
out = torch.zeros(batch_scores.shape[0] * batch_scores.shape[1] * self.max_output_boxes_per_class, 3, dtype=torch.int64)
indices = torch.arange(0, len(res) * len(res[0]), dtype=torch.int64)
out.put_(indices, res)
return out

def perform_nms_per_class(self, boxes: torch.Tensor, classes_score: torch.Tensor) -> torch.Tensor:
"""
Expand Down

0 comments on commit 2f2e4d7

Please sign in to comment.