From 2f2e4d7bea8d4921852906a8dd4e4a831c14fe41 Mon Sep 17 00:00:00 2001 From: Alankar Mahajan Date: Wed, 15 Nov 2023 11:18:02 +0530 Subject: [PATCH] NMS Op modification (#2568) Signed-off-by: Alankar Mahajan --- .../torch/src/python/aimet_torch/elementwise_ops.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py b/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py index c70adce82b5..a9dde58a972 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/elementwise_ops.py @@ -367,7 +367,12 @@ def forward(self, *args) -> torch.Tensor: res_per_class.append([index, class_index, val.detach()]) res_per_class = res_per_class[:self.max_output_boxes_per_class] res.extend(res_per_class) - return torch.Tensor(res).type(torch.int64) + + res = torch.Tensor(res).type(torch.int64) + out = torch.zeros(batch_scores.shape[0] * batch_scores.shape[1] * self.max_output_boxes_per_class, 3, dtype=torch.int64) + indices = torch.arange(0, len(res) * len(res[0]), dtype=torch.int64) + out.put_(indices, res) + return out def perform_nms_per_class(self, boxes: torch.Tensor, classes_score: torch.Tensor) -> torch.Tensor: """