-
Notifications
You must be signed in to change notification settings - Fork 505
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add basic device APIs to the top-level
torch_xla
module. (#6571)
- Loading branch information
1 parent
e77a629
commit 0ec5b91
Showing
5 changed files
with
98 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os | ||
|
||
from absl.testing import absltest, parameterized | ||
import torch | ||
import torch_xla as xla | ||
import torch_xla.runtime as xr | ||
|
||
|
||
class TestDevices(parameterized.TestCase): | ||
|
||
def setUpClass(): | ||
xr.set_device_type('CPU') | ||
os.environ['CPU_NUM_DEVICES'] = '4' | ||
|
||
@parameterized.parameters((None, torch.device('xla:0')), | ||
(0, torch.device('xla:0')), | ||
(3, torch.device('xla:3'))) | ||
def test_device(self, index, expected): | ||
device = xla.device(index) | ||
self.assertEqual(device, expected) | ||
|
||
def test_devices(self): | ||
self.assertEqual(xla.devices(), | ||
[torch.device(f'xla:{i}') for i in range(4)]) | ||
|
||
def test_real_devices(self): | ||
self.assertEqual(xla.real_devices(), [f'CPU:{i}' for i in range(4)]) | ||
|
||
def test_device_count(self): | ||
self.assertEqual(xla.device_count(), 4) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from typing import List | ||
import torch | ||
import torch_xla | ||
import torch_xla.core.xla_model as xm | ||
|
||
|
||
def device(index: int = None) -> torch.device: | ||
"""Returns a given instance of an XLA device. | ||
If SPMD enables, returns a virtual device that wraps all devices available | ||
to this process. | ||
Args: | ||
index: index of the XLA device to be returned. Corresponds to index in | ||
`torch_xla.devices()`. | ||
Returns: | ||
An XLA `torch.device`. | ||
""" | ||
|
||
return xm.xla_device(index) | ||
|
||
|
||
def devices() -> List[torch.device]: | ||
"""Returns all devices available in the current process. | ||
Returns: | ||
A list of XLA `torch.devices`. | ||
""" | ||
|
||
return [torch.device(d) for d in xm.get_xla_supported_devices()] | ||
|
||
|
||
def real_devices() -> List[str]: | ||
"""Returns local XLA device types and indices. | ||
Returns: | ||
A list strings representing the XLA devices available in the current | ||
process, e.g. `['TPU:0', 'TPU:1', ...]`. | ||
""" | ||
|
||
return torch_xla._XLAC._xla_real_devices() | ||
|
||
|
||
def device_count() -> int: | ||
"""Returns number of addressable devices in the current process.""" | ||
return len(real_devices()) |