diff --git a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt index 2f8949a74bca0..63fbe9ecd677c 100644 --- a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt +++ b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt @@ -9,6 +9,7 @@ collect_srcs( dist_mapper.cc reshard_utils.cc dist_tensor.cc + dist_meta_tensor.cc inferspmd_utils.cc reshard_function.cc reshard_split_functor.cc diff --git a/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.cc b/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.cc new file mode 100644 index 0000000000000..dc5d6c20e62b3 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.cc @@ -0,0 +1,51 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" + +namespace phi { +namespace distributed { + +phi::DDim DistMetaTensor::dims() const { + // member values in tensor_ have higher priority than those in DistMetaTensor + if (tensor_ != nullptr) { + PADDLE_ENFORCE_EQ(this->is_dist(), + true, + phi::errors::InvalidArgument( + "The current MetaTensor doesn't contains " + "DistTensor when call `dist_attr` method.")); + return MetaTensor::dims(); + } else { + return dims_; + } +} + +const distributed::TensorDistAttr& DistMetaTensor::dist_attr() const { + // member values in tensor_ have higher priority than those in DistMetaTensor + if (tensor_ != nullptr) { + PADDLE_ENFORCE_EQ(this->is_dist(), + true, + phi::errors::InvalidArgument( + "The current MetaTensor doesn't contains " + "DistTensor when call `dist_attr` method.")); + return static_cast(tensor_)->dist_attr(); + } else { + return dist_attr_; + } +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h b/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h new file mode 100644 index 0000000000000..efbf38d28f9f0 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h @@ -0,0 +1,68 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/meta_tensor.h" + +namespace phi { +namespace distributed { + +class DistMetaTensor : public MetaTensor { + public: + // supporting implicit construction is easier to use + DistMetaTensor(TensorBase* tensor) // NOLINT + : MetaTensor(tensor) {} + DistMetaTensor(const TensorBase& tensor) // NOLINT + : MetaTensor(tensor) {} + DistMetaTensor(const TensorBase* tensor) // NOLINT + : MetaTensor(tensor) {} + DistMetaTensor(TensorBase& tensor) // NOLINT + : MetaTensor(tensor) {} + // For static mode only + DistMetaTensor(const phi::DDim& dims, const TensorDistAttr& dist_attr) + : dims_(dims), dist_attr_(dist_attr) {} + + DistMetaTensor(DistMetaTensor&&) = default; + DistMetaTensor& operator=(DistMetaTensor&&) = default; + DistMetaTensor(const DistMetaTensor&) = default; + DistMetaTensor& operator=(const DistMetaTensor&) = default; + + virtual ~DistMetaTensor() = default; + + DDim dims() const override; + + const distributed::TensorDistAttr& dist_attr() const; + + private: + /** + * Note: When using the semi-automatic parallel segmentation derivation rules + * of the static graph, in order to facilitate the packaging of the input + * parameters of the construction, the DistMetaTensor is inherited and + * encapsulated, and the class members dims_ and dist_attr_ are added to it. + * + * The information contained in these two members is also in the tensor of the + * meta_tensor of the base class, and there is redundancy. + * + * We need to pay attention when using it to ensure the consistency. + * These two members are read-only, and their values cannot be changed + * after construction. To change their values, they need to be set + * directly in tensor_*/ + phi::DDim dims_; + TensorDistAttr dist_attr_; +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc index 9def309efcaf1..3b94dc017e5e7 100644 --- a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc +++ b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc @@ -17,7 +17,7 @@ limitations under the License. */ namespace phi { namespace distributed { -void InferSpmdContext::EmplaceBackInput(MetaTensor input) { +void InferSpmdContext::EmplaceBackInput(DistMetaTensor input) { inputs_.emplace_back(std::move(input)); } @@ -25,7 +25,7 @@ void InferSpmdContext::EmplaceBackAttr(Attribute attr) { attrs_.emplace_back(std::move(attr)); } -const MetaTensor& InferSpmdContext::InputAt(size_t idx) const { +const DistMetaTensor& InferSpmdContext::InputAt(size_t idx) const { return inputs_.at(idx); } diff --git a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h index 55388b8db69a9..bccee2bf5981a 100644 --- a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h +++ b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h @@ -23,10 +23,10 @@ limitations under the License. */ #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/attribute.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" #include "paddle/phi/core/distributed/type_defs.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/macros.h" -#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/type_defs.h" #include "paddle/utils/any.h" #include "paddle/utils/flat_hash_map.h" @@ -39,14 +39,14 @@ class InferSpmdContext { public: InferSpmdContext() = default; InferSpmdContext( - paddle::small_vector inputs, + paddle::small_vector inputs, paddle::small_vector attrs) : inputs_(std::move(inputs)), attrs_(std::move(attrs)) {} - void EmplaceBackInput(MetaTensor input); + void EmplaceBackInput(DistMetaTensor input); void EmplaceBackAttr(Attribute attr); - const MetaTensor& InputAt(size_t idx) const; + const DistMetaTensor& InputAt(size_t idx) const; template AttrType AttrAt(size_t idx) const; @@ -55,7 +55,7 @@ class InferSpmdContext { private: // Now we only need `inputs`, for backward, the `output` is passed as input - paddle::small_vector inputs_; + paddle::small_vector inputs_; // Because the attribute arguments of dygraph do not have `attr name`, // so we use vector instead of map paddle::small_vector attrs_; @@ -86,12 +86,12 @@ struct InferSpmdFnImpl { // TODO(chenweihang): support other input type later as needed template - struct InferSpmdFnCallHelper { + struct InferSpmdFnCallHelper { template static SpmdInfo Call(const InferSpmdContext& ctx, PreviousArgs&... pargs) { static_assert(attr_idx == 0, "InferSpmd's Input should appear before Attributes."); - const MetaTensor& arg = ctx.InputAt(in_idx); + const DistMetaTensor& arg = ctx.InputAt(in_idx); return InferSpmdFnCallHelper::template Call( ctx, pargs..., arg); diff --git a/paddle/phi/core/meta_tensor.cc b/paddle/phi/core/meta_tensor.cc index 052550416dbd2..146e0bc4fc662 100644 --- a/paddle/phi/core/meta_tensor.cc +++ b/paddle/phi/core/meta_tensor.cc @@ -17,8 +17,6 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" -#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/string_tensor.h" @@ -278,15 +276,4 @@ const LoD& MetaTensor::lod() const { } } -/////////////// For Auto Parallel //////////////// - -const distributed::TensorDistAttr& MetaTensor::dist_attr() const { - PADDLE_ENFORCE_EQ( - this->is_dist(), - true, - phi::errors::InvalidArgument("The current MetaTensor doesn't contains " - "DistTensor when call `dist_attr` method.")); - return static_cast(tensor_)->dist_attr(); -} - } // namespace phi diff --git a/paddle/phi/core/meta_tensor.h b/paddle/phi/core/meta_tensor.h index 8d66bea884418..e7ccc1a61c5f2 100644 --- a/paddle/phi/core/meta_tensor.h +++ b/paddle/phi/core/meta_tensor.h @@ -22,9 +22,6 @@ limitations under the License. */ #include "paddle/phi/core/tensor_meta.h" namespace phi { -namespace distributed { -class TensorDistAttr; -} // namespace distributed struct MetaConfig { bool is_runtime{true}; @@ -92,9 +89,6 @@ class MetaTensor { virtual bool is_same_tensor(const MetaTensor& meta_tensor) const; - // For auto parallel - const distributed::TensorDistAttr& dist_attr() const; - virtual operator unspecified_bool_type() const { return tensor_ == nullptr ? 0 : unspecified_bool_true; } @@ -104,7 +98,7 @@ class MetaTensor { protected: static void unspecified_bool_true() {} - private: + protected: // Because the lod in compiletime and runtime is different, // so `LoD` cannot in public methods const LoD& lod() const; diff --git a/paddle/phi/infermeta/spmd_rules/matmul.cc b/paddle/phi/infermeta/spmd_rules/matmul.cc index c9d01a5ec9ce1..088f9ab16363a 100644 --- a/paddle/phi/infermeta/spmd_rules/matmul.cc +++ b/paddle/phi/infermeta/spmd_rules/matmul.cc @@ -114,8 +114,8 @@ void FillMatmulOperandNotation(const int x_ndim, ////////////////// InferMeta(Contains SPMD) Functions ////////////////// -SpmdInfo MatmulSpmdInferForward(const MetaTensor& x, - const MetaTensor& y, +SpmdInfo MatmulSpmdInferForward(const DistMetaTensor& x, + const DistMetaTensor& y, bool trans_x, bool trans_y) { // Step0: verify input args based on matmul logic @@ -221,9 +221,9 @@ SpmdInfo MatmulSpmdInferForward(const MetaTensor& x, return {{x_dist_attr_dst, y_dist_attr_dst}, {output_dist_attr_dst}}; } -SpmdInfo MatmulSpmdInferBackward(const MetaTensor& x, - const MetaTensor& y, - const MetaTensor& out, +SpmdInfo MatmulSpmdInferBackward(const DistMetaTensor& x, + const DistMetaTensor& y, + const DistMetaTensor& out, bool trans_x, bool trans_y) { auto out_shape = phi::vectorize(out.dims()); diff --git a/paddle/phi/infermeta/spmd_rules/matmul.h b/paddle/phi/infermeta/spmd_rules/matmul.h index 1110f85a25bba..64cfba26a7445 100644 --- a/paddle/phi/infermeta/spmd_rules/matmul.h +++ b/paddle/phi/infermeta/spmd_rules/matmul.h @@ -16,20 +16,20 @@ limitations under the License. */ #include +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" #include "paddle/phi/core/distributed/type_defs.h" -#include "paddle/phi/infermeta/binary.h" namespace phi { namespace distributed { -SpmdInfo MatmulSpmdInferForward(const MetaTensor& x, - const MetaTensor& y, +SpmdInfo MatmulSpmdInferForward(const DistMetaTensor& x, + const DistMetaTensor& y, bool trans_x, bool trans_y); -SpmdInfo MatmulSpmdInferBackward(const MetaTensor& x, - const MetaTensor& y, - const MetaTensor& out, +SpmdInfo MatmulSpmdInferBackward(const DistMetaTensor& x, + const DistMetaTensor& y, + const DistMetaTensor& out, bool trans_x, bool trans_y); diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index fee69fd3d6d95..30907b707aa9e 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -21,7 +21,6 @@ limitations under the License. */ #include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" -#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" #include "paddle/phi/infermeta/spmd_rules/rules.h" @@ -53,12 +52,8 @@ TEST(MatmulSPMDRule, Ctor) { size_t input_size = 2; size_t output_size = 1; - auto dist_x = - phi::distributed::DistTensor(phi::make_ddim(x_shape), x_dist_attr); - auto dist_y = - phi::distributed::DistTensor(phi::make_ddim(y_shape), y_dist_attr); - phi::MetaTensor x(dist_x); - phi::MetaTensor y(dist_y); + phi::distributed::DistMetaTensor x(phi::make_ddim(x_shape), x_dist_attr); + phi::distributed::DistMetaTensor y(phi::make_ddim(y_shape), y_dist_attr); auto matmul_spmd_rule = phi::distributed::SpmdRuleFactory::Instance().GetSpmdRule("matmul"); @@ -83,10 +78,8 @@ TEST(MatmulSPMDRule, Ctor) { // mk[-1,-1],kn[-1,0] --> mk[-1,-1],kn[-1,0] = nm[-1,0] partial[] x_dist_attr.set_dims_mapping({-1, -1}); y_dist_attr.set_dims_mapping({-1, 0}); - dist_x = phi::distributed::DistTensor(phi::make_ddim(x_shape), x_dist_attr); - dist_y = phi::distributed::DistTensor(phi::make_ddim(y_shape), y_dist_attr); - x = phi::MetaTensor(dist_x); - y = phi::MetaTensor(dist_y); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); @@ -102,10 +95,8 @@ TEST(MatmulSPMDRule, Ctor) { // mk[1, 0],kn[-1,-1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0]: done x_dist_attr.set_dims_mapping({1, 0}); y_dist_attr.set_dims_mapping({-1, -1}); - dist_x = phi::distributed::DistTensor(phi::make_ddim(x_shape), x_dist_attr); - dist_y = phi::distributed::DistTensor(phi::make_ddim(y_shape), y_dist_attr); - x = phi::MetaTensor(dist_x); - y = phi::MetaTensor(dist_y); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); @@ -123,10 +114,8 @@ TEST(MatmulSPMDRule, Ctor) { // mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]: done x_dist_attr.set_dims_mapping({-1, -1}); y_dist_attr.set_dims_mapping({1, 0}); - dist_x = phi::distributed::DistTensor(phi::make_ddim(x_shape), x_dist_attr); - dist_y = phi::distributed::DistTensor(phi::make_ddim(y_shape), y_dist_attr); - x = phi::MetaTensor(dist_x); - y = phi::MetaTensor(dist_y); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); @@ -146,10 +135,8 @@ TEST(MatmulSPMDRule, Ctor) { x_shape = {512, 48, 64, 32}; x_dist_attr.set_dims_mapping({0, 1, -1, -1}); y_dist_attr.set_dims_mapping({-1, -1}); - dist_x = phi::distributed::DistTensor(phi::make_ddim(x_shape), x_dist_attr); - dist_y = phi::distributed::DistTensor(phi::make_ddim(y_shape), y_dist_attr); - x = phi::MetaTensor(dist_x); - y = phi::MetaTensor(dist_y); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); @@ -166,10 +153,8 @@ TEST(MatmulSPMDRule, Ctor) { // -1, -1, -1] partial[0]: done x_dist_attr.set_dims_mapping({1, -1, -1, 0}); y_dist_attr.set_dims_mapping({-1, -1}); - dist_x = phi::distributed::DistTensor(phi::make_ddim(x_shape), x_dist_attr); - dist_y = phi::distributed::DistTensor(phi::make_ddim(y_shape), y_dist_attr); - x = phi::MetaTensor(dist_x); - y = phi::MetaTensor(dist_y); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); @@ -188,10 +173,8 @@ TEST(MatmulSPMDRule, Ctor) { // abcmn[1, -1, 0, -1] partial[]: done x_dist_attr.set_dims_mapping({1, -1, -1, 0}); y_dist_attr.set_dims_mapping({-1, -1}); - dist_x = phi::distributed::DistTensor(phi::make_ddim(x_shape), x_dist_attr); - dist_y = phi::distributed::DistTensor(phi::make_ddim(y_shape), y_dist_attr); - x = phi::MetaTensor(dist_x); - y = phi::MetaTensor(dist_y); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/true, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); @@ -208,10 +191,8 @@ TEST(MatmulSPMDRule, Ctor) { // abcmn[-1, -1, -1, 1] partial[0]: done x_dist_attr.set_dims_mapping({-1, -1, -1, -1}); y_dist_attr.set_dims_mapping({1, 0}); - dist_x = phi::distributed::DistTensor(phi::make_ddim(x_shape), x_dist_attr); - dist_y = phi::distributed::DistTensor(phi::make_ddim(y_shape), y_dist_attr); - x = phi::MetaTensor(dist_x); - y = phi::MetaTensor(dist_y); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/true}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); @@ -232,10 +213,8 @@ TEST(MatmulSPMDRule, Ctor) { // abcmn[-1, -1, -1, 1] partial[0]: done x_dist_attr.set_dims_mapping({-1, -1, 0, 1}); y_dist_attr.set_dims_mapping({1, 0}); - dist_x = phi::distributed::DistTensor(phi::make_ddim(x_shape), x_dist_attr); - dist_y = phi::distributed::DistTensor(phi::make_ddim(y_shape), y_dist_attr); - x = phi::MetaTensor(dist_x); - y = phi::MetaTensor(dist_y); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/true, /*trans_x=*/true}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); @@ -258,10 +237,8 @@ TEST(MatmulSPMDRule, Ctor) { // abcmn[-1, -1, -1, 1] partial[0]: done x_dist_attr.set_dims_mapping({-1, -1, 1, 0}); y_dist_attr.set_dims_mapping({1, 0}); - dist_x = phi::distributed::DistTensor(phi::make_ddim(x_shape), x_dist_attr); - dist_y = phi::distributed::DistTensor(phi::make_ddim(y_shape), y_dist_attr); - x = phi::MetaTensor(dist_x); - y = phi::MetaTensor(dist_y); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/true, /*trans_x=*/true}); EXPECT_ANY_THROW(infered_dist_attrs = matmul_spmd_rule.InferForward(ctx)); @@ -272,10 +249,8 @@ TEST(MatmulSPMDRule, Ctor) { // abcmn[-1, -1, -1, 1] partial[0]: x_dist_attr.set_dims_mapping({-1, -1, 0, 1}); y_dist_attr.set_dims_mapping({1, 0}); - dist_x = phi::distributed::DistTensor(phi::make_ddim(x_shape), x_dist_attr); - dist_y = phi::distributed::DistTensor(phi::make_ddim(y_shape), y_dist_attr); - x = phi::MetaTensor(dist_x); - y = phi::MetaTensor(dist_y); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + y = phi::distributed::DistMetaTensor(phi::make_ddim(y_shape), y_dist_attr); ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/true, /*trans_x=*/true}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); @@ -290,7 +265,6 @@ TEST(MatmulSPMDRule, Ctor) { std::set({0})); infered_dist_attrs.second[0].clean_partial_dims(std::vector({0})); EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); - VLOG(4) << "test11 done." << std::endl << std::endl << std::endl; } @@ -420,25 +394,21 @@ TEST(MatmulSPMDRuleInferBackward, Ctor) { out_dist_attr.set_dynamic_dims(std::vector({false, false})); out_dist_attr.set_partial_status(std::vector({0})); - DistTensorSpec x_dist_tensor_spec = DistTensorSpec(x_shape, x_dist_attr); - DistTensorSpec y_dist_tensor_spec = DistTensorSpec(y_shape, y_dist_attr); - DistTensorSpec out_dist_tensor_spec = - DistTensorSpec(out_shape, out_dist_attr); - - paddle::framework::AttributeMap attrs; - attrs["trans_x"] = false; - attrs["trans_y"] = false; + phi::distributed::DistMetaTensor x(phi::make_ddim(x_shape), x_dist_attr); + phi::distributed::DistMetaTensor y(phi::make_ddim(y_shape), y_dist_attr); + phi::distributed::DistMetaTensor out(phi::make_ddim(out_shape), + out_dist_attr); - SPMDRuleBase* matmul_rule = SPMDRuleMap::Instance().Get("matmul"); + auto matmul_spmd_rule = + phi::distributed::SpmdRuleFactory::Instance().GetSpmdRule("matmul"); // TODO(zyc) update in future: propogate the partial in inferbackward // abmn[-1, -1, 1, -1] + partial[0] --> abmk[-1, -1, 1, -1], a1kn[-1, -1, -1, // -1] + phi::distributed::InferSpmdContext ctx( + {x, y, out}, {/*trans_x=*/false, /*trans_x=*/false}); std::pair, std::vector> - infered_dist_attrs = - matmul_rule->InferBackward({x_dist_tensor_spec, y_dist_tensor_spec}, - {out_dist_tensor_spec}, - attrs); + infered_dist_attrs = matmul_spmd_rule.InferBackward(ctx); size_t input_size = 2; size_t output_size = 1;