-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add attn_mask supported for FlashAttnKernel. #55969
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
) | ||
|
||
out = scaled_dot_product_attention( | ||
q, k, v, m, self.dropout, self.causal, fixed_seed_offset=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scaled_dot_product_attention
是加了Tensor mask m
的,但attention_naive
没有加m
@@ -293,6 +293,48 @@ def test_all(self): | |||
fetches_result[0], out_, rtol=5e-03, atol=1e-03 | |||
) | |||
|
|||
def test_dot_scale_product(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
重新实现一个单测的类会比较好,因为后面还有好几个单测是以TestFlashAttentionAPI
为基类,并且修改了shape、return_softmax等配置,当前这种修改方式会都会测试mask版本
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加一下如下判断,跳过不支持的cuda版本和硬件:
if not core.is_compiled_with_cuda() or get_cuda_version() < 11030 or not is_sm_supported:
pass
} | ||
} else { | ||
succ = | ||
phi::dynload::flash_attn_fwd(q.data(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码再封装一下吧,这个函数太长了
paddle/phi/api/yaml/backward.yaml
Outdated
@@ -818,8 +818,9 @@ | |||
inplace : (out_grad -> x_grad) | |||
|
|||
- backward_op : flash_attn_grad | |||
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) | |||
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false) | |||
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset,Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
,Tensor attn_mask
的,
后面加空格
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false) | ||
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset,Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) | ||
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false) | ||
optional : attn_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed_seed_offset
也要加到optional吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed_seed_offset是原有参数,类型为const Tensor,不是optional
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉原写法也不是很合理,先保持原样吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没问题了,反向的输入是seed_offset
,是前向的输出,是必须的。
@@ -33,6 +34,24 @@ DECLARE_bool(cudnn_deterministic); | |||
|
|||
namespace phi { | |||
|
|||
// template <typename T, typename Context> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除无用的代码
SimleScaleWithMaskKernel<<<gpu_config.block_per_grid, | ||
gpu_config.thread_per_block, | ||
0, | ||
ctx.stream()>>>(q_size, scale, q_ptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
L163定义了一个临时Tensor,SimleScaleWithMaskKernel
是不是应该读取原始输入q
的数据,scale后写入q_
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数的作用是在原地scale
|
||
auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, q_size, 1); | ||
DenseTensor q_(q); | ||
T* q_ptr = static_cast<T*>(q_.data<T>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tensor需要先Resize
设置维度,再Alloc
空间,否则是没有分配显存的。
int64_t q_size = total_q * num_heads * head_size; | ||
|
||
auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, q_size, 1); | ||
DenseTensor q_(q); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
q_
-> scaled_q
fixed_seed_offset=None, | ||
return_softmax=False, | ||
training=True, | ||
rng_name="", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
只新增training
参数,其他的参数先不要加
bool succ; | ||
|
||
if (attn_mask.get_ptr()) { | ||
scale = 1.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加一个PADDLE_ENFORCE
检查,即传入了attn_mask
时,is_causal
不能为true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注意报错信息的写法,使用:
PADDLE_ENFORCE_NE(causal, true, phi::errors::InvalidArguemnts(....));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scale是通过参数传进来的,不应该修改输入scale
的值,而是flash_attn
的scale
参数传1
,后续SimleScaleWithMaskKernel
函数调用传这个输入的scale
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false) | ||
forward : flash_attn (Tensor q, Tensor k, Tensor v, Tensor fixed_seed_offset,Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) | ||
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false) | ||
optional : attn_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉原写法也不是很合理,先保持原样吧
bool succ; | ||
|
||
if (attn_mask.get_ptr()) { | ||
scale = 1.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注意报错信息的写法,使用:
PADDLE_ENFORCE_NE(causal, true, phi::errors::InvalidArguemnts(....));
bool succ; | ||
|
||
if (attn_mask.get_ptr()) { | ||
scale = 1.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scale是通过参数传进来的,不应该修改输入scale
的值,而是flash_attn
的scale
参数传1
,后续SimleScaleWithMaskKernel
函数调用传这个输入的scale
|
||
int64_t q_size = total_q * num_heads * head_size; | ||
DenseTensor scale_q; | ||
scale_q.ShareDataWith(q).Resize({total_q, num_heads, head_size}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ShareDataWith
不合适,因为ComputeScaleQ
中相当于会把输入的值改了,不确定会出现什么问题。scaled_q
应该自己申请一块空间,ComputeScaleQ
应该改成非inplace
的版本。
// compute scale Q | ||
ComputeScaleQ(ctx, q_size, scale_q.data<T>(), scale); | ||
|
||
scale = 1.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要修改函数输入参数的值
} | ||
DenseTensor workspace; | ||
if (workspace_size > 0) { | ||
workspace = Empty<float>(ctx, {int64_t(workspace_size / sizeof(float))}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
类型转换用static_cast<int64_t>
temp_rand_mask_dim.data() ? temp_rand_mask_dim.data() : nullptr, | ||
nullptr); | ||
PADDLE_ENFORCE_EQ( | ||
succ, true, phi::errors::External(phi::dynload::flash_attn_error())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
封装一个函数,并且报错信息多加一些,写成类似“Error in Flash-Attention,detail information is xxx”
,xxx是phi::dynload::flash_attn_error()
。
"attn_mask is not nullptr, causal can not be true")); | ||
|
||
int64_t q_size = total_q * num_heads * head_size; | ||
DenseTensor* scale_q = new DenseTensor; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接用DensorTensor scale_q;
定义Tensor。
attn_mask=None, | ||
dropout_p=0.0, | ||
is_causal=False, | ||
training=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
新增了参数,需要在L453添加参数说明的文档,可以参考下其他API
true, | ||
phi::errors::InvalidArgument( | ||
"attn_mask is not nullptr, causal can not be true")); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考如下函数,支持Tensor Mask的Flash-Attention功能存在一定的限制,即if (head_dim == 32 || head_dim == 64 || head_dim == 128)
,这里也需要加个PADDLE_ENFORCE
判断下:
Paddle/paddle/fluid/operators/fused/fused_gate_attention.h
Lines 203 to 217 in 9877fb8
bool CanUseFlashAttn() const { | |
#ifdef PADDLE_WITH_FLASHATTN | |
if (!std::is_same<T, phi::dtype::bfloat16>::value && | |
!std::is_same<T, phi::dtype::float16>::value) { | |
return false; | |
} | |
if (merge_qkv && batch_size == 1) { | |
if (head_dim == 32 || head_dim == 64 || head_dim == 128) { | |
return use_flash_attn; | |
} | |
} | |
#endif | |
return false; | |
} |
@@ -19,6 +19,74 @@ | |||
|
|||
namespace phi { | |||
|
|||
template <typename T, typename Context> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
该头文件中,不需要加这个函数声明。
const int64_t* mask_dims); | ||
|
||
template <typename T, typename Context> | ||
void FlashAttnFwd( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
该头文件中,不需要加这个函数声明。
|
||
PADDLE_ENFORCE_EQ(succ, | ||
true, | ||
"Error in Flash-Attention, detail information is ", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
报错信息依然需要加错误类型,phi::errors::External(...)
本地(40G-A100、CUDA11.8)单测执行结果如下:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for skipIf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for new args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for yaml change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Doc-Preview
的 CI 有一些问题,导致一直没过,暂时排查不出。先aprrove
PR types
Others
PR changes
Others
Description
Pcard-70459
FlashAttnKernel支持Tensor类型输入的attn_mask。
#55758 已经将支持常规Causal类型Mask的FlashAttention-2集成到框架,由于Tensor类型的Mask需要对FlashAttention-2进行Kernel功能增强,故本PR目前集成的是基于Flash-Attention-1的Mask。性能数据如下: