Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add varlen support to AOTriton's Flash Attention #31

Merged
merged 16 commits into from
Jun 20, 2024
Merged

Conversation

xinyazhang
Copy link
Collaborator

@xinyazhang xinyazhang commented Jun 14, 2024

Varlen Flash Attention is implemented by two new APIs: attn_fwd_compact_varlen and attn_bwd_compact_varlen, with the same set of kernels. Check include/aotriton/flash.h for their details.

Note:

  1. the bias tensor input b is reserved. Uses should pass TensorView<4>::get_null_tensor() for this argument for now. Any other inputs are not supported nor tested.
  2. The varlen API still expects Rank-4 tensors, which uniforms the code b/w varlen and non-varlen. The API also expects the size of the first dimension (batch) is exact 1. Extra batches will not be processed. This interface is slightly different from Tri Dao's implementation (https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L500-L502), and users are supposed to use torch.transpose and torch.unsqueeze to match the API.
  3. This PR also refactored the Triton kernel, and unified the compiled data type of sequence length related arguments to int32_t

Copy link
Contributor

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a quick look and realized your implementation should work for the qkv_layout of THD in NVTE.

How about for the layout of BSHD/BHSD with padding mask and cu_seqlen, as we discussed before?

@xinyazhang
Copy link
Collaborator Author

How about for the layout of BSHD/BHSD with padding mask and cu_seqlen, as we discussed before?

It's in the triton kernel but not exposed as C++ API. They will be added in a separate PR as attn_?wd_padded_varlen

@wangye805
Copy link
Contributor

How about for the layout of BSHD/BHSD with padding mask and cu_seqlen, as we discussed before?

It's in the triton kernel but not exposed as C++ API. They will be added in a separate PR as attn_?wd_padded_varlen

Do you mind pointing me to the triton source codes for padded_varlen?

@xinyazhang
Copy link
Collaborator Author

Do you mind pointing me to the triton source codes for padded_varlen?

else: # < 0 for padded seqlen
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
if start_m * BLOCK_M >= seqlen_q:
return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
# Varlen, but padded to Rank 4 tensor
cu_seqlens_q_start = 0
cu_seqlens_k_start = 0
batch_index = off_z

Copy link
Contributor

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@groenenboomj groenenboomj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking to add more support for more layouts later but that will be a different piece of work. This looks to meet needs for varlen.

@@ -31,6 +31,25 @@ attn_fwd(T4 q, // batch_size x num_heads x seqlen_q x head_size
bool is_causal,
aotriton::Stream stream);

hipError_t
attn_fwd_compact_varlen(T4 q, // 1 x num_heads x total_q x head_size, total_q := \sum_{i=0}^{b} s_i
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine for this PR but will we be adding some layout flag or defaulting to this for other layouts?

Copy link
Collaborator Author

@xinyazhang xinyazhang Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by layout flag? The current plan is to let all inputs be BHSD layout.
If new layouts are needed we are going to add new APIs instead of changing existing ones.

@xinyazhang xinyazhang merged commit 88eae51 into main Jun 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants