diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index ea4c346b182e5..3a87826337465 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -2282,6 +2282,7 @@ infer_meta : func : KernelWithXShapeInferMeta param: [xshape, out_grad] + spmd_rule : SqueezeGradInferSpmd kernel : func : squeeze_grad data_type : out_grad @@ -2527,6 +2528,7 @@ infer_meta : func : KernelWithXShapeInferMeta param: [xshape, out_grad] + spmd_rule : UnsqueezeGradInferSpmd kernel : func : unsqueeze_grad param : [xshape, out_grad] diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 8b0163f817afd..32157222db2f6 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -2489,6 +2489,7 @@ output : Tensor(out), Tensor(xshape) infer_meta : func : SqueezeWithXShapeInferMeta + spmd_rule : SqueezeInferSpmd kernel : func : squeeze data_type : x @@ -2714,6 +2715,7 @@ output : Tensor(out), Tensor(xshape) infer_meta : func : UnsqueezeWithXShapeInferMeta + spmd_rule : UnsqueezeInferSpmd kernel : func : unsqueeze data_type : x diff --git a/paddle/phi/infermeta/spmd_rules/squeeze.cc b/paddle/phi/infermeta/spmd_rules/squeeze.cc index 046de2e049760..8080e6c3d24ac 100644 --- a/paddle/phi/infermeta/spmd_rules/squeeze.cc +++ b/paddle/phi/infermeta/spmd_rules/squeeze.cc @@ -1,16 +1,16 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "paddle/phi/infermeta/spmd_rules/squeeze.h" #include @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h" #include "paddle/phi/infermeta/spmd_rules/dim_trans.h" +#include "paddle/phi/infermeta/spmd_rules/reshape.h" #include "paddle/phi/infermeta/spmd_rules/utils.h" namespace phi { @@ -29,6 +30,14 @@ namespace distributed { using phi::distributed::auto_parallel::str_join; +TensorDistAttr CreateSqueezeXshape(const TensorDistAttr& x) { + TensorDistAttr out(x); + auto dims_mapping = x.dims_mapping(); + dims_mapping.insert(dims_mapping.begin(), -1); + out.set_dims_mapping(dims_mapping); + return out; +} + void MakeSqueezeDimTransWithoutAxis( const std::vector& x_shape, std::vector* out_shape, @@ -137,9 +146,18 @@ SpmdInfo SqueezeInferSpmd(const DistMetaTensor& x, // and output with the inferred dims mapping. TensorDistAttr x_dist_attr_dst(x_dist_attr_src); x_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]); + if (x_dist_attr_dst.dynamic_dims().size() != + x_dist_attr_dst.dims_mapping().size()) { + VLOG(4) << "SqueezeInferSPMD change x dist attr dynamic dims"; + x_dist_attr_dst.set_default_dynamic_dims(x_dist_attr_dst.dims_mapping()); + } TensorDistAttr out_dist_attr(x_dist_attr_src); out_dist_attr.set_dims_mapping(dims_mapping_vec[1]); - + if (out_dist_attr.dynamic_dims().size() != + out_dist_attr.dims_mapping().size()) { + VLOG(4) << "SqueezeInferSPMD change output dist attr dynamic dims"; + out_dist_attr.set_default_dynamic_dims(out_dist_attr.dims_mapping()); + } VLOG(4) << "SqueezeInferSpmd: X shape: [" << str_join(x_shape) << "] Out shape: [" << str_join(out_shape) << "]"; VLOG(4) << "Transformation from input to output:"; @@ -151,7 +169,8 @@ SpmdInfo SqueezeInferSpmd(const DistMetaTensor& x, << "]\n Out dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n"; - return {{x_dist_attr_dst}, {out_dist_attr}}; + return {{x_dist_attr_dst}, + {out_dist_attr, CreateSqueezeXshape(x_dist_attr_dst)}}; } SpmdInfo SqueezeInferSpmdReverse(const DistMetaTensor& x, @@ -202,8 +221,18 @@ SpmdInfo SqueezeInferSpmdReverse(const DistMetaTensor& x, // and output with the inferred dims mapping TensorDistAttr out_dist_attr_dst(out_dist_attr_src); out_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]); + if (out_dist_attr_dst.dynamic_dims().size() != + out_dist_attr_dst.dims_mapping().size()) { + VLOG(4) << "SqueezeInferSPMD change output dist attr dynamic dims"; + out_dist_attr_dst.set_default_dynamic_dims( + out_dist_attr_dst.dims_mapping()); + } TensorDistAttr x_dist_attr(x.dist_attr()); x_dist_attr.set_dims_mapping(dims_mapping_vec[1]); + if (x_dist_attr.dynamic_dims().size() != x_dist_attr.dims_mapping().size()) { + VLOG(4) << "SqueezeInferSPMD change x dist attr dynamic dims"; + x_dist_attr.set_default_dynamic_dims(x_dist_attr.dims_mapping()); + } VLOG(4) << "SqueezeInferSpmdReverse: Out shape: [" << str_join(out_shape) << "] X shape: [" << str_join(x_shape) << "]"; @@ -218,5 +247,14 @@ SpmdInfo SqueezeInferSpmdReverse(const DistMetaTensor& x, return {{x_dist_attr}, {out_dist_attr_dst}}; } +SpmdInfo SqueezeGradInferSpmd(const DistMetaTensor& xshape, + const DistMetaTensor& out_grad, + const IntArray& axis) { + auto shape = phi::vectorize(xshape.dims()); + shape = std::vector(shape.begin() + 1, shape.end()); + const auto& spmd = ReshapeInferSpmd(out_grad, shape); + return {{xshape.dist_attr(), spmd.first[0]}, {spmd.second[0]}}; +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/squeeze.h b/paddle/phi/infermeta/spmd_rules/squeeze.h index b111c3272612f..7ccb8ecb23a64 100644 --- a/paddle/phi/infermeta/spmd_rules/squeeze.h +++ b/paddle/phi/infermeta/spmd_rules/squeeze.h @@ -16,6 +16,7 @@ limitations under the License. */ #include +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" #include "paddle/phi/core/distributed/type_defs.h" @@ -28,5 +29,9 @@ SpmdInfo SqueezeInferSpmd(const DistMetaTensor& x, SpmdInfo SqueezeInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& out, const std::vector& axis); + +SpmdInfo SqueezeGradInferSpmd(const DistMetaTensor& xshape, + const DistMetaTensor& out_grad, + const IntArray& axis = {}); } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/unsqueeze.cc b/paddle/phi/infermeta/spmd_rules/unsqueeze.cc index 73ebad83db135..a5819f5adac39 100644 --- a/paddle/phi/infermeta/spmd_rules/unsqueeze.cc +++ b/paddle/phi/infermeta/spmd_rules/unsqueeze.cc @@ -1,16 +1,16 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights resized. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "paddle/phi/infermeta/spmd_rules/unsqueeze.h" #include @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h" #include "paddle/phi/infermeta/spmd_rules/dim_trans.h" +#include "paddle/phi/infermeta/spmd_rules/reshape.h" #include "paddle/phi/infermeta/spmd_rules/utils.h" namespace phi { @@ -29,6 +30,14 @@ namespace distributed { using phi::distributed::auto_parallel::str_join; +TensorDistAttr CreateUnsqueezeXshape(const TensorDistAttr& x) { + TensorDistAttr out(x); + auto dims_mapping = x.dims_mapping(); + dims_mapping.insert(dims_mapping.begin(), -1); + out.set_dims_mapping(dims_mapping); + return out; +} + std::vector> MakeUnsqueezeDimTrans( const std::vector& x_shape, std::vector* out_shape, @@ -119,8 +128,18 @@ SpmdInfo UnsqueezeInferSpmd(const DistMetaTensor& x, // and output with the inferred dims mapping. TensorDistAttr x_dist_attr_dst(x_dist_attr_src); x_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]); + if (x_dist_attr_dst.dynamic_dims().size() != + x_dist_attr_dst.dims_mapping().size()) { + VLOG(4) << "UnSqueezeInferSPMD change output dist attr dynamic dims"; + x_dist_attr_dst.set_default_dynamic_dims(x_dist_attr_dst.dims_mapping()); + } TensorDistAttr out_dist_attr(x_dist_attr_src); out_dist_attr.set_dims_mapping(dims_mapping_vec[1]); + if (out_dist_attr.dynamic_dims().size() != + out_dist_attr.dims_mapping().size()) { + VLOG(4) << "UnSqueezeInferSPMD change output dist attr dynamic dims"; + out_dist_attr.set_default_dynamic_dims(out_dist_attr.dims_mapping()); + } VLOG(4) << "UnsqueezeInferSpmd: X shape: [" << str_join(x_shape) << "] Out shape: [" << str_join(out_shape) << "]"; @@ -134,7 +153,8 @@ SpmdInfo UnsqueezeInferSpmd(const DistMetaTensor& x, << "]\n Out dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n"; - return {{x_dist_attr_dst}, {out_dist_attr}}; + return {{x_dist_attr_dst}, + {out_dist_attr, CreateUnsqueezeXshape(x_dist_attr_dst)}}; } SpmdInfo UnsqueezeInferSpmdReverse(const DistMetaTensor& x, @@ -181,9 +201,18 @@ SpmdInfo UnsqueezeInferSpmdReverse(const DistMetaTensor& x, // and output with the inferred dims mapping TensorDistAttr out_dist_attr_dst(out_dist_attr_src); out_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]); + if (out_dist_attr_dst.dynamic_dims().size() != + out_dist_attr_dst.dims_mapping().size()) { + VLOG(4) << "UnSqueezeInferSPMDReverse change output dist attr dynamic dims"; + out_dist_attr_dst.set_default_dynamic_dims( + out_dist_attr_dst.dims_mapping()); + } TensorDistAttr x_dist_attr(x.dist_attr()); x_dist_attr.set_dims_mapping(dims_mapping_vec[1]); - + if (x_dist_attr.dynamic_dims().size() != x_dist_attr.dims_mapping().size()) { + VLOG(4) << "UnSqueezeInferSPMDReverse change x dist attr dynamic dims"; + x_dist_attr.set_default_dynamic_dims(x_dist_attr.dims_mapping()); + } VLOG(4) << "UnsqueezeInferSpmdReverse: Out shape: [" << str_join(out_shape) << "] X shape: [" << str_join(x_shape) << "]"; VLOG(4) << "Transformation from output to input:"; @@ -198,5 +227,14 @@ SpmdInfo UnsqueezeInferSpmdReverse(const DistMetaTensor& x, return {{x_dist_attr}, {out_dist_attr_dst}}; } +SpmdInfo UnsqueezeGradInferSpmd(const DistMetaTensor& xshape, + const DistMetaTensor& out_grad, + const IntArray& axis) { + auto shape = phi::vectorize(xshape.dims()); + shape = std::vector(shape.begin() + 1, shape.end()); + const auto& spmd = ReshapeInferSpmd(out_grad, shape); + return {{xshape.dist_attr(), spmd.first[0]}, {spmd.second[0]}}; +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/unsqueeze.h b/paddle/phi/infermeta/spmd_rules/unsqueeze.h index a2f3490409b83..ba434677482c7 100644 --- a/paddle/phi/infermeta/spmd_rules/unsqueeze.h +++ b/paddle/phi/infermeta/spmd_rules/unsqueeze.h @@ -16,6 +16,7 @@ limitations under the License. */ #include +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" #include "paddle/phi/core/distributed/type_defs.h" @@ -28,5 +29,8 @@ SpmdInfo UnsqueezeInferSpmd(const DistMetaTensor& x, SpmdInfo UnsqueezeInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& out, const std::vector& axis); +SpmdInfo UnsqueezeGradInferSpmd(const DistMetaTensor& xshape, + const DistMetaTensor& out_grad, + const IntArray& axis = {}); } // namespace distributed } // namespace phi diff --git a/test/auto_parallel/semi_auto_parallel_for_squeeze.py b/test/auto_parallel/semi_auto_parallel_for_squeeze.py new file mode 100644 index 0000000000000..ff2f977c23e8f --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_squeeze.py @@ -0,0 +1,105 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist + + +class TestSqueezeApiForSemiAutoParallel: + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def check_tensor_eq(self, a, b): + np1 = a.numpy() + np2 = b.numpy() + np.testing.assert_allclose(np1, np2, rtol=1e-06, verbose=True) + + def test_body(self, x_shape, out_shape, x_placements, axis, op_func): + paddle.seed(self._seed) + np.random.seed(self._seed) + + x = paddle.randn(x_shape, self._dtype) + x.stop_gradient = False + + dist_x = dist.shard_tensor(x, self._mesh, x_placements) + dist_x.stop_gradient = False + + dist_out = op_func(dist_x, axis=axis) + out = op_func(x, axis=axis) + self.check_tensor_eq(out, dist_out) + np.testing.assert_equal(dist_out.shape, out_shape, verbose=True) + + dist_out.backward() + out.backward() + self.check_tensor_eq(x.grad, dist_x.grad) + + def test_squeeze(self): + self.test_body( + x_shape=[1, 4, 1, 6], + out_shape=[4, 1, 6], + x_placements=[dist.Shard(1)], + axis=0, + op_func=paddle.squeeze, + ) + + def test_squeeze_multi_axes(self): + self.test_body( + x_shape=[1, 4, 1, 6], + out_shape=[4, 6], + x_placements=[dist.Shard(1)], + axis=(0, 2), + op_func=paddle.squeeze, + ) + + def test_unsqueeze(self): + self.test_body( + x_shape=[4, 6], + out_shape=[1, 4, 6], + x_placements=[dist.Shard(0)], + axis=0, + op_func=paddle.unsqueeze, + ) + + def test_unsqueeze_multi_axes(self): + self.test_body( + x_shape=[4, 6], + out_shape=[1, 4, 6, 1], + x_placements=[dist.Shard(1)], + axis=(0, 3), + op_func=paddle.unsqueeze, + ) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + + self.test_squeeze() + self.test_squeeze_multi_axes() + self.test_unsqueeze() + self.test_unsqueeze_multi_axes() + + +if __name__ == '__main__': + TestSqueezeApiForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/spmd_rules/test_unsqueeze_rule.py b/test/auto_parallel/spmd_rules/test_unsqueeze_rule.py index afb851279ca36..84367ea9cbedd 100644 --- a/test/auto_parallel/spmd_rules/test_unsqueeze_rule.py +++ b/test/auto_parallel/spmd_rules/test_unsqueeze_rule.py @@ -48,7 +48,7 @@ def test_unsqueeze_infer_forward(self): infered_output_dist_attrs = result_dist_attrs[1] self.assertEqual(len(infered_input_dist_attrs), 1) - self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 2) self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [0, 1]) self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1]) diff --git a/test/auto_parallel/test_semi_auto_parallel_basic.py b/test/auto_parallel/test_semi_auto_parallel_basic.py index fd37d2d864888..e86cea24f572a 100644 --- a/test/auto_parallel/test_semi_auto_parallel_basic.py +++ b/test/auto_parallel/test_semi_auto_parallel_basic.py @@ -188,6 +188,16 @@ def test_reshape_api(self): user_defined_envs=envs, ) + def test_squeeze_api(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_squeeze.py", + user_defined_envs=envs, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index 2f994e455fd32..b5cd1ec05dfff 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -1507,6 +1507,125 @@ TEST(EmbeddingGradInferSpmd, Ctor) { << std::endl; } +TEST(SqueezeGradInferSpmd, Ctor) { + std::vector xshape_shape = {-1, 1, 32, 1, 48}; + std::vector out_grad_shape = {32, 48}; + + std::vector mesh_shape = {2, 3}; + std::vector process_ids = {0, 1, 2, 3, 4, 5}; + std::vector dim_names = {"x", "y"}; + ProcessMesh process_mesh(mesh_shape, process_ids, dim_names); + + TensorDistAttr xshape_dist_attr = TensorDistAttr(); + xshape_dist_attr.set_process_mesh(process_mesh); + xshape_dist_attr.set_dims_mapping(std::vector({-1, -1, 1, -1, -1})); + xshape_dist_attr.set_dynamic_dims( + std::vector({false, false, false, false})); + + TensorDistAttr out_grad_dist_attr = TensorDistAttr(); + out_grad_dist_attr.set_process_mesh(process_mesh); + out_grad_dist_attr.set_dims_mapping(std::vector({-1, 1})); + out_grad_dist_attr.set_dynamic_dims(std::vector({false, false})); + + phi::distributed::DistMetaTensor xshape(phi::make_ddim(xshape_shape), + xshape_dist_attr); + phi::distributed::DistMetaTensor out_grad(phi::make_ddim(out_grad_shape), + out_grad_dist_attr); + + auto spmdinfo = SqueezeGradInferSpmd(xshape, out_grad); + + EXPECT_EQ(spmdinfo.first.size(), 2UL); + EXPECT_EQ(spmdinfo.second.size(), 1UL); + + EXPECT_EQ(get_dims_mapping(spmdinfo.first[0]), + std::vector({-1, -1, 1, -1, -1})); + EXPECT_EQ(get_dims_mapping(spmdinfo.first[1]), std::vector({-1, 1})); + EXPECT_EQ(get_dims_mapping(spmdinfo.second[0]), + std::vector({-1, -1, -1, 1})); + EXPECT_DOUBLE_EQ( + PADDLE_GET_CONST(TensorDistAttr, spmdinfo.second[0]).is_partial(), false); + + xshape_dist_attr.set_dims_mapping({-1, -1, 0, -1, 1}); + out_grad_dist_attr.set_dims_mapping({0, 1}); + xshape = phi::distributed::DistMetaTensor(phi::make_ddim(xshape_shape), + xshape_dist_attr); + out_grad = phi::distributed::DistMetaTensor(phi::make_ddim(out_grad_shape), + out_grad_dist_attr); + + spmdinfo = SqueezeGradInferSpmd(xshape, out_grad); + + EXPECT_EQ(spmdinfo.first.size(), 2UL); + EXPECT_EQ(spmdinfo.second.size(), 1UL); + + EXPECT_EQ(get_dims_mapping(spmdinfo.first[0]), + std::vector({-1, -1, 0, -1, 1})); + EXPECT_EQ(get_dims_mapping(spmdinfo.first[1]), std::vector({0, 1})); + EXPECT_EQ(get_dims_mapping(spmdinfo.second[0]), + std::vector({-1, 0, -1, 1})); + EXPECT_DOUBLE_EQ( + PADDLE_GET_CONST(TensorDistAttr, spmdinfo.second[0]).is_partial(), false); +} + +TEST(UnsqueezeGradInferSpmd, Ctor) { + std::vector xshape_shape = {-1, 32, 48}; + std::vector out_grad_shape = {1, 32, 1, 48}; + + std::vector mesh_shape = {2, 3}; + std::vector process_ids = {0, 1, 2, 3, 4, 5}; + std::vector dim_names = {"x", "y"}; + ProcessMesh process_mesh(mesh_shape, process_ids, dim_names); + + TensorDistAttr xshape_dist_attr = TensorDistAttr(); + xshape_dist_attr.set_process_mesh(process_mesh); + xshape_dist_attr.set_dims_mapping(std::vector({-1, 1, -1})); + xshape_dist_attr.set_dynamic_dims(std::vector({false, false})); + + TensorDistAttr out_grad_dist_attr = TensorDistAttr(); + out_grad_dist_attr.set_process_mesh(process_mesh); + out_grad_dist_attr.set_dims_mapping(std::vector({-1, 1, -1, -1})); + out_grad_dist_attr.set_dynamic_dims( + std::vector({false, false, false, false})); + + phi::distributed::DistMetaTensor xshape(phi::make_ddim(xshape_shape), + xshape_dist_attr); + phi::distributed::DistMetaTensor out_grad(phi::make_ddim(out_grad_shape), + out_grad_dist_attr); + + auto spmdinfo = UnsqueezeGradInferSpmd(xshape, out_grad); + + EXPECT_EQ(spmdinfo.first.size(), 2UL); + EXPECT_EQ(spmdinfo.second.size(), 1UL); + + EXPECT_EQ(get_dims_mapping(spmdinfo.first[0]), + std::vector({-1, 1, -1})); + EXPECT_EQ(get_dims_mapping(spmdinfo.first[1]), + std::vector({-1, 1, -1, -1})); + EXPECT_EQ(get_dims_mapping(spmdinfo.second[0]), + std::vector({1, -1})); + EXPECT_DOUBLE_EQ( + PADDLE_GET_CONST(TensorDistAttr, spmdinfo.second[0]).is_partial(), false); + + xshape_dist_attr.set_dims_mapping({-1, 0, 1}); + out_grad_dist_attr.set_dims_mapping({-1, 0, -1, 1}); + xshape = phi::distributed::DistMetaTensor(phi::make_ddim(xshape_shape), + xshape_dist_attr); + out_grad = phi::distributed::DistMetaTensor(phi::make_ddim(out_grad_shape), + out_grad_dist_attr); + + spmdinfo = UnsqueezeGradInferSpmd(xshape, out_grad); + + EXPECT_EQ(spmdinfo.first.size(), 2UL); + EXPECT_EQ(spmdinfo.second.size(), 1UL); + + EXPECT_EQ(get_dims_mapping(spmdinfo.first[0]), + std::vector({-1, 0, 1})); + EXPECT_EQ(get_dims_mapping(spmdinfo.first[1]), + std::vector({-1, 0, -1, 1})); + EXPECT_EQ(get_dims_mapping(spmdinfo.second[0]), std::vector({0, 1})); + EXPECT_DOUBLE_EQ( + PADDLE_GET_CONST(TensorDistAttr, spmdinfo.second[0]).is_partial(), false); +} + } // namespace auto_parallel } // namespace distributed } // namespace paddle