From 4d9c8434d09f05c731a3c793bce25a78f919b5ab Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Sat, 21 Dec 2024 19:45:30 +0530 Subject: [PATCH] Fix `top_k` for `multiclass-f1score` (#2839) ** Apply suggestions from code review --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka B --- CHANGELOG.md | 3 + .../functional/classification/stat_scores.py | 27 +++++ tests/unittests/classification/test_f_beta.py | 107 +++++++++++++++++- .../classification/test_precision_recall.py | 42 +++++-- .../classification/test_specificity.py | 39 +++++-- .../classification/test_stat_scores.py | 65 ++++++++++- 6 files changed, 258 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f7ae27b4b3..769a6b7d4ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed issue with shared state in metric collection when using dice score ([#2848](https://github.com/PyTorchLightning/metrics/pull/2848)) +- Fixed `top_k` for `multiclassf1score` with one-hot encoding ([#2839](https://github.com/Lightning-AI/torchmetrics/issues/2839)) + + --- ## [1.6.0] - 2024-11-12 diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index d111b5459bb..0d32bef573a 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -340,6 +340,30 @@ def _multiclass_stat_scores_format( return preds, target +def _refine_preds_oh(preds: Tensor, preds_oh: Tensor, target: Tensor, top_k: int) -> Tensor: + """Refines prediction one-hot encodings by replacing entries with target one-hot when there's an intersection. + + When no intersection is found between the top-k predictions and target, uses the top-1 prediction. + + Args: + preds: Original prediction tensor with probabilities/logits + preds_oh: Current one-hot encoded predictions from top-k selection + target: Target tensor with class indices + top_k: Number of top predictions to consider + + Returns: + Refined one-hot encoded predictions tensor + + """ + preds = preds.squeeze() + target = target.squeeze() + top_k_indices = torch.topk(preds, k=top_k, dim=1).indices + top_1_indices = top_k_indices[:, 0] + target_in_topk = torch.any(top_k_indices == target.unsqueeze(1), dim=1) + result = torch.where(target_in_topk, target, top_1_indices) + return torch.zeros_like(preds_oh, dtype=torch.int32).scatter_(-1, result.unsqueeze(1).unsqueeze(1), 1) + + def _multiclass_stat_scores_update( preds: Tensor, target: Tensor, @@ -371,13 +395,16 @@ def _multiclass_stat_scores_update( if top_k > 1: preds_oh = torch.movedim(select_topk(preds, topk=top_k, dim=1), 1, -1) + preds_oh = _refine_preds_oh(preds, preds_oh, target, top_k) else: preds_oh = torch.nn.functional.one_hot( preds.long(), num_classes + 1 if ignore_index is not None and not ignore_in else num_classes ) + target_oh = torch.nn.functional.one_hot( target.long(), num_classes + 1 if ignore_index is not None and not ignore_in else num_classes ) + if ignore_index is not None: if 0 <= ignore_index <= num_classes - 1: target_oh[target == ignore_index, :] = -1 diff --git a/tests/unittests/classification/test_f_beta.py b/tests/unittests/classification/test_f_beta.py index 3c3e429f232..73e988dc36f 100644 --- a/tests/unittests/classification/test_f_beta.py +++ b/tests/unittests/classification/test_f_beta.py @@ -377,7 +377,19 @@ def test_multiclass_fbeta_score_half_gpu(self, inputs, module, functional, compa _mc_k_target = torch.tensor([0, 1, 2]) -_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) +_mc_k_preds = torch.tensor([ + [0.35, 0.4, 0.25], + [0.1, 0.5, 0.4], + [0.2, 0.1, 0.7], +]) + +_mc_k_target2 = torch.tensor([0, 1, 2, 0]) +_mc_k_preds2 = torch.tensor([ + [0.1, 0.2, 0.7], + [0.4, 0.4, 0.2], + [0.3, 0.3, 0.4], + [0.3, 0.3, 0.4], +]) @pytest.mark.parametrize( @@ -391,7 +403,33 @@ def test_multiclass_fbeta_score_half_gpu(self, inputs, module, functional, compa ("k", "preds", "target", "average", "expected_fbeta", "expected_f1"), [ (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)), - (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(5 / 6), torch.tensor(2 / 3)), + (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1.0), torch.tensor(1.0)), + (1, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(0.25), torch.tensor(0.25)), + (2, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(0.75), torch.tensor(0.75)), + (3, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(1.0), torch.tensor(1.0)), + (1, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(0.2381), torch.tensor(0.1667)), + (2, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(0.7963), torch.tensor(0.7778)), + (3, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(1.0), torch.tensor(1.0)), + (1, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(0.1786), torch.tensor(0.1250)), + (2, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(0.7361), torch.tensor(0.7500)), + (3, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(1.0), torch.tensor(1.0)), + ( + 1, + _mc_k_preds2, + _mc_k_target2, + "none", + torch.tensor([0.0000, 0.0000, 0.7143]), + torch.tensor([0.0000, 0.0000, 0.5000]), + ), + ( + 2, + _mc_k_preds2, + _mc_k_target2, + "none", + torch.tensor([0.5556, 1.0000, 0.8333]), + torch.tensor([0.6667, 1.0000, 0.6667]), + ), + (3, _mc_k_preds2, _mc_k_target2, "none", torch.tensor([1.0, 1.0, 1.0]), torch.tensor([1.0, 1.0, 1.0])), ], ) def test_top_k( @@ -404,14 +442,73 @@ def test_top_k( expected_fbeta: Tensor, expected_f1: Tensor, ): - """A simple test to check that top_k works as expected.""" + """A comprehensive test to check that top_k works as expected.""" class_metric = metric_class(top_k=k, average=average, num_classes=3) class_metric.update(preds, target) result = expected_fbeta if class_metric.beta != 1.0 else expected_f1 - assert torch.isclose(class_metric.compute(), result) - assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) + assert torch.allclose(class_metric.compute(), result, atol=1e-4, rtol=1e-4) + assert torch.allclose( + metric_fn(preds, target, top_k=k, average=average, num_classes=3), result, atol=1e-4, rtol=1e-4 + ) + + +@pytest.mark.parametrize("num_classes", [5]) +def test_multiclassf1score_with_top_k(num_classes): + """Test that F1 score increases monotonically with top_k and equals 1 when top_k equals num_classes. + + Args: + num_classes: Number of classes in the classification task. + + The test verifies two properties: + 1. F1 score increases or stays the same as top_k increases + 2. F1 score equals 1 when top_k equals num_classes + + """ + preds = torch.randn(200, num_classes).softmax(dim=-1) + target = torch.randint(num_classes, (200,)) + + previous_score = 0.0 + for k in range(1, num_classes + 1): + f1_score = MulticlassF1Score(num_classes=num_classes, top_k=k, average="macro") + score = f1_score(preds, target) + + assert score >= previous_score, f"F1 score did not increase for top_k={k}" + previous_score = score + + if k == num_classes: + assert torch.isclose( + score, torch.tensor(1.0) + ), f"F1 score is not 1 for top_k={k} when num_classes={num_classes}" + + +def test_multiclass_f1_score_top_k_equivalence(): + """Issue: https://github.com/Lightning-AI/torchmetrics/issues/1653. + + Test that top-k F1 score is equivalent to corrected top-1 F1 score. + """ + num_classes = 5 + + preds = torch.randn(200, num_classes).softmax(dim=-1) + target = torch.randint(num_classes, (200,)) + + f1_val_top3 = MulticlassF1Score(num_classes=num_classes, top_k=3, average="macro") + f1_val_top1 = MulticlassF1Score(num_classes=num_classes, top_k=1, average="macro") + + pred_top_3 = torch.argsort(preds, dim=1, descending=True)[:, :3] + pred_top_1 = pred_top_3[:, 0] + + target_in_top3 = (target.unsqueeze(1) == pred_top_3).any(dim=1) + + pred_corrected_top3 = torch.where(target_in_top3, target, pred_top_1) + + score_top3 = f1_val_top3(preds, target) + score_corrected = f1_val_top1(pred_corrected_top3, target) + + assert torch.isclose( + score_top3, score_corrected + ), f"Top-3 F1 score ({score_top3}) does not match corrected top-1 F1 score ({score_corrected})" def _reference_sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, average, zero_division): diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index 56d87ebf073..c95ececa1dd 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -385,8 +385,16 @@ def test_multiclass_precision_recall_half_gpu(self, inputs, module, functional, _mc_k_target = tensor([0, 1, 2]) _mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -_mc_k_targets2 = torch.tensor([0, 0, 2]) -_mc_k_preds2 = torch.tensor([[0.9, 0.1, 0.0], [0.9, 0.1, 0.0], [0.9, 0.1, 0.0]]) +_mc_k_targets2 = tensor([0, 0, 2]) +_mc_k_preds2 = tensor([[0.9, 0.1, 0.0], [0.9, 0.1, 0.0], [0.9, 0.1, 0.0]]) + +_mc_k_target3 = tensor([0, 1, 2, 0]) +_mc_k_preds3 = tensor([ + [0.1, 0.2, 0.7], + [0.4, 0.4, 0.2], + [0.3, 0.3, 0.4], + [0.3, 0.3, 0.4], +]) @pytest.mark.parametrize( @@ -395,10 +403,24 @@ def test_multiclass_precision_recall_half_gpu(self, inputs, module, functional, @pytest.mark.parametrize( ("k", "preds", "target", "average", "expected_prec", "expected_recall"), [ - (1, _mc_k_preds, _mc_k_target, "micro", tensor(2 / 3), tensor(2 / 3)), - (2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2), tensor(1.0)), - (1, _mc_k_preds2, _mc_k_targets2, "macro", tensor(1 / 3), tensor(1 / 2)), - (2, _mc_k_preds2, _mc_k_targets2, "macro", tensor(1 / 3), tensor(1 / 2)), + (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)), + (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1.0), torch.tensor(1.0)), + (3, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1.0), torch.tensor(1.0)), + (1, _mc_k_preds2, _mc_k_targets2, "macro", torch.tensor(1 / 3), torch.tensor(1 / 2)), + (2, _mc_k_preds2, _mc_k_targets2, "macro", torch.tensor(1 / 3), torch.tensor(1 / 2)), + (3, _mc_k_preds2, _mc_k_targets2, "macro", torch.tensor(1.0), torch.tensor(1.0)), + (1, _mc_k_preds3, _mc_k_target3, "macro", torch.tensor(0.1111), torch.tensor(0.3333)), + (2, _mc_k_preds3, _mc_k_target3, "macro", torch.tensor(0.8333), torch.tensor(0.8333)), + (3, _mc_k_preds3, _mc_k_target3, "macro", torch.tensor(1.0), torch.tensor(1.0)), + (1, _mc_k_preds3, _mc_k_target3, "micro", torch.tensor(0.2500), torch.tensor(0.2500)), + (2, _mc_k_preds3, _mc_k_target3, "micro", torch.tensor(0.7500), torch.tensor(0.7500)), + (3, _mc_k_preds3, _mc_k_target3, "micro", torch.tensor(1.0), torch.tensor(1.0)), + (1, _mc_k_preds3, _mc_k_target3, "weighted", torch.tensor(0.0833), torch.tensor(0.2500)), + (2, _mc_k_preds3, _mc_k_target3, "weighted", torch.tensor(0.8750), torch.tensor(0.7500)), + (3, _mc_k_preds3, _mc_k_target3, "weighted", torch.tensor(1.0), torch.tensor(1.0)), + (1, _mc_k_preds3, _mc_k_target3, "none", torch.tensor([0.0000, 0.0000, 0.3333]), torch.tensor([0.0, 0.0, 1.0])), + (2, _mc_k_preds3, _mc_k_target3, "none", torch.tensor([1.0000, 1.0000, 0.5000]), torch.tensor([0.5, 1.0, 1.0])), + (3, _mc_k_preds3, _mc_k_target3, "none", torch.tensor([1.0, 1.0, 1.0]), torch.tensor([1.0, 1.0, 1.0])), ], ) def test_top_k( @@ -411,14 +433,16 @@ def test_top_k( expected_prec: Tensor, expected_recall: Tensor, ): - """A simple test to check that top_k works as expected.""" + """A test to validate top_k functionality for precision and recall.""" class_metric = metric_class(top_k=k, average=average, num_classes=3) class_metric.update(preds, target) result = expected_prec if metric_class.__name__ == "MulticlassPrecision" else expected_recall - assert torch.equal(class_metric.compute(), result) - assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) + assert torch.allclose(class_metric.compute(), result, atol=1e-4, rtol=1e-4) + assert torch.allclose( + metric_fn(preds, target, top_k=k, average=average, num_classes=3), result, atol=1e-4, rtol=1e-4 + ) def _reference_sklearn_precision_recall_multilabel_global(preds, target, sk_fn, ignore_index, average, zero_division): diff --git a/tests/unittests/classification/test_specificity.py b/tests/unittests/classification/test_specificity.py index 437d9e07af9..4c2c2023630 100644 --- a/tests/unittests/classification/test_specificity.py +++ b/tests/unittests/classification/test_specificity.py @@ -18,7 +18,7 @@ import torch from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix -from torch import Tensor, tensor +from torch import Tensor from torchmetrics.classification.specificity import ( BinarySpecificity, MulticlassSpecificity, @@ -355,15 +355,35 @@ def test_multiclass_specificity_dtype_gpu(self, inputs, dtype): ) -_mc_k_target = tensor([0, 1, 2]) -_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) +_mc_k_target = torch.tensor([0, 1, 2]) +_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) + +_mc_k_target2 = torch.tensor([0, 1, 2, 0]) +_mc_k_preds2 = torch.tensor([ + [0.1, 0.2, 0.7], + [0.4, 0.4, 0.2], + [0.3, 0.3, 0.4], + [0.3, 0.3, 0.4], +]) @pytest.mark.parametrize( ("k", "preds", "target", "average", "expected_spec"), [ - (1, _mc_k_preds, _mc_k_target, "micro", tensor(5 / 6)), - (2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2)), + (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(5 / 6)), + (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1.0)), + (1, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(0.6111)), + (2, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(0.8889)), + (3, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor(1.0)), + (1, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(0.6250)), + (2, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(0.8750)), + (3, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor(1.0)), + (1, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(0.5833)), + (2, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(0.9167)), + (3, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor(1.0)), + (1, _mc_k_preds2, _mc_k_target2, "none", torch.tensor([0.5000, 1.0000, 0.3333])), + (2, _mc_k_preds2, _mc_k_target2, "none", torch.tensor([1.0000, 1.0000, 0.6667])), + (3, _mc_k_preds2, _mc_k_target2, "none", torch.tensor([1.0, 1.0, 1.0])), ], ) def test_top_k(k: int, preds: Tensor, target: Tensor, average: str, expected_spec: Tensor): @@ -371,8 +391,13 @@ def test_top_k(k: int, preds: Tensor, target: Tensor, average: str, expected_spe class_metric = MulticlassSpecificity(top_k=k, average=average, num_classes=3) class_metric.update(preds, target) - assert torch.equal(class_metric.compute(), expected_spec) - assert torch.equal(multiclass_specificity(preds, target, top_k=k, average=average, num_classes=3), expected_spec) + assert torch.allclose(class_metric.compute(), expected_spec, atol=1e-4, rtol=1e-4) + assert torch.allclose( + multiclass_specificity(preds, target, top_k=k, average=average, num_classes=3), + expected_spec, + atol=1e-4, + rtol=1e-4, + ) def _reference_specificity_multilabel_global(preds, target, ignore_index, average): diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index fee079011be..47a7a7cab28 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -26,6 +26,7 @@ StatScores, ) from torchmetrics.functional.classification.stat_scores import ( + _refine_preds_oh, binary_stat_scores, multiclass_stat_scores, multilabel_stat_scores, @@ -362,17 +363,73 @@ def test_raises_error_on_too_many_classes(preds, target, ignore_index, error_mes multiclass_stat_scores(preds, target, num_classes=NUM_CLASSES, ignore_index=ignore_index) +@pytest.mark.parametrize( + ("top_k", "expected_result"), + [ + (1, torch.tensor([[[0, 0, 1]], [[1, 0, 0]], [[0, 1, 0]], [[1, 0, 0]]], dtype=torch.int32)), + (2, torch.tensor([[[1, 0, 0]], [[1, 0, 0]], [[0, 1, 0]], [[0, 0, 1]]], dtype=torch.int32)), + (3, torch.tensor([[[1, 0, 0]], [[0, 1, 0]], [[0, 1, 0]], [[0, 0, 1]]], dtype=torch.int32)), + ], +) +def test_refine_preds_oh(top_k, expected_result): + """Test the _refine_preds_oh function. + + This function tests the behavior of the _refine_preds_oh function with various top_k values + and checks if the output matches the expected one-hot encoded results. + + Args: + top_k: The number of top predictions to consider. + expected_result: The expected one-hot encoded tensor result after refinement. + + """ + preds = torch.tensor([ + [[0.2917], [0.0682], [0.6401]], + [[0.2582], [0.0614], [0.0704]], + [[0.0725], [0.6015], [0.3260]], + [[0.4650], [0.2448], [0.2902]], + ]) + + preds_oh = torch.tensor([[[1, 0, 1]], [[1, 0, 1]], [[0, 1, 1]], [[1, 0, 1]]], dtype=torch.int32) + + target = torch.tensor([0, 1, 1, 2]) + + result = _refine_preds_oh(preds, preds_oh, target, top_k) + assert torch.equal(result, expected_result), ( + f"Test failed for top_k={top_k}. " f"Expected result: {expected_result}, but got: {result}" + ) + + _mc_k_target = torch.tensor([0, 1, 2]) _mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) +_mc_k_target2 = torch.tensor([0, 1, 2, 0]) +_mc_k_preds2 = torch.tensor([ + [0.1, 0.2, 0.7], + [0.4, 0.4, 0.2], + [0.3, 0.3, 0.4], + [0.3, 0.3, 0.4], +]) + @pytest.mark.parametrize( ("k", "preds", "target", "average", "expected"), [ (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor([2, 1, 5, 1, 3])), - (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor([3, 3, 3, 0, 3])), + (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor([3, 0, 6, 0, 3])), (1, _mc_k_preds, _mc_k_target, None, torch.tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])), - (2, _mc_k_preds, _mc_k_target, None, torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])), + (2, _mc_k_preds, _mc_k_target, None, torch.tensor([[1, 1, 1], [0, 0, 0], [2, 2, 2], [0, 0, 0], [1, 1, 1]])), + (1, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor([0.3333, 1.0000, 1.6667, 1.0000, 1.3333])), + (2, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor([1.0000, 0.3333, 2.3333, 0.3333, 1.3333])), + (3, _mc_k_preds2, _mc_k_target2, "macro", torch.tensor([1.3333, 0.0000, 2.6667, 0.0000, 1.3333])), + (1, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor([1, 3, 5, 3, 4])), + (2, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor([3, 1, 7, 1, 4])), + (3, _mc_k_preds2, _mc_k_target2, "micro", torch.tensor([4, 0, 8, 0, 4])), + (1, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor([0.2500, 1.0000, 1.5000, 1.2500, 1.5000])), + (2, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor([1.0000, 0.2500, 2.2500, 0.5000, 1.5000])), + (3, _mc_k_preds2, _mc_k_target2, "weighted", torch.tensor([1.5000, 0.0000, 2.5000, 0.0000, 1.5000])), + (1, _mc_k_preds2, _mc_k_target2, None, torch.tensor([[0, 0, 1], [1, 0, 2], [1, 3, 1], [2, 1, 0], [2, 1, 1]])), + (2, _mc_k_preds2, _mc_k_target2, None, torch.tensor([[1, 1, 1], [0, 0, 1], [2, 3, 2], [1, 0, 0], [2, 1, 1]])), + (3, _mc_k_preds2, _mc_k_target2, None, torch.tensor([[2, 1, 1], [0, 0, 0], [2, 3, 3], [0, 0, 0], [2, 1, 1]])), ], ) def test_top_k_multiclass(k, preds, target, average, expected): @@ -380,9 +437,9 @@ def test_top_k_multiclass(k, preds, target, average, expected): class_metric = MulticlassStatScores(top_k=k, average=average, num_classes=3) class_metric.update(preds, target) - assert torch.allclose(class_metric.compute().long(), expected.T) + assert torch.allclose(class_metric.compute(), expected.T, atol=1e-4, rtol=1e-4) assert torch.allclose( - multiclass_stat_scores(preds, target, top_k=k, average=average, num_classes=3).long(), expected.T + multiclass_stat_scores(preds, target, top_k=k, average=average, num_classes=3), expected.T, atol=1e-4, rtol=1e-4 )