Skip to content

Commit

Permalink
[PJRT] Separate collective ops test from TPU runtime test. (#5396)
Browse files Browse the repository at this point in the history
* [PJRT] Separate collective ops test from TPU runtime test.

* formatting
  • Loading branch information
will-cromar committed Sep 14, 2023
1 parent a4a742d commit 1d99226
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 102 deletions.
115 changes: 115 additions & 0 deletions test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import numpy as np
import torch
import torch.nn as nn
from absl.testing import absltest, parameterized
import torch_xla.core.xla_model as xm
from torch_xla._internal import pjrt, tpu


def _is_single_host():
return len(tpu.get_worker_ips())


class TestCollectiveOpsTpu(parameterized.TestCase):

@staticmethod
def _broadcast(sync):
torch.manual_seed(xm.get_ordinal())
device = xm.xla_device()
model = nn.Linear(5, 5).to(device)
if sync:
xm.broadcast_master_param(model)

xm.mark_step()
return next(model.parameters()).detach().cpu().numpy()

@absltest.skipUnless(_is_single_host, "Only implemented for single host.")
@parameterized.named_parameters(('synchronized_parameters', True),
('unsynchronized_parameters', False))
def test_broadcast_master_param(self, sync):
results = pjrt.run_multiprocess(self._broadcast, sync)
master_params = results[0]
for ordinal, worker_params in results.items():
if sync:
np.testing.assert_array_equal(master_params, worker_params)
elif ordinal != 0:
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal,
master_params, worker_params)

@staticmethod
def _all_gather(pin_layout):
device = xm.xla_device()
ordinal = torch.tensor([xm.get_ordinal()], device=device)
out = xm.all_gather(ordinal, pin_layout=pin_layout)
xm.mark_step()

return out.cpu().numpy()

@parameterized.named_parameters(('pinned', True), ('unpinned', False))
def test_all_gather(self, pin_layout):
results = pjrt.run_multiprocess(self._all_gather, pin_layout)

expected = list(range(len(results)))
for v in results.values():
np.testing.assert_array_equal(v, expected)

@staticmethod
def _reduce_scatter(pin_layout):
device = xm.xla_device()
world_size = xm.xrt_world_size()
tensor = -torch.arange(world_size, dtype=torch.float32).to(device)

out = xm.reduce_scatter(
xm.REDUCE_SUM,
tensor,
scale=1.0 / world_size,
scatter_dim=0,
shard_count=world_size,
pin_layout=pin_layout,
)
xm.mark_step()

return out.cpu().numpy()

@parameterized.named_parameters(('pinned', True), ('unpinned', False))
def test_reduce_scatter(self, pin_layout):
results = pjrt.run_multiprocess(self._reduce_scatter, pin_layout)

for ordinal, value in results.items():
np.testing.assert_array_equal(value, [-ordinal])

@staticmethod
def _all_to_all(pin_layout):
device = xm.xla_device()
world_size = xm.xrt_world_size()

tensor = torch.cat(
[
-torch.arange(world_size, dtype=torch.float32).view(-1, 1, 1),
torch.ones(world_size, 1, 1) * xm.get_ordinal(),
],
dim=1,
).to(device)
xm.mark_step()

out = xm.all_to_all(
tensor,
split_dimension=0,
concat_dimension=2,
split_count=world_size,
pin_layout=pin_layout,
)

return out.cpu().numpy()

@parameterized.named_parameters(('pinned', True), ('unpinned', False))
def test_all_to_all(self, pin_layout):
results = pjrt.run_multiprocess(self._all_to_all, pin_layout)

for ordinal, value in results.items():
np.testing.assert_array_equal(value, [[[-ordinal] * len(results),
list(range(len(results)))]])


if __name__ == '__main__':
absltest.main()
102 changes: 0 additions & 102 deletions test/pjrt/test_runtime_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from unittest import mock
import requests

import numpy as np
import torch
import torch.nn as nn
from absl.testing import absltest, parameterized
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -238,105 +236,5 @@ def test_execute_time_metric(self):
f"{expected_time_seconds} seconds, got {v / 1e9} seconds")


class TestTpuCollectiveOps(parameterized.TestCase):

@staticmethod
def _broadcast(sync):
torch.manual_seed(xm.get_ordinal())
device = xm.xla_device()
model = nn.Linear(5, 5).to(device)
if sync:
xm.broadcast_master_param(model)

xm.mark_step()
return next(model.parameters()).detach().cpu().numpy()

@parameterized.named_parameters(('synchronized_parameters', True),
('unsynchronized_parameters', False))
def test_broadcast_master_param(self, sync):
results = pjrt.run_multiprocess(self._broadcast, sync)
master_params = results[0]
for ordinal, worker_params in results.items():
if sync:
np.testing.assert_array_equal(master_params, worker_params)
elif ordinal != 0:
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal,
master_params, worker_params)

@staticmethod
def _all_gather(pin_layout):
device = xm.xla_device()
ordinal = torch.tensor([xm.get_ordinal()], device=device)
out = xm.all_gather(ordinal, pin_layout=pin_layout)
xm.mark_step()

return out.cpu().numpy()

@parameterized.named_parameters(('pinned', True), ('unpinned', False))
def test_all_gather(self, pin_layout):
results = pjrt.run_multiprocess(self._all_gather, pin_layout)

expected = list(range(len(results)))
for v in results.values():
np.testing.assert_array_equal(v, expected)

@staticmethod
def _reduce_scatter(pin_layout):
device = xm.xla_device()
world_size = xm.xrt_world_size()
tensor = -torch.arange(world_size, dtype=torch.float32).to(device)

out = xm.reduce_scatter(
xm.REDUCE_SUM,
tensor,
scale=1.0 / world_size,
scatter_dim=0,
shard_count=world_size,
pin_layout=pin_layout,
)
xm.mark_step()

return out.cpu().numpy()

@parameterized.named_parameters(('pinned', True), ('unpinned', False))
def test_reduce_scatter(self, pin_layout):
results = pjrt.run_multiprocess(self._reduce_scatter, pin_layout)

for ordinal, value in results.items():
np.testing.assert_array_equal(value, [-ordinal])

@staticmethod
def _all_to_all(pin_layout):
device = xm.xla_device()
world_size = xm.xrt_world_size()

tensor = torch.cat(
[
-torch.arange(world_size, dtype=torch.float32).view(-1, 1, 1),
torch.ones(world_size, 1, 1) * xm.get_ordinal(),
],
dim=1,
).to(device)
xm.mark_step()

out = xm.all_to_all(
tensor,
split_dimension=0,
concat_dimension=2,
split_count=world_size,
pin_layout=pin_layout,
)

return out.cpu().numpy()

@parameterized.named_parameters(('pinned', True), ('unpinned', False))
def test_all_to_all(self, pin_layout):
results = pjrt.run_multiprocess(self._all_to_all, pin_layout)

for ordinal, value in results.items():
np.testing.assert_array_equal(value, [[[-ordinal] * len(results),
list(range(len(results)))]])


if __name__ == '__main__':
absltest.main()
1 change: 1 addition & 0 deletions test/tpu/xla_test_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ spec:
python3 /src/pytorch/xla/test/test_operations.py -v
python3 /src/pytorch/xla/test/pjrt/test_runtime_tpu.py
python3 /src/pytorch/xla/test/pjrt/test_collective_ops_tpu.py
python3 /src/pytorch/xla/test/spmd/test_xla_sharding.py
python3 /src/pytorch/xla/test/spmd/test_xla_virtual_device.py
python3 /src/pytorch/xla/test/spmd/test_xla_distributed_checkpoint.py
Expand Down

0 comments on commit 1d99226

Please sign in to comment.