Skip to content

Commit

Permalink
[FA2] split-q + tiling-qk D=512 performance🎉 (#179)
Browse files Browse the repository at this point in the history
* Update README.md

* Update README.md

* Update flash_attn_mma_tiling_qk.cu

* Update flash_attn_mma_tiling_qkv.cu

* Update flash_attn_mma_swizzle_qkv.cu
  • Loading branch information
DefTruth authored Dec 23, 2024
1 parent 0a74947 commit 75d3744
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 12 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle|
|✔️|✔️|✔️|?|

Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. Moreover, on NVIDIA L20, [📚 Split Q + QK Fine-grained Tiling](mma-tiling-qk) can achieve **81 TFLOPS (D=512)** that almost **~1.4x** 🎉 faster than SDPA(EFFICIENT_ATTENTION). However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~
Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. Moreover, on NVIDIA L20, [📚 Split Q + QK Fine-grained Tiling](#mma-tiling-qk) can achieve **81 TFLOPS (D=512)** that almost **~1.4x** 🎉 faster than SDPA(EFFICIENT_ATTENTION). However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~

- Example: B=1, H=8, N=8192, `D=64` (NVIDIA RTX 3080 Laptop), Faster than FA2~🎉🎉
```bash
Expand Down
2 changes: 1 addition & 1 deletion kernels/flash-attn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2/SDPA on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~

For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. Moreover, on NVIDIA L20, [📚 Split Q + QK Fine-grained Tiling](mma-tiling-qk) can achieve **81 TFLOPS (D=512)** that almost **~1.4x** 🎉 faster than SDPA(EFFICIENT_ATTENTION).
For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. Moreover, on NVIDIA L20, [📚 Split Q + QK Fine-grained Tiling](#mma-tiling-qk) can achieve **81 TFLOPS (D=512)** that almost **~1.4x** 🎉 faster than SDPA(EFFICIENT_ATTENTION).


- Example: B=1, H=8, N=8192, `D=64` (NVIDIA RTX 3080 Laptop), Faster than FA2~🎉🎉
Expand Down
2 changes: 2 additions & 0 deletions kernels/flash-attn/mma/flash_attn_mma_swizzle_qkv.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// TODO: Manually apply SMEM swizzling instead of padding in
// Split-Q kernels to reduce bank conflicts.
8 changes: 4 additions & 4 deletions kernels/flash-attn/mma/flash_attn_mma_tiling_qk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@
// | warp_QP 6 | MMA 6 ... MMA 6 (x8) |
// | warp_QP 7 | MMA 7 ... MMA 7 (x8) |

// Fine grain tiling (MMA level) for Q, K the cause constant SRAM size 64*kMmaAtomK,
// and O(kMmaAtomK*d) SRAM complexity for V, thus, the SRAM complexity is O(kMmaAtomK*d).
// Thus, this kernel can extend D(headdim) to 1024. Performance is continuously being
// optimized. Stay tuned for updates ~
// Fine-grained tiling at the MMA level for Q and K results in a constant SRAM usage of
// 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to
// an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to
// extend D (head dimension) up to 1024. Performance is stay tuned for updates ~

template<
const int kHeadDim, // Headdim, 32,64,128
Expand Down
16 changes: 10 additions & 6 deletions kernels/flash-attn/mma/flash_attn_mma_tiling_qkv.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
// TODO: flash_attn_mma_stages_split_q_full_tiling_kernel
// fully tiling for headdim(d) while perform P@V, kMmaAtomK * (kMmaAtomN)
// NOTE: For R_V[kWarpTileHeadDimV][2], kWarpTileHeadDimV will increase with d.
// so, for large d, R_V will need more registers and cause performance down.
// We have to find a way to apply MMA level tiling for V(R_V) for large d.
// Also, R_O and R_D will bound by registers resources.
// TODO: Implement flash_attn_mma_stages_split_q_tiling_qkv_kernel
// Fully tile the head dimension (d) while performing P@V with dimensions kMmaAtomK * kMmaAtomN.
//
// NOTE: For R_V[kWarpTileHeadDimV][2], kWarpTileHeadDimV increases as d grows.
// As a result, for large values of d, R_V will require more registers, potentially
// leading to decreased performance. We need to find a way to apply MMA-level tiling
// for V (R_V) when d is large to mitigate this issue.
//
// Additionally, R_O and R_D are also constrained by register resources, which must
// be considered in the optimization process.

0 comments on commit 75d3744

Please sign in to comment.