From 31ed156341587538e9df37b274bfd0dd2884cf18 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Tue, 22 Aug 2023 21:03:49 +0800 Subject: [PATCH 1/4] shard_tensor support static graph --- .../paddle/distributed/auto_parallel/api.py | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index b25799d058ad2..226617fad67e0 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -13,6 +13,9 @@ # limitations under the License. import paddle +from paddle.distributed.auto_parallel.interface import ( + shard_tensor as shard_tensor_static, +) from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.framework import core @@ -55,6 +58,7 @@ def __init__(self, mesh, sharding_specs): for dim_name in sharding_specs ), 'The dimension name in sharding_specs must be an instance of str.' + self._sharding_specs = sharding_specs dims_mapping = [ mesh.dim_names.index(dim_name) if dim_name is not None else -1 for dim_name in sharding_specs @@ -62,8 +66,20 @@ def __init__(self, mesh, sharding_specs): # 2. init core.TensorDistAttr core.TensorDistAttr.__init__(self) - self.process_mesh = mesh - self.dims_mapping = dims_mapping + self._process_mesh = mesh + self._dims_mapping = dims_mapping + + @property + def process_mesh(self): + return self._process_mesh + + @property + def dims_mapping(self): + return self._dims_mapping + + @property + def sharding_specs(self): + return self._sharding_specs def shard_tensor( @@ -121,6 +137,7 @@ def shard_tensor( if paddle.in_dynamic_mode(): return paddle.Tensor(data, dist_attr=dist_attr) else: - raise NotImplementedError( - "The `paddle.distributed.shard_tensor` for static mode will be implemented later." + # TODO(zhiqiu): we need to refine the static shard_tensor + shard_tensor_static( + data, dist_attr.process_mesh, dist_attr.sharding_specs ) From b2f5c24107a831029503402f2141c6a9108019a4 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Tue, 22 Aug 2023 21:09:00 +0800 Subject: [PATCH 2/4] add comments --- python/paddle/distributed/auto_parallel/api.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 226617fad67e0..74c0ddbcb6893 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -71,14 +71,32 @@ def __init__(self, mesh, sharding_specs): @property def process_mesh(self): + """ + Get process_mesh of the dist_attr + + Returns: + paddle.distributed.ProcessMesh: process_mesh + """ return self._process_mesh @property def dims_mapping(self): + """ + Get dims_mapping of the dist_attr + + Returns: + list[int]: dims_mapping + """ return self._dims_mapping @property def sharding_specs(self): + """ + Get sharding_specs of the dist_attr + + Returns: + list[str]: sharding_specs + """ return self._sharding_specs From 8d769856aa190ab0b797b9f94f3dbce94845e297 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Wed, 23 Aug 2023 12:07:21 +0800 Subject: [PATCH 3/4] add dy2static ut --- .../paddle/distributed/auto_parallel/api.py | 2 +- test/auto_parallel/test_shard_tensor_api.py | 70 ++++++++++++++----- 2 files changed, 54 insertions(+), 18 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 74c0ddbcb6893..ad7211918e1ca 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -156,6 +156,6 @@ def shard_tensor( return paddle.Tensor(data, dist_attr=dist_attr) else: # TODO(zhiqiu): we need to refine the static shard_tensor - shard_tensor_static( + return shard_tensor_static( data, dist_attr.process_mesh, dist_attr.sharding_specs ) diff --git a/test/auto_parallel/test_shard_tensor_api.py b/test/auto_parallel/test_shard_tensor_api.py index 124c7dc7ba39e..6108372329f52 100644 --- a/test/auto_parallel/test_shard_tensor_api.py +++ b/test/auto_parallel/test_shard_tensor_api.py @@ -16,6 +16,10 @@ import paddle import paddle.distributed as dist +from paddle.distributed.auto_parallel.static.dist_context import ( + get_default_distributed_context, +) +from paddle.fluid.dygraph.base import switch_to_static_graph class TestDistAttrBasic(unittest.TestCase): @@ -51,27 +55,59 @@ def test_sharding_specs_argument_error(self): self.assertIsNotNone(exception) -class TestShardTensorBasic(unittest.TestCase): - # remove this test after static mode is supported - def test_static_mode_unimplemented(self): - exception = None - try: - paddle.enable_static() +class TestShardTensorStatic(unittest.TestCase): + def setUp(self): + self.mesh = dist.ProcessMesh( + [[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"] + ) + + @switch_to_static_graph + def test_static_mode(self): + dist_attr = dist.DistAttr( + mesh=self.mesh, sharding_specs=['x', None, None] + ) + + input = paddle.static.data( + name="input", + shape=[4, 1024, 512], + dtype='float32', + ) + d_tensor = dist.shard_tensor(input, dist_attr=dist_attr) + + default_dist_context = get_default_distributed_context() + dist_input = default_dist_context.get_dist_tensor_for_program(input) + self.assertEqual(dist_input.dist_attr.process_mesh, self.mesh) + self.assertEqual(dist_input.dist_attr.dims_mapping, [0, -1, -1]) + self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh")) + self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping")) + + +class TestShardTensorStaticDy2Static(unittest.TestCase): + def test_dy2static(self): + @paddle.jit.to_static + def func(): mesh = dist.ProcessMesh( - [[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"] + [[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"] ) - dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y']) - a = paddle.to_tensor([[1, 2, 3], [5, 6, 7]]) - d_tensor = dist.shard_tensor(a, dist_attr=dist_attr) - except NotImplementedError as ex: - self.assertIn( - "The `paddle.distributed.shard_tensor` for static mode will be implemented later", - str(ex), + dist_attr = dist.DistAttr( + mesh=mesh, sharding_specs=['x', None, None] ) - exception = ex - paddle.disable_static() - self.assertIsNotNone(exception) + input = paddle.rand([4, 1024, 512]) + d_tensor = dist.shard_tensor(input, dist_attr=dist_attr) + return input, mesh + + dy_tensor, mesh = func() + static_tensor = func.outputs[0] # get the inputs of static program + + default_dist_context = get_default_distributed_context() + dist_input = default_dist_context.get_dist_tensor_for_program( + static_tensor + ) + self.assertEqual(dist_input.dist_attr.process_mesh, mesh) + self.assertEqual(dist_input.dist_attr.dims_mapping, [0, -1, -1]) + self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh")) + self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping")) if __name__ == "__main__": From 7ff171785dad4ba286bf4411897bdcc8645e4f78 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Wed, 23 Aug 2023 20:27:18 +0800 Subject: [PATCH 4/4] use property in c++ side --- .../paddle/distributed/auto_parallel/api.py | 27 ++++--------------- .../distributed/auto_parallel/interface.py | 3 ++- test/auto_parallel/test_shard_tensor_api.py | 21 +++++++++++++++ 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index ad7211918e1ca..251eb6bb63263 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -16,7 +16,6 @@ from paddle.distributed.auto_parallel.interface import ( shard_tensor as shard_tensor_static, ) -from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.framework import core # There are the auto parallel API of the unified version of dynamic and static mode. @@ -47,7 +46,7 @@ class DistAttr(core.TensorDistAttr): def __init__(self, mesh, sharding_specs): # 1. inputs checking - if not isinstance(mesh, ProcessMesh): + if not isinstance(mesh, core.ProcessMesh): raise ValueError( "The mesh must be an instance of paddle.distributed.ProcessMesh." ) @@ -66,28 +65,12 @@ def __init__(self, mesh, sharding_specs): # 2. init core.TensorDistAttr core.TensorDistAttr.__init__(self) - self._process_mesh = mesh - self._dims_mapping = dims_mapping - @property - def process_mesh(self): - """ - Get process_mesh of the dist_attr - - Returns: - paddle.distributed.ProcessMesh: process_mesh - """ - return self._process_mesh + self.process_mesh = mesh + self.dims_mapping = dims_mapping - @property - def dims_mapping(self): - """ - Get dims_mapping of the dist_attr - - Returns: - list[int]: dims_mapping - """ - return self._dims_mapping + self.mark_annotated("process_mesh") + self.mark_annotated("dims_mapping") @property def sharding_specs(self): diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index 06a24b0c5433e..81f0133d31b08 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -13,6 +13,7 @@ # limitations under the License. import paddle +from paddle.framework import core from .process_mesh import ProcessMesh, get_current_process_mesh from .static.dist_context import get_default_distributed_context @@ -67,7 +68,7 @@ def shard_tensor(x, process_mesh=None, shard_spec=None): if process_mesh is not None: assert isinstance( - process_mesh, ProcessMesh + process_mesh, core.ProcessMesh ), "Argument process_mesh {} is not an instance of ProcessMesh".format( process_mesh ) diff --git a/test/auto_parallel/test_shard_tensor_api.py b/test/auto_parallel/test_shard_tensor_api.py index 6108372329f52..764cbdc36e2d1 100644 --- a/test/auto_parallel/test_shard_tensor_api.py +++ b/test/auto_parallel/test_shard_tensor_api.py @@ -55,6 +55,27 @@ def test_sharding_specs_argument_error(self): self.assertIsNotNone(exception) +class TestShardTensorDynamic(unittest.TestCase): + def setUp(self): + self.mesh = dist.ProcessMesh( + [[0, 1, 2, 3], [4, 5, 6, 7]], dim_names=["x", "y"] + ) + + def test_dynamic(self): + dist_attr = dist.DistAttr( + mesh=self.mesh, sharding_specs=['x', None, None] + ) + + input = paddle.rand([4, 1024, 512]) + d_tensor = dist.shard_tensor(input, dist_attr=dist_attr) + print(dist_attr.dims_mapping) + + self.assertEqual(d_tensor.dist_attr.process_mesh, self.mesh) + self.assertEqual(d_tensor.dist_attr.dims_mapping, [0, -1, -1]) + self.assertTrue(d_tensor.dist_attr.is_annotated("process_mesh")) + self.assertTrue(d_tensor.dist_attr.is_annotated("dims_mapping")) + + class TestShardTensorStatic(unittest.TestCase): def setUp(self): self.mesh = dist.ProcessMesh(