From e70fb298f32809570338cc452bec0f823d463a3a Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Thu, 27 Jul 2023 10:01:17 -0700 Subject: [PATCH] Add python test for SPMD+Runtime Python API (#5349) * Add python test for SPMD+Runtime Python API * replace test name * Update test_xla_spmd_python_api_interaction.py --- test/run_tests.sh | 1 + ...> test_xla_spmd_python_api_interaction.py} | 50 +++++++++++++++++++ test/tpu/xla_test_job.yaml | 2 +- 3 files changed, 52 insertions(+), 1 deletion(-) rename test/spmd/{test_spmd_xla_model_api.py => test_xla_spmd_python_api_interaction.py} (52%) diff --git a/test/run_tests.sh b/test/run_tests.sh index 683068864024..9ca4a0aecc12 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -172,6 +172,7 @@ function run_xla_op_tests { run_test "$CDIR/spmd/test_xla_virtual_device.py" run_test "$CDIR/spmd/test_dynamo_spmd.py" run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py" + run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py" run_save_tensor_ir "$CDIR/spmd/test_spmd_graph_dump.py" run_save_tensor_hlo "$CDIR/spmd/test_spmd_graph_dump.py" run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY diff --git a/test/spmd/test_spmd_xla_model_api.py b/test/spmd/test_xla_spmd_python_api_interaction.py similarity index 52% rename from test/spmd/test_spmd_xla_model_api.py rename to test/spmd/test_xla_spmd_python_api_interaction.py index fe044e04e16e..06cfbba06d04 100644 --- a/test/spmd/test_spmd_xla_model_api.py +++ b/test/spmd/test_xla_spmd_python_api_interaction.py @@ -5,6 +5,7 @@ import torch import torch_xla import torch_xla.core.xla_model as xm +from torch_xla import runtime as xr import test_xla_sharding_base @@ -54,6 +55,55 @@ def test_xla_replication_devices(self): self.assertEqual(xm.xla_real_devices([device]), [device_type + ':0']) +class BasicRuntimeAPITest(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + os.environ["XLA_USE_SPMD"] = "1" + super().setUpClass() + + def test_local_process_count(self): + self.assertEqual(xr.local_process_count(), 1) + + def test_global_device_count(self): + self.assertEqual(xr.global_device_count(), 1) + + def test_world_size(self): + self.assertEqual(xr.world_size(), 1) + + def test_local_device_count(self): + self.assertEqual(xr.local_device_count(), 1) + + def test_addressable_device_count(self): + self.assertEqual(xr.addressable_device_count(), 1) + + def test_global_ordinal(self): + self.assertEqual(xr.global_ordinal(), 0) + + def test_local_ordinal(self): + self.assertEqual(xr.local_ordinal(), 0) + + def test_process_index(self): + self.assertEqual(xr.process_index(), 0) + + def test_process_count(self): + self.assertEqual(xr.process_count(), 1) + + def test_global_runtime_device_count(self): + device_type = os.environ['PJRT_DEVICE'] + if device_type == "TPU": + self.assertGreaterEqual(xr.global_runtime_device_count(), 4) + elif device_type == "CPU": + self.assertEqual(xr.global_runtime_device_count(), 1) + + def test_addressable_runtime_device_count(self): + device_type = os.environ['PJRT_DEVICE'] + if device_type == "TPU": + self.assertGreaterEqual(xr.addressable_runtime_device_count(), 4) + elif device_type == "CPU": + self.assertEqual(xr.addressable_runtime_device_count(), 1) + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/xla_test_job.yaml b/test/tpu/xla_test_job.yaml index 2d37d6c767ea..ce5672bf676d 100644 --- a/test/tpu/xla_test_job.yaml +++ b/test/tpu/xla_test_job.yaml @@ -48,7 +48,7 @@ spec: python3 /src/pytorch/xla/test/spmd/test_xla_sharding.py python3 /src/pytorch/xla/test/spmd/test_xla_virtual_device.py python3 /src/pytorch/xla/test/spmd/test_train_spmd_linear_model.py - python3 /src/pytorch/xla/test/spmd/test_spmd_xla_model_api.py + python3 /src/pytorch/xla/test/spmd/test_xla_spmd_python_api_interaction.py XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shape_models.py -v XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shapes.py -v python3 /src/pytorch/xla/test/test_autocast.py