From 958c7bb9a38b3d9e2c1f410c2e7c03ad11147ff3 Mon Sep 17 00:00:00 2001 From: csy0225 Date: Mon, 23 Oct 2023 14:32:00 +0800 Subject: [PATCH] fix code style --- paddle/fluid/framework/ir/CMakeLists.txt | 7 ++- .../ir/delete_quant_dequant_linear_op_pass.cc | 2 +- .../delete_weight_dequant_linear_op_pass.cc | 2 +- ...d_pass_utils.h => quantize_pass_helper.cc} | 17 +++---- .../fluid/framework/ir/quantize_pass_helper.h | 49 +++++++++++++++++++ .../framework/ir/xpu/conv2d_xpu_fuse_pass.cc | 20 ++++---- .../framework/ir/xpu/fc_xpu_fuse_pass.cc | 17 +------ .../framework/ir/xpu/link_xpu_op_max_pass.cc | 2 - .../framework/ir/xpu/xpu_quantize_op_pass.cc | 2 +- .../ir/xpu/xpu_quantize_squash_pass.cc | 2 - 10 files changed, 77 insertions(+), 43 deletions(-) rename paddle/fluid/framework/ir/{quantize_related_pass_utils.h => quantize_pass_helper.cc} (82%) create mode 100644 paddle/fluid/framework/ir/quantize_pass_helper.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index bd9d40bde47026..47e7a9948856c3 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -59,6 +59,10 @@ cc_library( placement_pass_base SRCS placement_pass_base.cc DEPS pass) +cc_library( + quantize_pass_helper + SRCS quantize_pass_helper.cc + DEPS pass graph graph_helper) cc_library( coalesce_grad_tensor_pass @@ -241,7 +245,8 @@ if(WITH_XPU) xpu_graph_pattern_detector SRCS xpu/xpu_graph_pattern_detector.cc DEPS graph_pattern_detector) - set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils xpu_graph_pattern_detector) + set(XPU_PASS_DEPS quantize_pass_helper xpu_quant_utils xpu_pass_utils + xpu_graph_pattern_detector) pass_library(cast_mixed_precision_op_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc index c36fd3d4ff2698..025cd0c2b7dddf 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc @@ -19,7 +19,7 @@ #include #include #include -#include "paddle/fluid/framework/ir/quantize_related_pass_utils.h" +#include "paddle/fluid/framework/ir/quantize_pass_helper.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc index 59f25483c110b1..e30ae85f71c027 100644 --- a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/quantize_related_pass_utils.h" +#include "paddle/fluid/framework/ir/quantize_pass_helper.h" #include "glog/logging.h" diff --git a/paddle/fluid/framework/ir/quantize_related_pass_utils.h b/paddle/fluid/framework/ir/quantize_pass_helper.cc similarity index 82% rename from paddle/fluid/framework/ir/quantize_related_pass_utils.h rename to paddle/fluid/framework/ir/quantize_pass_helper.cc index 86f2160d31bc4d..9cbaf993be0709 100644 --- a/paddle/fluid/framework/ir/quantize_related_pass_utils.h +++ b/paddle/fluid/framework/ir/quantize_pass_helper.cc @@ -16,14 +16,13 @@ #include -#include "paddle/fluid/framework/ir/graph_helper.h" -#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/quantize_pass_helper.h" namespace paddle { namespace framework { namespace ir { -static inline void SaveQuantInfoInTheGraph( +void SaveQuantInfoInTheGraph( ir::Graph* graph, const std::string& flag, const std::string& key_suffix, @@ -37,10 +36,8 @@ static inline void SaveQuantInfoInTheGraph( } } -static inline std::unordered_map> -GetQuantInfoFromTheGraph(ir::Graph* graph, - const std::string& flag, - const std::string& key_suffix) { +std::unordered_map> GetQuantInfoFromTheGraph( + ir::Graph* graph, const std::string& flag, const std::string& key_suffix) { std::unordered_map> info_map; const std::string suffix = "_" + key_suffix + "_" + flag; if (graph->Has(flag)) { @@ -57,7 +54,7 @@ GetQuantInfoFromTheGraph(ir::Graph* graph, return info_map; } -static inline bool AreScalesPresentForNodes( +bool AreScalesPresentForNodes( std::unordered_map>* var_quant_scales, std::initializer_list nodes) { bool present = true; @@ -69,13 +66,13 @@ static inline bool AreScalesPresentForNodes( return present; } -static inline float GetScaleValueForNode( +float GetScaleValueForNode( std::unordered_map>* var_quant_scales, Node* node) { return var_quant_scales->at(node->Name())[0]; } -static inline std::vector GetScaleVecValueForNode( +std::vector GetScaleVecValueForNode( std::unordered_map>* var_quant_scales, Node* node) { return var_quant_scales->at(node->Name()); diff --git a/paddle/fluid/framework/ir/quantize_pass_helper.h b/paddle/fluid/framework/ir/quantize_pass_helper.h new file mode 100644 index 00000000000000..4876cd35a1cf3a --- /dev/null +++ b/paddle/fluid/framework/ir/quantize_pass_helper.h @@ -0,0 +1,49 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +void SaveQuantInfoInTheGraph( + ir::Graph* graph, + const std::string& flag, + const std::string& key_suffix, + const std::unordered_map>& info_map); + +std::unordered_map> GetQuantInfoFromTheGraph( + ir::Graph* graph, const std::string& flag, const std::string& key_suffix); + +bool AreScalesPresentForNodes( + std::unordered_map>* var_quant_scales, + std::initializer_list nodes); + +float GetScaleValueForNode( + std::unordered_map>* var_quant_scales, + Node* node); + +std::vector GetScaleVecValueForNode( + std::unordered_map>* var_quant_scales, + Node* node); + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc index 6fb76c5dbe4575..09037a0fd60eb9 100644 --- a/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc @@ -20,7 +20,7 @@ #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/pass.h" -#include "paddle/fluid/framework/ir/quantize_related_pass_utils.h" +#include "paddle/fluid/framework/ir/quantize_pass_helper.h" #include "paddle/fluid/framework/ir/xpu/pass_utils.h" #include "paddle/fluid/framework/ir/xpu/quant_utils.h" #include "paddle/fluid/framework/op_version_registry.h" @@ -515,7 +515,6 @@ void Conv2dXPUFusePass::CreateFusionWeightsAndBias( } // Create fusion_bias_node auto filter_dims = filter_t->dims(); - bool has_bias = with_bn || with_conv_bias; Node* fusion_bias_node = nullptr; if (with_conv_bias) { auto* ew_bias_add_y = @@ -677,7 +676,7 @@ void Conv2dXPUFusePass::CreateFusionWeightsAndBias( filter_ptr[i] *= scale_val_; } } else { - for (int i = 0; i < weight_scale.size(); i++) { + for (size_t i = 0; i < weight_scale.size(); i++) { weight_scale[i] *= scale_val_; } } @@ -877,12 +876,12 @@ void Conv2dXPUFusePass::CreateFusionOutputs( platform::errors::InvalidArgument("conv node ptr can not be null")); // output && output max std::string conv2d_xpu_out_name; - Node* conv2d_out_op_node = nullptr; Node* conv2d_out_var_node = nullptr; auto* ew_branch_add = GetNodeFromNodesMap(nodes_map, "ew_branch_add", "ew_branch_add"); auto* bn = GetNodeFromNodesMap(nodes_map, "bn", "bn"); + auto* scale = GetNodeFromNodesMap(nodes_map, "scale", "scale"); auto* ew_bias_add = GetNodeFromNodesMap(nodes_map, "ew_bias_add", "ew_bias_add"); if (!act_type.empty()) { @@ -898,7 +897,6 @@ void Conv2dXPUFusePass::CreateFusionOutputs( act != nullptr, true, platform::errors::InvalidArgument("act node ptr can not be null")); - conv2d_out_op_node = act; } else if (ew_branch_add) { auto* ew_branch_add_out = GetNodeFromNodesMap(nodes_map, "ew_branch_add", "ew_branch_add_out"); @@ -912,7 +910,14 @@ void Conv2dXPUFusePass::CreateFusionOutputs( true, platform::errors::InvalidArgument( "ew_branch_add node ptr can not be null")); - conv2d_out_op_node = ew_branch_add; + } else if (scale) { + auto* scale_out = GetNodeFromNodesMap(nodes_map, "scale", "scale_out"); + PADDLE_ENFORCE_EQ(scale_out != nullptr, + true, + platform::errors::InvalidArgument( + "scale_out node ptr can not be null")); + conv2d_xpu_out_name = scale_out->Name(); + conv2d_out_var_node = scale_out; } else if (bn) { auto* bn_out = GetNodeFromNodesMap(nodes_map, "bn", "bn_out"); PADDLE_ENFORCE_EQ( @@ -921,7 +926,6 @@ void Conv2dXPUFusePass::CreateFusionOutputs( platform::errors::InvalidArgument("bn_out node ptr can not be null")); conv2d_xpu_out_name = bn_out->Name(); conv2d_out_var_node = bn_out; - conv2d_out_op_node = bn; } else if (ew_bias_add) { auto* ew_bias_add_out = GetNodeFromNodesMap(nodes_map, "ew_bias_add", "ew_bias_add_out"); @@ -931,7 +935,6 @@ void Conv2dXPUFusePass::CreateFusionOutputs( "ew_bias_add_out node ptr can not be null")); conv2d_xpu_out_name = ew_bias_add_out->Name(); conv2d_out_var_node = ew_bias_add_out; - conv2d_out_op_node = ew_bias_add; } else { auto* conv_out = GetNodeFromNodesMap(nodes_map, "conv", "conv_out"); PADDLE_ENFORCE_EQ( @@ -945,7 +948,6 @@ void Conv2dXPUFusePass::CreateFusionOutputs( conv != nullptr, true, platform::errors::InvalidArgument("conv node ptr can not be null")); - conv2d_out_op_node = conv; } (*fusion_nodes_map)["out"] = conv2d_out_var_node; diff --git a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc index 4e8a6d9d99c732..93ad3aec0d16aa 100644 --- a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc @@ -19,7 +19,7 @@ #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/pass.h" -#include "paddle/fluid/framework/ir/quantize_related_pass_utils.h" +#include "paddle/fluid/framework/ir/quantize_pass_helper.h" #include "paddle/fluid/framework/ir/xpu/pass_utils.h" #include "paddle/fluid/framework/ir/xpu/quant_utils.h" #include "paddle/fluid/framework/op_version_registry.h" @@ -381,7 +381,6 @@ void FcXPUFusePass::CreateFusionWeightsAndBias( } // Create fusion_bias_node auto filter_dims = filter_t->dims(); - bool has_bias = with_bn || with_bias; Node* fusion_bias_node = nullptr; if (with_bias) { auto* ew_bias_add_bias = @@ -390,8 +389,6 @@ void FcXPUFusePass::CreateFusionWeightsAndBias( true, platform::errors::InvalidArgument( "ew_bias_add_bias node ptr can not be null")); - auto* ew_bias_add_bias_t = scope->FindVar(ew_bias_add_bias->Name()) - ->GetMutable(); PrepareBias(graph, scope, block, ew_bias_add_bias, &fusion_bias_node); } @@ -424,13 +421,6 @@ void FcXPUFusePass::CreateFusionWeightsAndBias( auto bn_bias_t = scope->Var(bn_bias->Name())->GetMutable(); - PADDLE_ENFORCE_EQ( - filter_dims[0], - bn_bias_t->dims()[0], - platform::errors::InvalidArgument("the shape[%d] of bn bias tensor " - "must equal out_channel[%d] of conv", - bn_bias_t->dims()[0], - filter_dims[0])); auto bn_scale_t = scope->Var(bn_scale->Name())->GetMutable(); auto bn_mean_t = @@ -582,7 +572,6 @@ void FcXPUFusePass::CreateFusionOutputs( platform::errors::InvalidArgument("mul node ptr can not be null")); // output && output max std::string fc_xpu_out_name; - Node* fc_out_op_node = nullptr; Node* fc_out_var_node = nullptr; auto* bn = GetNodeFromNodesMap(nodes_map, "bn", "bn"); @@ -597,7 +586,6 @@ void FcXPUFusePass::CreateFusionOutputs( platform::errors::InvalidArgument("act_out node ptr can not be null")); fc_xpu_out_name = act_out->Name(); fc_out_var_node = act_out; - fc_out_op_node = act; } else if (bn) { auto* bn_out = GetNodeFromNodesMap(nodes_map, "bn", "bn_out"); PADDLE_ENFORCE_EQ( @@ -606,7 +594,6 @@ void FcXPUFusePass::CreateFusionOutputs( platform::errors::InvalidArgument("bn_out node ptr can not be null")); fc_xpu_out_name = bn_out->Name(); fc_out_var_node = bn_out; - fc_out_op_node = bn; } else if (ew_bias_add) { auto* ew_bias_add_out = GetNodeFromNodesMap(nodes_map, "ew_bias_add", "ew_bias_add_out"); @@ -616,7 +603,6 @@ void FcXPUFusePass::CreateFusionOutputs( "ew_bias_add_out node ptr can not be null")); fc_xpu_out_name = ew_bias_add_out->Name(); fc_out_var_node = ew_bias_add_out; - fc_out_op_node = ew_bias_add; } else { auto* mul_out = GetNodeFromNodesMap(nodes_map, "mul", "mul_out"); PADDLE_ENFORCE_EQ( @@ -625,7 +611,6 @@ void FcXPUFusePass::CreateFusionOutputs( platform::errors::InvalidArgument("mul_out node ptr can not be null")); fc_xpu_out_name = mul_out->Name(); fc_out_var_node = mul_out; - fc_out_op_node = mul; } (*fusion_nodes_map)["out"] = fc_out_var_node; diff --git a/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc b/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc index bf03a2598726c6..9b552bac36f2d1 100644 --- a/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc +++ b/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc @@ -172,7 +172,6 @@ void LinkXPUOpMaxPass::LinkConv2dMax(ir::Graph* graph, bool with_branch) const { GraphPatternDetector gpd; patterns::LinkConv2dPattern pattern( gpd.mutable_pattern(), name_scope_, with_branch); - auto* scope = param_scope(); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { @@ -232,7 +231,6 @@ void LinkXPUOpMaxPass::LinkFcMax(ir::Graph* graph) const { GraphPatternDetector gpd; patterns::LinkFcPattern pattern(gpd.mutable_pattern(), name_scope_); int found_subgraph_count = 0; - auto* scope = param_scope(); auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { VLOG(4) << "handle LinkFcMax"; diff --git a/paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.cc b/paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.cc index ebeb75763320e6..865464dcd7dca7 100644 --- a/paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.cc +++ b/paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.cc @@ -18,7 +18,7 @@ #include #include -#include "paddle/fluid/framework/ir/quantize_related_pass_utils.h" +#include "paddle/fluid/framework/ir/quantize_pass_helper.h" #include "paddle/utils/string/pretty_log.h" namespace paddle { diff --git a/paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.cc index 0e6fd9797c1774..6161293bf7fb76 100644 --- a/paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.cc @@ -66,8 +66,6 @@ void XPUQuantizeSquashPass::DequantQuantSquash( int found_dequant_quant_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - LOG(INFO) << "squash dequantize-quantize ops pair"; - GET_IR_NODE_FROM_SUBGRAPH(dequant_in, dequant_in, squash_pattern); GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, squash_pattern); GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, squash_pattern);