Skip to content

Commit

Permalink
add top_k equivalence test
Browse files Browse the repository at this point in the history
  • Loading branch information
rittik9 committed Nov 22, 2024
1 parent 6d40bb4 commit a949834
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 28 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed mixed results of `rouge_score` with `accumulate='best'` ([#2830](https://github.com/Lightning-AI/torchmetrics/pull/2830))

- Fixed `top_k` for `multiclassf1score` is not working correctly([#1653](https://github.com/Lightning-AI/torchmetrics/issues/1653))

## [1.5.2] - 2024-11-07

Expand Down
50 changes: 22 additions & 28 deletions tests/unittests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,55 +425,49 @@ def test_multiclassf1score_with_top_k(num_classes):
2. F1 score equals 1 when top_k equals num_classes
"""
# Create test data
preds = torch.randn(200, num_classes).softmax(dim=-1)
target = torch.randint(num_classes, (200,))

previous_score = 0.0 # To track the last F1 score
previous_score = 0.0
for k in range(1, num_classes + 1):
# Calculate F1 score with top_k
f1_score = MulticlassF1Score(num_classes=num_classes, top_k=k, average="macro")
score = f1_score(preds, target)

# Check if the score increases as top_k increases
assert score >= previous_score, f"F1 score did not increase for top_k={k}"
previous_score = score

# Check if F1 score is 1 when top_k equals num_classes
if k == num_classes:
assert torch.isclose(
score, torch.tensor(1.0)
), f"F1 score is not 1 for top_k={k} when num_classes={num_classes}"


# def test_MulticlassF1Score_top_k_equivalence(num_classes):
# """Issue: https://github.com/Lightning-AI/torchmetrics/issues/1653"""
# # Create test data
# preds = torch.randn(200, num_classes).softmax(dim=-1)
# target = torch.randint(num_classes, (200,))
def test_multiclass_f1_score_top_k_equivalence():
"""Issue: https://github.com/Lightning-AI/torchmetrics/issues/1653.
# # Initialize metrics
# f1_val_top3 = MulticlassF1Score(num_classes=5, top_k=3, average="macro")
# f1_val_top1 = MulticlassF1Score(num_classes=5, top_k=1, average="macro")
Test that top-k F1 score is equivalent to corrected top-1 F1 score.
"""
num_classes = 5

preds = torch.randn(200, num_classes).softmax(dim=-1)
target = torch.randint(num_classes, (200,))

f1_val_top3 = MulticlassF1Score(num_classes=num_classes, top_k=3, average="macro")
f1_val_top1 = MulticlassF1Score(num_classes=num_classes, top_k=1, average="macro")

pred_top_3 = torch.argsort(preds, dim=1, descending=True)[:, :3]
pred_top_1 = pred_top_3[:, 0]

# # Get top-k predictions
# pred_top_3 = torch.argsort(preds, dim=1, descending=True)[:, :3]
# pred_top_1 = pred_top_3[:, 0]
target_in_top3 = (target.unsqueeze(1) == pred_top_3).any(dim=1)

# # Correct predictions if target is in top-3
# pred_corrected_top3 = torch.where(
# functorch.vmap(lambda t1, t2: torch.isin(t1, t2))(target, pred_top_3),
# target,
# pred_top_1
# )
pred_corrected_top3 = torch.where(target_in_top3, target, pred_top_1)

# # Calculate F1 scores
# score_top3 = f1_val_top3(preds, target)
# score_corrected = f1_val_top1(pred_corrected_top3, target)
score_top3 = f1_val_top3(preds, target)
score_corrected = f1_val_top1(pred_corrected_top3, target)

# # Assert that both methods give the same result
# assert torch.isclose(score_top3, score_corrected), \
# f"Top-3 F1 score ({score_top3}) does not match corrected top-1 F1 score ({score_corrected})"
assert torch.isclose(
score_top3, score_corrected
), f"Top-3 F1 score ({score_top3}) does not match corrected top-1 F1 score ({score_corrected})"


def _reference_sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, average, zero_division):
Expand Down

0 comments on commit a949834

Please sign in to comment.