-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Paddle Inference] Fix mmha when src_mask is not equal to zero #57936
[Paddle Inference] Fix mmha when src_mask is not equal to zero #57936
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -652,7 +652,7 @@ __global__ void masked_multihead_attention_kernel( | |||
#pragma unroll | |||
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { | |||
int jj = ii * params.max_seq_length + ti; | |||
if (ti < act_time_step) { | |||
if (ti < act_time_step + 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不需要加1
@@ -674,7 +674,7 @@ __global__ void masked_multihead_attention_kernel( | |||
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q, k, params.inv_sqrt_dh); | |||
|
|||
// bool is_mask = false; | |||
if (ti < act_time_step && tid % THREADS_PER_KEY == 0) { | |||
if (ti < act_time_step + 1 && tid % THREADS_PER_KEY == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的加1相关的逻辑 是否可以放到613行,加入一个判断,因为大多数情况下当前位置都是不会加mask的。
@@ -609,6 +609,11 @@ __global__ void masked_multihead_attention_kernel( | |||
// bi * (params.timestep + 1) + params.timestep]; | |||
// qk += static_cast<float>(mask); | |||
qk *= params.inv_sqrt_dh; | |||
auto mask_bhi = params.mask_broadcast_num_heads ? bi : bhi; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分可以放到if分支里面,当有attn mask的时候才计算mask_bhi
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done~
…ePaddle#57936) * fix_mmha_scrmask * fix_mmha_scrmask * remove_mask_to_qk_smem_act_time_step * remove_add_! * mask_bhi
…ePaddle#57936) * fix_mmha_scrmask * fix_mmha_scrmask * remove_mask_to_qk_smem_act_time_step * remove_add_! * mask_bhi
…ePaddle#57936) * fix_mmha_scrmask * fix_mmha_scrmask * remove_mask_to_qk_smem_act_time_step * remove_add_! * mask_bhi
PR types
Others
PR changes
Others
Description
Fix mmha diff when src_mask is not equal to zero.
Pcard-71502