Skip to content

Commit

Permalink
Add python test for SPMD+Runtime Python API (#5349)
Browse files Browse the repository at this point in the history
* Add python test for SPMD+Runtime Python API

* replace test name

* Update test_xla_spmd_python_api_interaction.py
  • Loading branch information
JackCaoG authored Jul 27, 2023
1 parent 22ba4af commit e70fb29
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 1 deletion.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ function run_xla_op_tests {
run_test "$CDIR/spmd/test_xla_virtual_device.py"
run_test "$CDIR/spmd/test_dynamo_spmd.py"
run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py"
run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py"
run_save_tensor_ir "$CDIR/spmd/test_spmd_graph_dump.py"
run_save_tensor_hlo "$CDIR/spmd/test_spmd_graph_dump.py"
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla import runtime as xr
import test_xla_sharding_base


Expand Down Expand Up @@ -54,6 +55,55 @@ def test_xla_replication_devices(self):
self.assertEqual(xm.xla_real_devices([device]), [device_type + ':0'])


class BasicRuntimeAPITest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
os.environ["XLA_USE_SPMD"] = "1"
super().setUpClass()

def test_local_process_count(self):
self.assertEqual(xr.local_process_count(), 1)

def test_global_device_count(self):
self.assertEqual(xr.global_device_count(), 1)

def test_world_size(self):
self.assertEqual(xr.world_size(), 1)

def test_local_device_count(self):
self.assertEqual(xr.local_device_count(), 1)

def test_addressable_device_count(self):
self.assertEqual(xr.addressable_device_count(), 1)

def test_global_ordinal(self):
self.assertEqual(xr.global_ordinal(), 0)

def test_local_ordinal(self):
self.assertEqual(xr.local_ordinal(), 0)

def test_process_index(self):
self.assertEqual(xr.process_index(), 0)

def test_process_count(self):
self.assertEqual(xr.process_count(), 1)

def test_global_runtime_device_count(self):
device_type = os.environ['PJRT_DEVICE']
if device_type == "TPU":
self.assertGreaterEqual(xr.global_runtime_device_count(), 4)
elif device_type == "CPU":
self.assertEqual(xr.global_runtime_device_count(), 1)

def test_addressable_runtime_device_count(self):
device_type = os.environ['PJRT_DEVICE']
if device_type == "TPU":
self.assertGreaterEqual(xr.addressable_runtime_device_count(), 4)
elif device_type == "CPU":
self.assertEqual(xr.addressable_runtime_device_count(), 1)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
2 changes: 1 addition & 1 deletion test/tpu/xla_test_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ spec:
python3 /src/pytorch/xla/test/spmd/test_xla_sharding.py
python3 /src/pytorch/xla/test/spmd/test_xla_virtual_device.py
python3 /src/pytorch/xla/test/spmd/test_train_spmd_linear_model.py
python3 /src/pytorch/xla/test/spmd/test_spmd_xla_model_api.py
python3 /src/pytorch/xla/test/spmd/test_xla_spmd_python_api_interaction.py
XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shape_models.py -v
XLA_EXPERIMENTAL=nonzero:masked_select python3 /src/pytorch/xla/test/ds/test_dynamic_shapes.py -v
python3 /src/pytorch/xla/test/test_autocast.py
Expand Down

0 comments on commit e70fb29

Please sign in to comment.