Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle Inference] Fix mmha when src_mask is not equal to zero #57936

Merged
merged 5 commits into from
Oct 11, 2023

Conversation

xiaoxiaohehe001
Copy link
Contributor

@xiaoxiaohehe001 xiaoxiaohehe001 commented Oct 8, 2023

PR types

Others

PR changes

Others

Description

Fix mmha diff when src_mask is not equal to zero.
Pcard-71502

@paddle-bot
Copy link

paddle-bot bot commented Oct 8, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -652,7 +652,7 @@ __global__ void masked_multihead_attention_kernel(
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_seq_length + ti;
if (ti < act_time_step) {
if (ti < act_time_step + 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不需要加1

@@ -674,7 +674,7 @@ __global__ void masked_multihead_attention_kernel(
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q, k, params.inv_sqrt_dh);

// bool is_mask = false;
if (ti < act_time_step && tid % THREADS_PER_KEY == 0) {
if (ti < act_time_step + 1 && tid % THREADS_PER_KEY == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的加1相关的逻辑 是否可以放到613行,加入一个判断,因为大多数情况下当前位置都是不会加mask的。

@xiaoxiaohehe001 xiaoxiaohehe001 changed the title [Paddle Inference] Fix mmha when src_mask is not equal to zero. [Paddle Inference] Fix mmha when src_mask is not equal to zero Oct 10, 2023
@carryyu carryyu self-requested a review October 10, 2023 08:00
carryyu
carryyu previously approved these changes Oct 10, 2023
@@ -609,6 +609,11 @@ __global__ void masked_multihead_attention_kernel(
// bi * (params.timestep + 1) + params.timestep];
// qk += static_cast<float>(mask);
qk *= params.inv_sqrt_dh;
auto mask_bhi = params.mask_broadcast_num_heads ? bi : bhi;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分可以放到if分支里面,当有attn mask的时候才计算mask_bhi

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done~

@carryyu carryyu self-requested a review October 10, 2023 08:02
@carryyu carryyu merged commit d78ec37 into PaddlePaddle:develop Oct 11, 2023
27 checks passed
Frida-a pushed a commit to Frida-a/Paddle that referenced this pull request Oct 14, 2023
…ePaddle#57936)

* fix_mmha_scrmask

* fix_mmha_scrmask

* remove_mask_to_qk_smem_act_time_step

* remove_add_!

* mask_bhi
jiahy0825 pushed a commit to jiahy0825/Paddle that referenced this pull request Oct 16, 2023
…ePaddle#57936)

* fix_mmha_scrmask

* fix_mmha_scrmask

* remove_mask_to_qk_smem_act_time_step

* remove_add_!

* mask_bhi
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
…ePaddle#57936)

* fix_mmha_scrmask

* fix_mmha_scrmask

* remove_mask_to_qk_smem_act_time_step

* remove_add_!

* mask_bhi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants