Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xpu] fused_MT xpu kernel support #51571

Merged
merged 8 commits into from
Mar 20, 2023
2 changes: 2 additions & 0 deletions cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ if(WITH_XPU_XFT)
message(STATUS "Compile with XPU XFT!")
add_definitions(-DPADDLE_WITH_XPU_XFT)

set(XPU_XFT_INC_DIR "${XPU_INC_DIR}/xft")
include_directories(${XPU_XFT_INC_DIR})
set(XPU_XFT_LIB "${XPU_LIB_DIR}/${XPU_XFT_LIB_NAME}")
endif()

Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,8 @@ if(WITH_XPU)
pass_library(link_xpu_op_max_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(delete_isolated_node_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
endif()

cc_library(
Expand Down Expand Up @@ -493,4 +495,8 @@ if(WITH_XPU)
test_delete_isolated_node_pass
SRCS xpu/delete_isolated_node_pass_test.cc
DEPS delete_isolated_node_pass)
cc_test(
test_fused_multi_transformer_xpu_quant_pass
SRCS xpu/fused_multi_transformer_xpu_quant_pass_tester.cc
DEPS fused_multi_transformer_xpu_quant_pass)
endif()
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ TEST(FuseMultiTransformerLayerPass, encoder_fp) {
1,
{2, -1, 16, 1024, 64},
0);
auto* out = layers.fused_multi_transformer(x,
auto outs = layers.fused_multi_transformer(x,
cache_kv,
src_mask,
qkv_w,
Expand All @@ -93,7 +93,7 @@ TEST(FuseMultiTransformerLayerPass, encoder_fp) {
0.1,
1e-12);

x = out;
x = outs[0];
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
Expand Down Expand Up @@ -126,7 +126,7 @@ TEST(FuseMultiTransformerLayerPass, decoder_fp) {
for (int i = 0; i < num_layers; ++i) {
auto* shape_out = layers.shape(src_mask);
auto* time_stamp = layers.slice(shape_out, {0}, {3}, {4});
auto* out = layers.fused_multi_transformer(x,
auto outs = layers.fused_multi_transformer(x,
cache_kv,
src_mask,
qkv_w,
Expand All @@ -145,7 +145,7 @@ TEST(FuseMultiTransformerLayerPass, decoder_fp) {
1e-12,
time_stamp);

x = out;
x = outs[0];
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto param_scope = CreateParamScope();
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/framework/ir/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ class Node {
var_desc_->SetName(new_name);
}

void RenameOp(const std::string& new_name) {
PADDLE_ENFORCE_EQ(
type_ == Type::kOperation && op_desc_,
true,
platform::errors::InvalidArgument("Node must be type of variable."));
name_ = new_name;
op_desc_->SetType(new_name);
}

int DescOrder() const { return desc_order_; }

int GetVarNodeBlockId() const {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ static const std::vector<std::string> support_subgraph_passes = {
"fuse_multi_transformer_layer_pass",
"delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_pass",
"fused_multi_transformer_xpu_quant_pass",
"fc_xpu_fuse_pass",
"delete_op_device_pass"};

Expand Down
58 changes: 31 additions & 27 deletions paddle/fluid/framework/ir/pass_tester_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -571,33 +571,35 @@ struct Layers {
return out;
}

VarDesc* fused_multi_transformer(VarDesc* x,
VarDesc* cache_kv,
VarDesc* src_mask,
VarDesc* qkv_w,
VarDesc* qkv_bias,
VarDesc* out_linear_w,
VarDesc* out_linear_bias,
VarDesc* ffn1_w,
VarDesc* ffn1_bias,
VarDesc* ffn2_w,
VarDesc* ffn2_bias,
VarDesc* ln_scale,
VarDesc* ln_bias,
VarDesc* ffn_ln_scale,
VarDesc* ffn_ln_bias,
float epsilon,
float dropout_rate,
VarDesc* time_stamp = nullptr,
VarDesc* qkv_out_scale = nullptr,
VarDesc* out_linear_out_scale = nullptr,
VarDesc* ffn1_out_scale = nullptr,
VarDesc* ffn2_out_scale = nullptr,
std::vector<float> qkv_in_scale = {},
std::vector<float> out_linear_in_scale = {},
std::vector<float> ffn1_in_scale = {},
std::vector<float> ffn2_in_scale = {}) {
std::vector<VarDesc*> fused_multi_transformer(
VarDesc* x,
VarDesc* cache_kv,
VarDesc* src_mask,
VarDesc* qkv_w,
VarDesc* qkv_bias,
VarDesc* out_linear_w,
VarDesc* out_linear_bias,
VarDesc* ffn1_w,
VarDesc* ffn1_bias,
VarDesc* ffn2_w,
VarDesc* ffn2_bias,
VarDesc* ln_scale,
VarDesc* ln_bias,
VarDesc* ffn_ln_scale,
VarDesc* ffn_ln_bias,
float epsilon,
float dropout_rate,
VarDesc* time_stamp = nullptr,
VarDesc* qkv_out_scale = nullptr,
VarDesc* out_linear_out_scale = nullptr,
VarDesc* ffn1_out_scale = nullptr,
VarDesc* ffn2_out_scale = nullptr,
std::vector<float> qkv_in_scale = {},
std::vector<float> out_linear_in_scale = {},
std::vector<float> ffn1_in_scale = {},
std::vector<float> ffn2_in_scale = {}) {
VarDesc* out = lod_tensor(unique_name());
VarDesc* cache_kv_out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
std::string op_type = qkv_out_scale ? "fused_multi_transformer_int8"
: "fused_multi_transformer";
Expand All @@ -623,6 +625,7 @@ struct Layers {
op->SetAttr("dropout_rate", dropout_rate);
op->SetAttr("epsilon", epsilon);
op->SetOutput("Out", {out->Name()});
op->SetOutput("CacheKVOut", {cache_kv_out->Name()});

if (time_stamp) {
op->SetInput("TimeStep", {time_stamp->Name()});
Expand All @@ -638,7 +641,8 @@ struct Layers {
op->SetAttr("ffn1_in_scale", ffn1_in_scale);
op->SetAttr("ffn2_in_scale", ffn2_in_scale);
}
return out;
std::vector<VarDesc*> outs = {out, cache_kv_out};
return outs;
}

VarDesc* dequantize_linear(VarDesc* x,
Expand Down
Loading