Skip to content

Commit

Permalink
[Distributed] Improve sharding example (pytorch#937)
Browse files Browse the repository at this point in the history
* [Distributed] Improve sharding example

* Add comment
  • Loading branch information
kwen2501 authored and weifengpy committed Sep 26, 2024
1 parent fc6c393 commit 0043ace
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions tutorials/developer_api_guide/tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import torch
import torch.distributed as dist
from typing import Sequence
from torch.distributed import DeviceMesh
from torch.distributed._tensor import DTensor, Replicate, Shard
from torch.distributed.tensor import DTensor, Replicate, Shard, Placement
from torch.utils._python_dispatch import return_and_correct_aliasing
from my_dtype_tensor_subclass import MyDTypeTensor, fill_defaults

Expand Down Expand Up @@ -101,18 +102,40 @@ def quantize(m: torch.nn.Module) -> torch.nn.Module:
)
return m

def shard(
full_tensor: torch.Tensor,
device_mesh: DeviceMesh,
placements: Sequence[Placement],
) -> DTensor:
"""
Add a shard function to simplify both colwise_shard and rowwise_shard. The
shard function accepts a full tensor, and returns a DTensor based on
indicated placements. Goal is to move the shard function as a static method
of DTensor, e.g.
dtensor = DTensor.shard(full_tensor, device_mesh, placement)
"""
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset

shape, offset = compute_local_shape_and_global_offset(
full_tensor.shape, device_mesh, placements
)
slices = [
slice(cur_offset, cur_offset + cur_shape)
for cur_shape, cur_offset in zip(shape, offset)
]
local_tensor = full_tensor[slices]
return DTensor.from_local(
local_tensor, device_mesh, placements
)

def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
"""
Shard linear layer of the model in column-wise fashion
"""
# Column-wise is wrt to A^T, so for A it is row-wise.
# Number of rows per rank
orig_weight = m.linear.weight
n_local_rows = orig_weight.size(0) // mesh.size()
rank = mesh.get_local_rank()
local_shard = orig_weight[rank * n_local_rows : (rank + 1) * n_local_rows, :]
# Construct DTensor from local shard
dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)])
dtensor = shard(orig_weight, mesh, [Shard(0)])
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
Expand All @@ -124,13 +147,9 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
Shard linear layer of the model in row-wise fashion
"""
# Row-wise is wrt to A^T, so for A it is column-wise.
# Number of rows per rank
orig_weight = m.linear.weight
n_local_cols = orig_weight.size(1) // mesh.size()
rank = mesh.get_local_rank()
local_shard = orig_weight[:, rank * n_local_cols : (rank + 1) * n_local_cols]
# Construct DTensor from local shard
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)])
dtensor = shard(orig_weight, mesh, [Shard(1)])
# Replace parameter in module
m.linear.weight = torch.nn.Parameter(
dtensor, requires_grad=False
Expand Down

0 comments on commit 0043ace

Please sign in to comment.