-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add flash-attn #41
Add flash-attn #41
Conversation
Some tests with a 1B MQA model (santacoder's config: num_layers 24, num_heads 16, hidden_size 2048), bf16, on 1 A100 gpu. With flash-attn, this model can be trained with sequences of length up to Flash-attn is faster, especially for longer sequences:
|
# [sq, b, 1, hn] -> [sq, b, np, hn] | ||
key_layer = key_layer.expand((sq, b, np, hn)) | ||
value_layer = value_layer.expand((sq, b, np, hn)) | ||
q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That looks very bad. Megatron uses the s b
format precisely to avoid this kind of reshape. If FlashAttention uses b s
we should use that format instead. It should be OK to just comment the two conversions, at least without sequence parallelism (SP would need extra changes but we probably won't use it anyway) https://github.com/bigcode-project/Megatron-LM/blob/multi-query-attention/megatron/model/language_model.py#L240 https://github.com/bigcode-project/Megatron-LM/blob/multi-query-attention/megatron/model/gpt_model.py#L43
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are suggesting to use b s
through the whole transformer model?
I think that would require a big chunk of refactoring work, and also testing to make sure we are not breaking anything.
Looking at the nice performance improvements that flash-attn brings, I wouldn't take the risk of breaking everything else just to avoid a transpose here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually the order only matters for attention (and sequence parallell), so it should just be about bypassing these two lines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The transposes have a big impact on memory usage and a moderate one on speed (I think) so it's quite important.
sq, b, np, hn = query_layer.size() | ||
# Expand kv to be compatible with flash-attn implementation | ||
# [sq, b, 1, hn] -> [sq, b, np, hn] | ||
key_layer = key_layer.expand((sq, b, np, hn)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if FlashAttention would work with just expand, that doesn't allocate new memory. If it were to work we would get the full benefits of FlashAttention for MQA. (I would expect it to enforce contiguous tensors but it's worth checking)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you asking whether if it would still work if we remove the call to .contiguous()
on the next line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would almost certainly not work (transposed tensors are much harder to deal with), but maybe if we do the expand after the transpose or skip the transpose altogether.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The remaining comments on eliminating unnecessary ops are not essential and can be looked into later.
Great job!
Flash-attention, based on NVIDIA#267
with support for MQA