diff --git a/CHANGELOG.md b/CHANGELOG.md index b8f73495da4..e7222d27248 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed mixed results of `rouge_score` with `accumulate='best'` ([#2830](https://github.com/Lightning-AI/torchmetrics/pull/2830)) +- Fixed `top_k` for `multiclassf1score` is not working correctly([#1653](https://github.com/Lightning-AI/torchmetrics/issues/1653)) ## [1.5.2] - 2024-11-07 diff --git a/tests/unittests/classification/test_f_beta.py b/tests/unittests/classification/test_f_beta.py index d096f03b857..847ceb8611b 100644 --- a/tests/unittests/classification/test_f_beta.py +++ b/tests/unittests/classification/test_f_beta.py @@ -425,55 +425,49 @@ def test_multiclassf1score_with_top_k(num_classes): 2. F1 score equals 1 when top_k equals num_classes """ - # Create test data preds = torch.randn(200, num_classes).softmax(dim=-1) target = torch.randint(num_classes, (200,)) - previous_score = 0.0 # To track the last F1 score + previous_score = 0.0 for k in range(1, num_classes + 1): - # Calculate F1 score with top_k f1_score = MulticlassF1Score(num_classes=num_classes, top_k=k, average="macro") score = f1_score(preds, target) - # Check if the score increases as top_k increases assert score >= previous_score, f"F1 score did not increase for top_k={k}" previous_score = score - # Check if F1 score is 1 when top_k equals num_classes if k == num_classes: assert torch.isclose( score, torch.tensor(1.0) ), f"F1 score is not 1 for top_k={k} when num_classes={num_classes}" -# def test_MulticlassF1Score_top_k_equivalence(num_classes): -# """Issue: https://github.com/Lightning-AI/torchmetrics/issues/1653""" -# # Create test data -# preds = torch.randn(200, num_classes).softmax(dim=-1) -# target = torch.randint(num_classes, (200,)) +def test_multiclass_f1_score_top_k_equivalence(): + """Issue: https://github.com/Lightning-AI/torchmetrics/issues/1653. -# # Initialize metrics -# f1_val_top3 = MulticlassF1Score(num_classes=5, top_k=3, average="macro") -# f1_val_top1 = MulticlassF1Score(num_classes=5, top_k=1, average="macro") + Test that top-k F1 score is equivalent to corrected top-1 F1 score. + """ + num_classes = 5 + + preds = torch.randn(200, num_classes).softmax(dim=-1) + target = torch.randint(num_classes, (200,)) + + f1_val_top3 = MulticlassF1Score(num_classes=num_classes, top_k=3, average="macro") + f1_val_top1 = MulticlassF1Score(num_classes=num_classes, top_k=1, average="macro") + + pred_top_3 = torch.argsort(preds, dim=1, descending=True)[:, :3] + pred_top_1 = pred_top_3[:, 0] -# # Get top-k predictions -# pred_top_3 = torch.argsort(preds, dim=1, descending=True)[:, :3] -# pred_top_1 = pred_top_3[:, 0] + target_in_top3 = (target.unsqueeze(1) == pred_top_3).any(dim=1) -# # Correct predictions if target is in top-3 -# pred_corrected_top3 = torch.where( -# functorch.vmap(lambda t1, t2: torch.isin(t1, t2))(target, pred_top_3), -# target, -# pred_top_1 -# ) + pred_corrected_top3 = torch.where(target_in_top3, target, pred_top_1) -# # Calculate F1 scores -# score_top3 = f1_val_top3(preds, target) -# score_corrected = f1_val_top1(pred_corrected_top3, target) + score_top3 = f1_val_top3(preds, target) + score_corrected = f1_val_top1(pred_corrected_top3, target) -# # Assert that both methods give the same result -# assert torch.isclose(score_top3, score_corrected), \ -# f"Top-3 F1 score ({score_top3}) does not match corrected top-1 F1 score ({score_corrected})" + assert torch.isclose( + score_top3, score_corrected + ), f"Top-3 F1 score ({score_top3}) does not match corrected top-1 F1 score ({score_corrected})" def _reference_sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, average, zero_division):