diff --git a/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py b/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py index 344d69ac..7d70113a 100644 --- a/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py +++ b/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py @@ -47,11 +47,12 @@ def forward(self, outputs, orig_target_sizes): else: scores = F.softmax(logits)[:, :, :-1] scores, labels = scores.max(dim=-1) + boxes = bbox_pred if scores.shape[1] > self.num_top_queries: scores, index = torch.topk(scores, self.num_top_queries, dim=-1) labels = torch.gather(labels, dim=1, index=index) boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) - + # TODO for onnx export if self.deploy_mode: return labels, boxes, scores