From 0eb5034223e55592409ff2c8f022627d3ebd1943 Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Mon, 30 Oct 2023 09:18:33 +0800 Subject: [PATCH] [auto parallel] enable spmd rules for maximum (#58360) --- paddle/phi/api/yaml/legacy_backward.yaml | 1 + paddle/phi/api/yaml/legacy_ops.yaml | 1 + paddle/phi/infermeta/spmd_rules/elementwise.h | 2 +- .../semi_auto_parallel_for_elementwise.py | 43 +++++++++++++++++++ 4 files changed, 46 insertions(+), 1 deletion(-) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 7453c7ec49a908..0cb62f2a84c765 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -363,6 +363,7 @@ infer_meta : func : GeneralBinaryGradInferMeta param: [x, y] + spmd_rule: ElementwiseBinaryGradInferSpmd kernel : func : maximum_grad composite : maximum_grad(x, y, out_grad, axis, x_grad, y_grad) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 6b1206f617bcaa..f43ae357df3e81 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -717,6 +717,7 @@ output : Tensor(out) infer_meta : func : ElementwiseInferMeta + spmd_rule : ElementwiseBinaryInferSpmd kernel : func : maximum backward : maximum_grad diff --git a/paddle/phi/infermeta/spmd_rules/elementwise.h b/paddle/phi/infermeta/spmd_rules/elementwise.h index 188e557e6737b0..637c3b793b6c44 100644 --- a/paddle/phi/infermeta/spmd_rules/elementwise.h +++ b/paddle/phi/infermeta/spmd_rules/elementwise.h @@ -40,7 +40,7 @@ SpmdInfo ElementwiseBinaryInferSpmdReverse(const DistMetaTensor& x, SpmdInfo ElementwiseBinaryGradInferSpmd(const DistMetaTensor& x, const DistMetaTensor& y, const DistMetaTensor& out_grad, - int64_t axis); + int64_t axis = -1); } // namespace distributed } // namespace phi diff --git a/test/auto_parallel/semi_auto_parallel_for_elementwise.py b/test/auto_parallel/semi_auto_parallel_for_elementwise.py index 24bf0c8be9e88b..2a55f7d02df03b 100644 --- a/test/auto_parallel/semi_auto_parallel_for_elementwise.py +++ b/test/auto_parallel/semi_auto_parallel_for_elementwise.py @@ -163,6 +163,49 @@ def test_relu_x_shard(self): unary_func=F.relu, ) + def test_maximum_x_shard(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, None], + binary_func=paddle.maximum, + ) + + def test_maximum_x_shard_broadcast(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[2, 16, 32], + out_shape=[2, 16, 32], + x_specs=['x', None], + y_specs=[None, None, None], + binary_func=paddle.maximum, + ) + + def test_maximum_x_y_shard(self): + if self._backend == "cpu": + return + + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, 'x'], + binary_func=paddle.maximum, + ) + + def test_maximum_x_y_shard_broadcast(self): + self.test_binary_body( + x_shape=[4, 16, 32], + y_shape=[16, 32], + out_shape=[4, 16, 32], + x_specs=['x', None, None], + y_specs=[None, None], + binary_func=paddle.maximum, + ) + def run_test_case(self): if self._backend == "cpu": paddle.set_device("cpu")