From afd1fb857011f8d9e28df17772b16637b6fde6aa Mon Sep 17 00:00:00 2001 From: zhaozheng09 Date: Fri, 22 Nov 2024 10:24:30 +0800 Subject: [PATCH 01/15] async host/device . --- .../classification/precision_recall_curve.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 3c5a840efa1..25ed6c21b93 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -182,8 +182,9 @@ def _binary_precision_recall_curve_format( preds = preds[idx] target = target[idx] - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.sigmoid() + out_of_bounds = (preds < 0) | (preds > 1) + out_of_bounds = out_of_bounds.any() + preds = torch.where(out_of_bounds, preds.sigmoid(), preds) thresholds = _adjust_threshold_arg(thresholds, preds.device) return preds, target, thresholds @@ -761,8 +762,10 @@ def _multilabel_precision_recall_curve_format( """ preds = preds.transpose(0, 1).reshape(num_labels, -1).T target = target.transpose(0, 1).reshape(num_labels, -1).T - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.sigmoid() + + out_of_bounds = (preds < 0) | (preds > 1) + out_of_bounds = out_of_bounds.any() + preds = torch.where(out_of_bounds, preds.sigmoid(), preds) thresholds = _adjust_threshold_arg(thresholds, preds.device) if ignore_index is not None and thresholds is not None: From 89a616ab3ad07ad31f850a96bd7e2f0ce7049dfa Mon Sep 17 00:00:00 2001 From: meng song Date: Sun, 24 Nov 2024 16:50:52 +0800 Subject: [PATCH 02/15] unittest . --- .../classification/precision_recall_curve.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 25ed6c21b93..e2bc3d641c4 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -182,9 +182,14 @@ def _binary_precision_recall_curve_format( preds = preds[idx] target = target[idx] - out_of_bounds = (preds < 0) | (preds > 1) - out_of_bounds = out_of_bounds.any() - preds = torch.where(out_of_bounds, preds.sigmoid(), preds) + # "sigmoid_cpu" not implemented for 'Half' + if preds.dtype != torch.float16 or preds.device != torch.device("cpu"): + out_of_bounds = (preds < 0) | (preds > 1) + out_of_bounds = out_of_bounds.any() + preds = torch.where(out_of_bounds, preds.sigmoid(), preds) + else: + if not torch.all((preds >= 0) * (preds <= 1)): + preds = preds.sigmoid() thresholds = _adjust_threshold_arg(thresholds, preds.device) return preds, target, thresholds @@ -763,9 +768,14 @@ def _multilabel_precision_recall_curve_format( preds = preds.transpose(0, 1).reshape(num_labels, -1).T target = target.transpose(0, 1).reshape(num_labels, -1).T - out_of_bounds = (preds < 0) | (preds > 1) - out_of_bounds = out_of_bounds.any() - preds = torch.where(out_of_bounds, preds.sigmoid(), preds) + # "sigmoid_cpu" not implemented for 'Half' + if preds.dtype != torch.float16 or preds.device != torch.device("cpu"): + out_of_bounds = (preds < 0) | (preds > 1) + out_of_bounds = out_of_bounds.any() + preds = torch.where(out_of_bounds, preds.sigmoid(), preds) + else: + if not torch.all((preds >= 0) * (preds <= 1)): + preds = preds.sigmoid() thresholds = _adjust_threshold_arg(thresholds, preds.device) if ignore_index is not None and thresholds is not None: From 9afd681507a1140f901578fd670c6967462568d3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 24 Nov 2024 09:11:07 +0000 Subject: [PATCH 03/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../classification/precision_recall_curve.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index e2bc3d641c4..991606e4b25 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -184,12 +184,12 @@ def _binary_precision_recall_curve_format( # "sigmoid_cpu" not implemented for 'Half' if preds.dtype != torch.float16 or preds.device != torch.device("cpu"): - out_of_bounds = (preds < 0) | (preds > 1) - out_of_bounds = out_of_bounds.any() - preds = torch.where(out_of_bounds, preds.sigmoid(), preds) + out_of_bounds = (preds < 0) | (preds > 1) + out_of_bounds = out_of_bounds.any() + preds = torch.where(out_of_bounds, preds.sigmoid(), preds) else: - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.sigmoid() + if not torch.all((preds >= 0) * (preds <= 1)): + preds = preds.sigmoid() thresholds = _adjust_threshold_arg(thresholds, preds.device) return preds, target, thresholds @@ -770,12 +770,12 @@ def _multilabel_precision_recall_curve_format( # "sigmoid_cpu" not implemented for 'Half' if preds.dtype != torch.float16 or preds.device != torch.device("cpu"): - out_of_bounds = (preds < 0) | (preds > 1) - out_of_bounds = out_of_bounds.any() - preds = torch.where(out_of_bounds, preds.sigmoid(), preds) + out_of_bounds = (preds < 0) | (preds > 1) + out_of_bounds = out_of_bounds.any() + preds = torch.where(out_of_bounds, preds.sigmoid(), preds) else: - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.sigmoid() + if not torch.all((preds >= 0) * (preds <= 1)): + preds = preds.sigmoid() thresholds = _adjust_threshold_arg(thresholds, preds.device) if ignore_index is not None and thresholds is not None: From 377f77d00845b55ce03a6d973cc63bbad572ba92 Mon Sep 17 00:00:00 2001 From: zhaozheng09 Date: Tue, 26 Nov 2024 19:29:58 +0800 Subject: [PATCH 04/15] update unittest . --- .../test_precision_recall_curve.py | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index 7c034c528e6..3772073536c 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -106,15 +106,27 @@ def test_binary_precision_recall_curve_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") - self.run_precision_test_cpu( - preds=preds, - target=target, - metric_module=BinaryPrecisionRecallCurve, - metric_functional=binary_precision_recall_curve, - metric_args={"thresholds": None}, - dtype=dtype, - ) + try: + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryPrecisionRecallCurve, + metric_functional=binary_precision_recall_curve, + metric_args={"thresholds": None}, + dtype=dtype, + ) + except Exception as e: + print(f"An unexpected error occurred: {e}") + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + else: + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryPrecisionRecallCurve, + metric_functional=binary_precision_recall_curve, + metric_args={"thresholds": None}, + dtype=dtype, + ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") @pytest.mark.parametrize("dtype", [torch.half, torch.double]) From 5146a6d1435d6192a55a9b56d0a985ab6bb4ba09 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Wed, 27 Nov 2024 15:56:53 +0100 Subject: [PATCH 05/15] chlog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a253f90f6ef..5d5c9554ef0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Delete `Device2Host` caused by comm with device and host ([#2840](https://github.com/PyTorchLightning/metrics/pull/2840)) --- From 5e25b283f7efc9e6481bd1a81b22a2450203d3d1 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 3 Dec 2024 15:16:04 +0100 Subject: [PATCH 06/15] revert test file --- .../test_precision_recall_curve.py | 30 ++++++------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index 3772073536c..7c034c528e6 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -106,27 +106,15 @@ def test_binary_precision_recall_curve_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs if (preds < 0).any() and dtype == torch.half: - try: - self.run_precision_test_cpu( - preds=preds, - target=target, - metric_module=BinaryPrecisionRecallCurve, - metric_functional=binary_precision_recall_curve, - metric_args={"thresholds": None}, - dtype=dtype, - ) - except Exception as e: - print(f"An unexpected error occurred: {e}") - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") - else: - self.run_precision_test_cpu( - preds=preds, - target=target, - metric_module=BinaryPrecisionRecallCurve, - metric_functional=binary_precision_recall_curve, - metric_args={"thresholds": None}, - dtype=dtype, - ) + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryPrecisionRecallCurve, + metric_functional=binary_precision_recall_curve, + metric_args={"thresholds": None}, + dtype=dtype, + ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") @pytest.mark.parametrize("dtype", [torch.half, torch.double]) From 08d9ece8675524db26c7bcc45d274c40fe9f44f8 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 3 Dec 2024 15:16:29 +0100 Subject: [PATCH 07/15] general conditional compute function --- src/torchmetrics/utilities/compute.py | 36 +++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index cbb648a8844..7832dab7493 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from typing_extensions import Literal def _safe_matmul(x: Tensor, y: Tensor) -> Tensor: @@ -184,3 +185,38 @@ def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor: indices = torch.clamp(indices, 0, len(m) - 1) return m[indices] * x + b[indices] + + +def normalize_logits_if_needed(tensor: Tensor, normalization: Literal["sigmoid", "softmax"]) -> Tensor: + """Normalize logits if needed. + + If input tensor is outside the [0,1] we assume that logits are provided and apply the normalization. + Use torch.where to prevent device-host sync. + + Args: + tensor: input tensor that may be logits or probabilities + normalization: normalization method, either 'sigmoid' or 'softmax' + + Returns: + normalized tensor if needed + + Example: + >>> import torch + >>> tensor = torch.tensor([-1.0, 0.0, 1.0]) + >>> normalize_logits_if_needed(tensor, normalization="sigmoid") + tensor([0.2689, 0.5000, 0.7311]) + >>> tensor = torch.tensor([[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]]) + >>> normalize_logits_if_needed(tensor, normalization="softmax") + tensor([[0.0900, 0.2447, 0.6652], + [0.6652, 0.2447, 0.0900]]) + >>> tensor = torch.tensor([0.0, 0.5, 1.0]) + >>> normalize_logits_if_needed(tensor, normalization="sigmoid") + tensor([0.0000, 0.5000, 1.0000]) + + """ + condition = ((tensor < 0) | (tensor > 1)).any() + return torch.where( + condition, + torch.sigmoid(tensor) if normalization == "sigmoid" else torch.softmax(tensor, dim=1), + tensor, + ) From ae09c7dc23dff31e39b75f0a99fa92e666e8bcb3 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 3 Dec 2024 15:16:51 +0100 Subject: [PATCH 08/15] fix code multiple locations using new function --- .../classification/calibration_error.py | 4 ++-- .../classification/confusion_matrix.py | 8 +++---- .../functional/classification/hinge.py | 5 ++-- .../classification/precision_recall_curve.py | 23 ++++--------------- .../functional/classification/stat_scores.py | 8 +++---- 5 files changed, 14 insertions(+), 34 deletions(-) diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index ca55bb2f79b..22e529588a0 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -23,6 +23,7 @@ _multiclass_confusion_matrix_format, _multiclass_confusion_matrix_tensor_validation, ) +from torchmetrics.utilities.compute import normalize_logits_if_needed from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel @@ -239,8 +240,7 @@ def _multiclass_calibration_error_update( preds: Tensor, target: Tensor, ) -> tuple[Tensor, Tensor]: - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.softmax(1) + preds = normalize_logits_if_needed(preds, "softmax") confidences, predictions = preds.max(dim=1) accuracies = predictions.eq(target) return confidences.float(), accuracies.float() diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 92059072490..e6443c00d74 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -18,6 +18,7 @@ from typing_extensions import Literal from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.compute import normalize_logits_if_needed from torchmetrics.utilities.data import _bincount from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.prints import rank_zero_warn @@ -137,9 +138,7 @@ def _binary_confusion_matrix_format( target = target[idx] if preds.is_floating_point(): - if not torch.all((preds >= 0) * (preds <= 1)): - # preds is logits, convert with sigmoid - preds = preds.sigmoid() + preds = normalize_logits_if_needed(preds, "sigmoid") if convert_to_labels: preds = preds > threshold @@ -491,8 +490,7 @@ def _multilabel_confusion_matrix_format( """ if preds.is_floating_point(): - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.sigmoid() + preds = normalize_logits_if_needed(preds, "sigmoid") if should_threshold: preds = preds > threshold preds = torch.movedim(preds, 1, -1).reshape(-1, num_labels) diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index 8fe7cf840b8..2e6b8740886 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -23,6 +23,7 @@ _multiclass_confusion_matrix_format, _multiclass_confusion_matrix_tensor_validation, ) +from torchmetrics.utilities.compute import normalize_logits_if_needed from torchmetrics.utilities.data import to_onehot from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel @@ -153,9 +154,7 @@ def _multiclass_hinge_loss_update( squared: bool, multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", ) -> tuple[Tensor, Tensor]: - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.softmax(1) - + preds = normalize_logits_if_needed(preds, "softmax") target = to_onehot(target, max(2, preds.shape[1])).bool() if multiclass_mode == "crammer-singer": margin = preds[target] diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 991606e4b25..b00c1975606 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -21,7 +21,7 @@ from typing_extensions import Literal from torchmetrics.utilities.checks import _check_same_shape -from torchmetrics.utilities.compute import _safe_divide, interp +from torchmetrics.utilities.compute import _safe_divide, interp, normalize_logits_if_needed from torchmetrics.utilities.data import _bincount, _cumsum from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.prints import rank_zero_warn @@ -182,14 +182,7 @@ def _binary_precision_recall_curve_format( preds = preds[idx] target = target[idx] - # "sigmoid_cpu" not implemented for 'Half' - if preds.dtype != torch.float16 or preds.device != torch.device("cpu"): - out_of_bounds = (preds < 0) | (preds > 1) - out_of_bounds = out_of_bounds.any() - preds = torch.where(out_of_bounds, preds.sigmoid(), preds) - else: - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.sigmoid() + preds = normalize_logits_if_needed(preds, "sigmoid") thresholds = _adjust_threshold_arg(thresholds, preds.device) return preds, target, thresholds @@ -458,8 +451,7 @@ def _multiclass_precision_recall_curve_format( preds = preds[idx] target = target[idx] - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.softmax(1) + preds = normalize_logits_if_needed(preds, "softmax") if average == "micro": preds = preds.flatten() @@ -768,14 +760,7 @@ def _multilabel_precision_recall_curve_format( preds = preds.transpose(0, 1).reshape(num_labels, -1).T target = target.transpose(0, 1).reshape(num_labels, -1).T - # "sigmoid_cpu" not implemented for 'Half' - if preds.dtype != torch.float16 or preds.device != torch.device("cpu"): - out_of_bounds = (preds < 0) | (preds > 1) - out_of_bounds = out_of_bounds.any() - preds = torch.where(out_of_bounds, preds.sigmoid(), preds) - else: - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.sigmoid() + preds = normalize_logits_if_needed(preds, "sigmoid") thresholds = _adjust_threshold_arg(thresholds, preds.device) if ignore_index is not None and thresholds is not None: diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index fbb6098db40..d111b5459bb 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -18,6 +18,7 @@ from typing_extensions import Literal from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification +from torchmetrics.utilities.compute import normalize_logits_if_needed from torchmetrics.utilities.data import _bincount, select_topk from torchmetrics.utilities.enums import AverageMethod, ClassificationTask, DataType, MDMCAverageMethod @@ -105,9 +106,7 @@ def _binary_stat_scores_format( """ if preds.is_floating_point(): - if not torch.all((preds >= 0) * (preds <= 1)): - # preds is logits, convert with sigmoid - preds = preds.sigmoid() + preds = normalize_logits_if_needed(preds, "sigmoid") preds = preds > threshold preds = preds.reshape(preds.shape[0], -1) @@ -659,8 +658,7 @@ def _multilabel_stat_scores_format( """ if preds.is_floating_point(): - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.sigmoid() + preds = normalize_logits_if_needed(preds, "sigmoid") preds = preds > threshold preds = preds.reshape(*preds.shape[:2], -1) target = target.reshape(*target.shape[:2], -1) From 9c82719983870cee08645fac280b57979452dfa7 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 3 Dec 2024 15:42:23 +0100 Subject: [PATCH 09/15] make cpu + sigmoid skipping less restrictive for classification --- tests/unittests/classification/test_accuracy.py | 13 +++++++------ tests/unittests/classification/test_auroc.py | 5 +++-- .../classification/test_average_precision.py | 5 +++-- .../classification/test_calibration_error.py | 5 +++-- tests/unittests/classification/test_cohen_kappa.py | 9 +++++---- .../classification/test_confusion_matrix.py | 9 +++++---- tests/unittests/classification/test_exact_match.py | 9 +++++---- tests/unittests/classification/test_f_beta.py | 13 +++++++------ .../unittests/classification/test_group_fairness.py | 5 +++-- .../classification/test_hamming_distance.py | 13 +++++++------ tests/unittests/classification/test_jaccard.py | 9 +++++---- tests/unittests/classification/test_logauc.py | 5 +++-- .../classification/test_matthews_corrcoef.py | 9 +++++---- .../test_negative_predictive_value.py | 13 +++++++------ .../classification/test_precision_fixed_recall.py | 5 +++-- .../classification/test_precision_recall.py | 13 +++++++------ .../classification/test_precision_recall_curve.py | 5 +++-- tests/unittests/classification/test_ranking.py | 5 +++-- .../classification/test_recall_fixed_precision.py | 5 +++-- tests/unittests/classification/test_roc.py | 5 +++-- .../classification/test_sensitivity_specificity.py | 6 +++--- tests/unittests/classification/test_specificity.py | 13 +++++++------ .../classification/test_specificity_sensitivity.py | 5 +++-- tests/unittests/classification/test_stat_scores.py | 13 +++++++------ 24 files changed, 110 insertions(+), 87 deletions(-) diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index 65e42c00b07..59b466d33c1 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -27,6 +27,7 @@ multilabel_accuracy, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -153,8 +154,8 @@ def test_binary_accuracy_half_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -310,8 +311,8 @@ def test_multiclass_accuracy_half_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -585,8 +586,8 @@ def test_multilabel_accuracy_half_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index 30d4acb470c..c7fdb54d6c1 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -24,6 +24,7 @@ from torchmetrics.functional.classification.auroc import binary_auroc, multiclass_auroc, multilabel_auroc from torchmetrics.functional.classification.roc import binary_roc from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -102,8 +103,8 @@ def test_binary_auroc_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_average_precision.py b/tests/unittests/classification/test_average_precision.py index da0dc2f56b6..cf37360e832 100644 --- a/tests/unittests/classification/test_average_precision.py +++ b/tests/unittests/classification/test_average_precision.py @@ -33,6 +33,7 @@ ) from torchmetrics.functional.classification.precision_recall_curve import binary_precision_recall_curve from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -106,8 +107,8 @@ def test_binary_average_precision_differentiability(self, inputs): def test_binary_average_precision_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index b8b6bfc1646..8e2556c0533 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -29,6 +29,7 @@ multiclass_calibration_error, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -112,8 +113,8 @@ def test_binary_calibration_error_differentiability(self, inputs): def test_binary_calibration_error_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_cohen_kappa.py b/tests/unittests/classification/test_cohen_kappa.py index 1f2585372bd..4c4a411aab7 100644 --- a/tests/unittests/classification/test_cohen_kappa.py +++ b/tests/unittests/classification/test_cohen_kappa.py @@ -21,6 +21,7 @@ from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa from torchmetrics.functional.classification.cohen_kappa import binary_cohen_kappa, multiclass_cohen_kappa from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -103,8 +104,8 @@ def test_binary_cohen_kappa_dtypes_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -206,8 +207,8 @@ def test_multiclass_cohen_kappa_dtypes_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index 4d27dfc2069..23777dbdc2a 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -30,6 +30,7 @@ multilabel_confusion_matrix, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -114,8 +115,8 @@ def test_binary_confusion_matrix_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -367,8 +368,8 @@ def test_multilabel_confusion_matrix_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_exact_match.py b/tests/unittests/classification/test_exact_match.py index 5afd4c00e40..3cb8caa2061 100644 --- a/tests/unittests/classification/test_exact_match.py +++ b/tests/unittests/classification/test_exact_match.py @@ -20,6 +20,7 @@ from torchmetrics.classification.exact_match import ExactMatch, MulticlassExactMatch, MultilabelExactMatch from torchmetrics.functional.classification.exact_match import multiclass_exact_match, multilabel_exact_match from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -121,8 +122,8 @@ def test_multiclass_exact_match_half_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -250,8 +251,8 @@ def test_multilabel_exact_match_half_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_f_beta.py b/tests/unittests/classification/test_f_beta.py index 075e37cc699..3c3e429f232 100644 --- a/tests/unittests/classification/test_f_beta.py +++ b/tests/unittests/classification/test_f_beta.py @@ -40,6 +40,7 @@ multilabel_fbeta_score, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -171,8 +172,8 @@ def test_binary_fbeta_score_half_cpu(self, inputs, module, functional, compare, """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -349,8 +350,8 @@ def test_multiclass_fbeta_score_half_cpu(self, inputs, module, functional, compa """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -608,8 +609,8 @@ def test_multilabel_fbeta_score_half_cpu(self, inputs, module, functional, compa """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_group_fairness.py b/tests/unittests/classification/test_group_fairness.py index 5f627676f09..d1899831d7e 100644 --- a/tests/unittests/classification/test_group_fairness.py +++ b/tests/unittests/classification/test_group_fairness.py @@ -26,6 +26,7 @@ from torchmetrics import Metric from torchmetrics.classification.group_fairness import BinaryFairness from torchmetrics.functional.classification.group_fairness import binary_fairness +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import THRESHOLD from unittests._helpers import seed_all @@ -282,8 +283,8 @@ def test_binary_fairness_half_cpu(self, inputs, dtype): """Test class implementation of metric.""" preds, target, groups = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_hamming_distance.py b/tests/unittests/classification/test_hamming_distance.py index a7a42db61b0..6d4f0f824cc 100644 --- a/tests/unittests/classification/test_hamming_distance.py +++ b/tests/unittests/classification/test_hamming_distance.py @@ -31,6 +31,7 @@ multilabel_hamming_distance, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -138,8 +139,8 @@ def test_binary_hamming_distance_differentiability(self, inputs): def test_binary_hamming_distance_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -301,8 +302,8 @@ def test_multiclass_hamming_distance_differentiability(self, inputs): def test_multiclass_hamming_distance_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -485,8 +486,8 @@ def test_multilabel_hamming_distance_differentiability(self, inputs): def test_multilabel_hamming_distance_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index 0a20a2e458a..606825f7e71 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -32,6 +32,7 @@ multilabel_jaccard_index, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -108,8 +109,8 @@ def test_binary_jaccard_index_differentiability(self, inputs): def test_binary_jaccard_index_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -355,8 +356,8 @@ def test_multilabel_jaccard_index_differentiability(self, inputs): def test_multilabel_jaccard_index_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_logauc.py b/tests/unittests/classification/test_logauc.py index 26cb395f45e..6494ac72372 100644 --- a/tests/unittests/classification/test_logauc.py +++ b/tests/unittests/classification/test_logauc.py @@ -27,6 +27,7 @@ from torchmetrics.functional.classification.logauc import binary_logauc, multiclass_logauc, multilabel_logauc from torchmetrics.functional.classification.roc import binary_roc from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -105,8 +106,8 @@ def test_binary_logauc_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_matthews_corrcoef.py b/tests/unittests/classification/test_matthews_corrcoef.py index b340db8d713..fc4d762384b 100644 --- a/tests/unittests/classification/test_matthews_corrcoef.py +++ b/tests/unittests/classification/test_matthews_corrcoef.py @@ -30,6 +30,7 @@ multilabel_matthews_corrcoef, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -105,8 +106,8 @@ def test_binary_matthews_corrcoef_differentiability(self, inputs): def test_binary_matthews_corrcoef_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -287,8 +288,8 @@ def test_multilabel_matthews_corrcoef_differentiability(self, inputs): def test_multilabel_matthews_corrcoef_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_negative_predictive_value.py b/tests/unittests/classification/test_negative_predictive_value.py index 464884ca82a..2fb352bc74f 100644 --- a/tests/unittests/classification/test_negative_predictive_value.py +++ b/tests/unittests/classification/test_negative_predictive_value.py @@ -31,6 +31,7 @@ multilabel_negative_predictive_value, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -155,8 +156,8 @@ def test_binary_negative_predictive_value_differentiability(self, inputs): def test_binary_negative_predictive_value_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -332,8 +333,8 @@ def test_multiclass_negative_predictive_value_differentiability(self, inputs): def test_multiclass_negative_predictive_value_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -538,8 +539,8 @@ def test_multilabel_negative_predictive_value_differentiability(self, inputs): def test_multilabel_negative_predictive_value_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_precision_fixed_recall.py b/tests/unittests/classification/test_precision_fixed_recall.py index b6649ad869d..03c8ee7654f 100644 --- a/tests/unittests/classification/test_precision_fixed_recall.py +++ b/tests/unittests/classification/test_precision_fixed_recall.py @@ -32,6 +32,7 @@ multilabel_precision_at_fixed_recall, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -125,8 +126,8 @@ def test_binary_precision_at_fixed_recall_differentiability(self, inputs): def test_binary_precision_at_fixed_recall_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index 7717ffa5b0d..56d87ebf073 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -40,6 +40,7 @@ multilabel_recall, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -172,8 +173,8 @@ def test_binary_precision_recall_differentiability(self, inputs, module, functio def test_binary_precision_recall_half_cpu(self, inputs, module, functional, compare, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -355,8 +356,8 @@ def test_multiclass_precision_recall_differentiability(self, inputs, module, fun def test_multiclass_precision_recall_half_cpu(self, inputs, module, functional, compare, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -612,8 +613,8 @@ def test_multilabel_precision_recall_differentiability(self, inputs, module, fun def test_multilabel_precision_recall_half_cpu(self, inputs, module, functional, compare, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index 7c034c528e6..d6a79b9b5cb 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -32,6 +32,7 @@ multilabel_precision_recall_curve, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -105,8 +106,8 @@ def test_binary_precision_recall_curve_differentiability(self, inputs): def test_binary_precision_recall_curve_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_ranking.py b/tests/unittests/classification/test_ranking.py index e85d38cde05..4727ab882a1 100644 --- a/tests/unittests/classification/test_ranking.py +++ b/tests/unittests/classification/test_ranking.py @@ -30,6 +30,7 @@ multilabel_ranking_average_precision, multilabel_ranking_loss, ) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -117,8 +118,8 @@ def test_multilabel_ranking_differentiability(self, inputs, metric, functional_m def test_multilabel_ranking_dtype_cpu(self, inputs, metric, functional_metric, ref_metric, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") if dtype == torch.half and functional_metric == multilabel_ranking_average_precision: pytest.xfail( reason="multilabel_ranking_average_precision requires torch.unique which is not implemented for half" diff --git a/tests/unittests/classification/test_recall_fixed_precision.py b/tests/unittests/classification/test_recall_fixed_precision.py index 2d73d64f264..5bbf2e55e58 100644 --- a/tests/unittests/classification/test_recall_fixed_precision.py +++ b/tests/unittests/classification/test_recall_fixed_precision.py @@ -32,6 +32,7 @@ multilabel_recall_at_fixed_precision, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -129,8 +130,8 @@ def test_binary_recall_at_fixed_precision_differentiability(self, inputs): def test_binary_recall_at_fixed_precision_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_roc.py b/tests/unittests/classification/test_roc.py index f6cbd173128..5ad6dee35fa 100644 --- a/tests/unittests/classification/test_roc.py +++ b/tests/unittests/classification/test_roc.py @@ -23,6 +23,7 @@ from torchmetrics.classification.roc import ROC, BinaryROC, MulticlassROC, MultilabelROC from torchmetrics.functional.classification.roc import binary_roc, multiclass_roc, multilabel_roc from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -98,8 +99,8 @@ def test_binary_roc_differentiability(self, inputs): def test_binary_roc_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_sensitivity_specificity.py b/tests/unittests/classification/test_sensitivity_specificity.py index 47884bae2a3..cc85f6e4e28 100644 --- a/tests/unittests/classification/test_sensitivity_specificity.py +++ b/tests/unittests/classification/test_sensitivity_specificity.py @@ -33,7 +33,7 @@ multilabel_sensitivity_at_specificity, ) from torchmetrics.metric import Metric -from torchmetrics.utilities.imports import _SKLEARN_GREATER_EQUAL_1_3 +from torchmetrics.utilities.imports import _SKLEARN_GREATER_EQUAL_1_3, _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -153,8 +153,8 @@ def test_binary_sensitivity_at_specificity_differentiability(self, inputs): def test_binary_sensitivity_at_specificity_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_specificity.py b/tests/unittests/classification/test_specificity.py index fe5fd8977a8..437d9e07af9 100644 --- a/tests/unittests/classification/test_specificity.py +++ b/tests/unittests/classification/test_specificity.py @@ -31,6 +31,7 @@ multilabel_specificity, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -151,8 +152,8 @@ def test_binary_specificity_differentiability(self, inputs): def test_binary_specificity_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -328,8 +329,8 @@ def test_multiclass_specificity_differentiability(self, inputs): def test_multiclass_specificity_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -532,8 +533,8 @@ def test_multilabel_specificity_differentiability(self, inputs): def test_multilabel_specificity_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_specificity_sensitivity.py b/tests/unittests/classification/test_specificity_sensitivity.py index 934d669678a..9e866dbabd9 100644 --- a/tests/unittests/classification/test_specificity_sensitivity.py +++ b/tests/unittests/classification/test_specificity_sensitivity.py @@ -33,6 +33,7 @@ multilabel_specificity_at_sensitivity, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -148,8 +149,8 @@ def test_binary_specificity_at_sensitivity_differentiability(self, inputs): def test_binary_specificity_at_sensitivity_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index 2a5a53bb8aa..fee079011be 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -31,6 +31,7 @@ multilabel_stat_scores, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -135,8 +136,8 @@ def test_binary_stat_scores_differentiability(self, inputs): def test_binary_stat_scores_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -300,8 +301,8 @@ def test_multiclass_stat_scores_differentiability(self, inputs): def test_multiclass_stat_scores_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -552,8 +553,8 @@ def test_multilabel_stat_scores_differentiability(self, inputs): def test_multilabel_stat_scores_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, From 1f1833e638d5071ca728dfb06b2c9b3ed9ac253f Mon Sep 17 00:00:00 2001 From: zhaozheng <976525070@qq.com> Date: Wed, 11 Dec 2024 14:51:37 +0800 Subject: [PATCH 10/15] compat cpu and device compat cpu and device --- src/torchmetrics/utilities/compute.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index 7832dab7493..c0044276a77 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -214,6 +214,13 @@ def normalize_logits_if_needed(tensor: Tensor, normalization: Literal["sigmoid", tensor([0.0000, 0.5000, 1.0000]) """ + # decrease sigmoid on cpu . + if tensor.device == torch.device("cpu"): + if not torch.all((preds >= 0) * (preds <= 1)): + preds = preds.sigmoid() + return preds + + # decrease device-host sync on device . condition = ((tensor < 0) | (tensor > 1)).any() return torch.where( condition, From 1d57e10cf6d0d93daddc8b89216783495f7fb39f Mon Sep 17 00:00:00 2001 From: zhaozheng <976525070@qq.com> Date: Wed, 11 Dec 2024 15:18:21 +0800 Subject: [PATCH 11/15] fixed not defined fixed not defined --- src/torchmetrics/utilities/compute.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index c0044276a77..7007b51fff7 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -216,9 +216,9 @@ def normalize_logits_if_needed(tensor: Tensor, normalization: Literal["sigmoid", """ # decrease sigmoid on cpu . if tensor.device == torch.device("cpu"): - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.sigmoid() - return preds + if not torch.all((tensor >= 0) * (tensor <= 1)): + tensor = tensor.sigmoid() + return tensor # decrease device-host sync on device . condition = ((tensor < 0) | (tensor > 1)).any() From 2ead3cc46002ea7449439637f8bf1b1770c7ee5c Mon Sep 17 00:00:00 2001 From: zhaozheng <976525070@qq.com> Date: Thu, 12 Dec 2024 07:13:50 +0800 Subject: [PATCH 12/15] add softmax support add softmax support --- src/torchmetrics/utilities/compute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index 7007b51fff7..6e30739ffba 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -217,7 +217,7 @@ def normalize_logits_if_needed(tensor: Tensor, normalization: Literal["sigmoid", # decrease sigmoid on cpu . if tensor.device == torch.device("cpu"): if not torch.all((tensor >= 0) * (tensor <= 1)): - tensor = tensor.sigmoid() + tensor = tensor.sigmoid() if normalization == "sigmoid" else torch.softmax(tensor, dim=1), return tensor # decrease device-host sync on device . From 953a98d2d8ba7a31f01d3a0f5099a574242dc0f0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Dec 2024 23:14:46 +0000 Subject: [PATCH 13/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/utilities/compute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index 6e30739ffba..f096fd684f5 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -217,7 +217,7 @@ def normalize_logits_if_needed(tensor: Tensor, normalization: Literal["sigmoid", # decrease sigmoid on cpu . if tensor.device == torch.device("cpu"): if not torch.all((tensor >= 0) * (tensor <= 1)): - tensor = tensor.sigmoid() if normalization == "sigmoid" else torch.softmax(tensor, dim=1), + tensor = (tensor.sigmoid() if normalization == "sigmoid" else torch.softmax(tensor, dim=1),) return tensor # decrease device-host sync on device . From 2b7ae748df7e201bf3f9bc5fb0470c3598c41836 Mon Sep 17 00:00:00 2001 From: zhaozheng <976525070@qq.com> Date: Thu, 12 Dec 2024 10:23:46 +0800 Subject: [PATCH 14/15] fix tuple bug. fix tuple bug. --- src/torchmetrics/utilities/compute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index f096fd684f5..5a46993d86b 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -217,7 +217,7 @@ def normalize_logits_if_needed(tensor: Tensor, normalization: Literal["sigmoid", # decrease sigmoid on cpu . if tensor.device == torch.device("cpu"): if not torch.all((tensor >= 0) * (tensor <= 1)): - tensor = (tensor.sigmoid() if normalization == "sigmoid" else torch.softmax(tensor, dim=1),) + tensor = tensor.sigmoid() if normalization == "sigmoid" else torch.softmax(tensor, dim=1) return tensor # decrease device-host sync on device . From c945a3729df781532172d9b5508de4ae9ccddb12 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 17 Dec 2024 20:52:02 +0900 Subject: [PATCH 15/15] Apply suggestions from code review --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index cdd9e6b8e2c..62d7ffc4c5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Delete `Device2Host` caused by comm with device and host ([#2840](https://github.com/PyTorchLightning/metrics/pull/2840)) + + - Fixed issue with shared state in metric collection when using dice score ([#2848](https://github.com/PyTorchLightning/metrics/pull/2848))