Skip to content

Commit

Permalink
[DTensor][XLA] Support Xla backend in distribute_tensor API (pytorch#…
Browse files Browse the repository at this point in the history
…110275)

This addresses pytorch#92909 , and enable XLA backend support for `distribute_tensor` API.

Test plan: added a unit test case & tested with CloudTPU. The CI should skip this unless it's a XLA workflow.

Pull Request resolved: pytorch#110275
Approved by: https://github.com/wanchaol, https://github.com/alanwaketan, https://github.com/JackCaoG
  • Loading branch information
yeounoh authored and Skylion007 committed Nov 14, 2023
1 parent 6411ae2 commit c5e299f
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 9 deletions.
120 changes: 120 additions & 0 deletions test/distributed/_tensor/test_xla_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]

import os
import unittest
from functools import wraps
from typing import Any, Callable, Dict, Tuple

import numpy as np
import torch
from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard
from torch.testing._internal.common_utils import run_tests, TestCase


# wrapper to check xla test requirements
def with_xla(func: Callable) -> Callable:
assert func is not None

@wraps(func) # pyre-ignore[6]
def wrapper(
self, *args: Tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc]
) -> None:
# TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag.
os.environ["XLA_USE_SPMD"] = "1"
try:
import torch_xla # type:ignore[import] # noqa: F401
except ImportError as exc:
raise unittest.SkipTest("torch_xla is not installed.") from exc
self.device_type = "xla"
func(self, *args, **kwargs) # type: ignore[misc]
os.environ["XLA_USE_SPMD"] = "0"

return wrapper


class DTensorXLAIntegrationTest(TestCase):
@with_xla
def test_xla_distribute_tensor_1d_shard(self):
import torch_xla.runtime as xr # type:ignore[import]

device_count = xr.global_runtime_device_count()
if device_count > 1:
device_mesh = DeviceMesh("xla", list(range(device_count)))
shard_spec = [Shard(0)]

for requires_grad in [True, False]:
tensor_to_shard = torch.randn(
3 * device_count, 3, requires_grad=requires_grad
)
dist_tensor = distribute_tensor(
tensor_to_shard, device_mesh, shard_spec
)
# TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
assert type(dist_tensor).__name__ == "XLAShardedTensor"
global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined]
self.assertEqual(
global_tensor.size(), torch.Size([3 * device_count, 3])
)
local_tensor = dist_tensor.local_shards[0].data
self.assertEqual(local_tensor.size(), torch.Size([3, 3]))
if requires_grad:
self.assertTrue(dist_tensor.global_tensor.requires_grad)
self.assertTrue(dist_tensor.is_leaf)

@with_xla
def test_xla_distribute_tensor_1d_replicate(self):
import torch_xla.runtime as xr # type:ignore[import]

device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
shard_spec = [Replicate()]

for requires_grad in [True, False]:
tensor_to_shard = torch.randn(
3 * device_count, 3, requires_grad=requires_grad
)
dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
# TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
assert type(dist_tensor).__name__ == "XLAShardedTensor"
global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined]
self.assertEqual(global_tensor.size(), torch.Size([3 * device_count, 3]))
local_tensor = dist_tensor.local_shards[0].data
self.assertEqual(local_tensor.size(), torch.Size([3 * device_count, 3]))
if requires_grad:
self.assertTrue(dist_tensor.global_tensor.requires_grad)
self.assertTrue(dist_tensor.is_leaf)

@with_xla
def test_xla_distribute_tensor_2d(self):
import torch_xla.runtime as xr # type:ignore[import]

device_count = xr.global_runtime_device_count()
if device_count > 1:
device_mesh = DeviceMesh(
"xla", np.array(range(device_count)).reshape(2, device_count // 2)
)
shard_spec = [Replicate(), Shard(0)]

for requires_grad in [True, False]:
tensor_to_shard = torch.randn(
3 * device_count // 2, 3, requires_grad=requires_grad
)
dist_tensor = distribute_tensor(
tensor_to_shard, device_mesh, shard_spec
)
# TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
assert type(dist_tensor).__name__ == "XLAShardedTensor"
global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined]
self.assertEqual(
global_tensor.size(), torch.Size([3 * device_count // 2, 3])
)
local_tensor = dist_tensor.local_shards[0].data
self.assertEqual(local_tensor.size(), torch.Size([3, 3]))
if requires_grad:
self.assertTrue(dist_tensor.global_tensor.requires_grad)
self.assertTrue(dist_tensor.is_leaf)


if __name__ == "__main__":
run_tests()
200 changes: 200 additions & 0 deletions torch/distributed/_tensor/_xla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import logging
import os
from functools import wraps
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import torch

import torch.nn as nn
from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.placement_types import Placement, Replicate

log = logging.getLogger(__name__)

TORCH_XLA_INITIALIZED = False
try:
import torch_xla.core.xla_model as xm # type:ignore[import] # noqa: F401
import torch_xla.runtime as xr # type:ignore[import]
from torch_xla.experimental.xla_sharded_tensor import ( # type:ignore[import]
XLAShardedTensor,
)
from torch_xla.experimental.xla_sharding import ( # type:ignore[import]
mark_sharding,
Mesh,
ShardingType,
)

TORCH_XLA_INITIALIZED = True
except ImportError as e:
log.warning(e.msg)


# wrapper to check xla test requirements
def with_xla(func: Callable) -> Callable:
assert func is not None

@wraps(func) # pyre-ignore[6]
def wrapper(
self, *args: Tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc]
) -> None:
if TORCH_XLA_INITIALIZED:
# TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag.
os.environ["XLA_USE_SPMD"] = "1"
return func(self, *args, **kwargs) # type: ignore[misc]
else:
raise ImportError(
"torch.distributed._tensor._xla API requires torch_xla package installation."
)

return wrapper


@with_xla
def convert_to_xla_mesh(dt_mesh: DeviceMesh) -> "Mesh":
"""
Convert DTensor `dt_mesh` to XLAShardedTensor `partition_spec`.
Example (1x4 logical device mesh topology):
```
dt_mesh = DeviceMesh("xla", [[1, 2, 3, 4]])
dt_mesh.mesh.shape
>> torch.Size([1, 4])
mesh = convert_to_xla_mesh(dt_mesh)
mesh_shape
>> [1, 4]
```
"""
assert dt_mesh.size() == xr.global_runtime_device_count()
return Mesh(
dt_mesh.mesh.flatten(), tuple(dt_mesh.mesh.size()), dt_mesh.mesh_dim_names
)


@with_xla
def convert_to_xla_partition_spec(
tensor: torch.Tensor, placements: Sequence[Placement]
) -> Tuple[Union[Tuple, int, None]]:
"""
Convert DTensor `placements` to XLAShardedTensor `partitoin_spec`.
This supports Shard and Replicate Placement types.
Example:
```
# Mesh partitioning, 1/4-th of the input with replicated overlaps.
# The first input tensor dimension is sharded across the second mesh
# dimension, and the rest is replicated over the first mesh dimension.
t = torch.randn(4, 8, 8)
dt_mesh = DeviceMesh("xla", torch.arange(8).reshape(2,4))
placements = [Replicate(), Shard(0)]
my_dtensor = distribute_tensor(t, dt_mesh, placements)
# `placements = [Replicate(), Shard(0)]` describes sharding per mesh dim,
# and this is equivalent to `partition_spec = (1, None, None)` which is
# sharding per input tensor dimension.
partition_spec = convert_to_xla_partition_spec(t, placements)
>> (1, None, None)
```
"""
# per tensor dimension sharding
sharding_spec = [None] * len(tensor.shape)
for mesh_idx, spec in enumerate(placements):
if spec.is_shard(): # type:ignore[truthy-function]
# mesh_idx to tensor_idx (spec.dim)
tensor_idx = spec.dim # type:ignore[attr-defined]
sharding_spec[tensor_idx] = mesh_idx # type:ignore[call-overload]
elif spec.is_replicate():
# spec.dim is already set to None by default
continue
else:
raise ValueError(f"Unsupported placement type: {type(spec).__name__}")
return tuple(sharding_spec) # type:ignore[return-value]


@with_xla
def xla_distribute_tensor(
tensor: torch.Tensor,
device_mesh: DeviceMesh,
placements: Optional[Sequence[Placement]] = None,
) -> "XLAShardedTensor":
"""
Distribute a torch.Tensor to the `device_mesh` according to the `placements`
specified. The rank of `device_mesh` and `placements` must be the same.
Args:
tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you
want to shard a tensor on a dimension that is not evenly divisible by
the number of devices in that mesh dimension, we use `torch.chunk`
semantic to shard the tensor and scatter the shards.
device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the
tensor, if not specified, must be called under a DeviceMesh context
manager, default: None
placements (List[:class:`Placement`], optional): the placements that
describes how to place the tensor on DeviceMesh, must have the same
number of elements as `device_mesh.ndim`. If not specified, we will
by default replicate the tensor across the `device_mesh` from the
first rank of each dimension of the `device_mesh`.
Returns:
A :class:`XLAShardedTensor` object
.. note:: We return a XLAShardedTensor with a global view and access to local shards.
The successive ops would be programmed as if on a single-device and without calling
any explicit collective ops. The actual sharded computation on the sharding annotated tensor
happens lazily, is transparent to the user. In the future, we will introduce
a new DTensor type for this kind of programming-mode (single-controller) and return.
"""
# device_mesh is not optional in xla_distribute_tensor
dt_mesh = device_mesh
assert dt_mesh.device_type == "xla"

# convert to XLA device mesh
xla_mesh = convert_to_xla_mesh(dt_mesh)
assert xla_mesh.mesh_shape == tuple(dt_mesh.mesh.size())

# convert tensor to the corresponding device type if it's not in that device type
if not tensor.is_meta:
tensor = tensor.to(dt_mesh.device_type)
# set default placements to replicated if not specified
if placements is None:
placements = [Replicate() for _ in range(dt_mesh.ndim)]
assert (
len(placements) == dt_mesh.ndim
), "`placements` must have the same length as `device_mesh.ndim`! "
f"Found placements length: {len(placements)}, and device_mesh.ndim: {dt_mesh.ndim}."
# convert placements to xla partition spec
partition_spec = convert_to_xla_partition_spec(tensor, placements)
assert len(tensor.shape) == len(
partition_spec
), "`partition_spec` from `placements` must have the same length as `tensor.length`! "
f"Found tensor shape length: {len(tensor.shape)}, and partition_spec length: {len(partition_spec)}."

global_tensor = tensor
if type(tensor).__name__ == "DTensor":
raise ValueError(
"Cannot distribute a DTensor with local tensor on xla devices."
"The input tensor must be global."
)
if type(tensor).__name__ == "XLAShardedTensor":
sharding_type = tensor.sharding_type # type:ignore[attr-defined]
assert (
sharding_type is None or sharding_type == ShardingType.REPLICATED
), "XLAShardedTensor `tensor` is already annotated with non-replication sharding. "
"Clear the existing sharding annotation first, by callling torch_xla.experimental.xla_sharding.clear_sharding API."
global_tensor = tensor.global_tensor # type:ignore[attr-defined]
assert global_tensor is not None, "distributing a tensor should not be None"

# Annotates sharding and returns an XLAShardedTensor
xla_tensor = mark_sharding(global_tensor, xla_mesh, partition_spec)
return xla_tensor


@with_xla
def xla_distribute_module(
module: nn.Module,
device_mesh: Optional[DeviceMesh] = None,
partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None,
input_fn: Optional[Callable[..., None]] = None,
output_fn: Optional[Callable[..., None]] = None,
) -> nn.Module:
raise NotImplementedError
16 changes: 13 additions & 3 deletions torch/distributed/_tensor/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
from torch.distributed._tensor._collective_utils import mesh_broadcast
from torch.distributed._tensor._utils import compute_global_tensor_info
from torch.distributed._tensor._xla import xla_distribute_tensor
from torch.distributed._tensor.device_mesh import DeviceMesh, mesh_resources
from torch.distributed._tensor.placement_types import (
DTensorSpec,
Expand Down Expand Up @@ -237,7 +238,6 @@ def __tensor_unflatten__(inner_tensors, flatten_spec):
assert (
flatten_spec is not None
), "Expecting spec to be not None from `__tensor_flatten__` return value!"
assert isinstance(inner_tensors, dict) and len(inner_tensors) == 1
local_tensor = inner_tensors["_local_tensor"]
spec, requires_grad = flatten_spec
return DTensor(
Expand Down Expand Up @@ -451,14 +451,25 @@ def distribute_tensor(
first rank of each dimension of the `device_mesh`.
Returns:
A :class:`DTensor` object
A :class:`DTensor` or `XLAShardedTensor` object.
Note:
When initialize the DeviceMesh with the `xla` device_type, `distribute_tensor`
return `XLAShardedTensor` instead. see [link](https://github.com/pytorch/pytorch/issues/92909)
for more details. The XLA integration is experimental and subject to change.
"""

torch._C._log_api_usage_once("torch.dtensor.distribute_tensor")

# get default device mesh if there's nothing specified
device_mesh = device_mesh or mesh_resources.get_current_mesh()
device_type = device_mesh.device_type
if device_type == "xla":
# call PyTorch/XLA SPMD for `xla` backend type device mesh.
# This returns XLAShardedTensor
return xla_distribute_tensor(
tensor, device_mesh, placements
) # type:ignore[return-value]

# instantiate a RNG tracker if haven't. By default DTensor uses an
# OffsetBasedRNGTracker to perform random operators.
Expand All @@ -485,7 +496,6 @@ def distribute_tensor(
f"`placements` must have the same length as `device_mesh.ndim`! "
f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}."
)

if isinstance(tensor, DTensor):
# if the tensor is already a DTensor, we just need to check if the
# device mesh and placements are the same
Expand Down
16 changes: 10 additions & 6 deletions torch/distributed/_tensor/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,16 @@ def __init__(
# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
self._hash = hash((self._flatten_mesh_list, self.mesh.shape))
# always try to create default (world) pg, even if it is not initialized
# already. The world pg is used for device mesh identity (rank) on each
# process (we need to know if the current global rank is in the mesh or not)
self._get_or_create_default_group()
if _init_process_groups:
self._init_process_groups(_validate_mesh)

# Skip process group initialization if xla device.
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
if device_type != "xla":
# always try to create default (world) pg, even if it is not initialized
# already. The world pg is used for device mesh identity (rank) on each
# process (we need to know if the current global rank is in the mesh or not).
self._get_or_create_default_group()
if _init_process_groups:
self._init_process_groups(_validate_mesh)

def _get_or_create_default_group(self):
default_initialized = is_initialized()
Expand Down

0 comments on commit c5e299f

Please sign in to comment.