Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SemiAuto] add static branch for shard_tensor #56561

Merged
merged 4 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.

import paddle
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.interface import (
shard_tensor as shard_tensor_static,
)
from paddle.framework import core

# There are the auto parallel API of the unified version of dynamic and static mode.
Expand Down Expand Up @@ -44,7 +46,7 @@ class DistAttr(core.TensorDistAttr):

def __init__(self, mesh, sharding_specs):
# 1. inputs checking
if not isinstance(mesh, ProcessMesh):
if not isinstance(mesh, core.ProcessMesh):
raise ValueError(
"The mesh must be an instance of paddle.distributed.ProcessMesh."
)
Expand All @@ -55,16 +57,31 @@ def __init__(self, mesh, sharding_specs):
for dim_name in sharding_specs
), 'The dimension name in sharding_specs must be an instance of str.'

self._sharding_specs = sharding_specs
dims_mapping = [
mesh.dim_names.index(dim_name) if dim_name is not None else -1
for dim_name in sharding_specs
]

# 2. init core.TensorDistAttr
core.TensorDistAttr.__init__(self)

self.process_mesh = mesh
self.dims_mapping = dims_mapping

self.mark_annotated("process_mesh")
self.mark_annotated("dims_mapping")

@property
def sharding_specs(self):
"""
Get sharding_specs of the dist_attr

Returns:
list[str]: sharding_specs
"""
return self._sharding_specs


def shard_tensor(
data, dtype=None, place=None, stop_gradient=True, dist_attr=None
Expand Down Expand Up @@ -121,6 +138,7 @@ def shard_tensor(
if paddle.in_dynamic_mode():
return paddle.Tensor(data, dist_attr=dist_attr)
else:
raise NotImplementedError(
"The `paddle.distributed.shard_tensor` for static mode will be implemented later."
# TODO(zhiqiu): we need to refine the static shard_tensor
return shard_tensor_static(
data, dist_attr.process_mesh, dist_attr.sharding_specs
)
3 changes: 2 additions & 1 deletion python/paddle/distributed/auto_parallel/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import paddle
from paddle.framework import core

from .process_mesh import ProcessMesh, get_current_process_mesh
from .static.dist_context import get_default_distributed_context
Expand Down Expand Up @@ -67,7 +68,7 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):

if process_mesh is not None:
assert isinstance(
process_mesh, ProcessMesh
process_mesh, core.ProcessMesh
), "Argument process_mesh {} is not an instance of ProcessMesh".format(
process_mesh
)
Expand Down
91 changes: 74 additions & 17 deletions test/auto_parallel/test_shard_tensor_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

import paddle
import paddle.distributed as dist
from paddle.distributed.auto_parallel.static.dist_context import (
get_default_distributed_context,
)
from paddle.fluid.dygraph.base import switch_to_static_graph


class TestDistAttrBasic(unittest.TestCase):
Expand Down Expand Up @@ -51,27 +55,80 @@ def test_sharding_specs_argument_error(self):
self.assertIsNotNone(exception)


class TestShardTensorBasic(unittest.TestCase):
# remove this test after static mode is supported
def test_static_mode_unimplemented(self):
exception = None
try:
paddle.enable_static()
class TestShardTensorDynamic(unittest.TestCase):
def setUp(self):
self.mesh = dist.ProcessMesh(
[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"]
)

def test_dynamic(self):
dist_attr = dist.DistAttr(
mesh=self.mesh, sharding_specs=['x', None, None]
)

input = paddle.rand([4, 1024, 512])
d_tensor = dist.shard_tensor(input, dist_attr=dist_attr)
print(dist_attr.dims_mapping)

self.assertEqual(d_tensor.dist_attr.process_mesh, self.mesh)
self.assertEqual(d_tensor.dist_attr.dims_mapping, [0, -1, -1])
self.assertTrue(d_tensor.dist_attr.is_annotated("process_mesh"))
self.assertTrue(d_tensor.dist_attr.is_annotated("dims_mapping"))


class TestShardTensorStatic(unittest.TestCase):
def setUp(self):
self.mesh = dist.ProcessMesh(
[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"]
)

@switch_to_static_graph
def test_static_mode(self):
dist_attr = dist.DistAttr(
mesh=self.mesh, sharding_specs=['x', None, None]
)

input = paddle.static.data(
name="input",
shape=[4, 1024, 512],
dtype='float32',
)
d_tensor = dist.shard_tensor(input, dist_attr=dist_attr)

default_dist_context = get_default_distributed_context()
dist_input = default_dist_context.get_dist_tensor_for_program(input)
self.assertEqual(dist_input.dist_attr.process_mesh, self.mesh)
self.assertEqual(dist_input.dist_attr.dims_mapping, [0, -1, -1])
self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh"))
self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping"))


class TestShardTensorStaticDy2Static(unittest.TestCase):
def test_dy2static(self):
@paddle.jit.to_static
def func():
mesh = dist.ProcessMesh(
[[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"]
[[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"]
)
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])
a = paddle.to_tensor([[1, 2, 3], [5, 6, 7]])
d_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
except NotImplementedError as ex:
self.assertIn(
"The `paddle.distributed.shard_tensor` for static mode will be implemented later",
str(ex),
dist_attr = dist.DistAttr(
mesh=mesh, sharding_specs=['x', None, None]
)
exception = ex
paddle.disable_static()

self.assertIsNotNone(exception)
input = paddle.rand([4, 1024, 512])
d_tensor = dist.shard_tensor(input, dist_attr=dist_attr)
return input, mesh

dy_tensor, mesh = func()
static_tensor = func.outputs[0] # get the inputs of static program

default_dist_context = get_default_distributed_context()
dist_input = default_dist_context.get_dist_tensor_for_program(
static_tensor
)
self.assertEqual(dist_input.dist_attr.process_mesh, mesh)
self.assertEqual(dist_input.dist_attr.dims_mapping, [0, -1, -1])
self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh"))
self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping"))


if __name__ == "__main__":
Expand Down