From 3fb46da10e8b649a80fcbe71f186389f63c60433 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 21 Apr 2023 20:18:51 +0000 Subject: [PATCH] Rearrange the test --- test/test_mp_all_gather.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index b36b851d558b..cbcc66e62010 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -15,10 +15,8 @@ def _mp_fn(index): world_size = xm.xrt_world_size() if xm.xla_device_hw(device) in ('TPU', 'GPU'): # Testing with a single replica group - compiled_all_gather = torch.compile( - all_gather, backend='torchxla_trace_once', fullgraph=True) ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device) - result = compiled_all_gather(ordinal_tensor, dim=0) + result = xm.all_gather(ordinal_tensor, dim=0) cpu_result = result.cpu() expected = torch.arange(0, world_size, dtype=torch.float) @@ -27,8 +25,10 @@ def _mp_fn(index): print(f'[{index}] {cpu_result}', file=sys.stderr) sys.exit(1) + compiled_all_gather = torch.compile( + all_gather, backend='torchxla_trace_once', fullgraph=True) ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device) - result = xm.all_gather(ordinal_tensor, dim=0) + result = compiled_all_gather(ordinal_tensor, dim=0) cpu_result = result.cpu() expected = torch.arange(0, world_size, dtype=torch.float)