Skip to content

Commit

Permalink
[DRR] add attention_fuse_pass (PaddlePaddle#58205)
Browse files Browse the repository at this point in the history
* fix cudnn 8.7+ bug on cudnnConvolutionBiasActivationForward

* [pir] add attention_fuse_pass

* update

* update
  • Loading branch information
yuanlehome authored Oct 20, 2023
1 parent 1877039 commit ec5660d
Show file tree
Hide file tree
Showing 17 changed files with 321 additions and 268 deletions.
14 changes: 8 additions & 6 deletions paddle/fluid/pir/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
file(GLOB FUSION_PASS_SRCS "fusion/*.cc")

cc_library(
fusion_passes
SRCS ${FUSION_PASS_SRCS}
DEPS drr)

cc_library(
transform_general_functions
SRCS transform_general_functions.cc
Expand All @@ -9,15 +16,10 @@ cc_library(
DEPS pd_kernel_dialect pd_op_dialect pd_op_dialect_utils)

cc_library(
_constant_folding_pass
pd_constant_folding_pass
SRCS constant_folding_pass.cc
DEPS standalone_executor pd_op_to_kernel_pass transform_general_functions)

cc_library(
fused_gemm_epilogue_pass
SRCS fused_gemm_epilogue_pass.cc
DEPS drr)

cc_library(
pd_inplace_pass
SRCS inplace_pass.cc
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/transforms/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,11 +574,11 @@ void ReplaceWithGroupOp(pir::Block* block,

class BuildCinnPass : public pir::Pass {
public:
BuildCinnPass() : pir::Pass("BuildCinnPass", /*opt_level=*/1) {}
BuildCinnPass() : pir::Pass("build_cinn_pass", /*opt_level=*/1) {}

void Run(pir::Operation* op) override {
auto module_op = op->dyn_cast<pir::ModuleOp>();
IR_ENFORCE(module_op, "InplacePass should run on module op.");
IR_ENFORCE(module_op, "build_cinn_pass should run on module op.");
auto* block = module_op.block();

std::vector<GroupOpsVec> groups =
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/pir/transforms/constant_folding_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ class ConstantFoldingPattern : public pir::RewritePattern {

class ConstantFoldingPass : public pir::Pass {
public:
// TODO(liuyuanle): Naming convention for pass.
ConstantFoldingPass() : pir::Pass("ConstantFoldingPass", 1) {}
ConstantFoldingPass() : pir::Pass("constant_folding_pass", 1) {}

bool Initialize(pir::IrContext* context) override {
pir::RewritePatternSet ps(context);
Expand Down
253 changes: 253 additions & 0 deletions paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h"

#include "paddle/fluid/pir/drr/api/drr_pattern_base.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace {

class MultiHeadMatmulFusePattern
: public pir::drr::DrrPatternBase<MultiHeadMatmulFusePattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
//
// Source Pattern.
//
pir::drr::SourcePattern src = ctx->SourcePattern();
// The first path to matmul with scale (q).
const auto &matmul_1 =
src.Op("pd_op.matmul",
{{"transpose_x", src.Attr("matmul_1_transpose_x")},
{"transpose_y", src.Attr("matmul_1_transpose_y")}});
src.Tensor("matmul_1_out") =
matmul_1(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_1_in_2"));
const auto &add_1 = src.Op("pd_op.add");
src.Tensor("add_1_out") =
add_1(src.Tensor("matmul_1_out"), src.Tensor("add_1_in_2"));
const auto &full_int_array_1 =
src.Op("pd_op.full_int_array",
{{"value", src.Attr("full_int_array_1_value")}});
const auto &reshape_1 = src.Op("pd_op.reshape");
reshape_1({&src.Tensor("add_1_out"), &full_int_array_1()},
{&src.Tensor("reshape_1_out"), &src.Tensor("reshape_1_xshape")});
const auto &transpose_1 = src.Op("pd_op.transpose");
src.Tensor("transpose_1_out") = transpose_1(src.Tensor("reshape_1_out"));
const auto &full_1 =
src.Op("pd_op.full", {{"value", src.Attr("full_1_value")}});
const auto &scale = src.Op("pd_op.scale");
src.Tensor("scale_out") = scale(src.Tensor("transpose_1_out"), full_1());

// The second path to matmul (k).
const auto &matmul_2 =
src.Op("pd_op.matmul",
{{"transpose_x", src.Attr("matmul_2_transpose_x")},
{"transpose_y", src.Attr("matmul_2_transpose_y")}});
src.Tensor("matmul_2_out") =
matmul_2(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_2_in_2"));
const auto &add_2 = src.Op("pd_op.add");
src.Tensor("add_2_out") =
add_2(src.Tensor("matmul_2_out"), src.Tensor("add_2_in_2"));
const auto &full_int_array_2 = src.Op("pd_op.full_int_array");
const auto &reshape_2 = src.Op("pd_op.reshape");
reshape_2({&src.Tensor("add_2_out"), &full_int_array_2()},
{&src.Tensor("reshape_2_out"), &src.Tensor("reshape_2_xshape")});
const auto &transpose_2 = src.Op("pd_op.transpose");
src.Tensor("transpose_2_out") = transpose_2(src.Tensor("reshape_2_out"));

// The third path to matmul (v).
const auto &matmul_3 =
src.Op("pd_op.matmul",
{{"transpose_x", src.Attr("matmul_3_transpose_x")},
{"transpose_y", src.Attr("matmul_3_transpose_y")}});
src.Tensor("matmul_3_out") =
matmul_3(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_3_in_2"));
const auto &add_3 = src.Op("pd_op.add");
src.Tensor("add_3_out") =
add_3(src.Tensor("matmul_3_out"), src.Tensor("add_3_in_2"));
const auto &full_int_array_3 = src.Op("pd_op.full_int_array");
const auto &reshape_3 = src.Op("pd_op.reshape");
reshape_3({&src.Tensor("add_3_out"), &full_int_array_3()},
{&src.Tensor("reshape_3_out"), &src.Tensor("reshape_3_xshape")});
const auto &transpose_3 = src.Op("pd_op.transpose");
src.Tensor("transpose_3_out") = transpose_3(src.Tensor("reshape_3_out"));

// softmax(qk)v
const auto &matmul_4 =
src.Op("pd_op.matmul",
{{"transpose_x", src.Attr("matmul_4_transpose_x")},
{"transpose_y", src.Attr("matmul_4_transpose_y")}});
src.Tensor("matmul_4_out") =
matmul_4(src.Tensor("scale_out"), src.Tensor("transpose_2_out"));
const auto &add_4 = src.Op("pd_op.add");
src.Tensor("add_4_out") =
add_4(src.Tensor("matmul_4_out"), src.Tensor("add_4_in_2"));
const auto &softmax =
src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_axis")}});
src.Tensor("softmax_out") = softmax(src.Tensor("add_4_out"));
const auto &matmul_5 =
src.Op("pd_op.matmul",
{{"transpose_x", src.Attr("matmul_5_transpose_x")},
{"transpose_y", src.Attr("matmul_5_transpose_y")}});
src.Tensor("matmul_5_out") =
matmul_5(src.Tensor("softmax_out"), src.Tensor("transpose_3_out"));
const auto &transpose_4 = src.Op("pd_op.transpose");
src.Tensor("transpose_4_out") = transpose_4(src.Tensor("matmul_5_out"));
const auto &full_int_array_4 = src.Op("pd_op.full_int_array");
const auto &reshape_4 = src.Op("pd_op.reshape");
reshape_4({&src.Tensor("transpose_4_out"), &full_int_array_4()},
{&src.Tensor("reshape_4_out"), &src.Tensor("reshape_4_xshape")});

//
// Constraints.
//
src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool {
const auto &softmax_axis = match_ctx.Attr<int>("softmax_axis");
if (softmax_axis != -1 && softmax_axis != 3) return false;

bool matmul_1_transpose_x = match_ctx.Attr<bool>("matmul_1_transpose_x");
bool matmul_1_transpose_y = match_ctx.Attr<bool>("matmul_1_transpose_y");
if (matmul_1_transpose_x || matmul_1_transpose_y) return false;

bool matmul_2_transpose_x = match_ctx.Attr<bool>("matmul_2_transpose_x");
bool matmul_2_transpose_y = match_ctx.Attr<bool>("matmul_2_transpose_y");
if (matmul_2_transpose_x || matmul_2_transpose_y) return false;

bool matmul_3_transpose_x = match_ctx.Attr<bool>("matmul_3_transpose_x");
bool matmul_3_transpose_y = match_ctx.Attr<bool>("matmul_3_transpose_y");
if (matmul_3_transpose_x || matmul_3_transpose_y) return false;

bool matmul_4_transpose_x = match_ctx.Attr<bool>("matmul_4_transpose_x");
bool matmul_4_transpose_y = match_ctx.Attr<bool>("matmul_4_transpose_y");
if (matmul_4_transpose_x || !matmul_4_transpose_y) return false;

bool matmul_5_transpose_x = match_ctx.Attr<bool>("matmul_5_transpose_x");
bool matmul_5_transpose_y = match_ctx.Attr<bool>("matmul_5_transpose_y");
if (matmul_5_transpose_x || matmul_5_transpose_y) return false;

return true;
});

//
// Result Pattern.
//
pir::drr::ResultPattern res = src.ResultPattern();
// W combine.
const auto &combine_1 = res.Op("builtin.combine");
combine_1({&res.Tensor("matmul_1_in_2"),
&res.Tensor("matmul_2_in_2"),
&res.Tensor("matmul_3_in_2")},
{&res.Tensor("combine_1_out")});
const auto &concat_axis = res.Attr(
[](const pir::drr::MatchContext &match_ctx) -> int { return 0; });
const auto &concat_1 = res.Op("pd_op.concat", {{"axis", concat_axis}});
res.Tensor("concat_1_out") = concat_1(res.Tensor("combine_1_out"));
const auto &reshape_5_shape = res.Attr(
[](const pir::drr::MatchContext &match_ctx) -> std::vector<int64_t> {
auto matmul_1_in_2 = match_ctx.Tensor("matmul_1_in_2").Shape();
return {-1, 3, matmul_1_in_2.at(1)};
});
const auto &reshape_5 =
res.Op("pd_op.reshape", {{"shape", reshape_5_shape}});
reshape_5({&res.Tensor("concat_1_out")},
{&res.Tensor("reshape_5_out"), &res.NoneTensor()});

// Bias combine.
const auto &combine_2 = res.Op("builtin.combine");
combine_2({&res.Tensor("add_1_in_2"),
&res.Tensor("add_2_in_2"),
&res.Tensor("add_3_in_2")},
{&res.Tensor("combine_2_out")});
const auto &concat_2 = res.Op("pd_op.concat", {{"axis", concat_axis}});
res.Tensor("concat_2_out") = concat_2(res.Tensor("combine_2_out"));
const auto &reshape_6_shape = res.Attr(
[](const pir::drr::MatchContext &match_ctx) -> std::vector<int64_t> {
return {3, -1};
});
const auto &reshape_6 =
res.Op("pd_op.reshape", {{"shape", reshape_6_shape}});
reshape_6({&res.Tensor("concat_2_out")},
{&res.Tensor("reshape_6_out"), &res.NoneTensor()});

const auto &head_number =
res.Attr([](const pir::drr::MatchContext &match_ctx) -> int {
const auto &full_int_array_1_value =
match_ctx.Attr<std::vector<int64_t>>("full_int_array_1_value");
return full_int_array_1_value.at(2);
});
const auto &alpha =
res.Attr([](const pir::drr::MatchContext &match_ctx) -> float {
return match_ctx.Attr<float>("full_1_value");
});
const auto &multihead_matmul = res.Op(
"pd_op.multihead_matmul",
{{"transpose_q", res.Attr([](const pir::drr::MatchContext &match_ctx) {
return false;
})},
{"transpose_k", res.Attr([](const pir::drr::MatchContext &match_ctx) {
return true;
})},
{"transpose_v", res.Attr([](const pir::drr::MatchContext &match_ctx) {
return false;
})},
{"head_number", head_number},
{"alpha", alpha}});
multihead_matmul({&res.Tensor("matmul_1_in_1"),
&res.Tensor("reshape_5_out"),
&res.Tensor("reshape_6_out"),
&res.Tensor("add_4_in_2")},
{&res.Tensor("reshape_4_out")});
}
};

class AttentionFusePass : public pir::Pass {
public:
AttentionFusePass() : pir::Pass("attention_fuse_pass", 2) {}

bool Initialize(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
ps.Add(MultiHeadMatmulFusePattern().Build(context));
// Add other attention variant fuse pattern.

patterns_ = pir::FrozenRewritePatternSet(std::move(ps));
return true;
}

void Run(pir::Operation *op) override {
pir::GreedyRewriteConfig cfg;
cfg.use_top_down_traversal = true;
cfg.max_iterations = 10;
pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg);
}

bool CanApplyOn(pir::Operation *op) const override {
return op->isa<::pir::ModuleOp>() && op->num_regions() > 0;
}

private:
pir::FrozenRewritePatternSet patterns_;
};

} // namespace

namespace pir {
std::unique_ptr<Pass> CreateAttentionFusePass() {
return std::make_unique<AttentionFusePass>();
}
} // namespace pir

REGISTER_IR_PASS(attention_fuse_pass, AttentionFusePass);
26 changes: 26 additions & 0 deletions paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include "paddle/pir/core/dll_decl.h"

namespace pir {

class Pass;

IR_API std::unique_ptr<Pass> CreateAttentionFusePass();

} // namespace pir
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/transforms/fused_gemm_epilogue_pass.h"
#include "paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h"

#include "paddle/fluid/pir/drr/api/drr_pattern_base.h"
#include "paddle/pir/pass/pass.h"
Expand Down Expand Up @@ -254,7 +254,7 @@ class FusedLinearReluGradPattern

class FusedGemmEpiloguePass : public pir::Pass {
public:
FusedGemmEpiloguePass() : pir::Pass("FusedGemmEpiloguePass", 1) {}
FusedGemmEpiloguePass() : pir::Pass("fused_gemm_epilogue_pass", 1) {}

bool Initialize(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
Expand Down Expand Up @@ -292,4 +292,4 @@ std::unique_ptr<Pass> CreateFusedGemmEpiloguePass() {

} // namespace pir

REGISTER_IR_PASS(fused_gemm_epilogue, FusedGemmEpiloguePass);
REGISTER_IR_PASS(fused_gemm_epilogue_pass, FusedGemmEpiloguePass);
6 changes: 3 additions & 3 deletions paddle/fluid/pir/transforms/inplace_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,11 @@ static std::unordered_map<pir::Operation*, std::string> GetInplaceOps(

class InplacePass : public pir::Pass {
public:
InplacePass() : pir::Pass("InplacePass", 3) {}
InplacePass() : pir::Pass("inplace_pass", 3) {}

void Run(pir::Operation* op) override {
auto module_op = op->dyn_cast<pir::ModuleOp>();
IR_ENFORCE(module_op, "InplacePass should run on module op.");
IR_ENFORCE(module_op, "inplace_pass should run on module op.");
auto* block = module_op.block();

auto inplace_ops = details::GetInplaceOps(block);
Expand Down Expand Up @@ -365,4 +365,4 @@ std::unique_ptr<pir::Pass> CreateInplacePass() {

} // namespace pir

REGISTER_IR_PASS(inplace, InplacePass);
REGISTER_IR_PASS(inplace_pass, InplacePass);
4 changes: 2 additions & 2 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ using pir::Type;
using pir::Value;
using pybind11::return_value_policy;

USE_PASS(dead_code_elimination);
USE_PASS(inplace);
USE_PASS(dead_code_elimination_pass);
USE_PASS(inplace_pass);

PHI_DECLARE_bool(print_ir);

Expand Down
Loading

0 comments on commit ec5660d

Please sign in to comment.