Skip to content

Commit

Permalink
Added matmul_v2+transpose+reshape fuse pass (#36481)
Browse files Browse the repository at this point in the history
* added base changes for matmul_v2+trans+resh fuse pass

* added full matmul_v2+transpose+reshape pass

* removed a file added by mistake

* added reviewers suggestions

* Changed ops type in checking capatibility version

* Deteled one statement
  • Loading branch information
jakpiase authored Oct 21, 2021
1 parent 0ca2807 commit 856cb9c
Show file tree
Hide file tree
Showing 15 changed files with 411 additions and 30 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ if(WITH_MKLDNN)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(matmul_v2_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(batch_norm_act_fuse_pass inference DIR mkldnn)
pass_library(multi_gru_fuse_pass inference DIR mkldnn)
pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn)
Expand Down Expand Up @@ -192,7 +193,7 @@ endif()
cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass)
cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass)
cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass matmul_v2_transpose_reshape_fuse_pass)
cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass)
cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass)
cc_test(test_multi_gru_fuse_pass SRCS mkldnn/multi_gru_fuse_pass_tester.cc DEPS multi_gru_fuse_pass)
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2697,16 +2697,18 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
return matmul_out;
}

PDNode *patterns::MatmulTransposeReshapePattern::operator()() {
// shared function for matmul and matmul_v2
PDNode *patterns::MatmulTransposeReshapePattern::operator()(
const std::string &op_name) {
auto reshape_op =
pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2");
auto transpose_op =
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op(op_name);

auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsInput()
->assert_is_op_output("matmul", "Out")
->assert_is_op_output(op_name, "Out")
->assert_is_op_input("transpose2", "X");

auto transpose_out = pattern->NewNode(transpose_out_repr())
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1546,7 +1546,7 @@ struct MatmulTransposeReshapePattern : public PatternBase {
const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_transpose_reshape") {}

PDNode* operator()();
PDNode* operator()(const std::string& op_name);

PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_out);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ namespace framework {
namespace ir {

MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() {
AddOpCompat(OpCompat("matmul"))
op_name_ = "matmul";

AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
Expand Down Expand Up @@ -89,7 +91,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
patterns::MatmulTransposeReshapePattern mtrp(gpd.mutable_pattern(),
name_scope_);

mtrp();
mtrp(op_name_);

int found_matmul_transpose_reshape_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Expand All @@ -98,7 +100,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "handle matmul_transpose_reshape fuse";
VLOG(4) << "handle " + op_name_ + "_transpose_reshape fuse";
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, mtrp);
Expand All @@ -118,17 +120,17 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
const bool supported_transpose_axis = std::equal(
transpose_axis.begin(), transpose_axis.end(), supported_axis.begin());
if (transpose_out_size != 4) {
VLOG(3) << "do not perform matmul_transpose_reshape fuse: "
VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: "
<< "supported rank is 4, received " << transpose_out_size;
return;
}
if (!supported_transpose_axis) {
VLOG(3) << "do not perform matmul_transpose_reshape fuse: "
VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: "
<< "supported transpose axis for the fuse are {0, 2, 1, 3}";
return;
}
if (reshape_out_size != 3) {
VLOG(3) << "do not perform matmul_transpose_reshape fuse: "
VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: "
<< "reshape_out supported rank is 3, received "
<< reshape_out_size;
return;
Expand All @@ -152,7 +154,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss;
msg_ss << "--- Fused " << found_matmul_transpose_reshape_count
<< " MatmulTransposeReshape patterns";
<< " MatmulTransposeReshape patterns for " + op_name_ + " Op";
paddle::string::PrettyLogDetail(msg_ss.str().c_str());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class MatmulTransposeReshapeMKLDNNPass : public FusePassBase {
protected:
void ApplyImpl(Graph* graph) const override;
const std::string name_scope_{"matmul_transpose_reshape_fuse"};
std::string op_name_;
};
} // namespace ir
} // namespace framework
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -42,31 +42,37 @@ void SetOp(ProgramDesc *prog, const std::string &type,
op->SetAttr("transpose_X", true);
op->SetAttr("transpose_Y", true);
}
if (type == "matmul_v2") {
op->SetInput("Y", {inputs[1]});
op->SetAttr("use_mkldnn", true);
op->SetAttr("trans_x", true);
op->SetAttr("trans_y", true);
}
}

ProgramDesc BuildProgramDesc() {
ProgramDesc BuildProgramDesc(const std::string &op_name) {
ProgramDesc prog;
for (auto &v : std::initializer_list<std::string>(
{"a1", "a2", "b", "c", "cx", "d", "dx", "e"})) {
auto *var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
}

SetOp(&prog, "matmul", {"a1", "a2"}, {"b"});
SetOp(&prog, op_name, {"a1", "a2"}, {"b"});
SetOp(&prog, "transpose2", {"b"}, {"c", "cx"});
SetOp(&prog, "reshape2", {"c"}, {"d", "dx"});
SetOp(&prog, "fc", {"d"}, {"e"});

return prog;
}

void MainTest(const ProgramDesc &prog) {
void MainTest(const ProgramDesc &prog, const std::string &op_name) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));

int original_nodes_num = graph->Nodes().size();

auto pass =
PassRegistry::Instance().Get("matmul_transpose_reshape_fuse_pass");
PassRegistry::Instance().Get(op_name + "_transpose_reshape_fuse_pass");
graph.reset(pass->Apply(graph.release()));

int current_nodes_num = graph->Nodes().size();
Expand All @@ -75,7 +81,7 @@ void MainTest(const ProgramDesc &prog) {
for (auto *node : graph->Nodes()) {
if (node->IsOp()) {
auto *op = node->Op();
if (op->Type() == "matmul") {
if (op->Type() == op_name) {
EXPECT_EQ(op->GetAttrIfExists<std::vector<int>>("fused_reshape_Out"),
std::vector<int>({4, 5, 6}));
EXPECT_EQ(op->GetAttrIfExists<std::vector<int>>("fused_transpose_Out"),
Expand All @@ -85,12 +91,18 @@ void MainTest(const ProgramDesc &prog) {
}
}

TEST(MatmulTransposeReshapeFusePass, matmul_inputs) {
auto prog = BuildProgramDesc();
MainTest(prog);
TEST(MatmulTransposeReshapeFusePass, matmul_fuse_pass) {
auto prog = BuildProgramDesc("matmul");
MainTest(prog, "matmul");
}

TEST(MatmulTransposeReshapeFusePass, matmul_v2_fuse_pass) {
auto prog = BuildProgramDesc("matmul_v2");
MainTest(prog, "matmul_v2");
}
} // namespace ir
} // namespace framework
} // namespace paddle

USE_PASS(matmul_transpose_reshape_fuse_pass);
USE_PASS(matmul_v2_transpose_reshape_fuse_pass);
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h"
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {
namespace ir {

MatmulV2TransposeReshapeMKLDNNPass::MatmulV2TransposeReshapeMKLDNNPass() {
op_name_ = "matmul_v2";

AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();

AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();

AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(matmul_v2_transpose_reshape_fuse_pass,
paddle::framework::ir::MatmulV2TransposeReshapeMKLDNNPass);

REGISTER_PASS_CAPABILITY(matmul_v2_transpose_reshape_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("transpose2", 0)
.EQ("reshape2", 0));
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>

#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"

namespace paddle {
namespace framework {
namespace ir {
class MatmulV2TransposeReshapeMKLDNNPass
: public MatmulTransposeReshapeMKLDNNPass {
public:
MatmulV2TransposeReshapeMKLDNNPass();
virtual ~MatmulV2TransposeReshapeMKLDNNPass() {}

protected:
const std::string name_scope_{"matmul_v2_transpose_reshape_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"scale_matmul_fuse_pass", //
"reshape_transpose_matmul_mkldnn_fuse_pass", //
"matmul_transpose_reshape_fuse_pass", //
"matmul_v2_transpose_reshape_fuse_pass", //
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass",
// "fc_act_mkldnn_fuse_pass",
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/operators/compat/matmul_v2.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,12 @@ extra {
name: "op_device"
type: STRING
}
attrs {
name: "fused_reshape_Out"
type: INTS
}
attrs {
name: "fused_transpose_Out"
type: INTS
}
}
Loading

0 comments on commit 856cb9c

Please sign in to comment.