Skip to content

Commit

Permalink
Merge pull request #319 from lyuwenyu/fix_postprocess
Browse files Browse the repository at this point in the history
[fix] bbox_pred
  • Loading branch information
lyuwenyu authored May 23, 2024
2 parents 2b88d5d + 186d207 commit 5b628ea
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ def forward(self, outputs, orig_target_sizes):
else:
scores = F.softmax(logits)[:, :, :-1]
scores, labels = scores.max(dim=-1)
boxes = bbox_pred
if scores.shape[1] > self.num_top_queries:
scores, index = torch.topk(scores, self.num_top_queries, dim=-1)
labels = torch.gather(labels, dim=1, index=index)
boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))

# TODO for onnx export
if self.deploy_mode:
return labels, boxes, scores
Expand Down

0 comments on commit 5b628ea

Please sign in to comment.