diff --git a/CHANGELOG.md b/CHANGELOG.md index 684d02fc8da..442b10026b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed warnings being suppressed in `MeanAveragePrecision` when requested ([#2501](https://github.com/Lightning-AI/torchmetrics/pull/2501)) +- Fixed cornercase in `binary_average_precision` when only negative samples are provided ([#2507](https://github.com/Lightning-AI/torchmetrics/pull/2507)) + + ## [1.3.2] - 2024-03-18 ### Fixed diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 64958267737..da12db561a1 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -23,6 +23,7 @@ from torchmetrics.utilities.compute import _safe_divide, interp from torchmetrics.utilities.data import _bincount, _cumsum from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.prints import rank_zero_warn def _binary_clf_curve( @@ -274,6 +275,12 @@ def _binary_precision_recall_curve_compute( fps, tps, thresholds = _binary_clf_curve(state[0], state[1], pos_label=pos_label) precision = tps / (tps + fps) recall = tps / tps[-1] + if (state[1] == 0).all(): # all labels are negative, recall is undefined + rank_zero_warn( + "No positive samples found in target, recall is undefined. Setting recall to one for all thresholds.", + UserWarning, + ) + recall = torch.ones_like(recall) # need to call reversed explicitly, since including that to slice would # introduce negative strides that are not yet supported in pytorch diff --git a/tests/unittests/classification/test_average_precision.py b/tests/unittests/classification/test_average_precision.py index 2ff0274c93c..51d40839642 100644 --- a/tests/unittests/classification/test_average_precision.py +++ b/tests/unittests/classification/test_average_precision.py @@ -143,6 +143,14 @@ def test_binary_average_precision_threshold_arg(self, inputs, threshold_fn): assert torch.allclose(ap1, ap2) +def test_warning_on_no_positives(): + """Test that a warning is raised when there are no positive samples in the target.""" + preds = torch.rand(100) + target = torch.zeros(100).long() + with pytest.warns(UserWarning, match="No positive samples found in target, recall is undefined. Setting recall.*"): + binary_average_precision(preds, target) + + def _reference_sklearn_avg_precision_multiclass(preds, target, average="macro", ignore_index=None): preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) target = target.numpy().flatten()