From 63289f49a13a35d9add8fc72dd2a633a5b211669 Mon Sep 17 00:00:00 2001 From: iosmers <1871465933@qq.com> Date: Thu, 3 Aug 2023 22:03:50 +0800 Subject: [PATCH 01/13] add mask --- paddle/phi/api/yaml/backward.yaml | 4 +- paddle/phi/api/yaml/ops.yaml | 8 +- paddle/phi/kernels/flash_attn_kernel.h | 2 + paddle/phi/kernels/gpu/flash_attn_kernel.cu | 285 +++++++++++++----- .../paddle/nn/functional/flash_attention.py | 44 ++- test/legacy_test/test_flash_attention.py | 44 ++- 6 files changed, 309 insertions(+), 78 deletions(-) diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 1dda103c85c42..7cc569bc07e50 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -818,7 +818,7 @@ inplace : (out_grad -> x_grad) - backward_op : flash_attn_grad - forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset,Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) infer_meta : @@ -829,7 +829,7 @@ data_type: q - backward_op : flash_attn_unpadded_grad - forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false) output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) infer_meta : diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 78cbd7c65188d..ea8b78d869cb0 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -910,9 +910,9 @@ backward : fill_diagonal_tensor_grad - op : flash_attn - args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") + args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) - optional : fixed_seed_offset + optional : fixed_seed_offset, attn_mask infer_meta : func : FlashAttnInferMeta param : [q, k, v] @@ -923,9 +923,9 @@ backward : flash_attn_grad - op : flash_attn_unpadded - args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") + args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) - optional : fixed_seed_offset + optional : fixed_seed_offset , attn_mask infer_meta : func : FlashAttnInferMeta param : [q, k, v] diff --git a/paddle/phi/kernels/flash_attn_kernel.h b/paddle/phi/kernels/flash_attn_kernel.h index 296e242026087..ec72d85a0babb 100644 --- a/paddle/phi/kernels/flash_attn_kernel.h +++ b/paddle/phi/kernels/flash_attn_kernel.h @@ -28,6 +28,7 @@ void FlashAttnUnpaddedKernel( const DenseTensor& cu_seqlens_q, const DenseTensor& cu_seqlens_k, const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, @@ -47,6 +48,7 @@ void FlashAttnKernel(const Context& ctx, const DenseTensor& k, const DenseTensor& v, const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, float dropout, bool causal, bool return_softmax, diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 714edf4be6f3c..6cfcc962dd709 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -16,6 +16,7 @@ #include "glog/logging.h" // For VLOG() #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/flags.h" @@ -33,6 +34,24 @@ DECLARE_bool(cudnn_deterministic); namespace phi { +// template +// void ComputeScaleQ(const Context& ctx, const DenseTensor& q, int64_t numel, +// int64_t head_dim, T* q_ptr){ +// auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, +// 1); DenseTensor q_(q); T* q_ptr = static_cast(q_.data()); +// SimleScaleWithMaskKernel<<>>(q_size, scale, q_ptr); +// } + +template +__global__ void SimleScaleWithMaskKernel(int64_t numel, float scale, T* inout) { + CUDA_KERNEL_LOOP_TYPE(i, numel, int64_t) { + inout[i] = static_cast(scale * static_cast(inout[i])); + } +} + template void FlashAttnUnpaddedKernel( const Context& ctx, @@ -42,6 +61,7 @@ void FlashAttnUnpaddedKernel( const DenseTensor& cu_seqlens_q, const DenseTensor& cu_seqlens_k, const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, @@ -56,11 +76,13 @@ void FlashAttnUnpaddedKernel( DenseTensor* seed_offset) { #ifdef PADDLE_WITH_FLASHATTN if (is_test) dropout = 0.0f; + // printf("welcom to FlashAttnUnpaddedKernel\n"); ctx.template Alloc(out); cudaStream_t stream = ctx.stream(); bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false; + // printf("is_bf16\n", is_bf16); // q,k,v [total_*, num_heads, head_dim] @@ -134,76 +156,201 @@ void FlashAttnUnpaddedKernel( } uint64_t workspace_size; + bool succ; + // printf("welcome to flash_attention_with_mask\n"); + if (attn_mask.get_ptr()) { + // compute scale Q + // printf("compute scale Q start\n"); + int64_t q_size = total_q * num_heads * head_size; + + auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, q_size, 1); + DenseTensor q_(q); + T* q_ptr = static_cast(q_.data()); + + SimleScaleWithMaskKernel<<>>(q_size, scale, q_ptr); + scale = 1.0f; + // printf("compute scale Q finish\n"); + std::vector temp_rand_mask_dim; + const DenseTensor* attn_mask_ptr = attn_mask.get_ptr(); + // const const int64_t* attn_mask_ptr = attn_mask.get_ptr(); + int64_t first_dim = 1; + const auto& origin_dims = attn_mask_ptr->dims(); + auto rank = origin_dims.size(); + for (int i = 0; i < rank - 3; i++) { + first_dim *= origin_dims[i]; + } + temp_rand_mask_dim = {first_dim, + origin_dims[rank - 3], + origin_dims[rank - 2], + origin_dims[rank - 1]}; + // printf("start to exec flash_attn_fwd_with_bias_and_mask\n"); + // succ =phi::dynload::flash_attn_fwd( + // q.data(), + // k.data(), + // v.data(), + // nullptr, // for calculation workspace size + // cu_seqlens_q.data(), + // cu_seqlens_k.data(), + // total_q, + // total_k, + // batch_size, + // num_heads, + // head_size, + // max_seqlen_q, + // max_seqlen_k, + // dropout, + // scale, + // zero_tensors, + // causal, + // is_bf16, + // num_splits, + // softmax_lse->data(), + // return_softmax ? softmax->data() : nullptr, + // nullptr, + // &workspace_size, + // stream, + // seed, + // offset); + succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( + static_cast(q_ptr), + static_cast(k.data()), + static_cast(v.data()), + nullptr, // for calculation workspace size + static_cast(cu_seqlens_q.data()), + static_cast(cu_seqlens_k.data()), + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + scale, + zero_tensors, + is_bf16, + num_splits, + softmax_lse->data(), + nullptr, + &workspace_size, + stream, + seed, + offset, + attn_mask_ptr ? attn_mask_ptr->data() : nullptr, + nullptr, + temp_rand_mask_dim.data() ? temp_rand_mask_dim.data() : nullptr, + nullptr); + if (!succ) { + PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + } - // TODO(kuizhiqing) pass allocation/empty func in capi to decouple - // calculate workspace size before execution - bool succ = - phi::dynload::flash_attn_fwd(q.data(), - k.data(), - v.data(), - nullptr, // for calculation workspace size - cu_seqlens_q.data(), - cu_seqlens_k.data(), - total_q, - total_k, - batch_size, - num_heads, - head_size, - max_seqlen_q, - max_seqlen_k, - dropout, - scale, - zero_tensors, - causal, - is_bf16, - num_splits, - softmax_lse->data(), - return_softmax ? softmax->data() : nullptr, - nullptr, - &workspace_size, - stream, - seed, - offset); - - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } + DenseTensor workspace; + if (workspace_size > 0) { + workspace = Empty(ctx, {int64_t(workspace_size / sizeof(float))}); + } + succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( + static_cast(q_ptr), + k.data(), + v.data(), + out->data(), // set out to nullptr to calculate workspace size + static_cast(cu_seqlens_q.data()), + static_cast(cu_seqlens_k.data()), + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + scale, + zero_tensors, + is_bf16, + num_splits, + softmax_lse->data(), + workspace_size > 0 ? workspace.data() : nullptr, + &workspace_size, + stream, + seed, + offset, + attn_mask_ptr ? attn_mask_ptr->data() : nullptr, + nullptr, + temp_rand_mask_dim.data() ? temp_rand_mask_dim.data() : nullptr, + nullptr); + if (!succ) { + PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + } + } else { + succ = + phi::dynload::flash_attn_fwd(q.data(), + k.data(), + v.data(), + nullptr, // for calculation workspace size + cu_seqlens_q.data(), + cu_seqlens_k.data(), + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + scale, + zero_tensors, + causal, + is_bf16, + num_splits, + softmax_lse->data(), + return_softmax ? softmax->data() : nullptr, + nullptr, + &workspace_size, + stream, + seed, + offset); + if (!succ) { + PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + } - DenseTensor workspace; - if (workspace_size > 0) { - workspace = Empty(ctx, {int64_t(workspace_size / sizeof(float))}); - } + DenseTensor workspace; + if (workspace_size > 0) { + workspace = Empty(ctx, {int64_t(workspace_size / sizeof(float))}); + } - succ = phi::dynload::flash_attn_fwd( - q.data(), - k.data(), - v.data(), - out->data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - total_q, - total_k, - batch_size, - num_heads, - head_size, - max_seqlen_q, - max_seqlen_k, - dropout, - scale, - zero_tensors, - causal, - is_bf16, - num_splits, - softmax_lse->data(), - return_softmax ? softmax->data() : nullptr, - workspace_size > 0 ? workspace.data() : nullptr, - &workspace_size, - stream, - seed, - offset); - - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + succ = phi::dynload::flash_attn_fwd( + q.data(), + k.data(), + v.data(), + out->data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + scale, + zero_tensors, + causal, + is_bf16, + num_splits, + softmax_lse->data(), + return_softmax ? softmax->data() : nullptr, + workspace_size > 0 ? workspace.data() : nullptr, + &workspace_size, + stream, + seed, + offset); + + if (!succ) { + PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + } } #endif @@ -215,6 +362,7 @@ void FlashAttnKernel(const Context& ctx, const DenseTensor& k, const DenseTensor& v, const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, float dropout, bool causal, bool return_softmax, @@ -226,7 +374,7 @@ void FlashAttnKernel(const Context& ctx, DenseTensor* seed_offset) { #ifdef PADDLE_WITH_FLASHATTN // q,k,v [batch_size, seq_len, num_heads, head_dim] - + // printf("welcome to FlashAttnKernel\n"); auto dims = q.dims(); PADDLE_ENFORCE_EQ(dims.size(), 4, @@ -268,6 +416,7 @@ void FlashAttnKernel(const Context& ctx, cu_seqlens_q, cu_seqlens_k, fixed_seed_offset, + attn_mask, seq_len_q, seq_len_k, scale, diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index b36bd5d74ec7b..44ecf7144abb3 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -202,6 +202,7 @@ def flash_attention( key, value, fixed_seed_offset, + None, dropout, causal, return_softmax, @@ -358,6 +359,7 @@ def flash_attn_unpadded( cu_seqlens_q, cu_seqlens_k, fixed_seed_offset, + None, max_seqlen_q, max_seqlen_k, scale, @@ -408,7 +410,13 @@ def flash_attn_unpadded( def scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + fixed_seed_offset=None, ): r""" The equation is: @@ -458,6 +466,36 @@ def scaled_dot_product_attention( >>> print(output) >>> # xdoctest: -SKIP """ - assert attn_mask is None, "attn_mask is not supported yet" - out, _ = flash_attention(query, key, value, dropout_p, is_causal) + if attn_mask is None: + out, _ = flash_attention(query, key, value, dropout_p, is_causal) + else: + # out, _ = _C_ops.flash_attn( + # query, + # key, + # value, + # fixed_seed_offset, + # attn_mask, + # dropout_p , + # is_causal, + # return_softmax = False, + # training = False, + # rng_name = "", + # ) + dropout = 0.0 + causal = False + return_softmax = False + training = True + rng_name = "" + out, _ = _C_ops.flash_attn( + query, + key, + value, + fixed_seed_offset, + attn_mask, + dropout, + causal, + return_softmax, + not training, + rng_name, + ) return out diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index 6bde691bd2f95..8d8707a959997 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -80,7 +80,7 @@ def attention_naive(q, k, v, causal=False): class TestFlashAttentionAPI(unittest.TestCase): def setUp(self): self.place = paddle.CUDAPlace(0) - self.shape = (2, 128, 8, 16) + self.shape = (2, 128, 8, 32) self.dtype = 'float16' self.dropout = 0.0 self.causal = False @@ -293,6 +293,48 @@ def test_all(self): fetches_result[0], out_, rtol=5e-03, atol=1e-03 ) + def test_dot_scale_product(self): + print( + f"Test case shape {self.shape} dtype {self.dtype} causal {self.causal}" + ) + query = np.random.random(self.shape) + key = np.random.random(self.shape) + value = np.random.random(self.shape) + + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + k = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + v = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + q_ = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + k_ = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + v_ = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + mask_shape = (self.shape[0], 1, self.shape[1], self.shape[1]) + mask = np.random.random(mask_shape) + m = paddle.to_tensor( + mask, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + out = scaled_dot_product_attention( + q, k, v, m, self.dropout, self.causal, fixed_seed_offset=None + ) + out_ = attention_naive(q_, k_, v_, self.causal) + out.backward() + out_.backward() + np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03) + class TestFlashAttentionAPITest1(TestFlashAttentionAPI): def setUp(self): From faf07d7e8e423d3ac2319201b4bdd064d0486681 Mon Sep 17 00:00:00 2001 From: iosmers <1871465933@qq.com> Date: Fri, 4 Aug 2023 18:55:46 +0800 Subject: [PATCH 02/13] add backword --- paddle/phi/api/yaml/backward.yaml | 6 +- paddle/phi/kernels/flash_attn_grad_kernel.h | 2 + .../phi/kernels/gpu/flash_attn_grad_kernel.cu | 271 +++++++++++++----- paddle/phi/kernels/gpu/flash_attn_kernel.cu | 35 --- .../paddle/nn/functional/flash_attention.py | 24 +- test/legacy_test/test_flash_attention.py | 39 ++- 6 files changed, 240 insertions(+), 137 deletions(-) diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 7cc569bc07e50..b94a248a31c5c 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -819,7 +819,8 @@ - backward_op : flash_attn_grad forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset,Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) - args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false) + args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false) + optional : attn_mask output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) infer_meta : func : FlashAttnGradInferMeta @@ -830,7 +831,8 @@ - backward_op : flash_attn_unpadded_grad forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) - args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false) + args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false) + optional : attn_mask output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) infer_meta : func : FlashAttnGradInferMeta diff --git a/paddle/phi/kernels/flash_attn_grad_kernel.h b/paddle/phi/kernels/flash_attn_grad_kernel.h index ba3a6020e4545..ef5458f4708eb 100644 --- a/paddle/phi/kernels/flash_attn_grad_kernel.h +++ b/paddle/phi/kernels/flash_attn_grad_kernel.h @@ -29,6 +29,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, const DenseTensor& out, const DenseTensor& softmax_lse, const DenseTensor& seed_offset, + const paddle::optional& attn_mask, const DenseTensor& dout, int64_t max_seqlen_q, int64_t max_seqlen_k, @@ -47,6 +48,7 @@ void FlashAttnGradKernel(const Context& ctx, const DenseTensor& out, const DenseTensor& softmax_lse, const DenseTensor& seed_offset, + const paddle::optional& attn_mask, const DenseTensor& dout, float dropout, bool causal, diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index b75f4b4aea4b8..cf18b77602426 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/flash_attn_grad_kernel.h" #include "glog/logging.h" // For VLOG() #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/core/kernel_registry.h" @@ -31,6 +32,13 @@ DECLARE_bool(cudnn_deterministic); namespace phi { +template +__global__ void SimleScaleWithMaskKernel(int64_t numel, float scale, T* inout) { + CUDA_KERNEL_LOOP_TYPE(i, numel, int64_t) { + inout[i] = static_cast(scale * static_cast(inout[i])); + } +} + template void FlashAttnUnpaddedGradKernel(const Context& ctx, const DenseTensor& q, @@ -41,6 +49,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, const DenseTensor& out, const DenseTensor& softmax_lse, const DenseTensor& seed_offset, + const paddle::optional& attn_mask, const DenseTensor& dout, int64_t max_seqlen_q, int64_t max_seqlen_k, @@ -85,85 +94,189 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, DenseTensor dsoftmax = Empty(ctx, {batch_size, num_heads, seq_len_q}); uint64_t workspace_size; - - // calculate workspace size before execution - bool succ = phi::dynload::flash_attn_bwd( - q.data(), - k.data(), - v.data(), - dq->data(), - dk->data(), - dv->data(), - nullptr, // for calculation workspace size - dout.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - total_q, - total_k, - batch_size, - num_heads, - head_size, - max_seqlen_q, - max_seqlen_k, - dropout, - scale, - zero_tensors, - causal, - is_bf16, - num_splits, - const_cast(softmax_lse.data()), - dsoftmax.data(), - nullptr, - &workspace_size, - stream, - seed, - offset); - - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + bool succ; + + if (attn_mask.get_ptr()) { + scale = 1.0f; + std::vector temp_rand_mask_dim; + const DenseTensor* attn_mask_ptr = attn_mask.get_ptr(); + int64_t first_dim = 1; + const auto& origin_dims = attn_mask_ptr->dims(); + auto rank = origin_dims.size(); + for (int i = 0; i < rank - 3; i++) { + first_dim *= origin_dims[i]; + } + temp_rand_mask_dim = {first_dim, + origin_dims[rank - 3], + origin_dims[rank - 2], + origin_dims[rank - 1]}; + succ = phi::dynload::flash_attn_bwd_with_bias_and_mask( + static_cast(q.data()), + static_cast(k.data()), + static_cast(v.data()), + static_cast(dq->data()), + static_cast(dk->data()), + static_cast(dv->data()), + nullptr, // set out to nullptr to calculate workspace size + dout.data(), + static_cast(cu_seqlens_q.data()), + static_cast(cu_seqlens_k.data()), + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + scale, + zero_tensors, + is_bf16, + num_splits, + static_cast(softmax_lse.data()), + static_cast(dsoftmax.data()), + nullptr, + nullptr, + &workspace_size, + stream, + seed, + offset, + attn_mask_ptr ? attn_mask_ptr->data() : nullptr, + nullptr, + temp_rand_mask_dim.data() ? temp_rand_mask_dim.data() : nullptr, + nullptr); + PADDLE_ENFORCE_EQ( + succ, true, phi::errors::External(phi::dynload::flash_attn_error())); + + DenseTensor workspace; + if (workspace_size > 0) { + workspace = Empty(ctx, {int64_t(workspace_size / sizeof(float))}); + } + + succ = phi::dynload::flash_attn_bwd_with_bias_and_mask( + static_cast(q.data()), + static_cast(k.data()), + static_cast(v.data()), + static_cast(dq->data()), + static_cast(dk->data()), + static_cast(dv->data()), + out.data(), // set out to nullptr to calculate workspace size + dout.data(), + static_cast(cu_seqlens_q.data()), + static_cast(cu_seqlens_k.data()), + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + scale, + zero_tensors, + is_bf16, + num_splits, + static_cast(softmax_lse.data()), + static_cast(dsoftmax.data()), + nullptr, + workspace_size > 0 ? workspace.data() : nullptr, + &workspace_size, + stream, + seed, + offset, + attn_mask_ptr ? attn_mask_ptr->data() : nullptr, + nullptr, + temp_rand_mask_dim.data() ? temp_rand_mask_dim.data() : nullptr, + nullptr); + PADDLE_ENFORCE_EQ( + succ, true, phi::errors::External(phi::dynload::flash_attn_error())); + + int64_t q_size = total_q * num_heads * head_size; + auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, q_size, 1); + SimleScaleWithMaskKernel<<>>( + q_size, scale, static_cast(dq->data())); + } else { + // calculate workspace size before execution + succ = phi::dynload::flash_attn_bwd( + q.data(), + k.data(), + v.data(), + dq->data(), + dk->data(), + dv->data(), + nullptr, // for calculation workspace size + dout.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + scale, + zero_tensors, + causal, + is_bf16, + num_splits, + const_cast(softmax_lse.data()), + dsoftmax.data(), + nullptr, + &workspace_size, + stream, + seed, + offset); + + if (!succ) { + PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + } + + DenseTensor workspace; + if (workspace_size > 0) { + workspace = Empty(ctx, {int64_t(workspace_size / sizeof(float))}); + } + + succ = phi::dynload::flash_attn_bwd( + q.data(), + k.data(), + v.data(), + dq->data(), + dk->data(), + dv->data(), + out.data(), + dout.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + scale, + zero_tensors, + causal, + is_bf16, + num_splits, + const_cast(softmax_lse.data()), + dsoftmax.data(), + workspace_size > 0 ? workspace.data() : nullptr, + &workspace_size, + stream, + seed, + offset); + + if (!succ) { + PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + } } - - DenseTensor workspace; - if (workspace_size > 0) { - workspace = Empty(ctx, {int64_t(workspace_size / sizeof(float))}); - } - - succ = phi::dynload::flash_attn_bwd( - q.data(), - k.data(), - v.data(), - dq->data(), - dk->data(), - dv->data(), - out.data(), - dout.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - total_q, - total_k, - batch_size, - num_heads, - head_size, - max_seqlen_q, - max_seqlen_k, - dropout, - scale, - zero_tensors, - causal, - is_bf16, - num_splits, - const_cast(softmax_lse.data()), - dsoftmax.data(), - workspace_size > 0 ? workspace.data() : nullptr, - &workspace_size, - stream, - seed, - offset); - - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } - #endif } @@ -175,6 +288,7 @@ void FlashAttnGradKernel(const Context& ctx, const DenseTensor& out, const DenseTensor& softmax_lse, const DenseTensor& seed_offset, + const paddle::optional& attn_mask, const DenseTensor& dout, float dropout, bool causal, @@ -221,6 +335,7 @@ void FlashAttnGradKernel(const Context& ctx, out, softmax_lse, seed_offset, + attn_mask, dout, seq_len_q, seq_len_k, diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 6cfcc962dd709..d62370df0c14c 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -76,13 +76,11 @@ void FlashAttnUnpaddedKernel( DenseTensor* seed_offset) { #ifdef PADDLE_WITH_FLASHATTN if (is_test) dropout = 0.0f; - // printf("welcom to FlashAttnUnpaddedKernel\n"); ctx.template Alloc(out); cudaStream_t stream = ctx.stream(); bool is_bf16 = q.dtype() == DataType::BFLOAT16 ? true : false; - // printf("is_bf16\n", is_bf16); // q,k,v [total_*, num_heads, head_dim] @@ -157,10 +155,8 @@ void FlashAttnUnpaddedKernel( uint64_t workspace_size; bool succ; - // printf("welcome to flash_attention_with_mask\n"); if (attn_mask.get_ptr()) { // compute scale Q - // printf("compute scale Q start\n"); int64_t q_size = total_q * num_heads * head_size; auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, q_size, 1); @@ -172,10 +168,8 @@ void FlashAttnUnpaddedKernel( 0, ctx.stream()>>>(q_size, scale, q_ptr); scale = 1.0f; - // printf("compute scale Q finish\n"); std::vector temp_rand_mask_dim; const DenseTensor* attn_mask_ptr = attn_mask.get_ptr(); - // const const int64_t* attn_mask_ptr = attn_mask.get_ptr(); int64_t first_dim = 1; const auto& origin_dims = attn_mask_ptr->dims(); auto rank = origin_dims.size(); @@ -186,34 +180,6 @@ void FlashAttnUnpaddedKernel( origin_dims[rank - 3], origin_dims[rank - 2], origin_dims[rank - 1]}; - // printf("start to exec flash_attn_fwd_with_bias_and_mask\n"); - // succ =phi::dynload::flash_attn_fwd( - // q.data(), - // k.data(), - // v.data(), - // nullptr, // for calculation workspace size - // cu_seqlens_q.data(), - // cu_seqlens_k.data(), - // total_q, - // total_k, - // batch_size, - // num_heads, - // head_size, - // max_seqlen_q, - // max_seqlen_k, - // dropout, - // scale, - // zero_tensors, - // causal, - // is_bf16, - // num_splits, - // softmax_lse->data(), - // return_softmax ? softmax->data() : nullptr, - // nullptr, - // &workspace_size, - // stream, - // seed, - // offset); succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( static_cast(q_ptr), static_cast(k.data()), @@ -374,7 +340,6 @@ void FlashAttnKernel(const Context& ctx, DenseTensor* seed_offset) { #ifdef PADDLE_WITH_FLASHATTN // q,k,v [batch_size, seq_len, num_heads, head_dim] - // printf("welcome to FlashAttnKernel\n"); auto dims = q.dims(); PADDLE_ENFORCE_EQ(dims.size(), 4, diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 44ecf7144abb3..c4adf742b7614 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -417,6 +417,9 @@ def scaled_dot_product_attention( dropout_p=0.0, is_causal=False, fixed_seed_offset=None, + return_softmax=False, + training=True, + rng_name="", ): r""" The equation is: @@ -469,31 +472,14 @@ def scaled_dot_product_attention( if attn_mask is None: out, _ = flash_attention(query, key, value, dropout_p, is_causal) else: - # out, _ = _C_ops.flash_attn( - # query, - # key, - # value, - # fixed_seed_offset, - # attn_mask, - # dropout_p , - # is_causal, - # return_softmax = False, - # training = False, - # rng_name = "", - # ) - dropout = 0.0 - causal = False - return_softmax = False - training = True - rng_name = "" out, _ = _C_ops.flash_attn( query, key, value, fixed_seed_offset, attn_mask, - dropout, - causal, + dropout_p, + is_causal, return_softmax, not training, rng_name, diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index 8d8707a959997..b1f0856f9ddfc 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -57,6 +57,18 @@ def attention_naive(q, k, v, causal=False): return paddle.transpose(o, [0, 2, 1, 3]) +def attention_naive_with_mask(q, k, v, attn_bias): + qt = paddle.transpose(q, [0, 2, 1, 3]) + kt = paddle.transpose(k, [0, 2, 1, 3]) + vt = paddle.transpose(v, [0, 2, 1, 3]) + scale = 1.0 / np.sqrt(q.shape[-1]) + s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2])) + s = paddle.scale(s, scale) + p = F.softmax(s + attn_bias) + o = paddle.matmul(p, vt) + return paddle.transpose(o, [0, 2, 1, 3]) + + is_sm75 = ( core.is_compiled_with_cuda() and paddle.device.cuda.get_device_capability()[0] == 7 @@ -80,7 +92,7 @@ def attention_naive(q, k, v, causal=False): class TestFlashAttentionAPI(unittest.TestCase): def setUp(self): self.place = paddle.CUDAPlace(0) - self.shape = (2, 128, 8, 32) + self.shape = (2, 128, 8, 16) self.dtype = 'float16' self.dropout = 0.0 self.causal = False @@ -293,10 +305,22 @@ def test_all(self): fetches_result[0], out_, rtol=5e-03, atol=1e-03 ) + +class TestFlashAttentionWithMaskAPI(unittest.TestCase): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (2, 128, 8, 32) + self.dtype = 'float16' + self.dropout = 0.0 + self.causal = True + def test_dot_scale_product(self): print( - f"Test case shape {self.shape} dtype {self.dtype} causal {self.causal}" + f"Test flash attn mask case shape {self.shape} dtype {self.dtype} causal {self.causal}" ) + # test dynamic + paddle.disable_static() + query = np.random.random(self.shape) key = np.random.random(self.shape) value = np.random.random(self.shape) @@ -330,7 +354,7 @@ def test_dot_scale_product(self): out = scaled_dot_product_attention( q, k, v, m, self.dropout, self.causal, fixed_seed_offset=None ) - out_ = attention_naive(q_, k_, v_, self.causal) + out_ = attention_naive_with_mask(q_, k_, v_, m) out.backward() out_.backward() np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03) @@ -410,5 +434,14 @@ def setUp(self): self.enable_mem_efficient = False +class TestFlashAttrnionWithMaskAPI(TestFlashAttentionWithMaskAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (8, 1024, 16, 128) + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = True + + if __name__ == '__main__': unittest.main() From 43303fa796b8048baac80fa07887dc0874818bf1 Mon Sep 17 00:00:00 2001 From: iosmers <1871465933@qq.com> Date: Sat, 5 Aug 2023 16:01:25 +0800 Subject: [PATCH 03/13] add enforce info --- paddle/phi/api/yaml/backward.yaml | 2 +- .../phi/kernels/gpu/flash_attn_grad_kernel.cu | 3 +- paddle/phi/kernels/gpu/flash_attn_kernel.cu | 40 +++++++++---------- .../paddle/nn/functional/flash_attention.py | 6 +-- test/legacy_test/test_flash_attention.py | 13 ++++-- 5 files changed, 34 insertions(+), 30 deletions(-) diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index b94a248a31c5c..5d2253790ae39 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -818,7 +818,7 @@ inplace : (out_grad -> x_grad) - backward_op : flash_attn_grad - forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset,Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false) optional : attn_mask output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index cf18b77602426..97aef290f4785 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -95,8 +95,9 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, uint64_t workspace_size; bool succ; - if (attn_mask.get_ptr()) { + PADDLE_ENFORCE(causal != true, + "When attn_mask is not nullptr, causal can not be true"); scale = 1.0f; std::vector temp_rand_mask_dim; const DenseTensor* attn_mask_ptr = attn_mask.get_ptr(); diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index d62370df0c14c..0675fb1725d66 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -34,17 +34,6 @@ DECLARE_bool(cudnn_deterministic); namespace phi { -// template -// void ComputeScaleQ(const Context& ctx, const DenseTensor& q, int64_t numel, -// int64_t head_dim, T* q_ptr){ -// auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, -// 1); DenseTensor q_(q); T* q_ptr = static_cast(q_.data()); -// SimleScaleWithMaskKernel<<>>(q_size, scale, q_ptr); -// } - template __global__ void SimleScaleWithMaskKernel(int64_t numel, float scale, T* inout) { CUDA_KERNEL_LOOP_TYPE(i, numel, int64_t) { @@ -52,6 +41,15 @@ __global__ void SimleScaleWithMaskKernel(int64_t numel, float scale, T* inout) { } } +template +void ComputeScaleQ(const Context& ctx, int64_t numel, T* scale_q, float scale) { + auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 1); + SimleScaleWithMaskKernel<<>>(numel, scale, scale_q); +} + template void FlashAttnUnpaddedKernel( const Context& ctx, @@ -156,17 +154,15 @@ void FlashAttnUnpaddedKernel( uint64_t workspace_size; bool succ; if (attn_mask.get_ptr()) { - // compute scale Q - int64_t q_size = total_q * num_heads * head_size; + PADDLE_ENFORCE(causal != true, + "When attn_mask is not nullptr, causal can not be true"); - auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, q_size, 1); - DenseTensor q_(q); - T* q_ptr = static_cast(q_.data()); + int64_t q_size = total_q * num_heads * head_size; + DenseTensor scale_q; + scale_q.ShareDataWith(q).Resize({total_q, num_heads, head_size}); + // compute scale Q + ComputeScaleQ(ctx, q_size, scale_q.data(), scale); - SimleScaleWithMaskKernel<<>>(q_size, scale, q_ptr); scale = 1.0f; std::vector temp_rand_mask_dim; const DenseTensor* attn_mask_ptr = attn_mask.get_ptr(); @@ -181,7 +177,7 @@ void FlashAttnUnpaddedKernel( origin_dims[rank - 2], origin_dims[rank - 1]}; succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( - static_cast(q_ptr), + static_cast(scale_q.data()), static_cast(k.data()), static_cast(v.data()), nullptr, // for calculation workspace size @@ -218,7 +214,7 @@ void FlashAttnUnpaddedKernel( workspace = Empty(ctx, {int64_t(workspace_size / sizeof(float))}); } succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( - static_cast(q_ptr), + static_cast(scale_q.data()), k.data(), v.data(), out->data(), // set out to nullptr to calculate workspace size diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index c4adf742b7614..3a68be3c4b805 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -416,10 +416,7 @@ def scaled_dot_product_attention( attn_mask=None, dropout_p=0.0, is_causal=False, - fixed_seed_offset=None, - return_softmax=False, training=True, - rng_name="", ): r""" The equation is: @@ -472,6 +469,9 @@ def scaled_dot_product_attention( if attn_mask is None: out, _ = flash_attention(query, key, value, dropout_p, is_causal) else: + fixed_seed_offset = (None,) + return_softmax = False + rng_name = "" out, _ = _C_ops.flash_attn( query, key, diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index b1f0856f9ddfc..64d37e816cf00 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -306,13 +306,20 @@ def test_all(self): ) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or not is_sm_supported, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.3" + "and device's compute capability must be 7.5 or 8.x", +) class TestFlashAttentionWithMaskAPI(unittest.TestCase): def setUp(self): self.place = paddle.CUDAPlace(0) self.shape = (2, 128, 8, 32) self.dtype = 'float16' self.dropout = 0.0 - self.causal = True + self.causal = False def test_dot_scale_product(self): print( @@ -352,7 +359,7 @@ def test_dot_scale_product(self): ) out = scaled_dot_product_attention( - q, k, v, m, self.dropout, self.causal, fixed_seed_offset=None + q, k, v, m, self.dropout, self.causal ) out_ = attention_naive_with_mask(q_, k_, v_, m) out.backward() @@ -440,7 +447,7 @@ def setUp(self): self.shape = (8, 1024, 16, 128) self.dtype = paddle.float16 self.dropout = 0.0 - self.causal = True + self.causal = False if __name__ == '__main__': From 99ff7fd4102b010fa1f4841b5852b9fc4243f617 Mon Sep 17 00:00:00 2001 From: iosmers <1871465933@qq.com> Date: Sat, 5 Aug 2023 20:46:50 +0800 Subject: [PATCH 04/13] update scale --- .../phi/kernels/gpu/flash_attn_grad_kernel.cu | 18 ++++--- paddle/phi/kernels/gpu/flash_attn_kernel.cu | 47 ++++++++++++------- 2 files changed, 40 insertions(+), 25 deletions(-) diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 97aef290f4785..bace388ddf555 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -96,9 +96,11 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, uint64_t workspace_size; bool succ; if (attn_mask.get_ptr()) { - PADDLE_ENFORCE(causal != true, - "When attn_mask is not nullptr, causal can not be true"); - scale = 1.0f; + PADDLE_ENFORCE_NE(causal, + true, + phi::errors::InvalidArgument( + "attn_mask is not nullptr, causal can not be true")); + float fa_with_mask_scale = 1.0f; std::vector temp_rand_mask_dim; const DenseTensor* attn_mask_ptr = attn_mask.get_ptr(); int64_t first_dim = 1; @@ -130,7 +132,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, max_seqlen_q, max_seqlen_k, dropout, - scale, + fa_with_mask_scale, zero_tensors, is_bf16, num_splits, @@ -151,7 +153,8 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, DenseTensor workspace; if (workspace_size > 0) { - workspace = Empty(ctx, {int64_t(workspace_size / sizeof(float))}); + workspace = Empty( + ctx, {static_cast(workspace_size / sizeof(float))}); } succ = phi::dynload::flash_attn_bwd_with_bias_and_mask( @@ -173,7 +176,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, max_seqlen_q, max_seqlen_k, dropout, - scale, + fa_with_mask_scale, zero_tensors, is_bf16, num_splits, @@ -239,7 +242,8 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, DenseTensor workspace; if (workspace_size > 0) { - workspace = Empty(ctx, {int64_t(workspace_size / sizeof(float))}); + workspace = Empty( + ctx, {static_cast(workspace_size / sizeof(float))}); } succ = phi::dynload::flash_attn_bwd( diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 0675fb1725d66..79c1c99297387 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -35,19 +35,23 @@ DECLARE_bool(cudnn_deterministic); namespace phi { template -__global__ void SimleScaleWithMaskKernel(int64_t numel, float scale, T* inout) { +__global__ void SimleScaleWithMaskKernel(int64_t numel, + float scale, + const T* input, + T* ouput) { CUDA_KERNEL_LOOP_TYPE(i, numel, int64_t) { - inout[i] = static_cast(scale * static_cast(inout[i])); + ouput[i] = static_cast(scale * static_cast(input[i])); } } template -void ComputeScaleQ(const Context& ctx, int64_t numel, T* scale_q, float scale) { +void ComputeScaleQ( + const Context& ctx, int64_t numel, float scale, const T* input, T* output) { auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 1); SimleScaleWithMaskKernel<<>>(numel, scale, scale_q); + ctx.stream()>>>(numel, scale, input, output); } template @@ -151,19 +155,24 @@ void FlashAttnUnpaddedKernel( ctx.template Alloc(softmax); } - uint64_t workspace_size; + uint64_t workspace_size = 0; bool succ; + DenseTensor workspace; + if (attn_mask.get_ptr()) { - PADDLE_ENFORCE(causal != true, - "When attn_mask is not nullptr, causal can not be true"); + PADDLE_ENFORCE_NE(causal, + true, + phi::errors::InvalidArgument( + "attn_mask is not nullptr, causal can not be true")); int64_t q_size = total_q * num_heads * head_size; - DenseTensor scale_q; - scale_q.ShareDataWith(q).Resize({total_q, num_heads, head_size}); + DenseTensor* scale_q = new DenseTensor; + scale_q->Resize({total_q, num_heads, head_size}); + ctx.template Alloc(scale_q); // compute scale Q - ComputeScaleQ(ctx, q_size, scale_q.data(), scale); + ComputeScaleQ(ctx, q_size, scale, q.data(), scale_q->data()); - scale = 1.0f; + float fa_with_mask_scale = 1.0f; std::vector temp_rand_mask_dim; const DenseTensor* attn_mask_ptr = attn_mask.get_ptr(); int64_t first_dim = 1; @@ -177,7 +186,7 @@ void FlashAttnUnpaddedKernel( origin_dims[rank - 2], origin_dims[rank - 1]}; succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( - static_cast(scale_q.data()), + static_cast(scale_q->data()), static_cast(k.data()), static_cast(v.data()), nullptr, // for calculation workspace size @@ -191,7 +200,7 @@ void FlashAttnUnpaddedKernel( max_seqlen_q, max_seqlen_k, dropout, - scale, + fa_with_mask_scale, zero_tensors, is_bf16, num_splits, @@ -209,12 +218,12 @@ void FlashAttnUnpaddedKernel( PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); } - DenseTensor workspace; if (workspace_size > 0) { - workspace = Empty(ctx, {int64_t(workspace_size / sizeof(float))}); + workspace = Empty( + ctx, {static_cast(workspace_size / sizeof(float))}); } succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( - static_cast(scale_q.data()), + static_cast(scale_q->data()), k.data(), v.data(), out->data(), // set out to nullptr to calculate workspace size @@ -228,7 +237,7 @@ void FlashAttnUnpaddedKernel( max_seqlen_q, max_seqlen_k, dropout, - scale, + fa_with_mask_scale, zero_tensors, is_bf16, num_splits, @@ -245,6 +254,7 @@ void FlashAttnUnpaddedKernel( if (!succ) { PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); } + delete scale_q; } else { succ = phi::dynload::flash_attn_fwd(q.data(), @@ -279,7 +289,8 @@ void FlashAttnUnpaddedKernel( DenseTensor workspace; if (workspace_size > 0) { - workspace = Empty(ctx, {int64_t(workspace_size / sizeof(float))}); + workspace = Empty( + ctx, {static_cast(workspace_size / sizeof(float))}); } succ = phi::dynload::flash_attn_fwd( From 130034c145910fd3004e48c1e75caabf29133962 Mon Sep 17 00:00:00 2001 From: iosmers <1871465933@qq.com> Date: Sun, 6 Aug 2023 15:09:09 +0800 Subject: [PATCH 05/13] integrate code --- paddle/phi/kernels/flash_attn_kernel.h | 68 ++++ .../phi/kernels/gpu/flash_attn_grad_kernel.cu | 25 +- paddle/phi/kernels/gpu/flash_attn_kernel.cu | 385 ++++++++++++------ .../paddle/nn/functional/flash_attention.py | 1 + test/legacy_test/test_flash_attention.py | 13 +- 5 files changed, 352 insertions(+), 140 deletions(-) diff --git a/paddle/phi/kernels/flash_attn_kernel.h b/paddle/phi/kernels/flash_attn_kernel.h index ec72d85a0babb..21d589843eebd 100644 --- a/paddle/phi/kernels/flash_attn_kernel.h +++ b/paddle/phi/kernels/flash_attn_kernel.h @@ -19,6 +19,74 @@ namespace phi { +template +void FlashAttnFwdWithBiasAndMask( + const Context& ctx, + const void* + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const void* + k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void* + v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void* + out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const int32_t* + cu_seqlens_q, // int32, batch_size+1, starting offset of each sequence + const int32_t* + cu_seqlens_k, // int32, batch_size+1, starting offset of each sequence + const int total_q, + const int total_k, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_q, + const int max_seqlen_k, + const float dropout, + const float scale, + const bool zero_tensors, + const bool is_bf16, + const int num_splits, // SMs per attention matrix, can be 1 + void* softmax_lse_ptr, // softmax log_sum_exp + cudaStream_t stream, + uint64_t seed, + uint64_t offset, + const void* attn_mask, + const int64_t* mask_dims); + +template +void FlashAttnFwd( + const Context& ctx, + const void* + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const void* + k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void* + v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void* + out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void* + cu_seqlens_q, // int32, batch_size+1, starting offset of each sequence + const void* + cu_seqlens_k, // int32, batch_size+1, starting offset of each sequence + const int total_q, + const int total_k, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_q, + const int max_seqlen_k, + const float dropout, + const float scale, + const bool zero_tensors, + const bool causal, + const bool is_bf16, + const int num_splits, // SMs per attention matrix, can be 1 + void* softmax_lse_ptr, // softmax log_sum_exp + const bool return_softmax, + cudaStream_t stream, + uint64_t seed, + uint64_t offset); + template void FlashAttnUnpaddedKernel( const Context& ctx, diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index bace388ddf555..ef2330cf4891b 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -100,6 +100,11 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, true, phi::errors::InvalidArgument( "attn_mask is not nullptr, causal can not be true")); + PADDLE_ENFORCE_NE( + head_size, + 32 || 64 || 128, + phi::errors::InvalidArgument( + "Currently, the mask only supports head_dim of 32, 64, and 128")); float fa_with_mask_scale = 1.0f; std::vector temp_rand_mask_dim; const DenseTensor* attn_mask_ptr = attn_mask.get_ptr(); @@ -148,8 +153,14 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, nullptr, temp_rand_mask_dim.data() ? temp_rand_mask_dim.data() : nullptr, nullptr); - PADDLE_ENFORCE_EQ( - succ, true, phi::errors::External(phi::dynload::flash_attn_error())); + // PADDLE_ENFORCE_EQ( + // succ, true, phi::errors::External(phi::dynload::flash_attn_error())); + if (!succ) { + PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + } + // PADDLE_ENFORCE_EQ( + // succ, true, "Error in Flash-Attention, detail information is ", + // phi::dynload::flash_attn_error()); DenseTensor workspace; if (workspace_size > 0) { @@ -192,8 +203,14 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, nullptr, temp_rand_mask_dim.data() ? temp_rand_mask_dim.data() : nullptr, nullptr); - PADDLE_ENFORCE_EQ( - succ, true, phi::errors::External(phi::dynload::flash_attn_error())); + // PADDLE_ENFORCE_EQ( + // succ, true, phi::errors::External(phi::dynload::flash_attn_error())); + // PADDLE_ENFORCE_EQ( + // succ, true, "Error in Flash-Attention, detail information is ", + // phi::dynload::flash_attn_error()); + if (!succ) { + PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + } int64_t q_size = total_q * num_heads * head_size; auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, q_size, 1); diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 79c1c99297387..a6a206d8770bf 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -54,6 +54,216 @@ void ComputeScaleQ( ctx.stream()>>>(numel, scale, input, output); } +template +void FlashAttnFwdWithBiasAndMask( + const Context& ctx, + const void* + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const void* + k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void* + v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void* + out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const int32_t* + cu_seqlens_q, // int32, batch_size+1, starting offset of each sequence + const int32_t* + cu_seqlens_k, // int32, batch_size+1, starting offset of each sequence + const int total_q, + const int total_k, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_q, + const int max_seqlen_k, + const float dropout, + const float scale, + const bool zero_tensors, + const bool is_bf16, + const int num_splits, // SMs per attention matrix, can be 1 + void* softmax_lse_ptr, // softmax log_sum_exp + cudaStream_t stream, + uint64_t seed, + uint64_t offset, + const void* attn_mask, + const int64_t* mask_dims) { + // to get workspace,these are temp variable + DenseTensor workspace; + uint64_t workspace_size = 0; + bool succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( + q, + k, + v, + nullptr, // for calculation workspace size + cu_seqlens_q, + cu_seqlens_k, + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + scale, + zero_tensors, + is_bf16, + num_splits, + softmax_lse_ptr, + nullptr, + &workspace_size, + stream, + seed, + offset, + attn_mask, + nullptr, + mask_dims, + nullptr); + if (!succ) { + PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + } + + if (workspace_size > 0) { + workspace = Empty( + ctx, {static_cast((workspace_size) / sizeof(float))}); + } + succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( + q, + k, + v, + out, // set out to nullptr to calculate workspace size + cu_seqlens_q, + cu_seqlens_k, + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + scale, + zero_tensors, + is_bf16, + num_splits, + softmax_lse_ptr, + workspace_size > 0 ? workspace.data() : nullptr, + &workspace_size, + stream, + seed, + offset, + attn_mask, + nullptr, + mask_dims, + nullptr); + if (!succ) { + PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + } +} + +template +void FlashAttnFwd( + const Context& ctx, + const void* + q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const void* + k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void* + v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void* + out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void* + cu_seqlens_q, // int32, batch_size+1, starting offset of each sequence + const void* + cu_seqlens_k, // int32, batch_size+1, starting offset of each sequence + const int total_q, + const int total_k, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_q, + const int max_seqlen_k, + const float dropout, + const float scale, + const bool zero_tensors, + const bool causal, + const bool is_bf16, + const int num_splits, // SMs per attention matrix, can be 1 + void* softmax_lse_ptr, // softmax log_sum_exp + const bool return_softmax, + cudaStream_t stream, + uint64_t seed, + uint64_t offset) { + DenseTensor workspace; + uint64_t workspace_size = 0; + bool succ = + phi::dynload::flash_attn_fwd(q, + k, + v, + nullptr, // for calculation workspace size + cu_seqlens_q, + cu_seqlens_k, + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + scale, + zero_tensors, + causal, + is_bf16, + num_splits, + softmax_lse_ptr, + return_softmax ? softmax_lse_ptr : nullptr, + nullptr, + &workspace_size, + stream, + seed, + offset); + if (!succ) { + PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + } + if (workspace_size > 0) { + workspace = Empty( + ctx, {static_cast(workspace_size / sizeof(float))}); + } + + succ = phi::dynload::flash_attn_fwd( + q, + k, + v, + out, + cu_seqlens_q, + cu_seqlens_k, + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + scale, + zero_tensors, + causal, + is_bf16, + num_splits, + softmax_lse_ptr, + return_softmax ? softmax_lse_ptr : nullptr, + workspace_size > 0 ? workspace.data() : nullptr, + &workspace_size, + stream, + seed, + offset); + + if (!succ) { + PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); + } +} + template void FlashAttnUnpaddedKernel( const Context& ctx, @@ -156,7 +366,6 @@ void FlashAttnUnpaddedKernel( } uint64_t workspace_size = 0; - bool succ; DenseTensor workspace; if (attn_mask.get_ptr()) { @@ -164,16 +373,21 @@ void FlashAttnUnpaddedKernel( true, phi::errors::InvalidArgument( "attn_mask is not nullptr, causal can not be true")); + PADDLE_ENFORCE_NE( + head_size, + 32 || 64 || 128, + phi::errors::InvalidArgument( + "Currently, the mask only supports head_dim of 32, 64, and 128")); int64_t q_size = total_q * num_heads * head_size; - DenseTensor* scale_q = new DenseTensor; - scale_q->Resize({total_q, num_heads, head_size}); - ctx.template Alloc(scale_q); + DenseTensor scale_q; + scale_q.Resize({total_q, num_heads, head_size}); + ctx.template Alloc(&scale_q); // compute scale Q - ComputeScaleQ(ctx, q_size, scale, q.data(), scale_q->data()); + ComputeScaleQ(ctx, q_size, scale, q.data(), scale_q.data()); float fa_with_mask_scale = 1.0f; - std::vector temp_rand_mask_dim; + std::vector rand_mask_dim; const DenseTensor* attn_mask_ptr = attn_mask.get_ptr(); int64_t first_dim = 1; const auto& origin_dims = attn_mask_ptr->dims(); @@ -181,15 +395,16 @@ void FlashAttnUnpaddedKernel( for (int i = 0; i < rank - 3; i++) { first_dim *= origin_dims[i]; } - temp_rand_mask_dim = {first_dim, - origin_dims[rank - 3], - origin_dims[rank - 2], - origin_dims[rank - 1]}; - succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( - static_cast(scale_q->data()), + rand_mask_dim = {first_dim, + origin_dims[rank - 3], + origin_dims[rank - 2], + origin_dims[rank - 1]}; + FlashAttnFwdWithBiasAndMask( + ctx, + static_cast(scale_q.data()), static_cast(k.data()), static_cast(v.data()), - nullptr, // for calculation workspace size + static_cast(out->data()), // for calculation workspace size static_cast(cu_seqlens_q.data()), static_cast(cu_seqlens_k.data()), total_q, @@ -204,126 +419,38 @@ void FlashAttnUnpaddedKernel( zero_tensors, is_bf16, num_splits, - softmax_lse->data(), - nullptr, - &workspace_size, + static_cast(softmax_lse->data()), stream, seed, offset, - attn_mask_ptr ? attn_mask_ptr->data() : nullptr, - nullptr, - temp_rand_mask_dim.data() ? temp_rand_mask_dim.data() : nullptr, - nullptr); - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } - - if (workspace_size > 0) { - workspace = Empty( - ctx, {static_cast(workspace_size / sizeof(float))}); - } - succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( - static_cast(scale_q->data()), - k.data(), - v.data(), - out->data(), // set out to nullptr to calculate workspace size - static_cast(cu_seqlens_q.data()), - static_cast(cu_seqlens_k.data()), - total_q, - total_k, - batch_size, - num_heads, - head_size, - max_seqlen_q, - max_seqlen_k, - dropout, - fa_with_mask_scale, - zero_tensors, - is_bf16, - num_splits, - softmax_lse->data(), - workspace_size > 0 ? workspace.data() : nullptr, - &workspace_size, - stream, - seed, - offset, - attn_mask_ptr ? attn_mask_ptr->data() : nullptr, - nullptr, - temp_rand_mask_dim.data() ? temp_rand_mask_dim.data() : nullptr, - nullptr); - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } - delete scale_q; + attn_mask_ptr->data(), + static_cast(rand_mask_dim.data())); } else { - succ = - phi::dynload::flash_attn_fwd(q.data(), - k.data(), - v.data(), - nullptr, // for calculation workspace size - cu_seqlens_q.data(), - cu_seqlens_k.data(), - total_q, - total_k, - batch_size, - num_heads, - head_size, - max_seqlen_q, - max_seqlen_k, - dropout, - scale, - zero_tensors, - causal, - is_bf16, - num_splits, - softmax_lse->data(), - return_softmax ? softmax->data() : nullptr, - nullptr, - &workspace_size, - stream, - seed, - offset); - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } - - DenseTensor workspace; - if (workspace_size > 0) { - workspace = Empty( - ctx, {static_cast(workspace_size / sizeof(float))}); - } - - succ = phi::dynload::flash_attn_fwd( - q.data(), - k.data(), - v.data(), - out->data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - total_q, - total_k, - batch_size, - num_heads, - head_size, - max_seqlen_q, - max_seqlen_k, - dropout, - scale, - zero_tensors, - causal, - is_bf16, - num_splits, - softmax_lse->data(), - return_softmax ? softmax->data() : nullptr, - workspace_size > 0 ? workspace.data() : nullptr, - &workspace_size, - stream, - seed, - offset); - - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } + FlashAttnFwd(ctx, + q.data(), + k.data(), + v.data(), + out->data(), + static_cast(cu_seqlens_q.data()), + static_cast(cu_seqlens_k.data()), + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + scale, + zero_tensors, + causal, + is_bf16, + num_splits, + softmax_lse->data(), + return_softmax, + stream, + seed, + offset); } #endif diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 3a68be3c4b805..b9077896da6c5 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -450,6 +450,7 @@ def scaled_dot_product_attention( not supported yet. dropout_p(float): The dropout ratio. is_causal(bool): Whether enable causal mode. + training(bool): Whether it is in the training phase Returns: out(Tensor): The attention tensor. diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index 64d37e816cf00..c5b99ac575fb1 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -306,13 +306,6 @@ def test_all(self): ) -@unittest.skipIf( - not core.is_compiled_with_cuda() - or get_cuda_version() < 11030 - or not is_sm_supported, - "core is not compiled with CUDA and cuda version need larger than or equal to 11.3" - "and device's compute capability must be 7.5 or 8.x", -) class TestFlashAttentionWithMaskAPI(unittest.TestCase): def setUp(self): self.place = paddle.CUDAPlace(0) @@ -322,6 +315,12 @@ def setUp(self): self.causal = False def test_dot_scale_product(self): + if ( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or not is_sm_supported + ): + pass print( f"Test flash attn mask case shape {self.shape} dtype {self.dtype} causal {self.causal}" ) From 61fe34d1589df67ce8e100cae678f2b21a692da6 Mon Sep 17 00:00:00 2001 From: iosmers <1871465933@qq.com> Date: Sun, 6 Aug 2023 15:20:09 +0800 Subject: [PATCH 06/13] update enforce --- .../phi/kernels/gpu/flash_attn_grad_kernel.cu | 26 +++++++------------ paddle/phi/kernels/gpu/flash_attn_kernel.cu | 14 +++++----- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index ef2330cf4891b..4e6ed44faf3bc 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -153,14 +153,11 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, nullptr, temp_rand_mask_dim.data() ? temp_rand_mask_dim.data() : nullptr, nullptr); - // PADDLE_ENFORCE_EQ( - // succ, true, phi::errors::External(phi::dynload::flash_attn_error())); - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } - // PADDLE_ENFORCE_EQ( - // succ, true, "Error in Flash-Attention, detail information is ", - // phi::dynload::flash_attn_error()); + + PADDLE_ENFORCE_EQ(succ, + true, + "Error in Flash-Attention, detail information is ", + phi::dynload::flash_attn_error()); DenseTensor workspace; if (workspace_size > 0) { @@ -203,14 +200,11 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, nullptr, temp_rand_mask_dim.data() ? temp_rand_mask_dim.data() : nullptr, nullptr); - // PADDLE_ENFORCE_EQ( - // succ, true, phi::errors::External(phi::dynload::flash_attn_error())); - // PADDLE_ENFORCE_EQ( - // succ, true, "Error in Flash-Attention, detail information is ", - // phi::dynload::flash_attn_error()); - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } + + PADDLE_ENFORCE_EQ(succ, + true, + "Error in Flash-Attention, detail information is ", + phi::dynload::flash_attn_error()); int64_t q_size = total_q * num_heads * head_size; auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, q_size, 1); diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index a6a206d8770bf..269f751eb6627 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -223,9 +223,10 @@ void FlashAttnFwd( stream, seed, offset); - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } + PADDLE_ENFORCE_EQ(succ, + true, + "Error in Flash-Attention, detail information is", + phi::dynload::flash_attn_error()); if (workspace_size > 0) { workspace = Empty( ctx, {static_cast(workspace_size / sizeof(float))}); @@ -259,9 +260,10 @@ void FlashAttnFwd( seed, offset); - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } + PADDLE_ENFORCE_EQ(succ, + true, + "Error in Flash-Attention, detail information is", + phi::dynload::flash_attn_error()); } template From fd4ce6ad9e376fc20734f4d68a0e13374b399f23 Mon Sep 17 00:00:00 2001 From: iosmers <1871465933@qq.com> Date: Sun, 6 Aug 2023 16:26:54 +0800 Subject: [PATCH 07/13] add enforce eq --- paddle/phi/kernels/flash_attn_kernel.h | 68 ---- .../phi/kernels/gpu/flash_attn_grad_kernel.cu | 21 +- paddle/phi/kernels/gpu/flash_attn_kernel.cu | 371 ++++++------------ 3 files changed, 137 insertions(+), 323 deletions(-) diff --git a/paddle/phi/kernels/flash_attn_kernel.h b/paddle/phi/kernels/flash_attn_kernel.h index 21d589843eebd..ec72d85a0babb 100644 --- a/paddle/phi/kernels/flash_attn_kernel.h +++ b/paddle/phi/kernels/flash_attn_kernel.h @@ -19,74 +19,6 @@ namespace phi { -template -void FlashAttnFwdWithBiasAndMask( - const Context& ctx, - const void* - q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const void* - k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const void* - v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - void* - out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const int32_t* - cu_seqlens_q, // int32, batch_size+1, starting offset of each sequence - const int32_t* - cu_seqlens_k, // int32, batch_size+1, starting offset of each sequence - const int total_q, - const int total_k, - const int batch_size, - const int num_heads, - const int head_size, - const int max_seqlen_q, - const int max_seqlen_k, - const float dropout, - const float scale, - const bool zero_tensors, - const bool is_bf16, - const int num_splits, // SMs per attention matrix, can be 1 - void* softmax_lse_ptr, // softmax log_sum_exp - cudaStream_t stream, - uint64_t seed, - uint64_t offset, - const void* attn_mask, - const int64_t* mask_dims); - -template -void FlashAttnFwd( - const Context& ctx, - const void* - q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const void* - k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const void* - v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - void* - out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const void* - cu_seqlens_q, // int32, batch_size+1, starting offset of each sequence - const void* - cu_seqlens_k, // int32, batch_size+1, starting offset of each sequence - const int total_q, - const int total_k, - const int batch_size, - const int num_heads, - const int head_size, - const int max_seqlen_q, - const int max_seqlen_k, - const float dropout, - const float scale, - const bool zero_tensors, - const bool causal, - const bool is_bf16, - const int num_splits, // SMs per attention matrix, can be 1 - void* softmax_lse_ptr, // softmax log_sum_exp - const bool return_softmax, - cudaStream_t stream, - uint64_t seed, - uint64_t offset); - template void FlashAttnUnpaddedKernel( const Context& ctx, diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 4e6ed44faf3bc..e2189edf9060e 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -100,9 +100,10 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, true, phi::errors::InvalidArgument( "attn_mask is not nullptr, causal can not be true")); - PADDLE_ENFORCE_NE( - head_size, - 32 || 64 || 128, + bool flag = (head_size == 32 || head_size == 64 || head_size == 128); + PADDLE_ENFORCE_EQ( + flag, + true, phi::errors::InvalidArgument( "Currently, the mask only supports head_dim of 32, 64, and 128")); float fa_with_mask_scale = 1.0f; @@ -247,9 +248,10 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, seed, offset); - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } + PADDLE_ENFORCE_EQ(succ, + true, + "Error in Flash-Attention, detail information is ", + phi::dynload::flash_attn_error()); DenseTensor workspace; if (workspace_size > 0) { @@ -289,9 +291,10 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, seed, offset); - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } + PADDLE_ENFORCE_EQ(succ, + true, + "Error in Flash-Attention, detail information is ", + phi::dynload::flash_attn_error()); } #endif } diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 269f751eb6627..a1281f16e0e11 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -54,218 +54,6 @@ void ComputeScaleQ( ctx.stream()>>>(numel, scale, input, output); } -template -void FlashAttnFwdWithBiasAndMask( - const Context& ctx, - const void* - q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const void* - k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const void* - v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - void* - out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const int32_t* - cu_seqlens_q, // int32, batch_size+1, starting offset of each sequence - const int32_t* - cu_seqlens_k, // int32, batch_size+1, starting offset of each sequence - const int total_q, - const int total_k, - const int batch_size, - const int num_heads, - const int head_size, - const int max_seqlen_q, - const int max_seqlen_k, - const float dropout, - const float scale, - const bool zero_tensors, - const bool is_bf16, - const int num_splits, // SMs per attention matrix, can be 1 - void* softmax_lse_ptr, // softmax log_sum_exp - cudaStream_t stream, - uint64_t seed, - uint64_t offset, - const void* attn_mask, - const int64_t* mask_dims) { - // to get workspace,these are temp variable - DenseTensor workspace; - uint64_t workspace_size = 0; - bool succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( - q, - k, - v, - nullptr, // for calculation workspace size - cu_seqlens_q, - cu_seqlens_k, - total_q, - total_k, - batch_size, - num_heads, - head_size, - max_seqlen_q, - max_seqlen_k, - dropout, - scale, - zero_tensors, - is_bf16, - num_splits, - softmax_lse_ptr, - nullptr, - &workspace_size, - stream, - seed, - offset, - attn_mask, - nullptr, - mask_dims, - nullptr); - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } - - if (workspace_size > 0) { - workspace = Empty( - ctx, {static_cast((workspace_size) / sizeof(float))}); - } - succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( - q, - k, - v, - out, // set out to nullptr to calculate workspace size - cu_seqlens_q, - cu_seqlens_k, - total_q, - total_k, - batch_size, - num_heads, - head_size, - max_seqlen_q, - max_seqlen_k, - dropout, - scale, - zero_tensors, - is_bf16, - num_splits, - softmax_lse_ptr, - workspace_size > 0 ? workspace.data() : nullptr, - &workspace_size, - stream, - seed, - offset, - attn_mask, - nullptr, - mask_dims, - nullptr); - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } -} - -template -void FlashAttnFwd( - const Context& ctx, - const void* - q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const void* - k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const void* - v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - void* - out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const void* - cu_seqlens_q, // int32, batch_size+1, starting offset of each sequence - const void* - cu_seqlens_k, // int32, batch_size+1, starting offset of each sequence - const int total_q, - const int total_k, - const int batch_size, - const int num_heads, - const int head_size, - const int max_seqlen_q, - const int max_seqlen_k, - const float dropout, - const float scale, - const bool zero_tensors, - const bool causal, - const bool is_bf16, - const int num_splits, // SMs per attention matrix, can be 1 - void* softmax_lse_ptr, // softmax log_sum_exp - const bool return_softmax, - cudaStream_t stream, - uint64_t seed, - uint64_t offset) { - DenseTensor workspace; - uint64_t workspace_size = 0; - bool succ = - phi::dynload::flash_attn_fwd(q, - k, - v, - nullptr, // for calculation workspace size - cu_seqlens_q, - cu_seqlens_k, - total_q, - total_k, - batch_size, - num_heads, - head_size, - max_seqlen_q, - max_seqlen_k, - dropout, - scale, - zero_tensors, - causal, - is_bf16, - num_splits, - softmax_lse_ptr, - return_softmax ? softmax_lse_ptr : nullptr, - nullptr, - &workspace_size, - stream, - seed, - offset); - PADDLE_ENFORCE_EQ(succ, - true, - "Error in Flash-Attention, detail information is", - phi::dynload::flash_attn_error()); - if (workspace_size > 0) { - workspace = Empty( - ctx, {static_cast(workspace_size / sizeof(float))}); - } - - succ = phi::dynload::flash_attn_fwd( - q, - k, - v, - out, - cu_seqlens_q, - cu_seqlens_k, - total_q, - total_k, - batch_size, - num_heads, - head_size, - max_seqlen_q, - max_seqlen_k, - dropout, - scale, - zero_tensors, - causal, - is_bf16, - num_splits, - softmax_lse_ptr, - return_softmax ? softmax_lse_ptr : nullptr, - workspace_size > 0 ? workspace.data() : nullptr, - &workspace_size, - stream, - seed, - offset); - - PADDLE_ENFORCE_EQ(succ, - true, - "Error in Flash-Attention, detail information is", - phi::dynload::flash_attn_error()); -} - template void FlashAttnUnpaddedKernel( const Context& ctx, @@ -375,9 +163,10 @@ void FlashAttnUnpaddedKernel( true, phi::errors::InvalidArgument( "attn_mask is not nullptr, causal can not be true")); - PADDLE_ENFORCE_NE( - head_size, - 32 || 64 || 128, + bool flag = (head_size == 32 || head_size == 64 || head_size == 128); + PADDLE_ENFORCE_EQ( + flag, + true, phi::errors::InvalidArgument( "Currently, the mask only supports head_dim of 32, 64, and 128")); @@ -401,12 +190,12 @@ void FlashAttnUnpaddedKernel( origin_dims[rank - 3], origin_dims[rank - 2], origin_dims[rank - 1]}; - FlashAttnFwdWithBiasAndMask( - ctx, + + bool succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( static_cast(scale_q.data()), static_cast(k.data()), static_cast(v.data()), - static_cast(out->data()), // for calculation workspace size + nullptr, // for calculation workspace size static_cast(cu_seqlens_q.data()), static_cast(cu_seqlens_k.data()), total_q, @@ -421,38 +210,128 @@ void FlashAttnUnpaddedKernel( zero_tensors, is_bf16, num_splits, - static_cast(softmax_lse->data()), + softmax_lse->data(), + nullptr, + &workspace_size, stream, seed, offset, - attn_mask_ptr->data(), - static_cast(rand_mask_dim.data())); + attn_mask_ptr ? attn_mask_ptr->data() : nullptr, + nullptr, + rand_mask_dim.data() ? rand_mask_dim.data() : nullptr, + nullptr); + PADDLE_ENFORCE_EQ(succ, + true, + "Error in Flash-Attention, detail information is ", + phi::dynload::flash_attn_error()); + + if (workspace_size > 0) { + workspace = Empty( + ctx, {static_cast(workspace_size / sizeof(float))}); + } + succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( + static_cast(scale_q.data()), + k.data(), + v.data(), + out->data(), // set out to nullptr to calculate workspace size + static_cast(cu_seqlens_q.data()), + static_cast(cu_seqlens_k.data()), + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + fa_with_mask_scale, + zero_tensors, + is_bf16, + num_splits, + softmax_lse->data(), + workspace_size > 0 ? workspace.data() : nullptr, + &workspace_size, + stream, + seed, + offset, + attn_mask_ptr ? attn_mask_ptr->data() : nullptr, + nullptr, + rand_mask_dim.data() ? rand_mask_dim.data() : nullptr, + nullptr); + PADDLE_ENFORCE_EQ(succ, + true, + "Error in Flash-Attention, detail information is ", + phi::dynload::flash_attn_error()); } else { - FlashAttnFwd(ctx, - q.data(), - k.data(), - v.data(), - out->data(), - static_cast(cu_seqlens_q.data()), - static_cast(cu_seqlens_k.data()), - total_q, - total_k, - batch_size, - num_heads, - head_size, - max_seqlen_q, - max_seqlen_k, - dropout, - scale, - zero_tensors, - causal, - is_bf16, - num_splits, - softmax_lse->data(), - return_softmax, - stream, - seed, - offset); + bool succ = + phi::dynload::flash_attn_fwd(q.data(), + k.data(), + v.data(), + nullptr, // for calculation workspace size + cu_seqlens_q.data(), + cu_seqlens_k.data(), + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + scale, + zero_tensors, + causal, + is_bf16, + num_splits, + softmax_lse->data(), + return_softmax ? softmax->data() : nullptr, + nullptr, + &workspace_size, + stream, + seed, + offset); + PADDLE_ENFORCE_EQ(succ, + true, + "Error in Flash-Attention, detail information is ", + phi::dynload::flash_attn_error()); + + if (workspace_size > 0) { + workspace = Empty( + ctx, {static_cast(workspace_size / sizeof(float))}); + } + + succ = phi::dynload::flash_attn_fwd( + q.data(), + k.data(), + v.data(), + out->data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + total_q, + total_k, + batch_size, + num_heads, + head_size, + max_seqlen_q, + max_seqlen_k, + dropout, + scale, + zero_tensors, + causal, + is_bf16, + num_splits, + softmax_lse->data(), + return_softmax ? softmax->data() : nullptr, + workspace_size > 0 ? workspace.data() : nullptr, + &workspace_size, + stream, + seed, + offset); + + PADDLE_ENFORCE_EQ(succ, + true, + "Error in Flash-Attention, detail information is ", + phi::dynload::flash_attn_error()); } #endif From 14a791196d08777413adcb0eabfcb4019426c391 Mon Sep 17 00:00:00 2001 From: iosmers <1871465933@qq.com> Date: Sun, 6 Aug 2023 16:31:23 +0800 Subject: [PATCH 08/13] add error type --- paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu | 8 ++++---- paddle/phi/kernels/gpu/flash_attn_kernel.cu | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index e2189edf9060e..cb59be0e209d9 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -158,7 +158,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, PADDLE_ENFORCE_EQ(succ, true, "Error in Flash-Attention, detail information is ", - phi::dynload::flash_attn_error()); + phi::errors::External(phi::dynload::flash_attn_error())); DenseTensor workspace; if (workspace_size > 0) { @@ -205,7 +205,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, PADDLE_ENFORCE_EQ(succ, true, "Error in Flash-Attention, detail information is ", - phi::dynload::flash_attn_error()); + phi::errors::External(phi::dynload::flash_attn_error())); int64_t q_size = total_q * num_heads * head_size; auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, q_size, 1); @@ -251,7 +251,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, PADDLE_ENFORCE_EQ(succ, true, "Error in Flash-Attention, detail information is ", - phi::dynload::flash_attn_error()); + phi::errors::External(phi::dynload::flash_attn_error())); DenseTensor workspace; if (workspace_size > 0) { @@ -294,7 +294,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, PADDLE_ENFORCE_EQ(succ, true, "Error in Flash-Attention, detail information is ", - phi::dynload::flash_attn_error()); + phi::errors::External(phi::dynload::flash_attn_error())); } #endif } diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index a1281f16e0e11..d8e7faa630d6b 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -223,7 +223,7 @@ void FlashAttnUnpaddedKernel( PADDLE_ENFORCE_EQ(succ, true, "Error in Flash-Attention, detail information is ", - phi::dynload::flash_attn_error()); + phi::errors::External(phi::dynload::flash_attn_error())); if (workspace_size > 0) { workspace = Empty( @@ -261,7 +261,7 @@ void FlashAttnUnpaddedKernel( PADDLE_ENFORCE_EQ(succ, true, "Error in Flash-Attention, detail information is ", - phi::dynload::flash_attn_error()); + phi::errors::External(phi::dynload::flash_attn_error())); } else { bool succ = phi::dynload::flash_attn_fwd(q.data(), @@ -293,7 +293,7 @@ void FlashAttnUnpaddedKernel( PADDLE_ENFORCE_EQ(succ, true, "Error in Flash-Attention, detail information is ", - phi::dynload::flash_attn_error()); + phi::errors::External(phi::dynload::flash_attn_error())); if (workspace_size > 0) { workspace = Empty( @@ -331,7 +331,7 @@ void FlashAttnUnpaddedKernel( PADDLE_ENFORCE_EQ(succ, true, "Error in Flash-Attention, detail information is ", - phi::dynload::flash_attn_error()); + phi::errors::External(phi::dynload::flash_attn_error())); } #endif From 70f8e75a7d89bdaa2bad672fa0ef3460826d2a31 Mon Sep 17 00:00:00 2001 From: iosmers <1871465933@qq.com> Date: Sun, 6 Aug 2023 17:47:28 +0800 Subject: [PATCH 09/13] update enforce --- .../phi/kernels/gpu/flash_attn_grad_kernel.cu | 36 +++++++++-------- paddle/phi/kernels/gpu/flash_attn_kernel.cu | 40 +++++++++++-------- 2 files changed, 44 insertions(+), 32 deletions(-) diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index cb59be0e209d9..9211607bf7ef7 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -155,10 +155,11 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, temp_rand_mask_dim.data() ? temp_rand_mask_dim.data() : nullptr, nullptr); - PADDLE_ENFORCE_EQ(succ, - true, - "Error in Flash-Attention, detail information is ", - phi::errors::External(phi::dynload::flash_attn_error())); + PADDLE_ENFORCE_EQ( + succ, + true, + phi::errors::External("Error in Flash-Attention, detail information is", + phi::dynload::flash_attn_error())); DenseTensor workspace; if (workspace_size > 0) { @@ -202,10 +203,11 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, temp_rand_mask_dim.data() ? temp_rand_mask_dim.data() : nullptr, nullptr); - PADDLE_ENFORCE_EQ(succ, - true, - "Error in Flash-Attention, detail information is ", - phi::errors::External(phi::dynload::flash_attn_error())); + PADDLE_ENFORCE_EQ( + succ, + true, + phi::errors::External("Error in Flash-Attention, detail information is", + phi::dynload::flash_attn_error())); int64_t q_size = total_q * num_heads * head_size; auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, q_size, 1); @@ -248,10 +250,11 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, seed, offset); - PADDLE_ENFORCE_EQ(succ, - true, - "Error in Flash-Attention, detail information is ", - phi::errors::External(phi::dynload::flash_attn_error())); + PADDLE_ENFORCE_EQ( + succ, + true, + phi::errors::External("Error in Flash-Attention, detail information is", + phi::dynload::flash_attn_error())); DenseTensor workspace; if (workspace_size > 0) { @@ -291,10 +294,11 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, seed, offset); - PADDLE_ENFORCE_EQ(succ, - true, - "Error in Flash-Attention, detail information is ", - phi::errors::External(phi::dynload::flash_attn_error())); + PADDLE_ENFORCE_EQ( + succ, + true, + phi::errors::External("Error in Flash-Attention, detail information is", + phi::dynload::flash_attn_error())); } #endif } diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index d8e7faa630d6b..422bde9e0f95c 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -174,6 +174,9 @@ void FlashAttnUnpaddedKernel( DenseTensor scale_q; scale_q.Resize({total_q, num_heads, head_size}); ctx.template Alloc(&scale_q); + // DenseTensor* scale_q = new DenseTensor; + // scale_q->Resize({total_q, num_heads, head_size}); + // ctx.template Alloc(scale_q); // compute scale Q ComputeScaleQ(ctx, q_size, scale, q.data(), scale_q.data()); @@ -220,10 +223,11 @@ void FlashAttnUnpaddedKernel( nullptr, rand_mask_dim.data() ? rand_mask_dim.data() : nullptr, nullptr); - PADDLE_ENFORCE_EQ(succ, - true, - "Error in Flash-Attention, detail information is ", - phi::errors::External(phi::dynload::flash_attn_error())); + PADDLE_ENFORCE_EQ( + succ, + true, + phi::errors::External("Error in Flash-Attention, detail information is", + phi::dynload::flash_attn_error())); if (workspace_size > 0) { workspace = Empty( @@ -258,10 +262,12 @@ void FlashAttnUnpaddedKernel( nullptr, rand_mask_dim.data() ? rand_mask_dim.data() : nullptr, nullptr); - PADDLE_ENFORCE_EQ(succ, - true, - "Error in Flash-Attention, detail information is ", - phi::errors::External(phi::dynload::flash_attn_error())); + PADDLE_ENFORCE_EQ( + succ, + true, + phi::errors::External("Error in Flash-Attention, detail information is", + phi::dynload::flash_attn_error())); + // delete scale_q; } else { bool succ = phi::dynload::flash_attn_fwd(q.data(), @@ -290,10 +296,11 @@ void FlashAttnUnpaddedKernel( stream, seed, offset); - PADDLE_ENFORCE_EQ(succ, - true, - "Error in Flash-Attention, detail information is ", - phi::errors::External(phi::dynload::flash_attn_error())); + PADDLE_ENFORCE_EQ( + succ, + true, + phi::errors::External("Error in Flash-Attention, detail information is", + phi::dynload::flash_attn_error())); if (workspace_size > 0) { workspace = Empty( @@ -328,10 +335,11 @@ void FlashAttnUnpaddedKernel( seed, offset); - PADDLE_ENFORCE_EQ(succ, - true, - "Error in Flash-Attention, detail information is ", - phi::errors::External(phi::dynload::flash_attn_error())); + PADDLE_ENFORCE_EQ( + succ, + true, + phi::errors::External("Error in Flash-Attention, detail information is", + phi::dynload::flash_attn_error())); } #endif From ffeaa146189e3b8f638588c70d8540dcf8cc8de1 Mon Sep 17 00:00:00 2001 From: iosmers <1871465933@qq.com> Date: Sun, 6 Aug 2023 19:18:30 +0800 Subject: [PATCH 10/13] add test_flash_attention --- test/legacy_test/test_flash_attention.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index c5b99ac575fb1..ec8ace12bbf48 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -306,6 +306,13 @@ def test_all(self): ) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11030 + or not is_sm_supported, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.3" + "and device's compute capability must be 7.5 or 8.x", +) class TestFlashAttentionWithMaskAPI(unittest.TestCase): def setUp(self): self.place = paddle.CUDAPlace(0) From 1c2c592651d9e3d1b6ce19b67f379df966c4fcc0 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 7 Aug 2023 13:03:00 +0800 Subject: [PATCH 11/13] Polish codes and fix compiling errors. --- .../phi/kernels/gpu/flash_attn_grad_kernel.cu | 203 +++++------ paddle/phi/kernels/gpu/flash_attn_kernel.cu | 341 +++++++----------- paddle/phi/kernels/gpu/flash_attn_utils.h | 120 +++++- 3 files changed, 318 insertions(+), 346 deletions(-) diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index f9ceee3004599..04ab44102dda8 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -15,7 +15,6 @@ #include "paddle/phi/kernels/flash_attn_grad_kernel.h" #include "glog/logging.h" // For VLOG() #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/core/kernel_registry.h" @@ -33,13 +32,6 @@ DECLARE_bool(cudnn_deterministic); namespace phi { -template -__global__ void SimleScaleWithMaskKernel(int64_t numel, float scale, T* inout) { - CUDA_KERNEL_LOOP_TYPE(i, numel, int64_t) { - inout[i] = static_cast(scale * static_cast(inout[i])); - } -} - template void FlashAttnUnpaddedGradImpl(const Context& ctx, const DenseTensor& q, @@ -74,44 +66,37 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, if (FLAGS_cudnn_deterministic) { num_splits = 1; } - bool zero_tensors = false; - - const int64_t* seed_offset_data = seed_offset.data(); - uint64_t seed = static_cast(seed_offset_data[0]); - uint64_t offset = static_cast(seed_offset_data[1]); - - VLOG(4) << "FlashAttn bwd seed: " << seed << ", offset: " << offset - << ", num_splits:" << num_splits; - - int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16; - DenseTensor dsoftmax = Empty(ctx, {batch_size, num_heads, seq_len_q}); - uint64_t workspace_size; - bool succ; PADDLE_ENFORCE_NE(causal, true, phi::errors::InvalidArgument( "attn_mask is not nullptr, causal can not be true")); - bool flag = (head_size == 32 || head_size == 64 || head_size == 128); + PADDLE_ENFORCE_EQ( - flag, + head_size == 32 || head_size == 64 || head_size == 128, true, - phi::errors::InvalidArgument( - "Currently, the mask only supports head_dim of 32, 64, and 128")); + phi::errors::InvalidArgument("The head_dim is expected to be either 32, " + "64, or 128, but recieved %d.", + head_size)); + + const int64_t* seed_offset_data = seed_offset.data(); + uint64_t seed = static_cast(seed_offset_data[0]); + uint64_t offset = static_cast(seed_offset_data[1]); + VLOG(10) << "FlashAttn bwd seed: " << seed << ", offset: " << offset + << ", num_splits:" << num_splits; + + int64_t seqlen_q = ((max_seqlen_q + 16 - 1) / 16) * 16; + DenseTensor dsoftmax = Empty(ctx, {batch_size, num_heads, seqlen_q}); + + const DenseTensor* attn_mask_tensor = attn_mask.get_ptr(); + std::vector mask_dims = GetAttnMaskDims(attn_mask_tensor); + + bool fa_is_bf16 = q.dtype() == DataType::BFLOAT16; float fa_with_mask_scale = 1.0f; - std::vector temp_rand_mask_dim; - const DenseTensor* attn_mask_ptr = attn_mask.get_ptr(); - int64_t first_dim = 1; - const auto& origin_dims = attn_mask_ptr->dims(); - auto rank = origin_dims.size(); - for (int i = 0; i < rank - 3; i++) { - first_dim *= origin_dims[i]; - } - temp_rand_mask_dim = {first_dim, - origin_dims[rank - 3], - origin_dims[rank - 2], - origin_dims[rank - 1]}; - succ = phi::dynload::flash_attn_bwd_with_bias_and_mask( + bool fa_zero_tensors = false; + + uint64_t workspace_size; + bool succ = phi::dynload::flash_attn_bwd_with_bias_and_mask( static_cast(q.data()), static_cast(k.data()), static_cast(v.data()), @@ -131,8 +116,8 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, max_seqlen_k, dropout, fa_with_mask_scale, - zero_tensors, - is_bf16, + fa_zero_tensors, + fa_is_bf16, num_splits, static_cast(softmax_lse.data()), static_cast(dsoftmax.data()), @@ -142,16 +127,11 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, stream, seed, offset, - attn_mask_ptr ? attn_mask_ptr->data() : nullptr, + attn_mask_tensor ? attn_mask_tensor->data() : nullptr, nullptr, - temp_rand_mask_dim.data() ? temp_rand_mask_dim.data() : nullptr, + mask_dims.data() ? mask_dims.data() : nullptr, nullptr); - - PADDLE_ENFORCE_EQ( - succ, - true, - phi::errors::External("Error in Flash-Attention, detail information is", - phi::dynload::flash_attn_error())); + CheckFlashAttnStatus(succ); DenseTensor workspace; if (workspace_size > 0) { @@ -179,8 +159,8 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, max_seqlen_k, dropout, fa_with_mask_scale, - zero_tensors, - is_bf16, + fa_zero_tensors, + fa_is_bf16, num_splits, static_cast(softmax_lse.data()), static_cast(dsoftmax.data()), @@ -190,24 +170,14 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, stream, seed, offset, - attn_mask_ptr ? attn_mask_ptr->data() : nullptr, + attn_mask_tensor ? attn_mask_tensor->data() : nullptr, nullptr, - temp_rand_mask_dim.data() ? temp_rand_mask_dim.data() : nullptr, + mask_dims.data() ? mask_dims.data() : nullptr, nullptr); - - PADDLE_ENFORCE_EQ( - succ, - true, - phi::errors::External("Error in Flash-Attention, detail information is", - phi::dynload::flash_attn_error())); + CheckFlashAttnStatus(succ); int64_t q_size = total_q * num_heads * head_size; - auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, q_size, 1); - SimleScaleWithMaskKernel<<>>( - q_size, scale, static_cast(dq->data())); + ComputeScaleQ(ctx, q_size, scale, dq->data(), dq->data()); } template @@ -275,8 +245,6 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, // num_splits = 1; // } - const bool zero_tensors = false; - // TODO(umiswing): add shape check PADDLE_ENFORCE_EQ( head_size_og, @@ -301,7 +269,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, VLOG(4) << "FlashAttn bwd seed: " << params.seed << ", offset: " << params.offset; - const bool succ = + bool succ = phi::dynload::flash_attn_varlen_bwd(dout.data(), q.data(), k.data(), @@ -332,14 +300,10 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, stream, params.seed, params.offset); - - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } + CheckFlashAttnStatus(succ); } #else - PADDLE_THROW(phi::errors::Unimplemented( - "FlashAttention is unsupported, please set use_flash_attn to false.")); + RaiseNotSupportedError(); #endif } @@ -361,14 +325,17 @@ void FlashAttnGradKernel(const Context& ctx, #ifdef PADDLE_WITH_FLASHATTN // q,k,v [batch_size, seq_len, num_heads, head_dim] - auto dims = q.dims(); - const int batch_size = dims[0]; - const int seqlen_q = dims[1]; - const int num_heads = dims[2]; - const int head_size_og = dout.dims()[3]; - const int head_size = dims[3]; - const int seqlen_k = k.dims()[1]; - const int num_heads_k = k.dims()[2]; + const auto& dims = q.dims(); + const int64_t batch_size = dims[0]; + const int64_t seqlen_q = dims[1]; + const int64_t num_heads = dims[2]; + const int64_t head_size_og = dout.dims()[3]; + const int64_t head_size = dims[3]; + const int64_t seqlen_k = k.dims()[1]; + const int64_t num_heads_k = k.dims()[2]; + + const int64_t total_q = batch_size * seqlen_q; + const int64_t total_k = batch_size * seqlen_k; // TODO(umiswing): add shape check PADDLE_ENFORCE_EQ( @@ -390,9 +357,9 @@ void FlashAttnGradKernel(const Context& ctx, DenseTensor cu_seqlens_q; DenseTensor cu_seqlens_k; ArangeNullaryKernel( - ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q); + ctx, 0, (batch_size + 1) * seqlen_q, seqlen_q, &cu_seqlens_q); ArangeNullaryKernel( - ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k); + ctx, 0, (batch_size + 1) * seqlen_k, seqlen_k, &cu_seqlens_k); FlashAttnUnpaddedGradKernel(ctx, q_t_s, @@ -405,8 +372,8 @@ void FlashAttnGradKernel(const Context& ctx, seed_offset, attn_mask, dout, - seq_len_q, - seq_len_k, + seqlen_q, + seqlen_k, scale, dropout, causal, @@ -437,44 +404,38 @@ void FlashAttnGradKernel(const Context& ctx, VLOG(4) << "FlashAttn bwd seed: " << params.seed << ", offset: " << params.offset; - const bool succ = phi::dynload::flash_attn_bwd(dout.data(), - q.data(), - k.data(), - v.data(), - out.data(), - params.softmax_d.data(), - softmax_lse.data(), - params.rng_state.data(), - dq->data(), - dk->data(), - dv->data(), - params.dq_accum.data(), - params.batch_size, - params.max_seqlen_q, - params.max_seqlen_k, - params.seqlen_q_rounded, - params.seqlen_k_rounded, - params.num_heads, - params.num_heads_k, - params.head_size, - params.head_size_rounded, - params.dropout, - params.scale, - params.causal, - params.is_bf16, - stream, - params.seed, - params.offset); - - PADDLE_ENFORCE_EQ(succ, - true, - phi::errors::External( - "Error in Flash-Attention-2, detail information is", - phi::dynload::flash_attn_error())); + bool succ = phi::dynload::flash_attn_bwd(dout.data(), + q.data(), + k.data(), + v.data(), + out.data(), + params.softmax_d.data(), + softmax_lse.data(), + params.rng_state.data(), + dq->data(), + dk->data(), + dv->data(), + params.dq_accum.data(), + params.batch_size, + params.max_seqlen_q, + params.max_seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.num_heads, + params.num_heads_k, + params.head_size, + params.head_size_rounded, + params.dropout, + params.scale, + params.causal, + params.is_bf16, + stream, + params.seed, + params.offset); + CheckFlashAttnStatus(succ); } #else - PADDLE_THROW(phi::errors::Unimplemented( - "FlashAttention is unsupported, please set use_flash_attn to false.")); + RaiseNotSupportedError(); #endif } diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 9a35c30bb28ca..36bb5de170125 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -15,10 +15,7 @@ #include "paddle/phi/kernels/flash_attn_kernel.h" #include "glog/logging.h" // For VLOG() -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/data_type.h" -#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/flags.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" @@ -35,26 +32,6 @@ DECLARE_bool(cudnn_deterministic); namespace phi { -template -__global__ void SimleScaleWithMaskKernel(int64_t numel, - float scale, - const T* input, - T* ouput) { - CUDA_KERNEL_LOOP_TYPE(i, numel, int64_t) { - ouput[i] = static_cast(scale * static_cast(input[i])); - } -} - -template -void ComputeScaleQ( - const Context& ctx, int64_t numel, float scale, const T* input, T* output) { - auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 1); - SimleScaleWithMaskKernel<<>>(numel, scale, input, output); -} - template void FlashAttnWithMaskUnpaddedImpl( const Context& ctx, @@ -91,29 +68,23 @@ void FlashAttnWithMaskUnpaddedImpl( if (FLAGS_cudnn_deterministic) { num_splits = 1; } - bool zero_tensors = false; + PADDLE_ENFORCE_NE(causal, + true, + phi::errors::InvalidArgument( + "attn_mask is not nullptr, causal can not be true")); - uint64_t seed; - uint64_t offset; + PADDLE_ENFORCE_EQ( + head_size == 32 || head_size == 64 || head_size == 128, + true, + phi::errors::InvalidArgument("The head_dim is expected to be either 32, " + "64, or 128, but recieved %d.", + head_size)); - if (fixed_seed_offset.get_ptr()) { - const int64_t* fixed_seed_offset_data = - fixed_seed_offset.get_ptr()->data(); - seed = static_cast(fixed_seed_offset_data[0]); - offset = static_cast(fixed_seed_offset_data[1]); - } else { - uint64_t inc = batch_size * num_heads * 32; - std::pair seed_offset_pair; - if (rng_name != "") { - auto gen = phi::GetRandomSeedGenerator(rng_name); - seed_offset_pair = gen->IncrementOffset(inc); - } else { - auto* gen = ctx.GetGenerator(); - seed_offset_pair = gen->IncrementOffset(inc); - } - seed = seed_offset_pair.first; - offset = seed_offset_pair.second; - } + // Generate random state for dropout and save for recompute in grad. + auto seed_offset_pair = + GenerateRNGState(ctx, fixed_seed_offset, rng_name, batch_size, num_heads); + uint64_t seed = seed_offset_pair.first; + uint64_t offset = seed_offset_pair.second; VLOG(4) << "FlashAttn fwd seed: " << seed << ", offset: " << offset << ", num_splits:" << num_splits; @@ -123,65 +94,41 @@ void FlashAttnWithMaskUnpaddedImpl( seed_offset_data[0] = static_cast(seed); seed_offset_data[1] = static_cast(offset); - int64_t seq_len_q = ((max_seqlen_q + 16 - 1) / 16) * 16; + // Allocate memory for softmax_lse and softmax. + int64_t seqlen_q = ((max_seqlen_q + 16 - 1) / 16) * 16; - softmax_lse->Resize({batch_size, num_heads, seq_len_q}); + softmax_lse->Resize({batch_size, num_heads, seqlen_q}); ctx.template Alloc(softmax_lse); if (return_softmax) { // may allocate more space than *max_seqlen_k* int64_t blocksize_c = head_size > 64 ? 128 : 256; - int64_t seq_len_k = + int64_t seqlen_k = ((max_seqlen_k + blocksize_c - 1) / blocksize_c) * blocksize_c; if (max_seqlen_k <= 128) { - seq_len_k = 128; + seqlen_k = 128; } else if (max_seqlen_k <= 256) { - seq_len_k = 256; + seqlen_k = 256; } - softmax->Resize({batch_size, num_heads, seq_len_q, seq_len_k}); + softmax->Resize({batch_size, num_heads, seqlen_q, seqlen_k}); ctx.template Alloc(softmax); } - uint64_t workspace_size = 0; - DenseTensor workspace; - - PADDLE_ENFORCE_NE(causal, - true, - phi::errors::InvalidArgument( - "attn_mask is not nullptr, causal can not be true")); - bool flag = (head_size == 32 || head_size == 64 || head_size == 128); - PADDLE_ENFORCE_EQ( - flag, - true, - phi::errors::InvalidArgument( - "Currently, the mask only supports head_dim of 32, 64, and 128")); - + // Compute scale Q int64_t q_size = total_q * num_heads * head_size; - DenseTensor scale_q; - scale_q.Resize({total_q, num_heads, head_size}); - ctx.template Alloc(&scale_q); - // DenseTensor* scale_q = new DenseTensor; - // scale_q->Resize({total_q, num_heads, head_size}); - // ctx.template Alloc(scale_q); - // compute scale Q - ComputeScaleQ(ctx, q_size, scale, q.data(), scale_q.data()); + DenseTensor scaled_q = Empty(ctx, {total_q, num_heads, head_size}); + ComputeScaleQ(ctx, q_size, scale, q.data(), scaled_q.data()); + + const DenseTensor* attn_mask_tensor = attn_mask.get_ptr(); + std::vector mask_dims = GetAttnMaskDims(attn_mask_tensor); + bool fa_is_bf16 = q.dtype() == DataType::BFLOAT16; float fa_with_mask_scale = 1.0f; - std::vector rand_mask_dim; - const DenseTensor* attn_mask_ptr = attn_mask.get_ptr(); - int64_t first_dim = 1; - const auto& origin_dims = attn_mask_ptr->dims(); - auto rank = origin_dims.size(); - for (int i = 0; i < rank - 3; i++) { - first_dim *= origin_dims[i]; - } - rand_mask_dim = {first_dim, - origin_dims[rank - 3], - origin_dims[rank - 2], - origin_dims[rank - 1]}; + bool fa_zero_tensors = false; + uint64_t workspace_size = 0; bool succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( - static_cast(scale_q.data()), + static_cast(scaled_q.data()), static_cast(k.data()), static_cast(v.data()), nullptr, // for calculation workspace size @@ -196,8 +143,8 @@ void FlashAttnWithMaskUnpaddedImpl( max_seqlen_k, dropout, fa_with_mask_scale, - zero_tensors, - is_bf16, + fa_zero_tensors, + fa_is_bf16, num_splits, softmax_lse->data(), nullptr, @@ -205,22 +152,19 @@ void FlashAttnWithMaskUnpaddedImpl( stream, seed, offset, - attn_mask_ptr ? attn_mask_ptr->data() : nullptr, + attn_mask_tensor ? attn_mask_tensor->data() : nullptr, nullptr, - rand_mask_dim.data() ? rand_mask_dim.data() : nullptr, + mask_dims.data() ? mask_dims.data() : nullptr, nullptr); - PADDLE_ENFORCE_EQ( - succ, - true, - phi::errors::External("Error in Flash-Attention, detail information is", - phi::dynload::flash_attn_error())); + CheckFlashAttnStatus(succ); + DenseTensor workspace; if (workspace_size > 0) { workspace = Empty( ctx, {static_cast(workspace_size / sizeof(float))}); } succ = phi::dynload::flash_attn_fwd_with_bias_and_mask( - static_cast(scale_q.data()), + static_cast(scaled_q.data()), k.data(), v.data(), out->data(), // set out to nullptr to calculate workspace size @@ -235,8 +179,8 @@ void FlashAttnWithMaskUnpaddedImpl( max_seqlen_k, dropout, fa_with_mask_scale, - zero_tensors, - is_bf16, + fa_zero_tensors, + fa_is_bf16, num_splits, softmax_lse->data(), workspace_size > 0 ? workspace.data() : nullptr, @@ -244,15 +188,11 @@ void FlashAttnWithMaskUnpaddedImpl( stream, seed, offset, - attn_mask_ptr ? attn_mask_ptr->data() : nullptr, + attn_mask_tensor ? attn_mask_tensor->data() : nullptr, nullptr, - rand_mask_dim.data() ? rand_mask_dim.data() : nullptr, + mask_dims.data() ? mask_dims.data() : nullptr, nullptr); - PADDLE_ENFORCE_EQ( - succ, - true, - phi::errors::External("Error in Flash-Attention, detail information is", - phi::dynload::flash_attn_error())); + CheckFlashAttnStatus(succ); } template @@ -328,30 +268,29 @@ void FlashAttnUnpaddedKernel( // TODO(umiswing): add shape check - FlashAttnFwdParamsV2 params = - FlashAttnFwdParamsV2(ctx, - batch_size, - max_seqlen_q, - max_seqlen_k, - num_heads, - num_heads_k, - head_size, - dropout, - scale, - causal, - return_softmax, - q.dtype(), - is_test, - rng_name, - fixed_seed_offset.get_ptr(), - softmax, - softmax_lse, - seed_offset); + FlashAttnFwdParamsV2 params = FlashAttnFwdParamsV2(ctx, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + num_heads_k, + head_size, + dropout, + scale, + causal, + return_softmax, + q.dtype(), + is_test, + rng_name, + fixed_seed_offset, + softmax, + softmax_lse, + seed_offset); VLOG(4) << "FlashAttn fwd seed: " << params.seed << ", offset: " << params.offset; - const bool succ = phi::dynload::flash_attn_varlen_fwd( + bool succ = phi::dynload::flash_attn_varlen_fwd( q.data(), k.data(), v.data(), @@ -378,14 +317,10 @@ void FlashAttnUnpaddedKernel( stream, params.seed, params.offset); - if (!succ) { - PADDLE_THROW(phi::errors::External(phi::dynload::flash_attn_error())); - } + CheckFlashAttnStatus(succ); } #else - PADDLE_THROW( - phi::errors::Unimplemented("FlashAttention is unsupported, please check " - "the GPU compability and CUDA Version.")); + RaiseNotSupportedError(); #endif } @@ -407,87 +342,28 @@ void FlashAttnKernel(const Context& ctx, DenseTensor* seed_offset) { #ifdef PADDLE_WITH_FLASHATTN // q,k,v [batch_size, seq_len, num_heads, head_dim] - auto dims = q.dims(); + const auto& dims = q.dims(); PADDLE_ENFORCE_EQ(dims.size(), 4, phi::errors::InvalidArgument( "flash_attn receive input with dim " "[batch_size, seq_len, num_heads, head_dim]")); - const int batch_size = dims[0]; - const int seqlen_q = dims[1]; - const int num_heads = dims[2]; - const int head_size = dims[3]; - const int seqlen_k = k.dims()[1]; - const int num_heads_k = k.dims()[2]; + const int64_t batch_size = dims[0]; + const int64_t seqlen_q = dims[1]; + const int64_t num_heads = dims[2]; + const int64_t head_size = dims[3]; + const int64_t seqlen_k = k.dims()[1]; + const int64_t num_heads_k = k.dims()[2]; + + const int64_t total_q = batch_size * seqlen_q; + const int64_t total_k = batch_size * seqlen_k; // TODO(umiswing): Add check shape const float scale = 1.0f / std::sqrt(head_size); - if (!attn_mask.get_ptr()) { - FlashAttnFwdParamsV2 params = - FlashAttnFwdParamsV2(ctx, - batch_size, - seqlen_q, - seqlen_k, - num_heads, - num_heads_k, - head_size, - dropout, - scale, - causal, - return_softmax, - q.dtype(), - is_test, - rng_name, - fixed_seed_offset.get_ptr(), - softmax, - softmax_lse, - seed_offset); - - VLOG(4) << "FlashAttn fwd dims q[" << q.dims() << "], k[" << k.dims() - << "], v[" << v.dims() << "]"; - ctx.template Alloc(out); - - cudaStream_t stream = ctx.stream(); - - VLOG(4) << "FlashAttn fwd seed: " << params.seed - << ", offset: " << params.offset; - - bool succ = phi::dynload::flash_attn_fwd( - q.data(), - k.data(), - v.data(), - params.rng_state.data(), - out->data(), - params.return_softmax ? params.softmax->data() : nullptr, - params.softmax_lse->data(), - params.batch_size, - params.max_seqlen_q, - params.max_seqlen_k, - params.seqlen_q_rounded, - params.seqlen_k_rounded, - params.num_heads, - params.num_heads_k, - params.head_size, - params.head_size_rounded, - params.dropout, - params.scale, - params.causal, - params.return_softmax, - params.is_bf16, - stream, - params.seed, - params.offset); - - PADDLE_ENFORCE_EQ( - succ, - true, - phi::errors::External( - "Error in Flash-Attention-2, detail information is: %s", - phi::dynload::flash_attn_error())); - } else { + if (attn_mask.get_ptr()) { DenseTensor q_t_s, k_t_s, v_t_s; q_t_s.ShareDataWith(q).Resize({total_q, num_heads, head_size}); k_t_s.ShareDataWith(k).Resize({total_k, num_heads, head_size}); @@ -496,9 +372,9 @@ void FlashAttnKernel(const Context& ctx, DenseTensor cu_seqlens_q; DenseTensor cu_seqlens_k; ArangeNullaryKernel( - ctx, 0, (batch_size + 1) * seq_len_q, seq_len_q, &cu_seqlens_q); + ctx, 0, (batch_size + 1) * seqlen_q, seqlen_q, &cu_seqlens_q); ArangeNullaryKernel( - ctx, 0, (batch_size + 1) * seq_len_k, seq_len_k, &cu_seqlens_k); + ctx, 0, (batch_size + 1) * seqlen_k, seqlen_k, &cu_seqlens_k); FlashAttnUnpaddedKernel(ctx, q_t_s, @@ -508,8 +384,8 @@ void FlashAttnKernel(const Context& ctx, cu_seqlens_k, fixed_seed_offset, attn_mask, - seq_len_q, - seq_len_k, + seqlen_q, + seqlen_k, scale, dropout, causal, @@ -520,10 +396,63 @@ void FlashAttnKernel(const Context& ctx, softmax, softmax_lse, seed_offset); + } else { + FlashAttnFwdParamsV2 params = FlashAttnFwdParamsV2(ctx, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size, + dropout, + scale, + causal, + return_softmax, + q.dtype(), + is_test, + rng_name, + fixed_seed_offset, + softmax, + softmax_lse, + seed_offset); + + VLOG(10) << "FlashAttn fwd dims: q[" << q.dims() << "], k[" << k.dims() + << "], v[" << v.dims() << "]"; + VLOG(10) << "FlashAttn fwd seed: " << params.seed + << ", offset: " << params.offset; + + ctx.template Alloc(out); + + cudaStream_t stream = ctx.stream(); + bool succ = phi::dynload::flash_attn_fwd( + q.data(), + k.data(), + v.data(), + params.rng_state.data(), + out->data(), + params.return_softmax ? params.softmax->data() : nullptr, + params.softmax_lse->data(), + params.batch_size, + params.max_seqlen_q, + params.max_seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.num_heads, + params.num_heads_k, + params.head_size, + params.head_size_rounded, + params.dropout, + params.scale, + params.causal, + params.return_softmax, + params.is_bf16, + stream, + params.seed, + params.offset); + CheckFlashAttnStatus(succ); } #else - PADDLE_THROW(phi::errors::Unimplemented( - "FlashAttention is unsupported, please set use_flash_attn to false.")); + RaiseNotSupportedError(); #endif } diff --git a/paddle/phi/kernels/gpu/flash_attn_utils.h b/paddle/phi/kernels/gpu/flash_attn_utils.h index 62d0f4ec95b37..e3988658db51f 100644 --- a/paddle/phi/kernels/gpu/flash_attn_utils.h +++ b/paddle/phi/kernels/gpu/flash_attn_utils.h @@ -14,8 +14,42 @@ #pragma once +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/enforce.h" + +#ifdef PADDLE_WITH_FLASHATTN +#include "paddle/phi/backends/dynload/flashattn.h" +#endif + namespace phi { +static std::pair GenerateRNGState( + const GPUContext& ctx, + const paddle::optional& fixed_seed_offset, + const std::string& rng_name, + const int64_t batch_size, + const int64_t num_heads) { + if (fixed_seed_offset.get_ptr()) { + const int64_t* fixed_seed_offset_data = + fixed_seed_offset.get_ptr()->data(); + uint64_t seed = static_cast(fixed_seed_offset_data[0]); + uint64_t offset = static_cast(fixed_seed_offset_data[1]); + return std::make_pair(seed, offset); + } else { + uint64_t inc = batch_size * num_heads * 32; + std::pair seed_offset_pair; + if (rng_name != "") { + auto gen = phi::GetRandomSeedGenerator(rng_name); + seed_offset_pair = gen->IncrementOffset(inc); + } else { + auto* gen = ctx.GetGenerator(); + seed_offset_pair = gen->IncrementOffset(inc); + } + return seed_offset_pair; + } +} + template struct FlashAttnFwdParamsV2 { int batch_size; @@ -55,7 +89,7 @@ struct FlashAttnFwdParamsV2 { const DataType q_dtype, const bool is_test, const std::string& rng_name, - const DenseTensor* const fixed_seed_offset_ptr, + const paddle::optional& fixed_seed_offset, DenseTensor* _softmax, DenseTensor* _softmax_lse, DenseTensor* _seed_offset) @@ -78,24 +112,11 @@ struct FlashAttnFwdParamsV2 { // (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t // with the same size. rng_state = Empty(ctx, {2}); - if (fixed_seed_offset_ptr) { - const int64_t* fixed_seed_offset_data = - fixed_seed_offset_ptr->data(); - seed = static_cast(fixed_seed_offset_data[0]); - offset = static_cast(fixed_seed_offset_data[1]); - } else { - uint64_t inc = batch_size * num_heads * 32; - std::pair seed_offset_pair; - if (rng_name != "") { - auto gen = phi::GetRandomSeedGenerator(rng_name); - seed_offset_pair = gen->IncrementOffset(inc); - } else { - auto* gen = ctx.GetGenerator(); - seed_offset_pair = gen->IncrementOffset(inc); - } - seed = seed_offset_pair.first; - offset = seed_offset_pair.second; - } + + auto seed_offset_pair = GenerateRNGState( + ctx, fixed_seed_offset, rng_name, batch_size, num_heads); + seed = seed_offset_pair.first; + offset = seed_offset_pair.second; seed_offset->Resize({2}); int64_t* seed_offset_data = ctx.template HostAlloc(seed_offset); @@ -178,4 +199,65 @@ struct FlashAttnBwdParamsV2 { ctx, {batch_size, num_heads, seqlen_q_rounded, head_size_rounded}); } }; + +static void CheckFlashAttnStatus(const bool status) { + PADDLE_ENFORCE_EQ(status, + true, + phi::errors::External( + "Error in Flash-Attention, detail information is: %s", + phi::dynload::flash_attn_error())); +} + +static void RaiseNotSupportedError() { + PADDLE_THROW( + phi::errors::Unimplemented("FlashAttention is unsupported, please check " + "the GPU compability and CUDA Version.")); +} + +template +__global__ void SimleScaleKernel(const T* input, + int64_t numel, + float scale, + T* ouput) { + CUDA_KERNEL_LOOP_TYPE(i, numel, int64_t) { + ouput[i] = static_cast(scale * static_cast(input[i])); + } +} + +template +void ComputeScaleQ( + const Context& ctx, int64_t numel, float scale, const T* input, T* output) { + auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 1); + SimleScaleKernel<<>>(input, numel, scale, output); +} + +static std::vector GetAttnMaskDims(const DenseTensor* attn_mask) { + std::vector mask_dim_4d; + if (attn_mask) { + const auto& origin_dims = attn_mask->dims(); + auto rank = origin_dims.size(); + PADDLE_ENFORCE_GE( + rank, + 4, + phi::errors::InvalidArgument( + "Teh number of dimenstions of attn_mask is expected to be greater " + "or equal to 4, but recieved %d. The shape of attn_mask is {%s}", + rank, + origin_dims)); + + int64_t first_dim = 1; + for (int i = 0; i < rank - 3; i++) { + first_dim *= origin_dims[i]; + } + mask_dim_4d = {first_dim, + origin_dims[rank - 3], + origin_dims[rank - 2], + origin_dims[rank - 1]}; + } + return mask_dim_4d; +} + } // namespace phi From 28a40ab563c5a14e9b55f0f644f1b2b227773b95 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 7 Aug 2023 13:12:33 +0800 Subject: [PATCH 12/13] Set num_splits to 0 for flash-attn with tensor mask. --- .../phi/kernels/gpu/flash_attn_grad_kernel.cu | 34 ++++++++----------- paddle/phi/kernels/gpu/flash_attn_kernel.cu | 16 ++++----- 2 files changed, 21 insertions(+), 29 deletions(-) diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 04ab44102dda8..f7be0d625e2f2 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -62,11 +62,6 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, int64_t total_k = k.dims()[0]; int64_t batch_size = cu_seqlens_q.numel() - 1; - int num_splits = 0; // 0 for an internal heuristic, which is optimal - if (FLAGS_cudnn_deterministic) { - num_splits = 1; - } - PADDLE_ENFORCE_NE(causal, true, phi::errors::InvalidArgument( @@ -91,6 +86,7 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, const DenseTensor* attn_mask_tensor = attn_mask.get_ptr(); std::vector mask_dims = GetAttnMaskDims(attn_mask_tensor); + int fa_num_splits = 0; bool fa_is_bf16 = q.dtype() == DataType::BFLOAT16; float fa_with_mask_scale = 1.0f; bool fa_zero_tensors = false; @@ -118,7 +114,7 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, fa_with_mask_scale, fa_zero_tensors, fa_is_bf16, - num_splits, + fa_num_splits, static_cast(softmax_lse.data()), static_cast(dsoftmax.data()), nullptr, @@ -161,7 +157,7 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, fa_with_mask_scale, fa_zero_tensors, fa_is_bf16, - num_splits, + fa_num_splits, static_cast(softmax_lse.data()), static_cast(dsoftmax.data()), nullptr, @@ -232,12 +228,12 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, dv); } else { const int64_t total_q = dims[0]; - const int batch_size = cu_seqlens_q.numel() - 1; - const int num_heads = dims[1]; - const int head_size_og = dout.dims()[2]; - const int head_size = dims[2]; - const int total_k = k.dims()[0]; - const int num_heads_k = k.dims()[1]; + const int64_t batch_size = cu_seqlens_q.numel() - 1; + const int64_t num_heads = dims[1]; + const int64_t head_size_og = dout.dims()[2]; + const int64_t head_size = dims[2]; + const int64_t total_k = k.dims()[0]; + const int64_t num_heads_k = k.dims()[1]; // TODO(umiswing): add deterministic in fa2. // int num_splits = 0; // 0 for an internal heuristic, which is optimal @@ -266,8 +262,8 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, q.dtype(), seed_offset.data()); - VLOG(4) << "FlashAttn bwd seed: " << params.seed - << ", offset: " << params.offset; + VLOG(10) << "FlashAttn bwd seed: " << params.seed + << ", offset: " << params.offset; bool succ = phi::dynload::flash_attn_varlen_bwd(dout.data(), @@ -344,8 +340,8 @@ void FlashAttnGradKernel(const Context& ctx, phi::errors::InvalidArgument( "flash_attn_bwd receive input with head_size_og == head_size")); - VLOG(4) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims() - << "], v[" << v.dims() << "]"; + VLOG(10) << "FlashAttn bwd dims q[" << q.dims() << "], k[" << k.dims() + << "], v[" << v.dims() << "]"; const float scale = 1.0f / std::sqrt(head_size); if (attn_mask.get_ptr()) { @@ -401,8 +397,8 @@ void FlashAttnGradKernel(const Context& ctx, cudaStream_t stream = ctx.stream(); - VLOG(4) << "FlashAttn bwd seed: " << params.seed - << ", offset: " << params.offset; + VLOG(10) << "FlashAttn bwd seed: " << params.seed + << ", offset: " << params.offset; bool succ = phi::dynload::flash_attn_bwd(dout.data(), q.data(), diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 36bb5de170125..2f3922093eac3 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -64,10 +64,6 @@ void FlashAttnWithMaskUnpaddedImpl( int64_t total_k = k.dims()[0]; int64_t batch_size = cu_seqlens_q.numel() - 1; - int num_splits = 0; // 0 for an internal heuristic, which is optimal - if (FLAGS_cudnn_deterministic) { - num_splits = 1; - } PADDLE_ENFORCE_NE(causal, true, phi::errors::InvalidArgument( @@ -86,8 +82,7 @@ void FlashAttnWithMaskUnpaddedImpl( uint64_t seed = seed_offset_pair.first; uint64_t offset = seed_offset_pair.second; - VLOG(4) << "FlashAttn fwd seed: " << seed << ", offset: " << offset - << ", num_splits:" << num_splits; + VLOG(10) << "FlashAttn fwd seed: " << seed << ", offset: " << offset; seed_offset->Resize({2}); int64_t* seed_offset_data = ctx.template HostAlloc(seed_offset); @@ -122,6 +117,7 @@ void FlashAttnWithMaskUnpaddedImpl( const DenseTensor* attn_mask_tensor = attn_mask.get_ptr(); std::vector mask_dims = GetAttnMaskDims(attn_mask_tensor); + int fa_num_splits = 0; bool fa_is_bf16 = q.dtype() == DataType::BFLOAT16; float fa_with_mask_scale = 1.0f; bool fa_zero_tensors = false; @@ -145,7 +141,7 @@ void FlashAttnWithMaskUnpaddedImpl( fa_with_mask_scale, fa_zero_tensors, fa_is_bf16, - num_splits, + fa_num_splits, softmax_lse->data(), nullptr, &workspace_size, @@ -181,7 +177,7 @@ void FlashAttnWithMaskUnpaddedImpl( fa_with_mask_scale, fa_zero_tensors, fa_is_bf16, - num_splits, + fa_num_splits, softmax_lse->data(), workspace_size > 0 ? workspace.data() : nullptr, &workspace_size, @@ -287,8 +283,8 @@ void FlashAttnUnpaddedKernel( softmax_lse, seed_offset); - VLOG(4) << "FlashAttn fwd seed: " << params.seed - << ", offset: " << params.offset; + VLOG(10) << "FlashAttn fwd seed: " << params.seed + << ", offset: " << params.offset; bool succ = phi::dynload::flash_attn_varlen_fwd( q.data(), From ec4bcc473fed688a80f3546d8cd25a70e89fe3aa Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 7 Aug 2023 13:44:22 +0800 Subject: [PATCH 13/13] Fix the compiling error for non flash-attn case. --- paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu | 13 ++++++------- paddle/phi/kernels/gpu/flash_attn_kernel.cu | 10 +++++----- paddle/phi/kernels/gpu/flash_attn_utils.h | 14 ++++++++------ 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index f7be0d625e2f2..de479cf9adfd2 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -21,12 +21,8 @@ #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/reshape_kernel.h" - -#ifdef PADDLE_WITH_FLASHATTN -#include "paddle/phi/backends/dynload/flashattn.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h" -#endif +#include "paddle/phi/kernels/reshape_kernel.h" DECLARE_bool(cudnn_deterministic); @@ -52,6 +48,7 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, DenseTensor* dq, DenseTensor* dk, DenseTensor* dv) { +#ifdef PADDLE_WITH_FLASHATTN const cudaStream_t stream = ctx.stream(); auto dims = q.dims(); @@ -77,8 +74,7 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, const int64_t* seed_offset_data = seed_offset.data(); uint64_t seed = static_cast(seed_offset_data[0]); uint64_t offset = static_cast(seed_offset_data[1]); - VLOG(10) << "FlashAttn bwd seed: " << seed << ", offset: " << offset - << ", num_splits:" << num_splits; + VLOG(10) << "FlashAttn bwd seed: " << seed << ", offset: " << offset; int64_t seqlen_q = ((max_seqlen_q + 16 - 1) / 16) * 16; DenseTensor dsoftmax = Empty(ctx, {batch_size, num_heads, seqlen_q}); @@ -174,6 +170,9 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, int64_t q_size = total_q * num_heads * head_size; ComputeScaleQ(ctx, q_size, scale, dq->data(), dq->data()); +#else + RaiseNotSupportedError(); +#endif } template diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 2f3922093eac3..bcf8791d3c17f 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -21,12 +21,8 @@ #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/reshape_kernel.h" - -#ifdef PADDLE_WITH_FLASHATTN -#include "paddle/phi/backends/dynload/flashattn.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h" -#endif +#include "paddle/phi/kernels/reshape_kernel.h" DECLARE_bool(cudnn_deterministic); @@ -54,6 +50,7 @@ void FlashAttnWithMaskUnpaddedImpl( DenseTensor* softmax, DenseTensor* softmax_lse, DenseTensor* seed_offset) { +#ifdef PADDLE_WITH_FLASHATTN cudaStream_t stream = ctx.stream(); auto dims = q.dims(); @@ -189,6 +186,9 @@ void FlashAttnWithMaskUnpaddedImpl( mask_dims.data() ? mask_dims.data() : nullptr, nullptr); CheckFlashAttnStatus(succ); +#else + RaiseNotSupportedError(); +#endif } template diff --git a/paddle/phi/kernels/gpu/flash_attn_utils.h b/paddle/phi/kernels/gpu/flash_attn_utils.h index e3988658db51f..00ba036df09ba 100644 --- a/paddle/phi/kernels/gpu/flash_attn_utils.h +++ b/paddle/phi/kernels/gpu/flash_attn_utils.h @@ -24,6 +24,7 @@ namespace phi { +#ifdef PADDLE_WITH_FLASHATTN static std::pair GenerateRNGState( const GPUContext& ctx, const paddle::optional& fixed_seed_offset, @@ -208,12 +209,6 @@ static void CheckFlashAttnStatus(const bool status) { phi::dynload::flash_attn_error())); } -static void RaiseNotSupportedError() { - PADDLE_THROW( - phi::errors::Unimplemented("FlashAttention is unsupported, please check " - "the GPU compability and CUDA Version.")); -} - template __global__ void SimleScaleKernel(const T* input, int64_t numel, @@ -259,5 +254,12 @@ static std::vector GetAttnMaskDims(const DenseTensor* attn_mask) { } return mask_dim_4d; } +#endif + +static void RaiseNotSupportedError() { + PADDLE_THROW( + phi::errors::Unimplemented("FlashAttention is unsupported, please check " + "the GPU compability and CUDA Version.")); +} } // namespace phi