-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add basic device APIs to the top-level torch_xla
module.
#6571
Changes from all commits
6db9719
850d8d8
db0284f
fc122a9
db33a82
b3d012c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os | ||
|
||
from absl.testing import absltest, parameterized | ||
import torch | ||
import torch_xla as xla | ||
import torch_xla.runtime as xr | ||
|
||
|
||
class TestDevices(parameterized.TestCase): | ||
|
||
def setUpClass(): | ||
xr.set_device_type('CPU') | ||
os.environ['CPU_NUM_DEVICES'] = '4' | ||
|
||
@parameterized.parameters((None, torch.device('xla:0')), | ||
(0, torch.device('xla:0')), | ||
(3, torch.device('xla:3'))) | ||
def test_device(self, index, expected): | ||
device = xla.device(index) | ||
self.assertEqual(device, expected) | ||
|
||
def test_devices(self): | ||
self.assertEqual(xla.devices(), | ||
[torch.device(f'xla:{i}') for i in range(4)]) | ||
|
||
def test_real_devices(self): | ||
self.assertEqual(xla.real_devices(), [f'CPU:{i}' for i in range(4)]) | ||
|
||
def test_device_count(self): | ||
self.assertEqual(xla.device_count(), 4) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -182,3 +182,5 @@ def _init_xla_lazy_backend(): | |
if os.getenv('XLA_REGISTER_INSTALLED_PLUGINS') == '1': | ||
plugins.use_dynamic_plugins() | ||
plugins.register_installed_plugins() | ||
|
||
from .torch_xla import * | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This imports the contents of |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1188,14 +1188,20 @@ void InitXlaModuleBindings(py::module m) { | |
runtime::GetComputationClient()->GetAllDevices(); | ||
return all_devices; | ||
}); | ||
m.def("_xla_real_devices", [](const std::vector<std::string>& devices) { | ||
std::vector<std::string> xla_devices; | ||
{ | ||
NoGilSection nogil; | ||
xla_devices = GetXlaDevices(devices); | ||
} | ||
return xla_devices; | ||
}); | ||
m.def("_xla_real_devices", | ||
[](const std::optional<std::vector<std::string>> devices) { | ||
if (!devices) { | ||
return runtime::GetComputationClient()->GetLocalDevices(); | ||
} | ||
|
||
std::vector<std::string> xla_devices; | ||
{ | ||
NoGilSection nogil; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you know why we release the GIL in this block? Is it for tpu v2/v3 where we allow multiple threads to do some runtime job such as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure to be honest. Maybe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh lol it is that |
||
xla_devices = GetXlaDevices(*devices); | ||
} | ||
return xla_devices; | ||
}, | ||
py::arg("devices") = std::nullopt); | ||
m.def("_xla_set_replication_devices", | ||
[](const std::vector<std::string>& devices) { | ||
auto replication_devices = | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from typing import List | ||
import torch | ||
import torch_xla | ||
import torch_xla.core.xla_model as xm | ||
|
||
|
||
def device(index: int = None) -> torch.device: | ||
"""Returns a given instance of an XLA device. | ||
|
||
If SPMD enables, returns a virtual device that wraps all devices available | ||
to this process. | ||
|
||
Args: | ||
index: index of the XLA device to be returned. Corresponds to index in | ||
`torch_xla.devices()`. | ||
|
||
Returns: | ||
An XLA `torch.device`. | ||
""" | ||
|
||
return xm.xla_device(index) | ||
|
||
|
||
def devices() -> List[torch.device]: | ||
"""Returns all devices available in the current process. | ||
|
||
Returns: | ||
A list of XLA `torch.devices`. | ||
""" | ||
|
||
return [torch.device(d) for d in xm.get_xla_supported_devices()] | ||
|
||
|
||
def real_devices() -> List[str]: | ||
"""Returns local XLA device types and indices. | ||
|
||
Returns: | ||
A list strings representing the XLA devices available in the current | ||
process, e.g. `['TPU:0', 'TPU:1', ...]`. | ||
""" | ||
|
||
return torch_xla._XLAC._xla_real_devices() | ||
|
||
|
||
def device_count() -> int: | ||
"""Returns number of addressable devices in the current process.""" | ||
return len(real_devices()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, it is OK for now but shouldn't we also test it on GPU and TPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is sufficient IMO. We're really just testing the integration of this module with the runtime client, which has the same API regardless of the underlying device.
As we switch to using these functions by convention, they'll be exercised by almost every other test.