From 61c519596e18b7fbfa3240b4dbebd10041e10aae Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sun, 24 Mar 2024 17:06:18 +0100 Subject: [PATCH 1/4] implementation + tests --- src/torchmetrics/metric.py | 4 ++++ tests/unittests/bases/test_ddp.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index c5e6999e89e..9636ad7635c 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -432,6 +432,10 @@ def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: if reduction_fn == dim_zero_cat and isinstance(input_dict[attr], list) and len(input_dict[attr]) > 1: input_dict[attr] = [dim_zero_cat(input_dict[attr])] + # cornor case in distributed settings where a rank have not received any data, create empty to concatenate + if reduction_fn == dim_zero_cat and isinstance(input_dict[attr], list) and len(input_dict[attr]) == 0: + input_dict[attr] = [torch.tensor([], device=self.device, dtype=self.dtype)] + output_dict = apply_to_collection( input_dict, Tensor, diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index 1a44a2145ba..ed11ede2f7f 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -277,3 +277,18 @@ def _test_sync_with_empty_lists(rank): def test_sync_with_empty_lists(): """Test that synchronization of states can be enabled and disabled for compute.""" pytest.pool.map(_test_sync_with_empty_lists, range(NUM_PROCESSES)) + + +def _test_sync_with_unequal_size_lists(rank): + """Test that synchronization of list states work even when some ranks have not received any data yet.""" + dummy = DummyListMetric() + if rank == 0: + dummy.update(torch.zeros(2)) + assert dummy.compute() == tensor([0.0, 0.0]) + + +@pytest.mark.DDP() +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +def test_sync_with_unequal_size_lists(): + """Test that synchronization of states can be enabled and disabled for compute.""" + pytest.pool.map(_test_sync_with_unequal_size_lists, range(NUM_PROCESSES)) From fbdbd3bb1bdf1d9afe4e88f1039e389cf961239e Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sun, 24 Mar 2024 17:08:49 +0100 Subject: [PATCH 2/4] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a9a7c49d0e..74f43ca972d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462)) +- Fixed list synchronization with partly empty lists ([#2468](https://github.com/Lightning-AI/torchmetrics/pull/2468)) + + ## [1.3.2] - 2024-03-18 ### Fixed From 1bfef7ea05b262377f2af2378fe9824dd3244c8f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sun, 14 Apr 2024 14:24:41 +0200 Subject: [PATCH 3/4] fix tests --- tests/unittests/bases/test_ddp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index ed11ede2f7f..fc45e07ff5d 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -269,7 +269,7 @@ def test_sync_on_compute(sync_on_compute, test_func): def _test_sync_with_empty_lists(rank): dummy = DummyListMetric() val = dummy.compute() - assert val == [] + assert torch.allclose(val, tensor([])) @pytest.mark.DDP() @@ -284,7 +284,7 @@ def _test_sync_with_unequal_size_lists(rank): dummy = DummyListMetric() if rank == 0: dummy.update(torch.zeros(2)) - assert dummy.compute() == tensor([0.0, 0.0]) + assert torch.all(dummy.compute() == tensor([0.0, 0.0])) @pytest.mark.DDP() From 92c150728fcaf51f24bbf802e09bf9d48913b996 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 19 Apr 2024 14:53:06 +0200 Subject: [PATCH 4/4] only for newer versions --- src/torchmetrics/metric.py | 8 +++++++- tests/unittests/bases/test_ddp.py | 3 +++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index a917dc21658..a210995eeaf 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -38,6 +38,7 @@ ) from torchmetrics.utilities.distributed import gather_all_tensors from torchmetrics.utilities.exceptions import TorchMetricsUserError +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val from torchmetrics.utilities.prints import rank_zero_warn @@ -439,7 +440,12 @@ def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: input_dict[attr] = [dim_zero_cat(input_dict[attr])] # cornor case in distributed settings where a rank have not received any data, create empty to concatenate - if reduction_fn == dim_zero_cat and isinstance(input_dict[attr], list) and len(input_dict[attr]) == 0: + if ( + _TORCH_GREATER_EQUAL_2_1 + and reduction_fn == dim_zero_cat + and isinstance(input_dict[attr], list) + and len(input_dict[attr]) == 0 + ): input_dict[attr] = [torch.tensor([], device=self.device, dtype=self.dtype)] output_dict = apply_to_collection( diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index fc45e07ff5d..e927816e203 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -22,6 +22,7 @@ from torchmetrics import Metric from torchmetrics.utilities.distributed import gather_all_tensors from torchmetrics.utilities.exceptions import TorchMetricsUserError +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_PROCESSES from unittests._helpers import seed_all @@ -273,6 +274,7 @@ def _test_sync_with_empty_lists(rank): @pytest.mark.DDP() +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") def test_sync_with_empty_lists(): """Test that synchronization of states can be enabled and disabled for compute.""" @@ -288,6 +290,7 @@ def _test_sync_with_unequal_size_lists(rank): @pytest.mark.DDP() +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") def test_sync_with_unequal_size_lists(): """Test that synchronization of states can be enabled and disabled for compute."""