From 26391c17acdd21ff011a1d58009f4926b73cf0ab Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 4 Aug 2023 12:37:06 -0700 Subject: [PATCH] Fix TPU collective ops test for multi-host TPUs (#5408) * Fix TPU collective ops test for multi-host TPUs * formatting --- test/pjrt/test_collective_ops_tpu.py | 33 +++++++++++++++++++++------- torch_xla/_internal/tpu.py | 14 ++++++++++++ 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 85c70a8ee586..f1752901661e 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -6,10 +6,6 @@ from torch_xla._internal import pjrt, tpu -def _is_single_host(): - return len(tpu.get_worker_ips()) - - class TestCollectiveOpsTpu(parameterized.TestCase): @staticmethod @@ -23,7 +19,8 @@ def _broadcast(sync): xm.mark_step() return next(model.parameters()).detach().cpu().numpy() - @absltest.skipUnless(_is_single_host, "Only implemented for single host.") + @absltest.skipUnless(tpu.num_tpu_workers() == 1, + "Only implemented for single host.") @parameterized.named_parameters(('synchronized_parameters', True), ('unsynchronized_parameters', False)) def test_broadcast_master_param(self, sync): @@ -36,6 +33,25 @@ def test_broadcast_master_param(self, sync): np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, master_params, worker_params) + @staticmethod + def _all_reduce(pin_layout): + device = xm.xla_device() + # Prevent 0 and 1 from being converted to constants + ordinal = xm.send_cpu_data_to_device( + torch.tensor(xm.get_ordinal()), device=device) + out = xm.all_reduce(xm.REDUCE_SUM, ordinal, pin_layout=pin_layout)[0] + xm.mark_step() + + return out.cpu().numpy() + + @parameterized.named_parameters(('pinned', True), ('unpinned', False)) + def test_all_reduce(self, pin_layout): + results = pjrt.run_multiprocess(self._all_reduce, pin_layout) + + expected = sum(range(tpu.num_expected_global_devices())) + for v in results.values(): + np.testing.assert_array_equal(v, expected) + @staticmethod def _all_gather(pin_layout): device = xm.xla_device() @@ -49,7 +65,7 @@ def _all_gather(pin_layout): def test_all_gather(self, pin_layout): results = pjrt.run_multiprocess(self._all_gather, pin_layout) - expected = list(range(len(results))) + expected = list(range(tpu.num_expected_global_devices())) for v in results.values(): np.testing.assert_array_equal(v, expected) @@ -106,9 +122,10 @@ def _all_to_all(pin_layout): def test_all_to_all(self, pin_layout): results = pjrt.run_multiprocess(self._all_to_all, pin_layout) + world_size = tpu.num_expected_global_devices() for ordinal, value in results.items(): - np.testing.assert_array_equal(value, [[[-ordinal] * len(results), - list(range(len(results)))]]) + np.testing.assert_array_equal(value, [[[-ordinal] * world_size, + list(range(world_size))]]) if __name__ == '__main__': diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index d13836532f6b..9e4ceaee26b0 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -119,6 +119,15 @@ def num_available_devices() -> int: return num_available_chips() * num_logical_cores_per_chip() +def num_expected_global_devices() -> int: + """Returns the number of expected runtime devices in this TPU slice. + + May differ from the actual number of runtime devices if TPU topology settings + are changed. + """ + return num_available_devices() * num_tpu_workers() + + def num_local_processes() -> int: """Returns number of processes to create on this host.""" local_chips = num_available_chips() @@ -188,6 +197,11 @@ def get_worker_ips() -> List[str]: return hostnames if len(hostnames) > 1 else ['localhost'] +def num_tpu_workers() -> int: + """Returns the number of configured TPU workers.""" + return len(get_worker_ips()) + + def configure_one_chip_topology() -> None: """Configures TPU topology environment variables for one process and chip.