Skip to content

Commit

Permalink
Fix TPU collective ops test for multi-host TPUs (#5408)
Browse files Browse the repository at this point in the history
* Fix TPU collective ops test for multi-host TPUs

* formatting
  • Loading branch information
will-cromar committed Sep 14, 2023
1 parent 2b2251f commit 26391c1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 8 deletions.
33 changes: 25 additions & 8 deletions test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
from torch_xla._internal import pjrt, tpu


def _is_single_host():
return len(tpu.get_worker_ips())


class TestCollectiveOpsTpu(parameterized.TestCase):

@staticmethod
Expand All @@ -23,7 +19,8 @@ def _broadcast(sync):
xm.mark_step()
return next(model.parameters()).detach().cpu().numpy()

@absltest.skipUnless(_is_single_host, "Only implemented for single host.")
@absltest.skipUnless(tpu.num_tpu_workers() == 1,
"Only implemented for single host.")
@parameterized.named_parameters(('synchronized_parameters', True),
('unsynchronized_parameters', False))
def test_broadcast_master_param(self, sync):
Expand All @@ -36,6 +33,25 @@ def test_broadcast_master_param(self, sync):
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal,
master_params, worker_params)

@staticmethod
def _all_reduce(pin_layout):
device = xm.xla_device()
# Prevent 0 and 1 from being converted to constants
ordinal = xm.send_cpu_data_to_device(
torch.tensor(xm.get_ordinal()), device=device)
out = xm.all_reduce(xm.REDUCE_SUM, ordinal, pin_layout=pin_layout)[0]
xm.mark_step()

return out.cpu().numpy()

@parameterized.named_parameters(('pinned', True), ('unpinned', False))
def test_all_reduce(self, pin_layout):
results = pjrt.run_multiprocess(self._all_reduce, pin_layout)

expected = sum(range(tpu.num_expected_global_devices()))
for v in results.values():
np.testing.assert_array_equal(v, expected)

@staticmethod
def _all_gather(pin_layout):
device = xm.xla_device()
Expand All @@ -49,7 +65,7 @@ def _all_gather(pin_layout):
def test_all_gather(self, pin_layout):
results = pjrt.run_multiprocess(self._all_gather, pin_layout)

expected = list(range(len(results)))
expected = list(range(tpu.num_expected_global_devices()))
for v in results.values():
np.testing.assert_array_equal(v, expected)

Expand Down Expand Up @@ -106,9 +122,10 @@ def _all_to_all(pin_layout):
def test_all_to_all(self, pin_layout):
results = pjrt.run_multiprocess(self._all_to_all, pin_layout)

world_size = tpu.num_expected_global_devices()
for ordinal, value in results.items():
np.testing.assert_array_equal(value, [[[-ordinal] * len(results),
list(range(len(results)))]])
np.testing.assert_array_equal(value, [[[-ordinal] * world_size,
list(range(world_size))]])


if __name__ == '__main__':
Expand Down
14 changes: 14 additions & 0 deletions torch_xla/_internal/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@ def num_available_devices() -> int:
return num_available_chips() * num_logical_cores_per_chip()


def num_expected_global_devices() -> int:
"""Returns the number of expected runtime devices in this TPU slice.
May differ from the actual number of runtime devices if TPU topology settings
are changed.
"""
return num_available_devices() * num_tpu_workers()


def num_local_processes() -> int:
"""Returns number of processes to create on this host."""
local_chips = num_available_chips()
Expand Down Expand Up @@ -188,6 +197,11 @@ def get_worker_ips() -> List[str]:
return hostnames if len(hostnames) > 1 else ['localhost']


def num_tpu_workers() -> int:
"""Returns the number of configured TPU workers."""
return len(get_worker_ips())


def configure_one_chip_topology() -> None:
"""Configures TPU topology environment variables for one process and chip.
Expand Down

0 comments on commit 26391c1

Please sign in to comment.