From 9483e72100f906d901a970090464bfee81196ad8 Mon Sep 17 00:00:00 2001 From: csy0225 Date: Wed, 20 Sep 2023 11:05:54 +0800 Subject: [PATCH] support fc_xpu int8 --- .../auto_trans_quantize_op_precision_pass.cc | 54 +++++++++---------- .../framework/ir/xpu/fc_xpu_fuse_pass.cc | 12 +++-- .../framework/ir/xpu/link_xpu_op_max_pass.cc | 15 +++++- paddle/fluid/framework/ir/xpu/quant_utils.cc | 5 +- .../ir/xpu/reshape2_matmul_xpu_fuse_pass.cc | 27 ++++++++++ .../phi/kernels/fusion/xpu/fc_xpu_kernel.cc | 4 +- 6 files changed, 80 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/framework/ir/xpu/auto_trans_quantize_op_precision_pass.cc b/paddle/fluid/framework/ir/xpu/auto_trans_quantize_op_precision_pass.cc index c8b4b7c040f7e1..9fec1091bd9a93 100644 --- a/paddle/fluid/framework/ir/xpu/auto_trans_quantize_op_precision_pass.cc +++ b/paddle/fluid/framework/ir/xpu/auto_trans_quantize_op_precision_pass.cc @@ -44,7 +44,7 @@ class AutoTransQuantizeOpPrecisionPass : public FusePassBase { const std::string name_scope_{"auto_trans_quantize_op_precision_pass"}; const std::unordered_set support_fusion_quant_op_type_{ - "conv2d_xpu"}; + "conv2d_xpu", "fc_xpu"}; }; static inline Node* GetOpOutVarNodeByArgsName(ir::Graph* graph, @@ -72,35 +72,33 @@ void AutoTransQuantizeOpPrecisionPass::FirstRound(ir::Graph* graph) const { bool enable_int8 = op_node->Op()->GetAttrIfExists("enable_int8"); int out_dtype = op_node->Op()->GetAttrIfExists("out_dtype"); if (enable_int8) { - if (op_type == "conv2d_xpu") { - auto* out_var_node = - GetOpOutVarNodeByArgsName(subgraph, op_node, "out"); - PADDLE_ENFORCE_NOT_NULL( - out_var_node, - platform::errors::InvalidArgument( - "out_var_node in graph cannot be nullptr.")); - bool is_int8_out = true; - for (auto* next_op_node : out_var_node->outputs) { - auto next_op_type = next_op_node->Op()->Type(); - bool is_next_op_support_int8 = - next_op_node->Op()->GetAttrIfExists("enable_int8") && - ((support_fusion_quant_op_type_.find(next_op_type) != - support_fusion_quant_op_type_.end())); - if (!is_next_op_support_int8) { - is_int8_out = false; - break; - } - } - if (is_int8_out) { - op_node->Op()->SetAttr( - "out_dtype", - static_cast(proto::VarType::Type::VarType_Type_INT8)); - out_var_node->Var()->SetDataType( - proto::VarType::Type::VarType_Type_INT8); - VLOG(1) << "The out var node " << out_var_node->Name() - << " is INT8"; + auto* out_var_node = + GetOpOutVarNodeByArgsName(subgraph, op_node, "out"); + PADDLE_ENFORCE_NOT_NULL( + out_var_node, + platform::errors::InvalidArgument( + "out_var_node in graph cannot be nullptr.")); + bool is_int8_out = true; + for (auto* next_op_node : out_var_node->outputs) { + auto next_op_type = next_op_node->Op()->Type(); + bool is_next_op_support_int8 = + next_op_node->Op()->GetAttrIfExists("enable_int8") && + ((support_fusion_quant_op_type_.find(next_op_type) != + support_fusion_quant_op_type_.end())); + if (!is_next_op_support_int8) { + is_int8_out = false; + break; } } + if (is_int8_out) { + op_node->Op()->SetAttr( + "out_dtype", + static_cast(proto::VarType::Type::VarType_Type_INT8)); + out_var_node->Var()->SetDataType( + proto::VarType::Type::VarType_Type_INT8); + VLOG(1) << "The out var node " << out_var_node->Name() + << " is INT8"; + } } } } diff --git a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc index 5868db56270212..f087b7caf20ab9 100644 --- a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc @@ -367,8 +367,7 @@ void FcXPUFusePass::CreateFusionWeightsAndBias( } // Get Weight scale in int8 scene std::vector weight_scale = - mul->Op()->GetAttrIfExists>("Input_scale_" + - mul_w->Name()); + mul->Op()->GetAttrIfExists>("weight_scale"); // Create fusion_bias_node auto filter_dims = filter_t->dims(); bool has_bias = with_bn || with_bias; @@ -754,8 +753,11 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph, GET_IR_NODE(act); GET_IR_NODE(act_out); std::map> nodes_map; - nodes_map.insert( - {"mul", {{"mul_x", mul_x}, {"mul_w", mul_w}, {"mul_out", mul_out}}}); + nodes_map.insert({"mul", + {{"mul", mul}, + {"mul_x", mul_x}, + {"mul_w", mul_w}, + {"mul_out", mul_out}}}); nodes_map.insert({"ew_bias_add", {{"ew_bias_add", add}, {"ew_bias_add_bias", bias}, @@ -785,7 +787,7 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph, bool enable_int8 = mul->Op()->GetAttrIfExists("enable_int8"); std::string op_precision_str = enable_int8 ? "int8" : "fp32"; - VLOG(4) << "FC fusion fuse pass is running on " << op_precision_str + VLOG(1) << "FC fusion fuse pass is running on " << op_precision_str << " precision!"; auto* block = mul->Op()->Block(); CreateFusionWeightsAndBias(graph, diff --git a/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc b/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc index 3a6d29f794d657..d9ab5448d0fdaf 100644 --- a/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc +++ b/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc @@ -106,7 +106,13 @@ struct LinkFcPattern : public PatternBase { LinkFcPattern::LinkFcPattern(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, name_scope) { - auto* fusion_op = pattern->NewNode(fusion_op_repr())->assert_is_op("fc_xpu"); + auto* fusion_op = pattern->NewNode(fusion_op_repr()) + ->assert_is_op("fc_xpu") + ->assert_more([&](Node* node) { + bool enable_int8 = + node->Op()->GetAttrIfExists("enable_int8"); + return !enable_int8; + }); auto* x = pattern->NewNode(x_repr())->assert_is_op_input("fc_xpu", "x"); fusion_op->LinksFrom({x}); @@ -231,7 +237,12 @@ void LinkXPUOpMaxPass::LinkFcMax(ir::Graph* graph) const { auto preop_max_var_name = x_pre_op->Output("out_max"); for (auto max_node : x->inputs[0]->outputs) { if (preop_max_var_name[0] == max_node->Name()) { - fusion_op_desc->SetInput("x_max", {max_node->Name()}); + if (fusion_op_desc->HasInput("x_max")) { + auto x_max_old_name = fusion_op_desc->Input("x_max")[0]; + fusion_op_desc->RenameInput(x_max_old_name, max_node->Name()); + } else { + fusion_op_desc->SetInput("x_max", {max_node->Name()}); + } IR_NODE_LINK_TO(max_node, fusion_op); } } diff --git a/paddle/fluid/framework/ir/xpu/quant_utils.cc b/paddle/fluid/framework/ir/xpu/quant_utils.cc index ada4a4b9b6c2f1..90ca41f72958ef 100644 --- a/paddle/fluid/framework/ir/xpu/quant_utils.cc +++ b/paddle/fluid/framework/ir/xpu/quant_utils.cc @@ -64,9 +64,12 @@ void Transpose2D(phi::DenseTensor* in, phi::DenseTensor* out) { case phi::DataType::FLOAT32: phi::TransposeKernel(*cpu_ctx, *in, axis, out_ptr); break; + case phi::DataType::INT8: + phi::TransposeKernel(*cpu_ctx, *in, axis, out_ptr); + break; default: PADDLE_THROW(platform::errors::InvalidArgument( - "Only support fp16 and fp32, but received dtype is %s.", + "Only support fp16/fp32/int8, but received dtype is %s.", phi::DataTypeToString(in->dtype()))); break; } diff --git a/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc index 8383501c30b8f9..fff3c4020b5447 100644 --- a/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc @@ -286,6 +286,33 @@ void MapMatmulV2ToMatmulXPUPass::MapMatmulV2ToMatmul(ir::Graph* graph) const { desc.SetAttr("transpose_X", matmul_v2->Op()->GetAttr("trans_x")); desc.SetAttr("transpose_Y", matmul_v2->Op()->GetAttr("trans_y")); desc.SetAttr("alpha", 1.0f); + if (matmul_v2->Op()->HasAttr("enable_int8")) { + desc.SetAttr("enable_int8", matmul_v2->Op()->GetAttr("enable_int8")); + } + if (matmul_v2->Op()->HasAttr("Input_scale_" + matmul_x->Name())) { + desc.SetAttr("Input_scale_" + matmul_x->Name(), + matmul_v2->Op()->GetAttr("Input_scale_" + matmul_x->Name())); + } + if (matmul_v2->Op()->HasAttr("Input_scale_" + matmul_y->Name())) { + desc.SetAttr("Input_scale_" + matmul_y->Name(), + matmul_v2->Op()->GetAttr("Input_scale_" + matmul_y->Name())); + } + if (matmul_v2->Op()->HasAttr("Input_scale_" + matmul_out->Name())) { + desc.SetAttr( + "Input_scale_" + matmul_out->Name(), + matmul_v2->Op()->GetAttr("Input_scale_" + matmul_out->Name())); + } + if (matmul_v2->Op()->HasAttr("weight_scale")) { + desc.SetAttr("weight_scale", matmul_v2->Op()->GetAttr("weight_scale")); + } + if (matmul_v2->Op()->HasAttr("weight_bit_length")) { + desc.SetAttr("weight_bit_length", + matmul_v2->Op()->GetAttr("weight_bit_length")); + } + if (matmul_v2->Op()->HasAttr("weight_quant_axis")) { + desc.SetAttr("weight_quant_axis", + matmul_v2->Op()->GetAttr("weight_quant_axis")); + } if (matmul_v2->Op()->HasAttr("use_mkldnn")) { desc.SetAttr("use_mkldnn", matmul_v2->Op()->GetAttr("use_mkldnn")); } diff --git a/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc index f2acd0893a6f70..eeb36a86eeec7d 100644 --- a/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc @@ -133,7 +133,7 @@ void FcXPUKernel(const Context& ctx, DenseTensor* out, DenseTensor* out_max) { // Dont use template T param - VLOG(1) << "Kernel type: " << x.dtype() << "," << w.dtype() << " ," + VLOG(1) << "Kernel type: " << x.dtype() << " ," << w.dtype() << " ," << out_dtype; if (x.dtype() == DataType::FLOAT32) { // float32/float16 kernel @@ -155,6 +155,8 @@ void FcXPUKernel(const Context& ctx, FC_XPU_KERNEL_IMPL(float, int8_t, float, int8_t); } else if (out_dtype == DataType::INT8) { FC_XPU_KERNEL_IMPL(float, int8_t, int8_t, int8_t); + } else if (out_dtype == DataType::FLOAT16) { + FC_XPU_KERNEL_IMPL(float, int8_t, dtype::float16, int8_t); } else { PADDLE_THROW(phi::errors::Unimplemented( "Not support x_dtype is %s, w_dtype is %s and out_dtype is "