Skip to content

Commit

Permalink
Add basic device APIs to the top-level torch_xla module. (#6571)
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar authored Feb 21, 2024
1 parent e77a629 commit 0ec5b91
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 8 deletions.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/test_torch_distributed_xla_backend.py"
run_torchrun "$CDIR/pjrt/test_torchrun.py"
run_test "$CDIR/test_persistent_cache.py"
run_test "$CDIR/test_devices.py"
# NOTE: this line below is testing export and don't care about GPU
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$CDIR/test_core_aten_ops.py"
}
Expand Down
34 changes: 34 additions & 0 deletions test/test_devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os

from absl.testing import absltest, parameterized
import torch
import torch_xla as xla
import torch_xla.runtime as xr


class TestDevices(parameterized.TestCase):

def setUpClass():
xr.set_device_type('CPU')
os.environ['CPU_NUM_DEVICES'] = '4'

@parameterized.parameters((None, torch.device('xla:0')),
(0, torch.device('xla:0')),
(3, torch.device('xla:3')))
def test_device(self, index, expected):
device = xla.device(index)
self.assertEqual(device, expected)

def test_devices(self):
self.assertEqual(xla.devices(),
[torch.device(f'xla:{i}') for i in range(4)])

def test_real_devices(self):
self.assertEqual(xla.real_devices(), [f'CPU:{i}' for i in range(4)])

def test_device_count(self):
self.assertEqual(xla.device_count(), 4)


if __name__ == "__main__":
absltest.main()
2 changes: 2 additions & 0 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,5 @@ def _init_xla_lazy_backend():
if os.getenv('XLA_REGISTER_INSTALLED_PLUGINS') == '1':
plugins.use_dynamic_plugins()
plugins.register_installed_plugins()

from .torch_xla import *
22 changes: 14 additions & 8 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1188,14 +1188,20 @@ void InitXlaModuleBindings(py::module m) {
runtime::GetComputationClient()->GetAllDevices();
return all_devices;
});
m.def("_xla_real_devices", [](const std::vector<std::string>& devices) {
std::vector<std::string> xla_devices;
{
NoGilSection nogil;
xla_devices = GetXlaDevices(devices);
}
return xla_devices;
});
m.def("_xla_real_devices",
[](const std::optional<std::vector<std::string>> devices) {
if (!devices) {
return runtime::GetComputationClient()->GetLocalDevices();
}

std::vector<std::string> xla_devices;
{
NoGilSection nogil;
xla_devices = GetXlaDevices(*devices);
}
return xla_devices;
},
py::arg("devices") = std::nullopt);
m.def("_xla_set_replication_devices",
[](const std::vector<std::string>& devices) {
auto replication_devices =
Expand Down
47 changes: 47 additions & 0 deletions torch_xla/torch_xla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import List
import torch
import torch_xla
import torch_xla.core.xla_model as xm


def device(index: int = None) -> torch.device:
"""Returns a given instance of an XLA device.
If SPMD enables, returns a virtual device that wraps all devices available
to this process.
Args:
index: index of the XLA device to be returned. Corresponds to index in
`torch_xla.devices()`.
Returns:
An XLA `torch.device`.
"""

return xm.xla_device(index)


def devices() -> List[torch.device]:
"""Returns all devices available in the current process.
Returns:
A list of XLA `torch.devices`.
"""

return [torch.device(d) for d in xm.get_xla_supported_devices()]


def real_devices() -> List[str]:
"""Returns local XLA device types and indices.
Returns:
A list strings representing the XLA devices available in the current
process, e.g. `['TPU:0', 'TPU:1', ...]`.
"""

return torch_xla._XLAC._xla_real_devices()


def device_count() -> int:
"""Returns number of addressable devices in the current process."""
return len(real_devices())

0 comments on commit 0ec5b91

Please sign in to comment.