Skip to content

Commit

Permalink
[XPU] XPU inference support int8 (#57258)
Browse files Browse the repository at this point in the history
  • Loading branch information
csy0225 authored Oct 26, 2023
1 parent a4a59c4 commit 57a14e2
Show file tree
Hide file tree
Showing 38 changed files with 3,278 additions and 502 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ function(pass_library TARGET DEST)
${TARGET}
SRCS ${pass_library_DIR}/${TARGET}.cc
DEPS graph_pattern_detector pass fuse_pass_base op_version_registry
${pass_library_DEPS})
quantize_helper ${pass_library_DEPS})
else()
cc_library(
${TARGET}
SRCS ${TARGET}.cc
DEPS graph_pattern_detector pass fuse_pass_base op_version_registry
${pass_library_DEPS})
quantize_helper ${pass_library_DEPS})
endif()

# add more DEST here, such as train, dist and collect USE_PASS into a file automatically.
Expand Down
12 changes: 11 additions & 1 deletion paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ cc_library(
placement_pass_base
SRCS placement_pass_base.cc
DEPS pass)
cc_library(
quantize_helper
SRCS quantize_helper.cc
DEPS graph graph_helper)

cc_library(
coalesce_grad_tensor_pass
Expand Down Expand Up @@ -237,7 +241,11 @@ if(WITH_XPU)
xpu_pass_utils
SRCS xpu/pass_utils.cc
DEPS pass xpu_quant_utils)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils)
cc_library(
xpu_graph_pattern_detector
SRCS xpu/xpu_graph_pattern_detector.cc
DEPS graph_pattern_detector)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils xpu_graph_pattern_detector)
pass_library(cast_mixed_precision_op_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
Expand All @@ -247,6 +255,8 @@ if(WITH_XPU)
# pass_library(conv1d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(conv2d_bias_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(xpu_quantize_op_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(xpu_quantize_squash_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(redundant_unsqueeze_squeeze_elimination_pass inference DIR xpu
DEPS ${XPU_PASS_DEPS})
pass_library(redundant_squeeze_unsqueeze_elimination_pass inference DIR xpu
Expand Down
32 changes: 30 additions & 2 deletions paddle/fluid/framework/ir/auto_mixed_precision_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,6 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
vars_should_not_low_precision.insert(in_var_node->Var()->Name());
}
}

// when op_1 only support cpu kernel. if op_2's intput var is op_1's
// output var, then op_2 should not run at low precision.
if (GetOpOriginalType(op_type) != "feed" &&
Expand Down Expand Up @@ -687,6 +686,16 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert(
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) == "quantize_linear" ||
GetOpOriginalType(op_desc->Type()) == "dequantize_linear") {
auto vecs = op_desc->Input("Scale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("ZeroPoint");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
}
}

Expand Down Expand Up @@ -733,6 +742,11 @@ bool AutoMixedPrecisionPass::OutputVarsNotConvert(
}

void AutoMixedPrecisionPass::SetVarPrecision() const {
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(scope,
platform::errors::PreconditionNotMet(
"During the auto_mixed_precision_pass, the scope "
"should not be null."));
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) {
Expand All @@ -749,7 +763,21 @@ void AutoMixedPrecisionPass::SetVarPrecision() const {
if (!IsFP32AndFP64(real_in_var_node->Var()->GetDataType())) continue;
if (!VarNodeHasDtype(real_in_var_node)) continue;
if (InputVarsNotConvert(op_node, in_var_name)) continue;

// Judge the real tensor is same to variable, Paddle-Slim weight use
// fp32 variable to save int8 tensor.
if (real_in_var_node->Var()->Persistable()) {
auto* tensor = scope->Var(real_in_var_node->Name())
->GetMutable<phi::DenseTensor>();
if (framework::TransToProtoVarType(tensor->type()) !=
real_in_var_node->Var()->GetDataType()) {
VLOG(3) << "[AutoMixedPrecisionPass] variable "
<< real_in_var_node->Name() << "'s proto data type "
<< real_in_var_node->Var()->GetDataType()
<< " is different from real dense tensor "
<< framework::TransToProtoVarType(tensor->type());
continue;
}
}
if (real_in_var_node->Var()->Persistable()) {
real_in_var_node->Var()->SetDataType(
framework::TransToProtoVarType(low_precision_));
Expand Down
12 changes: 11 additions & 1 deletion paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/quantize_helper.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -94,6 +95,8 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
scope,
platform::errors::InvalidArgument(
"Scope in DeleteQuantDequantLinearOpPass should not be null."));
std::unordered_map<std::string, std::vector<float>> var_quant_scales{};

// Create pattern
patterns::DeleteQuantDequantLinearOpPattern pattern(gpd.mutable_pattern(),
pattern_name);
Expand Down Expand Up @@ -141,7 +144,11 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
auto* any_op_desc = dequantize_linear_op_out->outputs[i]->Op();
any_op_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(),
input_scale);

if (!var_quant_scales.count(quantize_linear_op_x->Var()->Name())) {
var_quant_scales.insert(
std::make_pair(quantize_linear_op_x->Var()->Name(),
std::vector<float>({input_scale})));
}
// link x to any_op2
any_op_desc->RenameInput(dequantize_linear_op_out->Var()->Name(),
quantize_linear_op_x->Var()->Name());
Expand All @@ -161,6 +168,9 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
};
gpd(graph, handler);
AddStatis(found_count);

SaveQuantInfoInTheGraph(
graph, "has_quant_info", "var_quant_scales", var_quant_scales);
}

} // namespace ir
Expand Down
46 changes: 37 additions & 9 deletions paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/quantize_helper.h"

#include "glog/logging.h"

Expand All @@ -35,18 +36,20 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
true,
platform::errors::InvalidArgument(
"Graph must have kParamScopeAttr attribute."));

VLOG(3) << "Handle delete weight dequant linear op pass ...";
auto& scope = graph->Get<framework::Scope>(kParamScopeAttr);
bool is_int8 = false;

std::unordered_set<const Node*> nodes2rm;
std::unordered_map<std::string, std::vector<float>> var_quant_scales{};

for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
auto* op = n->Op();
if (op->Type() == "dequantize_linear") {
Node *weight_var_node = nullptr, *calcu_op_node = nullptr,
*while_op_node = nullptr;
Node* weight_var_node = nullptr;
Node* calcu_op_node = nullptr;
Node* while_op_node = nullptr;
Node *dequantized_weight_var_node = nullptr, *scale_var_node = nullptr;
// 1. Judge whether for dequant weight and find
// weight_var_node/scale_var_node
Expand All @@ -59,9 +62,12 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
scale_var_node = input_node;
}
} else {
return;
break;
}
}
if (weight_var_node == nullptr || scale_var_node == nullptr) {
continue;
}
// 2. Find next_op_node
// For while op: delete its input which is related to dequantized
// For calculation op: set weight scale as their attributes
Expand Down Expand Up @@ -106,7 +112,7 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"The dtype of quantization scale must be FP32/16, "
"The dtype of quantization scale must be FP32/FP16, "
"but received %d, which is not supported.",
weight_scale_tensor->dtype()));
}
Expand All @@ -125,14 +131,34 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {

calcu_op_desc->SetAttr("weight_scale", weight_scale[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Delete Weight Dequant Linear Op Pass is not supported "
"for "
"per-channel quantization"));
std::vector<int64_t> weights_shape =
weight_var_node->Var()->GetShape();
quant_axis = quant_axis >= 0
? quant_axis
: quant_axis + weights_shape.size();
PADDLE_ENFORCE_EQ(
weight_scale_nums,
weights_shape[quant_axis],
platform::errors::InvalidArgument(
"When quant_axis != -1, it means using per_channel "
"dequantization. In this situation, the number of "
"weight_scale should be equal with "
"weights_shape[quant_axis=%d]=%ld , but received "
"%d.",
quant_axis,
weights_shape[quant_axis],
weight_scale_nums));
calcu_op_desc->SetAttr("weight_scale", weight_scale);
}
if (!var_quant_scales.count(weight_var_node->Var()->Name())) {
var_quant_scales.insert(std::make_pair(
weight_var_node->Var()->Name(), weight_scale));
}

calcu_op_desc->RenameInput(
dequantized_weight_var_node->Var()->Name(),
weight_var_node->Var()->Name());
calcu_op_desc->Flush();
}
}
}
Expand All @@ -153,6 +179,8 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
}

GraphSafeRemoveNodes(graph, nodes2rm);
SaveQuantInfoInTheGraph(
graph, "has_quant_info", "var_quant_scales", var_quant_scales);
graph->Set("enable_int8", new bool(is_int8));
}
} // namespace ir
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/framework/ir/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,20 @@ class Graph {
return sub_graphs_.size();
}

std::vector<std::string> AttrNames() const {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->AttrNames();
}
}
std::vector<std::string> res;
res.reserve(attrs_.size());
for (auto &attr : attrs_) {
res.push_back(attr.first);
}
return res;
}

private:
// TODO(levi): delete this interface after when we can convert all
// blocks into sub_graphs.
Expand Down
79 changes: 79 additions & 0 deletions paddle/fluid/framework/ir/quantize_helper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/quantize_helper.h"

namespace paddle {
namespace framework {
namespace ir {

void SaveQuantInfoInTheGraph(
ir::Graph* graph,
const std::string& flag,
const std::string& key_suffix,
const std::unordered_map<std::string, std::vector<float>>& info_map) {
const std::string suffix = "_" + key_suffix + "_" + flag;
if (!graph->Has(flag)) {
graph->Set(flag, new bool(true));
}
for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) {
graph->Set(iter->first + suffix, new std::vector<float>(iter->second));
}
}

std::unordered_map<std::string, std::vector<float>> GetQuantInfoFromTheGraph(
ir::Graph* graph, const std::string& flag, const std::string& key_suffix) {
std::unordered_map<std::string, std::vector<float>> info_map;
const std::string suffix = "_" + key_suffix + "_" + flag;
if (graph->Has(flag)) {
std::vector<std::string> attr_names = graph->AttrNames();
for (auto fake_name : attr_names) {
size_t pos = fake_name.find(suffix);
if (pos != std::string::npos) {
std::string name = fake_name.substr(0, pos);
auto scales_vector = graph->Get<std::vector<float>>(fake_name);
info_map.insert(std::make_pair(name, scales_vector));
}
}
}
return info_map;
}

bool AreScalesPresentForNodes(
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
std::initializer_list<Node*> nodes) {
bool present = true;
for (auto node : nodes) {
if (var_quant_scales->count(node->Name()) == 0) {
present = false;
}
}
return present;
}

float GetScaleValueForNode(
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
Node* node) {
return var_quant_scales->at(node->Name())[0];
}

std::vector<float> GetScaleVecValueForNode(
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
Node* node) {
return var_quant_scales->at(node->Name());
}

} // namespace ir
} // namespace framework
} // namespace paddle
Loading

0 comments on commit 57a14e2

Please sign in to comment.