Skip to content

Commit

Permalink
Additional mask support on FA2 (PaddlePaddle#57276)
Browse files Browse the repository at this point in the history
* Add addition mask support. Tested on FlashAttnKernel.

* Fix bug in fwd (temporarily).
Add masked support on bwd.
Unpadded kernel to be tested.

* Add unscale on padded kernel.

* Add varlen mask.

* Remove redundant compute_scale_q

* Remove redundant comment.
Fix ci: PADDLE_ENFORCE format.
Remove test case: return_softmax && dropout==0

* Add mask type check.

* Update submodules.
  • Loading branch information
umiswing authored and AnnaTrainingG committed Nov 6, 2023
1 parent 7887923 commit c5920c1
Show file tree
Hide file tree
Showing 5 changed files with 791 additions and 221 deletions.
215 changes: 109 additions & 106 deletions paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,18 @@
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"

#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#endif
#include "paddle/phi/kernels/reshape_kernel.h"

DECLARE_bool(cudnn_deterministic);
PD_DECLARE_bool(cudnn_deterministic);

namespace phi {

int get_num_split() {
// 0 for an internal heuristic, which is optimal
return FLAGS_cudnn_deterministic ? 1 : 0;
}

template <typename T, typename Context>
void FlashAttnUnpaddedGradKernel(const Context& ctx,
const DenseTensor& q,
Expand All @@ -42,6 +43,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
Expand All @@ -59,23 +61,17 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
const cudaStream_t stream = ctx.stream();

// q,k,v [total_*, num_heads, head_dim]

auto dims = q.dims();
const int64_t total_q = dims[0];
const int batch_size = cu_seqlens_q.numel() - 1;
const int num_heads = dims[1];
const int head_size_og = dout.dims()[2];
const int head_size = dims[2];
const int total_k = k.dims()[0];
const int num_heads_k = k.dims()[1];

// TODO(umiswing): add deterministic in fa2.
// int num_splits = 0; // 0 for an internal heuristic, which is optimal
// if (FLAGS_cudnn_deterministic) {
// num_splits = 1;
// }
const int64_t total_q = dims[0];
const int64_t batch_size = cu_seqlens_q.numel() - 1;
const int64_t num_heads = dims[1];
const int64_t head_size_og = dout.dims()[2];
const int64_t head_size = dims[2];
const int64_t total_k = k.dims()[0];
const int64_t num_heads_k = k.dims()[1];

const bool zero_tensors = false;
int num_splits = get_num_split();

// TODO(umiswing): add shape check
PADDLE_ENFORCE_EQ(
Expand All @@ -96,49 +92,50 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
scale,
causal,
q.dtype(),
attn_mask,
seed_offset.data<int64_t>());

VLOG(4) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset;

const bool succ =
phi::dynload::flash_attn_varlen_bwd(dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
cu_seqlens_q.data<int32_t>(),
cu_seqlens_k.data<int32_t>(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.is_bf16,
stream,
params.seed,
params.offset);
VLOG(10) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset;

if (!succ) {
PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error()));
}
bool succ = phi::dynload::flash_attn_varlen_bwd(
dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
cu_seqlens_q.data<int32_t>(),
cu_seqlens_k.data<int32_t>(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
1.0f / params.scale,
params.causal,
params.is_bf16,
num_splits,
stream,
params.seed,
params.offset,
params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr,
params.mask_dims.data());
CheckFlashAttnStatus(succ);
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
RaiseNotSupportedError();
#endif
}

Expand All @@ -150,6 +147,7 @@ void FlashAttnGradKernel(const Context& ctx,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const DenseTensor& dout,
float dropout,
bool causal,
Expand All @@ -159,14 +157,17 @@ void FlashAttnGradKernel(const Context& ctx,
#ifdef PADDLE_WITH_FLASHATTN
// q,k,v [batch_size, seq_len, num_heads, head_dim]

auto dims = q.dims();
const int batch_size = dims[0];
const int seqlen_q = dims[1];
const int num_heads = dims[2];
const int head_size_og = dout.dims()[3];
const int head_size = dims[3];
const int seqlen_k = k.dims()[1];
const int num_heads_k = k.dims()[2];
const auto& dims = q.dims();
const int64_t batch_size = dims[0];
const int64_t seqlen_q = dims[1];
const int64_t num_heads = dims[2];
const int64_t head_size_og = dout.dims()[3];
const int64_t head_size = dims[3];
const int64_t seqlen_k = k.dims()[1];
const int64_t num_heads_k = k.dims()[2];

const int64_t total_q = batch_size * seqlen_q;
const int64_t total_k = batch_size * seqlen_k;

// TODO(umiswing): add shape check
PADDLE_ENFORCE_EQ(
Expand All @@ -175,8 +176,8 @@ void FlashAttnGradKernel(const Context& ctx,
phi::errors::InvalidArgument(
"flash_attn_bwd receive input with head_size_og == head_size"));

VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";
VLOG(10) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims()
<< "], v[" << v.dims() << "]";

const float scale = 1.0f / std::sqrt(head_size);

Expand All @@ -192,6 +193,7 @@ void FlashAttnGradKernel(const Context& ctx,
scale,
causal,
q.dtype(),
attn_mask,
seed_offset.data<int64_t>());

ctx.template Alloc<T>(dq);
Expand All @@ -200,46 +202,47 @@ void FlashAttnGradKernel(const Context& ctx,

cudaStream_t stream = ctx.stream();

VLOG(4) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset;
VLOG(10) << "FlashAttn bwd seed: " << params.seed
<< ", offset: " << params.offset;

const bool succ = phi::dynload::flash_attn_bwd(dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
params.causal,
params.is_bf16,
stream,
params.seed,
params.offset);
int num_splits = get_num_split();

PADDLE_ENFORCE_EQ(
succ,
true,
phi::errors::External("Error in Flash-Attention-2, detail information is",
phi::dynload::flash_attn_error()));
bool succ = phi::dynload::flash_attn_bwd(
dout.data(),
q.data(),
k.data(),
v.data(),
out.data(),
params.softmax_d.data(),
softmax_lse.data(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
params.max_seqlen_k,
params.seqlen_q_rounded,
params.seqlen_k_rounded,
params.num_heads,
params.num_heads_k,
params.head_size,
params.head_size_rounded,
params.dropout,
params.scale,
std::sqrt(head_size), // for unscale
params.causal,
params.is_bf16,
num_splits,
stream,
params.seed,
params.offset,
params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr,
params.mask_dims.data());
CheckFlashAttnStatus(succ);
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
RaiseNotSupportedError();
#endif
}

Expand Down
Loading

0 comments on commit c5920c1

Please sign in to comment.