From bc6fa4fca75d6b7ee908748fb78eb7b4348fc1dd Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Tue, 10 Oct 2023 16:57:19 +0800 Subject: [PATCH] support elementwise backward rule (#57813) --- paddle/phi/api/lib/data_transform.cc | 15 +- paddle/phi/api/yaml/legacy_backward.yaml | 2 + paddle/phi/api/yaml/legacy_ops.yaml | 2 + .../phi/infermeta/spmd_rules/elementwise.cc | 67 ++++++++ paddle/phi/infermeta/spmd_rules/elementwise.h | 5 + paddle/phi/kernels/all_gather_kernel.h | 4 +- paddle/phi/kernels/all_reduce_kernel.h | 4 +- paddle/phi/kernels/all_to_all_kernel.h | 4 +- paddle/phi/kernels/concat_kernel.h | 5 +- paddle/phi/kernels/cpu/svd_kernel.cc | 2 +- paddle/phi/kernels/reshape_kernel.h | 4 +- paddle/phi/kernels/split_kernel.h | 4 +- paddle/phi/kernels/transpose_kernel.h | 5 +- .../semi_auto_parallel_for_elementwise.py | 147 ++++++++++++++++++ .../semi_auto_parallel_for_matmul.py | 7 +- .../semi_auto_parallel_for_replicated_spmd.py | 5 +- .../semi_auto_parallel_simple_net.py | 5 +- .../test_semi_auto_parallel_basic.py | 12 +- ...test_semi_auto_parallel_single_strategy.py | 1 + 19 files changed, 272 insertions(+), 28 deletions(-) create mode 100644 test/auto_parallel/semi_auto_parallel_for_elementwise.py diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 8c9a57f264db4..6fd1ddf87c4a2 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -652,18 +652,9 @@ ReshardApiInputToReplicatedKernelInput( if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) { VLOG(6) << "ApiIn to Replicated KernelIn - " << ReshardDebugInfo(*dist_tensor, dist_attr); - if (dist_tensor->initialized()) { - auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor, - dist_attr); - return func->Eval(dev_ctx, *dist_tensor, dist_attr); - } else { - // when no tensor data need to be reshard, we still need to set correct - // replicated dist attr and local dims for output - dist_tensor->unsafe_set_dist_attr(dist_attr); - auto dense_tensor_meta = dist_tensor->value().meta(); - dense_tensor_meta.dims = dist_tensor->dims(); - dist_tensor->unsafe_mutable_value()->set_meta(dense_tensor_meta); - } + auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor, + dist_attr); + return func->Eval(dev_ctx, *dist_tensor, dist_attr); } return std::static_pointer_cast(tensor_in); } diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 17d53342ba277..d95bc19c57bff 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -19,6 +19,7 @@ infer_meta : func : GeneralBinaryGradInferMeta param : [x, y] + spmd_rule : ElementwiseBinaryGradInferSpmd kernel : func : add_grad no_need_buffer : x, y @@ -680,6 +681,7 @@ infer_meta : func : GeneralBinaryGradInferMeta param : [x, y] + spmd_rule : ElementwiseBinaryGradInferSpmd kernel : func : subtract_grad no_need_buffer : x, y diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 14daf99fd7f13..01acb338c987b 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -17,6 +17,7 @@ output : Tensor(out) infer_meta : func : ElementwiseInferMeta + spmd_rule : ElementwiseBinaryInferSpmd kernel : func : add inplace : (x -> out) @@ -1003,6 +1004,7 @@ output : Tensor(out) infer_meta : func : ElementwiseInferMeta + spmd_rule : ElementwiseBinaryInferSpmd kernel : func : subtract inplace : (x -> out) diff --git a/paddle/phi/infermeta/spmd_rules/elementwise.cc b/paddle/phi/infermeta/spmd_rules/elementwise.cc index 411c43de8cc41..24d6bed03c52d 100644 --- a/paddle/phi/infermeta/spmd_rules/elementwise.cc +++ b/paddle/phi/infermeta/spmd_rules/elementwise.cc @@ -309,5 +309,72 @@ SpmdInfo ElementwiseBinaryInferSpmdReverse(const DistMetaTensor& x, return {{x_dist_attr_dst, y_dist_attr_dst}, {out_dist_attr}}; } +SpmdInfo ElementwiseBinaryGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& y, + const DistMetaTensor& out_grad, + int64_t axis) { + TensorDistAttr x_dist_attr = out_grad.dist_attr(); + TensorDistAttr y_dist_attr = out_grad.dist_attr(); + TensorDistAttr x_grad_dist_attr = out_grad.dist_attr(); + TensorDistAttr y_grad_dist_attr = out_grad.dist_attr(); + + PADDLE_ENFORCE_GE( + out_grad.dims().size(), + x.dims().size(), + phi::errors::InvalidArgument("If being broadcast, the dims of out_grad " + "must larger or equal to the inputs." + "But we get the rank of output as [%d] and " + "the rank of input as [%d].", + out_grad.dims().size(), + x.dims().size())); + + PADDLE_ENFORCE_GE( + out_grad.dims().size(), + y.dims().size(), + phi::errors::InvalidArgument("If being broadcast, the dims of out_grad " + "must larger or equal to the inputs." + "But we get the rank of output as [%d] and " + "the rank of input as [%d].", + out_grad.dims().size(), + y.dims().size())); + + // The backward rule of elementwise follows the princple: the dist_attr + // of input should equal to out_grad. + // Caution the special case when the inputs calculate together with different + // shape it means one of the input is broadcast to same shape with the other + // first. When doing backward the input_grad with broadcast input is in + // partial status, which need to do communicate and get the right result. + if (x.dims() != out_grad.dims()) { + int64_t diff = out_grad.dims().size() - x.dims().size(); + auto dims_mapping = x_dist_attr.dims_mapping(); + dims_mapping.erase(dims_mapping.begin(), dims_mapping.begin() + diff); + x_dist_attr.set_dims_mapping(dims_mapping); + x_grad_dist_attr.set_dims_mapping(dims_mapping); + for (int64_t i = 0; i < diff; ++i) { + if (out_grad.dist_attr().dims_mapping()[i] != -1) { + x_grad_dist_attr.set_partial_status( + std::vector{out_grad.dist_attr().dims_mapping()[i]}); + } + } + } + + if (y.dims() != out_grad.dims()) { + int64_t diff = out_grad.dims().size() - y.dims().size(); + auto dims_mapping = y_dist_attr.dims_mapping(); + dims_mapping.erase(dims_mapping.begin(), dims_mapping.begin() + diff); + y_dist_attr.set_dims_mapping(dims_mapping); + y_grad_dist_attr.set_dims_mapping(dims_mapping); + for (int64_t i = 0; i < diff; ++i) { + if (out_grad.dist_attr().dims_mapping()[i] != -1) { + y_grad_dist_attr.set_partial_status( + std::vector{out_grad.dist_attr().dims_mapping()[i]}); + } + } + } + + return {{x_dist_attr, y_dist_attr, out_grad.dist_attr()}, + {x_grad_dist_attr, y_grad_dist_attr}}; +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/elementwise.h b/paddle/phi/infermeta/spmd_rules/elementwise.h index 319d3ccbbdac1..736aeec35ed0a 100644 --- a/paddle/phi/infermeta/spmd_rules/elementwise.h +++ b/paddle/phi/infermeta/spmd_rules/elementwise.h @@ -34,5 +34,10 @@ SpmdInfo ElementwiseBinaryInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& y, const DistMetaTensor& out); +SpmdInfo ElementwiseBinaryGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& y, + const DistMetaTensor& out_grad, + int64_t axis); + } // namespace distributed } // namespace phi diff --git a/paddle/phi/kernels/all_gather_kernel.h b/paddle/phi/kernels/all_gather_kernel.h index cc19f88202d2a..1bde193a7b5cd 100644 --- a/paddle/phi/kernels/all_gather_kernel.h +++ b/paddle/phi/kernels/all_gather_kernel.h @@ -34,7 +34,9 @@ void AllGather(const Context& dev_ctx, MetaTensor* out_meta_ptr = &out_meta; AllGatherInferMeta(phi::MetaTensor(x), nranks, out_meta_ptr); - AllGatherKernel(dev_ctx, x, nranks, out); + if (x.initialized()) { + AllGatherKernel(dev_ctx, x, nranks, out); + } } } // namespace phi diff --git a/paddle/phi/kernels/all_reduce_kernel.h b/paddle/phi/kernels/all_reduce_kernel.h index 3583bde3416b3..2ec072bfd3ff2 100644 --- a/paddle/phi/kernels/all_reduce_kernel.h +++ b/paddle/phi/kernels/all_reduce_kernel.h @@ -35,7 +35,9 @@ void AllReduce(const Context& dev_ctx, MetaTensor* out_meta_ptr = &out_meta; AllReduceInferMeta(phi::MetaTensor(x), out_meta_ptr); - AllReduceKernel(dev_ctx, x, reduce_type, out); + if (x.initialized()) { + AllReduceKernel(dev_ctx, x, reduce_type, out); + } } } // namespace phi diff --git a/paddle/phi/kernels/all_to_all_kernel.h b/paddle/phi/kernels/all_to_all_kernel.h index 5444960b1f69e..5ac951deba5fb 100644 --- a/paddle/phi/kernels/all_to_all_kernel.h +++ b/paddle/phi/kernels/all_to_all_kernel.h @@ -30,7 +30,9 @@ void AllToAll(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { MetaTensor* out_meta_ptr = &out_meta; AllToAllInferMeta(phi::MetaTensor(x), out_meta_ptr); - AllToAllKernel(dev_ctx, x, out); + if (x.initialized()) { + AllToAllKernel(dev_ctx, x, out); + } } } // namespace phi diff --git a/paddle/phi/kernels/concat_kernel.h b/paddle/phi/kernels/concat_kernel.h index d3b99449a06f2..e4f2d25f09833 100644 --- a/paddle/phi/kernels/concat_kernel.h +++ b/paddle/phi/kernels/concat_kernel.h @@ -41,7 +41,10 @@ void Concat(const Context& dev_ctx, MetaTensor meta_out(dense_out); ConcatInferMeta(meta_x_ptr, axis.to(), &meta_out); - ConcatKernel(dev_ctx, x, axis, dense_out); + + if (x[0]->initialized()) { + ConcatKernel(dev_ctx, x, axis, dense_out); + } } template diff --git a/paddle/phi/kernels/cpu/svd_kernel.cc b/paddle/phi/kernels/cpu/svd_kernel.cc index 136835876249d..1ae2d9cce0d40 100644 --- a/paddle/phi/kernels/cpu/svd_kernel.cc +++ b/paddle/phi/kernels/cpu/svd_kernel.cc @@ -98,7 +98,6 @@ void SvdKernel(const Context& dev_ctx, /*Create Tensors and output, set the dim ...*/ auto numel = X.numel(); DenseTensor trans_x = ::phi::TransposeLast2Dim(dev_ctx, X); - auto* x_data = trans_x.data(); auto x_dims = X.dims(); int rows = static_cast(x_dims[x_dims.size() - 2]); int cols = static_cast(x_dims[x_dims.size() - 1]); @@ -113,6 +112,7 @@ void SvdKernel(const Context& dev_ctx, 0, cols, errors::InvalidArgument("The col of Input(X) should be greater than 0.")); + auto* x_data = trans_x.data(); int batches = static_cast(numel / (rows * cols)); auto* U_out = dev_ctx.template Alloc>(U); auto* VH_out = dev_ctx.template Alloc>(VH); diff --git a/paddle/phi/kernels/reshape_kernel.h b/paddle/phi/kernels/reshape_kernel.h index 972d72ad706d9..d03e44c0636c8 100644 --- a/paddle/phi/kernels/reshape_kernel.h +++ b/paddle/phi/kernels/reshape_kernel.h @@ -48,7 +48,9 @@ void Reshape(const Context& dev_ctx, DenseTensor* out) { MetaTensor meta_out(out); InferMetaFromVecValue(x, shape, &meta_out); - ReshapeInferKernel(dev_ctx, x, IntArray(shape), out); + if (x.initialized()) { + ReshapeInferKernel(dev_ctx, x, IntArray(shape), out); + } } template diff --git a/paddle/phi/kernels/split_kernel.h b/paddle/phi/kernels/split_kernel.h index 2869bf3206f7d..d752a40084a22 100644 --- a/paddle/phi/kernels/split_kernel.h +++ b/paddle/phi/kernels/split_kernel.h @@ -74,7 +74,9 @@ void Split(const Context& dev_ctx, outs.push_back(&result->at(i)); } - SplitKernel(dev_ctx, x, sections, axis, outs); + if (x.initialized()) { + SplitKernel(dev_ctx, x, sections, axis, outs); + } } template diff --git a/paddle/phi/kernels/transpose_kernel.h b/paddle/phi/kernels/transpose_kernel.h index 5555586c04387..20c4af9cff1f9 100644 --- a/paddle/phi/kernels/transpose_kernel.h +++ b/paddle/phi/kernels/transpose_kernel.h @@ -19,6 +19,7 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/infermeta/unary.h" #include "paddle/phi/kernels/empty_kernel.h" + namespace phi { template @@ -43,7 +44,9 @@ void Transpose(const Context& dev_ctx, // do not call TransposeStridedKernel, because some other kernels call // Transpose directly - TransposeKernel(dev_ctx, x, axis, dense_out); + if (x.initialized()) { + TransposeKernel(dev_ctx, x, axis, dense_out); + } } template diff --git a/test/auto_parallel/semi_auto_parallel_for_elementwise.py b/test/auto_parallel/semi_auto_parallel_for_elementwise.py new file mode 100644 index 0000000000000..b7e3e30b89e56 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_elementwise.py @@ -0,0 +1,147 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist + + +class TestElementwiseApiForSemiAutoParallel: + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def check_tensor_eq(self, a, b): + np1 = a.numpy() + np2 = b.numpy() + np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + + def test_binary_body( + self, x_shape, y_shape, out_shape, x_specs, y_specs, binary_func + ): + paddle.seed(self._seed) + np.random.seed(self._seed) + + x = paddle.randn(x_shape, self._dtype) + y = paddle.randn(y_shape, self._dtype) + x.stop_gradient = False + y.stop_gradient = False + + x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs) + y_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=y_specs) + + dist_x = dist.shard_tensor(x, dist_attr=x_dist_attr) + dist_y = dist.shard_tensor(y, dist_attr=y_dist_attr) + dist_x.stop_gradient = False + dist_y.stop_gradient = False + + dist_out = binary_func(dist_x, dist_y) + out = binary_func(x, y) + self.check_tensor_eq(out, dist_out) + + dist_out.backward() + out.backward() + self.check_tensor_eq(x.grad, dist_x.grad) + self.check_tensor_eq(y.grad, dist_y.grad) + + def test_add_x_shard(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, None], + binary_func=paddle.add, + ) + + def test_sub_x_shard(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, None], + binary_func=paddle.subtract, + ) + + def test_add_x_shard_broadcast(self): + self.test_binary_body( + x_shape=[16, 32], + y_shape=[2, 16, 32], + out_shape=[2, 16, 32], + x_specs=['x', None], + y_specs=[None, None, None], + binary_func=paddle.add, + ) + + def test_add_x_y_shard(self): + if self._backend == "cpu": + return + + self.test_binary_body( + x_shape=[16, 32], + y_shape=[16, 32], + out_shape=[16, 32], + x_specs=['x', None], + y_specs=[None, 'x'], + binary_func=paddle.add, + ) + + def test_add_x_y_shard_broadcast(self): + if self._backend == "cpu": + return + + self.test_binary_body( + x_shape=[4, 16, 32], + y_shape=[16, 32], + out_shape=[4, 16, 32], + x_specs=['x', None, None], + y_specs=[None, None], + binary_func=paddle.add, + ) + + def test_sub_x_y_shard_broadcast(self): + if self._backend == "cpu": + return + + self.test_binary_body( + x_shape=[4, 16, 32], + y_shape=[16, 32], + out_shape=[4, 16, 32], + x_specs=['x', None, None], + y_specs=[None, None], + binary_func=paddle.subtract, + ) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + + self.test_add_x_shard() + self.test_add_x_shard_broadcast() + self.test_add_x_y_shard() + self.test_add_x_y_shard_broadcast() + + +if __name__ == '__main__': + TestElementwiseApiForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_for_matmul.py b/test/auto_parallel/semi_auto_parallel_for_matmul.py index bba31234ed80b..279062f483058 100644 --- a/test/auto_parallel/semi_auto_parallel_for_matmul.py +++ b/test/auto_parallel/semi_auto_parallel_for_matmul.py @@ -24,11 +24,9 @@ class TestMatmulApiForSemiAutoParallel: def __init__(self): self._dtype = os.getenv("dtype") self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) - paddle.seed(2023) - np.random.seed(2023) - def check_tensor_eq(self, a, b): np1 = a.numpy() np2 = b.numpy() @@ -37,6 +35,9 @@ def check_tensor_eq(self, a, b): def test_body( self, x_shape, y_shape, x_specs, y_specs, trans_x=False, trans_y=False ): + paddle.seed(self._seed) + np.random.seed(self._seed) + x_np = np.random.random(size=x_shape).astype(self._dtype) y_np = np.random.random(size=y_shape).astype(self._dtype) x = paddle.to_tensor(x_np) diff --git a/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py b/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py index 87a171091c961..b83d9ffb87e7e 100644 --- a/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py +++ b/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py @@ -25,10 +25,11 @@ class TestReplicatedSPmdApiForSemiAutoParallel: def __init__(self): self._dtype = os.getenv("dtype") self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) - paddle.seed(2023) - np.random.seed(2023) + paddle.seed(self._seed) + np.random.seed(self._seed) def check_tensor_eq(self, a, b): np1 = a.numpy() diff --git a/test/auto_parallel/semi_auto_parallel_simple_net.py b/test/auto_parallel/semi_auto_parallel_simple_net.py index 1e0b1a92859fc..75f25277b81e5 100644 --- a/test/auto_parallel/semi_auto_parallel_simple_net.py +++ b/test/auto_parallel/semi_auto_parallel_simple_net.py @@ -135,6 +135,7 @@ class TestSimpleNetForSemiAutoParallel: def __init__(self): self._dtype = os.getenv("dtype") self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) paddle.set_device(self._backend) @@ -144,8 +145,8 @@ def __init__(self): self.init_single_card_net_result() def init_input_data(self): - paddle.seed(2023) - np.random.seed(2023) + paddle.seed(self._seed) + np.random.seed(self._seed) self.image = np.random.random([BATCH_SIZE, IMAGE_SIZE]).astype( 'float32' diff --git a/test/auto_parallel/test_semi_auto_parallel_basic.py b/test/auto_parallel/test_semi_auto_parallel_basic.py index 3fe98e4d08744..8040b97d43ac9 100644 --- a/test/auto_parallel/test_semi_auto_parallel_basic.py +++ b/test/auto_parallel/test_semi_auto_parallel_basic.py @@ -23,7 +23,7 @@ def setUp(self): num_of_devices=2, timeout=120, ) - self._default_envs = {"dtype": "float32"} + self._default_envs = {"dtype": "float32", "seed": "2023"} self._changeable_envs = {"backend": ["cpu", "gpu"]} def test_matmul_api(self): @@ -36,6 +36,16 @@ def test_matmul_api(self): user_defined_envs=envs, ) + def test_elementwise_api(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_elementwise.py", + user_defined_envs=envs, + ) + def test_several_replicated_spmd_api(self): envs_list = test_base.gen_product_envs_list( self._default_envs, self._changeable_envs diff --git a/test/auto_parallel/test_semi_auto_parallel_single_strategy.py b/test/auto_parallel/test_semi_auto_parallel_single_strategy.py index 5c30f8b5954be..89ef4ac6a1a10 100644 --- a/test/auto_parallel/test_semi_auto_parallel_single_strategy.py +++ b/test/auto_parallel/test_semi_auto_parallel_single_strategy.py @@ -22,6 +22,7 @@ def setUp(self): super().setUp(num_of_devices=2, timeout=120) self._default_envs = { "dtype": "float32", + "seed": "2023", } self._changeable_envs = {"backend": ["cpu", "gpu"]}