Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic device APIs to the top-level torch_xla module. #6571

Merged
merged 6 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/test_torch_distributed_xla_backend.py"
run_torchrun "$CDIR/pjrt/test_torchrun.py"
run_test "$CDIR/test_persistent_cache.py"
run_test "$CDIR/test_devices.py"
# NOTE: this line below is testing export and don't care about GPU
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$CDIR/test_core_aten_ops.py"
}
Expand Down
34 changes: 34 additions & 0 deletions test/test_devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os

from absl.testing import absltest, parameterized
import torch
import torch_xla as xla
import torch_xla.runtime as xr


class TestDevices(parameterized.TestCase):

def setUpClass():
xr.set_device_type('CPU')
os.environ['CPU_NUM_DEVICES'] = '4'
Comment on lines +11 to +13
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, it is OK for now but shouldn't we also test it on GPU and TPU?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is sufficient IMO. We're really just testing the integration of this module with the runtime client, which has the same API regardless of the underlying device.

As we switch to using these functions by convention, they'll be exercised by almost every other test.


@parameterized.parameters((None, torch.device('xla:0')),
(0, torch.device('xla:0')),
(3, torch.device('xla:3')))
def test_device(self, index, expected):
device = xla.device(index)
self.assertEqual(device, expected)

def test_devices(self):
self.assertEqual(xla.devices(),
[torch.device(f'xla:{i}') for i in range(4)])

def test_real_devices(self):
self.assertEqual(xla.real_devices(), [f'CPU:{i}' for i in range(4)])

def test_device_count(self):
self.assertEqual(xla.device_count(), 4)


if __name__ == "__main__":
absltest.main()
2 changes: 2 additions & 0 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,5 @@ def _init_xla_lazy_backend():
if os.getenv('XLA_REGISTER_INSTALLED_PLUGINS') == '1':
plugins.use_dynamic_plugins()
plugins.register_installed_plugins()

from .torch_xla import *
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This imports the contents of torch_xla.py into torch_xla/'s module scope. Otherwise, the functions would be torch_xla.torch_xla.etc. This assigns them to torch_xla.etc

22 changes: 14 additions & 8 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1188,14 +1188,20 @@ void InitXlaModuleBindings(py::module m) {
runtime::GetComputationClient()->GetAllDevices();
return all_devices;
});
m.def("_xla_real_devices", [](const std::vector<std::string>& devices) {
std::vector<std::string> xla_devices;
{
NoGilSection nogil;
xla_devices = GetXlaDevices(devices);
}
return xla_devices;
});
m.def("_xla_real_devices",
[](const std::optional<std::vector<std::string>> devices) {
if (!devices) {
return runtime::GetComputationClient()->GetLocalDevices();
}

std::vector<std::string> xla_devices;
{
NoGilSection nogil;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why we release the GIL in this block? Is it for tpu v2/v3 where we allow multiple threads to do some runtime job such as GetXlaDevices(*devices)?

Copy link
Collaborator Author

@will-cromar will-cromar Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure to be honest. Maybe GetXlaDevices was a blocking call in XRT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh lol it is that torch_xla

xla_devices = GetXlaDevices(*devices);
}
return xla_devices;
},
py::arg("devices") = std::nullopt);
m.def("_xla_set_replication_devices",
[](const std::vector<std::string>& devices) {
auto replication_devices =
Expand Down
47 changes: 47 additions & 0 deletions torch_xla/torch_xla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import List
import torch
import torch_xla
import torch_xla.core.xla_model as xm


def device(index: int = None) -> torch.device:
"""Returns a given instance of an XLA device.

If SPMD enables, returns a virtual device that wraps all devices available
to this process.

Args:
index: index of the XLA device to be returned. Corresponds to index in
`torch_xla.devices()`.

Returns:
An XLA `torch.device`.
"""

return xm.xla_device(index)


def devices() -> List[torch.device]:
"""Returns all devices available in the current process.

Returns:
A list of XLA `torch.devices`.
"""

return [torch.device(d) for d in xm.get_xla_supported_devices()]


def real_devices() -> List[str]:
"""Returns local XLA device types and indices.

Returns:
A list strings representing the XLA devices available in the current
process, e.g. `['TPU:0', 'TPU:1', ...]`.
"""

return torch_xla._XLAC._xla_real_devices()


def device_count() -> int:
"""Returns number of addressable devices in the current process."""
return len(real_devices())
Loading