Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FA2] split-q + tiling-qk D=512 performance🎉 #179

Merged
merged 5 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle|
|✔️|✔️|✔️|?|

Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. Moreover, on NVIDIA L20, [📚 Split Q + QK Fine-grained Tiling](mma-tiling-qk) can achieve **81 TFLOPS (D=512)** that almost **~1.4x** 🎉 faster than SDPA(EFFICIENT_ATTENTION). However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~
Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. Moreover, on NVIDIA L20, [📚 Split Q + QK Fine-grained Tiling](#mma-tiling-qk) can achieve **81 TFLOPS (D=512)** that almost **~1.4x** 🎉 faster than SDPA(EFFICIENT_ATTENTION). However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~

- Example: B=1, H=8, N=8192, `D=64` (NVIDIA RTX 3080 Laptop), Faster than FA2~🎉🎉
```bash
Expand Down
2 changes: 1 addition & 1 deletion kernels/flash-attn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2/SDPA on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~

For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. Moreover, on NVIDIA L20, [📚 Split Q + QK Fine-grained Tiling](mma-tiling-qk) can achieve **81 TFLOPS (D=512)** that almost **~1.4x** 🎉 faster than SDPA(EFFICIENT_ATTENTION).
For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. Moreover, on NVIDIA L20, [📚 Split Q + QK Fine-grained Tiling](#mma-tiling-qk) can achieve **81 TFLOPS (D=512)** that almost **~1.4x** 🎉 faster than SDPA(EFFICIENT_ATTENTION).


- Example: B=1, H=8, N=8192, `D=64` (NVIDIA RTX 3080 Laptop), Faster than FA2~🎉🎉
Expand Down
2 changes: 2 additions & 0 deletions kernels/flash-attn/mma/flash_attn_mma_swizzle_qkv.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// TODO: Manually apply SMEM swizzling instead of padding in
// Split-Q kernels to reduce bank conflicts.
8 changes: 4 additions & 4 deletions kernels/flash-attn/mma/flash_attn_mma_tiling_qk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@
// | warp_QP 6 | MMA 6 ... MMA 6 (x8) |
// | warp_QP 7 | MMA 7 ... MMA 7 (x8) |

// Fine grain tiling (MMA level) for Q, K the cause constant SRAM size 64*kMmaAtomK,
// and O(kMmaAtomK*d) SRAM complexity for V, thus, the SRAM complexity is O(kMmaAtomK*d).
// Thus, this kernel can extend D(headdim) to 1024. Performance is continuously being
// optimized. Stay tuned for updates ~
// Fine-grained tiling at the MMA level for Q and K results in a constant SRAM usage of
// 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to
// an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to
// extend D (head dimension) up to 1024. Performance is stay tuned for updates ~

template<
const int kHeadDim, // Headdim, 32,64,128
Expand Down
16 changes: 10 additions & 6 deletions kernels/flash-attn/mma/flash_attn_mma_tiling_qkv.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
// TODO: flash_attn_mma_stages_split_q_full_tiling_kernel
// fully tiling for headdim(d) while perform P@V, kMmaAtomK * (kMmaAtomN)
// NOTE: For R_V[kWarpTileHeadDimV][2], kWarpTileHeadDimV will increase with d.
// so, for large d, R_V will need more registers and cause performance down.
// We have to find a way to apply MMA level tiling for V(R_V) for large d.
// Also, R_O and R_D will bound by registers resources.
// TODO: Implement flash_attn_mma_stages_split_q_tiling_qkv_kernel
// Fully tile the head dimension (d) while performing P@V with dimensions kMmaAtomK * kMmaAtomN.
//
// NOTE: For R_V[kWarpTileHeadDimV][2], kWarpTileHeadDimV increases as d grows.
// As a result, for large values of d, R_V will require more registers, potentially
// leading to decreased performance. We need to find a way to apply MMA-level tiling
// for V (R_V) when d is large to mitigate this issue.
//
// Additionally, R_O and R_D are also constrained by register resources, which must
// be considered in the optimization process.