-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Additional mask support on FA2 #57276
Conversation
Add masked support on bwd. Unpadded kernel to be tested.
你的PR提交成功,感谢你对开源项目的贡献! |
✅ This PR's description meets the template requirements! |
Fix ci: PADDLE_ENFORCE format. Remove test case: return_softmax && dropout==0
params.seed, | ||
params.offset, | ||
params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, | ||
params.mask_dims.data()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mask_dims
传参这里,需要加上对params.attn_mask_tensor
的判断吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mask_dims
传参这里,需要加上对params.attn_mask_tensor
的判断吗?
fa2 capi里,会判断只有attn_mask
不为nullptr
时才访问mask_dims
。所以功能上,不需要加判断了。不过需要注意的是,attn_mask_tensor
为nullptr
时,mask_dims
是一个空的vector
。c++标准没有定义empty vector
的data()
会返回什么,根据编译器实际的实现,可能是nullptr
,也可能是任意值。
auto rank = origin_dims.size(); | ||
PADDLE_ENFORCE_GE( | ||
rank, | ||
4, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果rank
本来就是4的话,那L68 - L75的维度变换就没有必要了。之前之所以加这个逻辑,是因为蛋白质里面的维度是5-D的,需要变换成4-D。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM and great work~
* Add addition mask support. Tested on FlashAttnKernel. * Fix bug in fwd (temporarily). Add masked support on bwd. Unpadded kernel to be tested. * Add unscale on padded kernel. * Add varlen mask. * Remove redundant compute_scale_q * Remove redundant comment. Fix ci: PADDLE_ENFORCE format. Remove test case: return_softmax && dropout==0 * Add mask type check. * Update submodules.
* Add addition mask support. Tested on FlashAttnKernel. * Fix bug in fwd (temporarily). Add masked support on bwd. Unpadded kernel to be tested. * Add unscale on padded kernel. * Add varlen mask. * Remove redundant compute_scale_q * Remove redundant comment. Fix ci: PADDLE_ENFORCE format. Remove test case: return_softmax && dropout==0 * Add mask type check. * Update submodules.
* Add addition mask support. Tested on FlashAttnKernel. * Fix bug in fwd (temporarily). Add masked support on bwd. Unpadded kernel to be tested. * Add unscale on padded kernel. * Add varlen mask. * Remove redundant compute_scale_q * Remove redundant comment. Fix ci: PADDLE_ENFORCE format. Remove test case: return_softmax && dropout==0 * Add mask type check. * Update submodules.
* Add addition mask support. Tested on FlashAttnKernel. * Fix bug in fwd (temporarily). Add masked support on bwd. Unpadded kernel to be tested. * Add unscale on padded kernel. * Add varlen mask. * Remove redundant compute_scale_q * Remove redundant comment. Fix ci: PADDLE_ENFORCE format. Remove test case: return_softmax && dropout==0 * Add mask type check. * Update submodules.
* Add addition mask support. Tested on FlashAttnKernel. * Fix bug in fwd (temporarily). Add masked support on bwd. Unpadded kernel to be tested. * Add unscale on padded kernel. * Add varlen mask. * Remove redundant compute_scale_q * Remove redundant comment. Fix ci: PADDLE_ENFORCE format. Remove test case: return_softmax && dropout==0 * Add mask type check. * Update submodules.
* Add addition mask support. Tested on FlashAttnKernel. * Fix bug in fwd (temporarily). Add masked support on bwd. Unpadded kernel to be tested. * Add unscale on padded kernel. * Add varlen mask. * Remove redundant compute_scale_q * Remove redundant comment. Fix ci: PADDLE_ENFORCE format. Remove test case: return_softmax && dropout==0 * Add mask type check. * Update submodules.
* Add addition mask support. Tested on FlashAttnKernel. * Fix bug in fwd (temporarily). Add masked support on bwd. Unpadded kernel to be tested. * Add unscale on padded kernel. * Add varlen mask. * Remove redundant compute_scale_q * Remove redundant comment. Fix ci: PADDLE_ENFORCE format. Remove test case: return_softmax && dropout==0 * Add mask type check. * Update submodules.
PR types
Performance optimization
PR changes
OPs
Description
Pcard-70459
为更新Mask FA的支持,依赖PaddlePaddle/flash-attention#19 。