Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRR] add attention_fuse_pass #58205

Merged
merged 17 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
e41ee6d
fix cudnn 8.7+ bug on cudnnConvolutionBiasActivationForward
yuanlehome Jul 13, 2023
9ecefe9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuanlehome Jul 14, 2023
d092636
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuanlehome Jul 17, 2023
b74a6df
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuanlehome Jul 24, 2023
9ffdf8e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuanlehome Aug 1, 2023
2ebb74e
Merge branch 'develop' of https://github.com/yuanlehome/Paddle into d…
yuanlehome Aug 1, 2023
8b63d93
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuanlehome Aug 7, 2023
92a6c19
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuanlehome Aug 24, 2023
48fa89a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuanlehome Sep 15, 2023
bf17cef
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuanlehome Sep 15, 2023
7869b8c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuanlehome Sep 21, 2023
40bdd66
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuanlehome Oct 16, 2023
6b2a3bf
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuanlehome Oct 18, 2023
61a7645
[pir] add attention_fuse_pass
yuanlehome Oct 18, 2023
28fc296
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
yuanlehome Oct 19, 2023
15fa7ae
update
yuanlehome Oct 19, 2023
e47bba3
update
yuanlehome Oct 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions paddle/fluid/pir/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
file(GLOB FUSION_PASS_SRCS "fusion/*.cc")

cc_library(
fusion_passes
SRCS ${FUSION_PASS_SRCS}
DEPS drr)

cc_library(
transform_general_functions
SRCS transform_general_functions.cc
Expand All @@ -9,15 +16,10 @@ cc_library(
DEPS pd_kernel_dialect pd_op_dialect pd_op_dialect_utils)

cc_library(
_constant_folding_pass
pd_constant_folding_pass
SRCS constant_folding_pass.cc
DEPS standalone_executor pd_op_to_kernel_pass transform_general_functions)

cc_library(
fused_gemm_epilogue_pass
SRCS fused_gemm_epilogue_pass.cc
DEPS drr)

cc_library(
pd_inplace_pass
SRCS inplace_pass.cc
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/transforms/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -575,11 +575,11 @@ void ReplaceWithGroupOp(pir::Block* block,

class BuildCinnPass : public pir::Pass {
public:
BuildCinnPass() : pir::Pass("BuildCinnPass", /*opt_level=*/1) {}
BuildCinnPass() : pir::Pass("build_cinn_pass", /*opt_level=*/1) {}

void Run(pir::Operation* op) override {
auto module_op = op->dyn_cast<pir::ModuleOp>();
IR_ENFORCE(module_op, "InplacePass should run on module op.");
IR_ENFORCE(module_op, "build_cinn_pass should run on module op.");
auto* block = module_op.block();

std::vector<GroupOpsVec> groups =
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/pir/transforms/constant_folding_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ class ConstantFoldingPattern : public pir::RewritePattern {

class ConstantFoldingPass : public pir::Pass {
public:
// TODO(liuyuanle): Naming convention for pass.
ConstantFoldingPass() : pir::Pass("ConstantFoldingPass", 1) {}
ConstantFoldingPass() : pir::Pass("constant_folding_pass", 1) {}

bool Initialize(pir::IrContext* context) override {
pir::RewritePatternSet ps(context);
Expand Down
253 changes: 253 additions & 0 deletions paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h"

#include "paddle/fluid/pir/drr/api/drr_pattern_base.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace {

class MultiHeadMatmulFusePattern
: public pir::drr::DrrPatternBase<MultiHeadMatmulFusePattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
//
// Source Pattern.
//
pir::drr::SourcePattern src = ctx->SourcePattern();
// The first path to matmul with scale (q).
const auto &matmul_1 =
src.Op("pd_op.matmul",
{{"transpose_x", src.Attr("matmul_1_transpose_x")},
{"transpose_y", src.Attr("matmul_1_transpose_y")}});
src.Tensor("matmul_1_out") =
matmul_1(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_1_in_2"));
const auto &add_1 = src.Op("pd_op.add");
src.Tensor("add_1_out") =
add_1(src.Tensor("matmul_1_out"), src.Tensor("add_1_in_2"));
const auto &full_int_array_1 =
src.Op("pd_op.full_int_array",
{{"value", src.Attr("full_int_array_1_value")}});
const auto &reshape_1 = src.Op("pd_op.reshape");
reshape_1({&src.Tensor("add_1_out"), &full_int_array_1()},
{&src.Tensor("reshape_1_out"), &src.Tensor("reshape_1_xshape")});
const auto &transpose_1 = src.Op("pd_op.transpose");
src.Tensor("transpose_1_out") = transpose_1(src.Tensor("reshape_1_out"));
const auto &full_1 =
src.Op("pd_op.full", {{"value", src.Attr("full_1_value")}});
const auto &scale = src.Op("pd_op.scale");
src.Tensor("scale_out") = scale(src.Tensor("transpose_1_out"), full_1());

// The second path to matmul (k).
const auto &matmul_2 =
src.Op("pd_op.matmul",
{{"transpose_x", src.Attr("matmul_2_transpose_x")},
{"transpose_y", src.Attr("matmul_2_transpose_y")}});
src.Tensor("matmul_2_out") =
matmul_2(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_2_in_2"));
const auto &add_2 = src.Op("pd_op.add");
src.Tensor("add_2_out") =
add_2(src.Tensor("matmul_2_out"), src.Tensor("add_2_in_2"));
const auto &full_int_array_2 = src.Op("pd_op.full_int_array");
const auto &reshape_2 = src.Op("pd_op.reshape");
reshape_2({&src.Tensor("add_2_out"), &full_int_array_2()},
{&src.Tensor("reshape_2_out"), &src.Tensor("reshape_2_xshape")});
const auto &transpose_2 = src.Op("pd_op.transpose");
src.Tensor("transpose_2_out") = transpose_2(src.Tensor("reshape_2_out"));

// The third path to matmul (v).
const auto &matmul_3 =
src.Op("pd_op.matmul",
{{"transpose_x", src.Attr("matmul_3_transpose_x")},
{"transpose_y", src.Attr("matmul_3_transpose_y")}});
src.Tensor("matmul_3_out") =
matmul_3(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_3_in_2"));
const auto &add_3 = src.Op("pd_op.add");
src.Tensor("add_3_out") =
add_3(src.Tensor("matmul_3_out"), src.Tensor("add_3_in_2"));
const auto &full_int_array_3 = src.Op("pd_op.full_int_array");
const auto &reshape_3 = src.Op("pd_op.reshape");
reshape_3({&src.Tensor("add_3_out"), &full_int_array_3()},
{&src.Tensor("reshape_3_out"), &src.Tensor("reshape_3_xshape")});
const auto &transpose_3 = src.Op("pd_op.transpose");
src.Tensor("transpose_3_out") = transpose_3(src.Tensor("reshape_3_out"));

// softmax(qk)v
const auto &matmul_4 =
src.Op("pd_op.matmul",
{{"transpose_x", src.Attr("matmul_4_transpose_x")},
{"transpose_y", src.Attr("matmul_4_transpose_y")}});
src.Tensor("matmul_4_out") =
matmul_4(src.Tensor("scale_out"), src.Tensor("transpose_2_out"));
const auto &add_4 = src.Op("pd_op.add");
src.Tensor("add_4_out") =
add_4(src.Tensor("matmul_4_out"), src.Tensor("add_4_in_2"));
const auto &softmax =
src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_axis")}});
src.Tensor("softmax_out") = softmax(src.Tensor("add_4_out"));
const auto &matmul_5 =
src.Op("pd_op.matmul",
{{"transpose_x", src.Attr("matmul_5_transpose_x")},
{"transpose_y", src.Attr("matmul_5_transpose_y")}});
src.Tensor("matmul_5_out") =
matmul_5(src.Tensor("softmax_out"), src.Tensor("transpose_3_out"));
const auto &transpose_4 = src.Op("pd_op.transpose");
src.Tensor("transpose_4_out") = transpose_4(src.Tensor("matmul_5_out"));
const auto &full_int_array_4 = src.Op("pd_op.full_int_array");
const auto &reshape_4 = src.Op("pd_op.reshape");
reshape_4({&src.Tensor("transpose_4_out"), &full_int_array_4()},
{&src.Tensor("reshape_4_out"), &src.Tensor("reshape_4_xshape")});

//
// Constraints.
//
src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool {
const auto &softmax_axis = match_ctx.Attr<int>("softmax_axis");
if (softmax_axis != -1 && softmax_axis != 3) return false;

bool matmul_1_transpose_x = match_ctx.Attr<bool>("matmul_1_transpose_x");
bool matmul_1_transpose_y = match_ctx.Attr<bool>("matmul_1_transpose_y");
if (matmul_1_transpose_x || matmul_1_transpose_y) return false;

bool matmul_2_transpose_x = match_ctx.Attr<bool>("matmul_2_transpose_x");
bool matmul_2_transpose_y = match_ctx.Attr<bool>("matmul_2_transpose_y");
if (matmul_2_transpose_x || matmul_2_transpose_y) return false;

bool matmul_3_transpose_x = match_ctx.Attr<bool>("matmul_3_transpose_x");
bool matmul_3_transpose_y = match_ctx.Attr<bool>("matmul_3_transpose_y");
if (matmul_3_transpose_x || matmul_3_transpose_y) return false;

bool matmul_4_transpose_x = match_ctx.Attr<bool>("matmul_4_transpose_x");
bool matmul_4_transpose_y = match_ctx.Attr<bool>("matmul_4_transpose_y");
if (matmul_4_transpose_x || !matmul_4_transpose_y) return false;

bool matmul_5_transpose_x = match_ctx.Attr<bool>("matmul_5_transpose_x");
bool matmul_5_transpose_y = match_ctx.Attr<bool>("matmul_5_transpose_y");
if (matmul_5_transpose_x || matmul_5_transpose_y) return false;

return true;
});

//
// Result Pattern.
//
pir::drr::ResultPattern res = src.ResultPattern();
// W combine.
const auto &combine_1 = res.Op("builtin.combine");
combine_1({&res.Tensor("matmul_1_in_2"),
&res.Tensor("matmul_2_in_2"),
&res.Tensor("matmul_3_in_2")},
{&res.Tensor("combine_1_out")});
const auto &concat_axis = res.Attr(
[](const pir::drr::MatchContext &match_ctx) -> int { return 0; });
const auto &concat_1 = res.Op("pd_op.concat", {{"axis", concat_axis}});
res.Tensor("concat_1_out") = concat_1(res.Tensor("combine_1_out"));
const auto &reshape_5_shape = res.Attr(
[](const pir::drr::MatchContext &match_ctx) -> std::vector<int64_t> {
auto matmul_1_in_2 = match_ctx.Tensor("matmul_1_in_2").Shape();
return {-1, 3, matmul_1_in_2.at(1)};
});
const auto &reshape_5 =
res.Op("pd_op.reshape", {{"shape", reshape_5_shape}});
reshape_5({&res.Tensor("concat_1_out")},
{&res.Tensor("reshape_5_out"), &res.NoneTensor()});

// Bias combine.
const auto &combine_2 = res.Op("builtin.combine");
combine_2({&res.Tensor("add_1_in_2"),
&res.Tensor("add_2_in_2"),
&res.Tensor("add_3_in_2")},
{&res.Tensor("combine_2_out")});
const auto &concat_2 = res.Op("pd_op.concat", {{"axis", concat_axis}});
res.Tensor("concat_2_out") = concat_2(res.Tensor("combine_2_out"));
const auto &reshape_6_shape = res.Attr(
[](const pir::drr::MatchContext &match_ctx) -> std::vector<int64_t> {
return {3, -1};
});
const auto &reshape_6 =
res.Op("pd_op.reshape", {{"shape", reshape_6_shape}});
reshape_6({&res.Tensor("concat_2_out")},
{&res.Tensor("reshape_6_out"), &res.NoneTensor()});

const auto &head_number =
res.Attr([](const pir::drr::MatchContext &match_ctx) -> int {
const auto &full_int_array_1_value =
match_ctx.Attr<std::vector<int64_t>>("full_int_array_1_value");
return full_int_array_1_value.at(2);
});
const auto &alpha =
res.Attr([](const pir::drr::MatchContext &match_ctx) -> float {
return match_ctx.Attr<float>("full_1_value");
});
const auto &multihead_matmul = res.Op(
"pd_op.multihead_matmul",
{{"transpose_q", res.Attr([](const pir::drr::MatchContext &match_ctx) {
return false;
})},
{"transpose_k", res.Attr([](const pir::drr::MatchContext &match_ctx) {
return true;
})},
{"transpose_v", res.Attr([](const pir::drr::MatchContext &match_ctx) {
return false;
})},
{"head_number", head_number},
{"alpha", alpha}});
multihead_matmul({&res.Tensor("matmul_1_in_1"),
&res.Tensor("reshape_5_out"),
&res.Tensor("reshape_6_out"),
&res.Tensor("add_4_in_2")},
{&res.Tensor("reshape_4_out")});
}
};

class AttentionFusePass : public pir::Pass {
public:
AttentionFusePass() : pir::Pass("attention_fuse_pass", 2) {}

bool Initialize(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
ps.Add(MultiHeadMatmulFusePattern().Build(context));
// Add other attention variant fuse pattern.

patterns_ = pir::FrozenRewritePatternSet(std::move(ps));
return true;
}

void Run(pir::Operation *op) override {
pir::GreedyRewriteConfig cfg;
cfg.use_top_down_traversal = true;
cfg.max_iterations = 10;
pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg);
}

bool CanApplyOn(pir::Operation *op) const override {
return op->isa<::pir::ModuleOp>() && op->num_regions() > 0;
}

private:
pir::FrozenRewritePatternSet patterns_;
};

} // namespace

namespace pir {
std::unique_ptr<Pass> CreateAttentionFusePass() {
return std::make_unique<AttentionFusePass>();
}
} // namespace pir

REGISTER_IR_PASS(attention_fuse_pass, AttentionFusePass);
26 changes: 26 additions & 0 deletions paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include "paddle/pir/core/dll_decl.h"

namespace pir {

class Pass;

IR_API std::unique_ptr<Pass> CreateAttentionFusePass();

} // namespace pir
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/transforms/fused_gemm_epilogue_pass.h"
#include "paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h"

#include "paddle/fluid/pir/drr/api/drr_pattern_base.h"
#include "paddle/pir/pass/pass.h"
Expand Down Expand Up @@ -254,7 +254,7 @@ class FusedLinearReluGradPattern

class FusedGemmEpiloguePass : public pir::Pass {
public:
FusedGemmEpiloguePass() : pir::Pass("FusedGemmEpiloguePass", 1) {}
FusedGemmEpiloguePass() : pir::Pass("fused_gemm_epilogue_pass", 1) {}

bool Initialize(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
Expand Down Expand Up @@ -292,4 +292,4 @@ std::unique_ptr<Pass> CreateFusedGemmEpiloguePass() {

} // namespace pir

REGISTER_IR_PASS(fused_gemm_epilogue, FusedGemmEpiloguePass);
REGISTER_IR_PASS(fused_gemm_epilogue_pass, FusedGemmEpiloguePass);
6 changes: 3 additions & 3 deletions paddle/fluid/pir/transforms/inplace_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,11 @@ static std::unordered_map<pir::Operation*, std::string> GetInplaceOps(

class InplacePass : public pir::Pass {
public:
InplacePass() : pir::Pass("InplacePass", 3) {}
InplacePass() : pir::Pass("inplace_pass", 3) {}

void Run(pir::Operation* op) override {
auto module_op = op->dyn_cast<pir::ModuleOp>();
IR_ENFORCE(module_op, "InplacePass should run on module op.");
IR_ENFORCE(module_op, "inplace_pass should run on module op.");
auto* block = module_op.block();

auto inplace_ops = details::GetInplaceOps(block);
Expand Down Expand Up @@ -365,4 +365,4 @@ std::unique_ptr<pir::Pass> CreateInplacePass() {

} // namespace pir

REGISTER_IR_PASS(inplace, InplacePass);
REGISTER_IR_PASS(inplace_pass, InplacePass);
4 changes: 2 additions & 2 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ using pir::Type;
using pir::Value;
using pybind11::return_value_policy;

USE_PASS(dead_code_elimination);
USE_PASS(inplace);
USE_PASS(dead_code_elimination_pass);
USE_PASS(inplace_pass);

PHI_DECLARE_bool(print_ir);

Expand Down
Loading