diff --git a/paddle/fluid/pir/transforms/CMakeLists.txt b/paddle/fluid/pir/transforms/CMakeLists.txt index 3140d9d20dc09e..082af7b827ead0 100644 --- a/paddle/fluid/pir/transforms/CMakeLists.txt +++ b/paddle/fluid/pir/transforms/CMakeLists.txt @@ -1,3 +1,10 @@ +file(GLOB FUSION_PASS_SRCS "fusion/*.cc") + +cc_library( + fusion_passes + SRCS ${FUSION_PASS_SRCS} + DEPS drr) + cc_library( transform_general_functions SRCS transform_general_functions.cc @@ -9,15 +16,10 @@ cc_library( DEPS pd_kernel_dialect pd_op_dialect pd_op_dialect_utils) cc_library( - _constant_folding_pass + pd_constant_folding_pass SRCS constant_folding_pass.cc DEPS standalone_executor pd_op_to_kernel_pass transform_general_functions) -cc_library( - fused_gemm_epilogue_pass - SRCS fused_gemm_epilogue_pass.cc - DEPS drr) - cc_library( pd_inplace_pass SRCS inplace_pass.cc diff --git a/paddle/fluid/pir/transforms/build_cinn_pass.cc b/paddle/fluid/pir/transforms/build_cinn_pass.cc index 2d8ea8ff0f5013..2d0fef35cb454c 100644 --- a/paddle/fluid/pir/transforms/build_cinn_pass.cc +++ b/paddle/fluid/pir/transforms/build_cinn_pass.cc @@ -574,11 +574,11 @@ void ReplaceWithGroupOp(pir::Block* block, class BuildCinnPass : public pir::Pass { public: - BuildCinnPass() : pir::Pass("BuildCinnPass", /*opt_level=*/1) {} + BuildCinnPass() : pir::Pass("build_cinn_pass", /*opt_level=*/1) {} void Run(pir::Operation* op) override { auto module_op = op->dyn_cast(); - IR_ENFORCE(module_op, "InplacePass should run on module op."); + IR_ENFORCE(module_op, "build_cinn_pass should run on module op."); auto* block = module_op.block(); std::vector groups = diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index 3b40960373a2f1..dfa26c950212f3 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -192,8 +192,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { class ConstantFoldingPass : public pir::Pass { public: - // TODO(liuyuanle): Naming convention for pass. - ConstantFoldingPass() : pir::Pass("ConstantFoldingPass", 1) {} + ConstantFoldingPass() : pir::Pass("constant_folding_pass", 1) {} bool Initialize(pir::IrContext* context) override { pir::RewritePatternSet ps(context); diff --git a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc new file mode 100644 index 00000000000000..0bd8c5e29e7efc --- /dev/null +++ b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc @@ -0,0 +1,253 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h" + +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace { + +class MultiHeadMatmulFusePattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // + // Source Pattern. + // + pir::drr::SourcePattern src = ctx->SourcePattern(); + // The first path to matmul with scale (q). + const auto &matmul_1 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_1_transpose_x")}, + {"transpose_y", src.Attr("matmul_1_transpose_y")}}); + src.Tensor("matmul_1_out") = + matmul_1(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_1_in_2")); + const auto &add_1 = src.Op("pd_op.add"); + src.Tensor("add_1_out") = + add_1(src.Tensor("matmul_1_out"), src.Tensor("add_1_in_2")); + const auto &full_int_array_1 = + src.Op("pd_op.full_int_array", + {{"value", src.Attr("full_int_array_1_value")}}); + const auto &reshape_1 = src.Op("pd_op.reshape"); + reshape_1({&src.Tensor("add_1_out"), &full_int_array_1()}, + {&src.Tensor("reshape_1_out"), &src.Tensor("reshape_1_xshape")}); + const auto &transpose_1 = src.Op("pd_op.transpose"); + src.Tensor("transpose_1_out") = transpose_1(src.Tensor("reshape_1_out")); + const auto &full_1 = + src.Op("pd_op.full", {{"value", src.Attr("full_1_value")}}); + const auto &scale = src.Op("pd_op.scale"); + src.Tensor("scale_out") = scale(src.Tensor("transpose_1_out"), full_1()); + + // The second path to matmul (k). + const auto &matmul_2 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_2_transpose_x")}, + {"transpose_y", src.Attr("matmul_2_transpose_y")}}); + src.Tensor("matmul_2_out") = + matmul_2(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_2_in_2")); + const auto &add_2 = src.Op("pd_op.add"); + src.Tensor("add_2_out") = + add_2(src.Tensor("matmul_2_out"), src.Tensor("add_2_in_2")); + const auto &full_int_array_2 = src.Op("pd_op.full_int_array"); + const auto &reshape_2 = src.Op("pd_op.reshape"); + reshape_2({&src.Tensor("add_2_out"), &full_int_array_2()}, + {&src.Tensor("reshape_2_out"), &src.Tensor("reshape_2_xshape")}); + const auto &transpose_2 = src.Op("pd_op.transpose"); + src.Tensor("transpose_2_out") = transpose_2(src.Tensor("reshape_2_out")); + + // The third path to matmul (v). + const auto &matmul_3 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_3_transpose_x")}, + {"transpose_y", src.Attr("matmul_3_transpose_y")}}); + src.Tensor("matmul_3_out") = + matmul_3(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_3_in_2")); + const auto &add_3 = src.Op("pd_op.add"); + src.Tensor("add_3_out") = + add_3(src.Tensor("matmul_3_out"), src.Tensor("add_3_in_2")); + const auto &full_int_array_3 = src.Op("pd_op.full_int_array"); + const auto &reshape_3 = src.Op("pd_op.reshape"); + reshape_3({&src.Tensor("add_3_out"), &full_int_array_3()}, + {&src.Tensor("reshape_3_out"), &src.Tensor("reshape_3_xshape")}); + const auto &transpose_3 = src.Op("pd_op.transpose"); + src.Tensor("transpose_3_out") = transpose_3(src.Tensor("reshape_3_out")); + + // softmax(qk)v + const auto &matmul_4 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_4_transpose_x")}, + {"transpose_y", src.Attr("matmul_4_transpose_y")}}); + src.Tensor("matmul_4_out") = + matmul_4(src.Tensor("scale_out"), src.Tensor("transpose_2_out")); + const auto &add_4 = src.Op("pd_op.add"); + src.Tensor("add_4_out") = + add_4(src.Tensor("matmul_4_out"), src.Tensor("add_4_in_2")); + const auto &softmax = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_axis")}}); + src.Tensor("softmax_out") = softmax(src.Tensor("add_4_out")); + const auto &matmul_5 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_5_transpose_x")}, + {"transpose_y", src.Attr("matmul_5_transpose_y")}}); + src.Tensor("matmul_5_out") = + matmul_5(src.Tensor("softmax_out"), src.Tensor("transpose_3_out")); + const auto &transpose_4 = src.Op("pd_op.transpose"); + src.Tensor("transpose_4_out") = transpose_4(src.Tensor("matmul_5_out")); + const auto &full_int_array_4 = src.Op("pd_op.full_int_array"); + const auto &reshape_4 = src.Op("pd_op.reshape"); + reshape_4({&src.Tensor("transpose_4_out"), &full_int_array_4()}, + {&src.Tensor("reshape_4_out"), &src.Tensor("reshape_4_xshape")}); + + // + // Constraints. + // + src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { + const auto &softmax_axis = match_ctx.Attr("softmax_axis"); + if (softmax_axis != -1 && softmax_axis != 3) return false; + + bool matmul_1_transpose_x = match_ctx.Attr("matmul_1_transpose_x"); + bool matmul_1_transpose_y = match_ctx.Attr("matmul_1_transpose_y"); + if (matmul_1_transpose_x || matmul_1_transpose_y) return false; + + bool matmul_2_transpose_x = match_ctx.Attr("matmul_2_transpose_x"); + bool matmul_2_transpose_y = match_ctx.Attr("matmul_2_transpose_y"); + if (matmul_2_transpose_x || matmul_2_transpose_y) return false; + + bool matmul_3_transpose_x = match_ctx.Attr("matmul_3_transpose_x"); + bool matmul_3_transpose_y = match_ctx.Attr("matmul_3_transpose_y"); + if (matmul_3_transpose_x || matmul_3_transpose_y) return false; + + bool matmul_4_transpose_x = match_ctx.Attr("matmul_4_transpose_x"); + bool matmul_4_transpose_y = match_ctx.Attr("matmul_4_transpose_y"); + if (matmul_4_transpose_x || !matmul_4_transpose_y) return false; + + bool matmul_5_transpose_x = match_ctx.Attr("matmul_5_transpose_x"); + bool matmul_5_transpose_y = match_ctx.Attr("matmul_5_transpose_y"); + if (matmul_5_transpose_x || matmul_5_transpose_y) return false; + + return true; + }); + + // + // Result Pattern. + // + pir::drr::ResultPattern res = src.ResultPattern(); + // W combine. + const auto &combine_1 = res.Op("builtin.combine"); + combine_1({&res.Tensor("matmul_1_in_2"), + &res.Tensor("matmul_2_in_2"), + &res.Tensor("matmul_3_in_2")}, + {&res.Tensor("combine_1_out")}); + const auto &concat_axis = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> int { return 0; }); + const auto &concat_1 = res.Op("pd_op.concat", {{"axis", concat_axis}}); + res.Tensor("concat_1_out") = concat_1(res.Tensor("combine_1_out")); + const auto &reshape_5_shape = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> std::vector { + auto matmul_1_in_2 = match_ctx.Tensor("matmul_1_in_2").Shape(); + return {-1, 3, matmul_1_in_2.at(1)}; + }); + const auto &reshape_5 = + res.Op("pd_op.reshape", {{"shape", reshape_5_shape}}); + reshape_5({&res.Tensor("concat_1_out")}, + {&res.Tensor("reshape_5_out"), &res.NoneTensor()}); + + // Bias combine. + const auto &combine_2 = res.Op("builtin.combine"); + combine_2({&res.Tensor("add_1_in_2"), + &res.Tensor("add_2_in_2"), + &res.Tensor("add_3_in_2")}, + {&res.Tensor("combine_2_out")}); + const auto &concat_2 = res.Op("pd_op.concat", {{"axis", concat_axis}}); + res.Tensor("concat_2_out") = concat_2(res.Tensor("combine_2_out")); + const auto &reshape_6_shape = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> std::vector { + return {3, -1}; + }); + const auto &reshape_6 = + res.Op("pd_op.reshape", {{"shape", reshape_6_shape}}); + reshape_6({&res.Tensor("concat_2_out")}, + {&res.Tensor("reshape_6_out"), &res.NoneTensor()}); + + const auto &head_number = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> int { + const auto &full_int_array_1_value = + match_ctx.Attr>("full_int_array_1_value"); + return full_int_array_1_value.at(2); + }); + const auto &alpha = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + return match_ctx.Attr("full_1_value"); + }); + const auto &multihead_matmul = res.Op( + "pd_op.multihead_matmul", + {{"transpose_q", res.Attr([](const pir::drr::MatchContext &match_ctx) { + return false; + })}, + {"transpose_k", res.Attr([](const pir::drr::MatchContext &match_ctx) { + return true; + })}, + {"transpose_v", res.Attr([](const pir::drr::MatchContext &match_ctx) { + return false; + })}, + {"head_number", head_number}, + {"alpha", alpha}}); + multihead_matmul({&res.Tensor("matmul_1_in_1"), + &res.Tensor("reshape_5_out"), + &res.Tensor("reshape_6_out"), + &res.Tensor("add_4_in_2")}, + {&res.Tensor("reshape_4_out")}); + } +}; + +class AttentionFusePass : public pir::Pass { + public: + AttentionFusePass() : pir::Pass("attention_fuse_pass", 2) {} + + bool Initialize(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(MultiHeadMatmulFusePattern().Build(context)); + // Add other attention variant fuse pattern. + + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation *op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation *op) const override { + return op->isa<::pir::ModuleOp>() && op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +} // namespace + +namespace pir { +std::unique_ptr CreateAttentionFusePass() { + return std::make_unique(); +} +} // namespace pir + +REGISTER_IR_PASS(attention_fuse_pass, AttentionFusePass); diff --git a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h new file mode 100644 index 00000000000000..0c0d2e84952ca4 --- /dev/null +++ b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateAttentionFusePass(); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/fused_gemm_epilogue_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc similarity index 98% rename from paddle/fluid/pir/transforms/fused_gemm_epilogue_pass.cc rename to paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc index 8585050e8efbf3..1096a10c85067d 100644 --- a/paddle/fluid/pir/transforms/fused_gemm_epilogue_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/fused_gemm_epilogue_pass.h" +#include "paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h" #include "paddle/fluid/pir/drr/api/drr_pattern_base.h" #include "paddle/pir/pass/pass.h" @@ -254,7 +254,7 @@ class FusedLinearReluGradPattern class FusedGemmEpiloguePass : public pir::Pass { public: - FusedGemmEpiloguePass() : pir::Pass("FusedGemmEpiloguePass", 1) {} + FusedGemmEpiloguePass() : pir::Pass("fused_gemm_epilogue_pass", 1) {} bool Initialize(pir::IrContext *context) override { pir::RewritePatternSet ps(context); @@ -292,4 +292,4 @@ std::unique_ptr CreateFusedGemmEpiloguePass() { } // namespace pir -REGISTER_IR_PASS(fused_gemm_epilogue, FusedGemmEpiloguePass); +REGISTER_IR_PASS(fused_gemm_epilogue_pass, FusedGemmEpiloguePass); diff --git a/paddle/fluid/pir/transforms/fused_gemm_epilogue_pass.h b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/fused_gemm_epilogue_pass.h rename to paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h diff --git a/paddle/fluid/pir/transforms/inplace_pass.cc b/paddle/fluid/pir/transforms/inplace_pass.cc index f70fc125689911..760a78c1952ab1 100644 --- a/paddle/fluid/pir/transforms/inplace_pass.cc +++ b/paddle/fluid/pir/transforms/inplace_pass.cc @@ -320,11 +320,11 @@ static std::unordered_map GetInplaceOps( class InplacePass : public pir::Pass { public: - InplacePass() : pir::Pass("InplacePass", 3) {} + InplacePass() : pir::Pass("inplace_pass", 3) {} void Run(pir::Operation* op) override { auto module_op = op->dyn_cast(); - IR_ENFORCE(module_op, "InplacePass should run on module op."); + IR_ENFORCE(module_op, "inplace_pass should run on module op."); auto* block = module_op.block(); auto inplace_ops = details::GetInplaceOps(block); @@ -365,4 +365,4 @@ std::unique_ptr CreateInplacePass() { } // namespace pir -REGISTER_IR_PASS(inplace, InplacePass); +REGISTER_IR_PASS(inplace_pass, InplacePass); diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 3e50bd64ca4ac0..9b49b6c340a325 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -65,8 +65,8 @@ using pir::Type; using pir::Value; using pybind11::return_value_policy; -USE_PASS(dead_code_elimination); -USE_PASS(inplace); +USE_PASS(dead_code_elimination_pass); +USE_PASS(inplace_pass); PHI_DECLARE_bool(print_ir); diff --git a/paddle/pir/dialect/shape/transforms/shape_optimization.cc b/paddle/pir/dialect/shape/transforms/shape_optimization.cc index 767353efdbc5f2..54f43c74cb4154 100644 --- a/paddle/pir/dialect/shape/transforms/shape_optimization.cc +++ b/paddle/pir/dialect/shape/transforms/shape_optimization.cc @@ -300,7 +300,7 @@ bool OptimizeShapeComputation(pir::ModuleOp m, PassPipelineRunner runner) { class ShapeOptimizationPass : public pir::Pass { public: - ShapeOptimizationPass() : pir::Pass("shape_optimization", 0) {} + ShapeOptimizationPass() : pir::Pass("shape_optimization_pass", 0) {} void Run(pir::Operation* op) override { auto module_op = op->dyn_cast(); @@ -328,4 +328,4 @@ std::unique_ptr CreateShapeOptimizationPass() { } // namespace pir -REGISTER_IR_PASS(shape_optimization, pir::ShapeOptimizationPass); +REGISTER_IR_PASS(shape_optimization_pass, pir::ShapeOptimizationPass); diff --git a/paddle/pir/pass/pass_manager.h b/paddle/pir/pass/pass_manager.h index f606be139c42f2..92faed24f1f5d2 100644 --- a/paddle/pir/pass/pass_manager.h +++ b/paddle/pir/pass/pass_manager.h @@ -20,13 +20,13 @@ #include #include "paddle/pir/core/program.h" +#include "paddle/pir/pass/pass.h" namespace pir { class IrContext; class Operation; class Program; -class Pass; class PassInstrumentation; class PassInstrumentor; diff --git a/paddle/pir/transforms/dead_code_elimination_pass.cc b/paddle/pir/transforms/dead_code_elimination_pass.cc index 6216fca5037e1d..bca3394d1c55d8 100644 --- a/paddle/pir/transforms/dead_code_elimination_pass.cc +++ b/paddle/pir/transforms/dead_code_elimination_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/pir/transforms/dead_code_elimination_pass.h" + #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/program.h" #include "paddle/pir/pass/pass.h" @@ -25,11 +26,12 @@ namespace { // Now just a naive implementation. class DeadCodeEliminationPass : public pir::Pass { public: - DeadCodeEliminationPass() : pir::Pass("dead_code_elimination", 0) {} + DeadCodeEliminationPass() : pir::Pass("dead_code_elimination_pass", 0) {} void Run(pir::Operation *op) override { auto module_op = op->dyn_cast(); - IR_ENFORCE(module_op, "DcePass should run on module op."); + IR_ENFORCE(module_op, + "dead_code_elimination_pass should run on module op."); auto *block = module_op.block(); std::vector erased_op; for (auto &op : *block) { @@ -76,4 +78,4 @@ std::unique_ptr CreateDeadCodeEliminationPass() { } // namespace pir -REGISTER_IR_PASS(dead_code_elimination, DeadCodeEliminationPass); +REGISTER_IR_PASS(dead_code_elimination_pass, DeadCodeEliminationPass); diff --git a/paddle/pir/transforms/reorder_block_ops_pass.cc b/paddle/pir/transforms/reorder_block_ops_pass.cc index db2d29fe9b0a73..0e25cc5f180ba9 100644 --- a/paddle/pir/transforms/reorder_block_ops_pass.cc +++ b/paddle/pir/transforms/reorder_block_ops_pass.cc @@ -24,11 +24,11 @@ namespace { class ReorderBlockOpsPass : public pir::Pass { public: - ReorderBlockOpsPass() : pir::Pass("ReorderBlockOpsPass", 0) {} + ReorderBlockOpsPass() : pir::Pass("reorder_block_ops_pass", 0) {} void Run(pir::Operation *op) override { IR_ENFORCE(op->num_regions() > 0, - "ReorderBlockOpsPass should run on Operation which regions " + "reorder_block_ops_pass should run on Operation which regions " "number greater than 0."); for (size_t i = 0; i < op->num_regions(); ++i) { for (auto *block : op->region(i)) { diff --git a/test/cpp/pir/pattern_rewrite/CMakeLists.txt b/test/cpp/pir/pattern_rewrite/CMakeLists.txt index 6f92036ecd6944..3282fe5893abba 100644 --- a/test/cpp/pir/pattern_rewrite/CMakeLists.txt +++ b/test/cpp/pir/pattern_rewrite/CMakeLists.txt @@ -1,5 +1,6 @@ set(PATTERN_REWRITE_TEST_DEPS - _constant_folding_pass transform_general_functions gtest pd_op_dialect pir) + pd_constant_folding_pass transform_general_functions gtest pd_op_dialect + pir) if(WITH_DISTRIBUTE) set(PATTERN_REWRITE_TEST_DEPS ${PATTERN_REWRITE_TEST_DEPS} fleet_executor @@ -19,20 +20,21 @@ cc_test_old( pd_op_dialect pir) cc_test_old( - drr_fuse_linear_test + drr_same_type_binding_test SRCS - drr_fuse_linear_test.cc + drr_same_type_binding_test.cc DEPS - fused_gemm_epilogue_pass drr gtest pd_op_dialect pir) + cc_test_old( - drr_same_type_binding_test + drr_fuse_linear_test SRCS - drr_same_type_binding_test.cc + drr_fuse_linear_test.cc DEPS + fusion_passes drr gtest pd_op_dialect @@ -42,6 +44,7 @@ cc_test_old( SRCS drr_attention_fuse_test.cc DEPS + fusion_passes drr gtest pd_op_dialect diff --git a/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc b/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc index 22252e52beb394..8ac00044146f5b 100644 --- a/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc @@ -12,247 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include #include -#include #include #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h" #include "paddle/pir/core/builtin_dialect.h" -#include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_manager.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" - -class MultiHeadMatmulFusePattern - : public pir::drr::DrrPatternBase { - public: - void operator()(pir::drr::DrrPatternContext *ctx) const override { - // - // Source Pattern. - // - pir::drr::SourcePattern src = ctx->SourcePattern(); - // The first path to matmul with scale (q). - const auto &matmul_1 = - src.Op("pd_op.matmul", - {{"transpose_x", src.Attr("matmul_1_transpose_x")}, - {"transpose_y", src.Attr("matmul_1_transpose_y")}}); - src.Tensor("matmul_1_out") = - matmul_1(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_1_in_2")); - const auto &add_1 = src.Op("pd_op.add"); - src.Tensor("add_1_out") = - add_1(src.Tensor("matmul_1_out"), src.Tensor("add_1_in_2")); - const auto &full_int_array_1 = - src.Op("pd_op.full_int_array", - {{"value", src.Attr("full_int_array_1_value")}}); - const auto &reshape_1 = src.Op("pd_op.reshape"); - reshape_1({&src.Tensor("add_1_out"), &full_int_array_1()}, - {&src.Tensor("reshape_1_out"), &src.Tensor("reshape_1_xshape")}); - const auto &transpose_1 = src.Op("pd_op.transpose"); - src.Tensor("transpose_1_out") = transpose_1(src.Tensor("reshape_1_out")); - const auto &full_1 = - src.Op("pd_op.full", {{"value", src.Attr("full_1_value")}}); - const auto &scale = src.Op("pd_op.scale"); - src.Tensor("scale_out") = scale(src.Tensor("transpose_1_out"), full_1()); - - // The second path to matmul (k). - const auto &matmul_2 = - src.Op("pd_op.matmul", - {{"transpose_x", src.Attr("matmul_2_transpose_x")}, - {"transpose_y", src.Attr("matmul_2_transpose_y")}}); - src.Tensor("matmul_2_out") = - matmul_2(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_2_in_2")); - const auto &add_2 = src.Op("pd_op.add"); - src.Tensor("add_2_out") = - add_2(src.Tensor("matmul_2_out"), src.Tensor("add_2_in_2")); - const auto &full_int_array_2 = src.Op("pd_op.full_int_array"); - const auto &reshape_2 = src.Op("pd_op.reshape"); - reshape_2({&src.Tensor("add_2_out"), &full_int_array_2()}, - {&src.Tensor("reshape_2_out"), &src.Tensor("reshape_2_xshape")}); - const auto &transpose_2 = src.Op("pd_op.transpose"); - src.Tensor("transpose_2_out") = transpose_2(src.Tensor("reshape_2_out")); - - // The third path to matmul (v). - const auto &matmul_3 = - src.Op("pd_op.matmul", - {{"transpose_x", src.Attr("matmul_3_transpose_x")}, - {"transpose_y", src.Attr("matmul_3_transpose_y")}}); - src.Tensor("matmul_3_out") = - matmul_3(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_3_in_2")); - const auto &add_3 = src.Op("pd_op.add"); - src.Tensor("add_3_out") = - add_3(src.Tensor("matmul_3_out"), src.Tensor("add_3_in_2")); - const auto &full_int_array_3 = src.Op("pd_op.full_int_array"); - const auto &reshape_3 = src.Op("pd_op.reshape"); - reshape_3({&src.Tensor("add_3_out"), &full_int_array_3()}, - {&src.Tensor("reshape_3_out"), &src.Tensor("reshape_3_xshape")}); - const auto &transpose_3 = src.Op("pd_op.transpose"); - src.Tensor("transpose_3_out") = transpose_3(src.Tensor("reshape_3_out")); - - // softmax(qk)v - const auto &matmul_4 = - src.Op("pd_op.matmul", - {{"transpose_x", src.Attr("matmul_4_transpose_x")}, - {"transpose_y", src.Attr("matmul_4_transpose_y")}}); - src.Tensor("matmul_4_out") = - matmul_4(src.Tensor("scale_out"), src.Tensor("transpose_2_out")); - const auto &add_4 = src.Op("pd_op.add"); - src.Tensor("add_4_out") = - add_4(src.Tensor("matmul_4_out"), src.Tensor("add_4_in_2")); - const auto &softmax = - src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_axis")}}); - src.Tensor("softmax_out") = softmax(src.Tensor("add_4_out")); - const auto &matmul_5 = - src.Op("pd_op.matmul", - {{"transpose_x", src.Attr("matmul_5_transpose_x")}, - {"transpose_y", src.Attr("matmul_5_transpose_y")}}); - src.Tensor("matmul_5_out") = - matmul_5(src.Tensor("softmax_out"), src.Tensor("transpose_3_out")); - const auto &transpose_4 = src.Op("pd_op.transpose"); - src.Tensor("transpose_4_out") = transpose_4(src.Tensor("matmul_5_out")); - const auto &full_int_array_4 = src.Op("pd_op.full_int_array"); - const auto &reshape_4 = src.Op("pd_op.reshape"); - reshape_4({&src.Tensor("transpose_4_out"), &full_int_array_4()}, - {&src.Tensor("reshape_4_out"), &src.Tensor("reshape_4_xshape")}); - - // - // Constraints. - // - src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { - const auto &softmax_axis = match_ctx.Attr("softmax_axis"); - if (softmax_axis != -1 && softmax_axis != 3) return false; - - bool matmul_1_transpose_x = match_ctx.Attr("matmul_1_transpose_x"); - bool matmul_1_transpose_y = match_ctx.Attr("matmul_1_transpose_y"); - if (matmul_1_transpose_x || matmul_1_transpose_y) return false; - - bool matmul_2_transpose_x = match_ctx.Attr("matmul_2_transpose_x"); - bool matmul_2_transpose_y = match_ctx.Attr("matmul_2_transpose_y"); - if (matmul_2_transpose_x || matmul_2_transpose_y) return false; - - bool matmul_3_transpose_x = match_ctx.Attr("matmul_3_transpose_x"); - bool matmul_3_transpose_y = match_ctx.Attr("matmul_3_transpose_y"); - if (matmul_3_transpose_x || matmul_3_transpose_y) return false; - - bool matmul_4_transpose_x = match_ctx.Attr("matmul_4_transpose_x"); - bool matmul_4_transpose_y = match_ctx.Attr("matmul_4_transpose_y"); - if (matmul_4_transpose_x || !matmul_4_transpose_y) return false; - - bool matmul_5_transpose_x = match_ctx.Attr("matmul_5_transpose_x"); - bool matmul_5_transpose_y = match_ctx.Attr("matmul_5_transpose_y"); - if (matmul_5_transpose_x || matmul_5_transpose_y) return false; - - return true; - }); - - // - // Result Pattern. - // - pir::drr::ResultPattern res = src.ResultPattern(); - // W combine. - const auto &combine_1 = res.Op("builtin.combine"); - combine_1({&res.Tensor("matmul_1_in_2"), - &res.Tensor("matmul_2_in_2"), - &res.Tensor("matmul_3_in_2")}, - {&res.Tensor("combine_1_out")}); - const auto &concat_axis = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> int { return 0; }); - const auto &concat_1 = res.Op("pd_op.concat", {{"axis", concat_axis}}); - res.Tensor("concat_1_out") = concat_1(res.Tensor("combine_1_out")); - const auto &reshape_5_shape = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> std::vector { - auto matmul_1_in_2 = match_ctx.Tensor("matmul_1_in_2").Shape(); - return {-1, 3, matmul_1_in_2.at(1)}; - }); - const auto &reshape_5 = - res.Op("pd_op.reshape", {{"shape", reshape_5_shape}}); - reshape_5({&res.Tensor("concat_1_out")}, - {&res.Tensor("reshape_5_out"), &res.NoneTensor()}); - - // Bias combine. - const auto &combine_2 = res.Op("builtin.combine"); - combine_2({&res.Tensor("add_1_in_2"), - &res.Tensor("add_2_in_2"), - &res.Tensor("add_3_in_2")}, - {&res.Tensor("combine_2_out")}); - const auto &concat_2 = res.Op("pd_op.concat", {{"axis", concat_axis}}); - res.Tensor("concat_2_out") = concat_2(res.Tensor("combine_2_out")); - const auto &reshape_6_shape = res.Attr( - [](const pir::drr::MatchContext &match_ctx) -> std::vector { - return {3, -1}; - }); - const auto &reshape_6 = - res.Op("pd_op.reshape", {{"shape", reshape_6_shape}}); - reshape_6({&res.Tensor("concat_2_out")}, - {&res.Tensor("reshape_6_out"), &res.NoneTensor()}); - - const auto &head_number = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> int { - const auto &full_int_array_1_value = - match_ctx.Attr>("full_int_array_1_value"); - return full_int_array_1_value.at(2); - }); - const auto &alpha = - res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { - return match_ctx.Attr("full_1_value"); - }); - const auto &multihead_matmul = res.Op( - "pd_op.multihead_matmul", - {{"transpose_q", res.Attr([](const pir::drr::MatchContext &match_ctx) { - return false; - })}, - {"transpose_k", res.Attr([](const pir::drr::MatchContext &match_ctx) { - return true; - })}, - {"transpose_v", res.Attr([](const pir::drr::MatchContext &match_ctx) { - return false; - })}, - {"head_number", head_number}, - {"alpha", alpha}}); - multihead_matmul({&res.Tensor("matmul_1_in_1"), - &res.Tensor("reshape_5_out"), - &res.Tensor("reshape_6_out"), - &res.Tensor("add_4_in_2")}, - {&res.Tensor("reshape_4_out")}); - } -}; - -class AttentionFusePass : public pir::Pass { - public: - AttentionFusePass() : pir::Pass("AttentionFusePass", 1) {} - - bool Initialize(pir::IrContext *context) override { - pir::RewritePatternSet ps(context); - ps.Add(MultiHeadMatmulFusePattern().Build(context)); - // Add other attention variant fuse pattern. - - patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); - return true; - } - - void Run(pir::Operation *op) override { - pir::GreedyRewriteConfig cfg; - cfg.use_top_down_traversal = true; - cfg.max_iterations = 10; - pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); - } - - bool CanApplyOn(pir::Operation *op) const override { - return op->name() == "builtin.module" && op->num_regions() > 0; - } - - private: - pir::FrozenRewritePatternSet patterns_; -}; - -namespace pir { -std::unique_ptr CreateAttentionFusePass() { - return std::make_unique(); -} -} // namespace pir void BuildProgram(pir::Builder &builder) { // NOLINT paddle::dialect::FullOp matmul_1_in_1 = diff --git a/test/cpp/pir/pattern_rewrite/drr_fuse_linear_test.cc b/test/cpp/pir/pattern_rewrite/drr_fuse_linear_test.cc index bb2e091043d0b6..3ef77cd1f96652 100644 --- a/test/cpp/pir/pattern_rewrite/drr_fuse_linear_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_fuse_linear_test.cc @@ -18,9 +18,8 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/transforms/fused_gemm_epilogue_pass.h" +#include "paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h" #include "paddle/pir/core/builtin_dialect.h" -#include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_manager.h" #include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" diff --git a/test/ir/new_ir/test_pass_manager.py b/test/ir/new_ir/test_pass_manager.py index 5849b0bbdfeffa..44689f485af818 100644 --- a/test/ir/new_ir/test_pass_manager.py +++ b/test/ir/new_ir/test_pass_manager.py @@ -51,12 +51,12 @@ def test_op(self): self.assertTrue('pd_op.uniform' in op_names) pm = pir.PassManager() pm.add_pass( - 'dead_code_elimination' + 'dead_code_elimination_pass' ) # apply pass to elimitate dead code pm.run(new_program) op_names = [op.name() for op in new_program.global_block().ops] # print(op_names) - self.assertEqual(pm.passes(), ['dead_code_elimination']) + self.assertEqual(pm.passes(), ['dead_code_elimination_pass']) self.assertFalse(pm.empty()) self.assertTrue( 'pd_op.uniform' not in op_names