Skip to content

Commit

Permalink
update fused_multi_transformer_encoder_pass support GPT new matmul API (
Browse files Browse the repository at this point in the history
PaddlePaddle#48953)

* fit paddle.matmul in fleetx.gpt
  • Loading branch information
RichardWooSJTU authored Dec 12, 2022
1 parent 293f746 commit 0fdc140
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 52 deletions.
86 changes: 74 additions & 12 deletions paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("matmul", "X");
->assert_is_op_input("matmul_v2", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
Expand All @@ -496,7 +496,7 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* concat_k_out_var = pattern->NewNode(concat_k_out_repr())
->assert_is_op_output("concat")
->AsIntermediate()
->assert_is_op_input("matmul")
->assert_is_op_input("matmul_v2")
->assert_is_op_input("assign");
auto* concat_v_in_var = pattern
->NewNode(concat_v_in_repr())
Expand Down Expand Up @@ -529,10 +529,16 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
assign_v->LinksFrom({concat_v_out_var});

// QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale");
auto* scale_qk = pattern->NewNode(scale_qk_repr())->assert_is_op("scale");
auto* scale_qk_out_var = pattern->NewNode(scale_qk_out_repr())
->assert_is_op_output("scale")
->AsIntermediate()
->assert_is_op_input("elementwise_add", "X");

auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
Expand All @@ -554,7 +560,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
// QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
scale_qk->LinksFrom({matmul_qk_out_var}).LinksTo({scale_qk_out_var});
eltadd_qk->LinksFrom({scale_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});

Expand Down Expand Up @@ -799,7 +806,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* split0_q_out_var = pattern->NewNode(split0_q_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
->assert_is_op_input("matmul", "X");
->assert_is_op_input("matmul_v2", "X");
auto* split0_k_out_var = pattern->NewNode(split0_k_out_repr())
->assert_is_op_output("split")
->AsIntermediate()
Expand All @@ -817,7 +824,7 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
auto* concat_k_out_var = pattern->NewNode(concat_k_out_repr())
->assert_is_op_output("concat")
->AsIntermediate()
->assert_is_op_input("matmul")
->assert_is_op_input("matmul_v2")
->assert_is_op_input("assign");
auto* concat_v_in_var = pattern
->NewNode(concat_v_in_repr())
Expand Down Expand Up @@ -850,10 +857,16 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
assign_v->LinksFrom({concat_v_out_var});

// QK path Nodes
auto* matmul_qk = pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul");
auto* matmul_qk =
pattern->NewNode(matmul_qk_repr())->assert_is_op("matmul_v2");
auto* matmul_qk_out_var =
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
pattern->NewNode(matmul_qk_out_repr())->assert_is_op_output("matmul_v2");
matmul_qk_out_var->AsIntermediate()->assert_is_op_input("scale");
auto* scale_qk = pattern->NewNode(scale_qk_repr())->assert_is_op("scale");
auto* scale_qk_out_var = pattern->NewNode(scale_qk_out_repr())
->assert_is_op_output("scale")
->AsIntermediate()
->assert_is_op_input("elementwise_add", "X");

auto* eltadd_qk =
pattern->NewNode(eltadd_qk_repr())->assert_is_op("elementwise_add");
Expand All @@ -875,7 +888,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
// QK path Linsk
matmul_qk->LinksFrom({split0_q_out_var, concat_k_out_var})
.LinksTo({matmul_qk_out_var});
eltadd_qk->LinksFrom({matmul_qk_out_var, eltadd_qk_b_var})
scale_qk->LinksFrom({matmul_qk_out_var}).LinksTo({scale_qk_out_var});
eltadd_qk->LinksFrom({scale_qk_out_var, eltadd_qk_b_var})
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});

Expand Down Expand Up @@ -2192,6 +2206,11 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);

GET_IR_NODE_FROM_SUBGRAPH(
scale_qk, scale_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_qk_out, scale_qk_out, fused_multi_transformer_fuse_qkv_pattern);

GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
Expand Down Expand Up @@ -2296,6 +2315,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
assign_v,
matmul_qk,
matmul_qk_out,
scale_qk,
scale_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
Expand Down Expand Up @@ -2382,6 +2403,23 @@ FusedMultiTransformerDecoderFuseQKVPass::
.IsNumGT(0)
.End();

AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsType<float>() // copy to new op. so unconstrained.
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale") // bias is 0, so unconstrained.
.IsType<bool>()
.End();

AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
Expand Down Expand Up @@ -2917,6 +2955,11 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
GET_IR_NODE_FROM_SUBGRAPH(
matmul_qk_out, matmul_qk_out, fused_multi_transformer_fuse_qkv_pattern);

GET_IR_NODE_FROM_SUBGRAPH(
scale_qk, scale_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
scale_qk_out, scale_qk_out, fused_multi_transformer_fuse_qkv_pattern);

GET_IR_NODE_FROM_SUBGRAPH(
eltadd_qk, eltadd_qk, fused_multi_transformer_fuse_qkv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
Expand Down Expand Up @@ -3031,6 +3074,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
assign_v,
matmul_qk,
matmul_qk_out,
scale_qk,
scale_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
Expand Down Expand Up @@ -3124,6 +3169,23 @@ MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::
.IsNumGT(0)
.End();

AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsType<float>() // copy to new op. so unconstrained.
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale") // bias is 0, so unconstrained.
.IsType<bool>()
.End();

AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X") // the shape shoule be (B, S, N*H)
.IsTensor()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(scale_qk);
PATTERN_DECL_NODE(scale_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
Expand Down Expand Up @@ -282,6 +284,8 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
// Q, K matmul
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(scale_qk);
PATTERN_DECL_NODE(scale_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,13 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
// (eltadd_0) reshape2 -> reshape_0
// (reshape_0) transpose2 -> transpose_0
// (transpose_0) split -> split_q, split_k,
// split_v (split_k) concat -> concat_k
// split_v (split_k) concat -> concat_k
// (split_v) concat -> concat_v
// (concat_k) assign -> assign_k
// (concat_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (split_q, split_k) matmul_v2 -> matmul_qk
// (matmul_qk) scale -> scale_qk
// (scale_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
Expand Down Expand Up @@ -298,10 +299,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {
layers.assign(concat_v);

// MHA: QK matmul
auto* matmul_qk = layers.matmul(split_q, concat_k, nullptr, false, true);
auto* matmul_qk = layers.matmul_v2(split_q, concat_k, nullptr, false, true);
auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false);

auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);

// MHA: QKV matmul
Expand Down Expand Up @@ -361,11 +363,11 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) {

PADDLE_ENFORCE_EQ(
num_nodes_before,
num_nodes_after + 50,
num_nodes_after + 52,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d",
num_nodes_before - 50,
num_nodes_before - 52,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
Expand Down Expand Up @@ -396,8 +398,9 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
// (split_v) concat -> concat_v
// (concat_k) assign -> assign_k
// (concat_v) assign -> assign_v
// (split_q, split_k) matmul -> matmul_qk
// (matmul_qk, bias_qk) elementwise_add -> eltadd_qk
// (split_q, split_k) matmul_v2 -> matmul_qk
// (matmul_qk) scale -> scale_qk
// (scale_qk, bias_qk) elementwise_add -> eltadd_qk
// (eltadd_qk) softmax -> softmax_qk
// (softmax_qk, transpose_2) matmul_v2 -> matmul_qkv
// (matmul_qkv) transpose -> transpose_qkv
Expand Down Expand Up @@ -455,10 +458,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {
layers.assign(concat_v);

// MHA: QK matmul
auto* matmul_qk = layers.matmul(split_q, concat_k, nullptr, false, true);
auto* matmul_qk = layers.matmul_v2(split_q, concat_k, nullptr, false, true);
auto* scale_qk = layers.scale(matmul_qk, 0.125, 0, false);

auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true);
auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk);
auto* elementwise_qk = layers.elementwise_add(scale_qk, bqk);
auto* softmax_qk = layers.softmax(elementwise_qk, -1);

// MHA: QKV matmul
Expand Down Expand Up @@ -523,11 +527,11 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) {

PADDLE_ENFORCE_EQ(
num_nodes_before,
num_nodes_after + 58,
num_nodes_after + 60,
platform::errors::InvalidArgument(
"After the fused_multi_transformer_decoder_fuse_qkv_pass, "
"The node num in graph should be %d, but the result is %d",
num_nodes_before - 58,
num_nodes_before - 60,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after,
1,
Expand Down
Loading

0 comments on commit 0fdc140

Please sign in to comment.