diff --git a/test/run_tests.sh b/test/run_tests.sh index f73dc156df76..1553d53e4095 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -236,7 +236,6 @@ function run_mp_op_tests { run_test "$CDIR/test_mp_save.py" run_test "$CDIR/test_mp_mesh_reduce.py" run_test "$CDIR/test_mp_sync_batch_norm.py" - run_test "$CDIR/test_mp_early_exit.py" run_pt_xla_debug "$CDIR/debug_tool/test_mp_pt_xla_debug.py" run_xla_backend_mp "$CDIR/test_torch_distributed_all_gather_xla_backend.py" run_xla_backend_mp "$CDIR/test_torch_distributed_all_reduce_xla_backend.py" diff --git a/test/test_mp_early_exit.py b/test/test_mp_early_exit.py deleted file mode 100644 index 837aea1751be..000000000000 --- a/test/test_mp_early_exit.py +++ /dev/null @@ -1,26 +0,0 @@ -import sys -import torch -import torch_xla -import torch_xla.core.xla_model as xm -import torch_xla.distributed.parallel_loader as pl -import torch_xla.distributed.xla_multiprocessing as xmp -import torch_xla.utils.utils as xu - - -def _mp_fn(index): - device = xm.xla_device() - if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM', 'NEURON'): - train_loader = xu.SampleGenerator( - data=torch.zeros(1, 12), sample_count=1024) - train_loader = pl.MpDeviceLoader(train_loader, device) - max_steps = 10 - for step, inputs in enumerate(train_loader): - xm.all_reduce('sum', [inputs], scale=1.0 / xm.xrt_world_size()) - if step > max_steps: - break - else: - print(f'{device} is not a TPU or GPU device', file=sys.stderr) - - -if __name__ == '__main__': - xmp.spawn(_mp_fn, args=()) diff --git a/test/test_zero1.py b/test/test_zero1.py index bc1ea3c10e32..17c46617973c 100644 --- a/test/test_zero1.py +++ b/test/test_zero1.py @@ -13,6 +13,7 @@ class XlaZeRO1Test(TestCase): @unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU") + @unittest.skipIf(xr.device_type() == 'CUDA', "Crash on CUDA") def test_zero1(self): device = xm.xla_device() diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 8d4997e28556..d753f8f7c8f2 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -148,8 +148,6 @@ def _setup_tpu_vm_library_path() -> bool: def _prepare_to_exit(): - device = _XLAC._xla_get_default_device() - _XLAC._set_all_reduce_token(device, None) _XLAC._prepare_to_exit() if int(os.environ.get('PT_XLA_DEBUG', '0')): _summarize_fn_tracker()