diff --git a/xformers/csrc/attention/cuda/fmha/kernel_backward.h b/xformers/csrc/attention/cuda/fmha/kernel_backward.h index 408054f4c4..1efdde75a2 100644 --- a/xformers/csrc/attention/cuda/fmha/kernel_backward.h +++ b/xformers/csrc/attention/cuda/fmha/kernel_backward.h @@ -830,9 +830,9 @@ struct AttentionBackwardKernel { }; static void print_size() { // Field size -#define FSZ(f) int((sizeof(((SharedStorage*)0)->f))) +#define FSZ(f) int((sizeof(((SharedStoragePrologue*)0)->f))) - printf("Total smem: %d bytes\n", int(sizeof(SharedStorage))); + printf("Total smem: %d bytes\n", int(sizeof(SharedStoragePrologue))); printf(" persistent: %db\n", FSZ(persistent)); printf(" mm_qk_k: %db\n", FSZ(persistent.mm_qk_k)); printf(" p1: %db\n", FSZ(p1)); @@ -968,8 +968,8 @@ struct AttentionBackwardKernel { } p6; }; static void print_size() { -#define FIELD_SIZEOF(f) int((sizeof(((SharedStorage*)0)->f))) - printf("Total smem: %d bytes\n", int(sizeof(SharedStorage))); +#define FIELD_SIZEOF(f) int((sizeof(((SharedStorageNoPrologue*)0)->f))) + printf("Total smem: %d bytes\n", int(sizeof(SharedStorageNoPrologue))); printf(" persistent: %db\n", FIELD_SIZEOF(persistent)); printf(" p1: %db\n", FIELD_SIZEOF(p1)); printf(" p2: %db\n", FIELD_SIZEOF(p2)); diff --git a/xformers/csrc/attention/cuda/fmha/kernels/cutlassB.h b/xformers/csrc/attention/cuda/fmha/kernels/cutlassB.h index 464e15ff68..06159a973e 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/cutlassB.h +++ b/xformers/csrc/attention/cuda/fmha/kernels/cutlassB.h @@ -69,7 +69,7 @@ __global__ void __launch_bounds__( AttentionBackwardKernel::kMinBlocksPerSm) fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel::Params p); -template void dispatch_cutlassB_f16_sm50(T cb) { +template void dispatch_cutlassB_f16_sm50(T cb, int cc) { cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_sm50); cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_sm50); cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_sm50); @@ -154,7 +154,7 @@ __global__ void __launch_bounds__( AttentionBackwardKernel::kMinBlocksPerSm) fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm50(typename AttentionBackwardKernel::Params p); -template void dispatch_cutlassB_f32_sm50(T cb) { +template void dispatch_cutlassB_f32_sm50(T cb, int cc) { cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_sm50); cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_sm50); cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_sm50); @@ -271,7 +271,7 @@ __global__ void __launch_bounds__( AttentionBackwardKernel::kMinBlocksPerSm) fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); -template void dispatch_cutlassB_f16_sm70(T cb) { +template void dispatch_cutlassB_f16_sm70(T cb, int cc) { cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_sm70); cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_sm70); cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k128_sm70); @@ -364,7 +364,7 @@ __global__ void __launch_bounds__( AttentionBackwardKernel::kMinBlocksPerSm) fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm70(typename AttentionBackwardKernel::Params p); -template void dispatch_cutlassB_f32_sm70(T cb) { +template void dispatch_cutlassB_f32_sm70(T cb, int cc) { cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_sm70); cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_sm70); cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_sm70); @@ -481,7 +481,7 @@ __global__ void __launch_bounds__( AttentionBackwardKernel::kMinBlocksPerSm) fmha_cutlassB_f16_notaligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); -template void dispatch_cutlassB_f16_sm75(T cb) { +template void dispatch_cutlassB_f16_sm75(T cb, int cc) { cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_sm75); cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_sm75); cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k128_sm75); @@ -574,7 +574,7 @@ __global__ void __launch_bounds__( AttentionBackwardKernel::kMinBlocksPerSm) fmha_cutlassB_f32_notaligned_64x64_k65536_dropout_sm75(typename AttentionBackwardKernel::Params p); -template void dispatch_cutlassB_f32_sm75(T cb) { +template void dispatch_cutlassB_f32_sm75(T cb, int cc) { cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_sm75); cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_sm75); cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k128_sm75); @@ -602,6 +602,10 @@ __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) fmha_cutlassB_bf16_aligned_64x64_k64_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x64_k96_sm80(typename AttentionBackwardKernel::Params p); __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) @@ -643,9 +647,10 @@ __global__ void __launch_bounds__( AttentionBackwardKernel::kMinBlocksPerSm) fmha_cutlassB_bf16_aligned_64x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); -template void dispatch_cutlassB_bf16_sm80(T cb) { +template void dispatch_cutlassB_bf16_sm80(T cb, int cc) { cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k32_sm80); cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k64_sm80); + if (cc == 86 || cc == 89) cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x64_k96_sm80); cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x128_k128_sm80); cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_64x64_k128_sm80); cb(AttentionBackwardKernel(), fmha_cutlassB_bf16_aligned_128x64_k65536_sm80); @@ -667,6 +672,10 @@ __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) fmha_cutlassB_f16_aligned_64x64_k64_sm80(typename AttentionBackwardKernel::Params p); +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k96_sm80(typename AttentionBackwardKernel::Params p); __global__ void __launch_bounds__( AttentionBackwardKernel::kNumThreads, AttentionBackwardKernel::kMinBlocksPerSm) @@ -708,9 +717,10 @@ __global__ void __launch_bounds__( AttentionBackwardKernel::kMinBlocksPerSm) fmha_cutlassB_f16_aligned_64x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); -template void dispatch_cutlassB_f16_sm80(T cb) { +template void dispatch_cutlassB_f16_sm80(T cb, int cc) { cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k32_sm80); cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k64_sm80); + if (cc == 86 || cc == 89) cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k96_sm80); cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x128_k128_sm80); cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_64x64_k128_sm80); cb(AttentionBackwardKernel(), fmha_cutlassB_f16_aligned_128x64_k65536_sm80); @@ -773,7 +783,7 @@ __global__ void __launch_bounds__( AttentionBackwardKernel::kMinBlocksPerSm) fmha_cutlassB_f32_aligned_64x64_k65536_dropout_sm80(typename AttentionBackwardKernel::Params p); -template void dispatch_cutlassB_f32_sm80(T cb) { +template void dispatch_cutlassB_f32_sm80(T cb, int cc) { cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k32_sm80); cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_64x64_k64_sm80); cb(AttentionBackwardKernel(), fmha_cutlassB_f32_aligned_128x64_k128_sm80); @@ -793,31 +803,31 @@ template void dispatch_cutlassB(T cb, int cc = 0) { if (std::is_same::value && 50 <= cc && cc < 70) { - dispatch_cutlassB_f16_sm50(cb); + dispatch_cutlassB_f16_sm50(cb, cc); } if (std::is_same::value && 50 <= cc && cc < 70) { - dispatch_cutlassB_f32_sm50(cb); + dispatch_cutlassB_f32_sm50(cb, cc); } if (std::is_same::value && 70 <= cc && cc < 75) { - dispatch_cutlassB_f16_sm70(cb); + dispatch_cutlassB_f16_sm70(cb, cc); } if (std::is_same::value && 70 <= cc && cc < 75) { - dispatch_cutlassB_f32_sm70(cb); + dispatch_cutlassB_f32_sm70(cb, cc); } if (std::is_same::value && 75 <= cc && cc < 80) { - dispatch_cutlassB_f16_sm75(cb); + dispatch_cutlassB_f16_sm75(cb, cc); } if (std::is_same::value && 75 <= cc && cc < 80) { - dispatch_cutlassB_f32_sm75(cb); + dispatch_cutlassB_f32_sm75(cb, cc); } if (std::is_same::value && 80 <= cc && cc < 90) { - dispatch_cutlassB_bf16_sm80(cb); + dispatch_cutlassB_bf16_sm80(cb, cc); } if (std::is_same::value && 80 <= cc && cc < 90) { - dispatch_cutlassB_f16_sm80(cb); + dispatch_cutlassB_f16_sm80(cb, cc); } if (std::is_same::value && 80 <= cc && cc < 90) { - dispatch_cutlassB_f32_sm80(cb); + dispatch_cutlassB_f32_sm80(cb, cc); } } #endif // XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD diff --git a/xformers/csrc/attention/cuda/fmha/kernels/cutlassB_bf16_aligned_k96.cu b/xformers/csrc/attention/cuda/fmha/kernels/cutlassB_bf16_aligned_k96.cu new file mode 100644 index 0000000000..48376625bd --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/cutlassB_bf16_aligned_k96.cu @@ -0,0 +1,24 @@ +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +// This file is auto-generated. See "generate_kernels.py" +#include "../kernel_backward.h" + +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_bf16_aligned_128x64_k96_sm80(typename AttentionBackwardKernel::Params p) { +#ifdef __CUDA_ARCH__ +#if __CUDA_ARCH__ >= 800 +#if __CUDA_ARCH__ < 900 + if (!p.advance_to_block()) { + return; + } + AttentionBackwardKernel::attention_kernel(p); + return; +#endif +#endif + printf( + "FATAL: kernel `fmha_cutlassB_bf16_aligned_128x64_k96_sm80` is for sm80-sm90, but was built for sm%d\n", + int(__CUDA_ARCH__ + 0) / 10); +#endif +} +#endif // XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD diff --git a/xformers/csrc/attention/cuda/fmha/kernels/cutlassB_f16_aligned_k96.cu b/xformers/csrc/attention/cuda/fmha/kernels/cutlassB_f16_aligned_k96.cu new file mode 100644 index 0000000000..385a57825d --- /dev/null +++ b/xformers/csrc/attention/cuda/fmha/kernels/cutlassB_f16_aligned_k96.cu @@ -0,0 +1,24 @@ +#ifndef XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD +// This file is auto-generated. See "generate_kernels.py" +#include "../kernel_backward.h" + +__global__ void __launch_bounds__( + AttentionBackwardKernel::kNumThreads, + AttentionBackwardKernel::kMinBlocksPerSm) +fmha_cutlassB_f16_aligned_128x64_k96_sm80(typename AttentionBackwardKernel::Params p) { +#ifdef __CUDA_ARCH__ +#if __CUDA_ARCH__ >= 800 +#if __CUDA_ARCH__ < 900 + if (!p.advance_to_block()) { + return; + } + AttentionBackwardKernel::attention_kernel(p); + return; +#endif +#endif + printf( + "FATAL: kernel `fmha_cutlassB_f16_aligned_128x64_k96_sm80` is for sm80-sm90, but was built for sm%d\n", + int(__CUDA_ARCH__ + 0) / 10); +#endif +} +#endif // XFORMERS_MEM_EFF_ATTENTION_DISABLE_BACKWARD diff --git a/xformers/csrc/attention/cuda/fmha/kernels/cutlassF.h b/xformers/csrc/attention/cuda/fmha/kernels/cutlassF.h index 0813fede5b..63c87d06e1 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/cutlassF.h +++ b/xformers/csrc/attention/cuda/fmha/kernels/cutlassF.h @@ -17,7 +17,7 @@ __global__ void __launch_bounds__( AttentionKernel::kMinBlocksPerSm) fmha_cutlassF_bf16_aligned_32x128_gmem_sm80(typename AttentionKernel::Params p); -template void dispatch_cutlassF_bf16_sm80(T cb) { +template void dispatch_cutlassF_bf16_sm80(T cb, int cc) { cb(AttentionKernel(), fmha_cutlassF_bf16_aligned_64x64_rf_sm80); cb(AttentionKernel(), fmha_cutlassF_bf16_aligned_32x128_rf_sm80); cb(AttentionKernel(), fmha_cutlassF_bf16_aligned_32x128_gmem_sm80); @@ -49,7 +49,7 @@ __global__ void __launch_bounds__( AttentionKernel::kMinBlocksPerSm) fmha_cutlassF_f16_notaligned_32x128_gmem_sm50(typename AttentionKernel::Params p); -template void dispatch_cutlassF_f16_sm50(T cb) { +template void dispatch_cutlassF_f16_sm50(T cb, int cc) { cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x64_rf_sm50); cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_rf_sm50); cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_gmem_sm50); @@ -84,7 +84,7 @@ __global__ void __launch_bounds__( AttentionKernel::kMinBlocksPerSm) fmha_cutlassF_f16_notaligned_32x128_gmem_sm70(typename AttentionKernel::Params p); -template void dispatch_cutlassF_f16_sm70(T cb) { +template void dispatch_cutlassF_f16_sm70(T cb, int cc) { cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x64_rf_sm70); cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_rf_sm70); cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_gmem_sm70); @@ -119,7 +119,7 @@ __global__ void __launch_bounds__( AttentionKernel::kMinBlocksPerSm) fmha_cutlassF_f16_notaligned_32x128_gmem_sm75(typename AttentionKernel::Params p); -template void dispatch_cutlassF_f16_sm75(T cb) { +template void dispatch_cutlassF_f16_sm75(T cb, int cc) { cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x64_rf_sm75); cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_rf_sm75); cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_gmem_sm75); @@ -142,7 +142,7 @@ __global__ void __launch_bounds__( AttentionKernel::kMinBlocksPerSm) fmha_cutlassF_f16_aligned_32x128_gmem_sm80(typename AttentionKernel::Params p); -template void dispatch_cutlassF_f16_sm80(T cb) { +template void dispatch_cutlassF_f16_sm80(T cb, int cc) { cb(AttentionKernel(), fmha_cutlassF_f16_aligned_64x64_rf_sm80); cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_rf_sm80); cb(AttentionKernel(), fmha_cutlassF_f16_aligned_32x128_gmem_sm80); @@ -174,7 +174,7 @@ __global__ void __launch_bounds__( AttentionKernel::kMinBlocksPerSm) fmha_cutlassF_f32_notaligned_32x128_gmem_sm50(typename AttentionKernel::Params p); -template void dispatch_cutlassF_f32_sm50(T cb) { +template void dispatch_cutlassF_f32_sm50(T cb, int cc) { cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x64_rf_sm50); cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_rf_sm50); cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_gmem_sm50); @@ -209,7 +209,7 @@ __global__ void __launch_bounds__( AttentionKernel::kMinBlocksPerSm) fmha_cutlassF_f32_notaligned_32x128_gmem_sm70(typename AttentionKernel::Params p); -template void dispatch_cutlassF_f32_sm70(T cb) { +template void dispatch_cutlassF_f32_sm70(T cb, int cc) { cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x64_rf_sm70); cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_rf_sm70); cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_gmem_sm70); @@ -244,7 +244,7 @@ __global__ void __launch_bounds__( AttentionKernel::kMinBlocksPerSm) fmha_cutlassF_f32_notaligned_32x128_gmem_sm75(typename AttentionKernel::Params p); -template void dispatch_cutlassF_f32_sm75(T cb) { +template void dispatch_cutlassF_f32_sm75(T cb, int cc) { cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x64_rf_sm75); cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_rf_sm75); cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_gmem_sm75); @@ -267,7 +267,7 @@ __global__ void __launch_bounds__( AttentionKernel::kMinBlocksPerSm) fmha_cutlassF_f32_aligned_32x128_gmem_sm80(typename AttentionKernel::Params p); -template void dispatch_cutlassF_f32_sm80(T cb) { +template void dispatch_cutlassF_f32_sm80(T cb, int cc) { cb(AttentionKernel(), fmha_cutlassF_f32_aligned_64x64_rf_sm80); cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_rf_sm80); cb(AttentionKernel(), fmha_cutlassF_f32_aligned_32x128_gmem_sm80); @@ -278,31 +278,31 @@ template void dispatch_cutlassF(T cb, int cc = 0) { if (std::is_same::value && 80 <= cc && cc < 90) { - dispatch_cutlassF_bf16_sm80(cb); + dispatch_cutlassF_bf16_sm80(cb, cc); } if (std::is_same::value && 50 <= cc && cc < 70) { - dispatch_cutlassF_f16_sm50(cb); + dispatch_cutlassF_f16_sm50(cb, cc); } if (std::is_same::value && 70 <= cc && cc < 75) { - dispatch_cutlassF_f16_sm70(cb); + dispatch_cutlassF_f16_sm70(cb, cc); } if (std::is_same::value && 75 <= cc && cc < 80) { - dispatch_cutlassF_f16_sm75(cb); + dispatch_cutlassF_f16_sm75(cb, cc); } if (std::is_same::value && 80 <= cc && cc < 90) { - dispatch_cutlassF_f16_sm80(cb); + dispatch_cutlassF_f16_sm80(cb, cc); } if (std::is_same::value && 50 <= cc && cc < 70) { - dispatch_cutlassF_f32_sm50(cb); + dispatch_cutlassF_f32_sm50(cb, cc); } if (std::is_same::value && 70 <= cc && cc < 75) { - dispatch_cutlassF_f32_sm70(cb); + dispatch_cutlassF_f32_sm70(cb, cc); } if (std::is_same::value && 75 <= cc && cc < 80) { - dispatch_cutlassF_f32_sm75(cb); + dispatch_cutlassF_f32_sm75(cb, cc); } if (std::is_same::value && 80 <= cc && cc < 90) { - dispatch_cutlassF_f32_sm80(cb); + dispatch_cutlassF_f32_sm80(cb, cc); } } #endif // XFORMERS_MEM_EFF_ATTENTION_DISABLE_FORWARD diff --git a/xformers/csrc/attention/cuda/fmha/kernels/generate_kernels.py b/xformers/csrc/attention/cuda/fmha/kernels/generate_kernels.py index 7593077239..6c832eb2b7 100644 --- a/xformers/csrc/attention/cuda/fmha/kernels/generate_kernels.py +++ b/xformers/csrc/attention/cuda/fmha/kernels/generate_kernels.py @@ -12,7 +12,7 @@ import itertools from dataclasses import dataclass, field from pathlib import Path -from typing import Dict, List, Tuple, TypeVar +from typing import Dict, List, Optional, Tuple, TypeVar DTYPES = { "f32": "float", @@ -49,13 +49,13 @@ class FwdKernel: sort_index: Tuple[int, ...] = field(init=False, repr=False) aligned: bool dtype: str - sm: int - sm_max: int + sm_range: Tuple[int, int] q: int k: int single_value_iter: bool supports_dropout: bool = True supports_bias: bool = True + dispatch_cond: Optional[str] = None def __post_init__(self) -> None: # Set kernel selection priority @@ -79,14 +79,14 @@ def _aligned_suffix(self) -> str: @property def name(self) -> str: acc = "rf" if self.single_value_iter else "gmem" - return f"fmha_cutlassF_{self.dtype}_{self._aligned_suffix}_{self.q}x{self.k}_{acc}_sm{self.sm}" + return f"fmha_cutlassF_{self.dtype}_{self._aligned_suffix}_{self.q}x{self.k}_{acc}_sm{self.sm_range[0]}" @property def cpp_class(self) -> str: template_args = ", ".join( [ DTYPES[self.dtype], - f"cutlass::arch::Sm{self.sm}", + f"cutlass::arch::Sm{self.sm_range[0]}", "true" if self.aligned else "false", str(self.q), str(self.k), @@ -107,8 +107,8 @@ def cpp_impl(self) -> str: return KERNEL_IMPL_TEMPLATE.format( CPP_CLASS=self.cpp_class, NAME=self.name, - SM=self.sm, - SM_MAX=self.sm_max, + SM=self.sm_range[0], + SM_MAX=self.sm_range[1], ) @classmethod @@ -131,8 +131,7 @@ def get_all(cls) -> List["FwdKernel"]: cls( aligned=aligned, dtype=dtype, - sm=sm, - sm_max=sm_max, + sm_range=(sm, sm_max), q=q, k=k, single_value_iter=single_value_iter, @@ -144,8 +143,7 @@ def get_all(cls) -> List["FwdKernel"]: @dataclass(order=True) class BwdKernel: sort_index: Tuple[int, ...] = field(init=False, repr=False) - sm: int - sm_max: int + sm_range: Tuple[int, int] dtype: str aligned: bool apply_dropout: bool @@ -153,6 +151,7 @@ class BwdKernel: block_i: int block_j: int max_k: int + dispatch_cond: Optional[str] = None def __post_init__(self) -> None: # Set kernel selection priority @@ -178,14 +177,14 @@ def name(self) -> str: dropout_suffix = "_dropout" if self.apply_dropout else "" return ( f"fmha_cutlassB_{self.dtype}_{self._aligned_suffix}" - f"_{self.block_i}x{self.block_j}_k{self.max_k}{dropout_suffix}_sm{self.sm}" + f"_{self.block_i}x{self.block_j}_k{self.max_k}{dropout_suffix}_sm{self.sm_range[0]}" ) @property def cpp_class(self) -> str: template_args = ", ".join( [ - f"cutlass::arch::Sm{self.sm}", + f"cutlass::arch::Sm{self.sm_range[0]}", DTYPES[self.dtype], "true" if self.aligned else "false", "true" if self.apply_dropout else "false", @@ -208,8 +207,8 @@ def cpp_impl(self) -> str: return KERNEL_IMPL_TEMPLATE.format( CPP_CLASS=self.cpp_class, NAME=self.name, - SM=self.sm, - SM_MAX=self.sm_max, + SM=self.sm_range[0], + SM_MAX=self.sm_range[1], ) @classmethod @@ -243,8 +242,7 @@ def get_all(cls) -> List["BwdKernel"]: cls( aligned=aligned, dtype=dtype, - sm=sm, - sm_max=sm_max, + sm_range=(sm, sm_max), apply_dropout=apply_dropout, preload_mmas=preload_mmas, block_i=bi, @@ -252,6 +250,24 @@ def get_all(cls) -> List["BwdKernel"]: max_k=max_k, ) ) + # Add some specialized kernels for stable diffusion BW (K=80) + # This is the only kernel that can keep the outputs on RF on + # Sm86/Sm89, so it's much faster than the 64x64 one + for dtype in ["f16", "bf16"]: + kernels.append( + cls( + aligned=True, + dtype=dtype, + sm_range=(80, 90), + apply_dropout=False, + preload_mmas=True, + block_i=128, + block_j=64, + max_k=96, + # Sm80 has a faster kernel for this case + dispatch_cond="cc == 86 || cc == 89", + ) + ) return kernels @@ -278,7 +294,7 @@ def write_decl_impl( # Declaration of kernel functions for k in kernels: implfile_to_kernels[k.impl_group].append(k) - cat_to_kernels[(k.dtype, k.sm, k.sm_max)].append(k) + cat_to_kernels[(k.dtype, k.sm_range[0], k.sm_range[1])].append(k) for (cat_dt, cat_sm, cat_sm_max), kernels in cat_to_kernels.items(): declarations += f"// ======== {cat_dt} / sm{cat_sm} ========\n" @@ -287,16 +303,17 @@ def write_decl_impl( ) dispatch_category_fn = f"dispatch_{family_name}_{cat_dt}_sm{cat_sm}" declarations += ( - f"\n\ntemplate void {dispatch_category_fn}(T cb) {{\n" - ) - declarations += "\n".join( - f" cb({k.cpp_class}(), {k.name});" for k in kernels + f"\n\ntemplate void {dispatch_category_fn}(T cb, int cc) {{\n" ) - declarations += "\n}\n" - declarations += "\n" + for k in kernels: + _call = f"cb({k.cpp_class}(), {k.name});\n" + if k.dispatch_cond is not None: + _call = f"if ({k.dispatch_cond}) {_call}" + declarations += f" {_call}" + declarations += "}\n\n" dispatch_all += f""" if (std::is_same::value && {cat_sm} <= cc && cc < {cat_sm_max}) {{ - {dispatch_category_fn}(cb); + {dispatch_category_fn}(cb, cc); }}""" declarations += f"""