From ed4812ff8fd2d93dd8e39431593ea546af9e3ff5 Mon Sep 17 00:00:00 2001 From: Yichen Zhang <32740647+pkuzyc@users.noreply.github.com> Date: Tue, 26 Sep 2023 16:29:06 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Semi-Auto=E3=80=91Adapt=20split=20spmd?= =?UTF-8?q?=20rule=20to=20phi=20(#57467)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * adapt split rule to phi * fix bugs and modify apis in unit test * fix codestyle * bug fix --- .../auto_parallel/spmd_rules/rules.h | 5 - .../spmd_rules/split_spmd_rule.cc | 218 ------------------ .../spmd_rules/split_spmd_rule.h | 41 ---- paddle/phi/infermeta/spmd_rules/rules.h | 10 + paddle/phi/infermeta/spmd_rules/split.cc | 216 +++++++++++++++++ paddle/phi/infermeta/spmd_rules/split.h | 46 ++++ .../spmd_rules/test_split_rule.py | 114 +++++---- 7 files changed, 338 insertions(+), 312 deletions(-) delete mode 100644 paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.cc delete mode 100644 paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h create mode 100644 paddle/phi/infermeta/spmd_rules/split.cc create mode 100644 paddle/phi/infermeta/spmd_rules/split.h diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h index b1606fceb41dc..13e72a4849623 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h @@ -18,7 +18,6 @@ #include "paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h" -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.h" // TODO(ljz) Automatic this process in cmake file. @@ -37,10 +36,6 @@ REGISTER_SPMD_RULE(log_softmax, SoftmaxSPMDRule); REGISTER_SPMD_RULE(cross_entropy_with_softmax, CrossEntropyWithSoftmaxSPMDRule); REGISTER_SPMD_RULE(softmax_with_cross_entropy, CrossEntropyWithSoftmaxSPMDRule); -// split rule -REGISTER_SPMD_RULE(split, SplitSPMDRule); -REGISTER_SPMD_RULE(split_with_num, SplitSPMDRule); - // transpose rule REGISTER_SPMD_RULE(transpose, TransposeSPMDRule); diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.cc deleted file mode 100644 index 51b4f4b10c675..0000000000000 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.cc +++ /dev/null @@ -1,218 +0,0 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h" -#include -#include -#include "paddle/phi/core/distributed/auto_parallel/utils.h" - -namespace paddle { -namespace distributed { -namespace auto_parallel { - -using phi::distributed::auto_parallel::str_join; - -std::pair, std::vector> -SplitSPMDRule::InferForward(const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs) { - // step0: Verify Input Args Based on Elementwise Logic - int64_t ninputs = static_cast(input_specs.size()); - PADDLE_ENFORCE_EQ( - ninputs, - 1, - phi::errors::InvalidArgument("The size of InputSpec in split must " - "be equal to 1, but got [%d].", - ninputs)); - VerifySpecs(input_specs, "split"); - - // step1: Build Einsum Notation - int64_t ndim = static_cast(input_specs[0].shape().size()); - int64_t noutput = 0; - // split api uses num or sections as attribute - if (attrs.find("num") != attrs.end()) { - noutput = ExtractAttr("num", attrs); - } else if (attrs.find("sections") != attrs.end()) { - std::vector sections = - ExtractAttr>("sections", attrs); - noutput = static_cast(sections.size()); - } - int64_t axis = ExtractAttr("axis", attrs); - if (axis < 0) { - axis += ndim; - } - std::string alphabet = "abcdefghijlmnopqrstuvwxyz"; - - // get einsum notation for input, use a special - // notation 'k' to mark the splitted axis in input - std::vector input_axes_vec; - std::string input_axes = alphabet.substr(0, ndim); - input_axes[axis] = 'k'; - input_axes_vec.emplace_back(input_axes); - - // get einsum notation for output - std::string output_axes(input_axes); - // the splitted axis cannot be sharded, set its notation - // with the special '1' to set its dim mapping to -1. - output_axes[axis] = '1'; - - // step2: Sharding Propogation - // step2.1: merge input shardings - std::vector>> axes_sharding_info; - axes_sharding_info = GetAxesDimsMappingPair(input_axes_vec, input_specs); - std::unordered_map axis_to_dim_map = - ShardingMergeForTensors(axes_sharding_info); - - // step2.2: infer output dims mapping from merged input dims mapping - std::vector output_dims_mapping = - GetDimsMappingForAxes(output_axes, axis_to_dim_map); - - // get the dist attributes for all outputs, the - // dist attributes are same for all outputs. - std::vector output_dist_attrs; - for (int64_t i = 0; i < noutput; i++) { - output_dist_attrs.emplace_back( - CopyTensorDistAttrForOutput(input_specs[0].dist_attr())); - output_dist_attrs[i].set_dims_mapping(output_dims_mapping); - } - - // step2.3 get new dist attribute for input. the splitted - // cannot be sharded, if it is sharded, set it to replicated. - std::vector new_input_dist_attrs; - new_input_dist_attrs.emplace_back(input_specs[0].dist_attr()); - std::vector new_input_dims_mapping(input_specs[0].dims_mapping()); - new_input_dims_mapping[axis] = -1; - new_input_dist_attrs[0].set_dims_mapping(new_input_dims_mapping); - - // Step3 Handle input tensor partial (TODO) - VLOG(4) << "SplitSPMDRule InferForward: "; - for (int64_t i = 0; i < ninputs; i++) { - VLOG(4) << "Input" << std::to_string(i) << " shape: [" - << str_join(input_specs[i].shape()) << "] " - << "einsum_notation: " << input_axes << " src_dims_mapping: [" - << str_join(input_specs[i].dims_mapping()) << "] " - << "dst_dims_mapping: [" - << str_join(new_input_dist_attrs[i].dims_mapping()) << "]"; - } - for (int64_t i = 0; i < noutput; i++) { - VLOG(4) << "Output" << std::to_string(i) << " dims_mapping: [" - << str_join(output_dims_mapping) << "]"; - } - - return {new_input_dist_attrs, output_dist_attrs}; -} - -std::pair, std::vector> -SplitSPMDRule::InferBackward(const std::vector& input_specs, - const std::vector& output_specs, - const paddle::framework::AttributeMap& attrs) { - // step0: Verify Input Args Based on Elementwise Logic - int64_t ninputs = input_specs.size(); - int64_t noutputs = output_specs.size(); - PADDLE_ENFORCE_EQ( - ninputs, - 1, - phi::errors::InvalidArgument("The size of InputSpec in split must " - "be equal to 1, but got [%d].", - ninputs)); - VerifySpecs(output_specs, "split"); - - // check whether the size of output_specs equals - // to the specified split num in op attributes - int64_t specified_split_num = -1; - // split api uses num or sections as attribute - if (attrs.find("num") != attrs.end()) { - specified_split_num = ExtractAttr("num", attrs); - } else if (attrs.find("sections") != attrs.end()) { - std::vector sections = - ExtractAttr>("sections", attrs); - specified_split_num = sections.size(); - } - PADDLE_ENFORCE_EQ( - noutputs, - specified_split_num, - phi::errors::InvalidArgument("The size of OutputSpec [%d] is not equal " - "to the specified split number [%d]", - noutputs, - specified_split_num)); - - // step1: Build Einsum Notation - int64_t ndim = input_specs[0].shape().size(); - int64_t axis = ExtractAttr("axis", attrs); - if (axis < 0) { - axis += ndim; - } - std::string alphabet = "abcdefghijlmnopqrstuvwxyz"; - - // get einsum notation for input, use a special - // notation 'k' to mark the splitted axis in input - std::string input_axes = alphabet.substr(0, ndim); - input_axes[axis] = 'k'; - - // get einsum notation for output - std::string output_axes(input_axes); - output_axes[axis] = 'k'; - - // step2: Sharding Propogation - // step2.1: merge input shardings - std::vector output_axes_vec; - for (int64_t i = 0; i < noutputs; i++) { - output_axes_vec.emplace_back(output_axes); - } - std::vector>> axes_sharding_info; - axes_sharding_info = GetAxesDimsMappingPair(output_axes_vec, output_specs); - std::unordered_map axis_to_dim_map = - ShardingMergeForTensors(axes_sharding_info); - - // step2.2: infer input dims mapping from output dims mapping - // the split axis in input is set to -1. - std::vector input_dims_mapping = - GetDimsMappingForAxes(input_axes, axis_to_dim_map, true); - input_dims_mapping[axis] = -1; - TensorDistAttr input_dist_attr(input_specs[0].dist_attr()); - input_dist_attr.set_dims_mapping(input_dims_mapping); - - // step2.3 get new dist attribute for output. the splitted - // cannot be sharded, if it is sharded, set it to replicated. - std::vector output_dist_attrs; - for (int64_t i = 0; i < noutputs; i++) { - output_dist_attrs.emplace_back(output_specs[i].dist_attr()); - std::vector out_dims_mapping = - GetDimsMappingForAxes(output_axes, axis_to_dim_map, true); - out_dims_mapping[axis] = -1; - output_dist_attrs[i].set_dims_mapping(out_dims_mapping); - } - - // step3 Handle input tensor partial (TODO) - - VLOG(4) << "SplitSPMDRule InferBackward: "; - for (int64_t i = 0; i < noutputs; i++) { - VLOG(4) << "Output" << std::to_string(i) << " shape: [" - << str_join(output_specs[i].shape()) << "] " - << "einsum_notation: " << output_axes << " dims_mapping: [" - << str_join(output_specs[i].dims_mapping()) << "]"; - } - for (int64_t i = 0; i < ninputs; i++) { - VLOG(4) << "Input" << std::to_string(i) << " shape: [" - << str_join(input_specs[i].shape()) << "] " - << "einsum_notation: " << input_axes << " dims_mapping: [" - << str_join(input_dims_mapping) << "]"; - } - VLOG(4) << std::endl; - - return {{input_dist_attr}, output_dist_attrs}; -} - -} // namespace auto_parallel -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h deleted file mode 100644 index f8a1300e62409..0000000000000 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include - -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" - -namespace paddle { -namespace distributed { -namespace auto_parallel { - -class SplitSPMDRule : public SPMDRuleBase { - public: - std::pair, std::vector> - InferForward(const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs) override; - - std::pair, std::vector> - InferBackward(const std::vector& input_specs, - const std::vector& output_specs, - const paddle::framework::AttributeMap& attrs) override; -}; -} // namespace auto_parallel -} // namespace distributed -} // namespace paddle diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index 5418b34e28b57..fa6da9beee2c7 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/reduction.h" #include "paddle/phi/infermeta/spmd_rules/replicated.h" #include "paddle/phi/infermeta/spmd_rules/reshape.h" +#include "paddle/phi/infermeta/spmd_rules/split.h" /** * Design Notes: @@ -485,5 +486,14 @@ PD_REGISTER_SPMD_RULE( PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmd), PD_INFER_SPMD(phi::distributed::EmbeddingInferSpmdReverse)); +// split rule +PD_REGISTER_SPMD_RULE(split, + PD_INFER_SPMD(phi::distributed::SplitInferSpmd), + PD_INFER_SPMD(phi::distributed::SplitInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + split_with_num, + PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmd), + PD_INFER_SPMD(phi::distributed::SplitWithNumInferSpmdReverse)); + } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/split.cc b/paddle/phi/infermeta/spmd_rules/split.cc new file mode 100644 index 0000000000000..4bc2a9ce0bdb1 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/split.cc @@ -0,0 +1,216 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/split.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +SpmdInfo SplitWithNumInferSpmd(const DistMetaTensor& x, int num, int axis) { + // Step0: Verify input args based on split logic + auto x_shape = phi::vectorize(x.dims()); + int x_ndim = x_shape.size(); + auto x_dist_attr_src = x.dist_attr(); + std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); + PADDLE_ENFORCE_EQ( + x_ndim, + x_dims_mapping.size(), + phi::errors::InvalidArgument("The Tensor X's rank [%d] and X's " + "dims_mapping size [%d] are not matched.", + x_ndim, + x_dims_mapping.size())); + + // Step1: Build Einsum Notation + std::string alphabet = "abcdefghijlmnopqrstuvwxyz"; + if (axis < 0) { + axis += x_ndim; + } + + // get einsum notation for input, use a special + // notation 'k' to mark the splitted axis in input + std::string x_axes = alphabet.substr(0, x_ndim); + x_axes[axis] = 'k'; + + // get einsum notation for output + std::string out_axes(x_axes); + // the splitted axis cannot be sharded, set its notation + // with the special '1' to set its dim mapping to -1. + out_axes[axis] = '1'; + + // Step2: Sharding Propogation + // Step2.1: merge input shardings + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors({{x_axes, x_dims_mapping}}); + + // Step2.2: infer output dims mapping from merged input dims mapping + std::vector out_dims_mapping = + GetDimsMappingForAxes(out_axes, axis_to_dim_map); + + // get the dist attributes for all outputs, the + // dist attributes are same for all outputs. + std::vector out_dist_attrs; + for (int i = 0; i < num; i++) { + out_dist_attrs.emplace_back(CopyTensorDistAttrForOutput(x_dist_attr_src)); + out_dist_attrs[i].set_dims_mapping(out_dims_mapping); + } + + // Step2.3 get new dist attribute for input. the splitted + // cannot be sharded, if it is sharded, set it to replicated. + TensorDistAttr x_dist_attr_dst(x_dist_attr_src); + x_dims_mapping[axis] = -1; + x_dist_attr_dst.set_dims_mapping(x_dims_mapping); + + // Step3 Handle input tensor partial (TODO) + VLOG(4) << "SplitWithNumInferSpmd:"; + VLOG(4) << "Einsum Notation: " << x_axes << "-->" << out_axes; + VLOG(4) << "Input shape: [" << str_join(x_shape) << "] " + << "src_dims_mapping: [" << str_join(x_dist_attr_src.dims_mapping()) + << "] " + << "dst_dims_mapping: [" << str_join(x_dims_mapping) << "]"; + for (int64_t i = 0; i < num; i++) { + VLOG(4) << "Output" << std::to_string(i) << " dims_mapping: [" + << str_join(out_dims_mapping) << "]"; + } + VLOG(4) << std::endl; + + return {{x_dist_attr_dst}, out_dist_attrs}; +} + +SpmdInfo SplitWithNumInferSpmdReverse( + const DistMetaTensor& x, + const std::vector& outs, + int num, + int axis) { + // Step0: Verify input args based on split logic + int nouts = outs.size(); + int out_ndim = phi::vectorize(outs[0]->dims()).size(); + auto x_shape = phi::vectorize(x.dims()); + int x_ndim = x_shape.size(); + auto x_dist_attr = x.dist_attr(); + std::vector x_dims_mapping = x_dist_attr.dims_mapping(); + PADDLE_ENFORCE_EQ(nouts, + num, + phi::errors::InvalidArgument( + "The size of Output Tensors [%d] is not equal " + "to the specified split number [%d]", + nouts, + num)); + PADDLE_ENFORCE_EQ( + x_ndim, + out_ndim, + phi::errors::InvalidArgument("The Tensor X's rank [%d] is not equal " + "to the Tensor Out's rank [%d]", + x_ndim, + out_ndim)); + for (int i = 0; i < num; i++) { + auto shape = phi::vectorize(outs[i]->dims()); + int ndim = shape.size(); + auto dist_attr = outs[i]->dist_attr(); + int dims_mapping_size = dist_attr.dims_mapping().size(); + PADDLE_ENFORCE_EQ( + ndim, + dims_mapping_size, + phi::errors::InvalidArgument("The Tensor Out[%d]'s rank [%d] and Its " + "dims_mapping size [%d] are not matched.", + i, + ndim, + dims_mapping_size)); + } + + // Step1: Build Einsum Notation + if (axis < 0) { + axis += x_ndim; + } + std::string alphabet = "abcdefghijlmnopqrstuvwxyz"; + + // get einsum notation for input, use a special + // notation 'k' to mark the splitted axis in input + std::string x_axes = alphabet.substr(0, x_ndim); + x_axes[axis] = 'k'; + + // get einsum notation for output + std::string out_axes(x_axes); + out_axes[axis] = 'k'; + + // Step2: Sharding Propogation + // Step2.1: merge output shardings + std::vector>> axes_sharding_info; + for (int i = 0; i < nouts; i++) { + std::vector out_dims_mapping = outs[i]->dist_attr().dims_mapping(); + axes_sharding_info.emplace_back(std::make_pair(out_axes, out_dims_mapping)); + } + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors(axes_sharding_info); + + // Step2.2: infer input dims mapping from output dims mapping + // the split axis in input is set to -1. + x_dims_mapping = GetDimsMappingForAxes(x_axes, axis_to_dim_map, true); + x_dims_mapping[axis] = -1; + x_dist_attr.set_dims_mapping(x_dims_mapping); + + // step2.3 get new dist attribute for output. the splitted + // cannot be sharded, if it is sharded, set it to replicated. + std::vector out_dist_attrs; + for (int i = 0; i < nouts; i++) { + out_dist_attrs.emplace_back(outs[i]->dist_attr()); + std::vector out_dims_mapping = + GetDimsMappingForAxes(out_axes, axis_to_dim_map, true); + out_dims_mapping[axis] = -1; + out_dist_attrs[i].set_dims_mapping(out_dims_mapping); + } + + // step3 Handle input tensor partial (TODO) + + VLOG(4) << "SplitWithNumInferSpmdReverse:"; + VLOG(4) << "Einsum Notation: " << x_axes << "-->" << out_axes; + for (int i = 0; i < nouts; i++) { + VLOG(4) << "Output" << std::to_string(i) << " shape: [" + << str_join(phi::vectorize(outs[i]->dims())) << "] " + << "src_dims_mapping: [" + << str_join(outs[i]->dist_attr().dims_mapping()) << "] " + << "dst_dims_mapping: [" + << str_join(out_dist_attrs[i].dims_mapping()) << "]"; + } + VLOG(4) << "Input shape: [" << str_join(x_shape) << "] " + << "dims_mapping: [" << str_join(x_dims_mapping) << "]\n\n"; + + return {{x_dist_attr}, out_dist_attrs}; +} + +SpmdInfo SplitInferSpmd(const DistMetaTensor& x, + const std::vector& sections, + int axis) { + int num = sections.size(); + return SplitWithNumInferSpmd(x, num, axis); +} + +SpmdInfo SplitInferSpmdReverse(const DistMetaTensor& x, + const std::vector& outs, + const std::vector& sections, + int axis) { + int num = sections.size(); + return SplitWithNumInferSpmdReverse(x, outs, num, axis); +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/split.h b/paddle/phi/infermeta/spmd_rules/split.h new file mode 100644 index 0000000000000..96b1a51e5e088 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/split.h @@ -0,0 +1,46 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { + +SpmdInfo SplitInferSpmd(const DistMetaTensor& x, + const std::vector& sections, + int axis); + +SpmdInfo SplitInferSpmdReverse(const DistMetaTensor& x, + const std::vector& outs, + const std::vector& sections, + int axis); + +SpmdInfo SplitWithNumInferSpmd(const DistMetaTensor& x, int num, int axis); + +SpmdInfo SplitWithNumInferSpmdReverse( + const DistMetaTensor& x, + const std::vector& outs, + int num, + int axis); + +} // namespace distributed +} // namespace phi diff --git a/test/auto_parallel/spmd_rules/test_split_rule.py b/test/auto_parallel/spmd_rules/test_split_rule.py index a4f66ff638f0d..191931a23448f 100644 --- a/test/auto_parallel/spmd_rules/test_split_rule.py +++ b/test/auto_parallel/spmd_rules/test_split_rule.py @@ -13,13 +13,14 @@ # limitations under the License. import unittest +from collections import OrderedDict -from paddle.distributed.auto_parallel.static.completion import get_spmd_rule from paddle.distributed.auto_parallel.static.dist_attribute import ( DistTensorSpec, TensorDistAttr, ) from paddle.distributed.fleet import auto +from paddle.framework import core class TestReductionSPMDRule(unittest.TestCase): @@ -28,8 +29,6 @@ class TestReductionSPMDRule(unittest.TestCase): """ def setUp(self): - self.rule = get_spmd_rule("split") - x_shape = [64, 32, 48] process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) @@ -38,21 +37,16 @@ def setUp(self): x_tensor_dist_attr.process_mesh = process_mesh self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) - self.attrs = { - 'num_or_sections': 2, - 'axis': 1, - } - def test_single_mesh_dim(self): # num_or_sections = 2, axis = 1 # [0, -1, -1] --> [0, -1, -1], [0, -1, -1], [0, -1, -1] - self.rule = get_spmd_rule("split_with_num") - self.attrs = {} + self.rule = core.get_phi_spmd_rule("split_with_num") + self.attrs = OrderedDict() self.attrs['num'] = 2 self.attrs['axis'] = 1 self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + self.x_dist_tensor_spec, self.attrs['num'], self.attrs['axis'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -67,13 +61,13 @@ def test_single_mesh_dim(self): # num_or_sections = [15, 16, 17], axis = 2 # [0, -1, -1] --> [0, -1, -1], [0, -1, -1], [0, -1, -1], [0, -1, -1] - self.rule = get_spmd_rule("split") - self.attrs = {} + self.rule = core.get_phi_spmd_rule("split") + self.attrs = OrderedDict() self.attrs['sections'] = [15, 16, 17] self.attrs['axis'] = 2 self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + self.x_dist_tensor_spec, self.attrs['sections'], self.attrs['axis'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -89,12 +83,12 @@ def test_single_mesh_dim(self): # num_or_sections = [15, 16, 17], axis = 2 # [-1, -1, 0] --> [-1, -1, -1], [-1, -1, -1], [-1, -1, -1], [-1, -1, -1] - self.attrs = {} + self.attrs = OrderedDict() self.attrs['sections'] = [15, 16, 17] self.attrs['axis'] = 2 self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + self.x_dist_tensor_spec, self.attrs['sections'], self.attrs['axis'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -116,13 +110,13 @@ def test_single_mesh_dim(self): # num_or_sections = 2, axis = -2 # [0, -1, -1] --> [0, -1, -1], [0, -1, -1], [0, -1, -1] - self.rule = get_spmd_rule("split_with_num") - self.attrs = {} + self.rule = core.get_phi_spmd_rule("split_with_num") + self.attrs = OrderedDict() self.attrs['num'] = 2 self.attrs['axis'] = -2 self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + self.x_dist_tensor_spec, self.attrs['num'], self.attrs['axis'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -142,13 +136,13 @@ def test_multi_mesh_dim(self): # num_or_sections = 3, axis = -1 # [0, 1, -1, -1] --> [0, 1, -1, -1], [0, 1, -1, -1], [0, 1, -1, -1], [0, 1, -1, -1] - self.rule = get_spmd_rule("split_with_num") - self.attrs = {} + self.rule = core.get_phi_spmd_rule("split_with_num") + self.attrs = OrderedDict() self.attrs['num'] = 3 self.attrs['axis'] = -1 self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + self.x_dist_tensor_spec, self.attrs['num'], self.attrs['axis'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -172,13 +166,13 @@ def test_multi_mesh_dim(self): # num_or_sections = [32, 32, 32], axis = 0 # [0, 1, -1, -1] --> [-1, 1, -1, -1], [-1, 1, -1, -1], [-1, 1, -1, -1], [-1, 1, -1, -1] - self.rule = get_spmd_rule("split") - self.attrs = {} + self.rule = core.get_phi_spmd_rule("split") + self.attrs = OrderedDict() self.attrs['sections'] = [32, 32, 32] self.attrs['axis'] = 0 self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec], self.attrs + self.x_dist_tensor_spec, self.attrs['sections'], self.attrs['axis'] ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -206,8 +200,8 @@ def test_backward_single_mesh_dim(self): # num_or_sections = 2, axis = 1 # [0, -1, -1], [0, -1, -1] --> [0, -1, -1], [0, -1, -1], [0, -1, -1] # (outputs --> input, outputs) - self.rule = get_spmd_rule("split_with_num") - self.attrs = {} + self.rule = core.get_phi_spmd_rule("split_with_num") + self.attrs = OrderedDict() self.attrs['num'] = 2 self.attrs['axis'] = 1 self.out_spec_list = [] @@ -218,7 +212,10 @@ def test_backward_single_mesh_dim(self): self.out_spec_list[0].set_dims_mapping([0, -1, -1]) self.out_spec_list[1].set_dims_mapping([0, -1, -1]) result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec], self.out_spec_list, self.attrs + self.x_dist_tensor_spec, + self.out_spec_list, + self.attrs['num'], + self.attrs['axis'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -235,8 +232,8 @@ def test_backward_single_mesh_dim(self): # [0, -1, -1], [0, -1, -1], [0, -1, -1] --> # [0, -1, -1], [0, -1, -1], [0, -1, -1], [0, -1, -1] # (outputs --> input, outputs) - self.rule = get_spmd_rule("split") - self.attrs = {} + self.rule = core.get_phi_spmd_rule("split") + self.attrs = OrderedDict() self.attrs['sections'] = [15, 16, 17] self.attrs['axis'] = 2 self.out_spec_list = [] @@ -250,7 +247,10 @@ def test_backward_single_mesh_dim(self): self.out_spec_list[1].set_dims_mapping([0, -1, -1]) self.out_spec_list[2].set_dims_mapping([0, -1, -1]) result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec], self.out_spec_list, self.attrs + self.x_dist_tensor_spec, + self.out_spec_list, + self.attrs['sections'], + self.attrs['axis'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -268,7 +268,7 @@ def test_backward_single_mesh_dim(self): # [-1, -1, -1], [-1, -1, -1], [-1, -1, -1] --> # [-1, -1, -1], [-1, -1, -1], [-1, -1, -1], [-1, -1, -1] # (outputs --> input, outputs) - self.attrs = {} + self.attrs = OrderedDict() self.attrs['sections'] = [15, 16, 17] self.attrs['axis'] = 2 self.out_spec_list = [] @@ -282,7 +282,10 @@ def test_backward_single_mesh_dim(self): self.out_spec_list[1].set_dims_mapping([-1, -1, -1]) self.out_spec_list[2].set_dims_mapping([-1, -1, -1]) result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec], self.out_spec_list, self.attrs + self.x_dist_tensor_spec, + self.out_spec_list, + self.attrs['sections'], + self.attrs['axis'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -305,8 +308,8 @@ def test_backward_single_mesh_dim(self): # num_or_sections = 2, axis = -2 # [0, -1, -1], [0, -1, -1] --> [0, -1, -1], [0, -1, -1], [0, -1, -1] # (outputs --> input, outputs) - self.rule = get_spmd_rule("split_with_num") - self.attrs = {} + self.rule = core.get_phi_spmd_rule("split_with_num") + self.attrs = OrderedDict() self.attrs['num'] = 2 self.attrs['axis'] = -2 self.out_spec_list = [] @@ -317,7 +320,10 @@ def test_backward_single_mesh_dim(self): self.out_spec_list[0].set_dims_mapping([0, -1, -1]) self.out_spec_list[1].set_dims_mapping([0, -1, -1]) result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec], self.out_spec_list, self.attrs + self.x_dist_tensor_spec, + self.out_spec_list, + self.attrs['num'], + self.attrs['axis'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -333,8 +339,8 @@ def test_backward_single_mesh_dim(self): # num_or_sections = 2, axis = -2 # [-1, 0, -1], [-1, -1, -1] --> [-1, -1, -1], [-1, -1, -1], [-1, -1, -1] # (outputs --> input, outputs) - self.rule = get_spmd_rule("split_with_num") - self.attrs = {} + self.rule = core.get_phi_spmd_rule("split_with_num") + self.attrs = OrderedDict() self.attrs['num'] = 2 self.attrs['axis'] = -2 self.out_spec_list = [] @@ -345,7 +351,10 @@ def test_backward_single_mesh_dim(self): self.out_spec_list[0].set_dims_mapping([-1, 0, -1]) self.out_spec_list[1].set_dims_mapping([-1, -1, -1]) result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec], self.out_spec_list, self.attrs + self.x_dist_tensor_spec, + self.out_spec_list, + self.attrs['num'], + self.attrs['axis'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -372,8 +381,8 @@ def test_backward_multi_mesh_dim(self): # [0, 1, -1, -1], [0, 1, -1, -1], [0, 1, -1, -1] --> # [0, 1, -1, -1], [0, 1, -1, -1], [0, 1, -1, -1], [0, 1, -1, -1] # (outputs --> input, outputs) - self.rule = get_spmd_rule("split_with_num") - self.attrs = {} + self.rule = core.get_phi_spmd_rule("split_with_num") + self.attrs = OrderedDict() self.attrs['num'] = 3 self.attrs['axis'] = -1 self.out_spec_list = [] @@ -402,7 +411,10 @@ def test_backward_multi_mesh_dim(self): self.out_spec_list[1].set_dims_mapping([0, 1, -1, -1]) self.out_spec_list[2].set_dims_mapping([0, 1, -1, -1]) result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec], self.out_spec_list, self.attrs + self.x_dist_tensor_spec, + self.out_spec_list, + self.attrs['num'], + self.attrs['axis'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -428,8 +440,8 @@ def test_backward_multi_mesh_dim(self): # [-1, 1, -1, -1], [-1, 1, -1, -1], [-1, 1, -1, -1] --> # [-1, 1, -1, -1], [-1, 1, -1, -1], [-1, 1, -1, -1], [-1, 1, -1, -1] # (outputs --> input, outputs) - self.rule = get_spmd_rule("split") - self.attrs = {} + self.rule = core.get_phi_spmd_rule("split") + self.attrs = OrderedDict() self.attrs['sections'] = [32, 32, 32] self.attrs['axis'] = 0 self.out_spec_list = [] @@ -443,7 +455,10 @@ def test_backward_multi_mesh_dim(self): self.out_spec_list[1].set_dims_mapping([-1, 1, -1, -1]) self.out_spec_list[2].set_dims_mapping([-1, 1, -1, -1]) result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec], self.out_spec_list, self.attrs + self.x_dist_tensor_spec, + self.out_spec_list, + self.attrs['sections'], + self.attrs['axis'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -469,8 +484,8 @@ def test_backward_multi_mesh_dim(self): # [0, -1, 1, -1], [-1, 1, -1, -1], [-1, -1, -1, -1] --> # [0, -1, -1, -1], [0, -1, -1, -1], [0, -1, -1, -1], [0, -1, -1, -1] # (outputs --> input, outputs) - self.rule = get_spmd_rule("split") - self.attrs = {} + self.rule = core.get_phi_spmd_rule("split") + self.attrs = OrderedDict() self.attrs['sections'] = [32, 32, 32] self.attrs['axis'] = 2 self.out_spec_list = [] @@ -484,7 +499,10 @@ def test_backward_multi_mesh_dim(self): self.out_spec_list[1].set_dims_mapping([-1, 1, -1, -1]) self.out_spec_list[2].set_dims_mapping([-1, -1, -1, -1]) result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec], self.out_spec_list, self.attrs + self.x_dist_tensor_spec, + self.out_spec_list, + self.attrs['sections'], + self.attrs['axis'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1]