Skip to content

Commit

Permalink
[SPMD][DTensor] introduce xla_distribute_module for DTensor integration
Browse files Browse the repository at this point in the history
  • Loading branch information
yeounoh committed Mar 7, 2024
1 parent 6bcf4fd commit a157894
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 2 deletions.
36 changes: 35 additions & 1 deletion test/spmd/test_dtensor_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
from torch_xla.distributed.spmd import xla_distribute_tensor
from torch_xla.distributed.spmd import xla_distribute_tensor, xla_distribute_module

import unittest

Expand Down Expand Up @@ -47,6 +47,40 @@ def test_xla_distribute_tensor(self):
self.assertTrue(dist_tensor.global_tensor.requires_grad)
self.assertTrue(dist_tensor.is_leaf)

def test_xla_distribute_module(self):
model = self.SimpleLinear().to(xm.xla_device())

device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))

def shard_params(mod_name, mod, mesh):
shard_spec = [Shard(0)]
# annoate fc1 and fc2
if isinstance(mod, nn.Linear):
for name, param in mod.named_parameters():
dist_param = xla_distribute_tensor(param, mesh, shard_spec)
mod.register_parameter(name, dist_param)

sharded_model = xla_distribute_module(model, device_mesh, shard_params)
self.assertTrue(
torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc1.weight) != "")
self.assertTrue(
torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc2.weight) != "")

sharded_model.train()
optimizer = optim.SGD(model.parameters(), lr=0.1)
data = torch.randn(128, 128).to(xm.xla_device())
target = torch.zeros(128).to(xm.xla_device())
loss_fn = nn.CrossEntropyLoss()
for i in range(3):
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()


def test_optimizer_step_with_sharding(self):
# Use simple linear model to test model parameter sharding
model = self.SimpleLinear().to(xm.xla_device())
Expand Down
73 changes: 72 additions & 1 deletion torch_xla/distributed/spmd/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import warnings
import inspect
import os
from functools import wraps
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -179,4 +181,73 @@ def xla_distribute_module(
input_fn: Optional[Callable[..., None]] = None,
output_fn: Optional[Callable[..., None]] = None,
) -> nn.Module:
raise NotImplementedError
"""
This function annotates all module parameters for auto-partitioning with
PyTorch/XLA SPMD or explicitly partition to :class:`XLAShardedTensor` parameters
according to the `partition_fn` specified. It could also control the input or
output of the module by specifying the `input_fn` and `output_fn`. (i.e. convert
the input to :class:`XLAShardedTensor`, convert the output back to torch.Tensor)
Args:
module (:class:`nn.Module`): user module to be partitioned.
device_mesh (:class:`DeviceMesh`): the device mesh to place the module.
partition_fn (Callable): the function to partition parameters (i.e. shard certain
parameters across the `device_mesh`). If `partition_fn` is not specified,
by default we replicate all module parameters of `module` across the mesh.
input_fn (Callable): specify the input distribution, i.e. could control how the
input of the module is sharded. `input_fn` will be installed as a module
`forward_pre_hook` (pre forward hook).
output_fn (Callable): specify the output distribution, i.e. could control how the
output is sharded, or convert it back to torch.Tensor. output_fn will be
installed as a module `forward_hook` (post forward hook).
Returns:
A module that contains parameters/buffers that are all `DTensor`s.
"""

if partition_fn:
# apply partition_fun to submodules
for name, submod in module.named_modules():
partition_fn(name, submod, device_mesh)
# non-partitioned (annotated) submodules and parameters are implicilty replicated

# register input_fn as module forward pre hook
if input_fn is not None:
# check the input_fn signature
num_args = len(inspect.signature(input_fn).parameters)
if num_args == 2:
# input_fn only takes in inputs and device mesh
warnings.warn(
"Deprecating input_fn that takes two arguments (inputs, device_mesh), "
"please use input_fn that takes in (module, inputs, device_mesh) instead!",
)
module.register_forward_pre_hook(lambda _, inputs: input_fn(
inputs, device_mesh)) # type: ignore[call-arg]
elif num_args == 3:
# input_fn takes in module, inputs, device mesh
module.register_forward_pre_hook(
lambda mod, inputs: input_fn(mod, inputs, device_mesh))
else:
raise ValueError(
f"input_fn should take in 3 arguments, but got {num_args} arguments!")

# register output_fn as module forward hook
if output_fn is not None:
num_args = len(inspect.signature(output_fn).parameters)
if num_args == 2:
# output_fn only takes in outputs and device mesh
warnings.warn(
"Deprecating output_fn that takes two arguments (inputs, device_mesh), "
"please use output_fn that takes in (module, inputs, device_mesh) instead!",
)
module.register_forward_hook(lambda mod, inputs, outputs: output_fn(
outputs, device_mesh) # type: ignore[call-arg]
)
elif num_args == 3:
module.register_forward_hook(
lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
else:
raise ValueError(
f"output_fn should take in 3 arguments, but got {num_args} arguments!"
)

return module

0 comments on commit a157894

Please sign in to comment.