diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h index cf4046950964a..ba468269b8230 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h @@ -19,7 +19,6 @@ #include "paddle/fluid/distributed/auto_parallel/spmd_rules/elementwise_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h" -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/reduction_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h" @@ -32,9 +31,6 @@ namespace paddle { namespace distributed { namespace auto_parallel { -// matmul rule -REGISTER_SPMD_RULE(matmul, MatmulSPMDRule); - // reduction rules REGISTER_SPMD_RULE(all, ReductionSPMDRule); REGISTER_SPMD_RULE(amax, ReductionSPMDRule); diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index e03292faa9e42..6f639f145dcea 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -15,13 +15,16 @@ #include #include +#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/pybind/auto_parallel_py.h" +#include "paddle/fluid/pybind/pybind_variant_caster.h" #include "paddle/phi/core/device_context.h" #include "paddle/phi/core/distributed/auto_parallel/device_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" #include "paddle/utils/optional.h" #include "paddle/utils/pybind.h" @@ -32,6 +35,10 @@ #include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h" +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#endif + namespace py = pybind11; namespace paddle { @@ -42,6 +49,7 @@ using paddle::distributed::auto_parallel::kDefault; using paddle::distributed::auto_parallel::OperatorDistAttr; using paddle::distributed::auto_parallel::SPMDRuleBase; using paddle::distributed::auto_parallel::SPMDRuleMap; +using paddle::framework::BlockDesc; using paddle::framework::OpDesc; using paddle::framework::VarDesc; using phi::distributed::ProcessMesh; @@ -343,6 +351,41 @@ void BindAutoParallel(py::module *m) { &SPMDRuleBase::InferBackward)); // .def("infer_backward", &SPMDRuleBase::InferBackward) [revert in future] + py::class_(*m, "SpmdRule") + .def("infer_forward", + [](const phi::distributed::SpmdRule &self, + const std::vector &input_specs, + const std::vector &attrs) { + phi::distributed::InferSpmdContext ctx; + for (auto &spec : input_specs) { + ctx.EmplaceBackInput(phi::distributed::DistMetaTensor( + phi::make_ddim(spec.shape()), spec.dist_attr())); + } + for (auto &attr : attrs) { + ctx.EmplaceBackAttr(attr); + } + return self.InferForward(ctx); + }) + .def("infer_backward", + [](const phi::distributed::SpmdRule &self, + const std::vector &input_specs, + const std::vector &output_specs, + const std::vector &attrs) { + phi::distributed::InferSpmdContext ctx; + for (auto &spec : input_specs) { + ctx.EmplaceBackInput(phi::distributed::DistMetaTensor( + phi::make_ddim(spec.shape()), spec.dist_attr())); + } + for (auto &spec : output_specs) { + ctx.EmplaceBackInput(phi::distributed::DistMetaTensor( + phi::make_ddim(spec.shape()), spec.dist_attr())); + } + for (auto &attr : attrs) { + ctx.EmplaceBackAttr(attr); + } + return self.InferBackward(ctx); + }); + py::class_(*m, "DistTensorSpec") .def(py::init<>()) .def(py::init()) @@ -472,6 +515,14 @@ void BindAutoParallel(py::module *m) { }, py::return_value_policy::reference); + m->def( + "get_phi_spmd_rule", + [](const std::string op_type) { + return phi::distributed::SpmdRuleFactory::Instance().GetSpmdRule( + op_type); + }, + py::return_value_policy::reference); + // TODO(liuzhenhai): DistributedMapper is not used for now, but // dist_mapper_test need the symbols forch DistributedMapper to be linked, // remove it latter diff --git a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt index 0aee1b5363838..63fbe9ecd677c 100644 --- a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt +++ b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt @@ -9,6 +9,8 @@ collect_srcs( dist_mapper.cc reshard_utils.cc dist_tensor.cc + dist_meta_tensor.cc + inferspmd_utils.cc reshard_function.cc reshard_split_functor.cc reshard_concat_functor.cc diff --git a/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.cc b/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.cc new file mode 100644 index 0000000000000..dc5d6c20e62b3 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.cc @@ -0,0 +1,51 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" + +namespace phi { +namespace distributed { + +phi::DDim DistMetaTensor::dims() const { + // member values in tensor_ have higher priority than those in DistMetaTensor + if (tensor_ != nullptr) { + PADDLE_ENFORCE_EQ(this->is_dist(), + true, + phi::errors::InvalidArgument( + "The current MetaTensor doesn't contains " + "DistTensor when call `dist_attr` method.")); + return MetaTensor::dims(); + } else { + return dims_; + } +} + +const distributed::TensorDistAttr& DistMetaTensor::dist_attr() const { + // member values in tensor_ have higher priority than those in DistMetaTensor + if (tensor_ != nullptr) { + PADDLE_ENFORCE_EQ(this->is_dist(), + true, + phi::errors::InvalidArgument( + "The current MetaTensor doesn't contains " + "DistTensor when call `dist_attr` method.")); + return static_cast(tensor_)->dist_attr(); + } else { + return dist_attr_; + } +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h b/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h new file mode 100644 index 0000000000000..efbf38d28f9f0 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h @@ -0,0 +1,68 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/meta_tensor.h" + +namespace phi { +namespace distributed { + +class DistMetaTensor : public MetaTensor { + public: + // supporting implicit construction is easier to use + DistMetaTensor(TensorBase* tensor) // NOLINT + : MetaTensor(tensor) {} + DistMetaTensor(const TensorBase& tensor) // NOLINT + : MetaTensor(tensor) {} + DistMetaTensor(const TensorBase* tensor) // NOLINT + : MetaTensor(tensor) {} + DistMetaTensor(TensorBase& tensor) // NOLINT + : MetaTensor(tensor) {} + // For static mode only + DistMetaTensor(const phi::DDim& dims, const TensorDistAttr& dist_attr) + : dims_(dims), dist_attr_(dist_attr) {} + + DistMetaTensor(DistMetaTensor&&) = default; + DistMetaTensor& operator=(DistMetaTensor&&) = default; + DistMetaTensor(const DistMetaTensor&) = default; + DistMetaTensor& operator=(const DistMetaTensor&) = default; + + virtual ~DistMetaTensor() = default; + + DDim dims() const override; + + const distributed::TensorDistAttr& dist_attr() const; + + private: + /** + * Note: When using the semi-automatic parallel segmentation derivation rules + * of the static graph, in order to facilitate the packaging of the input + * parameters of the construction, the DistMetaTensor is inherited and + * encapsulated, and the class members dims_ and dist_attr_ are added to it. + * + * The information contained in these two members is also in the tensor of the + * meta_tensor of the base class, and there is redundancy. + * + * We need to pay attention when using it to ensure the consistency. + * These two members are read-only, and their values cannot be changed + * after construction. To change their values, they need to be set + * directly in tensor_*/ + phi::DDim dims_; + TensorDistAttr dist_attr_; +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc new file mode 100644 index 0000000000000..3b94dc017e5e7 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc @@ -0,0 +1,97 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" + +namespace phi { +namespace distributed { + +void InferSpmdContext::EmplaceBackInput(DistMetaTensor input) { + inputs_.emplace_back(std::move(input)); +} + +void InferSpmdContext::EmplaceBackAttr(Attribute attr) { + attrs_.emplace_back(std::move(attr)); +} + +const DistMetaTensor& InferSpmdContext::InputAt(size_t idx) const { + return inputs_.at(idx); +} + +template +AttrType InferSpmdContext::AttrAt(size_t idx) const { + try { + return paddle::get(attrs_.at(idx)); + } catch (paddle::bad_variant_access const& e) { + PADDLE_THROW(phi::errors::InvalidArgument( + "Attribute cast error in InferSpmd Context, the input attr type is " + "`%s`, but the expected attribute type is `%s`.", + attrs_.at(idx).type().name(), + std::type_index(typeid(AttrType)).name())); + } +} + +template <> +bool InferSpmdContext::AttrAt(size_t idx) const { + try { + auto attr = attrs_.at(idx); + if (attr.type() == typeid(int)) { + return static_cast(paddle::get(attr)); + } else { + return paddle::get(attr); + } + } catch (paddle::bad_variant_access const& e) { + PADDLE_THROW(phi::errors::InvalidArgument( + "Attribute cast error in InferSpmd Context, the input attr type is " + "`%s`, but the expected attribute type is `bool`.", + attrs_.at(idx).type().name())); + } +} + +const Attribute& InferSpmdContext::AttrAt(size_t idx) const { + return attrs_.at(idx); +} + +SpmdRuleFactory& SpmdRuleFactory::Instance() { + static SpmdRuleFactory g_spmd_rule_map; + return g_spmd_rule_map; +} + +bool SpmdRuleFactory::ContainsSpmdRule(const std::string& kernel_name) const { + return spmd_rule_map_.count(kernel_name) > 0; +} + +int SpmdRuleFactory::InsertSpmdRule(std::string kernel_name, SpmdRule rule) { + PADDLE_ENFORCE_NE( + ContainsSpmdRule(kernel_name), + true, + phi::errors::AlreadyExists( + "`%s` Kernel's Spmd rules has been registered.", kernel_name)); + spmd_rule_map_.insert({std::move(kernel_name), std::move(rule)}); + return 0; +} + +const SpmdRule& SpmdRuleFactory::GetSpmdRule( + const std::string& kernel_name) const { + auto it = spmd_rule_map_.find(kernel_name); + PADDLE_ENFORCE_NE( + it, + spmd_rule_map_.end(), + phi::errors::NotFound("`%s` Kernel's Spmd rules is not registered.", + kernel_name)); + return it->second; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h new file mode 100644 index 0000000000000..bccee2bf5981a --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h @@ -0,0 +1,186 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/phi/common/int_array.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/attribute.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/macros.h" +#include "paddle/phi/core/type_defs.h" +#include "paddle/utils/any.h" +#include "paddle/utils/flat_hash_map.h" +#include "paddle/utils/small_vector.h" + +namespace phi { +namespace distributed { + +class InferSpmdContext { + public: + InferSpmdContext() = default; + InferSpmdContext( + paddle::small_vector inputs, + paddle::small_vector attrs) + : inputs_(std::move(inputs)), attrs_(std::move(attrs)) {} + + void EmplaceBackInput(DistMetaTensor input); + void EmplaceBackAttr(Attribute attr); + + const DistMetaTensor& InputAt(size_t idx) const; + + template + AttrType AttrAt(size_t idx) const; + + const Attribute& AttrAt(size_t idx) const; + + private: + // Now we only need `inputs`, for backward, the `output` is passed as input + paddle::small_vector inputs_; + // Because the attribute arguments of dygraph do not have `attr name`, + // so we use vector instead of map + paddle::small_vector attrs_; +}; + +using InferSpmdFn = SpmdInfo (*)(const InferSpmdContext&); + +#define PD_INFER_SPMD(...) \ + ::phi::distributed::InferSpmdFnImpl::Call + +template +struct InferSpmdTypeTag {}; + +template +struct InferSpmdFnImpl; + +template +struct InferSpmdFnImpl { + static SpmdInfo Call(const InferSpmdContext& ctx) { + return InferSpmdFnCallHelper>:: + template Call<0, 0>(ctx); + } + + private: + template + struct InferSpmdFnCallHelper; + + // TODO(chenweihang): support other input type later as needed + template + struct InferSpmdFnCallHelper { + template + static SpmdInfo Call(const InferSpmdContext& ctx, PreviousArgs&... pargs) { + static_assert(attr_idx == 0, + "InferSpmd's Input should appear before Attributes."); + const DistMetaTensor& arg = ctx.InputAt(in_idx); + return InferSpmdFnCallHelper::template Call( + ctx, pargs..., arg); + } + }; + +#define PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(attr_type) \ + template \ + struct InferSpmdFnCallHelper { \ + template \ + static SpmdInfo Call(const InferSpmdContext& ctx, \ + PreviousArgs&... pargs) { \ + attr_type arg = ctx.AttrAt(attr_idx); \ + return InferSpmdFnCallHelper::template Call( \ + ctx, pargs..., arg); \ + } \ + } + + // TODO(chenweihang): support other attr type later as needed + PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(bool); + + /* End case */ + template + struct InferSpmdFnCallHelper> { + template + static SpmdInfo Call(const InferSpmdContext& ctx UNUSED, Args&... args) { + return infer_spmd_fn(args...); + } + }; +}; + +class SpmdRule { + public: + explicit SpmdRule(InferSpmdFn forward_fn) + : forward_fn_(forward_fn), backward_fn_(nullptr) {} + + SpmdRule(InferSpmdFn forward_fn, InferSpmdFn backward_fn) + : forward_fn_(forward_fn), backward_fn_(backward_fn) {} + + SpmdInfo InferForward(const InferSpmdContext& ctx) const { + PADDLE_ENFORCE_NE( + forward_fn_, + nullptr, + phi::errors::NotFound("Current SpmdRule's forward function is not " + "found, Please make sure " + "that you have registered the rule correctly.")); + return forward_fn_(ctx); + } + + SpmdInfo InferBackward(const InferSpmdContext& ctx) const { + PADDLE_ENFORCE_NE( + backward_fn_, + nullptr, + phi::errors::NotFound("Current SpmdRule's backward function is not " + "found, Please make sure " + "that you have registered the rule correctly.")); + return backward_fn_(ctx); + } + + private: + InferSpmdFn forward_fn_; + InferSpmdFn backward_fn_; +}; + +// SpmdRuleFactory manage the spmd rules and cache the propagate results +// TODO(chenweihang): Add spmd caching impl later +class SpmdRuleFactory { + public: + static SpmdRuleFactory& Instance(); + + bool ContainsSpmdRule(const std::string& kernel_name) const; + + int InsertSpmdRule(std::string kernel_name, SpmdRule rule); + + const SpmdRule& GetSpmdRule(const std::string& kernel_name) const; + + private: + SpmdRuleFactory() = default; + + paddle::flat_hash_map spmd_rule_map_; + + DISABLE_COPY_AND_ASSIGN(SpmdRuleFactory); +}; + +#define PD_REGISTER_SPMD_RULE(kernel_name, ...) \ + UNUSED static int ___registrar_spmd_rule_for_##kernel_name = \ + ::phi::distributed::SpmdRuleFactory::Instance().InsertSpmdRule( \ + #kernel_name, ::phi::distributed::SpmdRule(__VA_ARGS__)); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/type_defs.h b/paddle/phi/core/distributed/type_defs.h new file mode 100644 index 0000000000000..cd201ac5c5aaf --- /dev/null +++ b/paddle/phi/core/distributed/type_defs.h @@ -0,0 +1,29 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace phi { +namespace distributed { +class TensorDistAttr; + +using SpmdInfo = + std::pair, std::vector>; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/meta_tensor.cc b/paddle/phi/core/meta_tensor.cc index 4ef20d9958284..146e0bc4fc662 100644 --- a/paddle/phi/core/meta_tensor.cc +++ b/paddle/phi/core/meta_tensor.cc @@ -17,8 +17,6 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/phi/core/dense_tensor.h" - -#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/string_tensor.h" @@ -208,6 +206,9 @@ bool MetaTensor::is_dense() const { return DenseTensor::classof(tensor_); } bool MetaTensor::is_selected_rows() const { return SelectedRows::classof(tensor_); } +bool MetaTensor::is_dist() const { + return distributed::DistTensor::classof(tensor_); +} bool MetaTensor::is_tensor_array() const { return false; } bool MetaTensor::is_same_tensor(const MetaTensor& meta_tensor) const { diff --git a/paddle/phi/core/meta_tensor.h b/paddle/phi/core/meta_tensor.h index 900bfb3eb6b3f..e7ccc1a61c5f2 100644 --- a/paddle/phi/core/meta_tensor.h +++ b/paddle/phi/core/meta_tensor.h @@ -23,7 +23,6 @@ limitations under the License. */ namespace phi { -// TODO(chenweihang): add other flags if needed struct MetaConfig { bool is_runtime{true}; bool is_run_mkldnn_kernel{false}; @@ -82,6 +81,8 @@ class MetaTensor { virtual bool is_selected_rows() const; virtual bool is_dense() const; + virtual bool is_dist() const; + // TODO(YuanRisheng) This API is for compatible with Fluid // and it will be deleted in the future. virtual bool is_tensor_array() const; @@ -97,7 +98,7 @@ class MetaTensor { protected: static void unspecified_bool_true() {} - private: + protected: // Because the lod in compiletime and runtime is different, // so `LoD` cannot in public methods const LoD& lod() const; diff --git a/paddle/phi/infermeta/CMakeLists.txt b/paddle/phi/infermeta/CMakeLists.txt index f53f655b24409..ef68ac8632ce4 100644 --- a/paddle/phi/infermeta/CMakeLists.txt +++ b/paddle/phi/infermeta/CMakeLists.txt @@ -1,5 +1,10 @@ add_subdirectory(strings) add_subdirectory(sparse) + +if(WITH_DISTRIBUTE) + add_subdirectory(spmd_rules) +endif() + collect_srcs( infermeta_srcs SRCS diff --git a/paddle/phi/infermeta/spmd_rules/CMakeLists.txt b/paddle/phi/infermeta/spmd_rules/CMakeLists.txt new file mode 100644 index 0000000000000..c28cd85c718c8 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/CMakeLists.txt @@ -0,0 +1,6 @@ +file( + GLOB spmd_rules_srcs + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "*.cc") + +collect_srcs(infermeta_srcs SRCS ${spmd_rules_srcs}) diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc b/paddle/phi/infermeta/spmd_rules/matmul.cc similarity index 69% rename from paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc rename to paddle/phi/infermeta/spmd_rules/matmul.cc index d280ccec37d7a..088f9ab16363a 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.cc +++ b/paddle/phi/infermeta/spmd_rules/matmul.cc @@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -12,22 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h" +#include "paddle/phi/infermeta/spmd_rules/matmul.h" +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" -namespace paddle { +namespace phi { namespace distributed { -namespace auto_parallel { + using phi::distributed::auto_parallel::str_join; -TensorDistAttr GetInferedDistAttr( +////////////////// Utils Functions ////////////////// + +TensorDistAttr GetMatmulInferedDistAttr( const TensorDistAttr& origin_dist_attr, const std::vector& shape, const std::string& tensor_axis, const std::unordered_map& axis_to_dim_map, - const bool trans_axis) { - TensorDistAttr dist_attr_ = CopyTensorDistAttrForOutput(origin_dist_attr); + bool trans_axis) { + TensorDistAttr dist_attr = CopyTensorDistAttrForOutput(origin_dist_attr); std::vector infered_dims_mapping; infered_dims_mapping.reserve(tensor_axis.size()); @@ -50,8 +57,8 @@ TensorDistAttr GetInferedDistAttr( infered_dims_mapping.end() - 1); } - dist_attr_.set_dims_mapping(infered_dims_mapping); - return dist_attr_; + dist_attr.set_dims_mapping(infered_dims_mapping); + return dist_attr; } void FillMatmulOperandNotation(const int x_ndim, @@ -105,42 +112,35 @@ void FillMatmulOperandNotation(const int x_ndim, } } -std::pair, std::vector> -MatmulSPMDRule::InferForward(const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs) { - // step0: verify input args based on matmul logic - auto input_specs_size = input_specs.size(); - PADDLE_ENFORCE_EQ( - input_specs_size, - 2, - phi::errors::InvalidArgument( - "The size of InputSpec of matmul should be 2, but got [%d].", - input_specs_size)); - auto x_shape = input_specs[0].shape(); - auto y_shape = input_specs[1].shape(); +////////////////// InferMeta(Contains SPMD) Functions ////////////////// + +SpmdInfo MatmulSpmdInferForward(const DistMetaTensor& x, + const DistMetaTensor& y, + bool trans_x, + bool trans_y) { + // Step0: verify input args based on matmul logic + auto x_shape = phi::vectorize(x.dims()); + auto y_shape = phi::vectorize(y.dims()); int x_ndim = x_shape.size(); int y_ndim = y_shape.size(); - auto x_dist_attr_src = input_specs[0].dist_attr(); - auto y_dist_attr_src = input_specs[1].dist_attr(); + auto x_dist_attr_src = x.dist_attr(); + auto y_dist_attr_src = y.dist_attr(); std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); std::vector y_dims_mapping = y_dist_attr_src.dims_mapping(); PADDLE_ENFORCE_EQ( x_ndim, x_dims_mapping.size(), - phi::errors::InvalidArgument( - "Mismatch of X's tensor size: [%d] and X's dims_mapping size [%d].", - x_ndim, - x_dims_mapping.size())); + phi::errors::InvalidArgument("The Tensor X's rank [%d] and X's " + "dims_mapping size [%d] are not matched.", + x_ndim, + x_dims_mapping.size())); PADDLE_ENFORCE_EQ( y_ndim, y_dims_mapping.size(), - phi::errors::InvalidArgument( - "Mismatch of Y's tensor size: [%d] and Y's dims_mapping size [%d].", - y_ndim, - y_dims_mapping.size())); - - bool trans_x = ExtractAttr("trans_x", attrs); - bool trans_y = ExtractAttr("trans_y", attrs); + phi::errors::InvalidArgument("The Tensor Y's rank [%d] and Y's " + "dims_mapping size [%d] are not matched.", + y_ndim, + y_dims_mapping.size())); VLOG(6) << "MatmulSPMDRule InferForward Inputs: " << "X shape: [" << str_join(x_shape) << "], x_dims_mapping: [" @@ -151,37 +151,37 @@ MatmulSPMDRule::InferForward(const std::vector& input_specs, << "trans_y: " << "[" << (trans_y ? "true" : "false") << "]; "; - // step1: build Einsum Notation + // Step1: build Einsum Notation std::string x_axes; std::string y_axes; std::string out_axes; FillMatmulOperandNotation(x_ndim, y_ndim, &x_axes, &y_axes, &out_axes); - // step2: Sharding Propogation + // Step2: Sharding Propogation if (trans_x) { - PADDLE_ENFORCE_GE( - x_ndim, - 2, - phi::errors::InvalidArgument("When trans_x is True, the size of X " - "tensor should be 2, but got [%d].", - x_ndim)); + PADDLE_ENFORCE_GE(x_ndim, + 2, + phi::errors::InvalidArgument( + "When trans_x is True, the size of X " + "tensor should be greater than 2, but got [%d].", + x_ndim)); std::iter_swap(x_dims_mapping.end() - 2, x_dims_mapping.end() - 1); } if (trans_y) { - PADDLE_ENFORCE_GE( - y_ndim, - 2, - phi::errors::InvalidArgument("When trans_x is True, the size of X " - "tensor should be 2, but got [%d].", - y_ndim)); + PADDLE_ENFORCE_GE(y_ndim, + 2, + phi::errors::InvalidArgument( + "When trans_y is True, the size of Y " + "tensor should be greater than 2, but got [%d].", + y_ndim)); std::iter_swap(y_dims_mapping.end() - 2, y_dims_mapping.end() - 1); } - // step2.1: Sharding Merge + // Step2.1: Sharding Merge std::pair> x_pair(x_axes, x_dims_mapping); std::pair> y_pair(y_axes, y_dims_mapping); auto axis_to_dim_map = ShardingMergeForTensors({x_pair, y_pair}); - // step2.2: Infer Output's Dims Mapping. + // Step2.2: Infer Output's Dims Mapping. TensorDistAttr output_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); std::vector out_dims_mapping; @@ -191,13 +191,13 @@ MatmulSPMDRule::InferForward(const std::vector& input_specs, } output_dist_attr_dst.set_dims_mapping(out_dims_mapping); - // step2.3: Merge and get Inputs' New Dims Mapping. - TensorDistAttr x_dist_attr_dst = GetInferedDistAttr( + // Step2.3: Merge and get Inputs' New Dims Mapping. + TensorDistAttr x_dist_attr_dst = GetMatmulInferedDistAttr( x_dist_attr_src, x_shape, x_axes, axis_to_dim_map, trans_x); - TensorDistAttr y_dist_attr_dst = GetInferedDistAttr( + TensorDistAttr y_dist_attr_dst = GetMatmulInferedDistAttr( y_dist_attr_src, y_shape, y_axes, axis_to_dim_map, trans_y); - // step2.3: Handle Partial + // Step2.3: Handle Partial // Step2.3.1 Output Partial std::vector partial_on_dims = ResoluteOutputPartialDimension(axis_to_dim_map, out_axes); @@ -221,24 +221,16 @@ MatmulSPMDRule::InferForward(const std::vector& input_specs, return {{x_dist_attr_dst, y_dist_attr_dst}, {output_dist_attr_dst}}; } -std::pair, std::vector> -MatmulSPMDRule::InferBackward(const std::vector& input_specs, - const std::vector& output_specs, - const paddle::framework::AttributeMap& attrs) { - // extra & verify input - auto output_specs_size = output_specs.size(); - PADDLE_ENFORCE_EQ( - output_specs_size, - 1, - phi::errors::InvalidArgument( - "The size of OutputSpec of matmul should be 1, but got [%d].", - output_specs_size)); - - auto out_shape = output_specs[0].shape(); +SpmdInfo MatmulSpmdInferBackward(const DistMetaTensor& x, + const DistMetaTensor& y, + const DistMetaTensor& out, + bool trans_x, + bool trans_y) { + auto out_shape = phi::vectorize(out.dims()); int out_ndim = out_shape.size(); - auto x_shape = input_specs[0].shape(); - auto y_shape = input_specs[1].shape(); + auto x_shape = phi::vectorize(x.dims()); + auto y_shape = phi::vectorize(y.dims()); int x_ndim = x_shape.size(); int y_ndim = y_shape.size(); int max_ndim = std::max(x_ndim, y_ndim); @@ -250,10 +242,7 @@ MatmulSPMDRule::InferBackward(const std::vector& input_specs, max_ndim, out_ndim)); - bool trans_x = ExtractAttr("trans_x", attrs); - bool trans_y = ExtractAttr("trans_y", attrs); - - auto out_dist_attr_src = output_specs[0].dist_attr(); + auto out_dist_attr_src = out.dist_attr(); std::vector out_dims_mapping = out_dist_attr_src.dims_mapping(); // step1: build Einsum Notation @@ -267,10 +256,10 @@ MatmulSPMDRule::InferBackward(const std::vector& input_specs, auto axis_to_dim_map = ShardingMergeForTensors({{out_axes, out_dims_mapping}}, false); - TensorDistAttr x_dist_attr_dst = GetInferedDistAttr( - input_specs[0].dist_attr(), x_shape, x_axes, axis_to_dim_map, trans_x); - TensorDistAttr y_dist_attr_dst = GetInferedDistAttr( - input_specs[1].dist_attr(), y_shape, y_axes, axis_to_dim_map, trans_y); + TensorDistAttr x_dist_attr_dst = GetMatmulInferedDistAttr( + x.dist_attr(), x_shape, x_axes, axis_to_dim_map, trans_x); + TensorDistAttr y_dist_attr_dst = GetMatmulInferedDistAttr( + y.dist_attr(), y_shape, y_axes, axis_to_dim_map, trans_y); // step3: Handle Partial // NOTE we skip the partial backward inference in Partial Stage-I. @@ -289,6 +278,5 @@ MatmulSPMDRule::InferBackward(const std::vector& input_specs, return {{x_dist_attr_dst, y_dist_attr_dst}, {out_dist_attr_src}}; } -} // namespace auto_parallel } // namespace distributed -} // namespace paddle +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/matmul.h b/paddle/phi/infermeta/spmd_rules/matmul.h new file mode 100644 index 0000000000000..64cfba26a7445 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/matmul.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { + +SpmdInfo MatmulSpmdInferForward(const DistMetaTensor& x, + const DistMetaTensor& y, + bool trans_x, + bool trans_y); + +SpmdInfo MatmulSpmdInferBackward(const DistMetaTensor& x, + const DistMetaTensor& y, + const DistMetaTensor& out, + bool trans_x, + bool trans_y); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h new file mode 100644 index 0000000000000..ad519ff287a33 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" + +#include "paddle/phi/infermeta/spmd_rules/matmul.h" + +/** + * Design Notes: + * + * 1. SPMD info is the special meta info of DistTensor, so we put Spmd infer + * functions in `infermeta` directory. + * + * 2. Since the infer functions of Spmd forward and backward are closely related + * and need to be registered together, we manage them together in one file. + * + * 3. SPMD rules are much smaller than infermeta function, and we manage files + * in operator units. + * + * 4. The previous registration used some compile-time regular matching methods, + * which was less flexible, and the registration of SPMD rules here is declare + * directly in the header file + */ + +namespace phi { +namespace distributed { + +// matmul rule +PD_REGISTER_SPMD_RULE(matmul, + PD_INFER_SPMD(phi::distributed::MatmulSpmdInferForward), + PD_INFER_SPMD(phi::distributed::MatmulSpmdInferBackward)); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc new file mode 100644 index 0000000000000..2252de98a78b3 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -0,0 +1,159 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/core/enforce.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +std::string GetBroadcastAxes(const int64_t& tenosr_ndim, + const int64_t& broadcast_ndim, + const std::string& alphabet) { + PADDLE_ENFORCE_GE( + alphabet.size(), + broadcast_ndim, + phi::errors::InvalidArgument( + "The size of alphabet [%d] is less than broadcast ndim [%d]", + alphabet.size(), + broadcast_ndim)); + PADDLE_ENFORCE_GE(broadcast_ndim, + tenosr_ndim, + phi::errors::InvalidArgument( + "The broadcast ndim [%d] is less than tenosr ndim [%d]", + broadcast_ndim, + tenosr_ndim)); + if (tenosr_ndim <= 0) { + return std::string(); + } + return alphabet.substr(broadcast_ndim - tenosr_ndim, tenosr_ndim); +} + +// Rule1: A repicated dimension could be merged by any sharded dimension. +// Rule2: A tensor axis could at most be sharded by one mesh dimension. +// (TODO trigger heuristics cost model and reshard to handle axis sharded by +// multiple dimension case.) +int64_t ShardingMergeForAxis(const std::string& axis, + const int64_t& mesh_dim1, + const int64_t& mesh_dim2) { + if (mesh_dim1 != mesh_dim2) { + if (mesh_dim1 == -1) { + return mesh_dim2; + } else if (mesh_dim2 == -1) { + return mesh_dim1; + } else { + // (TODO) local cost model here. + PADDLE_THROW( + phi::errors::Unimplemented("Tensor Axis[%s] is Sharded by two " + "different mesh dimension [%d] and [%d].", + axis, + mesh_dim1, + mesh_dim2)); + } + + } else { + return mesh_dim1; + } +} + +std::unordered_map ShardingMergeForTensors( + const std::vector>>& + tensor_axes_to_dim_pairs, + const bool merge_conflicts) { + std::unordered_map axis_to_dim_map; + std::unordered_map dim_to_axis_map; + int64_t merge_dim; + + for (auto& pair : tensor_axes_to_dim_pairs) { + for (size_t i = 0; i < pair.second.size(); ++i) { + auto tensor_axis = pair.first.substr(i, 1); + auto mesh_dim = pair.second[i]; + + if (axis_to_dim_map.count(tensor_axis) == 0) { + merge_dim = mesh_dim; + } else { + merge_dim = ShardingMergeForAxis( + tensor_axis, mesh_dim, axis_to_dim_map[tensor_axis]); + } + axis_to_dim_map[tensor_axis] = merge_dim; + if (merge_dim != -1) { + if (dim_to_axis_map.count(merge_dim) == 0) { + dim_to_axis_map.insert({merge_dim, tensor_axis}); + } else if (dim_to_axis_map[merge_dim].find(tensor_axis) == + std::string::npos) { + dim_to_axis_map[merge_dim] += tensor_axis; + } + } + } + } + + // Resolute "mesh_dim shard by more than one axis" confict. + // Now we just naive pick the first axis naively. + // (TODO) use local cost model to pick the axis with lowest cost(in concern of + // memory or communication or computation). + for (auto& it : dim_to_axis_map) { + if (it.second.size() > 1) { + if (merge_conflicts) { + VLOG(4) << "Sharding Conflict: Mesh_Dim [" << it.first + << "] are Sharding Multiple Tensor Axis: [" << it.second + << "]. The Axis: [" << it.second[0] << "] is Picked."; + for (size_t i = 1; i < it.second.size(); ++i) { + axis_to_dim_map[it.second.substr(i, 1)] = -1; + } + } else { + PADDLE_THROW(phi::errors::PreconditionNotMet( + "Multiple Tensor Axes [%s] is sharded by same mesh dimension [%d].", + str_join(it.second), + it.first)); + } + } + } + + return axis_to_dim_map; +} + +TensorDistAttr CopyTensorDistAttrForOutput( + const TensorDistAttr& src_dist_attr) { + TensorDistAttr new_dist_attr = TensorDistAttr(); + new_dist_attr.set_process_mesh(src_dist_attr.process_mesh()); + new_dist_attr.set_batch_dim(src_dist_attr.batch_dim()); + new_dist_attr.set_dynamic_dims(src_dist_attr.dynamic_dims()); + // new_dist_attr.set_annotated(false); TODO unset field is false by default. + return new_dist_attr; +} + +std::vector ResoluteOutputPartialDimension( + const std::unordered_map& axis_to_dim_map, + const std::string& tensor_axes) { + std::vector partial_on_dims; + + for (auto& it : axis_to_dim_map) { + if (tensor_axes.find(it.first) == std::string::npos) { + if (it.second > -1) { + partial_on_dims.push_back(it.second); + } + } + } + return partial_on_dims; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/utils.h b/paddle/phi/infermeta/spmd_rules/utils.h new file mode 100644 index 0000000000000..5e3c3a3d0961c --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/utils.h @@ -0,0 +1,65 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +namespace phi { +namespace distributed { +class TensorDistAttr; + +// Generate the axis notation of tensor for the einsum notation of a broadcast +// operation(alignment star from the rightmost axis). tenosr_ndim: the size of +// the tensor. broadcast_ndim: the maxium size of tensors in this broadcast +// operation. alphabet: the characters used to represent the axes of tensor. +// length of alphabet should >= broadcast_ndim. +std::string GetBroadcastAxes(const int64_t& tenosr_ndim, + const int64_t& broadcast_ndim, + const std::string& alphabet); + +// Merge the sharding specification (dims mapping) for one tensor Axis. +// Rule1: A repicated dimension could be merged by any sharded dimension. +// Rule2: A tensor axis could at most be sharded by one mesh dimension. +// (TODO trigger heuristics cost model and reshard to handle axis sharded by +// multiple dimension case.) +int64_t ShardingMergeForAxis(const std::string& axis, + const int64_t& mesh_dim1, + const int64_t& mesh_dim2); + +// Merge sharding specification (dims mapping) of given tensors. +// The same axes of different tensors will be merged. +std::unordered_map ShardingMergeForTensors( + const std::vector>>& + tensor_axes_to_dim_pairs, + const bool merge_conflicts = true); + +// Intend to use for generating the TensorDistAttr of output based on the input +// activation TensorDistAttr. The process_mesh, batch_dim, dynamic_dim are +// copied with annotated is forced to False, and dims_mapping is leave to be +// null. +TensorDistAttr CopyTensorDistAttrForOutput(const TensorDistAttr& src_dist_attr); + +// Resolute the partial mesh dimension of a output tensor, giving the +// merged sharding specifcation of input tensors and the axis names of output +// tensor. Input are +std::vector ResoluteOutputPartialDimension( + const std::unordered_map& axis_to_dim_map, + const std::string& tensor_axes); + +} // namespace distributed +} // namespace phi diff --git a/test/auto_parallel/spmd_rules/test_matmul_rule.py b/test/auto_parallel/spmd_rules/test_matmul_rule.py index 59e47113302db..1cf2f49860b33 100644 --- a/test/auto_parallel/spmd_rules/test_matmul_rule.py +++ b/test/auto_parallel/spmd_rules/test_matmul_rule.py @@ -13,23 +13,23 @@ # limitations under the License. import unittest +from collections import OrderedDict -from paddle.distributed.auto_parallel.static.completion import get_spmd_rule from paddle.distributed.auto_parallel.static.dist_attribute import ( DistTensorSpec, TensorDistAttr, ) from paddle.distributed.fleet import auto +from paddle.framework import core class TestMatmulSPMDRule(unittest.TestCase): def setUp(self): - self.rule = get_spmd_rule("matmul") + # After replaced all spmd rules by phi impl, we can recover the + # api name to `get_spmd_rule` + self.rule = core.get_phi_spmd_rule("matmul") - self.attrs = { - 'trans_x': False, - 'trans_y': False, - } + self.attrs = OrderedDict([('trans_x', False), ('trans_y', False)]) def test_matmul_infer_forward(self): # forward setup @@ -49,7 +49,8 @@ def test_matmul_infer_forward(self): # TODO test partial: mk[1, 0],kn[0, -1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0] result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + list(self.attrs.values()), ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -68,7 +69,8 @@ def test_matmul_infer_forward(self): self.x_dist_tensor_spec.set_dims_mapping([1, -1]) self.y_dist_tensor_spec.set_dims_mapping([-1, -1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + list(self.attrs.values()), ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -82,7 +84,8 @@ def test_matmul_infer_forward(self): self.x_dist_tensor_spec.set_dims_mapping([1, -1]) self.y_dist_tensor_spec.set_dims_mapping([-1, -1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + list(self.attrs.values()), ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -95,7 +98,8 @@ def test_matmul_infer_forward(self): self.x_dist_tensor_spec.set_dims_mapping([-1, -1]) self.y_dist_tensor_spec.set_dims_mapping([-1, 0]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + list(self.attrs.values()), ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -108,7 +112,8 @@ def test_matmul_infer_forward(self): self.x_dist_tensor_spec.set_dims_mapping([1, 0]) self.y_dist_tensor_spec.set_dims_mapping([-1, -1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + list(self.attrs.values()), ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -122,7 +127,8 @@ def test_matmul_infer_forward(self): self.x_dist_tensor_spec.set_dims_mapping([-1, -1]) self.y_dist_tensor_spec.set_dims_mapping([1, 0]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + list(self.attrs.values()), ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -137,7 +143,8 @@ def test_matmul_infer_forward(self): self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) self.y_dist_tensor_spec.set_dims_mapping([-1, -1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + list(self.attrs.values()), ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -154,7 +161,8 @@ def test_matmul_infer_forward(self): self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0]) self.y_dist_tensor_spec.set_dims_mapping([-1, -1]) result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + list(self.attrs.values()), ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -173,7 +181,8 @@ def test_matmul_infer_forward(self): self.y_dist_tensor_spec.set_dims_mapping([-1, -1]) self.attrs['trans_x'] = True result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + list(self.attrs.values()), ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -192,7 +201,8 @@ def test_matmul_infer_forward(self): self.attrs['trans_x'] = False self.attrs['trans_y'] = True result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + list(self.attrs.values()), ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -215,7 +225,8 @@ def test_matmul_infer_forward(self): self.attrs['trans_x'] = True self.attrs['trans_y'] = True result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + list(self.attrs.values()), ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -239,7 +250,8 @@ def test_matmul_infer_forward(self): self.attrs['trans_y'] = True with self.assertRaises(NotImplementedError): self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], self.attrs + [self.x_dist_tensor_spec, self.y_dist_tensor_spec], + list(self.attrs.values()), ) def test_matmul_infer_backward(self): @@ -270,7 +282,7 @@ def test_matmul_infer_backward(self): result_dist_attrs = self.rule.infer_backward( [self.x_dist_tensor_spec, self.y_dist_tensor_spec], [self.out_dist_tensor_spec], - self.attrs, + list(self.attrs.values()), ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -307,7 +319,7 @@ def test_matmul_infer_backward(self): result_dist_attrs = self.rule.infer_backward( [self.x_dist_tensor_spec, self.y_dist_tensor_spec], [self.out_dist_tensor_spec], - self.attrs, + list(self.attrs.values()), ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -329,7 +341,7 @@ def test_matmul_infer_backward(self): result_dist_attrs = self.rule.infer_backward( [self.x_dist_tensor_spec, self.y_dist_tensor_spec], [self.out_dist_tensor_spec], - self.attrs, + list(self.attrs.values()), ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -354,7 +366,7 @@ def test_matmul_infer_backward(self): result_dist_attrs = self.rule.infer_backward( [self.x_dist_tensor_spec, self.y_dist_tensor_spec], [self.out_dist_tensor_spec], - self.attrs, + list(self.attrs.values()), ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -376,7 +388,7 @@ def test_matmul_infer_backward(self): self.rule.infer_backward( [self.x_dist_tensor_spec, self.y_dist_tensor_spec], [self.out_dist_tensor_spec], - self.attrs, + list(self.attrs.values()), ) diff --git a/test/auto_parallel/test_dist_tensor.py b/test/auto_parallel/test_dist_tensor.py index 45aa8c9fbcaae..0bf2d88db4237 100644 --- a/test/auto_parallel/test_dist_tensor.py +++ b/test/auto_parallel/test_dist_tensor.py @@ -83,6 +83,21 @@ def test_relu_api_for_dist_tensor(self): dist_out.backward() self.check_tensor_eq(local_in.grad, dist_in.grad) + def test_matmul_api_for_dist_tensor(self): + x = np.random.random(size=[4, 4]).astype("float32") + y = np.random.random(size=[4, 4]).astype("float32") + local_x, dist_x = self.create_local_and_dist_tensor_pair(x) + local_y, dist_y = self.create_local_and_dist_tensor_pair(y) + local_out = paddle.matmul(local_x, local_y) + dist_out = paddle.matmul(dist_x, dist_y) + self.check_tensor_eq(local_out, dist_out) + + # test backward + local_out.backward() + dist_out.backward() + self.check_tensor_eq(local_x.grad, dist_x.grad) + self.check_tensor_eq(local_y.grad, dist_y.grad) + if __name__ == "__main__": unittest.main() diff --git a/test/cpp/auto_parallel/CMakeLists.txt b/test/cpp/auto_parallel/CMakeLists.txt index c5912a6fa1021..ae7300bf62f08 100644 --- a/test/cpp/auto_parallel/CMakeLists.txt +++ b/test/cpp/auto_parallel/CMakeLists.txt @@ -9,8 +9,7 @@ if(WITH_DISTRIBUTE) dist_tensor_test SRCS dist_tensor_test.cc DEPS phi) + cc_test_old(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rules) endif() cc_test_old(dist_mapper_test SRCS dist_mapper_test.cc DEPS phi) - -cc_test_old(spmd_rule_test SRCS spmd_rule_test.cc DEPS spmd_rules) diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index dfd8394faa16a..30907b707aa9e 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -14,12 +14,16 @@ limitations under the License. */ #include #include + +#include "glog/logging.h" #include "gtest/gtest.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" +#include "paddle/phi/infermeta/spmd_rules/rules.h" namespace paddle { namespace distributed { @@ -45,22 +49,20 @@ TEST(MatmulSPMDRule, Ctor) { y_dist_attr.set_dims_mapping(std::vector({-1, -1})); y_dist_attr.set_dynamic_dims(std::vector({false, false})); - DistTensorSpec x_dist_tensor_spec = DistTensorSpec(x_shape, x_dist_attr); - DistTensorSpec y_dist_tensor_spec = DistTensorSpec(y_shape, y_dist_attr); + size_t input_size = 2; + size_t output_size = 1; - paddle::framework::AttributeMap attrs; - attrs["trans_x"] = false; - attrs["trans_y"] = false; + phi::distributed::DistMetaTensor x(phi::make_ddim(x_shape), x_dist_attr); + phi::distributed::DistMetaTensor y(phi::make_ddim(y_shape), y_dist_attr); - SPMDRuleBase* matmul_rule = SPMDRuleMap::Instance().Get("matmul"); + auto matmul_spmd_rule = + phi::distributed::SpmdRuleFactory::Instance().GetSpmdRule("matmul"); // mk[1, -1],kn[-1, -1] --> mk[1, -1],kn[-1, -1] = nm[1, -1] partial[] - std::pair, std::vector> - infered_dist_attrs = matmul_rule->InferForward( - {x_dist_tensor_spec, y_dist_tensor_spec}, attrs); + phi::distributed::InferSpmdContext ctx( + {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); + auto infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - size_t input_size = 2; - size_t output_size = 1; EXPECT_EQ(infered_dist_attrs.first.size(), input_size); EXPECT_EQ(infered_dist_attrs.second.size(), output_size); @@ -74,10 +76,13 @@ TEST(MatmulSPMDRule, Ctor) { VLOG(4) << "test1 done." << std::endl << std::endl << std::endl; // mk[-1,-1],kn[-1,0] --> mk[-1,-1],kn[-1,0] = nm[-1,0] partial[] - x_dist_tensor_spec.set_dims_mapping({-1, -1}); - y_dist_tensor_spec.set_dims_mapping({-1, 0}); - infered_dist_attrs = matmul_rule->InferForward( - {x_dist_tensor_spec, y_dist_tensor_spec}, attrs); + x_dist_attr.set_dims_mapping({-1, -1}); + y_dist_attr.set_dims_mapping({-1, 0}); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); + ctx = phi::distributed::InferSpmdContext( + {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); + infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), std::vector({-1, -1})); EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), @@ -88,10 +93,13 @@ TEST(MatmulSPMDRule, Ctor) { VLOG(4) << "test2 done." << std::endl << std::endl << std::endl; // mk[1, 0],kn[-1,-1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0]: done - x_dist_tensor_spec.set_dims_mapping({1, 0}); - y_dist_tensor_spec.set_dims_mapping({-1, -1}); - infered_dist_attrs = matmul_rule->InferForward( - {x_dist_tensor_spec, y_dist_tensor_spec}, attrs); + x_dist_attr.set_dims_mapping({1, 0}); + y_dist_attr.set_dims_mapping({-1, -1}); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); + ctx = phi::distributed::InferSpmdContext( + {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); + infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), std::vector({1, 0})); EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), @@ -104,10 +112,13 @@ TEST(MatmulSPMDRule, Ctor) { VLOG(4) << "test3 done." << std::endl << std::endl << std::endl; // mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]: done - x_dist_tensor_spec.set_dims_mapping({-1, -1}); - y_dist_tensor_spec.set_dims_mapping({1, 0}); - infered_dist_attrs = matmul_rule->InferForward( - {x_dist_tensor_spec, y_dist_tensor_spec}, attrs); + x_dist_attr.set_dims_mapping({-1, -1}); + y_dist_attr.set_dims_mapping({1, 0}); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); + ctx = phi::distributed::InferSpmdContext( + {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); + infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), std::vector({-1, 1})); EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), @@ -121,11 +132,14 @@ TEST(MatmulSPMDRule, Ctor) { // abcmk[1, 0, -1, -1],kn[-1, -1] --> abcmk[1, 0, -1, -1],kn[-1, -1] = // abcmn[1, 0, -1, -1] partial[]: done - x_dist_tensor_spec.set_shape({512, 48, 64, 32}); - x_dist_tensor_spec.set_dims_mapping({0, 1, -1, -1}); - y_dist_tensor_spec.set_dims_mapping({-1, -1}); - infered_dist_attrs = matmul_rule->InferForward( - {x_dist_tensor_spec, y_dist_tensor_spec}, attrs); + x_shape = {512, 48, 64, 32}; + x_dist_attr.set_dims_mapping({0, 1, -1, -1}); + y_dist_attr.set_dims_mapping({-1, -1}); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); + ctx = phi::distributed::InferSpmdContext( + {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); + infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), std::vector({0, 1, -1, -1})); EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), @@ -137,10 +151,13 @@ TEST(MatmulSPMDRule, Ctor) { // abcmk[1, -1, -1, 0],kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[0, -1] = abcmn[1, // -1, -1, -1] partial[0]: done - x_dist_tensor_spec.set_dims_mapping({1, -1, -1, 0}); - y_dist_tensor_spec.set_dims_mapping({-1, -1}); - infered_dist_attrs = matmul_rule->InferForward( - {x_dist_tensor_spec, y_dist_tensor_spec}, attrs); + x_dist_attr.set_dims_mapping({1, -1, -1, 0}); + y_dist_attr.set_dims_mapping({-1, -1}); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); + ctx = phi::distributed::InferSpmdContext( + {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); + infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), std::vector({1, -1, -1, 0})); EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), @@ -154,11 +171,13 @@ TEST(MatmulSPMDRule, Ctor) { // abcmk[1, -1, -1, 0], kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[-1, -1] = // abcmn[1, -1, 0, -1] partial[]: done - x_dist_tensor_spec.set_dims_mapping({1, -1, -1, 0}); - y_dist_tensor_spec.set_dims_mapping({-1, -1}); - attrs["trans_x"] = true; - infered_dist_attrs = matmul_rule->InferForward( - {x_dist_tensor_spec, y_dist_tensor_spec}, attrs); + x_dist_attr.set_dims_mapping({1, -1, -1, 0}); + y_dist_attr.set_dims_mapping({-1, -1}); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); + ctx = phi::distributed::InferSpmdContext( + {x, y}, {/*trans_x=*/true, /*trans_x=*/false}); + infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), std::vector({1, -1, -1, 0})); EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), @@ -170,12 +189,13 @@ TEST(MatmulSPMDRule, Ctor) { // abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = // abcmn[-1, -1, -1, 1] partial[0]: done - x_dist_tensor_spec.set_dims_mapping({-1, -1, -1, -1}); - y_dist_tensor_spec.set_dims_mapping({1, 0}); - attrs["trans_x"] = false; - attrs["trans_y"] = true; - infered_dist_attrs = matmul_rule->InferForward( - {x_dist_tensor_spec, y_dist_tensor_spec}, attrs); + x_dist_attr.set_dims_mapping({-1, -1, -1, -1}); + y_dist_attr.set_dims_mapping({1, 0}); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); + ctx = phi::distributed::InferSpmdContext( + {x, y}, {/*trans_x=*/false, /*trans_x=*/true}); + infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), std::vector({-1, -1, -1, 0})); EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), @@ -191,12 +211,13 @@ TEST(MatmulSPMDRule, Ctor) { // abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = // abcmn[-1, -1, -1, 1] partial[0]: done - x_dist_tensor_spec.set_dims_mapping({-1, -1, 0, 1}); - y_dist_tensor_spec.set_dims_mapping({1, 0}); - attrs["trans_y"] = true; - attrs["trans_x"] = true; - infered_dist_attrs = matmul_rule->InferForward( - {x_dist_tensor_spec, y_dist_tensor_spec}, attrs); + x_dist_attr.set_dims_mapping({-1, -1, 0, 1}); + y_dist_attr.set_dims_mapping({1, 0}); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); + ctx = phi::distributed::InferSpmdContext( + {x, y}, {/*trans_x=*/true, /*trans_x=*/true}); + infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), std::vector({-1, -1, 0, 1})); EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), @@ -214,23 +235,25 @@ TEST(MatmulSPMDRule, Ctor) { // abcmk[-1, -1, 1, 0], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = // abcmn[-1, -1, -1, 1] partial[0]: done - x_dist_tensor_spec.set_dims_mapping({-1, -1, 1, 0}); - y_dist_tensor_spec.set_dims_mapping({1, 0}); - attrs["trans_y"] = true; - attrs["trans_x"] = true; - EXPECT_ANY_THROW(infered_dist_attrs = matmul_rule->InferForward( - {x_dist_tensor_spec, y_dist_tensor_spec}, attrs)); + x_dist_attr.set_dims_mapping({-1, -1, 1, 0}); + y_dist_attr.set_dims_mapping({1, 0}); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); + ctx = phi::distributed::InferSpmdContext( + {x, y}, {/*trans_x=*/true, /*trans_x=*/true}); + EXPECT_ANY_THROW(infered_dist_attrs = matmul_spmd_rule.InferForward(ctx)); // Error VLOG(4) << "test10 done." << std::endl << std::endl << std::endl; // abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = // abcmn[-1, -1, -1, 1] partial[0]: - x_dist_tensor_spec.set_dims_mapping({-1, -1, 0, 1}); - y_dist_tensor_spec.set_dims_mapping({1, 0}); - attrs["trans_y"] = true; - attrs["trans_x"] = true; - infered_dist_attrs = matmul_rule->InferForward( - {x_dist_tensor_spec, y_dist_tensor_spec}, attrs); + x_dist_attr.set_dims_mapping({-1, -1, 0, 1}); + y_dist_attr.set_dims_mapping({1, 0}); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); + ctx = phi::distributed::InferSpmdContext( + {x, y}, {/*trans_x=*/true, /*trans_x=*/true}); + infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); EXPECT_ANY_THROW(infered_dist_attrs.second[0].clean_partial_dims( std::vector({1}))); infered_dist_attrs.second[0].set_partial_status(std::vector({1})); @@ -242,7 +265,6 @@ TEST(MatmulSPMDRule, Ctor) { std::set({0})); infered_dist_attrs.second[0].clean_partial_dims(std::vector({0})); EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); - VLOG(4) << "test11 done." << std::endl << std::endl << std::endl; } @@ -372,25 +394,21 @@ TEST(MatmulSPMDRuleInferBackward, Ctor) { out_dist_attr.set_dynamic_dims(std::vector({false, false})); out_dist_attr.set_partial_status(std::vector({0})); - DistTensorSpec x_dist_tensor_spec = DistTensorSpec(x_shape, x_dist_attr); - DistTensorSpec y_dist_tensor_spec = DistTensorSpec(y_shape, y_dist_attr); - DistTensorSpec out_dist_tensor_spec = - DistTensorSpec(out_shape, out_dist_attr); - - paddle::framework::AttributeMap attrs; - attrs["trans_x"] = false; - attrs["trans_y"] = false; + phi::distributed::DistMetaTensor x(phi::make_ddim(x_shape), x_dist_attr); + phi::distributed::DistMetaTensor y(phi::make_ddim(y_shape), y_dist_attr); + phi::distributed::DistMetaTensor out(phi::make_ddim(out_shape), + out_dist_attr); - SPMDRuleBase* matmul_rule = SPMDRuleMap::Instance().Get("matmul"); + auto matmul_spmd_rule = + phi::distributed::SpmdRuleFactory::Instance().GetSpmdRule("matmul"); // TODO(zyc) update in future: propogate the partial in inferbackward // abmn[-1, -1, 1, -1] + partial[0] --> abmk[-1, -1, 1, -1], a1kn[-1, -1, -1, // -1] + phi::distributed::InferSpmdContext ctx( + {x, y, out}, {/*trans_x=*/false, /*trans_x=*/false}); std::pair, std::vector> - infered_dist_attrs = - matmul_rule->InferBackward({x_dist_tensor_spec, y_dist_tensor_spec}, - {out_dist_tensor_spec}, - attrs); + infered_dist_attrs = matmul_spmd_rule.InferBackward(ctx); size_t input_size = 2; size_t output_size = 1;