forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[DTensor][XLA] Support Xla backend in distribute_tensor API (pytorch#…
…110275) This addresses pytorch#92909 , and enable XLA backend support for `distribute_tensor` API. Test plan: added a unit test case & tested with CloudTPU. The CI should skip this unless it's a XLA workflow. Pull Request resolved: pytorch#110275 Approved by: https://github.com/wanchaol, https://github.com/alanwaketan, https://github.com/JackCaoG
- Loading branch information
1 parent
6411ae2
commit c5e299f
Showing
4 changed files
with
343 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates | ||
# Owner(s): ["oncall: distributed"] | ||
|
||
import os | ||
import unittest | ||
from functools import wraps | ||
from typing import Any, Callable, Dict, Tuple | ||
|
||
import numpy as np | ||
import torch | ||
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard | ||
from torch.testing._internal.common_utils import run_tests, TestCase | ||
|
||
|
||
# wrapper to check xla test requirements | ||
def with_xla(func: Callable) -> Callable: | ||
assert func is not None | ||
|
||
@wraps(func) # pyre-ignore[6] | ||
def wrapper( | ||
self, *args: Tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc] | ||
) -> None: | ||
# TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag. | ||
os.environ["XLA_USE_SPMD"] = "1" | ||
try: | ||
import torch_xla # type:ignore[import] # noqa: F401 | ||
except ImportError as exc: | ||
raise unittest.SkipTest("torch_xla is not installed.") from exc | ||
self.device_type = "xla" | ||
func(self, *args, **kwargs) # type: ignore[misc] | ||
os.environ["XLA_USE_SPMD"] = "0" | ||
|
||
return wrapper | ||
|
||
|
||
class DTensorXLAIntegrationTest(TestCase): | ||
@with_xla | ||
def test_xla_distribute_tensor_1d_shard(self): | ||
import torch_xla.runtime as xr # type:ignore[import] | ||
|
||
device_count = xr.global_runtime_device_count() | ||
if device_count > 1: | ||
device_mesh = DeviceMesh("xla", list(range(device_count))) | ||
shard_spec = [Shard(0)] | ||
|
||
for requires_grad in [True, False]: | ||
tensor_to_shard = torch.randn( | ||
3 * device_count, 3, requires_grad=requires_grad | ||
) | ||
dist_tensor = distribute_tensor( | ||
tensor_to_shard, device_mesh, shard_spec | ||
) | ||
# TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor | ||
assert type(dist_tensor).__name__ == "XLAShardedTensor" | ||
global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined] | ||
self.assertEqual( | ||
global_tensor.size(), torch.Size([3 * device_count, 3]) | ||
) | ||
local_tensor = dist_tensor.local_shards[0].data | ||
self.assertEqual(local_tensor.size(), torch.Size([3, 3])) | ||
if requires_grad: | ||
self.assertTrue(dist_tensor.global_tensor.requires_grad) | ||
self.assertTrue(dist_tensor.is_leaf) | ||
|
||
@with_xla | ||
def test_xla_distribute_tensor_1d_replicate(self): | ||
import torch_xla.runtime as xr # type:ignore[import] | ||
|
||
device_count = xr.global_runtime_device_count() | ||
device_mesh = DeviceMesh("xla", list(range(device_count))) | ||
shard_spec = [Replicate()] | ||
|
||
for requires_grad in [True, False]: | ||
tensor_to_shard = torch.randn( | ||
3 * device_count, 3, requires_grad=requires_grad | ||
) | ||
dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec) | ||
# TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor | ||
assert type(dist_tensor).__name__ == "XLAShardedTensor" | ||
global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined] | ||
self.assertEqual(global_tensor.size(), torch.Size([3 * device_count, 3])) | ||
local_tensor = dist_tensor.local_shards[0].data | ||
self.assertEqual(local_tensor.size(), torch.Size([3 * device_count, 3])) | ||
if requires_grad: | ||
self.assertTrue(dist_tensor.global_tensor.requires_grad) | ||
self.assertTrue(dist_tensor.is_leaf) | ||
|
||
@with_xla | ||
def test_xla_distribute_tensor_2d(self): | ||
import torch_xla.runtime as xr # type:ignore[import] | ||
|
||
device_count = xr.global_runtime_device_count() | ||
if device_count > 1: | ||
device_mesh = DeviceMesh( | ||
"xla", np.array(range(device_count)).reshape(2, device_count // 2) | ||
) | ||
shard_spec = [Replicate(), Shard(0)] | ||
|
||
for requires_grad in [True, False]: | ||
tensor_to_shard = torch.randn( | ||
3 * device_count // 2, 3, requires_grad=requires_grad | ||
) | ||
dist_tensor = distribute_tensor( | ||
tensor_to_shard, device_mesh, shard_spec | ||
) | ||
# TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor | ||
assert type(dist_tensor).__name__ == "XLAShardedTensor" | ||
global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined] | ||
self.assertEqual( | ||
global_tensor.size(), torch.Size([3 * device_count // 2, 3]) | ||
) | ||
local_tensor = dist_tensor.local_shards[0].data | ||
self.assertEqual(local_tensor.size(), torch.Size([3, 3])) | ||
if requires_grad: | ||
self.assertTrue(dist_tensor.global_tensor.requires_grad) | ||
self.assertTrue(dist_tensor.is_leaf) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
import logging | ||
import os | ||
from functools import wraps | ||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union | ||
|
||
import torch | ||
|
||
import torch.nn as nn | ||
from torch.distributed._tensor.device_mesh import DeviceMesh | ||
from torch.distributed._tensor.placement_types import Placement, Replicate | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
TORCH_XLA_INITIALIZED = False | ||
try: | ||
import torch_xla.core.xla_model as xm # type:ignore[import] # noqa: F401 | ||
import torch_xla.runtime as xr # type:ignore[import] | ||
from torch_xla.experimental.xla_sharded_tensor import ( # type:ignore[import] | ||
XLAShardedTensor, | ||
) | ||
from torch_xla.experimental.xla_sharding import ( # type:ignore[import] | ||
mark_sharding, | ||
Mesh, | ||
ShardingType, | ||
) | ||
|
||
TORCH_XLA_INITIALIZED = True | ||
except ImportError as e: | ||
log.warning(e.msg) | ||
|
||
|
||
# wrapper to check xla test requirements | ||
def with_xla(func: Callable) -> Callable: | ||
assert func is not None | ||
|
||
@wraps(func) # pyre-ignore[6] | ||
def wrapper( | ||
self, *args: Tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc] | ||
) -> None: | ||
if TORCH_XLA_INITIALIZED: | ||
# TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag. | ||
os.environ["XLA_USE_SPMD"] = "1" | ||
return func(self, *args, **kwargs) # type: ignore[misc] | ||
else: | ||
raise ImportError( | ||
"torch.distributed._tensor._xla API requires torch_xla package installation." | ||
) | ||
|
||
return wrapper | ||
|
||
|
||
@with_xla | ||
def convert_to_xla_mesh(dt_mesh: DeviceMesh) -> "Mesh": | ||
""" | ||
Convert DTensor `dt_mesh` to XLAShardedTensor `partition_spec`. | ||
Example (1x4 logical device mesh topology): | ||
``` | ||
dt_mesh = DeviceMesh("xla", [[1, 2, 3, 4]]) | ||
dt_mesh.mesh.shape | ||
>> torch.Size([1, 4]) | ||
mesh = convert_to_xla_mesh(dt_mesh) | ||
mesh_shape | ||
>> [1, 4] | ||
``` | ||
""" | ||
assert dt_mesh.size() == xr.global_runtime_device_count() | ||
return Mesh( | ||
dt_mesh.mesh.flatten(), tuple(dt_mesh.mesh.size()), dt_mesh.mesh_dim_names | ||
) | ||
|
||
|
||
@with_xla | ||
def convert_to_xla_partition_spec( | ||
tensor: torch.Tensor, placements: Sequence[Placement] | ||
) -> Tuple[Union[Tuple, int, None]]: | ||
""" | ||
Convert DTensor `placements` to XLAShardedTensor `partitoin_spec`. | ||
This supports Shard and Replicate Placement types. | ||
Example: | ||
``` | ||
# Mesh partitioning, 1/4-th of the input with replicated overlaps. | ||
# The first input tensor dimension is sharded across the second mesh | ||
# dimension, and the rest is replicated over the first mesh dimension. | ||
t = torch.randn(4, 8, 8) | ||
dt_mesh = DeviceMesh("xla", torch.arange(8).reshape(2,4)) | ||
placements = [Replicate(), Shard(0)] | ||
my_dtensor = distribute_tensor(t, dt_mesh, placements) | ||
# `placements = [Replicate(), Shard(0)]` describes sharding per mesh dim, | ||
# and this is equivalent to `partition_spec = (1, None, None)` which is | ||
# sharding per input tensor dimension. | ||
partition_spec = convert_to_xla_partition_spec(t, placements) | ||
>> (1, None, None) | ||
``` | ||
""" | ||
# per tensor dimension sharding | ||
sharding_spec = [None] * len(tensor.shape) | ||
for mesh_idx, spec in enumerate(placements): | ||
if spec.is_shard(): # type:ignore[truthy-function] | ||
# mesh_idx to tensor_idx (spec.dim) | ||
tensor_idx = spec.dim # type:ignore[attr-defined] | ||
sharding_spec[tensor_idx] = mesh_idx # type:ignore[call-overload] | ||
elif spec.is_replicate(): | ||
# spec.dim is already set to None by default | ||
continue | ||
else: | ||
raise ValueError(f"Unsupported placement type: {type(spec).__name__}") | ||
return tuple(sharding_spec) # type:ignore[return-value] | ||
|
||
|
||
@with_xla | ||
def xla_distribute_tensor( | ||
tensor: torch.Tensor, | ||
device_mesh: DeviceMesh, | ||
placements: Optional[Sequence[Placement]] = None, | ||
) -> "XLAShardedTensor": | ||
""" | ||
Distribute a torch.Tensor to the `device_mesh` according to the `placements` | ||
specified. The rank of `device_mesh` and `placements` must be the same. | ||
Args: | ||
tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you | ||
want to shard a tensor on a dimension that is not evenly divisible by | ||
the number of devices in that mesh dimension, we use `torch.chunk` | ||
semantic to shard the tensor and scatter the shards. | ||
device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the | ||
tensor, if not specified, must be called under a DeviceMesh context | ||
manager, default: None | ||
placements (List[:class:`Placement`], optional): the placements that | ||
describes how to place the tensor on DeviceMesh, must have the same | ||
number of elements as `device_mesh.ndim`. If not specified, we will | ||
by default replicate the tensor across the `device_mesh` from the | ||
first rank of each dimension of the `device_mesh`. | ||
Returns: | ||
A :class:`XLAShardedTensor` object | ||
.. note:: We return a XLAShardedTensor with a global view and access to local shards. | ||
The successive ops would be programmed as if on a single-device and without calling | ||
any explicit collective ops. The actual sharded computation on the sharding annotated tensor | ||
happens lazily, is transparent to the user. In the future, we will introduce | ||
a new DTensor type for this kind of programming-mode (single-controller) and return. | ||
""" | ||
# device_mesh is not optional in xla_distribute_tensor | ||
dt_mesh = device_mesh | ||
assert dt_mesh.device_type == "xla" | ||
|
||
# convert to XLA device mesh | ||
xla_mesh = convert_to_xla_mesh(dt_mesh) | ||
assert xla_mesh.mesh_shape == tuple(dt_mesh.mesh.size()) | ||
|
||
# convert tensor to the corresponding device type if it's not in that device type | ||
if not tensor.is_meta: | ||
tensor = tensor.to(dt_mesh.device_type) | ||
# set default placements to replicated if not specified | ||
if placements is None: | ||
placements = [Replicate() for _ in range(dt_mesh.ndim)] | ||
assert ( | ||
len(placements) == dt_mesh.ndim | ||
), "`placements` must have the same length as `device_mesh.ndim`! " | ||
f"Found placements length: {len(placements)}, and device_mesh.ndim: {dt_mesh.ndim}." | ||
# convert placements to xla partition spec | ||
partition_spec = convert_to_xla_partition_spec(tensor, placements) | ||
assert len(tensor.shape) == len( | ||
partition_spec | ||
), "`partition_spec` from `placements` must have the same length as `tensor.length`! " | ||
f"Found tensor shape length: {len(tensor.shape)}, and partition_spec length: {len(partition_spec)}." | ||
|
||
global_tensor = tensor | ||
if type(tensor).__name__ == "DTensor": | ||
raise ValueError( | ||
"Cannot distribute a DTensor with local tensor on xla devices." | ||
"The input tensor must be global." | ||
) | ||
if type(tensor).__name__ == "XLAShardedTensor": | ||
sharding_type = tensor.sharding_type # type:ignore[attr-defined] | ||
assert ( | ||
sharding_type is None or sharding_type == ShardingType.REPLICATED | ||
), "XLAShardedTensor `tensor` is already annotated with non-replication sharding. " | ||
"Clear the existing sharding annotation first, by callling torch_xla.experimental.xla_sharding.clear_sharding API." | ||
global_tensor = tensor.global_tensor # type:ignore[attr-defined] | ||
assert global_tensor is not None, "distributing a tensor should not be None" | ||
|
||
# Annotates sharding and returns an XLAShardedTensor | ||
xla_tensor = mark_sharding(global_tensor, xla_mesh, partition_spec) | ||
return xla_tensor | ||
|
||
|
||
@with_xla | ||
def xla_distribute_module( | ||
module: nn.Module, | ||
device_mesh: Optional[DeviceMesh] = None, | ||
partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None, | ||
input_fn: Optional[Callable[..., None]] = None, | ||
output_fn: Optional[Callable[..., None]] = None, | ||
) -> nn.Module: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters