-
Notifications
You must be signed in to change notification settings - Fork 505
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PJRT] Separate collective ops test from TPU runtime test. (#5396)
* [PJRT] Separate collective ops test from TPU runtime test. * formatting
- Loading branch information
1 parent
a4a742d
commit 1d99226
Showing
3 changed files
with
116 additions
and
102 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from absl.testing import absltest, parameterized | ||
import torch_xla.core.xla_model as xm | ||
from torch_xla._internal import pjrt, tpu | ||
|
||
|
||
def _is_single_host(): | ||
return len(tpu.get_worker_ips()) | ||
|
||
|
||
class TestCollectiveOpsTpu(parameterized.TestCase): | ||
|
||
@staticmethod | ||
def _broadcast(sync): | ||
torch.manual_seed(xm.get_ordinal()) | ||
device = xm.xla_device() | ||
model = nn.Linear(5, 5).to(device) | ||
if sync: | ||
xm.broadcast_master_param(model) | ||
|
||
xm.mark_step() | ||
return next(model.parameters()).detach().cpu().numpy() | ||
|
||
@absltest.skipUnless(_is_single_host, "Only implemented for single host.") | ||
@parameterized.named_parameters(('synchronized_parameters', True), | ||
('unsynchronized_parameters', False)) | ||
def test_broadcast_master_param(self, sync): | ||
results = pjrt.run_multiprocess(self._broadcast, sync) | ||
master_params = results[0] | ||
for ordinal, worker_params in results.items(): | ||
if sync: | ||
np.testing.assert_array_equal(master_params, worker_params) | ||
elif ordinal != 0: | ||
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, | ||
master_params, worker_params) | ||
|
||
@staticmethod | ||
def _all_gather(pin_layout): | ||
device = xm.xla_device() | ||
ordinal = torch.tensor([xm.get_ordinal()], device=device) | ||
out = xm.all_gather(ordinal, pin_layout=pin_layout) | ||
xm.mark_step() | ||
|
||
return out.cpu().numpy() | ||
|
||
@parameterized.named_parameters(('pinned', True), ('unpinned', False)) | ||
def test_all_gather(self, pin_layout): | ||
results = pjrt.run_multiprocess(self._all_gather, pin_layout) | ||
|
||
expected = list(range(len(results))) | ||
for v in results.values(): | ||
np.testing.assert_array_equal(v, expected) | ||
|
||
@staticmethod | ||
def _reduce_scatter(pin_layout): | ||
device = xm.xla_device() | ||
world_size = xm.xrt_world_size() | ||
tensor = -torch.arange(world_size, dtype=torch.float32).to(device) | ||
|
||
out = xm.reduce_scatter( | ||
xm.REDUCE_SUM, | ||
tensor, | ||
scale=1.0 / world_size, | ||
scatter_dim=0, | ||
shard_count=world_size, | ||
pin_layout=pin_layout, | ||
) | ||
xm.mark_step() | ||
|
||
return out.cpu().numpy() | ||
|
||
@parameterized.named_parameters(('pinned', True), ('unpinned', False)) | ||
def test_reduce_scatter(self, pin_layout): | ||
results = pjrt.run_multiprocess(self._reduce_scatter, pin_layout) | ||
|
||
for ordinal, value in results.items(): | ||
np.testing.assert_array_equal(value, [-ordinal]) | ||
|
||
@staticmethod | ||
def _all_to_all(pin_layout): | ||
device = xm.xla_device() | ||
world_size = xm.xrt_world_size() | ||
|
||
tensor = torch.cat( | ||
[ | ||
-torch.arange(world_size, dtype=torch.float32).view(-1, 1, 1), | ||
torch.ones(world_size, 1, 1) * xm.get_ordinal(), | ||
], | ||
dim=1, | ||
).to(device) | ||
xm.mark_step() | ||
|
||
out = xm.all_to_all( | ||
tensor, | ||
split_dimension=0, | ||
concat_dimension=2, | ||
split_count=world_size, | ||
pin_layout=pin_layout, | ||
) | ||
|
||
return out.cpu().numpy() | ||
|
||
@parameterized.named_parameters(('pinned', True), ('unpinned', False)) | ||
def test_all_to_all(self, pin_layout): | ||
results = pjrt.run_multiprocess(self._all_to_all, pin_layout) | ||
|
||
for ordinal, value in results.items(): | ||
np.testing.assert_array_equal(value, [[[-ordinal] * len(results), | ||
list(range(len(results)))]]) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters