diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 1ae6c887aee027..0296c2afee76ff 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -21,17 +21,18 @@ #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/reshape_kernel.h" - -#ifdef PADDLE_WITH_FLASHATTN -#include "paddle/phi/backends/dynload/flashattn.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h" -#endif +#include "paddle/phi/kernels/reshape_kernel.h" -DECLARE_bool(cudnn_deterministic); +PD_DECLARE_bool(cudnn_deterministic); namespace phi { +int get_num_split() { + // 0 for an internal heuristic, which is optimal + return FLAGS_cudnn_deterministic ? 1 : 0; +} + template void FlashAttnUnpaddedGradKernel(const Context& ctx, const DenseTensor& q, @@ -42,6 +43,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, const DenseTensor& out, const DenseTensor& softmax_lse, const DenseTensor& seed_offset, + const paddle::optional& attn_mask, const DenseTensor& dout, int64_t max_seqlen_q, int64_t max_seqlen_k, @@ -59,23 +61,17 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, const cudaStream_t stream = ctx.stream(); // q,k,v [total_*, num_heads, head_dim] - auto dims = q.dims(); - const int64_t total_q = dims[0]; - const int batch_size = cu_seqlens_q.numel() - 1; - const int num_heads = dims[1]; - const int head_size_og = dout.dims()[2]; - const int head_size = dims[2]; - const int total_k = k.dims()[0]; - const int num_heads_k = k.dims()[1]; - // TODO(umiswing): add deterministic in fa2. - // int num_splits = 0; // 0 for an internal heuristic, which is optimal - // if (FLAGS_cudnn_deterministic) { - // num_splits = 1; - // } + const int64_t total_q = dims[0]; + const int64_t batch_size = cu_seqlens_q.numel() - 1; + const int64_t num_heads = dims[1]; + const int64_t head_size_og = dout.dims()[2]; + const int64_t head_size = dims[2]; + const int64_t total_k = k.dims()[0]; + const int64_t num_heads_k = k.dims()[1]; - const bool zero_tensors = false; + int num_splits = get_num_split(); // TODO(umiswing): add shape check PADDLE_ENFORCE_EQ( @@ -96,49 +92,50 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, scale, causal, q.dtype(), + attn_mask, seed_offset.data()); - VLOG(4) << "FlashAttn bwd seed: " << params.seed - << ", offset: " << params.offset; - - const bool succ = - phi::dynload::flash_attn_varlen_bwd(dout.data(), - q.data(), - k.data(), - v.data(), - out.data(), - params.softmax_d.data(), - softmax_lse.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - params.rng_state.data(), - dq->data(), - dk->data(), - dv->data(), - params.dq_accum.data(), - params.batch_size, - params.max_seqlen_q, - params.max_seqlen_k, - params.seqlen_q_rounded, - params.seqlen_k_rounded, - params.num_heads, - params.num_heads_k, - params.head_size, - params.head_size_rounded, - params.dropout, - params.scale, - params.causal, - params.is_bf16, - stream, - params.seed, - params.offset); + VLOG(10) << "FlashAttn bwd seed: " << params.seed + << ", offset: " << params.offset; - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } + bool succ = phi::dynload::flash_attn_varlen_bwd( + dout.data(), + q.data(), + k.data(), + v.data(), + out.data(), + params.softmax_d.data(), + softmax_lse.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + params.rng_state.data(), + dq->data(), + dk->data(), + dv->data(), + params.dq_accum.data(), + params.batch_size, + params.max_seqlen_q, + params.max_seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.num_heads, + params.num_heads_k, + params.head_size, + params.head_size_rounded, + params.dropout, + params.scale, + 1.0f / params.scale, + params.causal, + params.is_bf16, + num_splits, + stream, + params.seed, + params.offset, + params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, + params.mask_dims.data()); + CheckFlashAttnStatus(succ); #else - PADDLE_THROW(phi::errors::Unimplemented( - "FlashAttention is unsupported, please set use_flash_attn to false.")); + RaiseNotSupportedError(); #endif } @@ -150,6 +147,7 @@ void FlashAttnGradKernel(const Context& ctx, const DenseTensor& out, const DenseTensor& softmax_lse, const DenseTensor& seed_offset, + const paddle::optional& attn_mask, const DenseTensor& dout, float dropout, bool causal, @@ -159,14 +157,17 @@ void FlashAttnGradKernel(const Context& ctx, #ifdef PADDLE_WITH_FLASHATTN // q,k,v [batch_size, seq_len, num_heads, head_dim] - auto dims = q.dims(); - const int batch_size = dims[0]; - const int seqlen_q = dims[1]; - const int num_heads = dims[2]; - const int head_size_og = dout.dims()[3]; - const int head_size = dims[3]; - const int seqlen_k = k.dims()[1]; - const int num_heads_k = k.dims()[2]; + const auto& dims = q.dims(); + const int64_t batch_size = dims[0]; + const int64_t seqlen_q = dims[1]; + const int64_t num_heads = dims[2]; + const int64_t head_size_og = dout.dims()[3]; + const int64_t head_size = dims[3]; + const int64_t seqlen_k = k.dims()[1]; + const int64_t num_heads_k = k.dims()[2]; + + const int64_t total_q = batch_size * seqlen_q; + const int64_t total_k = batch_size * seqlen_k; // TODO(umiswing): add shape check PADDLE_ENFORCE_EQ( @@ -175,8 +176,8 @@ void FlashAttnGradKernel(const Context& ctx, phi::errors::InvalidArgument( "flash_attn_bwd receive input with head_size_og == head_size")); - VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims() - << "], v[" << v.dims() << "]"; + VLOG(10) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims() + << "], v[" << v.dims() << "]"; const float scale = 1.0f / std::sqrt(head_size); @@ -192,6 +193,7 @@ void FlashAttnGradKernel(const Context& ctx, scale, causal, q.dtype(), + attn_mask, seed_offset.data()); ctx.template Alloc(dq); @@ -200,46 +202,47 @@ void FlashAttnGradKernel(const Context& ctx, cudaStream_t stream = ctx.stream(); - VLOG(4) << "FlashAttn bwd seed: " << params.seed - << ", offset: " << params.offset; + VLOG(10) << "FlashAttn bwd seed: " << params.seed + << ", offset: " << params.offset; - const bool succ = phi::dynload::flash_attn_bwd(dout.data(), - q.data(), - k.data(), - v.data(), - out.data(), - params.softmax_d.data(), - softmax_lse.data(), - params.rng_state.data(), - dq->data(), - dk->data(), - dv->data(), - params.dq_accum.data(), - params.batch_size, - params.max_seqlen_q, - params.max_seqlen_k, - params.seqlen_q_rounded, - params.seqlen_k_rounded, - params.num_heads, - params.num_heads_k, - params.head_size, - params.head_size_rounded, - params.dropout, - params.scale, - params.causal, - params.is_bf16, - stream, - params.seed, - params.offset); + int num_splits = get_num_split(); - PADDLE_ENFORCE_EQ( - succ, - true, - phi::errors::External("Error in Flash-Attention-2, detail information is", - phi::dynload::flash_attn_error())); + bool succ = phi::dynload::flash_attn_bwd( + dout.data(), + q.data(), + k.data(), + v.data(), + out.data(), + params.softmax_d.data(), + softmax_lse.data(), + params.rng_state.data(), + dq->data(), + dk->data(), + dv->data(), + params.dq_accum.data(), + params.batch_size, + params.max_seqlen_q, + params.max_seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.num_heads, + params.num_heads_k, + params.head_size, + params.head_size_rounded, + params.dropout, + params.scale, + std::sqrt(head_size), // for unscale + params.causal, + params.is_bf16, + num_splits, + stream, + params.seed, + params.offset, + params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, + params.mask_dims.data()); + CheckFlashAttnStatus(succ); #else - PADDLE_THROW(phi::errors::Unimplemented( - "FlashAttention is unsupported, please set use_flash_attn to false.")); + RaiseNotSupportedError(); #endif } diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index e943b7bbf78519..aadae0f29c3427 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -15,22 +15,16 @@ #include "paddle/phi/kernels/flash_attn_kernel.h" #include "glog/logging.h" // For VLOG() -#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/data_type.h" -#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/reshape_kernel.h" - -#ifdef PADDLE_WITH_FLASHATTN -#include "paddle/phi/backends/dynload/flashattn.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h" -#endif +#include "paddle/phi/kernels/reshape_kernel.h" -DECLARE_bool(cudnn_deterministic); +PD_DECLARE_bool(cudnn_deterministic); namespace phi { @@ -43,6 +37,7 @@ void FlashAttnUnpaddedKernel( const DenseTensor& cu_seqlens_q, const DenseTensor& cu_seqlens_k, const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, @@ -56,13 +51,11 @@ void FlashAttnUnpaddedKernel( DenseTensor* softmax_lse, DenseTensor* seed_offset) { #ifdef PADDLE_WITH_FLASHATTN - ctx.template Alloc(out); cudaStream_t stream = ctx.stream(); // q,k,v [total_*, num_heads, head_dim] - auto dims = q.dims(); PADDLE_ENFORCE_EQ( dims.size(), @@ -71,45 +64,39 @@ void FlashAttnUnpaddedKernel( "[total_seq_len, num_heads, head_dim]")); const int64_t total_q = dims[0]; - const int num_heads = dims[1]; - const int head_size = dims[2]; + const int64_t num_heads = dims[1]; + const int64_t head_size = dims[2]; - const int total_k = k.dims()[0]; - const int num_heads_k = k.dims()[1]; - const int batch_size = cu_seqlens_q.numel() - 1; - - // TODO(umiswing): add deterministic in fa2. - // int num_splits = 0; // 0 for an internal heuristic, which is optimal - // if (FLAGS_cudnn_deterministic) { - // num_splits = 1; - // } + const int64_t total_k = k.dims()[0]; + const int64_t num_heads_k = k.dims()[1]; + const int64_t batch_size = cu_seqlens_q.numel() - 1; // TODO(umiswing): add shape check - FlashAttnFwdParamsV2 params = - FlashAttnFwdParamsV2(ctx, - batch_size, - max_seqlen_q, - max_seqlen_k, - num_heads, - num_heads_k, - head_size, - dropout, - scale, - causal, - return_softmax, - q.dtype(), - is_test, - rng_name, - fixed_seed_offset.get_ptr(), - softmax, - softmax_lse, - seed_offset); - - VLOG(4) << "FlashAttn fwd seed: " << params.seed - << ", offset: " << params.offset; - - const bool succ = phi::dynload::flash_attn_varlen_fwd( + FlashAttnFwdParamsV2 params = FlashAttnFwdParamsV2(ctx, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + num_heads_k, + head_size, + dropout, + scale, + causal, + return_softmax, + q.dtype(), + is_test, + rng_name, + fixed_seed_offset, + attn_mask, + softmax, + softmax_lse, + seed_offset); + + VLOG(10) << "FlashAttn fwd seed: " << params.seed + << ", offset: " << params.offset; + + bool succ = phi::dynload::flash_attn_varlen_fwd( q.data(), k.data(), v.data(), @@ -130,19 +117,18 @@ void FlashAttnUnpaddedKernel( params.head_size_rounded, params.dropout, params.scale, + 1.0f / params.scale, params.causal, params.return_softmax, params.is_bf16, stream, params.seed, - params.offset); - - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } + params.offset, + params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, + params.mask_dims.data()); + CheckFlashAttnStatus(succ); #else - PADDLE_THROW(phi::errors::Unimplemented( - "FlashAttention is unsupported, please set use_flash_attn to false.")); + RaiseNotSupportedError(); #endif } @@ -152,6 +138,7 @@ void FlashAttnKernel(const Context& ctx, const DenseTensor& k, const DenseTensor& v, const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, float dropout, bool causal, bool return_softmax, @@ -163,55 +150,56 @@ void FlashAttnKernel(const Context& ctx, DenseTensor* seed_offset) { #ifdef PADDLE_WITH_FLASHATTN // q,k,v [batch_size, seq_len, num_heads, head_dim] - - auto dims = q.dims(); + const auto& dims = q.dims(); PADDLE_ENFORCE_EQ(dims.size(), 4, phi::errors::InvalidArgument( "flash_attn receive input with dim " "[batch_size, seq_len, num_heads, head_dim]")); - const int batch_size = dims[0]; - const int seqlen_q = dims[1]; - const int num_heads = dims[2]; - const int head_size = dims[3]; - const int seqlen_k = k.dims()[1]; - const int num_heads_k = k.dims()[2]; + const int64_t batch_size = dims[0]; + const int64_t seqlen_q = dims[1]; + const int64_t num_heads = dims[2]; + const int64_t head_size = dims[3]; + const int64_t seqlen_k = k.dims()[1]; + const int64_t num_heads_k = k.dims()[2]; + + const int64_t total_q = batch_size * seqlen_q; + const int64_t total_k = batch_size * seqlen_k; // TODO(umiswing): Add check shape const float scale = 1.0f / std::sqrt(head_size); - FlashAttnFwdParamsV2 params = - FlashAttnFwdParamsV2(ctx, - batch_size, - seqlen_q, - seqlen_k, - num_heads, - num_heads_k, - head_size, - dropout, - scale, - causal, - return_softmax, - q.dtype(), - is_test, - rng_name, - fixed_seed_offset.get_ptr(), - softmax, - softmax_lse, - seed_offset); - - VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims() - << "], v[" << v.dims() << "]"; + FlashAttnFwdParamsV2 params = FlashAttnFwdParamsV2(ctx, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size, + dropout, + scale, + causal, + return_softmax, + q.dtype(), + is_test, + rng_name, + fixed_seed_offset, + attn_mask, + softmax, + softmax_lse, + seed_offset); + + VLOG(10) << "FlashAttn fwd dims: q[" << q.dims() << "], k[" << k.dims() + << "], v[" << v.dims() << "]"; + VLOG(10) << "FlashAttn fwd seed: " << params.seed + << ", offset: " << params.offset; ctx.template Alloc(out); cudaStream_t stream = ctx.stream(); - VLOG(4) << "FlashAttn fwd seed: " << params.seed - << ", offset: " << params.offset; - bool succ = phi::dynload::flash_attn_fwd( q.data(), k.data(), @@ -231,21 +219,18 @@ void FlashAttnKernel(const Context& ctx, params.head_size_rounded, params.dropout, params.scale, + std::sqrt(head_size), // for unscale params.causal, params.return_softmax, params.is_bf16, stream, params.seed, - params.offset); - - PADDLE_ENFORCE_EQ( - succ, - true, - phi::errors::External("Error in Flash-Attention-2, detail information is", - phi::dynload::flash_attn_error())); + params.offset, + params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, + params.mask_dims.data()); + CheckFlashAttnStatus(succ); #else - PADDLE_THROW(phi::errors::Unimplemented( - "FlashAttention is unsupported, please set use_flash_attn to false.")); + RaiseNotSupportedError(); #endif } diff --git a/paddle/phi/kernels/gpu/flash_attn_utils.h b/paddle/phi/kernels/gpu/flash_attn_utils.h index 62d0f4ec95b37e..ea438014f43125 100644 --- a/paddle/phi/kernels/gpu/flash_attn_utils.h +++ b/paddle/phi/kernels/gpu/flash_attn_utils.h @@ -14,8 +14,69 @@ #pragma once +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/enforce.h" + +#ifdef PADDLE_WITH_FLASHATTN +#include "paddle/phi/backends/dynload/flashattn.h" +#endif + namespace phi { +#ifdef PADDLE_WITH_FLASHATTN +static std::pair GenerateRNGState( + const GPUContext& ctx, + const paddle::optional& fixed_seed_offset, + const std::string& rng_name, + const int64_t batch_size, + const int64_t num_heads) { + if (fixed_seed_offset.get_ptr()) { + const int64_t* fixed_seed_offset_data = + fixed_seed_offset.get_ptr()->data(); + uint64_t seed = static_cast(fixed_seed_offset_data[0]); + uint64_t offset = static_cast(fixed_seed_offset_data[1]); + return std::make_pair(seed, offset); + } else { + uint64_t inc = batch_size * num_heads * 32; + std::pair seed_offset_pair; + if (rng_name != "") { + auto gen = phi::GetRandomSeedGenerator(rng_name); + seed_offset_pair = gen->IncrementOffset(inc); + } else { + auto* gen = ctx.GetGenerator(); + seed_offset_pair = gen->IncrementOffset(inc); + } + return seed_offset_pair; + } +} + +static std::vector GetAttnMaskDims(const DenseTensor* attn_mask) { + std::vector mask_dim_4d; + if (attn_mask) { + const auto& origin_dims = attn_mask->dims(); + auto rank = origin_dims.size(); + PADDLE_ENFORCE_GE( + rank, + 4, + phi::errors::InvalidArgument( + "The number of dimenstions of attn_mask is expected to be greater " + "or equal to 4, but recieved %d. The shape of attn_mask is {%s}", + rank, + origin_dims)); + + int64_t first_dim = 1; + for (int i = 0; i < rank - 3; i++) { + first_dim *= origin_dims[i]; + } + mask_dim_4d = {first_dim, + origin_dims[rank - 3], + origin_dims[rank - 2], + origin_dims[rank - 1]}; + } + return mask_dim_4d; +} + template struct FlashAttnFwdParamsV2 { int batch_size; @@ -36,7 +97,9 @@ struct FlashAttnFwdParamsV2 { bool is_bf16; uint64_t seed; uint64_t offset; + std::vector mask_dims; DenseTensor rng_state; + const DenseTensor* attn_mask_tensor; DenseTensor* softmax; DenseTensor* softmax_lse; DenseTensor* seed_offset; @@ -55,7 +118,8 @@ struct FlashAttnFwdParamsV2 { const DataType q_dtype, const bool is_test, const std::string& rng_name, - const DenseTensor* const fixed_seed_offset_ptr, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, DenseTensor* _softmax, DenseTensor* _softmax_lse, DenseTensor* _seed_offset) @@ -71,31 +135,19 @@ struct FlashAttnFwdParamsV2 { return_softmax(_return_softmax), softmax(_softmax), softmax_lse(_softmax_lse), - seed_offset(_seed_offset) { + seed_offset(_seed_offset), + attn_mask_tensor(attn_mask.get_ptr()) { dropout = is_test ? 0.0f : _dropout; is_bf16 = q_dtype == DataType::BFLOAT16; // (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t // with the same size. rng_state = Empty(ctx, {2}); - if (fixed_seed_offset_ptr) { - const int64_t* fixed_seed_offset_data = - fixed_seed_offset_ptr->data(); - seed = static_cast(fixed_seed_offset_data[0]); - offset = static_cast(fixed_seed_offset_data[1]); - } else { - uint64_t inc = batch_size * num_heads * 32; - std::pair seed_offset_pair; - if (rng_name != "") { - auto gen = phi::GetRandomSeedGenerator(rng_name); - seed_offset_pair = gen->IncrementOffset(inc); - } else { - auto* gen = ctx.GetGenerator(); - seed_offset_pair = gen->IncrementOffset(inc); - } - seed = seed_offset_pair.first; - offset = seed_offset_pair.second; - } + + auto seed_offset_pair = GenerateRNGState( + ctx, fixed_seed_offset, rng_name, batch_size, num_heads); + seed = seed_offset_pair.first; + offset = seed_offset_pair.second; seed_offset->Resize({2}); int64_t* seed_offset_data = ctx.template HostAlloc(seed_offset); @@ -111,10 +163,25 @@ struct FlashAttnFwdParamsV2 { ctx.template Alloc(softmax_lse); if (return_softmax) { + PADDLE_ENFORCE_EQ( + dropout > 0.0f, + true, + phi::errors::InvalidArgument( + "return_softmax is only supported when dropout > 0.0")); + softmax->Resize( {batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded}); ctx.template Alloc(softmax); } + + mask_dims = GetAttnMaskDims(attn_mask_tensor); + if (attn_mask) { + PADDLE_ENFORCE_EQ( + attn_mask->dtype(), + q_dtype, + phi::errors::InvalidArgument( + "attn_mask is expected to have the same data type with q.")); + } } }; @@ -134,9 +201,11 @@ struct FlashAttnBwdParamsV2 { bool is_bf16; uint64_t seed; uint64_t offset; + std::vector mask_dims; DenseTensor softmax_d; DenseTensor dq_accum; DenseTensor rng_state; + const DenseTensor* attn_mask_tensor; FlashAttnBwdParamsV2(const GPUContext& ctx, const int _batch_size, @@ -149,6 +218,7 @@ struct FlashAttnBwdParamsV2 { const float _scale, const bool _causal, const DataType q_dtype, + const paddle::optional& attn_mask, const int64_t* seed_offset_data) : batch_size(_batch_size), max_seqlen_q(_max_seqlen_q), @@ -158,7 +228,8 @@ struct FlashAttnBwdParamsV2 { head_size(_head_size), dropout(_dropout), scale(_scale), - causal(_causal) { + causal(_causal), + attn_mask_tensor(attn_mask.get_ptr()) { is_bf16 = q_dtype == DataType::BFLOAT16; seed = static_cast(seed_offset_data[0]); offset = static_cast(seed_offset_data[1]); @@ -176,6 +247,52 @@ struct FlashAttnBwdParamsV2 { softmax_d = Empty(ctx, {batch_size, num_heads, seqlen_q_rounded}); dq_accum = Empty( ctx, {batch_size, num_heads, seqlen_q_rounded, head_size_rounded}); + + mask_dims = GetAttnMaskDims(attn_mask_tensor); + if (attn_mask) { + PADDLE_ENFORCE_EQ( + attn_mask->dtype(), + q_dtype, + phi::errors::InvalidArgument( + "attn_mask is expected to have the same data type with q.")); + } } }; + +static void CheckFlashAttnStatus(const bool status) { + PADDLE_ENFORCE_EQ(status, + true, + phi::errors::External( + "Error in Flash-Attention, detail information is: %s", + phi::dynload::flash_attn_error())); +} + +template +__global__ void SimleScaleKernel(const T* input, + int64_t numel, + float scale, + T* ouput) { + CUDA_KERNEL_LOOP_TYPE(i, numel, int64_t) { + ouput[i] = static_cast(scale * static_cast(input[i])); + } +} + +template +void ComputeScaleQ( + const Context& ctx, int64_t numel, float scale, const T* input, T* output) { + auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 1); + SimleScaleKernel<<>>(input, numel, scale, output); +} + +#endif + +static void RaiseNotSupportedError() { + PADDLE_THROW( + phi::errors::Unimplemented("FlashAttention is unsupported, please check " + "the GPU compability and CUDA Version.")); +} + } // namespace phi diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py new file mode 100644 index 00000000000000..d78f4cf575a6db --- /dev/null +++ b/test/legacy_test/test_flash_attention.py @@ -0,0 +1,464 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import unittest + +import numpy as np + +import paddle +import paddle.nn.functional as F +from paddle import base +from paddle.base import core +from paddle.nn.functional.flash_attention import ( + flash_attention, + flash_attn_unpadded, + scaled_dot_product_attention, +) + + +def get_cuda_version(): + result = os.popen("nvcc --version").read() + regex = r'release (\S+),' + match = re.search(regex, result) + if match: + num = str(match.group(1)) + integer, decimal = num.split('.') + return int(integer) * 1000 + int(float(decimal) * 10) + else: + return -1 + + +def attention_naive(q, k, v, causal=False): + qt = paddle.transpose(q, [0, 2, 1, 3]) + kt = paddle.transpose(k, [0, 2, 1, 3]) + vt = paddle.transpose(v, [0, 2, 1, 3]) + scale = 1.0 / np.sqrt(q.shape[-1]) + s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2])) + s = paddle.scale(s, scale) + p = ( + paddle.incubate.softmax_mask_fuse_upper_triangle(s) + if causal + else F.softmax(s) + ) + o = paddle.matmul(p, vt) + return paddle.transpose(o, [0, 2, 1, 3]) + + +def attention_naive_with_mask(q, k, v, attn_bias): + qt = paddle.transpose(q, [0, 2, 1, 3]) + kt = paddle.transpose(k, [0, 2, 1, 3]) + vt = paddle.transpose(v, [0, 2, 1, 3]) + scale = 1.0 / np.sqrt(q.shape[-1]) + s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2])) + s = paddle.scale(s, scale) + p = F.softmax(s + attn_bias) + o = paddle.matmul(p, vt) + return paddle.transpose(o, [0, 2, 1, 3]) + + +is_sm8x = ( + core.is_compiled_with_cuda() + and paddle.device.cuda.get_device_capability()[0] == 8 + and paddle.device.cuda.get_device_capability()[1] >= 0 +) + +is_sm90 = ( + core.is_compiled_with_cuda() + and paddle.device.cuda.get_device_capability()[0] == 9 + and paddle.device.cuda.get_device_capability()[1] == 0 +) + +is_sm_supported = is_sm8x or is_sm90 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11040 + or not is_sm_supported, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" + "and device's compute capability must be 8.x or 90", +) +class TestFlashAttentionAPI(unittest.TestCase): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (2, 128, 8, 16) + self.dtype = 'float16' + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + self.use_sdp_kernel = False + self.use_sdp_api = False + + def test_unpadded(self): + print( + f"Test unpadded case shape {self.shape} dtype {self.dtype} causal {self.causal}" + ) + + paddle.disable_static() + + query = np.random.random(self.shape) + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + q_ = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + out_ = attention_naive(q_, q_, q_, self.causal) + + scale = 1.0 / np.sqrt(q.shape[-1]) + + bs = self.shape[0] + ms = self.shape[1] + nh = self.shape[2] + hd = self.shape[3] + cu_q = paddle.arange(0, (bs + 1) * ms, ms, dtype='int32') + + qq = paddle.reshape(q, [bs * ms, nh, hd]) + out, _ = flash_attn_unpadded( + qq, + qq, + qq, + cu_q, + cu_q, + ms, + ms, + scale, + self.dropout, + self.causal, + self.return_softmax, + ) + out_ = paddle.reshape(out_, [bs * ms, nh, hd]) + + np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03) + + out.backward() + out_.backward() + + np.testing.assert_allclose( + q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03 + ) + + # test static + paddle.enable_static() + + with paddle.static.program_guard(paddle.static.Program()): + qs = paddle.static.data( + name="q", shape=self.shape, dtype=self.dtype + ) + + cu_q = paddle.arange(0, (bs + 1) * ms, ms, dtype='int32') + qs = paddle.reshape(qs, [bs * ms, nh, hd]) + + outs, softmax = flash_attn_unpadded( + qs, + qs, + qs, + cu_q, + cu_q, + ms, + ms, + scale, + self.dropout, + self.causal, + self.return_softmax, + ) + + exe = base.Executor(self.place) + fetches_result = exe.run( + feed={ + "q": query.astype('float16'), + "k": query.astype('float16'), + "v": query.astype('float16'), + }, + fetch_list=[outs], + ) + + np.testing.assert_allclose( + fetches_result[0], out_, rtol=5e-03, atol=1e-03 + ) + + def test_all(self): + print( + f"Test case shape {self.shape} dtype {self.dtype} causal {self.causal}" + ) + # test dynamic + paddle.disable_static() + + query = np.random.random(self.shape) + key = np.random.random(self.shape) + value = np.random.random(self.shape) + + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + k = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + v = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + q_ = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + k_ = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + v_ = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + if self.use_sdp_kernel: + with paddle.nn.functional.sdp_kernel( + enable_math=self.enable_math, + enable_flash=self.enable_flash, + enable_mem_efficient=self.enable_mem_efficient, + ): + if self.use_sdp_api: + out = scaled_dot_product_attention( + q, k, v, None, self.dropout, self.causal + ) + else: + out, _ = flash_attention( + q, k, v, self.dropout, self.causal, self.return_softmax + ) + + else: + out, _ = flash_attention( + q, k, v, self.dropout, self.causal, self.return_softmax + ) + out_ = attention_naive(q_, k_, v_, self.causal) + + out.backward() + out_.backward() + + np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03) + + self.assertEqual(q.grad.shape, q.shape) + self.assertEqual(q_.grad.shape, q.shape) + + np.testing.assert_allclose( + q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03 + ) + + # test static + paddle.enable_static() + + with paddle.static.program_guard(paddle.static.Program()): + qs = paddle.static.data( + name="q", shape=self.shape, dtype=self.dtype + ) + ks = paddle.static.data( + name="k", shape=self.shape, dtype=self.dtype + ) + vs = paddle.static.data( + name="v", shape=self.shape, dtype=self.dtype + ) + + if self.use_sdp_kernel: + with paddle.nn.functional.sdp_kernel( + enable_math=self.enable_math, + enable_flash=self.enable_flash, + enable_mem_efficient=self.enable_mem_efficient, + ): + if self.use_sdp_api: + outs = scaled_dot_product_attention( + qs, ks, vs, None, self.dropout, self.causal + ) + else: + outs, softmax = flash_attention( + qs, + ks, + vs, + self.dropout, + self.causal, + self.return_softmax, + ) + else: + outs, softmax = flash_attention( + qs, ks, vs, self.dropout, self.causal, self.return_softmax + ) + + exe = base.Executor(self.place) + fetches_result = exe.run( + feed={ + "q": query.astype('float16'), + "k": key.astype('float16'), + "v": value.astype('float16'), + }, + fetch_list=[outs], + ) + + np.testing.assert_allclose( + fetches_result[0], out_, rtol=5e-03, atol=1e-03 + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11040 + or not is_sm_supported, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" + "and device's compute capability must be 7.5 or 8.x", +) +class TestFlashAttentionWithMaskAPI(unittest.TestCase): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (2, 128, 8, 32) + self.dtype = 'float16' + self.dropout = 0.0 + self.causal = False + + def test_dot_scale_product(self): + # test dynamic + paddle.disable_static() + + query = np.random.random(self.shape) + key = np.random.random(self.shape) + value = np.random.random(self.shape) + + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + k = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + v = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + q_ = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + k_ = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + v_ = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + mask_shape = (self.shape[0], 1, self.shape[1], self.shape[1]) + mask = np.random.random(mask_shape) + m = paddle.to_tensor( + mask, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + out = scaled_dot_product_attention( + q, k, v, m, self.dropout, self.causal + ) + out_ = attention_naive_with_mask(q_, k_, v_, m) + out.backward() + out_.backward() + np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03) + + +class TestFlashAttentionAPITest1(TestFlashAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (2, 128, 8, 16) + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + self.use_sdp_kernel = False + + +class TestFlashAttentionAPITest2(TestFlashAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (2, 256, 8, 16) + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + self.use_sdp_kernel = False + + +class TestFlashAttentionAPITest3(TestFlashAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (2, 512, 8, 16) + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = True + self.return_softmax = False + self.use_sdp_kernel = False + + +class TestFlashAttentionAPITest4(TestFlashAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (8, 1024, 16, 128) + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + self.use_sdp_kernel = False + + +class TestFlashAttentionAPITest5(TestFlashAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (8, 1024, 16, 256) + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + self.use_sdp_kernel = False + + +class TestMathAttentionAPITest(TestFlashAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (8, 1024, 16, 128) + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + self.use_sdp_kernel = True + self.use_sdp_api = False + self.enable_math = True + self.enable_flash = False + self.enable_mem_efficient = False + + +class TestSDPAttentionAPITest(TestFlashAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (8, 1024, 16, 128) + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + self.use_sdp_kernel = True + self.use_sdp_api = True + self.enable_math = True + self.enable_flash = False + self.enable_mem_efficient = False + + +class TestFlashAttenionWithMaskAPITest(TestFlashAttentionWithMaskAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (8, 1024, 16, 128) + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + + +if __name__ == '__main__': + unittest.main() diff --git a/third_party/flashattn b/third_party/flashattn new file mode 160000 index 00000000000000..b74460b385b691 --- /dev/null +++ b/third_party/flashattn @@ -0,0 +1 @@ +Subproject commit b74460b385b691d881ff2d3a1adbcefdcac574a3