Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attn_mask supported for FlashAttnKernel. #55969

Merged
merged 14 commits into from
Aug 7, 2023
10 changes: 6 additions & 4 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -818,8 +818,9 @@
inplace : (out_grad -> x_grad)

- backward_op : flash_attn_grad
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false)
optional : attn_mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed_seed_offset也要加到optional吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed_seed_offset是原有参数,类型为const Tensor,不是optional

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉原写法也不是很合理,先保持原样吧

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没问题了,反向的输入是seed_offset,是前向的输出,是必须的。

output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnGradInferMeta
Expand All @@ -829,8 +830,9 @@
data_type: q

- backward_op : flash_attn_unpadded_grad
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
optional : attn_mask
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnGradInferMeta
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -910,9 +910,9 @@
backward : fill_diagonal_tensor_grad

- op : flash_attn
args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
args : (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
optional : fixed_seed_offset, attn_mask
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
Expand All @@ -923,9 +923,9 @@
backward : flash_attn_grad

- op : flash_attn_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset
optional : fixed_seed_offset , attn_mask
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/flash_attn_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const DenseTensor& dout,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
Expand All @@ -47,6 +48,7 @@ void FlashAttnGradKernel(const Context& ctx,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
const DenseTensor& dout,
float dropout,
bool causal,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/flash_attn_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ void FlashAttnUnpaddedKernel(
const DenseTensor& cu_seqlens_q,
const DenseTensor& cu_seqlens_k,
const paddle::optional<DenseTensor>& fixed_seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
float scale,
Expand All @@ -47,6 +48,7 @@ void FlashAttnKernel(const Context& ctx,
const DenseTensor& k,
const DenseTensor& v,
const paddle::optional<DenseTensor>& fixed_seed_offset,
const paddle::optional<DenseTensor>& attn_mask,
float dropout,
bool causal,
bool return_softmax,
Expand Down
Loading