Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added matmul_v2+transpose+reshape fuse pass #36481

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ if(WITH_MKLDNN)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(matmul_v2_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(batch_norm_act_fuse_pass inference DIR mkldnn)
pass_library(multi_gru_fuse_pass inference DIR mkldnn)
pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn)
Expand Down Expand Up @@ -189,7 +190,7 @@ endif()
cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass)
cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass)
cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass matmul_v2_transpose_reshape_fuse_pass)
cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass)
cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass)
cc_test(test_multi_gru_fuse_pass SRCS mkldnn/multi_gru_fuse_pass_tester.cc DEPS multi_gru_fuse_pass)
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2697,16 +2697,18 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
return matmul_out;
}

PDNode *patterns::MatmulTransposeReshapePattern::operator()() {
// shared function for matmul and matmul_v2
PDNode *patterns::MatmulTransposeReshapePattern::operator()(
const std::string &op_name) {
auto reshape_op =
pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2");
auto transpose_op =
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op(op_name);

auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsInput()
->assert_is_op_output("matmul", "Out")
->assert_is_op_output(op_name, "Out")
->assert_is_op_input("transpose2", "X");

auto transpose_out = pattern->NewNode(transpose_out_repr())
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1546,7 +1546,7 @@ struct MatmulTransposeReshapePattern : public PatternBase {
const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_transpose_reshape") {}

PDNode* operator()();
PDNode* operator()(const std::string& op_name);

PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_out);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ namespace framework {
namespace ir {

MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() {
AddOpCompat(OpCompat("matmul"))
op_name_ = "matmul";

AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
Expand Down Expand Up @@ -89,7 +91,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
patterns::MatmulTransposeReshapePattern mtrp(gpd.mutable_pattern(),
name_scope_);

mtrp();
mtrp(op_name_);

int found_matmul_transpose_reshape_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Expand All @@ -98,7 +100,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "handle matmul_transpose_reshape fuse";
VLOG(4) << "handle " + op_name_ + "_transpose_reshape fuse";
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, mtrp);
Expand All @@ -118,17 +120,17 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
const bool supported_transpose_axis = std::equal(
transpose_axis.begin(), transpose_axis.end(), supported_axis.begin());
if (transpose_out_size != 4) {
VLOG(3) << "do not perform matmul_transpose_reshape fuse: "
VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: "
<< "supported rank is 4, received " << transpose_out_size;
return;
}
if (!supported_transpose_axis) {
VLOG(3) << "do not perform matmul_transpose_reshape fuse: "
VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: "
<< "supported transpose axis for the fuse are {0, 2, 1, 3}";
return;
}
if (reshape_out_size != 3) {
VLOG(3) << "do not perform matmul_transpose_reshape fuse: "
VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: "
<< "reshape_out supported rank is 3, received "
<< reshape_out_size;
return;
Expand All @@ -152,7 +154,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss;
msg_ss << "--- Fused " << found_matmul_transpose_reshape_count
<< " MatmulTransposeReshape patterns";
<< " MatmulTransposeReshape patterns for " + op_name_ + " Op";
paddle::string::PrettyLogDetail(msg_ss.str().c_str());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class MatmulTransposeReshapeMKLDNNPass : public FusePassBase {
protected:
void ApplyImpl(Graph* graph) const override;
const std::string name_scope_{"matmul_transpose_reshape_fuse"};
std::string op_name_;
};
} // namespace ir
} // namespace framework
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -42,31 +42,37 @@ void SetOp(ProgramDesc *prog, const std::string &type,
op->SetAttr("transpose_X", true);
op->SetAttr("transpose_Y", true);
}
if (type == "matmul_v2") {
op->SetInput("Y", {inputs[1]});
op->SetAttr("use_mkldnn", true);
op->SetAttr("trans_x", true);
op->SetAttr("trans_y", true);
}
lidanqing-intel marked this conversation as resolved.
Show resolved Hide resolved
}

ProgramDesc BuildProgramDesc() {
ProgramDesc BuildProgramDesc(const std::string &op_name) {
ProgramDesc prog;
for (auto &v : std::initializer_list<std::string>(
{"a1", "a2", "b", "c", "cx", "d", "dx", "e"})) {
auto *var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
}

SetOp(&prog, "matmul", {"a1", "a2"}, {"b"});
SetOp(&prog, op_name, {"a1", "a2"}, {"b"});
SetOp(&prog, "transpose2", {"b"}, {"c", "cx"});
SetOp(&prog, "reshape2", {"c"}, {"d", "dx"});
SetOp(&prog, "fc", {"d"}, {"e"});

return prog;
}

void MainTest(const ProgramDesc &prog) {
void MainTest(const ProgramDesc &prog, const std::string &op_name) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));

int original_nodes_num = graph->Nodes().size();

auto pass =
PassRegistry::Instance().Get("matmul_transpose_reshape_fuse_pass");
PassRegistry::Instance().Get(op_name + "_transpose_reshape_fuse_pass");
graph.reset(pass->Apply(graph.release()));

int current_nodes_num = graph->Nodes().size();
Expand All @@ -75,7 +81,7 @@ void MainTest(const ProgramDesc &prog) {
for (auto *node : graph->Nodes()) {
if (node->IsOp()) {
auto *op = node->Op();
if (op->Type() == "matmul") {
if (op->Type() == op_name) {
EXPECT_EQ(op->GetAttrIfExists<std::vector<int>>("fused_reshape_Out"),
std::vector<int>({4, 5, 6}));
EXPECT_EQ(op->GetAttrIfExists<std::vector<int>>("fused_transpose_Out"),
Expand All @@ -85,12 +91,18 @@ void MainTest(const ProgramDesc &prog) {
}
}

TEST(MatmulTransposeReshapeFusePass, matmul_inputs) {
auto prog = BuildProgramDesc();
MainTest(prog);
TEST(MatmulTransposeReshapeFusePass, matmul_fuse_pass) {
auto prog = BuildProgramDesc("matmul");
MainTest(prog, "matmul");
}

TEST(MatmulTransposeReshapeFusePass, matmul_v2_fuse_pass) {
auto prog = BuildProgramDesc("matmul_v2");
MainTest(prog, "matmul_v2");
}
} // namespace ir
} // namespace framework
} // namespace paddle

USE_PASS(matmul_transpose_reshape_fuse_pass);
USE_PASS(matmul_v2_transpose_reshape_fuse_pass);
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h"
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {
namespace ir {

MatmulV2TransposeReshapeMKLDNNPass::MatmulV2TransposeReshapeMKLDNNPass() {
op_name_ = "matmul_v2";

AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();

AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();

AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(matmul_v2_transpose_reshape_fuse_pass,
paddle::framework::ir::MatmulV2TransposeReshapeMKLDNNPass);

REGISTER_PASS_CAPABILITY(matmul_v2_transpose_reshape_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("transpose2", 0)
.EQ("reshape2", 0));
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>

#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"

jakpiase marked this conversation as resolved.
Show resolved Hide resolved
namespace paddle {
namespace framework {
namespace ir {
class MatmulV2TransposeReshapeMKLDNNPass
: public MatmulTransposeReshapeMKLDNNPass {
public:
MatmulV2TransposeReshapeMKLDNNPass();
virtual ~MatmulV2TransposeReshapeMKLDNNPass() {}

protected:
const std::string name_scope_{"matmul_v2_transpose_reshape_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"scale_matmul_fuse_pass", //
"reshape_transpose_matmul_mkldnn_fuse_pass", //
"matmul_transpose_reshape_fuse_pass", //
"matmul_v2_transpose_reshape_fuse_pass", //
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass",
// "fc_act_mkldnn_fuse_pass",
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/operators/compat/matmul_v2.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,12 @@ extra {
name: "op_device"
type: STRING
}
attrs {
name: "fused_reshape_Out"
type: INTS
}
attrs {
name: "fused_transpose_Out"
type: INTS
}
}
Loading