Skip to content

Commit

Permalink
Bugfix for empty preds or target in iou scores (#2806)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: Jirka B <j.borovec+github@gmail.com>
  • Loading branch information
3 people authored Oct 29, 2024
1 parent c572289 commit fbc7877
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 6 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed iou scores in detection for either empty predictions/targets leading to wrong scores ([#2805](https://github.com/Lightning-AI/torchmetrics/pull/2805))


---
Expand Down
13 changes: 8 additions & 5 deletions src/torchmetrics/detection/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,17 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]
"""Update state with predictions and targets."""
_input_validator(preds, target, ignore_score=True)

for p, t in zip(preds, target):
det_boxes = self._get_safe_item_values(p["boxes"])
gt_boxes = self._get_safe_item_values(t["boxes"])
self.groundtruth_labels.append(t["labels"])
for p_i, t_i in zip(preds, target):
det_boxes = self._get_safe_item_values(p_i["boxes"])
gt_boxes = self._get_safe_item_values(t_i["boxes"])
self.groundtruth_labels.append(t_i["labels"])

iou_matrix = self._iou_update_fn(det_boxes, gt_boxes, self.iou_threshold, self._invalid_val) # N x M
if self.respect_labels:
label_eq = p["labels"].unsqueeze(1) == t["labels"].unsqueeze(0) # N x M
if det_boxes.numel() > 0 and gt_boxes.numel() > 0:
label_eq = p_i["labels"].unsqueeze(1) == t_i["labels"].unsqueeze(0) # N x M
else:
label_eq = torch.eye(iou_matrix.shape[0], dtype=bool, device=iou_matrix.device) # type: ignore[call-overload]
iou_matrix[~label_eq] = self._invalid_val
self.iou_matrix.append(iou_matrix)

Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/functional/detection/ciou.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def _ciou_update(

from torchvision.ops import complete_box_iou

if preds.numel() == 0: # if no boxes are predicted
return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32)
if target.numel() == 0: # if no boxes are true
return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32)

iou = complete_box_iou(preds, target)
if iou_threshold is not None:
iou[iou < iou_threshold] = replacement_val
Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/functional/detection/diou.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def _diou_update(

from torchvision.ops import distance_box_iou

if preds.numel() == 0: # if no boxes are predicted
return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32)
if target.numel() == 0: # if no boxes are true
return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32)

iou = distance_box_iou(preds, target)
if iou_threshold is not None:
iou[iou < iou_threshold] = replacement_val
Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/functional/detection/giou.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def _giou_update(

from torchvision.ops import generalized_box_iou

if preds.numel() == 0: # if no boxes are predicted
return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32)
if target.numel() == 0: # if no boxes are true
return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32)

iou = generalized_box_iou(preds, target)
if iou_threshold is not None:
iou[iou < iou_threshold] = replacement_val
Expand Down
5 changes: 5 additions & 0 deletions src/torchmetrics/functional/detection/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def _iou_update(

from torchvision.ops import box_iou

if preds.numel() == 0: # if no boxes are predicted
return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32)
if target.numel() == 0: # if no boxes are true
return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32)

iou = box_iou(preds, target)
if iou_threshold is not None:
iou[iou < iou_threshold] = replacement_val
Expand Down
37 changes: 37 additions & 0 deletions tests/unittests/detection/test_intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,43 @@ def test_corner_case_only_one_empty_prediction(self, class_metric, functional_me
for val in res.values():
assert val == torch.tensor(0.0)

def test_empty_preds_and_target(self, class_metric, functional_metric, reference_metric):
"""Check that for either empty preds and targets that the metric returns 0 in these cases before averaging."""
x = [
{
"boxes": torch.empty(size=(0, 4), dtype=torch.float32),
"labels": torch.tensor([], dtype=torch.long),
},
{
"boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]),
"labels": torch.LongTensor([1, 2]),
},
]

y = [
{
"boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]),
"labels": torch.LongTensor([1, 2]),
"scores": torch.FloatTensor([0.9, 0.8]),
},
{
"boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]),
"labels": torch.LongTensor([1, 2]),
"scores": torch.FloatTensor([0.9, 0.8]),
},
]
metric = class_metric()
metric.update(x, y)
res = metric.compute()
for val in res.values():
assert val == torch.tensor(0.5)

metric = class_metric()
metric.update(y, x)
res = metric.compute()
for val in res.values():
assert val == torch.tensor(0.5)


def test_corner_case():
"""See issue: https://github.com/Lightning-AI/torchmetrics/issues/1921."""
Expand Down

0 comments on commit fbc7877

Please sign in to comment.