diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index 785e80a3abeaa..ac02cd0fc87ac 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -74,6 +74,7 @@ using paddle::distributed::auto_parallel::SPMDRuleMap; using paddle::framework::BlockDesc; using paddle::framework::OpDesc; using paddle::framework::VarDesc; +using phi::distributed::ArgDistAttr; using phi::distributed::ProcessMesh; using phi::distributed::TensorDistAttr; using phi::distributed::auto_parallel::Device; @@ -143,9 +144,9 @@ static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) { dist_attr->clear_annotated(); } -static std::pair, std::vector> +static std::pair, std::vector> infer_forward(const phi::distributed::SpmdRule &self, const py::args &args); -static std::pair, std::vector> +static std::pair, std::vector> infer_backward(const phi::distributed::SpmdRule &self, const py::args &args); void BindAutoParallel(py::module *m) { @@ -703,7 +704,7 @@ static void prepare_ctx(phi::distributed::InferSpmdContext *ctx, parse_single_pyobject(obj, ctx, i); } } -static std::pair, std::vector> +static std::pair, std::vector> infer_forward(const phi::distributed::SpmdRule &self, const py::args &args) { VLOG(6) << "infer_forward "; phi::distributed::InferSpmdContext ctx; @@ -711,7 +712,7 @@ infer_forward(const phi::distributed::SpmdRule &self, const py::args &args) { return self.InferForward(ctx); } -static std::pair, std::vector> +static std::pair, std::vector> infer_backward(const phi::distributed::SpmdRule &self, const py::args &args) { VLOG(6) << "infer_backward "; phi::distributed::InferSpmdContext ctx; diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index acb45d058038e..2f1333042fe68 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -165,7 +165,14 @@ Tensor add_n_impl(const std::vector& x) { auto* kernel_fn = kernel.GetVariadicKernelFn(); (*kernel_fn)(*dev_ctx, input_x, dense_out); } - auto current_process_mesh = spmd_info.first[0].process_mesh(); + PADDLE_ENFORCE_EQ( + paddle::holds_alternative( + spmd_info.first[0]), + true, + phi::errors::PreconditionNotMet( + "Arg must be a single TensorDistAttr")); + auto current_process_mesh = + paddle::get<0>(spmd_info.first[0]).process_mesh(); SetReplicatedDistAttrForOutput(dist_out, current_process_mesh); return api_output; } diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index c020d0332c570..a39010ac2f73b 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -559,6 +559,15 @@ phi::distributed::DistTensor* SetKernelDistOutput( return nullptr; } +phi::distributed::DistTensor* SetKernelDistOutput( + Tensor* out, const phi::distributed::ArgDistAttr& dist_attr) { + PADDLE_ENFORCE_EQ( + paddle::holds_alternative(dist_attr), + true, + phi::errors::PreconditionNotMet("Arg must be a single TensorDistAttr")); + return SetKernelDistOutput(out, paddle::get<0>(dist_attr)); +} + std::shared_ptr CreateKernelDistOutput( Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) { if (out) { @@ -568,6 +577,19 @@ std::shared_ptr CreateKernelDistOutput( return nullptr; } +std::shared_ptr CreateKernelDistOutput( + Tensor* out, const phi::distributed::ArgDistAttr& dist_attr) { + if (out) { + PADDLE_ENFORCE_EQ( + paddle::holds_alternative(dist_attr), + true, + phi::errors::PreconditionNotMet("Arg must be a single TensorDistAttr")); + return std::make_shared( + phi::DDim(), paddle::get<0>(dist_attr)); + } + return nullptr; +} + std::vector SetKernelDistOutput( std::vector out) { std::vector result; diff --git a/paddle/phi/api/lib/api_gen_utils.h b/paddle/phi/api/lib/api_gen_utils.h index 5272c14209e0d..13f68ab7defbb 100644 --- a/paddle/phi/api/lib/api_gen_utils.h +++ b/paddle/phi/api/lib/api_gen_utils.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" #include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/sparse_coo_tensor.h" @@ -147,11 +148,17 @@ phi::distributed::DistTensor* SetKernelDistOutput( const phi::distributed::TensorDistAttr& dist_attr = phi::distributed::TensorDistAttr()); +phi::distributed::DistTensor* SetKernelDistOutput( + Tensor* out, const phi::distributed::ArgDistAttr& dist_attr); + std::shared_ptr CreateKernelDistOutput( Tensor* out, const phi::distributed::TensorDistAttr& dist_attr = phi::distributed::TensorDistAttr()); +std::shared_ptr CreateKernelDistOutput( + Tensor* out, const phi::distributed::ArgDistAttr& dist_attr); + std::vector SetKernelDistOutput( std::vector out); diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 561d8ce379b9d..7c88bd3df44b0 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -640,6 +640,70 @@ std::shared_ptr ReshardApiInputToKernelInput( return nullptr; } +std::shared_ptr ReshardApiInputToKernelInput( + phi::DeviceContext* dev_ctx, + const Tensor& tensor, + const phi::distributed::ArgDistAttr& dist_attr) { + PADDLE_ENFORCE_EQ( + paddle::holds_alternative(dist_attr), + true, + phi::errors::PreconditionNotMet("Arg must be a single TensorDistAttr")); + const auto& tensor_dist_attr = paddle::get<0>(dist_attr); + return ReshardApiInputToKernelInput(dev_ctx, tensor, tensor_dist_attr); +} + +std::vector> +ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const phi::distributed::ArgDistAttr& dist_attrs) { + PADDLE_ENFORCE_EQ( + paddle::holds_alternative>( + dist_attrs), + true, + phi::errors::PreconditionNotMet( + "Arg must be a vector of TensorDistAttr")); + const auto& tensor_dist_attrs = paddle::get<1>(dist_attrs); + return ReshardApiInputToKernelInput(dev_ctx, tensors, tensor_dist_attrs); +} + +std::vector> +ReshardApiInputToKernelInput( + phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const std::vector& dist_attrs) { + std::vector> output; + PADDLE_ENFORCE_EQ(tensors.size(), + dist_attrs.size(), + phi::errors::PreconditionNotMet( + "tensors size and dist_attrs size not equal: %d vs %d", + tensors.size(), + dist_attrs.size())); + for (size_t i = 0; i < dist_attrs.size(); i++) { + output.push_back( + ReshardApiInputToKernelInput(dev_ctx, tensors[i], dist_attrs[i])); + } + return output; +} + +std::vector> +ReshardApiInputToKernelInput( + phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const std::vector& dist_attrs) { + std::vector> output; + PADDLE_ENFORCE_EQ(tensors.size(), + dist_attrs.size(), + phi::errors::PreconditionNotMet( + "tensors size and dist_attrs size not equal: %d vs %d", + tensors.size(), + dist_attrs.size())); + for (size_t i = 0; i < dist_attrs.size(); i++) { + output.push_back( + ReshardApiInputToKernelInput(dev_ctx, tensors[i], dist_attrs[i])); + } + return output; +} + std::shared_ptr ReshardApiInputToReplicatedKernelInput( phi::DeviceContext* dev_ctx, @@ -688,6 +752,63 @@ ReshardApiInputToReplicatedKernelInput( return result; } +std::shared_ptr +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const Tensor& tensor, + const phi::distributed::ArgDistAttr& dist_attr) { + PADDLE_ENFORCE_EQ( + paddle::holds_alternative(dist_attr), + true, + phi::errors::PreconditionNotMet("Arg must be a TensorDistAttr")); + const auto& tensor_dist_attr = paddle::get<0>(dist_attr); + return ReshardApiInputToReplicatedKernelInput( + dev_ctx, tensor, tensor_dist_attr); +} + +paddle::optional> +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const paddle::optional& tensor, + const phi::distributed::ArgDistAttr& dist_attr) { + PADDLE_ENFORCE_EQ( + paddle::holds_alternative(dist_attr), + true, + phi::errors::PreconditionNotMet("Arg must be a TensorDistAttr")); + const auto& tensor_dist_attr = paddle::get<0>(dist_attr); + return ReshardApiInputToReplicatedKernelInput( + dev_ctx, tensor, tensor_dist_attr); +} + +std::vector> +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const std::vector& dist_attrs) { + std::vector> outputs; + for (size_t i = 0; i < tensors.size(); ++i) { + outputs.push_back(ReshardApiInputToReplicatedKernelInput( + dev_ctx, tensors[i], dist_attrs[i])); + } + return outputs; +} + +std::vector> +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const phi::distributed::ArgDistAttr& dist_attr) { + PADDLE_ENFORCE_EQ( + paddle::holds_alternative>( + dist_attr), + true, + phi::errors::PreconditionNotMet( + "Arg must be a vector of TensorDistAttr")); + const auto& tensor_dist_attrs = paddle::get<1>(dist_attr); + return ReshardApiInputToReplicatedKernelInput( + dev_ctx, tensors, tensor_dist_attrs); +} + void ReshardOutputPartialAxisToReplicated( phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor) { if (out_tensor->dist_attr().is_partial()) { diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index 8df013860a5ab..02d86622e2aa6 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" #include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/sparse_coo_tensor.h" @@ -180,24 +181,76 @@ std::shared_ptr ReshardApiInputToKernelInput( const Tensor& tensor, const phi::distributed::TensorDistAttr& dist_attr); +std::shared_ptr ReshardApiInputToKernelInput( + phi::DeviceContext* dev_ctx, + const Tensor& tensor, + const phi::distributed::ArgDistAttr& dist_attr); + +std::vector> +ReshardApiInputToKernelInput( + phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const std::vector& dist_attrs); + +std::vector> +ReshardApiInputToKernelInput( + phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const std::vector& dist_attrs); + +std::vector> +ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const phi::distributed::ArgDistAttr& dist_attrs); + +std::shared_ptr +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const Tensor& tensor, + const phi::distributed::ArgDistAttr& dist_attr); + std::shared_ptr ReshardApiInputToReplicatedKernelInput( phi::DeviceContext* dev_ctx, const Tensor& tensor, const phi::distributed::TensorDistAttr& dist_attr); +std::shared_ptr +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const Tensor& tensor, + const phi::distributed::ArgDistAttr& dist_attr); + paddle::optional> ReshardApiInputToReplicatedKernelInput( phi::DeviceContext* dev_ctx, const paddle::optional& tensor, const phi::distributed::TensorDistAttr& dist_attr); +paddle::optional> +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const paddle::optional& tensor, + const phi::distributed::ArgDistAttr& dist_attr); + std::vector> ReshardApiInputToReplicatedKernelInput( phi::DeviceContext* dev_ctx, const std::vector& tensors, const std::vector& dist_attrs); +std::vector> +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const std::vector& dist_attrs); + +std::vector> +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const std::vector& tensor, + const phi::distributed::ArgDistAttr& dist_attr); + void ReshardOutputPartialAxisToReplicated( phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor); diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index d6c90584cb540..059320c0058ed 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -81,13 +81,21 @@ # 1. InferSPMD SINGLE_DIST_META_IN_TEMPLATE = """ auto meta_dist_input_{name} = MakeDistMetaTensor(*{name}.impl());""" + +LIST_DIST_META_IN_TEMPLATE = """ + std::vector meta_dist_input_{name}; + for(auto& e: {name}){{ + meta_dist_input_{name}.push_back(MakeDistMetaTensor(*e.impl())); + }} +""" + OPTIONAL_SINGLE_DIST_META_IN_TEMPLATE = """ auto meta_dist_input_{name} = {name} ? MakeDistMetaTensor(*(*{name}).impl()) : phi::distributed::DistMetaTensor();""" INFER_SPMD_TEMPLATE = """ auto spmd_info = phi::distributed::{}({}); """ GENERAL_INFER_SPMD_TEMPLATE = """ - auto spmd_info = phi::distributed::VariadicReplicatedInferSpmd({}); + auto spmd_info = phi::distributed::VariadicReplicatedInferSpmdDynamic({}); """ UNSUPPORTED_INFER_SPMD_COMMENT_TEMPLATE = """ // API `{}` does not support InferSpmd now @@ -246,8 +254,9 @@ auto dist_input_{arg} = PrepareDataForDistTensor({arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel); auto input_{arg} = &dist_input_{arg}->value(); """ +# dist_input_ prefix VECTOR_PREPARE_DATA_TEMPLATE = """ - auto dist_input_{name}_vec = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + auto dist_input_{name}_vec = PrepareDataForDistTensor({prefix}{name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); std::vector dense_input_{name}_vec; for (auto tmp : dist_input_{name}_vec) {{ dense_input_{name}_vec.emplace_back(&tmp->value()); @@ -258,6 +267,7 @@ dense_input_{name}_meta_ptr_vec[i] = &dense_input_{name}_meta_vec[i]; }} """ + OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE = """ dist_input_{name} = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); paddle::optional input_{name} = dist_input_{name} ? paddle::make_optional((*dist_input_{name})->value()) : paddle::none; @@ -266,8 +276,10 @@ auto dist_input_{name} = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); paddle::optional input_{name} = dist_input_{name} ? paddle::make_optional(dist_input_{name}->value()) : paddle::none; """ + +# dist_input_ prefix OPTIONAL_VECTOR_PREPARE_DATA_TEMPLATE = """ - auto dist_input_{name}_vec = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + auto dist_input_{name}_vec = PrepareDataForDistTensor({prefix}{name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); std::vector dense_input_{name}_vec; if ({name}) {{ for (auto tmp : *dist_input_{name}_vec) {{ @@ -357,7 +369,8 @@ # 10. Set Output DistAttr for Default impl # Dist Branch will not generated in the API that doesn't have input tensor. CURRENT_PROCESS_MESH_TEMPLATE = """ - auto current_process_mesh = spmd_info.first[0].process_mesh();""" + auto current_process_mesh = paddle::holds_alternative(spmd_info.first[0]) ? + paddle::get<0>(spmd_info.first[0]).process_mesh() : paddle::get<1>(spmd_info.first[0]).at(0).process_mesh();""" SET_SINGLE_OUT_REPLICATED_DIST_ATTR_TEMPLATE = """ SetReplicatedDistAttrForOutput({}, current_process_mesh);""" SET_VECTOR_OUT_REPLICATED_DIST_ATTR_TEMPLATE = """ @@ -695,6 +708,15 @@ def generate_specialized_infer_spmd_code(self) -> str: name=param ) input_args_code += "meta_dist_input_" + param + ", " + elif ( + self.inputs['input_info'][param] + == "const std::vector&" + ): + input_decl_code += LIST_DIST_META_IN_TEMPLATE.format( + name=param + ) + input_args_code += "meta_dist_input_" + param + ", " + else: raise ValueError( f"{self.api} : Param of infer_spmd error : {self.inputs['input_info'][param]} type is not supported." @@ -747,7 +769,13 @@ def generate_general_infer_spmd_code(self) -> str: elif ( self.inputs['input_info'][param] == "const std::vector&" - or self.inputs['input_info'][param] + ): + input_decl_code += LIST_DIST_META_IN_TEMPLATE.format( + name=param + ) + input_args_code += "meta_dist_input_" + param + ", " + elif ( + self.inputs['input_info'][param] == "const paddle::optional>&" ): # TODO(chenweihang): support other input type later, @@ -1015,6 +1043,22 @@ def generate_reshard_input_code(self) -> str: arg=param, idx=i ) ) + elif ( + self.inputs['input_info'][param] + == "const std::vector&" + ): + if self.generate_general_infer_spmd is True: + input_reshard_code += ( + SINGLE_GENERAL_INPUT_RESHARD_TEMPLATE.format( + arg=param, idx=i + ) + ) + else: + input_reshard_code += ( + SINGLE_INPUT_RESHARD_TEMPLATE.format( + arg=param, idx=i + ) + ) else: raise ValueError( f"{self.api} : Param of reshard input error : {self.inputs['input_info'][param]} type is not supported." @@ -1067,8 +1111,9 @@ def generate_vector_dense_input( kernel_param = self.kernel['param'] if kernel_param is None: kernel_param = input_names + attr_names - + prefix = "dist_input_" if self.generate_infer_spmd else "" input_tensor_code += VECTOR_PREPARE_DATA_TEMPLATE.format( + prefix=prefix, name=input_name, index=kernel_param.index(input_name), trans_flag=trans_flag, @@ -1116,8 +1161,9 @@ def generate_optional_vector_dense_input( kernel_param = self.kernel['param'] if kernel_param is None: kernel_param = input_names + attr_names - + prefix = "dist_input_" if self.generate_infer_spmd else "" input_tensor_code += OPTIONAL_VECTOR_PREPARE_DATA_TEMPLATE.format( + prefix=prefix, name=input_name, index=kernel_param.index(input_name), trans_flag=trans_flag, diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index e6b11884f74eb..5a0c6abc7688b 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -505,6 +505,7 @@ infer_meta : func : ConcatInferMeta param : [x, axis] + spmd_rule : ConcatInferSpmdDynamic kernel : func : concat data_type : x diff --git a/paddle/phi/core/distributed/auto_parallel/dist_attr.cc b/paddle/phi/core/distributed/auto_parallel/dist_attr.cc index 3c95f2c3ff66f..052a6d457ca8b 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_attr.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_attr.cc @@ -399,7 +399,7 @@ bool TensorDistAttr::is_replicated(int64_t mesh_axis) const { bool TensorDistAttr::is_shard(int64_t mesh_axis, int64_t tensor_axis) const { auto placement = to_placement(); if (mesh_axis == -1) { - return std::all_of(placement.begin(), + return std::any_of(placement.begin(), placement.end(), [tensor_axis](std::shared_ptr status) { return status->is_shard(tensor_axis); diff --git a/paddle/phi/core/distributed/auto_parallel/dist_attr.h b/paddle/phi/core/distributed/auto_parallel/dist_attr.h index f051592b7bf7e..6689750d24ad9 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_attr.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_attr.h @@ -32,6 +32,8 @@ limitations under the License. */ namespace phi { namespace distributed { +constexpr int kReplicateDim = -1; + class PlacementStatus { public: virtual ~PlacementStatus() = default; diff --git a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h index 4781b5d872001..2d444decf640a 100644 --- a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h +++ b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h @@ -125,6 +125,28 @@ struct InferSpmdFnImpl { } }; + // direct vector + template + struct InferSpmdFnCallHelper&, Tail...> { + template + static SpmdInfo Call(const InferSpmdContext& ctx, PreviousArgs&... pargs) { + static_assert(attr_idx == 0, + "InferSpmd's Input should appear before Attributes."); + // TODO(liuzhenhai): parse input list as vector directly + const std::pair range = ctx.InputRangeAt(in_idx); + std::vector tmp_arg = + ctx.InputsBetween(range.first, range.second); + std::vector arg; + std::transform(tmp_arg.begin(), + tmp_arg.end(), + std::back_inserter(arg), + [](const DistMetaTensor* arg_ptr) { return *arg_ptr; }); + return InferSpmdFnCallHelper::template Call( + ctx, pargs..., arg); + } + }; + #define PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(attr_type) \ template \ struct InferSpmdFnCallHelper { \ diff --git a/paddle/phi/core/distributed/type_defs.h b/paddle/phi/core/distributed/type_defs.h index cd201ac5c5aaf..1b7035c1a4528 100644 --- a/paddle/phi/core/distributed/type_defs.h +++ b/paddle/phi/core/distributed/type_defs.h @@ -18,12 +18,16 @@ #include #include +#include "paddle/utils/variant.h" + namespace phi { namespace distributed { class TensorDistAttr; -using SpmdInfo = - std::pair, std::vector>; +using ArgDistAttr = + paddle::variant>; + +using SpmdInfo = std::pair, std::vector>; } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/concat.cc b/paddle/phi/infermeta/spmd_rules/concat.cc new file mode 100644 index 0000000000000..fd036cfad603a --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/concat.cc @@ -0,0 +1,187 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/concat.h" + +#include +#include + +#include "paddle/phi/infermeta/spmd_rules/elementwise.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +static bool IsEmpty(const std::vector& shape) { + return shape.empty() || shape.at(0) == 0; +} + +SpmdInfo ConcatInferSpmd(const std::vector& x, int axis) { + /* +# paddle.concat requires all tensors must either have the same shape (except +# in the concatenating dimension) or be "empty". "Empty" here strictly means +# tensor.shape is torch.Size([0]). When tensor.ndim > 1, it will be treated +# as a non-empty tensor and the shape must match on non-cat dimensions. + */ + + // 1、check tensors shapes + std::vector> tensor_shapes; + std::transform(x.begin(), + x.end(), + std::back_inserter(tensor_shapes), + [](const DistMetaTensor& meta) { + return phi::vectorize(meta.dims()); + }); + bool all_empty = + std::all_of(tensor_shapes.begin(), tensor_shapes.end(), IsEmpty); + if (all_empty) { + return SpmdInfo(); + } + + auto non_empty_iter = + std::find_if(tensor_shapes.begin(), tensor_shapes.end(), [](auto& shape) { + return !IsEmpty(shape); + }); + auto non_empty_index = non_empty_iter - tensor_shapes.begin(); + int64_t ndim = static_cast(tensor_shapes[non_empty_index].size()); + // normlize dim + int64_t dim = axis; + dim = dim < 0 ? dim + ndim : dim; + + std::vector input_attrs; + // 2、make sure all tensors replicated on concat dim + auto n_inputs = x.size(); + for (size_t i = 0; i < n_inputs; ++i) { + const auto& dist_attr = x[i].dist_attr(); + if ((!IsEmpty(tensor_shapes[i])) && IsDimSharded(dist_attr, dim)) { + auto sharded_dist_attr = ReplicateTensorDim(dist_attr, dim); + input_attrs.emplace_back(sharded_dist_attr); + } else { + input_attrs.emplace_back(dist_attr); + } + } + // 3、align non-concat dimensions according to cost + std::vector>> inputs_placements; + std::transform( + input_attrs.begin(), + input_attrs.end(), + std::back_inserter(inputs_placements), + [](const TensorDistAttr& attr) { return attr.to_placement(); }); + const auto& process_mess = input_attrs[non_empty_index].process_mesh(); + auto has_mismatch = [&](int32_t mesh_dim) { + bool mismatch = false; + for (size_t i = 0; i < n_inputs; i++) { + if ((!IsEmpty(tensor_shapes[i])) && + !PlacementEqual(inputs_placements[non_empty_index][mesh_dim], + inputs_placements[i][mesh_dim])) { + mismatch = true; + break; + } + } + return mismatch; + }; + bool need_reshard = false; + int32_t n_mesh_dim = process_mess.ndim(); + std::vector> best_placements( + n_mesh_dim, std::make_shared()); + // a dim can not be sharded twice along diffrent mesh_dim + std::set sharded_dims = {dim}; + + for (int32_t mesh_dim = 0; mesh_dim < process_mess.ndim(); ++mesh_dim) { + if (!has_mismatch(mesh_dim)) { + // use the old placement + auto& best = inputs_placements[non_empty_index][mesh_dim]; + if (best->is_shard()) { + auto shard_placement = std::dynamic_pointer_cast(best); + sharded_dims.insert(shard_placement->get_axis()); + } + best_placements[mesh_dim] = best; + } + } + + for (int32_t mesh_dim = 0; mesh_dim < process_mess.ndim(); ++mesh_dim) { + if (!has_mismatch(mesh_dim)) { + continue; + } + need_reshard = true; + std::vector costs; + for (int32_t shard_dim = 0; shard_dim < ndim; shard_dim++) { + double cost = std::numeric_limits::infinity(); + if (!sharded_dims.count(shard_dim)) { + cost = 0.0; + for (size_t i = 0; i < n_inputs; i++) { + auto& tensor_shape = tensor_shapes[i]; + auto& tensor_dist_attr = input_attrs[i]; + if (IsEmpty(tensor_shape)) { + continue; + } + + if (tensor_shape[shard_dim] < process_mess.dim_size(mesh_dim)) { + // should not be selected + cost += std::numeric_limits::infinity(); + continue; + } + if (IsDimSharded(tensor_dist_attr, shard_dim)) { + continue; + } + int64_t num = std::accumulate(tensor_shape.begin(), + tensor_shape.end(), + 1, + std::multiplies()); + if (num == static_cast(0)) { + continue; + } + std::vector local_shape = + GetLocalShape(tensor_shape, process_mess, inputs_placements[i]); + cost += std::accumulate(local_shape.begin(), + local_shape.end(), + 1, + std::multiplies()) * + process_mess.dim_size(mesh_dim); + } + } + costs.push_back(cost); + } + auto min_itr = std::min_element(costs.begin(), costs.end()); + auto min_dim = min_itr - costs.begin(); + if (!sharded_dims.count(min_dim)) { + best_placements[mesh_dim] = std::make_shared(min_dim); + sharded_dims.insert(min_dim); + } + } + // set placement to the best placements + if (need_reshard) { + std::vector new_input_attrs; + for (auto& e : input_attrs) { + new_input_attrs.emplace_back(FromPlacements(e, best_placements)); + } + std::swap(input_attrs, new_input_attrs); + } + return {{input_attrs}, {input_attrs[non_empty_index]}}; +} + +SpmdInfo ConcatInferSpmdReverse(const std::vector& x, + const DistMetaTensor& output, + int axis) { + // TODO(liuzhenhai): add latter + return SpmdInfo(); +} +SpmdInfo ConcatInferSpmdDynamic(const std::vector& x, + const Scalar& axis) { + return ConcatInferSpmd(x, axis.to()); +} +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/concat.h b/paddle/phi/infermeta/spmd_rules/concat.h new file mode 100644 index 0000000000000..0f7435bec0b23 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/concat.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { +SpmdInfo ConcatInferSpmd(const std::vector& x, int axis); + +SpmdInfo ConcatInferSpmdReverse(const std::vector& x, + const DistMetaTensor& output, + int axis); + +SpmdInfo ConcatInferSpmdDynamic(const std::vector& x, + const Scalar& axis); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/default_data_parallel.cc b/paddle/phi/infermeta/spmd_rules/default_data_parallel.cc index eb469200a7ec8..7a3639147f1ee 100644 --- a/paddle/phi/infermeta/spmd_rules/default_data_parallel.cc +++ b/paddle/phi/infermeta/spmd_rules/default_data_parallel.cc @@ -95,7 +95,8 @@ SpmdInfo DefaultDataParallelInferSpmd( << str_join(output_dist_attrs[i].dims_mapping()) << "]"; } - return {dst_input_dist_attrs, output_dist_attrs}; + return {ToArgDistAttr(dst_input_dist_attrs), + ToArgDistAttr(output_dist_attrs)}; } SpmdInfo DefaultDataParallelInferSpmdReverse( const std::vector& ins, @@ -157,7 +158,8 @@ SpmdInfo DefaultDataParallelInferSpmdReverse( << str_join(dst_input_dist_attrs[i].dims_mapping()) << "]"; } - return {dst_input_dist_attrs, output_dist_attrs}; + return {ToArgDistAttr(dst_input_dist_attrs), + ToArgDistAttr(output_dist_attrs)}; } } // namespace distributed diff --git a/paddle/phi/infermeta/spmd_rules/layer_norm.cc b/paddle/phi/infermeta/spmd_rules/layer_norm.cc index 6befef19cfef1..1dfe8bf19c296 100644 --- a/paddle/phi/infermeta/spmd_rules/layer_norm.cc +++ b/paddle/phi/infermeta/spmd_rules/layer_norm.cc @@ -275,7 +275,7 @@ SpmdInfo LayerNormInferSpmdReverse(const DistMetaTensor& x, } VLOG(4) << std::endl; - return {input_dist_attrs, output_dist_attrs}; + return {ToArgDistAttr(input_dist_attrs), ToArgDistAttr(output_dist_attrs)}; } } // namespace distributed diff --git a/paddle/phi/infermeta/spmd_rules/matmul.cc b/paddle/phi/infermeta/spmd_rules/matmul.cc index 4893c7071f19e..60c7acacf0478 100644 --- a/paddle/phi/infermeta/spmd_rules/matmul.cc +++ b/paddle/phi/infermeta/spmd_rules/matmul.cc @@ -291,17 +291,22 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x, const DistMetaTensor& out_grad, bool trans_x, bool trans_y) { - auto confirm_dist_attr_same_fn = [&](const TensorDistAttr& x_dist_attr, + auto get_attr = [](const ArgDistAttr& attr) -> const TensorDistAttr& { + return paddle::get(attr); + }; + + auto confirm_dist_attr_same_fn = [&](const ArgDistAttr& x_dist_attr, const DistMetaTensor& y, const char* debug_msg) { + const auto& x_single_dist_attr = get_attr(x_dist_attr); PADDLE_ENFORCE_EQ( - DistAttrsAreBasicallyEqual(x_dist_attr, y.dist_attr()), + DistAttrsAreBasicallyEqual(x_single_dist_attr, y.dist_attr()), true, phi::errors::Unavailable("The matmul grad infer spmd `%s` verify " "error: left dist attr is %s, " "right dist attr is %s.", debug_msg, - x_dist_attr, + x_single_dist_attr, y.dist_attr())); }; @@ -313,8 +318,8 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x, // so it cannot be handled correctly in the backward for the time being // For this case, we uniformly transition the input to the Replicated state. auto fwd_spmd_info = MatmulInferSpmd(x, y, trans_x, trans_y); - if (x.dist_attr() != fwd_spmd_info.first[0] || - y.dist_attr() != fwd_spmd_info.first[1]) { + if (x.dist_attr() != get_attr(fwd_spmd_info.first[0]) || + y.dist_attr() != get_attr(fwd_spmd_info.first[1])) { auto x_r_dist_attr = GetReplicatedDistAttr(x.dist_attr()); auto y_r_dist_attr = GetReplicatedDistAttr(y.dist_attr()); return {{x_r_dist_attr, diff --git a/paddle/phi/infermeta/spmd_rules/replicated.cc b/paddle/phi/infermeta/spmd_rules/replicated.cc index b2b3b019be039..d0c90f7b2d2a9 100644 --- a/paddle/phi/infermeta/spmd_rules/replicated.cc +++ b/paddle/phi/infermeta/spmd_rules/replicated.cc @@ -86,7 +86,8 @@ SpmdInfo ReplicatedInferSpmd(const std::vector& ins, << str_join(output_dist_attrs[i].dims_mapping()) << "]"; } - return {dst_input_dist_attrs, output_dist_attrs}; + return {ToArgDistAttr(dst_input_dist_attrs), + ToArgDistAttr(output_dist_attrs)}; } SpmdInfo ReplicatedInferSpmdReverse( @@ -135,7 +136,53 @@ SpmdInfo ReplicatedInferSpmdReverse( << str_join(dst_input_dist_attrs[i].dims_mapping()) << "]"; } - return {dst_input_dist_attrs, output_dist_attrs}; + return {ToArgDistAttr(dst_input_dist_attrs), + ToArgDistAttr(output_dist_attrs)}; +} + +SpmdInfo ReplicatedInferDynamic( + const std::vector*>>& + inputs) { + std::vector nonnull_inputs; + int64_t ninputs = inputs.size(); + SpmdInfo spmd_info; + + auto build_tensor_dist_attr = + [&nonnull_inputs](const DistMetaTensor& dist_meta_tensor) { + int ndim = dist_meta_tensor.dims().size(); + TensorDistAttr dist_attr_dst = + CopyTensorDistAttrForOutput(dist_meta_tensor.dist_attr()); + // `ndim == -1` means input is nullptr + if (ndim >= 0) { + std::vector dst_dims_maping = GetReplicatedDimsmapping(ndim); + dist_attr_dst.set_dims_mapping(dst_dims_maping); + nonnull_inputs.push_back(&dist_meta_tensor); + } + return dist_attr_dst; + }; + + for (int64_t i = 0; i < ninputs; i++) { + if (paddle::holds_alternative(inputs[i])) { + auto dist_meta_tensor_ptr = paddle::get<0>(inputs[i]); + auto& dist_meta_tensor = *dist_meta_tensor_ptr; + auto dist_attr_dst = build_tensor_dist_attr(dist_meta_tensor); + VLOG(4) << "input " << i << ": dist attr: " << dist_attr_dst.to_string(); + spmd_info.first.emplace_back(dist_attr_dst); + } else { + std::vector list_dist_attr; + auto dist_meta_tensors_ptr = paddle::get<1>(inputs[i]); + auto& dist_meta_tensors = *dist_meta_tensors_ptr; + for (const auto& dist_meta_tensor : dist_meta_tensors) { + auto dist_attr_dst = build_tensor_dist_attr(dist_meta_tensor); + VLOG(4) << "input " << i + << ": dist attr: " << dist_attr_dst.to_string(); + list_dist_attr.emplace_back(std::move(dist_attr_dst)); + } + spmd_info.first.emplace_back(std::move(list_dist_attr)); + } + } + return spmd_info; } } // namespace distributed diff --git a/paddle/phi/infermeta/spmd_rules/replicated.h b/paddle/phi/infermeta/spmd_rules/replicated.h index a8d6c0719f2ec..1f3a26cb426d4 100644 --- a/paddle/phi/infermeta/spmd_rules/replicated.h +++ b/paddle/phi/infermeta/spmd_rules/replicated.h @@ -41,6 +41,19 @@ SpmdInfo ReplicatedInferSpmdReverse( const std::vector& ins, const std::vector& outs); +SpmdInfo ReplicatedInferDynamic( + const std::vector*>>& + inputs); + +// For phi api +template +SpmdInfo VariadicReplicatedInferSpmdDynamic(const Args&... args) { + return detail::ReplicateInferSpmdDynamicHelper() + .apply(args...) + .Infer(); +} + // For phi api template SpmdInfo VariadicReplicatedInferSpmd(const Args&... args) { diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index 3eb63b5e7d0ee..eb3a97ce053c3 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/infermeta/spmd_rules/concat.h" #include "paddle/phi/infermeta/spmd_rules/default_data_parallel.h" #include "paddle/phi/infermeta/spmd_rules/elementwise.h" #include "paddle/phi/infermeta/spmd_rules/embedding.h" @@ -523,6 +524,10 @@ PD_REGISTER_SPMD_RULE(slice, PD_INFER_SPMD(phi::distributed::SliceInferSpmd), PD_INFER_SPMD(phi::distributed::SliceInferSpmdReverse)); +PD_REGISTER_SPMD_RULE(concat, + PD_INFER_SPMD(phi::distributed::ConcatInferSpmd), + PD_INFER_SPMD(phi::distributed::ConcatInferSpmdReverse)); + // transpose rule PD_REGISTER_SPMD_RULE( transpose, diff --git a/paddle/phi/infermeta/spmd_rules/split.cc b/paddle/phi/infermeta/spmd_rules/split.cc index 4bc2a9ce0bdb1..0856fec2e89df 100644 --- a/paddle/phi/infermeta/spmd_rules/split.cc +++ b/paddle/phi/infermeta/spmd_rules/split.cc @@ -92,8 +92,10 @@ SpmdInfo SplitWithNumInferSpmd(const DistMetaTensor& x, int num, int axis) { << str_join(out_dims_mapping) << "]"; } VLOG(4) << std::endl; - - return {{x_dist_attr_dst}, out_dist_attrs}; + // TODO(liuzhenhai): remedy this + // should return list in list [] + // return {{x_dist_attr_dst}, {out_dist_attrs}}; + return {{x_dist_attr_dst}, ToArgDistAttr(out_dist_attrs)}; } SpmdInfo SplitWithNumInferSpmdReverse( @@ -193,8 +195,9 @@ SpmdInfo SplitWithNumInferSpmdReverse( } VLOG(4) << "Input shape: [" << str_join(x_shape) << "] " << "dims_mapping: [" << str_join(x_dims_mapping) << "]\n\n"; - - return {{x_dist_attr}, out_dist_attrs}; + // TODO(liuzhenhai): remedy this + // return {{x_dist_attr}, {out_dist_attrs}}; + return {{x_dist_attr}, ToArgDistAttr(out_dist_attrs)}; } SpmdInfo SplitInferSpmd(const DistMetaTensor& x, diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc index dc6141f3ec0ce..42bbc659b2f2b 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.cc +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -164,6 +164,99 @@ TensorDistAttr GetReplicatedDistAttr(const TensorDistAttr& dist_attr) { return dst_dist_attr; } +TensorDistAttr ReplicateTensorDim(const TensorDistAttr& dist_attr, int dim) { + TensorDistAttr dst_dist_attr = CopyTensorDistAttrForOutput(dist_attr); + std::vector dims_mapping = dist_attr.dims_mapping(); + dims_mapping[dim] = kReplicateDim; + dst_dist_attr.set_dims_mapping(dims_mapping); + return dst_dist_attr; +} + +bool IsDimSharded(const TensorDistAttr& dist_attr, int dim) { + return dist_attr.is_shard(-1, dim); +} + +bool PlacementEqual(const std::shared_ptr& a, + const std::shared_ptr& b) { + if (a->is_partial()) { + if (!b->is_partial()) { + return false; + } + auto a_partial = std::dynamic_pointer_cast(a); + auto b_partial = std::dynamic_pointer_cast(b); + return a_partial->get_reduce_type() == b_partial->get_reduce_type(); + } + if (a->is_replicated()) { + if (b->is_replicated()) { + return true; + } + return false; + } + if (!b->is_shard()) { + return false; + } + + auto a_shard = std::dynamic_pointer_cast(a); + auto b_shard = std::dynamic_pointer_cast(b); + return a_shard->get_axis() == b_shard->get_axis(); +} + +TensorDistAttr FromPlacements( + const TensorDistAttr& dist_attr, + const std::vector>& placements) { + TensorDistAttr dst_dist_attr = CopyTensorDistAttrForOutput(dist_attr); + std::vector dims_mapping(dist_attr.dims_mapping().size(), -1); + paddle::flat_hash_map partial_status; + + for (size_t mesh_dim = 0; mesh_dim < placements.size(); mesh_dim++) { + auto& placement = placements[mesh_dim]; + if (placement->is_shard()) { + auto shard_placement = std::dynamic_pointer_cast(placement); + dims_mapping[shard_placement->get_axis()] = mesh_dim; + } + if (placement->is_partial()) { + auto partial_placement = + std::dynamic_pointer_cast(placement); + auto reduce_type = partial_placement->get_reduce_type(); + partial_status[mesh_dim] = reduce_type; + } + } + dst_dist_attr.set_dims_mapping(dims_mapping); + dst_dist_attr.set_partial_status(partial_status); + return dst_dist_attr; +} + +std::vector ToArgDistAttr( + const std::vector& dist_attrs) { + std::vector items_dist_attrs; + std::transform( + dist_attrs.begin(), + dist_attrs.end(), + std::back_inserter(items_dist_attrs), + [](const TensorDistAttr& attr) -> ArgDistAttr { return {attr}; }); + return items_dist_attrs; +} + +std::vector GetLocalShape( + const std::vector shape, + const ProcessMesh& mesh, + const std::vector>& placements) { + auto local_shape = shape; + auto n_placement = placements.size(); + for (size_t i = 0; i < n_placement; i++) { + auto& placement = placements.at(i); + if (placement->is_shard()) { + auto mesh_dim_size = mesh.dim_size(i); + auto shard_dim = + std::dynamic_pointer_cast(placement)->get_axis(); + auto split_size = + (shape.at(shard_dim) + mesh_dim_size - 1) / mesh_dim_size; + local_shape[shard_dim] = split_size; + } + } + return local_shape; +} + std::vector GetDimsMappingForAxes( const std::string& axes, const std::unordered_map& axis_to_dim_map, diff --git a/paddle/phi/infermeta/spmd_rules/utils.h b/paddle/phi/infermeta/spmd_rules/utils.h index cd140c68fc8ac..b5b5e207a0ee6 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.h +++ b/paddle/phi/infermeta/spmd_rules/utils.h @@ -69,6 +69,25 @@ std::vector ResoluteOutputPartialDimension( // Repliacated state TensorDistAttr GetReplicatedDistAttr(const TensorDistAttr& dist_attr); +bool IsDimSharded(const TensorDistAttr& dist_attr, int dim); + +std::vector GetLocalShape( + const std::vector shape, + const ProcessMesh& mesh, + const std::vector>& placements); + +TensorDistAttr FromPlacements( + const TensorDistAttr& dist_attr, + const std::vector>& placements); + +std::vector ToArgDistAttr( + const std::vector& dist_attrs); + +TensorDistAttr ReplicateTensorDim(const TensorDistAttr& dist_attr, int dim); + +bool PlacementEqual(const std::shared_ptr& a, + const std::shared_ptr& b); + // Adaptor for variadic arguments template struct ArgsIterator { @@ -131,6 +150,28 @@ struct VariadicSpmdRuleArgumentParser SpmdInfo InferBackward() { return Fn(inputs, outputs); } }; + +using DynamicSpmdFn = SpmdInfo (*)( + const std::vector*>>&); + +template +struct ReplicateInferSpmdDynamicHelper + : public ArgsIterator> { + SpmdInfo Infer() { return Fn(inputs); } + + void operator()(const DistMetaTensor& x) { inputs.emplace_back(&x); } + void operator()(const std::vector& x) { + inputs.emplace_back(&x); + } + + void operator()(std::vector&& x) = delete; + void operator()(DistMetaTensor&& x) = delete; + + std::vector*>> + inputs; +}; } // namespace detail // Get dims mapping for the given axes according to sharding information of diff --git a/test/auto_parallel/semi_auto_parallel_for_concat.py b/test/auto_parallel/semi_auto_parallel_for_concat.py new file mode 100644 index 0000000000000..24605825d5f15 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_concat.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from semi_auto_parallel_util import SemiAutoParallelTestBase + +import paddle +import paddle.distributed as dist + + +class TestSplitAndConcatSemiAutoParallel(SemiAutoParallelTestBase): + def __init__(self): + super().__init__() + + def test_concat_forward(self): + shapes = [[16, 4, 4], [64, 4, 4]] + specs = [[None, None, 'x'], [None, None, 'x']] + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=paddle.concat, + with_backward=False, + axis=0, + ) + + def test_concat_forward_reshard(self): + shapes = [[16, 4, 4], [64, 4, 4]] + specs = [['x', None, None], [None, None, 'x']] + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=paddle.concat, + with_backward=False, + axis=0, + ) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + + self.test_concat_forward() + # all to all is not supported yet for cpu + if self._backend == "gpu": + self.test_concat_forward_reshard() + + +if __name__ == '__main__': + TestSplitAndConcatSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_for_matmul.py b/test/auto_parallel/semi_auto_parallel_for_matmul.py index 279062f483058..470100e9c3bc8 100644 --- a/test/auto_parallel/semi_auto_parallel_for_matmul.py +++ b/test/auto_parallel/semi_auto_parallel_for_matmul.py @@ -30,7 +30,7 @@ def __init__(self): def check_tensor_eq(self, a, b): np1 = a.numpy() np2 = b.numpy() - np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + np.testing.assert_allclose(np1, np2, rtol=1e-04, verbose=True) def test_body( self, x_shape, y_shape, x_specs, y_specs, trans_x=False, trans_y=False diff --git a/test/auto_parallel/semi_auto_parallel_util.py b/test/auto_parallel/semi_auto_parallel_util.py new file mode 100644 index 0000000000000..cfb905e8382a2 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_util.py @@ -0,0 +1,133 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist + + +class SemiAutoParallelTestBase: + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def check_tensor_eq(self, a, b): + np1 = a.numpy() + np2 = b.numpy() + np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + + def flatten(self, inputs, terminal_cond): + """ + inputs may be single tensor、tuple + """ + + if terminal_cond(inputs): + return [inputs], "i" + + assert isinstance(inputs, (tuple, list)) + flattened = [] + structure = [] + for i in range(len(inputs)): + tmp, tmp_structure = self.flatten(inputs[i], terminal_cond) + flattened.extend(tmp) + structure.append(tmp_structure) + + if isinstance(inputs, tuple): + structure = tuple(structure) + return flattened, structure + + def unflatten(self, inputs, structure, offset=0): + """ + inputs may be single tensor + """ + assert isinstance(inputs, list) + assert offset < len(inputs) + if structure == "i": + offset = offset + 1 + # return a list + return inputs[offset - 1], offset + assert isinstance(structure, (tuple, list)) + unflattened = [] + for i in range(len(structure)): + tmp, offset = self.unflatten(inputs, structure[i], offset) + unflattened.append(tmp) + if isinstance(structure, tuple): + unflattened = tuple(unflattened) + return unflattened, offset + + def runfunc_and_check( + self, inputs_shape, inputs_specs, op_func, with_backward, **kwargs + ): + paddle.seed(self._seed) + np.random.seed(self._seed) + + flat_inputs = [] + flat_dist_inputs = [] + + def terminal_cond(x): + return isinstance(x, list) and all( + not isinstance(e, (list, tuple)) for e in x + ) + + flat_inputs_specs, inputs_structure = self.flatten( + inputs_specs, terminal_cond + ) + flat_inputs_shape, _ = self.flatten(inputs_shape, terminal_cond) + assert len(flat_inputs_specs) == len(flat_inputs_shape) + + for shape, spec in zip(flat_inputs_shape, flat_inputs_specs): + input_np = np.random.random(size=shape).astype(self._dtype) + input = paddle.to_tensor(input_np) + input.stop_gradient = False + input_dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=spec + ) + dist_input = dist.shard_tensor(input, dist_attr=input_dist_attr) + dist_input.stop_gradient = False + flat_inputs.append(input) + flat_dist_inputs.append(dist_input) + inputs, _ = self.unflatten(flat_inputs, inputs_structure) + dist_inputs, _ = self.unflatten(flat_dist_inputs, inputs_structure) + + def wrap_tuple(e): + return e if isinstance(e, tuple) else (e,) + + op_inputs = wrap_tuple(inputs) + op_dist_input = wrap_tuple(dist_inputs) + + out = op_func(*op_inputs, **kwargs) + dist_out = op_func(*op_dist_input, **kwargs) + + if with_backward: + + def terminal_cond2(x): + return not isinstance(x, (list, tuple)) + + flat_out, _ = self.flatten(out, terminal_cond2) + flat_dist_out, _ = self.flatten(dist_out, terminal_cond2) + assert len(flat_out) == len(flat_dist_out) + for output, dist_output in zip(flat_out, flat_dist_out): + self.check_tensor_eq(out, dist_out) + output.backward() + dist_output.backward() + + for x, dist_x in zip(flat_inputs, flat_dist_inputs): + self.check_tensor_eq(x.grad, dist_x.grad) + + return dist_inputs, dist_out diff --git a/test/auto_parallel/spmd_rules/CMakeLists.txt b/test/auto_parallel/spmd_rules/CMakeLists.txt index f5d45ecaafc3f..5c8f78b6c6544 100644 --- a/test/auto_parallel/spmd_rules/CMakeLists.txt +++ b/test/auto_parallel/spmd_rules/CMakeLists.txt @@ -20,6 +20,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_layer_norm_rule MODULES test_layer_norm_rule) py_test_modules(test_slice_rule MODULES test_slice_rule) py_test_modules(test_flatten_rule MODULES test_flatten_rule) + py_test_modules(test_concat_rule MODULES test_concat_rule) # End of unittests WITH single card WITHOUT timeout endif() diff --git a/test/auto_parallel/spmd_rules/test_concat_rule.py b/test/auto_parallel/spmd_rules/test_concat_rule.py new file mode 100644 index 0000000000000..b1e1c11a0622e --- /dev/null +++ b/test/auto_parallel/spmd_rules/test_concat_rule.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddle.distributed.auto_parallel.static.dist_attribute import ( + DistTensorSpec, + TensorDistAttr, +) +from paddle.distributed.fleet import auto +from paddle.framework import core + + +class TestConcatSPMDRule(unittest.TestCase): + """ + Unit tests for split spmd rule. + """ + + def setUp(self): + self.process_mesh = auto.ProcessMesh(mesh=[[0, 1], [2, 3]]) + self.shapes = [[16, 16, 16], [4, 16, 16], [2, 16, 16]] + self.dim_mappings = [[-1, 0, 1], [-1, 1, 0], [-1, -1, 0]] + + def build_inputs(self): + inputs = [] + for shape, dim_mapping in zip(self.shapes, self.dim_mappings): + tensor_dist_attr = TensorDistAttr() + tensor_dist_attr.dims_mapping = dim_mapping + tensor_dist_attr.process_mesh = self.process_mesh + inputs.append(DistTensorSpec(shape, tensor_dist_attr)) + return inputs + + def test_infer_forward(self): + inputs = self.build_inputs() + rule = core.get_phi_spmd_rule("concat") + infered_dist_attrs = rule.infer_forward(inputs, 0) + infered_input_dist_attrs = infered_dist_attrs[0] + self.assertEqual(len(infered_input_dist_attrs), 1) + infered_output_dist_attrs = infered_dist_attrs[1] + self.assertEqual(len(infered_output_dist_attrs), 1) + for input_dist_attr in infered_input_dist_attrs[0]: + self.assertEqual(input_dist_attr.dims_mapping, [-1, 1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, 0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/test_semi_auto_parallel_basic.py b/test/auto_parallel/test_semi_auto_parallel_basic.py index 3730e019f7506..b1132a6a3a8dc 100644 --- a/test/auto_parallel/test_semi_auto_parallel_basic.py +++ b/test/auto_parallel/test_semi_auto_parallel_basic.py @@ -46,6 +46,16 @@ def test_elementwise_api(self): user_defined_envs=envs, ) + def test_concat_api(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_concat.py", + user_defined_envs=envs, + ) + def test_reduction_api(self): envs_list = test_base.gen_product_envs_list( self._default_envs, self._changeable_envs diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index 42476d7bb323f..eb6d08542b04a 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include +#include #include "glog/logging.h" #include "gtest/gtest.h" @@ -23,6 +24,7 @@ limitations under the License. */ #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" +#include "paddle/phi/core/distributed/type_defs.h" #include "paddle/phi/infermeta/spmd_rules/replicated.h" #include "paddle/phi/infermeta/spmd_rules/rules.h" @@ -30,6 +32,68 @@ namespace paddle { namespace distributed { namespace auto_parallel { +auto& get_dims_mapping(const phi::distributed::ArgDistAttr& dist_attr) { + EXPECT_TRUE( + paddle::holds_alternative(dist_attr)); + const auto& tensor_attr = paddle::get<0>(dist_attr); + return tensor_attr.dims_mapping(); +} + +bool is_partial(const phi::distributed::ArgDistAttr& dist_attr) { + EXPECT_TRUE( + paddle::holds_alternative(dist_attr)); + const auto& tensor_attr = paddle::get<0>(dist_attr); + return tensor_attr.is_partial(); +} + +auto get_partial_dims(const phi::distributed::ArgDistAttr& dist_attr) { + EXPECT_TRUE( + paddle::holds_alternative(dist_attr)); + const auto& tensor_attr = paddle::get<0>(dist_attr); + return tensor_attr.partial_dims(); +} + +void check_dim_mapping(const phi::distributed::ArgDistAttr& dist_attr, + const std::vector& dim_mapping, + const std::string& line = "") { + EXPECT_TRUE( + paddle::holds_alternative(dist_attr)) + << line; + EXPECT_EQ(get_dims_mapping(dist_attr), dim_mapping) << line; +} + +void check_partial_dims(const phi::distributed::ArgDistAttr& dist_attr, + const std::set& dims, + const std::string& line = "") { + EXPECT_TRUE( + paddle::holds_alternative(dist_attr)) + << line; + EXPECT_EQ(get_partial_dims(dist_attr), dims) << line; +} + +void clean_partial_status(phi::distributed::ArgDistAttr* dist_attr) { + EXPECT_TRUE( + paddle::holds_alternative(*dist_attr)); + auto& tensor_attr = paddle::get<0>(*dist_attr); + tensor_attr.clean_partial_status(); +} + +void clean_partial_dims(phi::distributed::ArgDistAttr* dist_attr, + std::vector dims) { + EXPECT_TRUE( + paddle::holds_alternative(*dist_attr)); + auto& tensor_attr = paddle::get<0>(*dist_attr); + tensor_attr.clean_partial_dims(dims); +} + +void set_partial_status(phi::distributed::ArgDistAttr* dist_attr, + std::vector dims) { + EXPECT_TRUE( + paddle::holds_alternative(*dist_attr)); + auto& tensor_attr = paddle::get<0>(*dist_attr); + tensor_attr.set_partial_status(dims); +} + TEST(MatmulSPMDRule, Ctor) { // build input data class std::vector x_shape = {64, 32}; @@ -66,14 +130,10 @@ TEST(MatmulSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs.first.size(), input_size); EXPECT_EQ(infered_dist_attrs.second.size(), output_size); - - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); + check_dim_mapping(infered_dist_attrs.first[0], {1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1, -1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); VLOG(4) << "test1 done." << std::endl << std::endl << std::endl; // mk[-1,-1],kn[-1,0] --> mk[-1,-1],kn[-1,0] = nm[-1,0] partial[] @@ -84,15 +144,11 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1, 0})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({-1, 0})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); + check_dim_mapping(infered_dist_attrs.first[0], {-1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1, 0}); + check_dim_mapping(infered_dist_attrs.second[0], {-1, 0}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); VLOG(4) << "test2 done." << std::endl << std::endl << std::endl; - // mk[1, 0],kn[-1,-1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0]: done x_dist_attr.set_dims_mapping({1, 0}); y_dist_attr.set_dims_mapping({-1, -1}); @@ -101,15 +157,11 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({1, 0})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({0, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({0})); + check_dim_mapping(infered_dist_attrs.first[0], {1, 0}); + check_dim_mapping(infered_dist_attrs.first[1], {0, -1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); + check_partial_dims(infered_dist_attrs.second[0], {0}); VLOG(4) << "test3 done." << std::endl << std::endl << std::endl; // mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]: done @@ -120,15 +172,11 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({-1, 1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({1, 0})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({-1, 0})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({1})); + check_dim_mapping(infered_dist_attrs.first[0], {-1, 1}); + check_dim_mapping(infered_dist_attrs.first[1], {1, 0}); + check_dim_mapping(infered_dist_attrs.second[0], {-1, 0}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); + check_partial_dims(infered_dist_attrs.second[0], {1}); VLOG(4) << "test4 done." << std::endl << std::endl << std::endl; // abcmk[1, 0, -1, -1],kn[-1, -1] --> abcmk[1, 0, -1, -1],kn[-1, -1] = @@ -141,13 +189,10 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({0, 1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({0, 1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); + check_dim_mapping(infered_dist_attrs.first[0], {0, 1, -1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1, -1}); + check_dim_mapping(infered_dist_attrs.second[0], {0, 1, -1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); VLOG(4) << "test5 done." << std::endl << std::endl << std::endl; // abcmk[1, -1, -1, 0],kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[0, -1] = abcmn[1, @@ -159,15 +204,11 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({1, -1, -1, 0})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({0, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({0})); + check_dim_mapping(infered_dist_attrs.first[0], {1, -1, -1, 0}); + check_dim_mapping(infered_dist_attrs.first[1], {0, -1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, -1, -1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); + check_partial_dims(infered_dist_attrs.second[0], {0}); VLOG(4) << "test6 done." << std::endl << std::endl << std::endl; // abcmk[1, -1, -1, 0], kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[-1, -1] = @@ -179,13 +220,12 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/true, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({1, -1, -1, 0})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({1, -1, 0, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); + + check_dim_mapping(infered_dist_attrs.first[0], {1, -1, -1, 0}); + check_dim_mapping(infered_dist_attrs.first[1], {-1, -1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, -1, 0, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); + VLOG(4) << "test7 done." << std::endl << std::endl << std::endl; // abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = @@ -197,17 +237,13 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/true}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({-1, -1, -1, 0})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({1, 0})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({-1, -1, -1, 1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({0})); - infered_dist_attrs.second[0].clean_partial_dims(std::vector({0})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); + check_dim_mapping(infered_dist_attrs.first[0], {-1, -1, -1, 0}); + check_dim_mapping(infered_dist_attrs.first[1], {1, 0}); + check_dim_mapping(infered_dist_attrs.second[0], {-1, -1, -1, 1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); + check_partial_dims(infered_dist_attrs.second[0], {0}); + clean_partial_dims(&infered_dist_attrs.second[0], {0}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); VLOG(4) << "test8 done." << std::endl << std::endl << std::endl; // abcmk[-1, -1, 0, 1]+trans_x=true, kn[1, 0]+trans_y=true --> abcmk[-1, -1, @@ -219,20 +255,16 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/true, /*trans_x=*/true}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({-1, -1, 0, 1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector( - {-1, 0})); // confilct and should be changed to [-1, 0] - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({-1, -1, 1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({0})); - VLOG(4) << infered_dist_attrs.second[0].to_string(); - infered_dist_attrs.second[0].clean_partial_status(); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); - EXPECT_ANY_THROW(infered_dist_attrs.second[0].set_partial_status( - std::vector({1}))); + + check_dim_mapping(infered_dist_attrs.first[0], {-1, -1, 0, 1}); + check_dim_mapping(infered_dist_attrs.first[1], + {-1, 0}); // confilct and should be changed to [-1, 0] + check_dim_mapping(infered_dist_attrs.second[0], {-1, -1, 1, -1}); + check_partial_dims(infered_dist_attrs.second[0], {0}); + + clean_partial_status(&infered_dist_attrs.second[0]); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); + EXPECT_ANY_THROW(set_partial_status(&infered_dist_attrs.second[0], {1})); VLOG(4) << "test9 done." << std::endl << std::endl << std::endl; // abcmk[-1, -1, 1, 0], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = @@ -256,29 +288,21 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/true, /*trans_x=*/true}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({-1, -1, 1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({0})); + check_dim_mapping(infered_dist_attrs.second[0], {-1, -1, 1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); + check_partial_dims(infered_dist_attrs.second[0], {0}); // try to clean partial on a dim which is not partial - EXPECT_ANY_THROW(infered_dist_attrs.second[0].clean_partial_dims( - std::vector({1}))); - + EXPECT_ANY_THROW(clean_partial_dims(&infered_dist_attrs.second[0], {1})); // try to clean partial on a dims which is sharded - EXPECT_ANY_THROW(infered_dist_attrs.second[0].set_partial_status( - std::vector({1}))); + EXPECT_ANY_THROW(set_partial_status(&infered_dist_attrs.second[0], {1})); // clean partial and then re-set again - infered_dist_attrs.second[0].clean_partial_dims(std::vector({0})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); - infered_dist_attrs.second[0].set_partial_status(std::vector({0})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({0})); - + clean_partial_dims(&infered_dist_attrs.second[0], {0}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); + set_partial_status(&infered_dist_attrs.second[0], {0}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); + check_partial_dims(infered_dist_attrs.second[0], {0}); VLOG(4) << "test11 done." << std::endl << std::endl << std::endl; } @@ -328,26 +352,18 @@ TEST(LayerNormSPMDRule, Ctor) { bias_dist_attr); phi::distributed::InferSpmdContext ctx({x, scale, bias}, {epsilon, begin_norm_axis}); - std::pair, std::vector> - infered_dist_attrs = layer_norm_rule.InferForward(ctx); + auto infered_dist_attrs = layer_norm_rule.InferForward(ctx); size_t input_size = 3; size_t output_size = 3; EXPECT_EQ(infered_dist_attrs.first.size(), input_size); EXPECT_EQ(infered_dist_attrs.second.size(), output_size); - - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1})); - EXPECT_EQ(infered_dist_attrs.first[2].dims_mapping(), - std::vector({-1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.second[1].dims_mapping(), - std::vector({1})); - EXPECT_EQ(infered_dist_attrs.second[2].dims_mapping(), - std::vector({1})); + check_dim_mapping(infered_dist_attrs.first[0], {1, -1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1}); + check_dim_mapping(infered_dist_attrs.first[2], {-1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, -1, -1}); + check_dim_mapping(infered_dist_attrs.second[1], {1}); + check_dim_mapping(infered_dist_attrs.second[2], {1}); VLOG(4) << "test1 done."; // ijk[1, 0, -1],k[0],k[0] --> ijk[1, -1, -1],z[1],z[1], @@ -364,18 +380,13 @@ TEST(LayerNormSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext({x, scale, bias}, {epsilon, begin_norm_axis}); infered_dist_attrs = layer_norm_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1})); - EXPECT_EQ(infered_dist_attrs.first[2].dims_mapping(), - std::vector({-1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.second[1].dims_mapping(), - std::vector({1})); - EXPECT_EQ(infered_dist_attrs.second[2].dims_mapping(), - std::vector({1})); + + check_dim_mapping(infered_dist_attrs.first[0], {1, -1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1}); + check_dim_mapping(infered_dist_attrs.first[2], {-1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, -1, -1}); + check_dim_mapping(infered_dist_attrs.second[1], {1}); + check_dim_mapping(infered_dist_attrs.second[2], {1}); VLOG(4) << "test2 done."; // ijk[0, -1, -1],y[-1],y[1] --> ijk[0, 1, -1], i[0], i[0], y=jk, @@ -392,18 +403,13 @@ TEST(LayerNormSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext({x, scale, bias}, {epsilon, begin_norm_axis}); infered_dist_attrs = layer_norm_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({0, -1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1})); - EXPECT_EQ(infered_dist_attrs.first[2].dims_mapping(), - std::vector({-1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({0, -1, -1})); - EXPECT_EQ(infered_dist_attrs.second[1].dims_mapping(), - std::vector({0})); - EXPECT_EQ(infered_dist_attrs.second[2].dims_mapping(), - std::vector({0})); + + check_dim_mapping(infered_dist_attrs.first[0], {0, -1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1}); + check_dim_mapping(infered_dist_attrs.first[2], {-1}); + check_dim_mapping(infered_dist_attrs.second[0], {0, -1, -1}); + check_dim_mapping(infered_dist_attrs.second[1], {0}); + check_dim_mapping(infered_dist_attrs.second[2], {0}); VLOG(4) << "test3 done."; } @@ -449,24 +455,19 @@ TEST(MatmulSPMDRuleInferBackward, Ctor) { // -1] phi::distributed::InferSpmdContext ctx( {x, y, out}, {/*trans_x=*/false, /*trans_x=*/false}); - std::pair, std::vector> - infered_dist_attrs = matmul_spmd_rule.InferBackward(ctx); + auto infered_dist_attrs = matmul_spmd_rule.InferBackward(ctx); size_t input_size = 2; size_t output_size = 1; EXPECT_EQ(infered_dist_attrs.first.size(), input_size); EXPECT_EQ(infered_dist_attrs.second.size(), output_size); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({-1, -1, 1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({-1, -1, 1, -1})); - EXPECT_EQ(infered_dist_attrs.first[0].is_partial(), false); - EXPECT_EQ(infered_dist_attrs.first[1].is_partial(), false); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - + check_dim_mapping(infered_dist_attrs.first[0], {-1, -1, 1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs.second[0], {-1, -1, 1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.first[0]), false); + EXPECT_EQ(is_partial(infered_dist_attrs.first[1]), false); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); VLOG(4) << "test1 done." << std::endl << std::endl << std::endl; } @@ -524,18 +525,14 @@ TEST(ReplicatedSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs_dy.first.size(), input_size); EXPECT_EQ(infered_dist_attrs_dy.second.size(), output_size); - EXPECT_EQ(infered_dist_attrs_st.first[0].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_st.first[1].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs_st.second[0].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_st.second[1].dims_mapping(), - std::vector({-1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_st.first[0].is_partial(), false); - EXPECT_EQ(infered_dist_attrs_st.first[1].is_partial(), false); - EXPECT_EQ(infered_dist_attrs_st.second[0].is_partial(), false); - EXPECT_EQ(infered_dist_attrs_st.second[1].is_partial(), false); + check_dim_mapping(infered_dist_attrs_st.first[0], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_st.first[1], {-1, -1}); + check_dim_mapping(infered_dist_attrs_st.second[0], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_st.second[1], {-1, -1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs_st.first[0]), false); + EXPECT_EQ(is_partial(infered_dist_attrs_st.first[1]), false); + EXPECT_EQ(is_partial(infered_dist_attrs_st.second[0]), false); + EXPECT_EQ(is_partial(infered_dist_attrs_st.second[1]), false); EXPECT_EQ(infered_dist_attrs_st.first, infered_dist_attrs_dy.first); EXPECT_EQ(infered_dist_attrs_st.second, infered_dist_attrs_dy.second); VLOG(4) << "test1 done." << std::endl << std::endl << std::endl; @@ -554,15 +551,10 @@ TEST(ReplicatedSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs_st.second.size(), output_size); EXPECT_EQ(infered_dist_attrs_dy.first.size(), input_size); EXPECT_EQ(infered_dist_attrs_dy.second.size(), output_size); - - EXPECT_EQ(infered_dist_attrs_dy.first[0].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.first[1].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.first[2].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[0].dims_mapping(), - std::vector({-1, -1, -1})); + check_dim_mapping(infered_dist_attrs_dy.first[0], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.first[1], {-1, -1}); + check_dim_mapping(infered_dist_attrs_dy.first[2], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[0], {-1, -1, -1}); EXPECT_EQ(infered_dist_attrs_st.first, infered_dist_attrs_dy.first); EXPECT_EQ(infered_dist_attrs_st.second, infered_dist_attrs_dy.second); VLOG(4) << "test2 done." << std::endl << std::endl << std::endl; @@ -582,14 +574,10 @@ TEST(ReplicatedSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs_dy.first.size(), input_size); EXPECT_EQ(infered_dist_attrs_dy.second.size(), output_size); - EXPECT_EQ(infered_dist_attrs_dy.first[0].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[0].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[1].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[2].dims_mapping(), - std::vector({-1, -1, -1})); + check_dim_mapping(infered_dist_attrs_dy.first[0], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[0], {-1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[1], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[2], {-1, -1, -1}); EXPECT_EQ(infered_dist_attrs_st.first, infered_dist_attrs_dy.first); EXPECT_EQ(infered_dist_attrs_st.second, infered_dist_attrs_dy.second); VLOG(4) << "test3 done." << std::endl << std::endl << std::endl; @@ -649,19 +637,15 @@ TEST(DefaultDataParallelSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs_st.second.size(), output_size); EXPECT_EQ(infered_dist_attrs_dy.first.size(), input_size); EXPECT_EQ(infered_dist_attrs_dy.second.size(), output_size); + check_dim_mapping(infered_dist_attrs_st.first[0], {0, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_st.first[1], {0, -1}); + check_dim_mapping(infered_dist_attrs_st.second[0], {0, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_st.second[1], {0, -1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs_st.first[0]), false); + EXPECT_EQ(is_partial(infered_dist_attrs_st.first[1]), false); + EXPECT_EQ(is_partial(infered_dist_attrs_st.second[0]), false); + EXPECT_EQ(is_partial(infered_dist_attrs_st.second[1]), false); - EXPECT_EQ(infered_dist_attrs_st.first[0].dims_mapping(), - std::vector({0, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_st.first[1].dims_mapping(), - std::vector({0, -1})); - EXPECT_EQ(infered_dist_attrs_st.second[0].dims_mapping(), - std::vector({0, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_st.second[1].dims_mapping(), - std::vector({0, -1, -1})); - EXPECT_EQ(infered_dist_attrs_st.first[0].is_partial(), false); - EXPECT_EQ(infered_dist_attrs_st.first[1].is_partial(), false); - EXPECT_EQ(infered_dist_attrs_st.second[0].is_partial(), false); - EXPECT_EQ(infered_dist_attrs_st.second[1].is_partial(), false); EXPECT_EQ(infered_dist_attrs_st.first, infered_dist_attrs_dy.first); EXPECT_EQ(infered_dist_attrs_st.second, infered_dist_attrs_dy.second); VLOG(4) << "test1 done." << std::endl << std::endl << std::endl; @@ -682,14 +666,11 @@ TEST(DefaultDataParallelSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs_dy.first.size(), input_size); EXPECT_EQ(infered_dist_attrs_dy.second.size(), output_size); - EXPECT_EQ(infered_dist_attrs_dy.first[0].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[0].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[1].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[2].dims_mapping(), - std::vector({-1, -1, -1})); + check_dim_mapping(infered_dist_attrs_dy.first[0], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[0], {-1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[1], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[2], {-1, -1, -1}); + EXPECT_EQ(infered_dist_attrs_st.first, infered_dist_attrs_dy.first); EXPECT_EQ(infered_dist_attrs_st.second, infered_dist_attrs_dy.second); VLOG(4) << "test2 done." << std::endl << std::endl << std::endl; @@ -735,19 +716,101 @@ TEST(DefaultDataParallelSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs_st.second.size(), output_size); EXPECT_EQ(infered_dist_attrs_dy.first.size(), input_size); EXPECT_EQ(infered_dist_attrs_dy.second.size(), output_size); - - EXPECT_EQ(infered_dist_attrs_dy.first[0].dims_mapping(), - std::vector({0, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.first[1].dims_mapping(), - std::vector({0, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[0].dims_mapping(), - std::vector({0, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[1].dims_mapping(), - std::vector({0, -1, -1})); + check_dim_mapping(infered_dist_attrs_dy.first[0], {0, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.first[1], {0, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[0], {0, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[1], {0, -1, -1}); EXPECT_EQ(infered_dist_attrs_st.first, infered_dist_attrs_dy.first); EXPECT_EQ(infered_dist_attrs_st.second, infered_dist_attrs_dy.second); VLOG(4) << "test4 done." << std::endl << std::endl << std::endl; } +TEST(ConcatRule, Ctor) { + std::vector mesh_shape = {2, 2}; + std::vector process_ids = {0, 1, 2, 3}; + std::vector dim_names = {"x", "y"}; + ProcessMesh process_mesh(mesh_shape, process_ids, dim_names); + + std::vector> shapes = { + {16, 16, 16}, {4, 16, 16}, {2, 16, 16}}; + std::vector> dim_mappings = { + {-1, 0, 1}, {-1, 1, 0}, {-1, -1, 0}}; + std::vector> partial_status = {{}, {}, {1}}; + + auto build_inputs = [&] { + std::vector inputs; + for (int i = 0; i < 3; i++) { + auto t_dist_attr = TensorDistAttr(); + t_dist_attr.set_process_mesh(process_mesh); + t_dist_attr.set_dims_mapping(dim_mappings[i]); + t_dist_attr.set_dynamic_dims({false, false, false}); + auto input = phi::distributed::DistMetaTensor(phi::make_ddim(shapes[i]), + t_dist_attr); + inputs.push_back(input); + } + return inputs; + }; + + // test 1, inputs are aligned according to cost, and partial status is cleared + auto inputs = build_inputs(); + auto infered_dist_attrs = phi::distributed::ConcatInferSpmd(inputs, 0); + // list of tensor => sigle tensor + EXPECT_EQ(infered_dist_attrs.first.size(), static_cast(1)); + EXPECT_EQ(infered_dist_attrs.second.size(), static_cast(1)); + EXPECT_TRUE( + paddle::holds_alternative>( + infered_dist_attrs.first[0])); + EXPECT_TRUE(paddle::holds_alternative( + infered_dist_attrs.second[0])); + auto& inputs_infer1 = paddle::get<1>(infered_dist_attrs.first[0]); + for (auto e : inputs_infer1) { + check_dim_mapping(e, {-1, 1, 0}); + check_partial_dims(e, {}); + } + check_dim_mapping(infered_dist_attrs.second[0], {-1, 1, 0}); + check_partial_dims(infered_dist_attrs.second[0], {}); + + // test 2,force replicate along concat axis + inputs = build_inputs(); + infered_dist_attrs = phi::distributed::ConcatInferSpmd(inputs, 1); + // list of tensor => sigle tensor + EXPECT_EQ(infered_dist_attrs.first.size(), static_cast(1)); + EXPECT_EQ(infered_dist_attrs.second.size(), static_cast(1)); + EXPECT_TRUE( + paddle::holds_alternative>( + infered_dist_attrs.first[0])); + EXPECT_TRUE(paddle::holds_alternative( + infered_dist_attrs.second[0])); + auto& inputs_infer2 = paddle::get<1>(infered_dist_attrs.first[0]); + for (auto e : inputs_infer2) { + check_dim_mapping(e, {1, -1, 0}); + check_partial_dims(e, {}); + } + check_dim_mapping(infered_dist_attrs.second[0], {1, -1, 0}); + check_partial_dims(infered_dist_attrs.second[0], {}); +} +TEST(Util, Ctor) { + // test equal test not equal + using phi::distributed::PartialStatus; + using phi::distributed::PlacementEqual; + using phi::distributed::ReplicatedStatus; + using phi::distributed::ShardStatus; + auto a = std::make_shared(phi::ReduceType::kRedSum); + auto b = std::make_shared(phi::ReduceType::kRedMin); + EXPECT_TRUE(PlacementEqual(a, a)); + EXPECT_TRUE(!PlacementEqual(a, b)); + auto c = std::make_shared(0); + auto d = std::make_shared(1); + EXPECT_TRUE(!PlacementEqual(a, c)); + EXPECT_TRUE(!PlacementEqual(b, c)); + EXPECT_TRUE(PlacementEqual(c, c)); + EXPECT_TRUE(!PlacementEqual(c, d)); + auto e = std::make_shared(); + EXPECT_TRUE(PlacementEqual(e, e)); + EXPECT_TRUE(!PlacementEqual(a, e)); + EXPECT_TRUE(!PlacementEqual(b, e)); + EXPECT_TRUE(!PlacementEqual(c, e)); + EXPECT_TRUE(!PlacementEqual(d, e)); +} } // namespace auto_parallel } // namespace distributed