diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 8130c91a31373..efd514f124ce5 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -233,6 +233,8 @@ if(WITH_XPU) pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(multi_encoder_xpu_adaptive_seqlen_fuse_pass inference DIR xpu + DEPS ${XPU_PASS_DEPS}) pass_library(multi_encoder_xpu_slice_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(generate_sequence_xpu_fuse_pass inference DIR xpu DEPS @@ -529,4 +531,8 @@ if(WITH_XPU) test_fused_multi_transformer_cachekv_layout_trans_pass SRCS xpu/fused_multi_transformer_cachekv_layout_trans_pass_test.cc DEPS fused_multi_transformer_cachekv_layout_trans_pass) + cc_test( + test_multi_encoder_xpu_adaptive_seqlen_fuse_pass + SRCS xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass_test.cc + DEPS multi_encoder_xpu_adaptive_seqlen_fuse_pass) endif() diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index a84c9b84e9466..9e61210e623ab 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -61,6 +61,7 @@ static const std::vector xpu_support_subgraph_passes = { "generate_sequence_xpu_fuse_pass", "embedding_with_eltwise_add_xpu_fuse_pass", "multi_encoder_xpu_fuse_pass", + "multi_encoder_xpu_adaptive_seqlen_fuse_pass", "multi_encoder_xpu_slice_fuse_pass", "fused_multi_transformer_cachekv_layout_trans_pass", "one_beam_size_fuse_pass", diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 0557104773a95..9811ac01b0b5c 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -824,6 +824,17 @@ struct Layers { return unary_op("logical_not", input); } + VarDesc* not_equal(VarDesc* x, VarDesc* y, int axis = -1) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("not_equal"); + op->SetInput("X", {x->Name()}); + op->SetInput("Y", {y->Name()}); + op->SetAttr("axis", axis); + op->SetOutput("Out", {out->Name()}); + return out; + } + VarDesc* stack(std::vector inputs, int axis = -1) { VarDesc* out = lod_tensor(unique_name()); OpDesc* op = program_.MutableBlock(0)->AppendOp(); @@ -838,6 +849,16 @@ struct Layers { return out; } + VarDesc* tile(VarDesc* x, const std::vector& repeat_times = {2}) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("tile"); + op->SetInput("X", {x->Name()}); + op->SetAttr("repeat_times", repeat_times); + op->SetOutput("Out", {out->Name()}); + return out; + } + private: VarDesc* lod_tensor(std::string name, std::vector shape = {}, diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.cc new file mode 100644 index 0000000000000..e20320e29a959 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.cc @@ -0,0 +1,343 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.h" +#include +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct AdaptiveSeqlenPatternV1 : public PatternBase { + AdaptiveSeqlenPatternV1(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(embedding_xpu); + PATTERN_DECL_NODE(layer_norm); + PATTERN_DECL_NODE(matmul); + PATTERN_DECL_NODE(scale); + PATTERN_DECL_NODE(stack); + PATTERN_DECL_NODE(multi_encoder_xpu); + // declare variable node's name + PATTERN_DECL_NODE(mask); + PATTERN_DECL_NODE(embedding_xpu_out); + PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(matmul_out); + PATTERN_DECL_NODE(scale_out); + PATTERN_DECL_NODE(stack_out); +}; + +AdaptiveSeqlenPatternV1::AdaptiveSeqlenPatternV1(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* embedding_xpu = pattern->NewNode(embedding_xpu_repr()) + ->assert_is_op("embedding_with_eltwise_add_xpu"); + auto* embedding_xpu_out = + pattern->NewNode(embedding_xpu_out_repr()) + ->assert_is_op_output("embedding_with_eltwise_add_xpu", "out") + ->assert_is_op_input("layer_norm", "X"); + auto* layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto* layer_norm_out = pattern->NewNode(layer_norm_out_repr()) + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("multi_encoder_xpu", "x"); + + auto* mask = pattern->NewNode(mask_repr()) + ->assert_is_op_input("matmul", "X") + ->assert_is_op_input("matmul", "Y"); + auto* matmul = pattern->NewNode(matmul_repr())->assert_is_op("matmul"); + auto* matmul_out = pattern->NewNode(matmul_out_repr()) + ->assert_is_op_output("matmul", "Out") + ->assert_is_op_input("scale", "X"); + auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale"); + auto* scale_out = pattern->NewNode(scale_out_repr()) + ->assert_is_op_output("scale", "Out") + ->assert_is_op_input("stack", "X"); + auto* stack = pattern->NewNode(stack_repr())->assert_is_op("stack"); + auto* stack_out = pattern->NewNode(stack_out_repr()) + ->assert_is_op_output("stack", "Y") + ->assert_is_op_input("multi_encoder_xpu", "mask"); + + auto* multi_encoder_xpu = pattern->NewNode(multi_encoder_xpu_repr()) + ->assert_is_op("multi_encoder_xpu"); + + embedding_xpu->LinksTo({embedding_xpu_out}); + layer_norm->LinksFrom({embedding_xpu_out}).LinksTo({layer_norm_out}); + matmul->LinksFrom({mask}).LinksTo({matmul_out}); + scale->LinksFrom({matmul_out}).LinksTo({scale_out}); + stack->LinksFrom({scale_out}).LinksTo({stack_out}); + multi_encoder_xpu->LinksFrom({layer_norm_out, stack_out}); +} + +} // namespace patterns + +int MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyAdaptiveSeqlenPassV1( + ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::AdaptiveSeqlenPatternV1 pattern(gpd.mutable_pattern(), name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle ApplyAdaptiveSeqlenPassV1 fuse"; + GET_IR_NODE(embedding_xpu); + GET_IR_NODE(layer_norm); + GET_IR_NODE(matmul); + GET_IR_NODE(scale); + GET_IR_NODE(stack); + GET_IR_NODE(multi_encoder_xpu); + GET_IR_NODE(mask); + GET_IR_NODE(embedding_xpu_out); + GET_IR_NODE(layer_norm_out); + GET_IR_NODE(matmul_out); + GET_IR_NODE(scale_out); + GET_IR_NODE(stack_out); + + std::string mask_name = mask->Name(); + std::string seq_lod_name = mask_name + "_seq_lod"; + VarDesc seq_lod_desc(seq_lod_name); + auto* seq_lod = graph->CreateVarNode(&seq_lod_desc); + std::string max_seq_len_name = mask_name + "_max_seq_len"; + VarDesc max_seq_len_desc(max_seq_len_name); + auto* max_seq_len = graph->CreateVarNode(&max_seq_len_desc); + + embedding_xpu->Op()->SetInput("mask", {mask_name}); + embedding_xpu->Op()->SetOutput("seq_lod", {seq_lod_name}); + embedding_xpu->Op()->SetOutput("max_seq_len", {max_seq_len_name}); + multi_encoder_xpu->Op()->SetInput("seq_lod", {seq_lod_name}); + multi_encoder_xpu->Op()->SetInput("max_seq_len", {max_seq_len_name}); + multi_encoder_xpu->Op()->RemoveInput("mask"); + IR_NODE_LINK_TO(mask, embedding_xpu); + IR_NODE_LINK_TO(embedding_xpu, seq_lod); + IR_NODE_LINK_TO(embedding_xpu, max_seq_len); + IR_NODE_LINK_TO(seq_lod, multi_encoder_xpu); + IR_NODE_LINK_TO(max_seq_len, multi_encoder_xpu); + + // delete useless node + std::unordered_set delete_nodes{ + matmul, scale, stack, matmul_out, scale_out, stack_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + return found_subgraph_count; +} + +namespace patterns { + +struct AdaptiveSeqlenPatternV2 : public PatternBase { + AdaptiveSeqlenPatternV2(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(embedding_xpu); + PATTERN_DECL_NODE(layer_norm); + PATTERN_DECL_NODE(not_equal); + PATTERN_DECL_NODE(cast); + PATTERN_DECL_NODE(unsqueeze_0); + PATTERN_DECL_NODE(matmul); + PATTERN_DECL_NODE(scale_0); + PATTERN_DECL_NODE(scale_1); + PATTERN_DECL_NODE(unsqueeze_1); + PATTERN_DECL_NODE(tile); + PATTERN_DECL_NODE(multi_encoder_xpu); + // declare variable node's name + PATTERN_DECL_NODE(mask); + PATTERN_DECL_NODE(embedding_xpu_out); + PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(not_equal_out); + PATTERN_DECL_NODE(cast_out); + PATTERN_DECL_NODE(unsqueeze_0_out); + PATTERN_DECL_NODE(matmul_out); + PATTERN_DECL_NODE(scale_0_out); + PATTERN_DECL_NODE(scale_1_out); + PATTERN_DECL_NODE(unsqueeze_1_out); + PATTERN_DECL_NODE(tile_out); +}; + +AdaptiveSeqlenPatternV2::AdaptiveSeqlenPatternV2(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* embedding_xpu = pattern->NewNode(embedding_xpu_repr()) + ->assert_is_op("embedding_with_eltwise_add_xpu"); + auto* embedding_xpu_out = + pattern->NewNode(embedding_xpu_out_repr()) + ->assert_is_op_output("embedding_with_eltwise_add_xpu", "out") + ->assert_is_op_input("layer_norm", "X"); + auto* layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto* layer_norm_out = pattern->NewNode(layer_norm_out_repr()) + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("multi_encoder_xpu", "x"); + + auto* mask = + pattern->NewNode(mask_repr())->assert_is_op_input("not_equal", "X"); + auto* not_equal = + pattern->NewNode(not_equal_repr())->assert_is_op("not_equal"); + auto* not_equal_out = pattern->NewNode(not_equal_out_repr()) + ->assert_is_op_output("not_equal", "Out") + ->assert_is_op_input("cast", "X"); + auto* cast = pattern->NewNode(cast_repr())->assert_is_op("cast"); + auto* cast_out = pattern->NewNode(cast_out_repr()) + ->assert_is_op_output("cast", "Out") + ->assert_is_op_input("unsqueeze2", "X"); + auto* unsqueeze_0 = + pattern->NewNode(unsqueeze_0_repr())->assert_is_op("unsqueeze2"); + auto* unsqueeze_0_out = pattern->NewNode(unsqueeze_0_out_repr()) + ->assert_is_op_output("unsqueeze2", "Out") + ->assert_is_op_input("matmul_v2", "X") + ->assert_is_op_input("matmul_v2", "Y"); + auto* matmul = pattern->NewNode(matmul_repr())->assert_is_op("matmul_v2"); + auto* matmul_out = pattern->NewNode(matmul_out_repr()) + ->assert_is_op_output("matmul_v2", "Out") + ->assert_is_op_input("scale", "X"); + auto* scale_0 = pattern->NewNode(scale_0_repr())->assert_is_op("scale"); + auto* scale_0_out = pattern->NewNode(scale_0_out_repr()) + ->assert_is_op_output("scale", "Out") + ->assert_is_op_input("scale", "X"); + auto* scale_1 = pattern->NewNode(scale_1_repr())->assert_is_op("scale"); + auto* scale_1_out = pattern->NewNode(scale_1_out_repr()) + ->assert_is_op_output("scale", "Out") + ->assert_is_op_input("unsqueeze2", "X"); + auto* unsqueeze_1 = + pattern->NewNode(unsqueeze_1_repr())->assert_is_op("unsqueeze2"); + auto* unsqueeze_1_out = pattern->NewNode(unsqueeze_1_out_repr()) + ->assert_is_op_output("unsqueeze2", "Out") + ->assert_is_op_input("tile", "X"); + auto* tile = pattern->NewNode(tile_repr())->assert_is_op("tile"); + auto* tile_out = pattern->NewNode(tile_out_repr()) + ->assert_is_op_output("tile", "Out") + ->assert_is_op_input("multi_encoder_xpu", "mask"); + + auto* multi_encoder_xpu = pattern->NewNode(multi_encoder_xpu_repr()) + ->assert_is_op("multi_encoder_xpu"); + + embedding_xpu->LinksTo({embedding_xpu_out}); + layer_norm->LinksFrom({embedding_xpu_out}).LinksTo({layer_norm_out}); + not_equal->LinksFrom({mask}).LinksTo({not_equal_out}); + cast->LinksFrom({not_equal_out}).LinksTo({cast_out}); + unsqueeze_0->LinksFrom({cast_out}).LinksTo({unsqueeze_0_out}); + matmul->LinksFrom({unsqueeze_0_out}).LinksTo({matmul_out}); + scale_0->LinksFrom({matmul_out}).LinksTo({scale_0_out}); + scale_1->LinksFrom({scale_0_out}).LinksTo({scale_1_out}); + unsqueeze_1->LinksFrom({scale_1_out}).LinksTo({unsqueeze_1_out}); + tile->LinksFrom({unsqueeze_1_out}).LinksTo({tile_out}); + multi_encoder_xpu->LinksFrom({layer_norm_out, tile_out}); +} + +} // namespace patterns + +int MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyAdaptiveSeqlenPassV2( + ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::AdaptiveSeqlenPatternV2 pattern(gpd.mutable_pattern(), name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle ApplyAdaptiveSeqlenPassV2 fuse"; + GET_IR_NODE(embedding_xpu); + GET_IR_NODE(layer_norm); + GET_IR_NODE(not_equal); + GET_IR_NODE(cast); + GET_IR_NODE(unsqueeze_0); + GET_IR_NODE(matmul); + GET_IR_NODE(scale_0); + GET_IR_NODE(scale_1); + GET_IR_NODE(unsqueeze_1); + GET_IR_NODE(tile); + GET_IR_NODE(multi_encoder_xpu); + GET_IR_NODE(mask); + GET_IR_NODE(embedding_xpu_out); + GET_IR_NODE(layer_norm_out); + GET_IR_NODE(not_equal_out); + GET_IR_NODE(cast_out); + GET_IR_NODE(unsqueeze_0_out); + GET_IR_NODE(matmul_out); + GET_IR_NODE(scale_0_out); + GET_IR_NODE(scale_1_out); + GET_IR_NODE(unsqueeze_1_out); + GET_IR_NODE(tile_out); + + std::string mask_name = mask->Name(); + std::string seq_lod_name = mask_name + "_seq_lod"; + VarDesc seq_lod_desc(seq_lod_name); + auto* seq_lod = graph->CreateVarNode(&seq_lod_desc); + std::string max_seq_len_name = mask_name + "_max_seq_len"; + VarDesc max_seq_len_desc(max_seq_len_name); + auto* max_seq_len = graph->CreateVarNode(&max_seq_len_desc); + + embedding_xpu->Op()->SetInput("mask", {mask_name}); + embedding_xpu->Op()->SetOutput("seq_lod", {seq_lod_name}); + embedding_xpu->Op()->SetOutput("max_seq_len", {max_seq_len_name}); + multi_encoder_xpu->Op()->SetInput("seq_lod", {seq_lod_name}); + multi_encoder_xpu->Op()->SetInput("max_seq_len", {max_seq_len_name}); + multi_encoder_xpu->Op()->RemoveInput("mask"); + IR_NODE_LINK_TO(mask, embedding_xpu); + IR_NODE_LINK_TO(embedding_xpu, seq_lod); + IR_NODE_LINK_TO(embedding_xpu, max_seq_len); + IR_NODE_LINK_TO(seq_lod, multi_encoder_xpu); + IR_NODE_LINK_TO(max_seq_len, multi_encoder_xpu); + + // delete useless node + std::unordered_set delete_nodes{not_equal, + cast, + unsqueeze_0, + matmul, + scale_0, + scale_1, + unsqueeze_1, + tile, + not_equal_out, + cast_out, + unsqueeze_0_out, + matmul_out, + scale_0_out, + scale_1_out, + unsqueeze_1_out, + tile_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + return found_subgraph_count; +} + +void MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + int found_subgraph_count = ApplyAdaptiveSeqlenPassV1(graph); + found_subgraph_count += ApplyAdaptiveSeqlenPassV2(graph); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(multi_encoder_xpu_adaptive_seqlen_fuse_pass, + paddle::framework::ir::MultiEncoderXPUAdaptiveSeqlenFusePass); + +REGISTER_PASS_CAPABILITY(multi_encoder_xpu_adaptive_seqlen_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "multi_encoder_xpu", 0)); diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.h b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.h new file mode 100644 index 0000000000000..a21d6498dea8e --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.h @@ -0,0 +1,143 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +// support adaptive seq len for bert/ernie +class MultiEncoderXPUAdaptiveSeqlenFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + /* + adaptive seqlen V1, before: + + inpu_var* mask_var* + | | + | | + embedding_xpu matmul + | | + | | + layer_norm scale + | | + | | + | stack + \ / + \ / + multi_encoder_xpu + | + | + out_var* + + after: + + inpu_var* mask_var* + \ / + \ / + embedding_xpu + / \ + / \ + embedding_out_var* seq_lod_var* + | | + | | + layer_norm | + \ / + \ / + multi_encoder_xpu + | + | + out_var* + */ + int ApplyAdaptiveSeqlenPassV1(ir::Graph* graph) const; + + /* + adaptive seqlen V2, before: + + inpu_var* mask_var* + | | + | | + embedding_xpu not_equal + | | + | | + layer_norm cast + | | + | | + | unsqueeze2 + | | + | | + | matmul_v2 + | | + | | + | scale + | | + | | + | scale + | | + | | + | unsqueeze2 + | | + | | + | tile + \ / + \ / + multi_encoder_xpu + | + | + out_var* + + after: + + inpu_var* mask_var* + \ / + \ / + embedding_xpu + / \ + / \ + embedding_out_var* seq_lod_var* + | | + | | + layer_norm | + \ / + \ / + multi_encoder_xpu + | + | + out_var* + */ + int ApplyAdaptiveSeqlenPassV2(ir::Graph* graph) const; + + private: + const std::string name_scope_{"multi_encoder_xpu_adaptive_seqlen_fuse_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass_test.cc b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass_test.cc new file mode 100644 index 0000000000000..556ef75415aef --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass_test.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(MultiEncoderXPUAdaptiveSeqlenFusePass, V1) { + Layers layers; + auto* block = layers.Block(); + + auto* embedding_xpu_out = layers.data("embedding_xpu_out"); + OpDesc* embedding_xpu = block->AppendOp(); + embedding_xpu->SetType("embedding_with_eltwise_add_xpu"); + embedding_xpu->SetOutput("out", {embedding_xpu_out->Name()}); + auto* layer_norm_out = layers.layer_norm(embedding_xpu_out)[0]; + + auto* mask = layers.data("mask"); + auto* matmul_out = layers.matmul(mask, mask); + auto* scale_out = layers.scale(matmul_out); + auto* stack_out = layers.stack({scale_out, scale_out}); + + OpDesc* multi_encoder_xpu = block->AppendOp(); + multi_encoder_xpu->SetType("multi_encoder_xpu"); + multi_encoder_xpu->SetInput("x", {layer_norm_out->Name()}); + multi_encoder_xpu->SetInput("mask", {stack_out->Name()}); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get( + "multi_encoder_xpu_adaptive_seqlen_fuse_pass"); + pass->Apply(graph.get()); + auto num = GetNumOpNodes(graph, "matmul") + GetNumOpNodes(graph, "scale") + + GetNumOpNodes(graph, "stack"); + PADDLE_ENFORCE_EQ( + num, + 0, + platform::errors::PreconditionNotMet( + "matmul/scale/stack ops should be removed from graph, but graph " + "still has %d ops.", + num)); +} + +TEST(MultiEncoderXPUAdaptiveSeqlenFusePass, V2) { + Layers layers; + auto* block = layers.Block(); + + auto* embedding_xpu_out = layers.data("embedding_xpu_out"); + OpDesc* embedding_xpu = block->AppendOp(); + embedding_xpu->SetType("embedding_with_eltwise_add_xpu"); + embedding_xpu->SetOutput("out", {embedding_xpu_out->Name()}); + auto* layer_norm_out = layers.layer_norm(embedding_xpu_out)[0]; + + auto* mask = layers.data("mask"); + auto* not_equal_y = layers.data("not_equal_y"); + auto* not_equal_out = layers.not_equal(mask, not_equal_y); + auto* cast_out = layers.cast(not_equal_out); + auto* unsqueeze_0_out = layers.unsqueeze2(cast_out); + auto* matmul_out = layers.matmul_v2(unsqueeze_0_out, unsqueeze_0_out); + auto* scale_0_out = layers.scale(matmul_out); + auto* scale_1_out = layers.scale(scale_0_out); + auto* unsqueeze_1_out = layers.unsqueeze2(scale_1_out); + auto* tile_out = layers.tile(unsqueeze_1_out); + + OpDesc* multi_encoder_xpu = block->AppendOp(); + multi_encoder_xpu->SetType("multi_encoder_xpu"); + multi_encoder_xpu->SetInput("x", {layer_norm_out->Name()}); + multi_encoder_xpu->SetInput("mask", {tile_out->Name()}); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get( + "multi_encoder_xpu_adaptive_seqlen_fuse_pass"); + pass->Apply(graph.get()); + auto num = GetNumOpNodes(graph, "not_equal") + GetNumOpNodes(graph, "cast") + + GetNumOpNodes(graph, "unsqueeze2") + + GetNumOpNodes(graph, "matmul_v2") + GetNumOpNodes(graph, "scale") + + GetNumOpNodes(graph, "tile"); + PADDLE_ENFORCE_EQ(num, + 0, + platform::errors::PreconditionNotMet( + "not_equal/cast/unsqueeze2/matmul_v2/scale ops should " + "be removed from graph, but graph " + "still has %d ops.", + num)); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(multi_encoder_xpu_adaptive_seqlen_fuse_pass); diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc index e4adc737a6b9e..8551d7f688113 100644 --- a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc @@ -1205,7 +1205,10 @@ std::vector MultiEncoderXPUFusePass::GeneratePatternParams() const { return std::vector{ // Params are arranged in alphabetic order - {"gelu", "matmul_v2", "matmul", "matmul_v2", false, false, true}}; + {"gelu", "matmul_v2", "matmul", "matmul_v2", false, false, true}, + {"gelu", "matmul_v2", "matmul_v2", "matmul_v2", false, true, true}, + {"gelu", "mul", "matmul", "matmul", false, true, true}, + }; } } // namespace ir diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h index 1aabd19ef4c95..7c7595a8564cd 100644 --- a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h @@ -51,6 +51,7 @@ Origin subgraph: | | | | | v_transpose q_transpose k_transpose | | | | + | | | (scale) | | \ / | | qk_matmul | | | diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_slice_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_slice_fuse_pass.cc index 59754f9d58146..c3224d070346d 100644 --- a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_slice_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_slice_fuse_pass.cc @@ -74,15 +74,12 @@ MultiEncoderXPUSlicePattern::MultiEncoderXPUSlicePattern( ->assert_more([](Node* node) { std::vector axes = PADDLE_GET_CONST(std::vector, node->Op()->GetAttr("axes")); - std::vector decrease_axis = PADDLE_GET_CONST( - std::vector, node->Op()->GetAttr("decrease_axis")); std::vector starts = PADDLE_GET_CONST( std::vector, node->Op()->GetAttr("starts")); std::vector ends = PADDLE_GET_CONST(std::vector, node->Op()->GetAttr("ends")); - return axes.size() == 1 && axes[0] == 1 && - decrease_axis.size() == 1 && decrease_axis[0] == 1 && - starts.size() == 1 && starts[0] == 0 && // + return axes.size() == 1 && axes[0] == 1 && starts.size() == 1 && + starts[0] == 0 && // ends.size() == 1 && ends[0] == 1; }); auto* slice_out = pattern->NewNode(slice_out_repr()) diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index f9123a111771f..e5372ab78f520 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -516,6 +516,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "generate_sequence_xpu_fuse_pass", "embedding_with_eltwise_add_xpu_fuse_pass", "multi_encoder_xpu_fuse_pass", + "multi_encoder_xpu_adaptive_seqlen_fuse_pass", "multi_encoder_xpu_slice_fuse_pass", "fused_multi_transformer_cachekv_layout_trans_pass", "one_beam_size_fuse_pass", diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index d5f9a9e0d481d..1e85309c28a77 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -15,13 +15,14 @@ optional : bias, branch, branch_max ,x_max - op : embedding_with_eltwise_add_xpu - args : (Tensor[] ids, Tensor[] tables, int64_t padding_idx) - output: Tensor + args : (Tensor[] ids, Tensor[] tables, Tensor mask, int64_t padding_idx) + output: Tensor(out), Tensor(seq_lod), Tensor(max_seq_len) infer_meta : func: EmbeddingWithEltwiseAddXPUInferMeta kernel: func: embedding_with_eltwise_add_xpu data_type: tables + optional : mask, seq_lod, max_seq_len - op : fc_xpu args : (Tensor x, Tensor x_max, Tensor w, Tensor w_max, Tensor bias, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha) @@ -77,11 +78,11 @@ data_type : dtype - op : multi_encoder_xpu - args : (Tensor x, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor mask, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx) + args : (Tensor x, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor mask, Tensor seq_lod, Tensor max_seq_len, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx) output : Tensor(out), Tensor(x_fp16), Tensor(out_fp16) infer_meta : func : MultiEncoderXPUInferMeta kernel : func : multi_encoder_xpu data_type : x - optional : mask, x_fp16, out_fp16 + optional : mask, seq_lod, max_seq_len, x_fp16, out_fp16 diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index f32ee528075db..39a76bc7143b9 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -165,7 +165,10 @@ void Conv2dXPUInferMeta(const MetaTensor& x, void EmbeddingWithEltwiseAddXPUInferMeta( const std::vector& ids, const std::vector& tables, - MetaTensor* out) { + const MetaTensor& mask, + MetaTensor* out, + MetaTensor* seq_lod, + MetaTensor* max_seq_len) { PADDLE_ENFORCE_GT(ids.size(), 0UL, phi::errors::InvalidArgument( @@ -226,6 +229,8 @@ void MultiEncoderXPUInferMeta( const std::vector& ln_scale, const std::vector& ln_bias, const MetaTensor& mask, + const MetaTensor& seq_lod, + const MetaTensor& max_seq_len, int layer_num, bool norm_before, int hidden_dim, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 0fddf3995ffa0..7a844eeff0c9b 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -44,7 +44,10 @@ void Conv2dXPUInferMeta(const MetaTensor& x, void EmbeddingWithEltwiseAddXPUInferMeta( const std::vector& ids, const std::vector& tables, - MetaTensor* out); + const MetaTensor& mask, + MetaTensor* out, + MetaTensor* seq_lod, + MetaTensor* max_seq_len); void FcXPUInferMeta(const MetaTensor& x, const MetaTensor& x_max, @@ -72,6 +75,8 @@ void MultiEncoderXPUInferMeta( const std::vector& ln_scale, const std::vector& ln_bias, const MetaTensor& mask, + const MetaTensor& seq_lod, + const MetaTensor& max_seq_len, int layer_num, bool norm_before, int hidden_dim, diff --git a/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc index f18b6d5283b58..de704bec0f7e5 100644 --- a/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc @@ -24,14 +24,14 @@ void EmbeddingWithEltwiseAddXpuKernel( const Context& ctx, const std::vector& ids, const std::vector& tables, + const paddle::optional& mask, int64_t padding_idx, - DenseTensor* out) { + DenseTensor* out, + DenseTensor* seq_lod, + DenseTensor* max_seq_len) { using XPUType = typename XPUTypeTrait::Type; - auto& id_dims = ids[0]->dims(); - int idx_len = id_dims[0] * id_dims[1]; - int emb_layer_num = ids.size(); - int embed_dim = tables[0]->dims()[1]; - std::vector table_lens_cpu; + int emb_dim = tables[0]->dims()[1]; + std::vector table_lens; std::vector arg_tables; for (auto* table : tables) { auto& table_dims = table->dims(); @@ -39,16 +39,16 @@ void EmbeddingWithEltwiseAddXpuKernel( table_dims.size(), 2, errors::InvalidArgument( - "The table_dims size [%d] should be equal 2.", - table_dims.size())); /* shape like [table_len, embed_dim] */ + "The table_dims size [%d] should be equal to 2.", + table_dims.size())); /* shape like [table_len, emb_dim] */ PADDLE_ENFORCE_EQ( table_dims[1], - embed_dim, + emb_dim, errors::InvalidArgument( - "Every embed_dim [%d] should be equal the first one [%d].", + "Every emb_dim [%d] should be equal to the first one [%d].", table_dims[1], - embed_dim)); - table_lens_cpu.push_back(table_dims[0]); + emb_dim)); + table_lens.push_back(table_dims[0]); if (std::is_same::value) { DenseTensor table_data_fp32_t; ctx.template Alloc(&table_data_fp32_t, @@ -64,26 +64,99 @@ void EmbeddingWithEltwiseAddXpuKernel( arg_tables.push_back(table->data()); } } - std::vector> int_idx(emb_layer_num, - std::vector(idx_len, 0)); - std::vector> arg_ids; + + int emb_layer_num = ids.size(); for (int i = 0; i < emb_layer_num; i++) { - PADDLE_ENFORCE_EQ( - ids[i]->dtype() == phi::DataType::INT64 || - ids[i]->dtype() == phi::DataType::INT32, - true, + auto id_dtype = ids[i]->dtype(); + PADDLE_ENFORCE( + id_dtype == phi::DataType::INT64 || id_dtype == phi::DataType::INT32, errors::InvalidArgument( "The data type of ids should be int64 or int32, but got %s.", - ids[i]->dtype())); - for (int j = 0; j < idx_len; j++) { - if (ids[i]->dtype() == phi::DataType::INT64) { - int_idx[i][j] = static_cast(ids[i]->data()[j]); - } else if (ids[i]->dtype() == phi::DataType::INT32) { - int_idx[i][j] = ids[i]->data()[j]; + DataTypeToString(id_dtype))); + } + + auto& id_dims = ids[0]->dims(); + int batch_size = id_dims[0]; + int max_seq_len_value = id_dims[1]; + int ids_len = id_dims[0] * id_dims[1]; + std::vector> int_ids(emb_layer_num, + std::vector(ids_len, 0)); + std::vector> arg_ids; + auto* mask_tensor = mask.get_ptr(); + if (mask_tensor != nullptr) { + auto mask_dtype = mask_tensor->dtype(); + PADDLE_ENFORCE( + mask_dtype == phi::DataType::INT64 || + mask_dtype == phi::DataType::FLOAT32, + errors::InvalidArgument( + "The data type of mask should be int64 or float32, but got %s.", + DataTypeToString(mask_dtype))); + max_seq_len->Resize({1}); + ctx.template HostAlloc(max_seq_len)[0] = max_seq_len_value; + + seq_lod->Resize({batch_size + 1}); + int* seq_lod_data = ctx.template HostAlloc(seq_lod); + seq_lod_data[0] = 0; + for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { + int cur_batch_seq_len = 0; + for (int seq_idx = 0; seq_idx < max_seq_len_value; seq_idx++) { + int mask_idx = batch_idx * max_seq_len_value + seq_idx; + if ((mask_dtype == phi::DataType::INT64 && + mask->data()[mask_idx] > 0) || + (mask_dtype == phi::DataType::FLOAT32 && + fabs(mask->data()[mask_idx]) > 1e-5)) { + cur_batch_seq_len++; + } else { + break; + } + } + PADDLE_ENFORCE_GT( + cur_batch_seq_len, + 0, + errors::PreconditionNotMet( + "cur_batch_seq_len should be greater than 0, but got %d.", + cur_batch_seq_len)); + seq_lod_data[batch_idx + 1] = seq_lod_data[batch_idx] + cur_batch_seq_len; + } + out->Resize({batch_size, seq_lod_data[batch_size], emb_dim}); + + for (int i = 0; i < emb_layer_num; i++) { + if (ids[i]->dtype() == DataType::INT64) { + auto* ids_data = ids[i]->data(); + for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { + for (int j = 0; + j < seq_lod_data[batch_idx + 1] - seq_lod_data[batch_idx]; + j++) { + int_ids[i][seq_lod_data[batch_idx] + j] = + ids_data[batch_idx * max_seq_len_value + j]; + } + } + } else { + auto* ids_data = ids[i]->data(); + for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { + for (int j = 0; + j < seq_lod_data[batch_idx + 1] - seq_lod_data[batch_idx]; + j++) { + int_ids[i][seq_lod_data[batch_idx] + j] = + ids_data[batch_idx * max_seq_len_value + j]; + } + } + } + arg_ids.push_back( + xpu::VectorParam{int_ids[i].data(), ids_len, nullptr}); + } + } else { + for (int i = 0; i < emb_layer_num; i++) { + for (int j = 0; j < ids_len; j++) { + if (ids[i]->dtype() == phi::DataType::INT64) { + int_ids[i][j] = static_cast(ids[i]->data()[j]); + } else if (ids[i]->dtype() == phi::DataType::INT32) { + int_ids[i][j] = ids[i]->data()[j]; + } } + arg_ids.push_back( + xpu::VectorParam{int_ids[i].data(), ids_len, nullptr}); } - arg_ids.push_back( - xpu::VectorParam{int_idx[i].data(), idx_len, nullptr}); } ctx.template Alloc(out); @@ -95,10 +168,10 @@ void EmbeddingWithEltwiseAddXpuKernel( arg_tables, /* tables */ out_fp32_t.data(), arg_ids, - table_lens_cpu, - embed_dim, - std::vector(table_lens_cpu.size(), 1.0f), - std::vector(table_lens_cpu.size(), padding_idx)); + table_lens, + emb_dim, + std::vector(table_lens.size(), 1.0f), + std::vector(table_lens.size(), padding_idx)); PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_with_eltwise_add_xpu"); r = xpu::cast(ctx.x_context(), @@ -112,10 +185,10 @@ void EmbeddingWithEltwiseAddXpuKernel( arg_tables, /* tables */ out->data(), arg_ids, - table_lens_cpu, - embed_dim, - std::vector(table_lens_cpu.size(), 1.0f), - std::vector(table_lens_cpu.size(), padding_idx)); + table_lens, + emb_dim, + std::vector(table_lens.size(), 1.0f), + std::vector(table_lens.size(), padding_idx)); PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_with_eltwise_add_xpu"); } } @@ -130,4 +203,7 @@ PD_REGISTER_KERNEL(embedding_with_eltwise_add_xpu, float, phi::dtype::float16) { kernel->InputAt(0).SetBackend(phi::Backend::CPU); + kernel->InputAt(2).SetBackend(phi::Backend::CPU); + kernel->OutputAt(1).SetBackend(phi::Backend::CPU); + kernel->OutputAt(2).SetBackend(phi::Backend::CPU); } diff --git a/paddle/phi/kernels/fusion/xpu/multi_encoder_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/multi_encoder_xpu_kernel.cc index 094d2854e412a..b4a3127ca2469 100644 --- a/paddle/phi/kernels/fusion/xpu/multi_encoder_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/multi_encoder_xpu_kernel.cc @@ -29,6 +29,8 @@ void MultiEncoderXPUKernel(const Context& ctx, const std::vector& ln_scale, const std::vector& ln_bias, const paddle::optional& mask, + const paddle::optional& seq_lod, + const paddle::optional& max_seq_len, int layer_num, bool norm_before, int hidden_dim, @@ -89,19 +91,52 @@ void MultiEncoderXPUKernel(const Context& ctx, } const float* mask_data = mask.get_ptr() == nullptr ? nullptr : mask.get_ptr()->data(); + const int* seq_lod_data = + seq_lod.get_ptr() == nullptr ? nullptr : seq_lod.get_ptr()->data(); + const int* max_seq_len_data = max_seq_len.get_ptr() == nullptr + ? nullptr + : max_seq_len.get_ptr()->data(); xpu::Activation_t qkv_act(static_cast(act_type)); int batch = x.dims()[0]; - int max_seqlen = x.dims()[1]; // matmul_size * layer_num std::vector quant_types(8 * layer_num, xpu::QuantType::NOT_QUANT); - if (mask_data) { + if (seq_lod_data) { + xpu::VectorParam query_lod = { + seq_lod_data, seq_lod.get_ptr()->numel(), nullptr}; + int max_seq_len_value = slice_idx == -1 ? max_seq_len_data[0] : -1; + xpu::QKVAttnParam qkv_attn_param(query_lod, + head_num, + size_per_head, + qkv_act, + slice_idx, + true, + max_seq_len_value, + hidden_dim, + norm_before, + false); + qkv_attn_param.quant_type_.assign(quant_types.begin(), quant_types.end()); + qkv_attn_param.scale_of_hidden_units = ffn_hidden_dim_scale; + int r = xpu::transformer_encoder( + ctx.x_context(), + x_fp16_data, + fc_weight_data, + out_fp16_data, + fc_input_max_data, + fc_weight_max_data, + fc_bias_data, + ln_scale_data, + ln_bias_data, + qkv_attn_param); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "multi_encoder_xpu"); + } else if (mask_data) { auto mask_dims = mask.get_ptr()->dims(); std::vector mask_shape(mask_dims.Get(), mask_dims.Get() + mask_dims.size()); + int max_seq_len_value = x.dims()[1]; xpu::QKVAttnParam qkv_attn_param(batch, - max_seqlen, + max_seq_len_value, head_num, size_per_head, mask_shape, @@ -128,9 +163,10 @@ void MultiEncoderXPUKernel(const Context& ctx, PADDLE_ENFORCE_XDNN_SUCCESS(r, "multi_encoder_xpu"); } else { // When no mask input, like VIT, create LOD to act as vsl. + int max_seq_len_value = x.dims()[1]; std::vector lod; for (int i = 0; i < batch + 1; i++) { - lod.push_back(i * max_seqlen); + lod.push_back(i * max_seq_len_value); } xpu::VectorParam query_lod = { lod.data(), static_cast(lod.size()), nullptr}; @@ -180,4 +216,7 @@ PD_REGISTER_KERNEL(multi_encoder_xpu, ALL_LAYOUT, phi::fusion::MultiEncoderXPUKernel, float, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->InputAt(7).SetBackend(phi::Backend::CPU); + kernel->InputAt(8).SetBackend(phi::Backend::CPU); +}