Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] XPU inference support int8 #57258

Merged
merged 16 commits into from
Oct 26, 2023
4 changes: 2 additions & 2 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ function(pass_library TARGET DEST)
${TARGET}
SRCS ${pass_library_DIR}/${TARGET}.cc
DEPS graph_pattern_detector pass fuse_pass_base op_version_registry
${pass_library_DEPS})
quantize_helper ${pass_library_DEPS})
else()
cc_library(
${TARGET}
SRCS ${TARGET}.cc
DEPS graph_pattern_detector pass fuse_pass_base op_version_registry
${pass_library_DEPS})
quantize_helper ${pass_library_DEPS})
endif()

# add more DEST here, such as train, dist and collect USE_PASS into a file automatically.
Expand Down
12 changes: 11 additions & 1 deletion paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ cc_library(
placement_pass_base
SRCS placement_pass_base.cc
DEPS pass)
cc_library(
quantize_helper
SRCS quantize_helper.cc
DEPS graph graph_helper)

cc_library(
coalesce_grad_tensor_pass
Expand Down Expand Up @@ -237,7 +241,11 @@ if(WITH_XPU)
xpu_pass_utils
SRCS xpu/pass_utils.cc
DEPS pass xpu_quant_utils)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils)
cc_library(
xpu_graph_pattern_detector
SRCS xpu/xpu_graph_pattern_detector.cc
DEPS graph_pattern_detector)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils xpu_graph_pattern_detector)
pass_library(cast_mixed_precision_op_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
Expand All @@ -247,6 +255,8 @@ if(WITH_XPU)
# pass_library(conv1d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(conv2d_bias_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(xpu_quantize_op_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(xpu_quantize_squash_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(redundant_unsqueeze_squeeze_elimination_pass inference DIR xpu
DEPS ${XPU_PASS_DEPS})
pass_library(redundant_squeeze_unsqueeze_elimination_pass inference DIR xpu
Expand Down
32 changes: 30 additions & 2 deletions paddle/fluid/framework/ir/auto_mixed_precision_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,6 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
vars_should_not_low_precision.insert(in_var_node->Var()->Name());
}
}

// when op_1 only support cpu kernel. if op_2's intput var is op_1's
// output var, then op_2 should not run at low precision.
if (GetOpOriginalType(op_type) != "feed" &&
Expand Down Expand Up @@ -687,6 +686,16 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert(
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) == "quantize_linear" ||
GetOpOriginalType(op_desc->Type()) == "dequantize_linear") {
auto vecs = op_desc->Input("Scale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("ZeroPoint");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
}
}

Expand Down Expand Up @@ -733,6 +742,11 @@ bool AutoMixedPrecisionPass::OutputVarsNotConvert(
}

void AutoMixedPrecisionPass::SetVarPrecision() const {
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(scope,
platform::errors::PreconditionNotMet(
"During the auto_mixed_precision_pass, the scope "
"should not be null."));
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) {
Expand All @@ -749,7 +763,21 @@ void AutoMixedPrecisionPass::SetVarPrecision() const {
if (!IsFP32AndFP64(real_in_var_node->Var()->GetDataType())) continue;
if (!VarNodeHasDtype(real_in_var_node)) continue;
if (InputVarsNotConvert(op_node, in_var_name)) continue;

// Judge the real tensor is same to variable, Paddle-Slim weight use
// fp32 variable to save int8 tensor.
if (real_in_var_node->Var()->Persistable()) {
auto* tensor = scope->Var(real_in_var_node->Name())
->GetMutable<phi::DenseTensor>();
if (framework::TransToProtoVarType(tensor->type()) !=
real_in_var_node->Var()->GetDataType()) {
VLOG(3) << "[AutoMixedPrecisionPass] variable "
<< real_in_var_node->Name() << "'s proto data type "
<< real_in_var_node->Var()->GetDataType()
<< " is different from real dense tensor "
<< framework::TransToProtoVarType(tensor->type());
continue;
}
}
if (real_in_var_node->Var()->Persistable()) {
real_in_var_node->Var()->SetDataType(
framework::TransToProtoVarType(low_precision_));
Expand Down
12 changes: 11 additions & 1 deletion paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/quantize_helper.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -94,6 +95,8 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
scope,
platform::errors::InvalidArgument(
"Scope in DeleteQuantDequantLinearOpPass should not be null."));
std::unordered_map<std::string, std::vector<float>> var_quant_scales{};

// Create pattern
patterns::DeleteQuantDequantLinearOpPattern pattern(gpd.mutable_pattern(),
pattern_name);
Expand Down Expand Up @@ -141,7 +144,11 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
auto* any_op_desc = dequantize_linear_op_out->outputs[i]->Op();
any_op_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(),
input_scale);

if (!var_quant_scales.count(quantize_linear_op_x->Var()->Name())) {
var_quant_scales.insert(
std::make_pair(quantize_linear_op_x->Var()->Name(),
std::vector<float>({input_scale})));
}
// link x to any_op2
any_op_desc->RenameInput(dequantize_linear_op_out->Var()->Name(),
quantize_linear_op_x->Var()->Name());
Expand All @@ -161,6 +168,9 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
};
gpd(graph, handler);
AddStatis(found_count);

SaveQuantInfoInTheGraph(
graph, "has_quant_info", "var_quant_scales", var_quant_scales);
}

} // namespace ir
Expand Down
46 changes: 37 additions & 9 deletions paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/quantize_helper.h"

#include "glog/logging.h"

Expand All @@ -35,18 +36,20 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
true,
platform::errors::InvalidArgument(
"Graph must have kParamScopeAttr attribute."));

VLOG(3) << "Handle delete weight dequant linear op pass ...";
auto& scope = graph->Get<framework::Scope>(kParamScopeAttr);
bool is_int8 = false;

std::unordered_set<const Node*> nodes2rm;
std::unordered_map<std::string, std::vector<float>> var_quant_scales{};

for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
auto* op = n->Op();
if (op->Type() == "dequantize_linear") {
Node *weight_var_node = nullptr, *calcu_op_node = nullptr,
*while_op_node = nullptr;
Node* weight_var_node = nullptr;
Node* calcu_op_node = nullptr;
Node* while_op_node = nullptr;
Node *dequantized_weight_var_node = nullptr, *scale_var_node = nullptr;
// 1. Judge whether for dequant weight and find
// weight_var_node/scale_var_node
Expand All @@ -59,9 +62,12 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
scale_var_node = input_node;
}
} else {
return;
break;
}
}
if (weight_var_node == nullptr || scale_var_node == nullptr) {
continue;
}
// 2. Find next_op_node
// For while op: delete its input which is related to dequantized
// For calculation op: set weight scale as their attributes
Expand Down Expand Up @@ -106,7 +112,7 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"The dtype of quantization scale must be FP32/16, "
"The dtype of quantization scale must be FP32/FP16, "
"but received %d, which is not supported.",
weight_scale_tensor->dtype()));
}
Expand All @@ -125,14 +131,34 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {

calcu_op_desc->SetAttr("weight_scale", weight_scale[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Delete Weight Dequant Linear Op Pass is not supported "
"for "
"per-channel quantization"));
std::vector<int64_t> weights_shape =
weight_var_node->Var()->GetShape();
quant_axis = quant_axis >= 0
? quant_axis
: quant_axis + weights_shape.size();
PADDLE_ENFORCE_EQ(
weight_scale_nums,
weights_shape[quant_axis],
platform::errors::InvalidArgument(
"When quant_axis != -1, it means using per_channel "
"dequantization. In this situation, the number of "
"weight_scale should be equal with "
"weights_shape[quant_axis=%d]=%ld , but received "
"%d.",
quant_axis,
weights_shape[quant_axis],
weight_scale_nums));
calcu_op_desc->SetAttr("weight_scale", weight_scale);
}
if (!var_quant_scales.count(weight_var_node->Var()->Name())) {
var_quant_scales.insert(std::make_pair(
weight_var_node->Var()->Name(), weight_scale));
}

calcu_op_desc->RenameInput(
dequantized_weight_var_node->Var()->Name(),
weight_var_node->Var()->Name());
calcu_op_desc->Flush();
}
}
}
Expand All @@ -153,6 +179,8 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
}

GraphSafeRemoveNodes(graph, nodes2rm);
SaveQuantInfoInTheGraph(
graph, "has_quant_info", "var_quant_scales", var_quant_scales);
graph->Set("enable_int8", new bool(is_int8));
}
} // namespace ir
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/framework/ir/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,20 @@ class Graph {
return sub_graphs_.size();
}

std::vector<std::string> AttrNames() const {
if (FLAGS_convert_all_blocks) {
if (IsMainGraph()) {
return GetSubGraph(0)->AttrNames();
}
}
std::vector<std::string> res;
res.reserve(attrs_.size());
for (auto &attr : attrs_) {
res.push_back(attr.first);
}
return res;
}

private:
// TODO(levi): delete this interface after when we can convert all
// blocks into sub_graphs.
Expand Down
79 changes: 79 additions & 0 deletions paddle/fluid/framework/ir/quantize_helper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/quantize_helper.h"

namespace paddle {
namespace framework {
namespace ir {

void SaveQuantInfoInTheGraph(
ir::Graph* graph,
const std::string& flag,
const std::string& key_suffix,
const std::unordered_map<std::string, std::vector<float>>& info_map) {
const std::string suffix = "_" + key_suffix + "_" + flag;
if (!graph->Has(flag)) {
graph->Set(flag, new bool(true));
}
for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) {
graph->Set(iter->first + suffix, new std::vector<float>(iter->second));
}
}

std::unordered_map<std::string, std::vector<float>> GetQuantInfoFromTheGraph(
ir::Graph* graph, const std::string& flag, const std::string& key_suffix) {
std::unordered_map<std::string, std::vector<float>> info_map;
const std::string suffix = "_" + key_suffix + "_" + flag;
if (graph->Has(flag)) {
std::vector<std::string> attr_names = graph->AttrNames();
for (auto fake_name : attr_names) {
size_t pos = fake_name.find(suffix);
if (pos != std::string::npos) {
std::string name = fake_name.substr(0, pos);
auto scales_vector = graph->Get<std::vector<float>>(fake_name);
info_map.insert(std::make_pair(name, scales_vector));
}
}
}
return info_map;
}

bool AreScalesPresentForNodes(
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
std::initializer_list<Node*> nodes) {
bool present = true;
for (auto node : nodes) {
if (var_quant_scales->count(node->Name()) == 0) {
present = false;
}
}
return present;
}

float GetScaleValueForNode(
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
Node* node) {
return var_quant_scales->at(node->Name())[0];
}

std::vector<float> GetScaleVecValueForNode(
std::unordered_map<std::string, std::vector<float>>* var_quant_scales,
Node* node) {
return var_quant_scales->at(node->Name());
}

} // namespace ir
} // namespace framework
} // namespace paddle
Loading