Skip to content

Commit

Permalink
update dygraph auto parallel API interface. (#59059)
Browse files Browse the repository at this point in the history
Co-authored-by: wuhuachao <wuhuachao@baidu.com>
  • Loading branch information
wuhuachaocoding and wuhuachao authored Nov 27, 2023
1 parent cac0a03 commit 33854f2
Show file tree
Hide file tree
Showing 49 changed files with 581 additions and 809 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ PyObject* ToPyObject(const phi::distributed::ProcessMesh* value) {
}

PyObject* ToPyObject(const phi::distributed::Placement& value) {
auto obj = ::pybind11::cast(value);
auto obj = ::pybind11::cast(value, py::return_value_policy::reference);
obj.inc_ref();
return obj.ptr();
}
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,18 @@ DistTensor::DistTensor(const std::shared_ptr<phi::DenseTensor>& global_value,
placements,
DenseTensorMeta(global_value->dtype(), global_value->dims()));

std::vector<int64_t> partial_dims;
size_t idx = 0;
for (auto p : placements) {
if (p->is_partial()) {
partial_dims.push_back(idx);
}
idx++;
}
TensorDistAttr dist_attr(vectorize(dist_tensor_meta_.dims()));
dist_attr.set_process_mesh(dist_tensor_meta_.process_mesh());
dist_attr.set_dims_mapping(dist_tensor_meta_.dim_mapping());
dist_attr.set_partial_status(partial_dims);
dist_attr.mark_annotated("process_mesh");
dist_attr.mark_annotated("dims_mapping");
dist_attr_ = dist_attr;
Expand Down
9 changes: 6 additions & 3 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -7465,10 +7465,13 @@ def from_tensor(cls, tensor, **kwargs):
param = cls(tensor.shape, tensor.dtype, **kwargs)

# 2. transform data if needed
dist_attr = kwargs.get('dist_attr', None)
mesh = kwargs.get("process_mesh", None)
placements = kwargs.get("placements", None)
src_tensor = tensor
if dist_attr is not None:
src_tensor = core.eager.Tensor(tensor, dist_attr=dist_attr)
if mesh is not None and placements is not None:
src_tensor = core.eager.Tensor(
tensor, process_mesh=mesh, placements=placements
)

# 3. set param data
param._set_impl(src_tensor)
Expand Down
13 changes: 13 additions & 0 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@

from .auto_parallel.process_mesh import ProcessMesh

from .auto_parallel.placement_type import (
ReduceType,
Placement,
Shard,
Replicate,
Partial,
)

from .auto_parallel import shard_op # noqa: F401

from .auto_parallel.api import (
Expand Down Expand Up @@ -144,4 +152,9 @@
"dtensor_from_fn",
"reshard",
"shard_layer",
"ReduceType",
"Placement",
"Shard",
"Replicate",
"Partial",
]
95 changes: 52 additions & 43 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
)
from paddle.framework import core

from .placement_type import get_shard_spec

# There are the auto parallel API of the unified version of dynamic and static mode.
# Some APIs have the same name with the previous APIs implementation, which are
# a temporary state, and the APIs here will eventually be used.
Expand Down Expand Up @@ -92,7 +94,7 @@ def sharding_specs(self):


def shard_tensor(
data, dtype=None, place=None, stop_gradient=True, dist_attr=None
data, mesh, placements, dtype=None, place=None, stop_gradient=True
):
"""
Constructs a ``paddle.Tensor`` with distributed attributes from ``data``,
Expand All @@ -103,6 +105,9 @@ def shard_tensor(
Args:
data(scalar|tuple|list|ndarray|Tensor): Initial data for the tensor.
Can be a scalar, list, tuple, numpy.ndarray, paddle.Tensor.
mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes.
placements(list[paddle.distributed.Placement]): the placements describe how to place the tensor on ProcessMesh, it can
be Shard, Replicate and Partial.
dtype(str|np.dtype, optional): The desired data type of returned tensor. Can be 'bool' , 'float16' ,
'float32' , 'float64' , 'int8' , 'int16' , 'int32' , 'int64' , 'uint8',
'complex64' , 'complex128'. Default: None, infers dtype from ``data``
Expand All @@ -111,7 +116,6 @@ def shard_tensor(
CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place. If ``place`` is
string, It can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, where ``x`` is the index of the GPUs.
stop_gradient(bool, optional): Whether to block the gradient propagation of Autograd. Default: True.
dist_attr(paddle.distributed.DistAttr): Specify how tensors are distributed or sliced on ProcessMesh.
Returns:
Tensor: A Tensor constructed from ``data`` with distributed attributes.
Expand All @@ -123,15 +127,14 @@ def shard_tensor(
>>> import paddle.distributed as dist
>>> mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=['x', 'y'])
>>> dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])
>>> # dense tensor
>>> a = paddle.to_tensor([[1,2,3],
... [5,6,7]])
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> # distributed tensor
>>> d_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
>>> d_tensor = dist.shard_tensor(a, mesh, [dist.Shard(0), dist.Shard(1)])
>>> print(d_tensor)
Expand All @@ -146,33 +149,34 @@ def shard_tensor(
data, dtype=dtype, place=place, stop_gradient=stop_gradient
)

# 2. create dist tensor
assert len(dist_attr.dims_mapping) == len(
list(tensor.shape)
), "The length of sharding_specs must be same as the shape of the input tensor."

if paddle.in_dynamic_mode():
# here the dist tensor is deep copy constructed
if isinstance(data, EagerParamBase):
return EagerParamBase.from_tensor(
tensor, dist_attr=dist_attr, **tensor.__dict__
tensor,
process_mesh=mesh,
placements=placements,
**tensor.__dict__
)
else:
return paddle.Tensor(tensor, dist_attr=dist_attr, place=place)
return paddle.Tensor(
tensor, process_mesh=mesh, placements=placements, place=place
)
else:
# TODO(zhiqiu): we need to refine the static shard_tensor
return shard_tensor_static(
tensor, dist_attr.process_mesh, dist_attr.sharding_specs
)
sharding_specs = get_shard_spec(mesh, placements, tensor.ndim)
return shard_tensor_static(tensor, mesh, sharding_specs)


def dtensor_from_fn(fn, dist_attr, *args, **kwargs):
def dtensor_from_fn(fn, mesh, placements, *args, **kwargs):
"""
Construct a Distributed Tensor from a function of arguments.
Args:
fn (callable): A callable function that takes arguments of Distributed Tensor and returns tensor.
dist_attr (paddle.distributed.DistAttr): Specify how tensors are distributed or sliced on ProcessMesh.
mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes.
placements(list[paddle.distributed.Placement]): the placements describe how to place the tensor on ProcessMesh, it can
be Shard, Replicate and Partial.
*args (tuple): A tuple of arguments to be passed to the ``fn`` function.
**kwargs (dict): A dict of arguments to be passed to the ``fn`` function.
Expand All @@ -186,26 +190,27 @@ def dtensor_from_fn(fn, dist_attr, *args, **kwargs):
>>> import paddle.distributed as dist
>>> # Create a distributed attribute
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None])
>>> # Call the function dtensor_from_fn with dist_attr parameter
>>> d_tensor = dist.dtensor_from_fn(paddle.ones, dist_attr=dist_attr, shape=[1])
>>> d_tensor = dist.dtensor_from_fn(paddle.ones, mesh, [dist.Replicate()], shape=[1])
>>> print(d_tensor)
"""
tensor = fn(*args, **kwargs)
return shard_tensor(tensor, dist_attr=dist_attr)
return shard_tensor(tensor, mesh, placements)


# Part3: Data conversion related APIs


def reshard(dist_tensor, dist_attr):
def reshard(dist_tensor, mesh, placements):
"""
Reshard a distributed ``paddle.Tensor`` with given distributed attributes.
Args:
dist_tensor(Tensor): the distributed tensor to be resharded.
dist_attr(paddle.distributed.DistAttr): Specify how tensors are distributed or sliced on ProcessMesh.
mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes.
placements(list[paddle.distributed.Placement]): the placements describe how to place the tensor on ProcessMesh, it can
be Shard, Replicate and Partial.
Returns:
Tensor: A Distributed Tensor reshared with distributed attributes.
Expand All @@ -216,28 +221,33 @@ def reshard(dist_tensor, dist_attr):
>>> import paddle
>>> import paddle.distributed as dist
>>> mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=['x', 'y'])
>>> dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])
>>> out_mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=['x', 'y'])
>>> out_dist_attr = dist.DistAttr(mesh=out_mesh, sharding_specs=[None, None])
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> # dense tensor
>>> a = paddle.to_tensor([[1,2,3],
... [5,6,7]])
>>> a = paddle.ones([10, 20])
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> # distributed tensor
>>> d_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
>>> d_tensor = dist.shard_tensor(a, mesh, [dist.Partial()])
>>> out_d_tensor = dist.reshard(d_tensor, out_dist_attr)
>>> out_d_tensor = dist.reshard(d_tensor, mesh, [dist.Replicate()])
>>> print(d_tensor)
>>> print(out_d_tensor)
"""

if paddle.framework.in_dynamic_mode():
# TODO(LiYuRio): static logic here, reshard should be changed for dygraph logic
# when reshard has been changed align dygraph logic, delete it.
sharding_specs = get_shard_spec(mesh, placements, dist_tensor.ndim)
dist_attr = DistAttr(mesh, sharding_specs)
partial_dims = []
for i, p in enumerate(placements):
if isinstance(p, dist.Partial):
partial_dims.append(i)
if len(partial_dims) > 0:
dist_attr._set_partial_dims(partial_dims)

return paddle.base.core.reshard(dist_tensor, dist_attr)
else:
# TODO(GhostScreaming): Support static DistTensor later.
Expand Down Expand Up @@ -312,9 +322,8 @@ def output_fn(outputs, process_mesh) -> list(paddle.Tensor)
... return self.fc2(self.fc1(input))
>>> def shard_fn(layer_name, layer, process_mesh):
... dist_attr = dist.DistAttr(mesh=process_mesh, sharding_specs=['x', None])
... if layer_name == 'fc1':
... layer.weight = dist.shard_tensor(layer.weight, dist_attr=dist_attr)
... layer.weight = dist.shard_tensor(layer.weight, process_mesh, [dist.Shard(0)])
>>> layer = MLP()
>>> layer = dist.shard_layer(layer, mesh, shard_fn)
Expand All @@ -339,26 +348,26 @@ def replicate_layer_params_and_buffers(
) -> None:
for key, param in layer._parameters.items():
if param is not None and not param.is_dist():
replicated_dist_attr = dist.DistAttr(
mesh=mesh,
sharding_specs=[None for _ in range(len(param.shape))],
)
placements = [
paddle.distributed.Replicate()
for _ in range(len(param.shape))
]
layer.add_parameter(
key,
shard_tensor(param, dist_attr=replicated_dist_attr),
shard_tensor(param, mesh, placements),
)
else:
# do nothing, the dist parameters has already been shard by shard_fn
pass
for key, buffer in layer._buffers.items():
if buffer is not None and not buffer.is_dist():
replicated_dist_attr = dist.DistAttr(
mesh=mesh,
sharding_specs=[None for _ in range(len(buffer.shape))],
)
placements = [
paddle.distributed.Replicate()
for _ in range(len(buffer.shape))
]
layer.register_buffer(
key,
shard_tensor(buffer, dist_attr=replicated_dist_attr),
shard_tensor(buffer, mesh, placements),
)
else:
# do nothing, the dist buffers has already been shard by shard_fn
Expand Down
89 changes: 89 additions & 0 deletions python/paddle/distributed/auto_parallel/placement_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast

from paddle.base.core import Partial, Placement, ReduceType, Replicate, Shard

__all__ = ["ReduceType", "Placement", "Replicate", "Shard", "Partial"]


def to_placements(dim_map, mesh, partial_idx=[]):
"""
convert dim_map to placements.
Args:
dim_map(List[int]): a list of integer that represents sharding on each tensor dimension.
mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes.
partial_idx(List[int], Optional): a list of integer that represents the DTensor have pending sum on which device mesh dimension
Returns:
List[Placement]: a list contains some `paddle.distributed.Placement`.
"""
placements = [Replicate() for _ in range(len(mesh.mesh.shape))]

for s in partial_idx:
placements[s] = Partial()

for i, m in enumerate(dim_map):
if m >= 0:
p = placements[m]
if p.is_shard():
p = cast(Shard, p)
raise Exception(
f"ProcessMesh dimension can not be mapped to two dimension of same tensor: {i} and {p.get_dim()}."
)
elif p.is_partial():
raise Exception(
f"ProcessMesh dimension {m} can not be both shard and partial!"
)
placements[m] = Shard(i)

return placements


def to_dim_map(placements, tensor_dims):
"""
convert placements to dim_map.
Args:
placements(List[Placement]): a list contains some `paddle.distributed.Placement`.
tensor_dims(int): the dimension of dist_tensor.
Returns:
List[int]: a list of integer that represents sharding on each tensor dimension.
"""
dim_map = [-1] * tensor_dims
for i, placement in enumerate(placements):
if placement.is_shard():
shard_dim = cast(Shard, placement).get_dim()
if dim_map[shard_dim] > -1:
raise Exception(
"Tensor dim {shard_dim} is already sharded on mesh dim {dim_map[shard_dim]}"
)

dim_map[shard_dim] = i

return dim_map


def get_shard_spec(mesh, placements, tensor_dims):
"""to get shard_spec for construct DistAttr for static API."""
dim_map = to_dim_map(placements, tensor_dims)
mesh_dim_names = mesh.dim_names
shard_spec = [None] * len(dim_map)
for i, d in enumerate(dim_map):
if d > -1:
shard_spec[i] = mesh_dim_names[d]

return shard_spec
Loading

0 comments on commit 33854f2

Please sign in to comment.