Skip to content

Commit

Permalink
Support unpadded GQA FlashAttention (#60610)
Browse files Browse the repository at this point in the history
* add unpadded gpa fa

* update FA commit
  • Loading branch information
sneaxiy authored Jan 12, 2024
1 parent 51d97a6 commit 0a55857
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
2 changes: 1 addition & 1 deletion cmake/external/flashattn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_TAG fd6890c7ef6e53380b9eddc0a12b5acc641eb57d)
set(FLASHATTN_TAG 5fc132ac11e78d26471ca09e5ba0cd817c3424d8)

set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
Expand Down
27 changes: 25 additions & 2 deletions paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,23 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
VLOG(10) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset;

bool is_mha = (num_heads == num_heads_k);

void* dk_data = nullptr;
void* dv_data = nullptr;
phi::DenseTensor dk_expanded, dv_expanded;
if (is_mha) {
dk_data = ctx.template Alloc<T>(dk);
dv_data = ctx.template Alloc<T>(dv);
} else {
std::initializer_list<int64_t> dk_dv_shape = {
total_k, num_heads_k, num_heads / num_heads_k, head_size};
dk_expanded.Resize(dk_dv_shape);
dv_expanded.Resize(dk_dv_shape);
dk_data = ctx.template Alloc<T>(&dk_expanded);
dv_data = ctx.template Alloc<T>(&dv_expanded);
}

bool succ = phi::dynload::flash_attn_varlen_bwd(
dout.data(),
q.data(),
Expand All @@ -109,8 +126,8 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
cu_seqlens_k.data<int32_t>(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
dk_data,
dv_data,
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
Expand All @@ -133,6 +150,12 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr,
params.mask_dims.data());
CheckFlashAttnStatus(succ);

if (!is_mha) {
phi::SumKernel<T, Context>(ctx, dk_expanded, {2}, dk->type(), false, dk);
phi::SumKernel<T, Context>(ctx, dv_expanded, {2}, dv->type(), false, dv);
}

#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
Expand Down

0 comments on commit 0a55857

Please sign in to comment.