-
Notifications
You must be signed in to change notification settings - Fork 183
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FA2] split-q + tiling-qk D=512 performance🎉 (#179)
* Update README.md * Update README.md * Update flash_attn_mma_tiling_qk.cu * Update flash_attn_mma_tiling_qkv.cu * Update flash_attn_mma_swizzle_qkv.cu
- Loading branch information
Showing
5 changed files
with
18 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
// TODO: Manually apply SMEM swizzling instead of padding in | ||
// Split-Q kernels to reduce bank conflicts. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,10 @@ | ||
// TODO: flash_attn_mma_stages_split_q_full_tiling_kernel | ||
// fully tiling for headdim(d) while perform P@V, kMmaAtomK * (kMmaAtomN) | ||
// NOTE: For R_V[kWarpTileHeadDimV][2], kWarpTileHeadDimV will increase with d. | ||
// so, for large d, R_V will need more registers and cause performance down. | ||
// We have to find a way to apply MMA level tiling for V(R_V) for large d. | ||
// Also, R_O and R_D will bound by registers resources. | ||
// TODO: Implement flash_attn_mma_stages_split_q_tiling_qkv_kernel | ||
// Fully tile the head dimension (d) while performing P@V with dimensions kMmaAtomK * kMmaAtomN. | ||
// | ||
// NOTE: For R_V[kWarpTileHeadDimV][2], kWarpTileHeadDimV increases as d grows. | ||
// As a result, for large values of d, R_V will require more registers, potentially | ||
// leading to decreased performance. We need to find a way to apply MMA-level tiling | ||
// for V (R_V) when d is large to mitigate this issue. | ||
// | ||
// Additionally, R_O and R_D are also constrained by register resources, which must | ||
// be considered in the optimization process. |