Skip to content

Commit

Permalink
support fc_xpu int8
Browse files Browse the repository at this point in the history
  • Loading branch information
csy0225 committed Sep 20, 2023
1 parent 26e125d commit 9483e72
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class AutoTransQuantizeOpPrecisionPass : public FusePassBase {

const std::string name_scope_{"auto_trans_quantize_op_precision_pass"};
const std::unordered_set<std::string> support_fusion_quant_op_type_{
"conv2d_xpu"};
"conv2d_xpu", "fc_xpu"};
};

static inline Node* GetOpOutVarNodeByArgsName(ir::Graph* graph,
Expand Down Expand Up @@ -72,35 +72,33 @@ void AutoTransQuantizeOpPrecisionPass::FirstRound(ir::Graph* graph) const {
bool enable_int8 = op_node->Op()->GetAttrIfExists<bool>("enable_int8");
int out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
if (enable_int8) {
if (op_type == "conv2d_xpu") {
auto* out_var_node =
GetOpOutVarNodeByArgsName(subgraph, op_node, "out");
PADDLE_ENFORCE_NOT_NULL(
out_var_node,
platform::errors::InvalidArgument(
"out_var_node in graph cannot be nullptr."));
bool is_int8_out = true;
for (auto* next_op_node : out_var_node->outputs) {
auto next_op_type = next_op_node->Op()->Type();
bool is_next_op_support_int8 =
next_op_node->Op()->GetAttrIfExists<bool>("enable_int8") &&
((support_fusion_quant_op_type_.find(next_op_type) !=
support_fusion_quant_op_type_.end()));
if (!is_next_op_support_int8) {
is_int8_out = false;
break;
}
}
if (is_int8_out) {
op_node->Op()->SetAttr(
"out_dtype",
static_cast<int>(proto::VarType::Type::VarType_Type_INT8));
out_var_node->Var()->SetDataType(
proto::VarType::Type::VarType_Type_INT8);
VLOG(1) << "The out var node " << out_var_node->Name()
<< " is INT8";
auto* out_var_node =
GetOpOutVarNodeByArgsName(subgraph, op_node, "out");
PADDLE_ENFORCE_NOT_NULL(
out_var_node,
platform::errors::InvalidArgument(
"out_var_node in graph cannot be nullptr."));
bool is_int8_out = true;
for (auto* next_op_node : out_var_node->outputs) {
auto next_op_type = next_op_node->Op()->Type();
bool is_next_op_support_int8 =
next_op_node->Op()->GetAttrIfExists<bool>("enable_int8") &&
((support_fusion_quant_op_type_.find(next_op_type) !=
support_fusion_quant_op_type_.end()));
if (!is_next_op_support_int8) {
is_int8_out = false;
break;
}
}
if (is_int8_out) {
op_node->Op()->SetAttr(
"out_dtype",
static_cast<int>(proto::VarType::Type::VarType_Type_INT8));
out_var_node->Var()->SetDataType(
proto::VarType::Type::VarType_Type_INT8);
VLOG(1) << "The out var node " << out_var_node->Name()
<< " is INT8";
}
}
}
}
Expand Down
12 changes: 7 additions & 5 deletions paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,7 @@ void FcXPUFusePass::CreateFusionWeightsAndBias(
}
// Get Weight scale in int8 scene
std::vector<float> weight_scale =
mul->Op()->GetAttrIfExists<std::vector<float>>("Input_scale_" +
mul_w->Name());
mul->Op()->GetAttrIfExists<std::vector<float>>("weight_scale");
// Create fusion_bias_node
auto filter_dims = filter_t->dims();
bool has_bias = with_bn || with_bias;
Expand Down Expand Up @@ -754,8 +753,11 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph,
GET_IR_NODE(act);
GET_IR_NODE(act_out);
std::map<std::string, std::map<std::string, Node*>> nodes_map;
nodes_map.insert(
{"mul", {{"mul_x", mul_x}, {"mul_w", mul_w}, {"mul_out", mul_out}}});
nodes_map.insert({"mul",
{{"mul", mul},
{"mul_x", mul_x},
{"mul_w", mul_w},
{"mul_out", mul_out}}});
nodes_map.insert({"ew_bias_add",
{{"ew_bias_add", add},
{"ew_bias_add_bias", bias},
Expand Down Expand Up @@ -785,7 +787,7 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph,

bool enable_int8 = mul->Op()->GetAttrIfExists<bool>("enable_int8");
std::string op_precision_str = enable_int8 ? "int8" : "fp32";
VLOG(4) << "FC fusion fuse pass is running on " << op_precision_str
VLOG(1) << "FC fusion fuse pass is running on " << op_precision_str
<< " precision!";
auto* block = mul->Op()->Block();
CreateFusionWeightsAndBias(graph,
Expand Down
15 changes: 13 additions & 2 deletions paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,13 @@ struct LinkFcPattern : public PatternBase {

LinkFcPattern::LinkFcPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* fusion_op = pattern->NewNode(fusion_op_repr())->assert_is_op("fc_xpu");
auto* fusion_op = pattern->NewNode(fusion_op_repr())
->assert_is_op("fc_xpu")
->assert_more([&](Node* node) {
bool enable_int8 =
node->Op()->GetAttrIfExists<bool>("enable_int8");
return !enable_int8;
});
auto* x = pattern->NewNode(x_repr())->assert_is_op_input("fc_xpu", "x");

fusion_op->LinksFrom({x});
Expand Down Expand Up @@ -231,7 +237,12 @@ void LinkXPUOpMaxPass::LinkFcMax(ir::Graph* graph) const {
auto preop_max_var_name = x_pre_op->Output("out_max");
for (auto max_node : x->inputs[0]->outputs) {
if (preop_max_var_name[0] == max_node->Name()) {
fusion_op_desc->SetInput("x_max", {max_node->Name()});
if (fusion_op_desc->HasInput("x_max")) {
auto x_max_old_name = fusion_op_desc->Input("x_max")[0];
fusion_op_desc->RenameInput(x_max_old_name, max_node->Name());
} else {
fusion_op_desc->SetInput("x_max", {max_node->Name()});
}
IR_NODE_LINK_TO(max_node, fusion_op);
}
}
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/ir/xpu/quant_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,12 @@ void Transpose2D(phi::DenseTensor* in, phi::DenseTensor* out) {
case phi::DataType::FLOAT32:
phi::TransposeKernel<float>(*cpu_ctx, *in, axis, out_ptr);
break;
case phi::DataType::INT8:
phi::TransposeKernel<int8_t>(*cpu_ctx, *in, axis, out_ptr);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support fp16 and fp32, but received dtype is %s.",
"Only support fp16/fp32/int8, but received dtype is %s.",
phi::DataTypeToString(in->dtype())));
break;
}
Expand Down
27 changes: 27 additions & 0 deletions paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,33 @@ void MapMatmulV2ToMatmulXPUPass::MapMatmulV2ToMatmul(ir::Graph* graph) const {
desc.SetAttr("transpose_X", matmul_v2->Op()->GetAttr("trans_x"));
desc.SetAttr("transpose_Y", matmul_v2->Op()->GetAttr("trans_y"));
desc.SetAttr("alpha", 1.0f);
if (matmul_v2->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_v2->Op()->GetAttr("enable_int8"));
}
if (matmul_v2->Op()->HasAttr("Input_scale_" + matmul_x->Name())) {
desc.SetAttr("Input_scale_" + matmul_x->Name(),
matmul_v2->Op()->GetAttr("Input_scale_" + matmul_x->Name()));
}
if (matmul_v2->Op()->HasAttr("Input_scale_" + matmul_y->Name())) {
desc.SetAttr("Input_scale_" + matmul_y->Name(),
matmul_v2->Op()->GetAttr("Input_scale_" + matmul_y->Name()));
}
if (matmul_v2->Op()->HasAttr("Input_scale_" + matmul_out->Name())) {
desc.SetAttr(
"Input_scale_" + matmul_out->Name(),
matmul_v2->Op()->GetAttr("Input_scale_" + matmul_out->Name()));
}
if (matmul_v2->Op()->HasAttr("weight_scale")) {
desc.SetAttr("weight_scale", matmul_v2->Op()->GetAttr("weight_scale"));
}
if (matmul_v2->Op()->HasAttr("weight_bit_length")) {
desc.SetAttr("weight_bit_length",
matmul_v2->Op()->GetAttr("weight_bit_length"));
}
if (matmul_v2->Op()->HasAttr("weight_quant_axis")) {
desc.SetAttr("weight_quant_axis",
matmul_v2->Op()->GetAttr("weight_quant_axis"));
}
if (matmul_v2->Op()->HasAttr("use_mkldnn")) {
desc.SetAttr("use_mkldnn", matmul_v2->Op()->GetAttr("use_mkldnn"));
}
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ void FcXPUKernel(const Context& ctx,
DenseTensor* out,
DenseTensor* out_max) {
// Dont use template T param
VLOG(1) << "Kernel type: " << x.dtype() << "," << w.dtype() << " ,"
VLOG(1) << "Kernel type: " << x.dtype() << " ," << w.dtype() << " ,"
<< out_dtype;
if (x.dtype() == DataType::FLOAT32) {
// float32/float16 kernel
Expand All @@ -155,6 +155,8 @@ void FcXPUKernel(const Context& ctx,
FC_XPU_KERNEL_IMPL(float, int8_t, float, int8_t);
} else if (out_dtype == DataType::INT8) {
FC_XPU_KERNEL_IMPL(float, int8_t, int8_t, int8_t);
} else if (out_dtype == DataType::FLOAT16) {
FC_XPU_KERNEL_IMPL(float, int8_t, dtype::float16, int8_t);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Not support x_dtype is %s, w_dtype is %s and out_dtype is "
Expand Down

0 comments on commit 9483e72

Please sign in to comment.