Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix top_k for multiclass-f1score #2839

Merged
merged 52 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
8302300
Fix: Handle zero division error in binary IoU (Jaccard index) calcula…
rittik9 Sep 9, 2024
9098d0a
chlog
Borda Sep 9, 2024
7803302
Merge branch 'master' into fix/handle-zero-division-iou-calculation
Borda Sep 9, 2024
65b2714
Merge branch 'master' into fix/handle-zero-division-iou-calculation
mergify[bot] Sep 10, 2024
31087e3
Merge branch 'master' into fix/handle-zero-division-iou-calculation
mergify[bot] Sep 10, 2024
21948b2
fix: rouge_score with accumulate='best' gives mixed results #2148
rittik9 Nov 7, 2024
22c5b60
Merge branch 'master' into fix/rouge
rittik9 Nov 7, 2024
bb208f4
fix: test_rouge.py
rittik9 Nov 8, 2024
d6c9955
Merge branch 'fix/rouge' of https://github.com/rittik9/torchmetrics i…
rittik9 Nov 8, 2024
456c9d5
Merge branch 'master' into fix/rouge
rittik9 Nov 8, 2024
e09d751
test: test_rouge.py
rittik9 Nov 8, 2024
da834d0
fix: test_rouge.py
rittik9 Nov 8, 2024
beeb001
minor fix: test_rouge.py
rittik9 Nov 8, 2024
6eaa882
Update CHANGELOG.md
Borda Nov 8, 2024
27d4e59
fix: top_k for multiclassf1score by adding refine_preds_oh
rittik9 Nov 21, 2024
6d40bb4
fix: modify top_k tests for f_beta, precision_recall, specificity, st…
rittik9 Nov 22, 2024
a949834
add top_k equivalence test
rittik9 Nov 22, 2024
00c536e
Merge branch 'master' into rittik/multiclassf1_topk
rittik9 Nov 22, 2024
7c8b325
uncomment test_top_k_ignore_index_multiclass
rittik9 Nov 23, 2024
aa8ff49
refactor: modified refine_preds_oh
rittik9 Nov 23, 2024
910c537
Update CHANGELOG.md
Borda Nov 25, 2024
efab768
Update tests/unittests/classification/test_precision_recall.py
rittik9 Nov 26, 2024
9da37ad
Apply suggestions from code review
Borda Nov 26, 2024
4fd8850
Merge branch 'master' into rittik/multiclassf1_topk
rittik9 Nov 26, 2024
db24ff2
add unittest for refine_preds_oh
rittik9 Nov 26, 2024
b574fd0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 26, 2024
9aa6786
add docstring for test_refine_preds_oh
rittik9 Nov 26, 2024
64bd4be
add new tests
rittik9 Nov 27, 2024
18855dc
fix tests
rittik9 Nov 27, 2024
1832e07
refactor: test_stat_scores.py
rittik9 Nov 27, 2024
c8b47bf
fix test_stat_scores.py
rittik9 Nov 27, 2024
d53c0c5
fix test_specificity.py
rittik9 Nov 27, 2024
8508274
fix test
rittik9 Nov 27, 2024
ddd5fb6
fix test precision recall
rittik9 Nov 27, 2024
1886a81
fix: test_f_beta.py
rittik9 Nov 27, 2024
6457058
Merge branch 'master' into rittik/multiclassf1_topk
rittik9 Dec 2, 2024
1b2abfa
reactor: make _refine_preds_oh private
rittik9 Dec 11, 2024
d673d05
Update test_stat_scores.py
rittik9 Dec 11, 2024
11f31e9
Merge branch 'master' into rittik/multiclassf1_topk
rittik9 Dec 11, 2024
7178077
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2024
bca8d56
Apply suggestions from code review
Borda Dec 17, 2024
6644d01
Apply suggestions from code review
Borda Dec 17, 2024
db08ff8
Merge branch 'master' into rittik/multiclassf1_topk
Borda Dec 17, 2024
9530f18
Merge branch 'master' into rittik/multiclassf1_topk
rittik9 Dec 20, 2024
78b12d7
Apply suggestions from code review
Borda Dec 21, 2024
d9ff941
tensor
Borda Dec 21, 2024
412068f
Merge branch 'master' into rittik/multiclassf1_topk
mergify[bot] Dec 21, 2024
3219e81
Revert "Apply suggestions from code review"
Borda Dec 21, 2024
9be317d
Merge branch 'master' into rittik/multiclassf1_topk
mergify[bot] Dec 21, 2024
a7ee183
Merge branch 'rittik/multiclassf1_topk' of https://github.com/rittik9…
Borda Dec 21, 2024
0eae931
Merge branch 'master' into rittik/multiclassf1_topk
mergify[bot] Dec 21, 2024
0b92576
Apply suggestions from code review
Borda Dec 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed issue with shared state in metric collection when using dice score ([#2848](https://github.com/PyTorchLightning/metrics/pull/2848))


Borda marked this conversation as resolved.
Show resolved Hide resolved
- Fixed `top_k` for `multiclassf1score` with one-hot encoding ([#2839](https://github.com/Lightning-AI/torchmetrics/issues/2839))


---

## [1.6.0] - 2024-11-12
Expand Down
27 changes: 27 additions & 0 deletions src/torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,30 @@ def _multiclass_stat_scores_format(
return preds, target


def _refine_preds_oh(preds: Tensor, preds_oh: Tensor, target: Tensor, top_k: int) -> Tensor:
"""Refines prediction one-hot encodings by replacing entries with target one-hot when there's an intersection.

When no intersection is found between the top-k predictions and target, uses the top-1 prediction.

Args:
preds: Original prediction tensor with probabilities/logits
preds_oh: Current one-hot encoded predictions from top-k selection
target: Target tensor with class indices
top_k: Number of top predictions to consider

Returns:
Refined one-hot encoded predictions tensor

"""
preds = preds.squeeze()
target = target.squeeze()
top_k_indices = torch.topk(preds, k=top_k, dim=1).indices
top_1_indices = top_k_indices[:, 0]
target_in_topk = torch.any(top_k_indices == target.unsqueeze(1), dim=1)
result = torch.where(target_in_topk, target, top_1_indices)
return torch.zeros_like(preds_oh, dtype=torch.int32).scatter_(-1, result.unsqueeze(1).unsqueeze(1), 1)


def _multiclass_stat_scores_update(
preds: Tensor,
target: Tensor,
Expand Down Expand Up @@ -371,13 +395,16 @@ def _multiclass_stat_scores_update(

if top_k > 1:
preds_oh = torch.movedim(select_topk(preds, topk=top_k, dim=1), 1, -1)
preds_oh = _refine_preds_oh(preds, preds_oh, target, top_k)
else:
preds_oh = torch.nn.functional.one_hot(
preds.long(), num_classes + 1 if ignore_index is not None and not ignore_in else num_classes
)

target_oh = torch.nn.functional.one_hot(
target.long(), num_classes + 1 if ignore_index is not None and not ignore_in else num_classes
)

if ignore_index is not None:
if 0 <= ignore_index <= num_classes - 1:
target_oh[target == ignore_index, :] = -1
Expand Down
107 changes: 102 additions & 5 deletions tests/unittests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,19 @@ def test_multiclass_fbeta_score_half_gpu(self, inputs, module, functional, compa


_mc_k_target = torch.tensor([0, 1, 2])
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
_mc_k_preds = torch.tensor([
[0.35, 0.4, 0.25],
[0.1, 0.5, 0.4],
[0.2, 0.1, 0.7],
])

_mc_k_target2 = torch.tensor([0, 1, 2, 0])
_mc_k_preds2 = torch.tensor([
[0.1, 0.2, 0.7],
[0.4, 0.4, 0.2],
[0.3, 0.3, 0.4],
[0.3, 0.3, 0.4],
])


@pytest.mark.parametrize(
Expand All @@ -391,7 +403,33 @@ def test_multiclass_fbeta_score_half_gpu(self, inputs, module, functional, compa
("k", "preds", "target", "average", "expected_fbeta", "expected_f1"),
[
(1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)),
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(5 / 6), torch.tensor(2 / 3)),
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1.0), torch.tensor(1.0)),
(1, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(0.25), torch.tensor(0.25)),
(2, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(0.75), torch.tensor(0.75)),
(3, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(1.0), torch.tensor(1.0)),
(1, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(0.2381), torch.tensor(0.1667)),
(2, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(0.7963), torch.tensor(0.7778)),
(3, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(1.0), torch.tensor(1.0)),
(1, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(0.1786), torch.tensor(0.1250)),
(2, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(0.7361), torch.tensor(0.7500)),
(3, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(1.0), torch.tensor(1.0)),
(
1,
_mc_k_preds2,
_mc_k_target2,
"none",
torch.tensor([0.0000, 0.0000, 0.7143]),
torch.tensor([0.0000, 0.0000, 0.5000]),
),
(
2,
_mc_k_preds2,
_mc_k_target2,
"none",
torch.tensor([0.5556, 1.0000, 0.8333]),
torch.tensor([0.6667, 1.0000, 0.6667]),
),
(3, _mc_k_preds2, _mc_k_target2, "none", torch.tensor([1.0, 1.0, 1.0]), torch.tensor([1.0, 1.0, 1.0])),
],
)
def test_top_k(
Expand All @@ -404,14 +442,73 @@ def test_top_k(
expected_fbeta: Tensor,
expected_f1: Tensor,
):
"""A simple test to check that top_k works as expected."""
"""A comprehensive test to check that top_k works as expected."""
class_metric = metric_class(top_k=k, average=average, num_classes=3)
class_metric.update(preds, target)

result = expected_fbeta if class_metric.beta != 1.0 else expected_f1

assert torch.isclose(class_metric.compute(), result)
assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result)
assert torch.allclose(class_metric.compute(), result, atol=1e-4, rtol=1e-4)
assert torch.allclose(
metric_fn(preds, target, top_k=k, average=average, num_classes=3), result, atol=1e-4, rtol=1e-4
)


@pytest.mark.parametrize("num_classes", [5])
def test_multiclassf1score_with_top_k(num_classes):
"""Test that F1 score increases monotonically with top_k and equals 1 when top_k equals num_classes.

Args:
num_classes: Number of classes in the classification task.

The test verifies two properties:
1. F1 score increases or stays the same as top_k increases
2. F1 score equals 1 when top_k equals num_classes

"""
preds = torch.randn(200, num_classes).softmax(dim=-1)
target = torch.randint(num_classes, (200,))

previous_score = 0.0
for k in range(1, num_classes + 1):
f1_score = MulticlassF1Score(num_classes=num_classes, top_k=k, average="macro")
score = f1_score(preds, target)

assert score >= previous_score, f"F1 score did not increase for top_k={k}"
previous_score = score

if k == num_classes:
assert torch.isclose(
score, torch.tensor(1.0)
), f"F1 score is not 1 for top_k={k} when num_classes={num_classes}"


def test_multiclass_f1_score_top_k_equivalence():
"""Issue: https://github.com/Lightning-AI/torchmetrics/issues/1653.

Test that top-k F1 score is equivalent to corrected top-1 F1 score.
"""
num_classes = 5

preds = torch.randn(200, num_classes).softmax(dim=-1)
target = torch.randint(num_classes, (200,))

f1_val_top3 = MulticlassF1Score(num_classes=num_classes, top_k=3, average="macro")
f1_val_top1 = MulticlassF1Score(num_classes=num_classes, top_k=1, average="macro")

pred_top_3 = torch.argsort(preds, dim=1, descending=True)[:, :3]
pred_top_1 = pred_top_3[:, 0]

target_in_top3 = (target.unsqueeze(1) == pred_top_3).any(dim=1)

pred_corrected_top3 = torch.where(target_in_top3, target, pred_top_1)

score_top3 = f1_val_top3(preds, target)
score_corrected = f1_val_top1(pred_corrected_top3, target)

assert torch.isclose(
score_top3, score_corrected
), f"Top-3 F1 score ({score_top3}) does not match corrected top-1 F1 score ({score_corrected})"


def _reference_sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, average, zero_division):
Expand Down
42 changes: 33 additions & 9 deletions tests/unittests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,16 @@ def test_multiclass_precision_recall_half_gpu(self, inputs, module, functional,
_mc_k_target = tensor([0, 1, 2])
_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])

_mc_k_targets2 = torch.tensor([0, 0, 2])
_mc_k_preds2 = torch.tensor([[0.9, 0.1, 0.0], [0.9, 0.1, 0.0], [0.9, 0.1, 0.0]])
_mc_k_targets2 = tensor([0, 0, 2])
_mc_k_preds2 = tensor([[0.9, 0.1, 0.0], [0.9, 0.1, 0.0], [0.9, 0.1, 0.0]])

_mc_k_target3 = tensor([0, 1, 2, 0])
_mc_k_preds3 = tensor([
[0.1, 0.2, 0.7],
[0.4, 0.4, 0.2],
[0.3, 0.3, 0.4],
[0.3, 0.3, 0.4],
])


@pytest.mark.parametrize(
Expand All @@ -395,10 +403,24 @@ def test_multiclass_precision_recall_half_gpu(self, inputs, module, functional,
@pytest.mark.parametrize(
("k", "preds", "target", "average", "expected_prec", "expected_recall"),
[
(1, _mc_k_preds, _mc_k_target, "micro", tensor(2 / 3), tensor(2 / 3)),
(2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2), tensor(1.0)),
(1, _mc_k_preds2, _mc_k_targets2, "macro", tensor(1 / 3), tensor(1 / 2)),
(2, _mc_k_preds2, _mc_k_targets2, "macro", tensor(1 / 3), tensor(1 / 2)),
(1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)),
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1.0), torch.tensor(1.0)),
(3, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1.0), torch.tensor(1.0)),
(1, _mc_k_preds2, _mc_k_targets2, "macro", torch.tensor(1 / 3), torch.tensor(1 / 2)),
(2, _mc_k_preds2, _mc_k_targets2, "macro", torch.tensor(1 / 3), torch.tensor(1 / 2)),
(3, _mc_k_preds2, _mc_k_targets2, "macro", torch.tensor(1.0), torch.tensor(1.0)),
(1, _mc_k_preds3, _mc_k_target3, "macro", torch.tensor(0.1111), torch.tensor(0.3333)),
(2, _mc_k_preds3, _mc_k_target3, "macro", torch.tensor(0.8333), torch.tensor(0.8333)),
(3, _mc_k_preds3, _mc_k_target3, "macro", torch.tensor(1.0), torch.tensor(1.0)),
(1, _mc_k_preds3, _mc_k_target3, "micro", torch.tensor(0.2500), torch.tensor(0.2500)),
(2, _mc_k_preds3, _mc_k_target3, "micro", torch.tensor(0.7500), torch.tensor(0.7500)),
(3, _mc_k_preds3, _mc_k_target3, "micro", torch.tensor(1.0), torch.tensor(1.0)),
(1, _mc_k_preds3, _mc_k_target3, "weighted", torch.tensor(0.0833), torch.tensor(0.2500)),
(2, _mc_k_preds3, _mc_k_target3, "weighted", torch.tensor(0.8750), torch.tensor(0.7500)),
(3, _mc_k_preds3, _mc_k_target3, "weighted", torch.tensor(1.0), torch.tensor(1.0)),
(1, _mc_k_preds3, _mc_k_target3, "none", torch.tensor([0.0000, 0.0000, 0.3333]), torch.tensor([0.0, 0.0, 1.0])),
(2, _mc_k_preds3, _mc_k_target3, "none", torch.tensor([1.0000, 1.0000, 0.5000]), torch.tensor([0.5, 1.0, 1.0])),
(3, _mc_k_preds3, _mc_k_target3, "none", torch.tensor([1.0, 1.0, 1.0]), torch.tensor([1.0, 1.0, 1.0])),
],
)
def test_top_k(
Expand All @@ -411,14 +433,16 @@ def test_top_k(
expected_prec: Tensor,
expected_recall: Tensor,
):
"""A simple test to check that top_k works as expected."""
"""A test to validate top_k functionality for precision and recall."""
class_metric = metric_class(top_k=k, average=average, num_classes=3)
class_metric.update(preds, target)

result = expected_prec if metric_class.__name__ == "MulticlassPrecision" else expected_recall

assert torch.equal(class_metric.compute(), result)
assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result)
assert torch.allclose(class_metric.compute(), result, atol=1e-4, rtol=1e-4)
assert torch.allclose(
metric_fn(preds, target, top_k=k, average=average, num_classes=3), result, atol=1e-4, rtol=1e-4
)


def _reference_sklearn_precision_recall_multilabel_global(preds, target, sk_fn, ignore_index, average, zero_division):
Expand Down
39 changes: 32 additions & 7 deletions tests/unittests/classification/test_specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
from scipy.special import expit as sigmoid
from sklearn.metrics import confusion_matrix as sk_confusion_matrix
from torch import Tensor, tensor
from torch import Tensor
from torchmetrics.classification.specificity import (
BinarySpecificity,
MulticlassSpecificity,
Expand Down Expand Up @@ -355,24 +355,49 @@ def test_multiclass_specificity_dtype_gpu(self, inputs, dtype):
)


_mc_k_target = tensor([0, 1, 2])
_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
_mc_k_target = torch.tensor([0, 1, 2])
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])

_mc_k_target2 = torch.tensor([0, 1, 2, 0])
_mc_k_preds2 = torch.tensor([
[0.1, 0.2, 0.7],
[0.4, 0.4, 0.2],
[0.3, 0.3, 0.4],
[0.3, 0.3, 0.4],
])


@pytest.mark.parametrize(
("k", "preds", "target", "average", "expected_spec"),
[
(1, _mc_k_preds, _mc_k_target, "micro", tensor(5 / 6)),
(2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2)),
(1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(5 / 6)),
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1.0)),
(1, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(0.6111)),
(2, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(0.8889)),
(3, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(1.0)),
(1, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(0.6250)),
(2, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(0.8750)),
(3, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(1.0)),
(1, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(0.5833)),
(2, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(0.9167)),
(3, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(1.0)),
(1, _mc_k_preds2, _mc_k_target2, "none", torch.tensor([0.5000, 1.0000, 0.3333])),
(2, _mc_k_preds2, _mc_k_target2, "none", torch.tensor([1.0000, 1.0000, 0.6667])),
(3, _mc_k_preds2, _mc_k_target2, "none", torch.tensor([1.0, 1.0, 1.0])),
],
)
def test_top_k(k: int, preds: Tensor, target: Tensor, average: str, expected_spec: Tensor):
"""A simple test to check that top_k works as expected."""
class_metric = MulticlassSpecificity(top_k=k, average=average, num_classes=3)
class_metric.update(preds, target)

assert torch.equal(class_metric.compute(), expected_spec)
assert torch.equal(multiclass_specificity(preds, target, top_k=k, average=average, num_classes=3), expected_spec)
assert torch.allclose(class_metric.compute(), expected_spec, atol=1e-4, rtol=1e-4)
assert torch.allclose(
multiclass_specificity(preds, target, top_k=k, average=average, num_classes=3),
expected_spec,
atol=1e-4,
rtol=1e-4,
)


def _reference_specificity_multilabel_global(preds, target, ignore_index, average):
Expand Down
Loading
Loading