Skip to content

Commit

Permalink
Ensure target class indices are of type long in loss calculations (#4143
Browse files Browse the repository at this point in the history
)

* Ensure target class indices are of type long in loss calculations

* update changelog
  • Loading branch information
eugene123tw authored Dec 4, 2024
1 parent 5d6f8d3 commit 1608d9b
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/4131>)
- Fix tensor type compatibility in dynamic soft label assigner and RTMDet head
(<https://github.com/openvinotoolkit/training_extensions/pull/4140>)
- Fix DETR target class indices are of type long in loss calculations
(<https://github.com/openvinotoolkit/training_extensions/pull/4143>)

## \[v2.1.0\]

Expand Down
2 changes: 1 addition & 1 deletion src/otx/algo/detection/losses/rtdetr_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def loss_labels_vfl(
src_logits = outputs["pred_logits"]
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device)
target_classes[idx] = target_classes_o
target_classes[idx] = target_classes_o.long()
target = nn.functional.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]

target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype)
Expand Down
2 changes: 1 addition & 1 deletion src/otx/algo/detection/utils/matchers/hungarian_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def forward(
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]

# Also concat the target labels and boxes
tgt_ids = torch.cat([v["labels"] for v in targets])
tgt_ids = torch.cat([v["labels"] for v in targets]).long()
tgt_bbox = torch.cat([v["boxes"] for v in targets])

# Compute the classification cost. Contrary to the loss, we don't use the NLL,
Expand Down

0 comments on commit 1608d9b

Please sign in to comment.