-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add varlen support to AOTriton's Flash Attention #31
Conversation
…_q/k" (Note this is not the exact SQL statement)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took a quick look and realized your implementation should work for the qkv_layout of THD in NVTE.
How about for the layout of BSHD/BHSD with padding mask and cu_seqlen, as we discussed before?
It's in the triton kernel but not exposed as C++ API. They will be added in a separate PR as |
Do you mind pointing me to the triton source codes for padded_varlen? |
aotriton/tritonsrc/fwd_kernel.py Lines 73 to 85 in c8551b1
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking to add more support for more layouts later but that will be a different piece of work. This looks to meet needs for varlen.
@@ -31,6 +31,25 @@ attn_fwd(T4 q, // batch_size x num_heads x seqlen_q x head_size | |||
bool is_causal, | |||
aotriton::Stream stream); | |||
|
|||
hipError_t | |||
attn_fwd_compact_varlen(T4 q, // 1 x num_heads x total_q x head_size, total_q := \sum_{i=0}^{b} s_i |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine for this PR but will we be adding some layout flag or defaulting to this for other layouts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by layout flag? The current plan is to let all inputs be BHSD layout.
If new layouts are needed we are going to add new APIs instead of changing existing ones.
Varlen Flash Attention is implemented by two new APIs:
attn_fwd_compact_varlen
andattn_bwd_compact_varlen
, with the same set of kernels. Checkinclude/aotriton/flash.h
for their details.Note:
b
is reserved. Uses should passTensorView<4>::get_null_tensor()
for this argument for now. Any other inputs are not supported nor tested.torch.transpose
andtorch.unsqueeze
to match the API.