Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete Device2Host caused by comm with device and host #2840

Merged
merged 25 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
afd1fb8
async host/device .
Nov 22, 2024
89a616a
unittest .
Nov 24, 2024
9afd681
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2024
377f77d
update unittest .
Nov 26, 2024
92ef211
Merge branch 'master' into async_cpu_gpu_20241122
Borda Nov 26, 2024
5146a6d
chlog
Borda Nov 27, 2024
5e25b28
revert test file
SkafteNicki Dec 3, 2024
08d9ece
general conditional compute function
SkafteNicki Dec 3, 2024
ae09c7d
fix code multiple locations using new function
SkafteNicki Dec 3, 2024
9c82719
make cpu + sigmoid skipping less restrictive for classification
SkafteNicki Dec 3, 2024
a07363d
Merge branch 'master' into async_cpu_gpu_20241122
SkafteNicki Dec 3, 2024
eb41cee
Merge branch 'master' into async_cpu_gpu_20241122
mergify[bot] Dec 11, 2024
49f0a9a
Merge branch 'master' into async_cpu_gpu_20241122
mergify[bot] Dec 11, 2024
1f1833e
compat cpu and device
zhaozheng09 Dec 11, 2024
1d57e10
fixed not defined
zhaozheng09 Dec 11, 2024
70faf9d
Merge branch 'master' into async_cpu_gpu_20241122
zhaozheng09 Dec 11, 2024
2ead3cc
add softmax support
zhaozheng09 Dec 11, 2024
953a98d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2024
2b7ae74
fix tuple bug.
zhaozheng09 Dec 12, 2024
e108721
Merge branch 'master' into async_cpu_gpu_20241122
mergify[bot] Dec 17, 2024
fd1c189
Merge branch 'master' into async_cpu_gpu_20241122
mergify[bot] Dec 17, 2024
d85f94c
Merge branch 'master' into async_cpu_gpu_20241122
mergify[bot] Dec 17, 2024
c945a37
Apply suggestions from code review
Borda Dec 17, 2024
1cdba8f
Merge branch 'master' into async_cpu_gpu_20241122
mergify[bot] Dec 17, 2024
6ced9ed
Merge branch 'master' into async_cpu_gpu_20241122
zhaozheng09 Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,14 @@ def _binary_precision_recall_curve_format(
preds = preds[idx]
target = target[idx]

if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.sigmoid()
# "sigmoid_cpu" not implemented for 'Half'
if preds.dtype != torch.float16 or preds.device != torch.device("cpu"):
out_of_bounds = (preds < 0) | (preds > 1)
out_of_bounds = out_of_bounds.any()
preds = torch.where(out_of_bounds, preds.sigmoid(), preds)
else:
if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.sigmoid()

thresholds = _adjust_threshold_arg(thresholds, preds.device)
return preds, target, thresholds
Expand Down Expand Up @@ -761,8 +767,15 @@ def _multilabel_precision_recall_curve_format(
"""
preds = preds.transpose(0, 1).reshape(num_labels, -1).T
target = target.transpose(0, 1).reshape(num_labels, -1).T
if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.sigmoid()

# "sigmoid_cpu" not implemented for 'Half'
if preds.dtype != torch.float16 or preds.device != torch.device("cpu"):
out_of_bounds = (preds < 0) | (preds > 1)
out_of_bounds = out_of_bounds.any()
preds = torch.where(out_of_bounds, preds.sigmoid(), preds)
else:
if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.sigmoid()

thresholds = _adjust_threshold_arg(thresholds, preds.device)
if ignore_index is not None and thresholds is not None:
Expand Down
30 changes: 21 additions & 9 deletions tests/unittests/classification/test_precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,27 @@ def test_binary_precision_recall_curve_dtype_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs
if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
self.run_precision_test_cpu(
preds=preds,
target=target,
metric_module=BinaryPrecisionRecallCurve,
metric_functional=binary_precision_recall_curve,
metric_args={"thresholds": None},
dtype=dtype,
)
try:
self.run_precision_test_cpu(
preds=preds,
target=target,
metric_module=BinaryPrecisionRecallCurve,
metric_functional=binary_precision_recall_curve,
metric_args={"thresholds": None},
dtype=dtype,
)
except Exception as e:
print(f"An unexpected error occurred: {e}")
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
else:
self.run_precision_test_cpu(
preds=preds,
target=target,
metric_module=BinaryPrecisionRecallCurve,
metric_functional=binary_precision_recall_curve,
metric_args={"thresholds": None},
dtype=dtype,
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
@pytest.mark.parametrize("dtype", [torch.half, torch.double])
Expand Down
Loading