Skip to content

Commit

Permalink
linter fix
Browse files Browse the repository at this point in the history
  • Loading branch information
yeounoh committed Nov 8, 2023
1 parent 0842869 commit ccb2c4c
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 113 deletions.
81 changes: 81 additions & 0 deletions test/spmd/test_dtensor_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
import sys

import torch
from torch import nn
import torch.optim as optim
from torch.distributed._tensor import DeviceMesh, Shard
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
from torch_xla.experimental.spmd import xla_distribute_tensor

import unittest

import test_xla_sharding_base


class DTensorIntegrationTest(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
xr.use_spmd()
super().setUpClass()

def test_xla_distribute_tensor(self):
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
shard_spec = [Shard(0)]

for requires_grad in [True, False]:
tensor_to_shard = torch.randn(
3 * device_count,
3,
requires_grad=requires_grad,
device=xm.xla_device())
dist_tensor = xla_distribute_tensor(tensor_to_shard, device_mesh,
shard_spec)
# TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
assert type(dist_tensor).__name__ == "XLAShardedTensor"
assert len(dist_tensor.sharding_spec) > 0

global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined]
self.assertEqual(global_tensor.size(), torch.Size([3 * device_count, 3]))
local_tensor = dist_tensor.local_shards[0].data
self.assertEqual(local_tensor.size(), torch.Size([3, 3]))
if requires_grad:
self.assertTrue(dist_tensor.global_tensor.requires_grad)
self.assertTrue(dist_tensor.is_leaf)

def test_optimizer_step_with_sharding(self):
# Use simple linear model to test model parameter sharding
model = self.SimpleLinear().to(xm.xla_device())

# Running the same mark_sharding test with xla_distribute_tensor instead
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
shard_spec = [Shard(0)]
xla_distribute_tensor(model.fc1.weight, device_mesh, shard_spec)
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight)

model.train()
optimizer = optim.SGD(model.parameters(), lr=0.1)
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for i in range(3):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Sharding is persisted across mark_step calls, and test if the sharded computation
# can repeat more than once without crashing.
self.assertEqual(sharding_spec,
torch_xla._XLAC._get_xla_sharding_spec(model.fc1.weight))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
24 changes: 5 additions & 19 deletions torch_xla/experimental/spmd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,11 @@
from .xla_sharded_tensor import XLAShard, XLAShardedTensor
from .xla_sharding import (Mesh,
HybridMesh,
ShardingType,
ShardingSpec,
XLAPatchedLinear,
mark_sharding,
clear_sharding,
from .xla_sharding import (Mesh, HybridMesh, ShardingType, ShardingSpec,
XLAPatchedLinear, mark_sharding, clear_sharding,
wrap_if_sharded)
from .api import xla_distribute_tensor, xla_distribute_module

__all__ = [
"XLAShard",
"XLAShardedTensor",
"Mesh",
"HybridMesh",
"ShardingType",
"ShardingSpec",
"XLAPatchedLinear",
"mark_sharding",
"clear_sharding",
"wrap_if_sharded",
"xla_distribute_tensor",
"xla_distribute_module"
"XLAShard", "XLAShardedTensor", "Mesh", "HybridMesh", "ShardingType",
"ShardingSpec", "XLAPatchedLinear", "mark_sharding", "clear_sharding",
"wrap_if_sharded", "xla_distribute_tensor", "xla_distribute_module"
]
184 changes: 90 additions & 94 deletions torch_xla/experimental/spmd/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,45 +13,44 @@

TORCH_XLA_INITIALIZED = False
try:
import torch_xla.core.xla_model as xm # type:ignore[import] # noqa: F401
import torch_xla.runtime as xr # type:ignore[import]
from torch_xla.experimental.spmd import ( # type:ignore[import]
XLAShardedTensor,
)
from torch_xla.experimental.spmd import ( # type:ignore[import]
mark_sharding,
Mesh,
ShardingType,
)

TORCH_XLA_INITIALIZED = True
import torch_xla.core.xla_model as xm # type:ignore[import] # noqa: F401
import torch_xla.runtime as xr # type:ignore[import]
from torch_xla.experimental.spmd import ( # type:ignore[import]
XLAShardedTensor,)
from torch_xla.experimental.spmd import ( # type:ignore[import]
mark_sharding, Mesh, ShardingType,
)

TORCH_XLA_INITIALIZED = True
except ImportError as e:
log.warning(e.msg)
log.warning(e.msg)


# wrapper to check xla test requirements
def with_xla(func: Callable) -> Callable:
assert func is not None

@wraps(func) # pyre-ignore[6]
def wrapper(
self, *args: Tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc]
) -> None:
if TORCH_XLA_INITIALIZED:
# TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag.
os.environ["XLA_USE_SPMD"] = "1"
return func(self, *args, **kwargs) # type: ignore[misc]
else:
raise ImportError(
"torch.distributed._tensor._xla API requires torch_xla package installation."
)

return wrapper
assert func is not None

@wraps(func) # pyre-ignore[6]
def wrapper(
self,
*args: Tuple[object],
**kwargs: Dict[str, Any] # type: ignore[misc]
) -> None:
if TORCH_XLA_INITIALIZED:
# TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag.
os.environ["XLA_USE_SPMD"] = "1"
return func(self, *args, **kwargs) # type: ignore[misc]
else:
raise ImportError(
"torch.distributed._tensor._xla API requires torch_xla package installation."
)

return wrapper


@with_xla
def convert_to_xla_mesh(dt_mesh: DeviceMesh) -> "Mesh":
"""
"""
Convert DTensor `dt_mesh` to XLAShardedTensor `partition_spec`.
Example (1x4 logical device mesh topology):
Expand All @@ -65,17 +64,16 @@ def convert_to_xla_mesh(dt_mesh: DeviceMesh) -> "Mesh":
>> [1, 4]
```
"""
assert dt_mesh.size() == xr.global_runtime_device_count()
return Mesh(
dt_mesh.mesh.flatten(), tuple(dt_mesh.mesh.size()), dt_mesh.mesh_dim_names
)
assert dt_mesh.size() == xr.global_runtime_device_count()
return Mesh(dt_mesh.mesh.flatten(), tuple(dt_mesh.mesh.size()),
dt_mesh.mesh_dim_names)


@with_xla
def convert_to_xla_partition_spec(
tensor: torch.Tensor, placements: Sequence[Placement]
) -> Tuple[Union[Tuple, int, None]]:
"""
tensor: torch.Tensor,
placements: Sequence[Placement]) -> Tuple[Union[Tuple, int, None]]:
"""
Convert DTensor `placements` to XLAShardedTensor `partitoin_spec`.
This supports Shard and Replicate Placement types.
Expand All @@ -96,19 +94,19 @@ def convert_to_xla_partition_spec(
>> (1, None, None)
```
"""
# per tensor dimension sharding
sharding_spec = [None] * len(tensor.shape)
for mesh_idx, spec in enumerate(placements):
if spec.is_shard(): # type:ignore[truthy-function]
# mesh_idx to tensor_idx (spec.dim)
tensor_idx = spec.dim # type:ignore[attr-defined]
sharding_spec[tensor_idx] = mesh_idx # type:ignore[call-overload]
elif spec.is_replicate():
# spec.dim is already set to None by default
continue
else:
raise ValueError(f"Unsupported placement type: {type(spec).__name__}")
return tuple(sharding_spec) # type:ignore[return-value]
# per tensor dimension sharding
sharding_spec = [None] * len(tensor.shape)
for mesh_idx, spec in enumerate(placements):
if spec.is_shard(): # type:ignore[truthy-function]
# mesh_idx to tensor_idx (spec.dim)
tensor_idx = spec.dim # type:ignore[attr-defined]
sharding_spec[tensor_idx] = mesh_idx # type:ignore[call-overload]
elif spec.is_replicate():
# spec.dim is already set to None by default
continue
else:
raise ValueError(f"Unsupported placement type: {type(spec).__name__}")
return tuple(sharding_spec) # type:ignore[return-value]


@with_xla
Expand All @@ -117,7 +115,7 @@ def xla_distribute_tensor(
device_mesh: DeviceMesh,
placements: Optional[Sequence[Placement]] = None,
) -> "XLAShardedTensor":
"""
"""
Distribute a torch.Tensor to the `device_mesh` according to the `placements`
specified. The rank of `device_mesh` and `placements` must be the same.
Expand All @@ -144,49 +142,47 @@ def xla_distribute_tensor(
happens lazily, is transparent to the user. In the future, we will introduce
a new DTensor type for this kind of programming-mode (single-controller) and return.
"""
# device_mesh is not optional in xla_distribute_tensor
dt_mesh = device_mesh
assert dt_mesh.device_type == "xla"

# convert to XLA device mesh
xla_mesh = convert_to_xla_mesh(dt_mesh)
assert xla_mesh.mesh_shape == tuple(dt_mesh.mesh.size())

# convert tensor to the corresponding device type if it's not in that device type
if not tensor.is_meta:
tensor = tensor.to(dt_mesh.device_type)
# set default placements to replicated if not specified
if placements is None:
placements = [Replicate() for _ in range(dt_mesh.ndim)]
# device_mesh is not optional in xla_distribute_tensor
dt_mesh = device_mesh
assert dt_mesh.device_type == "xla"

# convert to XLA device mesh
xla_mesh = convert_to_xla_mesh(dt_mesh)
assert xla_mesh.mesh_shape == tuple(dt_mesh.mesh.size())

# convert tensor to the corresponding device type if it's not in that device type
if not tensor.is_meta:
tensor = tensor.to(dt_mesh.device_type)
# set default placements to replicated if not specified
if placements is None:
placements = [Replicate() for _ in range(dt_mesh.ndim)]
assert (len(placements) == dt_mesh.ndim
), "`placements` must have the same length as `device_mesh.ndim`! "
f"Found placements length: {len(placements)}, and device_mesh.ndim: {dt_mesh.ndim}."
# convert placements to xla partition spec
partition_spec = convert_to_xla_partition_spec(tensor, placements)
assert len(tensor.shape) == len(
partition_spec
), "`partition_spec` from `placements` must have the same length as `tensor.length`! "
f"Found tensor shape length: {len(tensor.shape)}, and partition_spec length: {len(partition_spec)}."

global_tensor = tensor
if type(tensor).__name__ == "DTensor":
raise ValueError(
"Cannot distribute a DTensor with local tensor on xla devices."
"The input tensor must be global.")
if type(tensor).__name__ == "XLAShardedTensor":
sharding_type = tensor.sharding_type # type:ignore[attr-defined]
assert (
len(placements) == dt_mesh.ndim
), "`placements` must have the same length as `device_mesh.ndim`! "
f"Found placements length: {len(placements)}, and device_mesh.ndim: {dt_mesh.ndim}."
# convert placements to xla partition spec
partition_spec = convert_to_xla_partition_spec(tensor, placements)
assert len(tensor.shape) == len(
partition_spec
), "`partition_spec` from `placements` must have the same length as `tensor.length`! "
f"Found tensor shape length: {len(tensor.shape)}, and partition_spec length: {len(partition_spec)}."

global_tensor = tensor
if type(tensor).__name__ == "DTensor":
raise ValueError(
"Cannot distribute a DTensor with local tensor on xla devices."
"The input tensor must be global."
)
if type(tensor).__name__ == "XLAShardedTensor":
sharding_type = tensor.sharding_type # type:ignore[attr-defined]
assert (
sharding_type is None or sharding_type == ShardingType.REPLICATED
), "XLAShardedTensor `tensor` is already annotated with non-replication sharding. "
"Clear the existing sharding annotation first, by callling torch_xla.experimental.spmd.clear_sharding API."
global_tensor = tensor.global_tensor # type:ignore[attr-defined]
assert global_tensor is not None, "distributing a tensor should not be None"

# Annotates sharding and returns an XLAShardedTensor
xla_tensor = mark_sharding(global_tensor, xla_mesh, partition_spec)
return xla_tensor
sharding_type is None or sharding_type == ShardingType.REPLICATED
), "XLAShardedTensor `tensor` is already annotated with non-replication sharding. "
"Clear the existing sharding annotation first, by callling torch_xla.experimental.spmd.clear_sharding API."
global_tensor = tensor.global_tensor # type:ignore[attr-defined]
assert global_tensor is not None, "distributing a tensor should not be None"

# Annotates sharding and returns an XLAShardedTensor
xla_tensor = mark_sharding(global_tensor, xla_mesh, partition_spec)
return xla_tensor


@with_xla
Expand All @@ -197,4 +193,4 @@ def xla_distribute_module(
input_fn: Optional[Callable[..., None]] = None,
output_fn: Optional[Callable[..., None]] = None,
) -> nn.Module:
raise NotImplementedError
raise NotImplementedError

0 comments on commit ccb2c4c

Please sign in to comment.