Skip to content

Commit

Permalink
Refactor Pass for fused_conv (#48848)
Browse files Browse the repository at this point in the history
* refactor conv_activation_mkldnn_fuse_pass

* refactor conv_affine_channel_mkldnn_fuse_pass

* fix conv_activation_mkldnn_fuse_pass

* fix mkldnn unittest

* refactor int8_scale_calculation_mkldnn_pass and params_quantization_mkldnn_pass

* refactor conv_elementwise_add_mkldnn_fuse_pass

* fix quant

* refactor conv_bn_fuse_pass

* fix conv_bn_fuse_pass

* refactor depthwise_conv_bn_fuse_pass

* fix unittest

* fix conv_bn_fuse_pass

* remove redundant conv2d in params_quantization_mkldnn_pass

* fix params_quantization_mkldnn_pass_tester
  • Loading branch information
zyfncg authored Dec 21, 2022
1 parent b881477 commit 7f0eb2e
Show file tree
Hide file tree
Showing 30 changed files with 352 additions and 67 deletions.
78 changes: 78 additions & 0 deletions paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,43 @@ ConvBNFusePass::ConvBNFusePass() {
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("fused_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();

AddOpCompat(OpCompat("batch_norm"))
.AddInput("X")
Expand Down Expand Up @@ -361,6 +398,10 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
// with MKL-DNN fuse conv+bn into conv with bias
// without MKL-DNN fuse conv+bn into conv+elementwise_add
if (fuse_option == FUSE_MKLDNN) {
if (conv->Op()->Type() == "conv2d" ||
conv->Op()->Type() == "depthwise_conv2d") {
conv->Op()->SetType("fused_conv2d");
}
auto input_names = conv->Op()->InputNames();
bool has_bias =
std::find(input_names.begin(), input_names.end(), "Bias") !=
Expand Down Expand Up @@ -818,6 +859,43 @@ DepthwiseConvBNFusePass::DepthwiseConvBNFusePass() {
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("fused_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}

} // namespace ir
Expand Down
22 changes: 12 additions & 10 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1826,20 +1826,20 @@ PDNode *patterns::ConvBias::operator()(
return eltwise_out_var;
}

PDNode *patterns::Conv::operator()() {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
PDNode *patterns::Conv::operator()(const std::string &conv_type) {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op(conv_type);

auto input_var = pattern->NewNode(conv_input_repr())
->AsInput()
->assert_is_op_input("conv2d", "Input");
->assert_is_op_input(conv_type, "Input");

auto filter_var = pattern->NewNode(conv_filter_repr())
->AsInput()
->assert_is_op_input("conv2d", "Filter");
->assert_is_op_input(conv_type, "Filter");

auto output_var = pattern->NewNode(conv_output_repr())
->AsOutput()
->assert_is_op_output("conv2d", "Output");
->assert_is_op_output(conv_type, "Output");

conv_op->LinksFrom({input_var, filter_var}).LinksTo({output_var});
return output_var;
Expand Down Expand Up @@ -2658,10 +2658,12 @@ PDNode *patterns::ConvElementwiseadd::operator()(PDNode *conv_in) {
}

PDNode *patterns::ConvAffineChannel::operator()(
paddle::framework::ir::PDNode *conv_input, bool with_eltwise_add) {
paddle::framework::ir::PDNode *conv_input,
const std::string &conv_type,
bool with_eltwise_add) {
// Create Operators
conv_input->assert_is_op_input("conv2d", "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
conv_input->assert_is_op_input(conv_type, "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(conv_type);

PDNode *eltwise_op = nullptr;
if (with_eltwise_add) {
Expand All @@ -2676,11 +2678,11 @@ PDNode *patterns::ConvAffineChannel::operator()(
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter");
->assert_is_op_input(conv_type, "Filter");

auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("conv2d");
->assert_is_only_output_of_op(conv_type);

PDNode *eltwise_y_in_var = nullptr;
PDNode *eltwise_out_var = nullptr;
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ struct Conv : public PatternBase {
Conv(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "convolution") {}

PDNode* operator()();
PDNode* operator()(const std::string& conv_type);

PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_input);
Expand Down Expand Up @@ -1544,7 +1544,9 @@ struct ConvAffineChannel : public PatternBase {
ConvAffineChannel(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_affine_channel") {}

PDNode* operator()(PDNode* conv_input, bool with_eltwise_add);
PDNode* operator()(PDNode* conv_input,
const std::string& conv_type,
bool with_eltwise_add);

// declare operator node's name
PATTERN_DECL_NODE(conv);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ using string::PrettyLogDetail;

void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = phi::funcs::GetSupportedActivations();
std::vector<std::string> conv_types = {"conv2d", "fused_conv2d"};
std::vector<std::string> conv_types = {"fused_conv2d", "conv2d"};

for (auto& act_type : act_types) {
FuseConvConcatAct(graph, act_type);
Expand Down Expand Up @@ -64,6 +64,10 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
OpDesc* conv_op = conv->Op();
OpDesc* act_op = activation->Op();

if (conv_op->Type() == "conv2d") {
conv_op->SetType("fused_conv2d");
}

auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) {
Expand Down Expand Up @@ -91,8 +95,9 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
AddStatis(found_conv_activation_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_conv_activation_count > 0) {
PrettyLogDetail("--- fused %d conv with %s activation",
PrettyLogDetail("--- fused %d %s with %s activation",
found_conv_activation_count,
conv_type,
act_type);
}
}
Expand Down Expand Up @@ -134,15 +139,20 @@ void ConvActivationMkldnnFusePass::FuseConvConcatAct(

bool is_not_conv_mkldnn =
!(prev_op_nodes[0]->Op()->GetAttrIfExists<bool>("use_mkldnn"));
if (prev_op_nodes[0]->Op()->Type() != "conv2d" || is_not_conv_mkldnn) {
LOG(WARNING)
<< "This fuse pass supports only conv2d (mkldnn) + activation.";
if ((prev_op_nodes[0]->Op()->Type() != "conv2d" &&
prev_op_nodes[0]->Op()->Type() != "fused_conv2d") ||
is_not_conv_mkldnn) {
LOG(WARNING) << "This fuse pass supports only conv2d(mkldnn) | "
"fused_conv2d(mkldnn) + activation.";
return;
}
}

for (auto node : concat_inputs) {
OpDesc* conv_op = node->inputs[0]->Op();
if (conv_op->Type() == "conv2d") {
conv_op->SetType("fused_conv2d");
}
OpDesc* act_op = activation_op->Op();

auto attr_map = phi::funcs::GetAttributeMap(act_type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ void MainTest(std::string activation) {
int conv_activation_count = 0;

for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
if (node->IsOp() && (node->Op()->Type() == "conv2d" ||
node->Op()->Type() == "fused_conv2d")) {
auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,44 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() {
.IsStringIn({"NCHW", "AnyLayout"})
.End();

AddOpCompat(OpCompat("fused_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "AnyLayout"})
.End();

AddOpCompat(OpCompat("affine_channel"))
.AddInput("X")
.IsTensor()
Expand Down Expand Up @@ -177,6 +215,12 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() {
}

void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
FuseConvAffineChannel(graph, "conv2d");
FuseConvAffineChannel(graph, "fused_conv2d");
}

void ConvAffineChannelFusePass::FuseConvAffineChannel(
ir::Graph* graph, const std::string& conv_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph);
Expand All @@ -190,10 +234,10 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput()
->assert_is_op_input("conv2d", "Input");
->assert_is_op_input(conv_type, "Input");
patterns::ConvAffineChannel conv_ac_pattern(gpd.mutable_pattern(),
name_scope_);
conv_ac_pattern(conv_input, false /*with_eltwise_add*/);
conv_ac_pattern(conv_input, conv_type, false /*with_eltwise_add*/);

int found_conv_ac_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class ConvAffineChannelFusePass : public FusePassBase {

protected:
void ApplyImpl(ir::Graph*) const override;
void FuseConvAffineChannel(ir::Graph* graph,
const std::string& conv_type) const;
const std::string name_scope_{"conv_affine_channel_mkldnn_fuse"};
};

Expand Down
Loading

0 comments on commit 7f0eb2e

Please sign in to comment.