Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete Device2Host caused by comm with device and host #2840

Merged
merged 25 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
afd1fb8
async host/device .
Nov 22, 2024
89a616a
unittest .
Nov 24, 2024
9afd681
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2024
377f77d
update unittest .
Nov 26, 2024
92ef211
Merge branch 'master' into async_cpu_gpu_20241122
Borda Nov 26, 2024
5146a6d
chlog
Borda Nov 27, 2024
5e25b28
revert test file
SkafteNicki Dec 3, 2024
08d9ece
general conditional compute function
SkafteNicki Dec 3, 2024
ae09c7d
fix code multiple locations using new function
SkafteNicki Dec 3, 2024
9c82719
make cpu + sigmoid skipping less restrictive for classification
SkafteNicki Dec 3, 2024
a07363d
Merge branch 'master' into async_cpu_gpu_20241122
SkafteNicki Dec 3, 2024
eb41cee
Merge branch 'master' into async_cpu_gpu_20241122
mergify[bot] Dec 11, 2024
49f0a9a
Merge branch 'master' into async_cpu_gpu_20241122
mergify[bot] Dec 11, 2024
1f1833e
compat cpu and device
zhaozheng09 Dec 11, 2024
1d57e10
fixed not defined
zhaozheng09 Dec 11, 2024
70faf9d
Merge branch 'master' into async_cpu_gpu_20241122
zhaozheng09 Dec 11, 2024
2ead3cc
add softmax support
zhaozheng09 Dec 11, 2024
953a98d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2024
2b7ae74
fix tuple bug.
zhaozheng09 Dec 12, 2024
e108721
Merge branch 'master' into async_cpu_gpu_20241122
mergify[bot] Dec 17, 2024
fd1c189
Merge branch 'master' into async_cpu_gpu_20241122
mergify[bot] Dec 17, 2024
d85f94c
Merge branch 'master' into async_cpu_gpu_20241122
mergify[bot] Dec 17, 2024
c945a37
Apply suggestions from code review
Borda Dec 17, 2024
1cdba8f
Merge branch 'master' into async_cpu_gpu_20241122
mergify[bot] Dec 17, 2024
6ced9ed
Merge branch 'master' into async_cpu_gpu_20241122
zhaozheng09 Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Delete `Device2Host` caused by comm with device and host ([#2840](https://github.com/PyTorchLightning/metrics/pull/2840))
Borda marked this conversation as resolved.
Show resolved Hide resolved


- Fixed issue with shared state in metric collection when using dice score ([#2848](https://github.com/PyTorchLightning/metrics/pull/2848))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_multiclass_confusion_matrix_format,
_multiclass_confusion_matrix_tensor_validation,
)
from torchmetrics.utilities.compute import normalize_logits_if_needed
from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel


Expand Down Expand Up @@ -239,8 +240,7 @@ def _multiclass_calibration_error_update(
preds: Tensor,
target: Tensor,
) -> tuple[Tensor, Tensor]:
if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.softmax(1)
preds = normalize_logits_if_needed(preds, "softmax")
confidences, predictions = preds.max(dim=1)
accuracies = predictions.eq(target)
return confidences.float(), accuracies.float()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing_extensions import Literal

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.compute import normalize_logits_if_needed
from torchmetrics.utilities.data import _bincount
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.prints import rank_zero_warn
Expand Down Expand Up @@ -137,9 +138,7 @@ def _binary_confusion_matrix_format(
target = target[idx]

if preds.is_floating_point():
if not torch.all((preds >= 0) * (preds <= 1)):
# preds is logits, convert with sigmoid
preds = preds.sigmoid()
preds = normalize_logits_if_needed(preds, "sigmoid")
if convert_to_labels:
preds = preds > threshold

Expand Down Expand Up @@ -491,8 +490,7 @@ def _multilabel_confusion_matrix_format(

"""
if preds.is_floating_point():
if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.sigmoid()
preds = normalize_logits_if_needed(preds, "sigmoid")
if should_threshold:
preds = preds > threshold
preds = torch.movedim(preds, 1, -1).reshape(-1, num_labels)
Expand Down
5 changes: 2 additions & 3 deletions src/torchmetrics/functional/classification/hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_multiclass_confusion_matrix_format,
_multiclass_confusion_matrix_tensor_validation,
)
from torchmetrics.utilities.compute import normalize_logits_if_needed
from torchmetrics.utilities.data import to_onehot
from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel

Expand Down Expand Up @@ -153,9 +154,7 @@ def _multiclass_hinge_loss_update(
squared: bool,
multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer",
) -> tuple[Tensor, Tensor]:
if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.softmax(1)

preds = normalize_logits_if_needed(preds, "softmax")
target = to_onehot(target, max(2, preds.shape[1])).bool()
if multiclass_mode == "crammer-singer":
margin = preds[target]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing_extensions import Literal

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.compute import _safe_divide, interp
from torchmetrics.utilities.compute import _safe_divide, interp, normalize_logits_if_needed
from torchmetrics.utilities.data import _bincount, _cumsum
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.prints import rank_zero_warn
Expand Down Expand Up @@ -182,8 +182,7 @@ def _binary_precision_recall_curve_format(
preds = preds[idx]
target = target[idx]

if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.sigmoid()
preds = normalize_logits_if_needed(preds, "sigmoid")

thresholds = _adjust_threshold_arg(thresholds, preds.device)
return preds, target, thresholds
Expand Down Expand Up @@ -452,8 +451,7 @@ def _multiclass_precision_recall_curve_format(
preds = preds[idx]
target = target[idx]

if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.softmax(1)
preds = normalize_logits_if_needed(preds, "softmax")

if average == "micro":
preds = preds.flatten()
Expand Down Expand Up @@ -761,8 +759,8 @@ def _multilabel_precision_recall_curve_format(
"""
preds = preds.transpose(0, 1).reshape(num_labels, -1).T
target = target.transpose(0, 1).reshape(num_labels, -1).T
if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.sigmoid()

preds = normalize_logits_if_needed(preds, "sigmoid")

thresholds = _adjust_threshold_arg(thresholds, preds.device)
if ignore_index is not None and thresholds is not None:
Expand Down
8 changes: 3 additions & 5 deletions src/torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing_extensions import Literal

from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification
from torchmetrics.utilities.compute import normalize_logits_if_needed
from torchmetrics.utilities.data import _bincount, select_topk
from torchmetrics.utilities.enums import AverageMethod, ClassificationTask, DataType, MDMCAverageMethod

Expand Down Expand Up @@ -105,9 +106,7 @@ def _binary_stat_scores_format(

"""
if preds.is_floating_point():
if not torch.all((preds >= 0) * (preds <= 1)):
# preds is logits, convert with sigmoid
preds = preds.sigmoid()
preds = normalize_logits_if_needed(preds, "sigmoid")
preds = preds > threshold

preds = preds.reshape(preds.shape[0], -1)
Expand Down Expand Up @@ -659,8 +658,7 @@ def _multilabel_stat_scores_format(

"""
if preds.is_floating_point():
if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.sigmoid()
preds = normalize_logits_if_needed(preds, "sigmoid")
preds = preds > threshold
preds = preds.reshape(*preds.shape[:2], -1)
target = target.reshape(*target.shape[:2], -1)
Expand Down
43 changes: 43 additions & 0 deletions src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch import Tensor
from typing_extensions import Literal


def _safe_matmul(x: Tensor, y: Tensor) -> Tensor:
Expand Down Expand Up @@ -184,3 +185,45 @@ def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor:
indices = torch.clamp(indices, 0, len(m) - 1)

return m[indices] * x + b[indices]


def normalize_logits_if_needed(tensor: Tensor, normalization: Literal["sigmoid", "softmax"]) -> Tensor:
"""Normalize logits if needed.

If input tensor is outside the [0,1] we assume that logits are provided and apply the normalization.
Use torch.where to prevent device-host sync.

Args:
tensor: input tensor that may be logits or probabilities
normalization: normalization method, either 'sigmoid' or 'softmax'

Returns:
normalized tensor if needed

Example:
>>> import torch
>>> tensor = torch.tensor([-1.0, 0.0, 1.0])
>>> normalize_logits_if_needed(tensor, normalization="sigmoid")
tensor([0.2689, 0.5000, 0.7311])
>>> tensor = torch.tensor([[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]])
>>> normalize_logits_if_needed(tensor, normalization="softmax")
tensor([[0.0900, 0.2447, 0.6652],
[0.6652, 0.2447, 0.0900]])
>>> tensor = torch.tensor([0.0, 0.5, 1.0])
>>> normalize_logits_if_needed(tensor, normalization="sigmoid")
tensor([0.0000, 0.5000, 1.0000])

"""
# decrease sigmoid on cpu .
if tensor.device == torch.device("cpu"):
if not torch.all((tensor >= 0) * (tensor <= 1)):
tensor = tensor.sigmoid() if normalization == "sigmoid" else torch.softmax(tensor, dim=1)
return tensor

# decrease device-host sync on device .
condition = ((tensor < 0) | (tensor > 1)).any()
return torch.where(
condition,
torch.sigmoid(tensor) if normalization == "sigmoid" else torch.softmax(tensor, dim=1),
tensor,
)
13 changes: 7 additions & 6 deletions tests/unittests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
multilabel_accuracy,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1

from unittests import NUM_CLASSES, THRESHOLD
from unittests._helpers import seed_all
Expand Down Expand Up @@ -153,8 +154,8 @@ def test_binary_accuracy_half_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down Expand Up @@ -310,8 +311,8 @@ def test_multiclass_accuracy_half_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down Expand Up @@ -585,8 +586,8 @@ def test_multilabel_accuracy_half_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down
5 changes: 3 additions & 2 deletions tests/unittests/classification/test_auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torchmetrics.functional.classification.auroc import binary_auroc, multiclass_auroc, multilabel_auroc
from torchmetrics.functional.classification.roc import binary_roc
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1

from unittests import NUM_CLASSES
from unittests._helpers import seed_all
Expand Down Expand Up @@ -102,8 +103,8 @@ def test_binary_auroc_dtype_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down
5 changes: 3 additions & 2 deletions tests/unittests/classification/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from torchmetrics.functional.classification.precision_recall_curve import binary_precision_recall_curve
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1

from unittests import NUM_CLASSES
from unittests._helpers import seed_all
Expand Down Expand Up @@ -106,8 +107,8 @@ def test_binary_average_precision_differentiability(self, inputs):
def test_binary_average_precision_dtype_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs
if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down
5 changes: 3 additions & 2 deletions tests/unittests/classification/test_calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
multiclass_calibration_error,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1

from unittests import NUM_CLASSES
from unittests._helpers import seed_all
Expand Down Expand Up @@ -112,8 +113,8 @@ def test_binary_calibration_error_differentiability(self, inputs):
def test_binary_calibration_error_dtype_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs
if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down
9 changes: 5 additions & 4 deletions tests/unittests/classification/test_cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa
from torchmetrics.functional.classification.cohen_kappa import binary_cohen_kappa, multiclass_cohen_kappa
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1

from unittests import NUM_CLASSES, THRESHOLD
from unittests._helpers import seed_all
Expand Down Expand Up @@ -103,8 +104,8 @@ def test_binary_cohen_kappa_dtypes_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down Expand Up @@ -206,8 +207,8 @@ def test_multiclass_cohen_kappa_dtypes_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down
9 changes: 5 additions & 4 deletions tests/unittests/classification/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
multilabel_confusion_matrix,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1

from unittests import NUM_CLASSES, THRESHOLD
from unittests._helpers import seed_all
Expand Down Expand Up @@ -114,8 +115,8 @@ def test_binary_confusion_matrix_dtype_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down Expand Up @@ -367,8 +368,8 @@ def test_multilabel_confusion_matrix_dtype_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down
9 changes: 5 additions & 4 deletions tests/unittests/classification/test_exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchmetrics.classification.exact_match import ExactMatch, MulticlassExactMatch, MultilabelExactMatch
from torchmetrics.functional.classification.exact_match import multiclass_exact_match, multilabel_exact_match
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1

from unittests import NUM_CLASSES, THRESHOLD
from unittests._helpers import seed_all
Expand Down Expand Up @@ -121,8 +122,8 @@ def test_multiclass_exact_match_half_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down Expand Up @@ -250,8 +251,8 @@ def test_multilabel_exact_match_half_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down
Loading
Loading