|
6 | 6 | CpuAdamX86Extension,
|
7 | 7 | FlashAttentionDaoCudaExtension,
|
8 | 8 | FlashAttentionNpuExtension,
|
9 |
| - FlashAttentionXformersCudaExtension, |
| 9 | + FlashAttentionSdpaCudaExtension, |
10 | 10 | FusedOptimizerCudaExtension,
|
11 | 11 | LayerNormCudaExtension,
|
12 | 12 | MoeCudaExtension,
|
@@ -65,9 +65,9 @@ def load(self, ext_name: str = None):
|
65 | 65 | else:
|
66 | 66 | usable_exts = []
|
67 | 67 | for ext in exts:
|
68 |
| - if ext.is_hardware_available(): |
| 68 | + if ext.is_available(): |
69 | 69 | # make sure the machine is compatible during kernel loading
|
70 |
| - ext.assert_hardware_compatible() |
| 70 | + ext.assert_compatible() |
71 | 71 | usable_exts.append(ext)
|
72 | 72 |
|
73 | 73 | assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine."
|
@@ -106,4 +106,20 @@ class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader):
|
106 | 106 |
|
107 | 107 |
|
108 | 108 | class FlashAttentionLoader(KernelLoader):
|
109 |
| - REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension] |
| 109 | + REGISTRY = [ |
| 110 | + FlashAttentionNpuExtension, |
| 111 | + FlashAttentionDaoCudaExtension, |
| 112 | + FlashAttentionSdpaCudaExtension, |
| 113 | + ] |
| 114 | + |
| 115 | + |
| 116 | +class FlashAttentionWithPaddingMaskLoader(KernelLoader): |
| 117 | + REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension] |
| 118 | + |
| 119 | + |
| 120 | +class FlashAttentionWithCustomMaskLoader(KernelLoader): |
| 121 | + REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension] |
| 122 | + |
| 123 | + |
| 124 | +class FlashAttentionForFloatAndCustomMaskLoader(KernelLoader): |
| 125 | + REGISTRY = [FlashAttentionSdpaCudaExtension] |
0 commit comments