Skip to content

Commit

Permalink
fix code style
Browse files Browse the repository at this point in the history
  • Loading branch information
csy0225 committed Oct 23, 2023
1 parent 748bb9d commit 958c7bb
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 43 deletions.
7 changes: 6 additions & 1 deletion paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ cc_library(
placement_pass_base
SRCS placement_pass_base.cc
DEPS pass)
cc_library(
quantize_pass_helper
SRCS quantize_pass_helper.cc
DEPS pass graph graph_helper)

cc_library(
coalesce_grad_tensor_pass
Expand Down Expand Up @@ -241,7 +245,8 @@ if(WITH_XPU)
xpu_graph_pattern_detector
SRCS xpu/xpu_graph_pattern_detector.cc
DEPS graph_pattern_detector)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils xpu_graph_pattern_detector)
set(XPU_PASS_DEPS quantize_pass_helper xpu_quant_utils xpu_pass_utils
xpu_graph_pattern_detector)
pass_library(cast_mixed_precision_op_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/quantize_related_pass_utils.h"
#include "paddle/fluid/framework/ir/quantize_pass_helper.h"

namespace paddle {
namespace framework {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ limitations under the License. */

#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/quantize_related_pass_utils.h"
#include "paddle/fluid/framework/ir/quantize_pass_helper.h"

#include "glog/logging.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@

#include <string>

#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/quantize_pass_helper.h"

namespace paddle {
namespace framework {
namespace ir {

static inline void SaveQuantInfoInTheGraph(
void SaveQuantInfoInTheGraph(
ir::Graph* graph,
const std::string& flag,
const std::string& key_suffix,
Expand All @@ -37,10 +36,8 @@ static inline void SaveQuantInfoInTheGraph(
}
}

static inline std::unordered_map<std::string, std::vector<float>>
GetQuantInfoFromTheGraph(ir::Graph* graph,
const std::string& flag,
const std::string& key_suffix) {
std::unordered_map<std::string, std::vector<float>> GetQuantInfoFromTheGraph(
ir::Graph* graph, const std::string& flag, const std::string& key_suffix) {
std::unordered_map<std::string, std::vector<float>> info_map;
const std::string suffix = "_" + key_suffix + "_" + flag;
if (graph->Has(flag)) {
Expand All @@ -57,7 +54,7 @@ GetQuantInfoFromTheGraph(ir::Graph* graph,
return info_map;
}

static inline bool AreScalesPresentForNodes(
bool AreScalesPresentForNodes(
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
std::initializer_list<Node*> nodes) {
bool present = true;
Expand All @@ -69,13 +66,13 @@ static inline bool AreScalesPresentForNodes(
return present;
}

static inline float GetScaleValueForNode(
float GetScaleValueForNode(
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
Node* node) {
return var_quant_scales->at(node->Name())[0];
}

static inline std::vector<float> GetScaleVecValueForNode(
std::vector<float> GetScaleVecValueForNode(
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
Node* node) {
return var_quant_scales->at(node->Name());
Expand Down
49 changes: 49 additions & 0 deletions paddle/fluid/framework/ir/quantize_pass_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>

#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass.h"

namespace paddle {
namespace framework {
namespace ir {

void SaveQuantInfoInTheGraph(
ir::Graph* graph,
const std::string& flag,
const std::string& key_suffix,
const std::unordered_map<std::string, std::vector<float>>& info_map);

std::unordered_map<std::string, std::vector<float>> GetQuantInfoFromTheGraph(
ir::Graph* graph, const std::string& flag, const std::string& key_suffix);

bool AreScalesPresentForNodes(
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
std::initializer_list<Node*> nodes);

float GetScaleValueForNode(
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
Node* node);

std::vector<float> GetScaleVecValueForNode(
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
Node* node);

} // namespace ir
} // namespace framework
} // namespace paddle
20 changes: 11 additions & 9 deletions paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/quantize_related_pass_utils.h"
#include "paddle/fluid/framework/ir/quantize_pass_helper.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
Expand Down Expand Up @@ -515,7 +515,6 @@ void Conv2dXPUFusePass::CreateFusionWeightsAndBias(
}
// Create fusion_bias_node
auto filter_dims = filter_t->dims();
bool has_bias = with_bn || with_conv_bias;
Node* fusion_bias_node = nullptr;
if (with_conv_bias) {
auto* ew_bias_add_y =
Expand Down Expand Up @@ -677,7 +676,7 @@ void Conv2dXPUFusePass::CreateFusionWeightsAndBias(
filter_ptr[i] *= scale_val_;
}
} else {
for (int i = 0; i < weight_scale.size(); i++) {
for (size_t i = 0; i < weight_scale.size(); i++) {
weight_scale[i] *= scale_val_;
}
}
Expand Down Expand Up @@ -877,12 +876,12 @@ void Conv2dXPUFusePass::CreateFusionOutputs(
platform::errors::InvalidArgument("conv node ptr can not be null"));
// output && output max
std::string conv2d_xpu_out_name;
Node* conv2d_out_op_node = nullptr;
Node* conv2d_out_var_node = nullptr;

auto* ew_branch_add =
GetNodeFromNodesMap(nodes_map, "ew_branch_add", "ew_branch_add");
auto* bn = GetNodeFromNodesMap(nodes_map, "bn", "bn");
auto* scale = GetNodeFromNodesMap(nodes_map, "scale", "scale");
auto* ew_bias_add =
GetNodeFromNodesMap(nodes_map, "ew_bias_add", "ew_bias_add");
if (!act_type.empty()) {
Expand All @@ -898,7 +897,6 @@ void Conv2dXPUFusePass::CreateFusionOutputs(
act != nullptr,
true,
platform::errors::InvalidArgument("act node ptr can not be null"));
conv2d_out_op_node = act;
} else if (ew_branch_add) {
auto* ew_branch_add_out =
GetNodeFromNodesMap(nodes_map, "ew_branch_add", "ew_branch_add_out");
Expand All @@ -912,7 +910,14 @@ void Conv2dXPUFusePass::CreateFusionOutputs(
true,
platform::errors::InvalidArgument(
"ew_branch_add node ptr can not be null"));
conv2d_out_op_node = ew_branch_add;
} else if (scale) {
auto* scale_out = GetNodeFromNodesMap(nodes_map, "scale", "scale_out");
PADDLE_ENFORCE_EQ(scale_out != nullptr,
true,
platform::errors::InvalidArgument(
"scale_out node ptr can not be null"));
conv2d_xpu_out_name = scale_out->Name();
conv2d_out_var_node = scale_out;
} else if (bn) {
auto* bn_out = GetNodeFromNodesMap(nodes_map, "bn", "bn_out");
PADDLE_ENFORCE_EQ(
Expand All @@ -921,7 +926,6 @@ void Conv2dXPUFusePass::CreateFusionOutputs(
platform::errors::InvalidArgument("bn_out node ptr can not be null"));
conv2d_xpu_out_name = bn_out->Name();
conv2d_out_var_node = bn_out;
conv2d_out_op_node = bn;
} else if (ew_bias_add) {
auto* ew_bias_add_out =
GetNodeFromNodesMap(nodes_map, "ew_bias_add", "ew_bias_add_out");
Expand All @@ -931,7 +935,6 @@ void Conv2dXPUFusePass::CreateFusionOutputs(
"ew_bias_add_out node ptr can not be null"));
conv2d_xpu_out_name = ew_bias_add_out->Name();
conv2d_out_var_node = ew_bias_add_out;
conv2d_out_op_node = ew_bias_add;
} else {
auto* conv_out = GetNodeFromNodesMap(nodes_map, "conv", "conv_out");
PADDLE_ENFORCE_EQ(
Expand All @@ -945,7 +948,6 @@ void Conv2dXPUFusePass::CreateFusionOutputs(
conv != nullptr,
true,
platform::errors::InvalidArgument("conv node ptr can not be null"));
conv2d_out_op_node = conv;
}
(*fusion_nodes_map)["out"] = conv2d_out_var_node;

Expand Down
17 changes: 1 addition & 16 deletions paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/quantize_related_pass_utils.h"
#include "paddle/fluid/framework/ir/quantize_pass_helper.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
Expand Down Expand Up @@ -381,7 +381,6 @@ void FcXPUFusePass::CreateFusionWeightsAndBias(
}
// Create fusion_bias_node
auto filter_dims = filter_t->dims();
bool has_bias = with_bn || with_bias;
Node* fusion_bias_node = nullptr;
if (with_bias) {
auto* ew_bias_add_bias =
Expand All @@ -390,8 +389,6 @@ void FcXPUFusePass::CreateFusionWeightsAndBias(
true,
platform::errors::InvalidArgument(
"ew_bias_add_bias node ptr can not be null"));
auto* ew_bias_add_bias_t = scope->FindVar(ew_bias_add_bias->Name())
->GetMutable<phi::DenseTensor>();
PrepareBias(graph, scope, block, ew_bias_add_bias, &fusion_bias_node);
}

Expand Down Expand Up @@ -424,13 +421,6 @@ void FcXPUFusePass::CreateFusionWeightsAndBias(

auto bn_bias_t =
scope->Var(bn_bias->Name())->GetMutable<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(
filter_dims[0],
bn_bias_t->dims()[0],
platform::errors::InvalidArgument("the shape[%d] of bn bias tensor "
"must equal out_channel[%d] of conv",
bn_bias_t->dims()[0],
filter_dims[0]));
auto bn_scale_t =
scope->Var(bn_scale->Name())->GetMutable<phi::DenseTensor>();
auto bn_mean_t =
Expand Down Expand Up @@ -582,7 +572,6 @@ void FcXPUFusePass::CreateFusionOutputs(
platform::errors::InvalidArgument("mul node ptr can not be null"));
// output && output max
std::string fc_xpu_out_name;
Node* fc_out_op_node = nullptr;
Node* fc_out_var_node = nullptr;

auto* bn = GetNodeFromNodesMap(nodes_map, "bn", "bn");
Expand All @@ -597,7 +586,6 @@ void FcXPUFusePass::CreateFusionOutputs(
platform::errors::InvalidArgument("act_out node ptr can not be null"));
fc_xpu_out_name = act_out->Name();
fc_out_var_node = act_out;
fc_out_op_node = act;
} else if (bn) {
auto* bn_out = GetNodeFromNodesMap(nodes_map, "bn", "bn_out");
PADDLE_ENFORCE_EQ(
Expand All @@ -606,7 +594,6 @@ void FcXPUFusePass::CreateFusionOutputs(
platform::errors::InvalidArgument("bn_out node ptr can not be null"));
fc_xpu_out_name = bn_out->Name();
fc_out_var_node = bn_out;
fc_out_op_node = bn;
} else if (ew_bias_add) {
auto* ew_bias_add_out =
GetNodeFromNodesMap(nodes_map, "ew_bias_add", "ew_bias_add_out");
Expand All @@ -616,7 +603,6 @@ void FcXPUFusePass::CreateFusionOutputs(
"ew_bias_add_out node ptr can not be null"));
fc_xpu_out_name = ew_bias_add_out->Name();
fc_out_var_node = ew_bias_add_out;
fc_out_op_node = ew_bias_add;
} else {
auto* mul_out = GetNodeFromNodesMap(nodes_map, "mul", "mul_out");
PADDLE_ENFORCE_EQ(
Expand All @@ -625,7 +611,6 @@ void FcXPUFusePass::CreateFusionOutputs(
platform::errors::InvalidArgument("mul_out node ptr can not be null"));
fc_xpu_out_name = mul_out->Name();
fc_out_var_node = mul_out;
fc_out_op_node = mul;
}
(*fusion_nodes_map)["out"] = fc_out_var_node;

Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ void LinkXPUOpMaxPass::LinkConv2dMax(ir::Graph* graph, bool with_branch) const {
GraphPatternDetector gpd;
patterns::LinkConv2dPattern pattern(
gpd.mutable_pattern(), name_scope_, with_branch);
auto* scope = param_scope();
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
Expand Down Expand Up @@ -232,7 +231,6 @@ void LinkXPUOpMaxPass::LinkFcMax(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::LinkFcPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto* scope = param_scope();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle LinkFcMax";
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/xpu/xpu_quantize_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include <utility>
#include <vector>

#include "paddle/fluid/framework/ir/quantize_related_pass_utils.h"
#include "paddle/fluid/framework/ir/quantize_pass_helper.h"
#include "paddle/utils/string/pretty_log.h"

namespace paddle {
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/framework/ir/xpu/xpu_quantize_squash_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ void XPUQuantizeSquashPass::DequantQuantSquash(
int found_dequant_quant_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
LOG(INFO) << "squash dequantize-quantize ops pair";

GET_IR_NODE_FROM_SUBGRAPH(dequant_in, dequant_in, squash_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, squash_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, squash_pattern);
Expand Down

0 comments on commit 958c7bb

Please sign in to comment.