Skip to content

Commit

Permalink
Rearrange the test
Browse files Browse the repository at this point in the history
  • Loading branch information
alanwaketan committed Apr 21, 2023
1 parent 4066833 commit 3fb46da
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions test/test_mp_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@ def _mp_fn(index):
world_size = xm.xrt_world_size()
if xm.xla_device_hw(device) in ('TPU', 'GPU'):
# Testing with a single replica group
compiled_all_gather = torch.compile(
all_gather, backend='torchxla_trace_once', fullgraph=True)
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
result = compiled_all_gather(ordinal_tensor, dim=0)
result = xm.all_gather(ordinal_tensor, dim=0)

cpu_result = result.cpu()
expected = torch.arange(0, world_size, dtype=torch.float)
Expand All @@ -27,8 +25,10 @@ def _mp_fn(index):
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

compiled_all_gather = torch.compile(
all_gather, backend='torchxla_trace_once', fullgraph=True)
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
result = xm.all_gather(ordinal_tensor, dim=0)
result = compiled_all_gather(ordinal_tensor, dim=0)

cpu_result = result.cpu()
expected = torch.arange(0, world_size, dtype=torch.float)
Expand Down

0 comments on commit 3fb46da

Please sign in to comment.