Skip to content

Commit

Permalink
update doc for dygraph auto parallel API.
Browse files Browse the repository at this point in the history
  • Loading branch information
wuhuachao committed Nov 22, 2023
1 parent ef3a6f7 commit b07083f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 22 deletions.
25 changes: 7 additions & 18 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,14 @@ def shard_tensor(
>>> import paddle.distributed as dist
>>> mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=['x', 'y'])
>>> dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])
>>> # dense tensor
>>> a = paddle.to_tensor([[1,2,3],
... [5,6,7]])
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> # distributed tensor
>>> d_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
>>> d_tensor = dist.shard_tensor(a, mesh, [dist.Shard(0), dist.Shard(1)])
>>> print(d_tensor)
Expand Down Expand Up @@ -191,9 +190,8 @@ def dtensor_from_fn(fn, mesh, placements, *args, **kwargs):
>>> import paddle.distributed as dist
>>> # Create a distributed attribute
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=[None])
>>> # Call the function dtensor_from_fn with dist_attr parameter
>>> d_tensor = dist.dtensor_from_fn(paddle.ones, dist_attr=dist_attr, shape=[1])
>>> d_tensor = dist.dtensor_from_fn(paddle.ones, mesh, [dist.Replicate()], shape=[1])
>>> print(d_tensor)
"""
Expand Down Expand Up @@ -223,23 +221,15 @@ def reshard(dist_tensor, mesh, placements):
>>> import paddle
>>> import paddle.distributed as dist
>>> mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=['x', 'y'])
>>> dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])
>>> out_mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=['x', 'y'])
>>> out_dist_attr = dist.DistAttr(mesh=out_mesh, sharding_specs=[None, None])
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> # dense tensor
>>> a = paddle.to_tensor([[1,2,3],
... [5,6,7]])
>>> a = paddle.ones([10, 20])
>>> # doctest: +REQUIRES(env:DISTRIBUTED)
>>> # distributed tensor
>>> d_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
>>> d_tensor = dist.shard_tensor(a, mesh, [dist.Partial()])
>>> out_d_tensor = dist.reshard(d_tensor, out_dist_attr)
>>> out_d_tensor = dist.reshard(d_tensor, mesh, [dist.Replicate()])
>>> print(d_tensor)
>>> print(out_d_tensor)
"""
Expand Down Expand Up @@ -330,9 +320,8 @@ def output_fn(outputs, process_mesh) -> list(paddle.Tensor)
... return self.fc2(self.fc1(input))
>>> def shard_fn(layer_name, layer, process_mesh):
... dist_attr = dist.DistAttr(mesh=process_mesh, sharding_specs=['x', None])
... if layer_name == 'fc1':
... layer.weight = dist.shard_tensor(layer.weight, dist_attr=dist_attr)
... layer.weight = dist.shard_tensor(layer.weight, process_mesh, [dist.Shard(0)])
>>> layer = MLP()
>>> layer = dist.shard_layer(layer, mesh, shard_fn)
Expand Down
40 changes: 36 additions & 4 deletions python/paddle/distributed/auto_parallel/placement_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,24 @@

from paddle.base import core

__all__ = ["ReduceType", "Placement", "Replicate", "Shard", "Partial"]


class ReduceType:
"""
Specify the type of operation used for paddle.distributed.Partial().
It should be one of the following values:
ReduceType.kRedSum
ReduceType.kRedMax
ReduceType.kRedMin
ReduceType.kRedProd
ReduceType.kRedAvg
ReduceType.kRedAny
ReduceType.kRedAll
"""

kRedSum = 0
kRedMax = 1
kRedMin = 2
Expand All @@ -27,7 +43,11 @@ class ReduceType:


class Placement(core.Placement):
# Placement (base class)
"""
The `Placement` is base class that describes how to place the tensor on ProcessMesh.
"""

def __init__(self):
super().__init__()

Expand All @@ -48,7 +68,11 @@ def __repr__(self):


class Replicate(core.Replicate):
# Replicate placement
"""
The `Replicate` describes the tensor placed repeatedly on ProcessMesh, see `paddle.distributed.shard_tensor`.
"""

def __init__(self):
super().__init__()

Expand All @@ -62,7 +86,11 @@ def __repr__(self):


class Shard(core.Shard):
# Shard placement
"""
The `Shard` describes the tensor placed shardly along the dim of ProcessMesh, see `paddle.distributed.shard_tensor`.
"""

def __init__(self, dim):
super().__init__(dim)
self.dim = dim
Expand All @@ -80,7 +108,11 @@ def __repr__(self):


class Partial(core.Partial):
# Partial placement
"""
The `Partial` describes the tensor placed partially on ProcessMesh, see `paddle.distributed.reshard`.
"""

def __init__(self, reduce_type=None):
if reduce_type is None:
self.reduce_type = core.ReduceType.kRedSum
Expand Down

0 comments on commit b07083f

Please sign in to comment.