Skip to content

Commit

Permalink
update dygraph auto parallel API interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
wuhuachao committed Nov 22, 2023
1 parent 4b05d5a commit ef3a6f7
Show file tree
Hide file tree
Showing 43 changed files with 633 additions and 690 deletions.
9 changes: 9 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,18 @@ DistTensor::DistTensor(const std::shared_ptr<phi::DenseTensor>& global_value,
placements,
DenseTensorMeta(global_value->dtype(), global_value->dims()));

std::vector<int64_t> partial_dims;
size_t idx = 0;
for (auto p : placements) {
if (p->is_partial()) {
partial_dims.push_back(idx);
}
idx++;
}
TensorDistAttr dist_attr(vectorize(dist_tensor_meta_.dims()));
dist_attr.set_process_mesh(dist_tensor_meta_.process_mesh());
dist_attr.set_dims_mapping(dist_tensor_meta_.dim_mapping());
dist_attr.set_partial_status(partial_dims);
dist_attr.mark_annotated("process_mesh");
dist_attr.mark_annotated("dims_mapping");
dist_attr_ = dist_attr;
Expand Down
9 changes: 6 additions & 3 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -7449,10 +7449,13 @@ def from_tensor(cls, tensor, **kwargs):
param = cls(tensor.shape, tensor.dtype, **kwargs)

# 2. transform data if needed
dist_attr = kwargs.get('dist_attr', None)
mesh = kwargs.get("process_mesh", None)
placements = kwargs.get("placements", None)
src_tensor = tensor
if dist_attr is not None:
src_tensor = core.eager.Tensor(tensor, dist_attr=dist_attr)
if mesh is not None and placements is not None:
src_tensor = core.eager.Tensor(
tensor, process_mesh=mesh, placements=placements
)

# 3. set param data
param._set_impl(src_tensor)
Expand Down
13 changes: 13 additions & 0 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@

from .auto_parallel.process_mesh import ProcessMesh

from .auto_parallel.placement_type import (
ReduceType,
Placement,
Shard,
Replicate,
Partial,
)

from .auto_parallel import shard_op # noqa: F401

from .auto_parallel.api import (
Expand Down Expand Up @@ -144,4 +152,9 @@
"dtensor_from_fn",
"reshard",
"shard_layer",
"ReduceType",
"Placement",
"Shard",
"Replicate",
"Partial",
]
72 changes: 45 additions & 27 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
)
from paddle.framework import core

from .placement_type import get_shard_spec

# There are the auto parallel API of the unified version of dynamic and static mode.
# Some APIs have the same name with the previous APIs implementation, which are
# a temporary state, and the APIs here will eventually be used.
Expand Down Expand Up @@ -92,7 +94,7 @@ def sharding_specs(self):


def shard_tensor(
data, dtype=None, place=None, stop_gradient=True, dist_attr=None
data, mesh, placements, dtype=None, place=None, stop_gradient=True
):
"""
Constructs a ``paddle.Tensor`` with distributed attributes from ``data``,
Expand All @@ -103,6 +105,9 @@ def shard_tensor(
Args:
data(scalar|tuple|list|ndarray|Tensor): Initial data for the tensor.
Can be a scalar, list, tuple, numpy.ndarray, paddle.Tensor.
mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes.
placements(list[paddle.distributed.Placement]): the placements describe how to place the tensor on ProcessMesh, it can
be Shard, Replicate and Partial.
dtype(str|np.dtype, optional): The desired data type of returned tensor. Can be 'bool' , 'float16' ,
'float32' , 'float64' , 'int8' , 'int16' , 'int32' , 'int64' , 'uint8',
'complex64' , 'complex128'. Default: None, infers dtype from ``data``
Expand All @@ -111,7 +116,6 @@ def shard_tensor(
CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place. If ``place`` is
string, It can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, where ``x`` is the index of the GPUs.
stop_gradient(bool, optional): Whether to block the gradient propagation of Autograd. Default: True.
dist_attr(paddle.distributed.DistAttr): Specify how tensors are distributed or sliced on ProcessMesh.
Returns:
Tensor: A Tensor constructed from ``data`` with distributed attributes.
Expand Down Expand Up @@ -146,33 +150,34 @@ def shard_tensor(
data, dtype=dtype, place=place, stop_gradient=stop_gradient
)

# 2. create dist tensor
assert len(dist_attr.dims_mapping) == len(
list(tensor.shape)
), "The length of sharding_specs must be same as the shape of the input tensor."

if paddle.in_dynamic_mode():
# here the dist tensor is deep copy constructed
if isinstance(data, EagerParamBase):
return EagerParamBase.from_tensor(
tensor, dist_attr=dist_attr, **tensor.__dict__
tensor,
process_mesh=mesh,
placements=placements,
**tensor.__dict__
)
else:
return paddle.Tensor(tensor, dist_attr=dist_attr, place=place)
return paddle.Tensor(
tensor, process_mesh=mesh, placements=placements, place=place
)
else:
# TODO(zhiqiu): we need to refine the static shard_tensor
return shard_tensor_static(
tensor, dist_attr.process_mesh, dist_attr.sharding_specs
)
sharding_specs = get_shard_spec(mesh, placements, tensor.ndim)
return shard_tensor_static(tensor, mesh, sharding_specs)


def dtensor_from_fn(fn, dist_attr, *args, **kwargs):
def dtensor_from_fn(fn, mesh, placements, *args, **kwargs):
"""
Construct a Distributed Tensor from a function of arguments.
Args:
fn (callable): A callable function that takes arguments of Distributed Tensor and returns tensor.
dist_attr (paddle.distributed.DistAttr): Specify how tensors are distributed or sliced on ProcessMesh.
mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes.
placements(list[paddle.distributed.Placement]): the placements describe how to place the tensor on ProcessMesh, it can
be Shard, Replicate and Partial.
*args (tuple): A tuple of arguments to be passed to the ``fn`` function.
**kwargs (dict): A dict of arguments to be passed to the ``fn`` function.
Expand All @@ -193,19 +198,21 @@ def dtensor_from_fn(fn, dist_attr, *args, **kwargs):
"""
tensor = fn(*args, **kwargs)
return shard_tensor(tensor, dist_attr=dist_attr)
return shard_tensor(tensor, mesh, placements)


# Part3: Data conversion related APIs


def reshard(dist_tensor, dist_attr):
def reshard(dist_tensor, mesh, placements):
"""
Reshard a distributed ``paddle.Tensor`` with given distributed attributes.
Args:
dist_tensor(Tensor): the distributed tensor to be resharded.
dist_attr(paddle.distributed.DistAttr): Specify how tensors are distributed or sliced on ProcessMesh.
mesh(paddle.distributed.ProcessMesh): The `ProcessMesh` object describes the Cartesian topology of the used processes.
placements(list[paddle.distributed.Placement]): the placements describe how to place the tensor on ProcessMesh, it can
be Shard, Replicate and Partial.
Returns:
Tensor: A Distributed Tensor reshared with distributed attributes.
Expand Down Expand Up @@ -238,6 +245,17 @@ def reshard(dist_tensor, dist_attr):
"""

if paddle.framework.in_dynamic_mode():
# TODO(LiYuRio): static logic here, reshard should be changed for dygraph logic
# when reshard has been changed align dygraph logic, delete it.
sharding_specs = get_shard_spec(mesh, placements, dist_tensor.ndim)
dist_attr = DistAttr(mesh, sharding_specs)
partial_dims = []
for i, p in enumerate(placements):
if isinstance(p, dist.Partial):
partial_dims.append(i)
if len(partial_dims) > 0:
dist_attr._set_partial_dims(partial_dims)

return paddle.base.core.reshard(dist_tensor, dist_attr)
else:
# TODO(GhostScreaming): Support static DistTensor later.
Expand Down Expand Up @@ -339,26 +357,26 @@ def replicate_layer_params_and_buffers(
) -> None:
for key, param in layer._parameters.items():
if param is not None and not param.is_dist():
replicated_dist_attr = dist.DistAttr(
mesh=mesh,
sharding_specs=[None for _ in range(len(param.shape))],
)
placements = [
paddle.distributed.Replicate()
for _ in range(len(param.shape))
]
layer.add_parameter(
key,
shard_tensor(param, dist_attr=replicated_dist_attr),
shard_tensor(param, mesh, placements),
)
else:
# do nothing, the dist parameters has already been shard by shard_fn
pass
for key, buffer in layer._buffers.items():
if buffer is not None and not buffer.is_dist():
replicated_dist_attr = dist.DistAttr(
mesh=mesh,
sharding_specs=[None for _ in range(len(buffer.shape))],
)
placements = [
paddle.distributed.Replicate()
for _ in range(len(buffer.shape))
]
layer.register_buffer(
key,
shard_tensor(buffer, dist_attr=replicated_dist_attr),
shard_tensor(buffer, mesh, placements),
)
else:
# do nothing, the dist buffers has already been shard by shard_fn
Expand Down
Loading

0 comments on commit ef3a6f7

Please sign in to comment.