diff --git a/README.md b/README.md index 2cef7adb..a9b1f75c 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d |✔️|✔️|✔️|✔️| -I have also implemented **FlashAttention-2** using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp, Shared KV SMEM, **Fully Shared QKV SMEM**, **Prefetch Q s2r**, Collective Store, etc. Please refer to [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for more details. +I have also implemented **FlashAttention-2** using pure MMA PTX instructions, which supports features such as Multi-Stages, Tile MMA, Tile Warp, Shared KV SMEM, **Fully Shared QKV SMEM**, **Prefetch Q s2r**, **Prefetch K/V g2s**, **QK Fine-grained Tiling**, Collective Store, etc. Please refer to [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for more details. ![flash-attn-mma](https://github.com/user-attachments/assets/6f66796d-44d5-4ec1-b224-af997bd152b2) diff --git a/kernels/flash-attn/mma/flash_attn_mma_tiling_qk.cu b/kernels/flash-attn/mma/flash_attn_mma_tiling_qk.cu index b57b24c5..0de6d278 100644 --- a/kernels/flash-attn/mma/flash_attn_mma_tiling_qk.cu +++ b/kernels/flash-attn/mma/flash_attn_mma_tiling_qk.cu @@ -45,7 +45,8 @@ // Fine-grained tiling at the MMA level for Q and K results in a constant SRAM usage of // 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to // an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to -// extend D (head dimension) up to 1024. Performance is stay tuned for updates ~ +// extend D (head dimension) up to 1024. Performance optimizations are ongoing. +// Stay tuned for updates ~ template< const int kHeadDim, // Headdim, 32,64,128