From 6db97191c51add95f2bf14bb576147ba4650632a Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 20 Feb 2024 21:13:12 +0000 Subject: [PATCH 1/6] Add basic device APIs to the top-level `torch_xla` module. --- torch_xla/__init__.py | 2 ++ torch_xla/csrc/init_python_bindings.cpp | 10 ++++-- torch_xla/torch_xla.py | 44 +++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 3 deletions(-) create mode 100644 torch_xla/torch_xla.py diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 8e657787ee9f..4cf9295b2ea4 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -182,3 +182,5 @@ def _init_xla_lazy_backend(): if os.getenv('XLA_REGISTER_INSTALLED_PLUGINS') == '1': plugins.use_dynamic_plugins() plugins.register_installed_plugins() + +from .torch_xla import * diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c023cd3a0dc4..196f60a7b18c 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1188,14 +1188,18 @@ void InitXlaModuleBindings(py::module m) { runtime::GetComputationClient()->GetAllDevices(); return all_devices; }); - m.def("_xla_real_devices", [](const std::vector& devices) { + m.def("_xla_real_devices", [](const std::optional> devices) { + if (!devices) { + return runtime::GetComputationClient()->GetLocalDevices(); + } + std::vector xla_devices; { NoGilSection nogil; - xla_devices = GetXlaDevices(devices); + xla_devices = GetXlaDevices(*devices); } return xla_devices; - }); + }, py::arg("devices") = std::nullopt); m.def("_xla_set_replication_devices", [](const std::vector& devices) { auto replication_devices = diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py new file mode 100644 index 000000000000..12c6737ece48 --- /dev/null +++ b/torch_xla/torch_xla.py @@ -0,0 +1,44 @@ +from typing import List +import torch +import torch_xla +import torch_xla.core.xla_model as xm + + +def device(n: int = None) -> torch.device: + """Returns a given instance of an XLA device. + + If SPMD enables, returns a virtual device that wraps all devices available + to this process. + + Args: + n: index of the XLA device to be returned. Corresponds to index in + `torch_xla.devices()`. + + Returns: + An XLA `torch.device`. + """ + + return xm.xla_device(n) + + +def devices() -> List[torch.device]: + """Returns all devices available in the current process. + + Returns: + A list of XLA `torch.devices`. + """ + + return [torch.device(d) for d in xm.get_xla_supported_devices()] + + +def real_devices() -> List[str]: + """Returns local XLA device types and indices. + + Returns: + A list strings representing the XLA devices available in the current process. + """ + + return torch_xla._XLAC._xla_real_devices() + +def device_count() -> int: + return len(real_devices()) From 850d8d80c0daddb236b79c4beb4b5c79e482996b Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 20 Feb 2024 21:23:10 +0000 Subject: [PATCH 2/6] unit test --- test/test_devices.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 test/test_devices.py diff --git a/test/test_devices.py b/test/test_devices.py new file mode 100644 index 000000000000..1382101ed5a0 --- /dev/null +++ b/test/test_devices.py @@ -0,0 +1,33 @@ +import os + +from absl.testing import absltest, parameterized +import torch +import torch_xla as xla +import torch_xla.runtime as xr + + +class TestDevices(parameterized.TestCase): + def setUpClass(): + xr.set_device_type('CPU') + os.environ['CPU_NUM_DEVICES'] = '4' + + @parameterized.parameters( + (None, torch.device('xla:0')), + (0, torch.device('xla:0')), + (3, torch.device('xla:3'))) + def test_device(self, index, expected): + device = xla.device(n=index) + self.assertEqual(device, expected) + + def test_devices(self): + self.assertEqual(xla.devices(), [torch.device(f'xla:{i}') for i in range(4)]) + + def test_real_devices(self): + self.assertEqual(xla.real_devices(), [f'CPU:{i}' for i in range(4)]) + + def test_device_count(self): + self.assertEqual(xla.device_count(), 4) + + +if __name__ == "__main__": + absltest.main() From db0284f607eacc50ffe0bf46e8e2ac5ffeafa40b Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 20 Feb 2024 21:25:11 +0000 Subject: [PATCH 3/6] format --- test/test_devices.py | 11 ++++++----- torch_xla/torch_xla.py | 3 ++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/test/test_devices.py b/test/test_devices.py index 1382101ed5a0..a37c01afa203 100644 --- a/test/test_devices.py +++ b/test/test_devices.py @@ -7,20 +7,21 @@ class TestDevices(parameterized.TestCase): + def setUpClass(): xr.set_device_type('CPU') os.environ['CPU_NUM_DEVICES'] = '4' - @parameterized.parameters( - (None, torch.device('xla:0')), - (0, torch.device('xla:0')), - (3, torch.device('xla:3'))) + @parameterized.parameters((None, torch.device('xla:0')), + (0, torch.device('xla:0')), + (3, torch.device('xla:3'))) def test_device(self, index, expected): device = xla.device(n=index) self.assertEqual(device, expected) def test_devices(self): - self.assertEqual(xla.devices(), [torch.device(f'xla:{i}') for i in range(4)]) + self.assertEqual(xla.devices(), + [torch.device(f'xla:{i}') for i in range(4)]) def test_real_devices(self): self.assertEqual(xla.real_devices(), [f'CPU:{i}' for i in range(4)]) diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index 12c6737ece48..f0bd0d08d043 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -35,10 +35,11 @@ def real_devices() -> List[str]: """Returns local XLA device types and indices. Returns: - A list strings representing the XLA devices available in the current process. + A list strings representing the XLA devices available in the current process, e.g. `['TPU:0', 'TPU:1', ...]`. """ return torch_xla._XLAC._xla_real_devices() + def device_count() -> int: return len(real_devices()) From fc122a94f5bfac02c315e988b97345a99e0bb80f Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 20 Feb 2024 21:27:10 +0000 Subject: [PATCH 4/6] clean up --- test/test_devices.py | 2 +- torch_xla/torch_xla.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/test/test_devices.py b/test/test_devices.py index a37c01afa203..ff93f64a5c50 100644 --- a/test/test_devices.py +++ b/test/test_devices.py @@ -16,7 +16,7 @@ def setUpClass(): (0, torch.device('xla:0')), (3, torch.device('xla:3'))) def test_device(self, index, expected): - device = xla.device(n=index) + device = xla.device(index) self.assertEqual(device, expected) def test_devices(self): diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index f0bd0d08d043..961f6a3217ed 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -4,21 +4,21 @@ import torch_xla.core.xla_model as xm -def device(n: int = None) -> torch.device: +def device(index: int = None) -> torch.device: """Returns a given instance of an XLA device. If SPMD enables, returns a virtual device that wraps all devices available to this process. Args: - n: index of the XLA device to be returned. Corresponds to index in + index: index of the XLA device to be returned. Corresponds to index in `torch_xla.devices()`. Returns: An XLA `torch.device`. """ - return xm.xla_device(n) + return xm.xla_device(index) def devices() -> List[torch.device]: @@ -35,11 +35,13 @@ def real_devices() -> List[str]: """Returns local XLA device types and indices. Returns: - A list strings representing the XLA devices available in the current process, e.g. `['TPU:0', 'TPU:1', ...]`. + A list strings representing the XLA devices available in the current + process, e.g. `['TPU:0', 'TPU:1', ...]`. """ return torch_xla._XLAC._xla_real_devices() def device_count() -> int: + """Returns number of addressable devices in the current process.""" return len(real_devices()) From db33a828d942be2402b6d6379039b016ad9b3eef Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 20 Feb 2024 21:28:23 +0000 Subject: [PATCH 5/6] add unit test to CI --- test/run_tests.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/test/run_tests.sh b/test/run_tests.sh index be98575c45e6..ca5f042e109e 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -211,6 +211,7 @@ function run_xla_op_tests3 { run_test "$CDIR/test_torch_distributed_xla_backend.py" run_torchrun "$CDIR/pjrt/test_torchrun.py" run_test "$CDIR/test_persistent_cache.py" + run_test "$CDIR/test_devices.py" # NOTE: this line below is testing export and don't care about GPU PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$CDIR/test_core_aten_ops.py" } From b3d012c40406ed5146bb9f502aaf8d32312aac9c Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 20 Feb 2024 21:29:41 +0000 Subject: [PATCH 6/6] clang format --- torch_xla/csrc/init_python_bindings.cpp | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 196f60a7b18c..404d4282e527 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1188,18 +1188,20 @@ void InitXlaModuleBindings(py::module m) { runtime::GetComputationClient()->GetAllDevices(); return all_devices; }); - m.def("_xla_real_devices", [](const std::optional> devices) { - if (!devices) { - return runtime::GetComputationClient()->GetLocalDevices(); - } + m.def("_xla_real_devices", + [](const std::optional> devices) { + if (!devices) { + return runtime::GetComputationClient()->GetLocalDevices(); + } - std::vector xla_devices; - { - NoGilSection nogil; - xla_devices = GetXlaDevices(*devices); - } - return xla_devices; - }, py::arg("devices") = std::nullopt); + std::vector xla_devices; + { + NoGilSection nogil; + xla_devices = GetXlaDevices(*devices); + } + return xla_devices; + }, + py::arg("devices") = std::nullopt); m.def("_xla_set_replication_devices", [](const std::vector& devices) { auto replication_devices =