Skip to content

Commit

Permalink
Replace matmul with matmul_v2 during oneDNN fuse passes (PaddlePaddle…
Browse files Browse the repository at this point in the history
…#49108)

* replace matmul with matmul_v2 in fuse passes

* Remove fusion logic from matmul

* removing fusion methods

* add proper name

* adjust namespaces
  • Loading branch information
Silv3S authored Jan 3, 2023
1 parent 958b9f0 commit 2c444df
Show file tree
Hide file tree
Showing 17 changed files with 690 additions and 1,304 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ void MatmulActivationMkldnnFusePass::FuseMatmulAct(
? "gelu_tanh"
: "gelu_erf";
}

if (matmul_type == "matmul") {
matmul_op->SetType("matmul_v2");
matmul_op->SetAttr("trans_x", matmul_op->GetAttr("transpose_X"));
matmul_op->SetAttr("trans_y", matmul_op->GetAttr("transpose_Y"));
auto matmul_alpha = matmul_op->GetAttrIfExists<float>("alpha");
if (matmul_alpha != 1.0f) {
matmul_op->SetAttr("alpha", matmul_alpha);
}
}
matmul_op->SetAttr("fuse_activation", act_type);
matmul_op->SetOutput("Out", {activation_out->Name()});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ void MatmulElementwiseAddMKLDNNFusePass::FuseMatmulElementwiseAdd(
return;
}

if (matmul_type == "matmul") {
matmul->Op()->SetType("matmul_v2");
matmul->Op()->SetAttr("trans_x", matmul->Op()->GetAttr("transpose_X"));
matmul->Op()->SetAttr("trans_y", matmul->Op()->GetAttr("transpose_Y"));
auto matmul_alpha = matmul->Op()->GetAttrIfExists<float>("alpha");
if (matmul_alpha != 1.0f) {
matmul->Op()->SetAttr("alpha", matmul_alpha);
}
}

matmul->Op()->SetInput("ResidualData", {elementwise_addend->Name()});
matmul->Op()->SetOutput("Out", {elementwise_add_out->Name()});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ void MatmulTransposeReshapeMKLDNNPass::Fuse(
}

OpDesc *matmul_desc = matmul_op->Op();
if (matmul_type == "matmul") {
matmul_desc->SetType("matmul_v2");
matmul_desc->SetAttr("trans_x", matmul_desc->GetAttr("transpose_X"));
matmul_desc->SetAttr("trans_y", matmul_desc->GetAttr("transpose_Y"));
auto matmul_alpha = matmul_desc->GetAttrIfExists<float>("alpha");
if (matmul_alpha != 1.0f) {
matmul_desc->SetAttr("alpha", matmul_alpha);
}
}
matmul_desc->SetOutput("Out", {reshape_out->Name()});
matmul_desc->SetAttr("fused_reshape_Out", reshape_shape);
matmul_desc->SetAttr("fused_transpose_Out", transpose_axis);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ void FuseOperatorScaleOneDNNPass::FuseScale(Graph *graph,
scale = *(scale_tensor->data<float>());
}

if (op_type == "matmul") {
operator_op->Op()->SetType("matmul_v2");
operator_op->Op()->SetAttr("trans_x",
operator_op->Op()->GetAttr("transpose_X"));
operator_op->Op()->SetAttr("trans_y",
operator_op->Op()->GetAttr("transpose_Y"));
auto matmul_alpha = operator_op->Op()->GetAttrIfExists<float>("alpha");
if (matmul_alpha != 1.0f) {
operator_op->Op()->SetAttr("alpha", matmul_alpha);
}
}
operator_op->Op()->SetAttr("fused_output_scale", scale);
operator_op->Op()->SetOutput("Out", {scale_out->Name()});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
return;
}

if (matmul_type == "matmul") {
matmul_desc->SetType("matmul_v2");
matmul_desc->SetAttr("trans_x", matmul_desc->GetAttr("transpose_X"));
matmul_desc->SetAttr("trans_y", matmul_desc->GetAttr("transpose_Y"));
auto matmul_alpha = matmul_desc->GetAttrIfExists<float>("alpha");
if (matmul_alpha != 1.0f) {
matmul_desc->SetAttr("alpha", matmul_alpha);
}
}
matmul_desc->SetInput(matmul_input_name, {(reshape_in)->Name()});
matmul_desc->SetAttr("fused_reshape_" + matmul_input_name, reshape_shape);
matmul_desc->SetAttr("fused_transpose_" + matmul_input_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void TestMain(const std::string& op_name, bool with_xshapes) {
int removed = 8; // 2* reshape, reshape_out, transpose, transpose_out
if (with_xshapes) removed += 2; // transpose_xshape, reshape_xshape
EXPECT_EQ(total_nodes_before - removed, total_nodes_after);
auto* matmul_op_desc = GetOpNodes(graph, op_name).at(0)->Op();
auto* matmul_op_desc = GetOpNodes(graph, "matmul_v2").at(0)->Op();

auto check = [&matmul_op_desc](std::string a) {
std::string shape_str = "fused_reshape_" + a;
Expand Down
32 changes: 2 additions & 30 deletions paddle/fluid/operators/matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,26 +345,6 @@ class MatMulGradKernel : public framework::OpKernel<T> {
}
};

framework::DDim GetDimForInput(const framework::InferShapeContext &ctx,
std::string input_name) {
auto shape = ctx.Attrs().Get<std::vector<int>>("fused_reshape_" + input_name);
auto axis =
ctx.Attrs().Get<std::vector<int>>("fused_transpose_" + input_name);
auto dim = ctx.GetInputDim(input_name);

PADDLE_ENFORCE_GT(dim.size(),
0,
platform::errors::InvalidArgument(
"The Input(%s) has not been initialized properly. The "
"shape of Input(%s) = [%s].",
dim));

if (!shape.empty() && !axis.empty()) {
dim = dim.reshape(shape).transpose(axis);
}
return dim;
}

template <typename DeviceContext, typename T>
class MatMulDoubleGradKernel : public framework::OpKernel<T> {
public:
Expand Down Expand Up @@ -579,8 +559,8 @@ class MatMulOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul");
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "matmul");

auto dim_x = GetDimForInput(*context, "X");
auto dim_y = GetDimForInput(*context, "Y");
auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y");

#ifdef PADDLE_WITH_MKLDNN
// (jczaja): For NHWC execution output shape needs
Expand Down Expand Up @@ -681,14 +661,6 @@ class MatMulOp : public framework::OperatorWithKernel {

framework::DDim ddim_out = phi::make_ddim(dim_out);

#ifdef PADDLE_WITH_MKLDNN
auto shape = context->Attrs().Get<std::vector<int>>("fused_reshape_Out");
auto axis = context->Attrs().Get<std::vector<int>>("fused_transpose_Out");

if (!shape.empty() && !axis.empty()) {
ddim_out = ddim_out.transpose(axis).reshape(shape);
}
#endif
context->SetOutputDim("Out", ddim_out);
context->ShareLoD("X", "Out");
}
Expand Down
Loading

0 comments on commit 2c444df

Please sign in to comment.