Skip to content

Commit

Permalink
Fix top_k for multiclass-f1score (#2839)
Browse files Browse the repository at this point in the history
** Apply suggestions from code review

---------

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka B <j.borovec+github@gmail.com>
  • Loading branch information
4 people authored Dec 21, 2024
1 parent 8827e64 commit 4d9c843
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 25 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed issue with shared state in metric collection when using dice score ([#2848](https://github.com/PyTorchLightning/metrics/pull/2848))


- Fixed `top_k` for `multiclassf1score` with one-hot encoding ([#2839](https://github.com/Lightning-AI/torchmetrics/issues/2839))


---

## [1.6.0] - 2024-11-12
Expand Down
27 changes: 27 additions & 0 deletions src/torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,30 @@ def _multiclass_stat_scores_format(
return preds, target


def _refine_preds_oh(preds: Tensor, preds_oh: Tensor, target: Tensor, top_k: int) -> Tensor:
"""Refines prediction one-hot encodings by replacing entries with target one-hot when there's an intersection.
When no intersection is found between the top-k predictions and target, uses the top-1 prediction.
Args:
preds: Original prediction tensor with probabilities/logits
preds_oh: Current one-hot encoded predictions from top-k selection
target: Target tensor with class indices
top_k: Number of top predictions to consider
Returns:
Refined one-hot encoded predictions tensor
"""
preds = preds.squeeze()
target = target.squeeze()
top_k_indices = torch.topk(preds, k=top_k, dim=1).indices
top_1_indices = top_k_indices[:, 0]
target_in_topk = torch.any(top_k_indices == target.unsqueeze(1), dim=1)
result = torch.where(target_in_topk, target, top_1_indices)
return torch.zeros_like(preds_oh, dtype=torch.int32).scatter_(-1, result.unsqueeze(1).unsqueeze(1), 1)


def _multiclass_stat_scores_update(
preds: Tensor,
target: Tensor,
Expand Down Expand Up @@ -371,13 +395,16 @@ def _multiclass_stat_scores_update(

if top_k > 1:
preds_oh = torch.movedim(select_topk(preds, topk=top_k, dim=1), 1, -1)
preds_oh = _refine_preds_oh(preds, preds_oh, target, top_k)
else:
preds_oh = torch.nn.functional.one_hot(
preds.long(), num_classes + 1 if ignore_index is not None and not ignore_in else num_classes
)

target_oh = torch.nn.functional.one_hot(
target.long(), num_classes + 1 if ignore_index is not None and not ignore_in else num_classes
)

if ignore_index is not None:
if 0 <= ignore_index <= num_classes - 1:
target_oh[target == ignore_index, :] = -1
Expand Down
107 changes: 102 additions & 5 deletions tests/unittests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,19 @@ def test_multiclass_fbeta_score_half_gpu(self, inputs, module, functional, compa


_mc_k_target = torch.tensor([0, 1, 2])
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
_mc_k_preds = torch.tensor([
[0.35, 0.4, 0.25],
[0.1, 0.5, 0.4],
[0.2, 0.1, 0.7],
])

_mc_k_target2 = torch.tensor([0, 1, 2, 0])
_mc_k_preds2 = torch.tensor([
[0.1, 0.2, 0.7],
[0.4, 0.4, 0.2],
[0.3, 0.3, 0.4],
[0.3, 0.3, 0.4],
])


@pytest.mark.parametrize(
Expand All @@ -391,7 +403,33 @@ def test_multiclass_fbeta_score_half_gpu(self, inputs, module, functional, compa
("k", "preds", "target", "average", "expected_fbeta", "expected_f1"),
[
(1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)),
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(5 / 6), torch.tensor(2 / 3)),
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1.0), torch.tensor(1.0)),
(1, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(0.25), torch.tensor(0.25)),
(2, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(0.75), torch.tensor(0.75)),
(3, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(1.0), torch.tensor(1.0)),
(1, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(0.2381), torch.tensor(0.1667)),
(2, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(0.7963), torch.tensor(0.7778)),
(3, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(1.0), torch.tensor(1.0)),
(1, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(0.1786), torch.tensor(0.1250)),
(2, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(0.7361), torch.tensor(0.7500)),
(3, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(1.0), torch.tensor(1.0)),
(
1,
_mc_k_preds2,
_mc_k_target2,
"none",
torch.tensor([0.0000, 0.0000, 0.7143]),
torch.tensor([0.0000, 0.0000, 0.5000]),
),
(
2,
_mc_k_preds2,
_mc_k_target2,
"none",
torch.tensor([0.5556, 1.0000, 0.8333]),
torch.tensor([0.6667, 1.0000, 0.6667]),
),
(3, _mc_k_preds2, _mc_k_target2, "none", torch.tensor([1.0, 1.0, 1.0]), torch.tensor([1.0, 1.0, 1.0])),
],
)
def test_top_k(
Expand All @@ -404,14 +442,73 @@ def test_top_k(
expected_fbeta: Tensor,
expected_f1: Tensor,
):
"""A simple test to check that top_k works as expected."""
"""A comprehensive test to check that top_k works as expected."""
class_metric = metric_class(top_k=k, average=average, num_classes=3)
class_metric.update(preds, target)

result = expected_fbeta if class_metric.beta != 1.0 else expected_f1

assert torch.isclose(class_metric.compute(), result)
assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result)
assert torch.allclose(class_metric.compute(), result, atol=1e-4, rtol=1e-4)
assert torch.allclose(
metric_fn(preds, target, top_k=k, average=average, num_classes=3), result, atol=1e-4, rtol=1e-4
)


@pytest.mark.parametrize("num_classes", [5])
def test_multiclassf1score_with_top_k(num_classes):
"""Test that F1 score increases monotonically with top_k and equals 1 when top_k equals num_classes.
Args:
num_classes: Number of classes in the classification task.
The test verifies two properties:
1. F1 score increases or stays the same as top_k increases
2. F1 score equals 1 when top_k equals num_classes
"""
preds = torch.randn(200, num_classes).softmax(dim=-1)
target = torch.randint(num_classes, (200,))

previous_score = 0.0
for k in range(1, num_classes + 1):
f1_score = MulticlassF1Score(num_classes=num_classes, top_k=k, average="macro")
score = f1_score(preds, target)

assert score >= previous_score, f"F1 score did not increase for top_k={k}"
previous_score = score

if k == num_classes:
assert torch.isclose(
score, torch.tensor(1.0)
), f"F1 score is not 1 for top_k={k} when num_classes={num_classes}"


def test_multiclass_f1_score_top_k_equivalence():
"""Issue: https://github.com/Lightning-AI/torchmetrics/issues/1653.
Test that top-k F1 score is equivalent to corrected top-1 F1 score.
"""
num_classes = 5

preds = torch.randn(200, num_classes).softmax(dim=-1)
target = torch.randint(num_classes, (200,))

f1_val_top3 = MulticlassF1Score(num_classes=num_classes, top_k=3, average="macro")
f1_val_top1 = MulticlassF1Score(num_classes=num_classes, top_k=1, average="macro")

pred_top_3 = torch.argsort(preds, dim=1, descending=True)[:, :3]
pred_top_1 = pred_top_3[:, 0]

target_in_top3 = (target.unsqueeze(1) == pred_top_3).any(dim=1)

pred_corrected_top3 = torch.where(target_in_top3, target, pred_top_1)

score_top3 = f1_val_top3(preds, target)
score_corrected = f1_val_top1(pred_corrected_top3, target)

assert torch.isclose(
score_top3, score_corrected
), f"Top-3 F1 score ({score_top3}) does not match corrected top-1 F1 score ({score_corrected})"


def _reference_sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, average, zero_division):
Expand Down
42 changes: 33 additions & 9 deletions tests/unittests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,16 @@ def test_multiclass_precision_recall_half_gpu(self, inputs, module, functional,
_mc_k_target = tensor([0, 1, 2])
_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])

_mc_k_targets2 = torch.tensor([0, 0, 2])
_mc_k_preds2 = torch.tensor([[0.9, 0.1, 0.0], [0.9, 0.1, 0.0], [0.9, 0.1, 0.0]])
_mc_k_targets2 = tensor([0, 0, 2])
_mc_k_preds2 = tensor([[0.9, 0.1, 0.0], [0.9, 0.1, 0.0], [0.9, 0.1, 0.0]])

_mc_k_target3 = tensor([0, 1, 2, 0])
_mc_k_preds3 = tensor([
[0.1, 0.2, 0.7],
[0.4, 0.4, 0.2],
[0.3, 0.3, 0.4],
[0.3, 0.3, 0.4],
])


@pytest.mark.parametrize(
Expand All @@ -395,10 +403,24 @@ def test_multiclass_precision_recall_half_gpu(self, inputs, module, functional,
@pytest.mark.parametrize(
("k", "preds", "target", "average", "expected_prec", "expected_recall"),
[
(1, _mc_k_preds, _mc_k_target, "micro", tensor(2 / 3), tensor(2 / 3)),
(2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2), tensor(1.0)),
(1, _mc_k_preds2, _mc_k_targets2, "macro", tensor(1 / 3), tensor(1 / 2)),
(2, _mc_k_preds2, _mc_k_targets2, "macro", tensor(1 / 3), tensor(1 / 2)),
(1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)),
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1.0), torch.tensor(1.0)),
(3, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1.0), torch.tensor(1.0)),
(1, _mc_k_preds2, _mc_k_targets2, "macro", torch.tensor(1 / 3), torch.tensor(1 / 2)),
(2, _mc_k_preds2, _mc_k_targets2, "macro", torch.tensor(1 / 3), torch.tensor(1 / 2)),
(3, _mc_k_preds2, _mc_k_targets2, "macro", torch.tensor(1.0), torch.tensor(1.0)),
(1, _mc_k_preds3, _mc_k_target3, "macro", torch.tensor(0.1111), torch.tensor(0.3333)),
(2, _mc_k_preds3, _mc_k_target3, "macro", torch.tensor(0.8333), torch.tensor(0.8333)),
(3, _mc_k_preds3, _mc_k_target3, "macro", torch.tensor(1.0), torch.tensor(1.0)),
(1, _mc_k_preds3, _mc_k_target3, "micro", torch.tensor(0.2500), torch.tensor(0.2500)),
(2, _mc_k_preds3, _mc_k_target3, "micro", torch.tensor(0.7500), torch.tensor(0.7500)),
(3, _mc_k_preds3, _mc_k_target3, "micro", torch.tensor(1.0), torch.tensor(1.0)),
(1, _mc_k_preds3, _mc_k_target3, "weighted", torch.tensor(0.0833), torch.tensor(0.2500)),
(2, _mc_k_preds3, _mc_k_target3, "weighted", torch.tensor(0.8750), torch.tensor(0.7500)),
(3, _mc_k_preds3, _mc_k_target3, "weighted", torch.tensor(1.0), torch.tensor(1.0)),
(1, _mc_k_preds3, _mc_k_target3, "none", torch.tensor([0.0000, 0.0000, 0.3333]), torch.tensor([0.0, 0.0, 1.0])),
(2, _mc_k_preds3, _mc_k_target3, "none", torch.tensor([1.0000, 1.0000, 0.5000]), torch.tensor([0.5, 1.0, 1.0])),
(3, _mc_k_preds3, _mc_k_target3, "none", torch.tensor([1.0, 1.0, 1.0]), torch.tensor([1.0, 1.0, 1.0])),
],
)
def test_top_k(
Expand All @@ -411,14 +433,16 @@ def test_top_k(
expected_prec: Tensor,
expected_recall: Tensor,
):
"""A simple test to check that top_k works as expected."""
"""A test to validate top_k functionality for precision and recall."""
class_metric = metric_class(top_k=k, average=average, num_classes=3)
class_metric.update(preds, target)

result = expected_prec if metric_class.__name__ == "MulticlassPrecision" else expected_recall

assert torch.equal(class_metric.compute(), result)
assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result)
assert torch.allclose(class_metric.compute(), result, atol=1e-4, rtol=1e-4)
assert torch.allclose(
metric_fn(preds, target, top_k=k, average=average, num_classes=3), result, atol=1e-4, rtol=1e-4
)


def _reference_sklearn_precision_recall_multilabel_global(preds, target, sk_fn, ignore_index, average, zero_division):
Expand Down
39 changes: 32 additions & 7 deletions tests/unittests/classification/test_specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
from scipy.special import expit as sigmoid
from sklearn.metrics import confusion_matrix as sk_confusion_matrix
from torch import Tensor, tensor
from torch import Tensor
from torchmetrics.classification.specificity import (
BinarySpecificity,
MulticlassSpecificity,
Expand Down Expand Up @@ -355,24 +355,49 @@ def test_multiclass_specificity_dtype_gpu(self, inputs, dtype):
)


_mc_k_target = tensor([0, 1, 2])
_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
_mc_k_target = torch.tensor([0, 1, 2])
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])

_mc_k_target2 = torch.tensor([0, 1, 2, 0])
_mc_k_preds2 = torch.tensor([
[0.1, 0.2, 0.7],
[0.4, 0.4, 0.2],
[0.3, 0.3, 0.4],
[0.3, 0.3, 0.4],
])


@pytest.mark.parametrize(
("k", "preds", "target", "average", "expected_spec"),
[
(1, _mc_k_preds, _mc_k_target, "micro", tensor(5 / 6)),
(2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2)),
(1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(5 / 6)),
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1.0)),
(1, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(0.6111)),
(2, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(0.8889)),
(3, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(1.0)),
(1, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(0.6250)),
(2, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(0.8750)),
(3, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(1.0)),
(1, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(0.5833)),
(2, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(0.9167)),
(3, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(1.0)),
(1, _mc_k_preds2, _mc_k_target2, "none", torch.tensor([0.5000, 1.0000, 0.3333])),
(2, _mc_k_preds2, _mc_k_target2, "none", torch.tensor([1.0000, 1.0000, 0.6667])),
(3, _mc_k_preds2, _mc_k_target2, "none", torch.tensor([1.0, 1.0, 1.0])),
],
)
def test_top_k(k: int, preds: Tensor, target: Tensor, average: str, expected_spec: Tensor):
"""A simple test to check that top_k works as expected."""
class_metric = MulticlassSpecificity(top_k=k, average=average, num_classes=3)
class_metric.update(preds, target)

assert torch.equal(class_metric.compute(), expected_spec)
assert torch.equal(multiclass_specificity(preds, target, top_k=k, average=average, num_classes=3), expected_spec)
assert torch.allclose(class_metric.compute(), expected_spec, atol=1e-4, rtol=1e-4)
assert torch.allclose(
multiclass_specificity(preds, target, top_k=k, average=average, num_classes=3),
expected_spec,
atol=1e-4,
rtol=1e-4,
)


def _reference_specificity_multilabel_global(preds, target, ignore_index, average):
Expand Down
Loading

0 comments on commit 4d9c843

Please sign in to comment.