Skip to content

Commit

Permalink
[XPU] Use xft::fused_multi_transformer_gpt in fused_multi_transformer…
Browse files Browse the repository at this point in the history
…_xpu kernel (PaddlePaddle#56921)
  • Loading branch information
NALLEIN authored and jiahy0825 committed Oct 16, 2023
1 parent 00b5cef commit a24c3fb
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 84 deletions.
10 changes: 10 additions & 0 deletions paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,16 @@ void ComputeInterceptor::Run() {
SendDataReadyToDownStream();
// reply to upstream and decrease ready data
ReplyCompletedToUpStream();
// clear TensorArray
auto vars_names = microbatch_scopes_[cur_scope_id_]->LocalVarNames();
for (auto var_name : vars_names) {
if (var_name == "feed" || var_name == "fetch") continue;
auto* var = microbatch_scopes_[cur_scope_id_]->Var(var_name);
if (var != nullptr && var->IsType<framework::LoDTensorArray>()) {
auto* lod_tensor_arr = var->GetMutable<framework::LoDTensorArray>();
lod_tensor_arr->clear();
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,44 @@ int FusedMultiTransformerXPUPass::FusedMultiTransformerXPUQuant(
cast_tofp32_func("FFN1Bias");
cast_tofp32_func("FFN2Bias");

// Generate max_buffer: per_tensor_max and per_batch_max for kv_cache
int layer_num = fused_mt->Op()->Input("QKVW").size();
int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1);
phi::DenseTensor max_buffer_tensor;
max_buffer_tensor.set_type(phi::DataType::FLOAT32);
int max_buffer_len = max_ptr_size * layer_num * 2;
max_buffer_tensor.Resize({max_buffer_len});
std::vector<float> ones_vec(max_buffer_len, 1.f);
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
memcpy(cpu_ctx->Alloc<float>(&max_buffer_tensor),
ones_vec.data(),
max_buffer_len * sizeof(float));

size_t max_buffer_hash = HashTensor<float>(max_buffer_tensor);
std::string max_buffer_name =
"max_buffer_#" + std::to_string(max_buffer_hash);
auto* max_buffer_node = FindNodeWithName(graph, max_buffer_name);
if (max_buffer_node == nullptr) {
// Create dst node
// Update dst var_desc in block
VarDesc dst_desc(max_buffer_name);
dst_desc.SetPersistable(true);
dst_desc.SetShape(vectorize(max_buffer_tensor.dims()));
dst_desc.SetDataType(
framework::TransToProtoVarType(max_buffer_tensor.dtype()));
max_buffer_node = graph->CreateVarNode(&dst_desc);
auto* block_dst_desc = block->Var(max_buffer_name);
block_dst_desc->SetPersistable(dst_desc.Persistable());
block_dst_desc->SetShape(dst_desc.GetShape());
block_dst_desc->SetDataType(dst_desc.GetDataType());
auto* max_buffer_var = scope->FindVar(max_buffer_name);
if (max_buffer_var == nullptr) {
Assign(max_buffer_tensor,
scope->Var(max_buffer_name)->GetMutable<phi::DenseTensor>());
}
}

// Generate fused_multi_transformer_xpu op inplace
fused_mt->RenameOp("fused_multi_transformer_xpu");
framework::OpDesc* fused_mt_xpu_op_desc = fused_mt->Op();
Expand All @@ -542,6 +580,7 @@ int FusedMultiTransformerXPUPass::FusedMultiTransformerXPUQuant(
fused_mt_xpu_op_desc->MutableInputs()->clear();
fused_mt_xpu_op_desc->MutableOutputs()->clear();
fused_mt_xpu_op_desc->SetInput("x", name_caches.at("X"));
fused_mt_xpu_op_desc->SetInput("max_buffer", {max_buffer_name});
fused_mt_xpu_op_desc->SetInput("ln_scale", name_caches.at("LnScale"));
fused_mt_xpu_op_desc->SetInput("ln_bias", name_caches.at("LnBias"));
fused_mt_xpu_op_desc->SetInput("qkv_bias", name_caches.at("QKVBias"));
Expand Down Expand Up @@ -613,7 +652,7 @@ int FusedMultiTransformerXPUPass::FusedMultiTransformerXPUQuant(
IR_NODE_LINK_TO(node, fused_mt);
}
}

IR_NODE_LINK_TO(max_buffer_node, fused_mt);
found_subgraph_count++;
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@
support_dygraph_mode : true

- op : fused_multi_transformer_xpu
args : (Tensor x, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] qkvw, Tensor[] qkvw_max, Tensor[] qkv_bias, Tensor[] out_linear_w, Tensor[] out_linear_wmax, Tensor[] out_linear_bias, Tensor[] ffn_ln_scale, Tensor[] ffn_ln_bias, Tensor[] ffn1_weight, Tensor[] ffn1_weight_max, Tensor[] ffn1_bias, Tensor[] ffn2_weight, Tensor[] ffn2_weight_max, Tensor[] ffn2_bias, Tensor[] cache_kv, Tensor[] pre_caches, Tensor rotary_pos_emb, Tensor time_step, Tensor seq_lengths, Tensor src_mask, Tensor gather_index, bool pre_layer_norm, int rotary_emb_dims, float epsilon, float dropout_rate, bool is_test, str dropout_implementation, str act_method, bool trans_qkvw, int ring_id, int gather_axis)
args : (Tensor x, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] qkvw, Tensor[] qkvw_max, Tensor[] qkv_bias, Tensor[] out_linear_w, Tensor[] out_linear_wmax, Tensor[] out_linear_bias, Tensor[] ffn_ln_scale, Tensor[] ffn_ln_bias, Tensor[] ffn1_weight, Tensor[] ffn1_weight_max, Tensor[] ffn1_bias, Tensor[] ffn2_weight, Tensor[] ffn2_weight_max, Tensor[] ffn2_bias, Tensor[] cache_kv, Tensor[] pre_caches, Tensor rotary_pos_emb, Tensor time_step, Tensor seq_lengths, Tensor src_mask, Tensor gather_index, Tensor max_buffer, bool pre_layer_norm, int rotary_emb_dims, float epsilon, float dropout_rate, bool is_test, str dropout_implementation, str act_method, bool trans_qkvw, int ring_id, int gather_axis)
output : Tensor(out), Tensor[](cache_kv_out){out_linear_w.size()}
infer_meta :
func : FusedMultiTransformerXpuInferMeta
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,12 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::BOOL})},
{"gather_inplace",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::BOOL})},
{"gaussian_random",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"gelu_grad",
Expand Down Expand Up @@ -1030,6 +1036,12 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT64,
phi::DataType::INT32,
phi::DataType::INT64})},
{"mp_allreduce_sum",
XPUKernelSet({phi::DataType::FLOAT16,
phi::DataType::FLOAT32,
phi::DataType::INT32})},
{"c_embedding",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
};

return s_xpu2_kernels;
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,7 @@ void FusedMultiTransformerXpuInferMeta(
const MetaTensor& seq_lengths,
const MetaTensor& src_mask,
const MetaTensor& gather_index,
const MetaTensor& max_buffer,
bool pre_layer_norm,
int rotary_emb_dims,
float epsilon,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ void FusedMultiTransformerXpuInferMeta(
const MetaTensor& seq_lengths,
const MetaTensor& src_mask,
const MetaTensor& gather_index,
const MetaTensor& max_buffer,
bool pre_layer_norm,
int rotary_emb_dims,
float epsilon,
Expand Down
Loading

0 comments on commit a24c3fb

Please sign in to comment.