Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional mask support on FA2 #57276

Merged
merged 13 commits into from
Sep 24, 2023
Merged

Additional mask support on FA2 #57276

merged 13 commits into from
Sep 24, 2023

Conversation

umiswing
Copy link
Member

@umiswing umiswing commented Sep 13, 2023

PR types

Performance optimization

PR changes

OPs

Description

Pcard-70459
为更新Mask FA的支持,依赖PaddlePaddle/flash-attention#19

@paddle-bot
Copy link

paddle-bot bot commented Sep 13, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot
Copy link

paddle-bot bot commented Sep 13, 2023

✅ This PR's description meets the template requirements!
Please wait for other CI results.

@umiswing umiswing changed the title [WIP] Additional mask support on FA2 Additional mask support on FA2 Sep 21, 2023
params.seed,
params.offset,
params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr,
params.mask_dims.data());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mask_dims传参这里,需要加上对params.attn_mask_tensor的判断吗?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mask_dims传参这里,需要加上对params.attn_mask_tensor的判断吗?

fa2 capi里,会判断只有attn_mask不为nullptr时才访问mask_dims。所以功能上,不需要加判断了。不过需要注意的是,attn_mask_tensornullptr时,mask_dims是一个空的vector。c++标准没有定义empty vectordata()会返回什么,根据编译器实际的实现,可能是nullptr,也可能是任意值。

auto rank = origin_dims.size();
PADDLE_ENFORCE_GE(
rank,
4,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果rank本来就是4的话,那L68 - L75的维度变换就没有必要了。之前之所以加这个逻辑,是因为蛋白质里面的维度是5-D的,需要变换成4-D。

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM and great work~

@Xreki Xreki merged commit 19a8f0a into PaddlePaddle:develop Sep 24, 2023
27 checks passed
AnnaTrainingG pushed a commit to AnnaTrainingG/Paddle that referenced this pull request Sep 26, 2023
* Add addition mask support. Tested on FlashAttnKernel.

* Fix bug in fwd (temporarily).
Add masked support on bwd.
Unpadded kernel to be tested.

* Add unscale on padded kernel.

* Add varlen mask.

* Remove redundant compute_scale_q

* Remove redundant comment.
Fix ci: PADDLE_ENFORCE format.
Remove test case: return_softmax && dropout==0

* Add mask type check.

* Update submodules.
AnnaTrainingG pushed a commit to AnnaTrainingG/Paddle that referenced this pull request Sep 26, 2023
* Add addition mask support. Tested on FlashAttnKernel.

* Fix bug in fwd (temporarily).
Add masked support on bwd.
Unpadded kernel to be tested.

* Add unscale on padded kernel.

* Add varlen mask.

* Remove redundant compute_scale_q

* Remove redundant comment.
Fix ci: PADDLE_ENFORCE format.
Remove test case: return_softmax && dropout==0

* Add mask type check.

* Update submodules.
Frida-a pushed a commit to Frida-a/Paddle that referenced this pull request Oct 14, 2023
* Add addition mask support. Tested on FlashAttnKernel.

* Fix bug in fwd (temporarily).
Add masked support on bwd.
Unpadded kernel to be tested.

* Add unscale on padded kernel.

* Add varlen mask.

* Remove redundant compute_scale_q

* Remove redundant comment.
Fix ci: PADDLE_ENFORCE format.
Remove test case: return_softmax && dropout==0

* Add mask type check.

* Update submodules.
jiahy0825 pushed a commit to jiahy0825/Paddle that referenced this pull request Oct 16, 2023
* Add addition mask support. Tested on FlashAttnKernel.

* Fix bug in fwd (temporarily).
Add masked support on bwd.
Unpadded kernel to be tested.

* Add unscale on padded kernel.

* Add varlen mask.

* Remove redundant compute_scale_q

* Remove redundant comment.
Fix ci: PADDLE_ENFORCE format.
Remove test case: return_softmax && dropout==0

* Add mask type check.

* Update submodules.
AnnaTrainingG pushed a commit to AnnaTrainingG/Paddle that referenced this pull request Nov 6, 2023
* Add addition mask support. Tested on FlashAttnKernel.

* Fix bug in fwd (temporarily).
Add masked support on bwd.
Unpadded kernel to be tested.

* Add unscale on padded kernel.

* Add varlen mask.

* Remove redundant compute_scale_q

* Remove redundant comment.
Fix ci: PADDLE_ENFORCE format.
Remove test case: return_softmax && dropout==0

* Add mask type check.

* Update submodules.
AnnaTrainingG pushed a commit to AnnaTrainingG/Paddle that referenced this pull request Nov 6, 2023
* Add addition mask support. Tested on FlashAttnKernel.

* Fix bug in fwd (temporarily).
Add masked support on bwd.
Unpadded kernel to be tested.

* Add unscale on padded kernel.

* Add varlen mask.

* Remove redundant compute_scale_q

* Remove redundant comment.
Fix ci: PADDLE_ENFORCE format.
Remove test case: return_softmax && dropout==0

* Add mask type check.

* Update submodules.
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
* Add addition mask support. Tested on FlashAttnKernel.

* Fix bug in fwd (temporarily).
Add masked support on bwd.
Unpadded kernel to be tested.

* Add unscale on padded kernel.

* Add varlen mask.

* Remove redundant compute_scale_q

* Remove redundant comment.
Fix ci: PADDLE_ENFORCE format.
Remove test case: return_softmax && dropout==0

* Add mask type check.

* Update submodules.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants