Skip to content

Commit 1155423

Browse files
authored
[feature] refactor colo attention (#5462)
* [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test
1 parent f2e8b9e commit 1155423

16 files changed

+534
-359
lines changed

colossalai/kernel/kernel_loader.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
CpuAdamX86Extension,
77
FlashAttentionDaoCudaExtension,
88
FlashAttentionNpuExtension,
9-
FlashAttentionXformersCudaExtension,
9+
FlashAttentionSdpaCudaExtension,
1010
FusedOptimizerCudaExtension,
1111
LayerNormCudaExtension,
1212
MoeCudaExtension,
@@ -65,9 +65,9 @@ def load(self, ext_name: str = None):
6565
else:
6666
usable_exts = []
6767
for ext in exts:
68-
if ext.is_hardware_available():
68+
if ext.is_available():
6969
# make sure the machine is compatible during kernel loading
70-
ext.assert_hardware_compatible()
70+
ext.assert_compatible()
7171
usable_exts.append(ext)
7272

7373
assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine."
@@ -106,4 +106,12 @@ class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader):
106106

107107

108108
class FlashAttentionLoader(KernelLoader):
109-
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension]
109+
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionSdpaCudaExtension]
110+
111+
112+
class FlashAttentionWithPaddingMaskLoader(KernelLoader):
113+
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension]
114+
115+
116+
class FlashAttentionWithCustomMaskLoader(KernelLoader):
117+
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]

colossalai/shardformer/layer/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .attn import AttnMaskType, ColoAttention
12
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
23
from .embedding import Embedding1D, VocabParallelEmbedding1D
34
from .linear import Linear1D_Col, Linear1D_Row
@@ -23,4 +24,6 @@
2324
"FusedRMSNorm",
2425
"FusedLinear1D_Col",
2526
"ParallelModule",
27+
"AttnMaskType",
28+
"ColoAttention",
2629
]

colossalai/shardformer/layer/attn.py

+234
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
from enum import Enum
2+
from typing import Callable, Dict, Optional, Tuple
3+
4+
import torch
5+
import torch.nn.functional as F
6+
7+
from colossalai.kernel.kernel_loader import (
8+
FlashAttentionLoader,
9+
FlashAttentionWithCustomMaskLoader,
10+
FlashAttentionWithPaddingMaskLoader,
11+
)
12+
13+
__all__ = [
14+
"AttnMaskType",
15+
"ColoAttention",
16+
]
17+
18+
19+
class AttnMaskType(Enum):
20+
CUSTOM = 0
21+
PADDED = 1
22+
CAUSAL = 2
23+
PADDED_CAUSAL = 3
24+
25+
26+
def invert_mask(mask: torch.Tensor) -> torch.Tensor:
27+
"""Invert the mask tensor.
28+
29+
Args:
30+
mask (torch.Tensor): Mask tensor. Shape should be [B, 1, Sq, Skv]
31+
32+
Returns:
33+
torch.Tensor: Inverted mask tensor.
34+
"""
35+
inverted_mask = 1.0 - mask
36+
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(mask.dtype).min)
37+
38+
39+
# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
40+
def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]:
41+
"""Get padding information from padding mask.
42+
43+
Args:
44+
padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, S]
45+
46+
Returns:
47+
Tuple[int, torch.Tensor, torch.Tensor]: Tuple of (max_seq_len, cu_seqlens, indices)
48+
"""
49+
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
50+
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
51+
max_seqlen_in_batch = seqlens_in_batch.max().item()
52+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
53+
return max_seqlen_in_batch, cu_seqlens, indices
54+
55+
56+
class ColoAttention:
57+
# these two attrs are initialized in the first call of attention() method
58+
_flash_attn_func: Optional[Callable] = None
59+
_flash_attn_with_custom_mask_func: Optional[Callable] = None
60+
_flash_attn_with_padding_mask_func: Optional[Callable] = None
61+
62+
@staticmethod
63+
def _init_flash_attn_func():
64+
if ColoAttention._flash_attn_func is None:
65+
ColoAttention._flash_attn_func = FlashAttentionLoader().load()
66+
if ColoAttention._flash_attn_with_custom_mask_func is None:
67+
ColoAttention._flash_attn_with_custom_mask_func = FlashAttentionWithCustomMaskLoader().load()
68+
if ColoAttention._flash_attn_with_padding_mask_func is None:
69+
ColoAttention._flash_attn_with_padding_mask_func = FlashAttentionWithPaddingMaskLoader().load()
70+
71+
@staticmethod
72+
def prepare_attn_kwargs(
73+
shape_4d: Tuple[int],
74+
dtype: torch.dtype,
75+
device: torch.device,
76+
q_padding_mask: Optional[torch.Tensor] = None,
77+
kv_padding_mask: Optional[torch.Tensor] = None,
78+
is_causal: bool = False,
79+
) -> Dict[str, torch.Tensor]:
80+
"""Return a dictionary of keyword arguments for attention function. It supports 4 mask type.
81+
1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves.
82+
2. padded mask: recv padding mask and is_causal=False, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}.
83+
3. causal mask: no padding mask and is_causal=True, return {attention_mask, attention_mask_type}.
84+
4. padded causal mask: recv padding mask and is_causal=True, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}.
85+
86+
Args:
87+
shape_4d (Tuple[int]): Should be (B, 1, Sq, Skv)
88+
dtype (torch.dtype): Dtype of attention mask, generally should be ``hidden_states.dtype``
89+
device (torch.device): Device of attention mask, generally should be ``hidden_states.device``
90+
q_padding_mask (Optional[torch.Tensor], optional): Padding mask of query. It should be a long tensor or int tensor.
91+
The shape should be [B, Sq]. ``1`` means valid token, and ``0`` means padding token. Defaults to None.
92+
kv_padding_mask (Optional[torch.Tensor], optional): Padding mask of key and value. It should be a long tensor or int tensor.
93+
The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token.
94+
If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None.
95+
is_causal (bool, optional): Whether to use causal attention mask. Defaults to False.
96+
97+
Returns:
98+
Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function.
99+
"""
100+
if q_padding_mask is None and not is_causal:
101+
return {}
102+
assert len(shape_4d) == 4 and shape_4d[1] == 1
103+
b, _, s_q, s_kv = shape_4d
104+
outputs = {}
105+
if q_padding_mask is not None:
106+
if kv_padding_mask is None:
107+
kv_padding_mask = q_padding_mask
108+
assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == (b, s_kv)
109+
attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device)
110+
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
111+
max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
112+
outputs.update(
113+
{
114+
"cu_seqlens_q": cu_seqlens_q,
115+
"cu_seqlens_kv": cu_seqlens_kv,
116+
"max_seqlen_q": max_seqlen_q,
117+
"max_seqlen_kv": max_seqlen_kv,
118+
"q_indices": q_indices,
119+
"kv_indices": kv_indices,
120+
}
121+
)
122+
if is_causal:
123+
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
124+
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
125+
else:
126+
outputs["attention_mask_type"] = AttnMaskType.PADDED
127+
else:
128+
assert is_causal
129+
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
130+
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
131+
attention_mask = invert_mask(attention_mask).unsqueeze(1)
132+
outputs["attention_mask"] = attention_mask
133+
return outputs
134+
135+
@staticmethod
136+
def attention(
137+
q: torch.Tensor,
138+
k: torch.Tensor,
139+
v: torch.Tensor,
140+
attention_mask: Optional[torch.Tensor] = None,
141+
attention_mask_type: AttnMaskType = AttnMaskType.CUSTOM,
142+
cu_seqlens_q: Optional[torch.Tensor] = None,
143+
cu_seqlens_kv: Optional[torch.Tensor] = None,
144+
max_seqlen_q: Optional[int] = None,
145+
max_seqlen_kv: Optional[int] = None,
146+
q_indices: Optional[torch.Tensor] = None,
147+
kv_indices: Optional[torch.Tensor] = None,
148+
dropout_p: float = 0.0,
149+
scale: Optional[float] = None,
150+
) -> torch.Tensor:
151+
"""Flash Attention function. It supports 4 mask type.
152+
1. custom mask: recv attention_mask
153+
2. padded mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
154+
3. causal mask: recv attention_mask, attention_mask_type
155+
4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
156+
157+
Args:
158+
q (torch.Tensor): Query tensor. Shape should be [B, N, Sq, D]
159+
k (torch.Tensor): Key tensor. Shape should be [B, N, Skv, D]
160+
v (torch.Tensor): Value tensor. Shape should be [B, N, Skv, D]
161+
attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None.
162+
attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM.
163+
cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths
164+
of the sequences in the batch, used to index into q.
165+
Shape should be [B+1]. Defaults to None.
166+
cu_seqlens_kv (Optional[torch.Tensor], optional): The cumulative sequence lengths
167+
of the sequences in the batch, used to index into kv.
168+
Shape should be [B+1]. Defaults to None.
169+
max_seqlen_q (Optional[int], optional): Maximum query sequence length in the batch. Defaults to None.
170+
max_seqlen_kv (Optional[int], optional): Maximum key/value sequence length in the batch. Defaults to None.
171+
indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from the flattened input sequence.
172+
Shape should be [NUM_TOKENS]. Defaults to None.
173+
dropout_p (float, optional): Dropout probability. Defaults to 0.0.
174+
scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None.
175+
176+
Returns:
177+
torch.Tensor: Output tensor. Shape should be [B, N, Sq, D]
178+
"""
179+
ColoAttention._init_flash_attn_func()
180+
# known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan
181+
# this case is usaul when padding mask is used and self attention is performed
182+
# thus, we don't use sdpa when padding mask is used
183+
# sanity check
184+
if attention_mask is not None:
185+
assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor."
186+
if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL):
187+
assert (
188+
cu_seqlens_q is None
189+
and cu_seqlens_kv is None
190+
and max_seqlen_q is None
191+
and max_seqlen_kv is None
192+
and q_indices is None
193+
and kv_indices is None
194+
)
195+
if attention_mask_type == AttnMaskType.CUSTOM:
196+
assert not torch.all(attention_mask != 0, dim=-1).any()
197+
elif attention_mask_type in (AttnMaskType.PADDED, AttnMaskType.PADDED_CAUSAL):
198+
assert (
199+
cu_seqlens_q is not None
200+
and cu_seqlens_kv is not None
201+
and max_seqlen_q is not None
202+
and max_seqlen_kv is not None
203+
and q_indices is not None
204+
and kv_indices is not None
205+
)
206+
else:
207+
# if attention_mask is None, attention_mask_type should be the default value
208+
assert attention_mask_type == AttnMaskType.CUSTOM
209+
# kernel dispatch
210+
if attention_mask is not None and attention_mask_type == AttnMaskType.CUSTOM:
211+
attn_func = ColoAttention._flash_attn_with_custom_mask_func
212+
elif attention_mask_type in (AttnMaskType.PADDED, AttnMaskType.PADDED_CAUSAL):
213+
attn_func = ColoAttention._flash_attn_with_padding_mask_func
214+
else:
215+
attn_func = ColoAttention._flash_attn_func
216+
is_causal = attention_mask is not None and attention_mask_type in (
217+
AttnMaskType.CAUSAL,
218+
AttnMaskType.PADDED_CAUSAL,
219+
)
220+
return attn_func(
221+
q,
222+
k,
223+
v,
224+
dropout_p=dropout_p,
225+
scale=scale,
226+
attention_mask=attention_mask,
227+
is_causal=is_causal,
228+
cu_seqlens_q=cu_seqlens_q,
229+
cu_seqlens_kv=cu_seqlens_kv,
230+
max_seqlen_q=max_seqlen_q,
231+
max_seqlen_kv=max_seqlen_kv,
232+
q_indices=q_indices,
233+
kv_indices=kv_indices,
234+
)

extensions/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,13 @@ class MyExtension(_Extension):
101101
self._support_jit = True
102102
self.priority = 10
103103

104-
def is_hardware_available(self) -> bool:
104+
def is_available(self) -> bool:
105105
"""
106106
Return if the required hardware can be found.
107107
"""
108108
...
109109

110-
def assert_hardware_compatible(self) -> None:
110+
def assert_compatible(self) -> None:
111111
"""
112112
Check if the hardware required by the kernel is compatible.
113113
"""

extensions/__init__.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
from .cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension
2-
from .flash_attention import (
3-
FlashAttentionDaoCudaExtension,
4-
FlashAttentionNpuExtension,
5-
FlashAttentionXformersCudaExtension,
6-
)
2+
from .flash_attention import FlashAttentionDaoCudaExtension, FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension
73
from .layernorm import LayerNormCudaExtension
84
from .moe import MoeCudaExtension
95
from .optimizer import FusedOptimizerCudaExtension
@@ -18,7 +14,7 @@
1814
ScaledMaskedSoftmaxCudaExtension,
1915
ScaledUpperTriangleMaskedSoftmaxCudaExtension,
2016
FlashAttentionDaoCudaExtension,
21-
FlashAttentionXformersCudaExtension,
17+
FlashAttentionSdpaCudaExtension,
2218
FlashAttentionNpuExtension,
2319
]
2420

@@ -31,6 +27,6 @@
3127
"ScaledMaskedSoftmaxCudaExtension",
3228
"ScaledUpperTriangleMaskedSoftmaxCudaExtension",
3329
"FlashAttentionDaoCudaExtension",
34-
"FlashAttentionXformersCudaExtension",
30+
"FlashAttentionSdpaCudaExtension",
3531
"FlashAttentionNpuExtension",
3632
]

extensions/base_extension.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,13 @@ def get_jit_extension_folder_path():
5858
return cache_directory
5959

6060
@abstractmethod
61-
def is_hardware_available(self) -> bool:
61+
def is_available(self) -> bool:
6262
"""
6363
Check if the hardware required by the kernel is available.
6464
"""
6565

6666
@abstractmethod
67-
def assert_hardware_compatible(self) -> None:
67+
def assert_compatible(self) -> None:
6868
"""
6969
Check if the hardware required by the kernel is compatible.
7070
"""

extensions/cpu_adam/cpu_adam_arm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ class CpuAdamArmExtension(_CppExtension):
77
def __init__(self):
88
super().__init__(name="cpu_adam_arm")
99

10-
def is_hardware_available(self) -> bool:
10+
def is_available(self) -> bool:
1111
# only arm allowed
1212
return platform.machine() == "aarch64"
1313

14-
def assert_hardware_compatible(self) -> None:
14+
def assert_compatible(self) -> None:
1515
arch = platform.machine()
1616
assert (
1717
arch == "aarch64"

extensions/cpu_adam/cpu_adam_x86.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ class CpuAdamX86Extension(_CudaExtension):
88
def __init__(self):
99
super().__init__(name="cpu_adam_x86")
1010

11-
def is_hardware_available(self) -> bool:
12-
return platform.machine() == "x86_64" and super().is_hardware_available()
11+
def is_available(self) -> bool:
12+
return platform.machine() == "x86_64" and super().is_available()
1313

14-
def assert_hardware_compatible(self) -> None:
14+
def assert_compatible(self) -> None:
1515
arch = platform.machine()
1616
assert (
1717
arch == "x86_64"
1818
), f"[extension] The {self.name} kernel requires the CPU architecture to be x86_64 but got {arch}"
19-
super().assert_hardware_compatible()
19+
super().assert_compatible()
2020

2121
# necessary 4 functions
2222
def sources_files(self):

extensions/cuda_extension.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def nvcc_flags(self) -> List[str]:
2222
This function should return a list of nvcc compilation flags for extensions.
2323
"""
2424

25-
def is_hardware_available(self) -> bool:
25+
def is_available(self) -> bool:
2626
# cuda extension can only be built if cuda is available
2727
try:
2828
import torch
@@ -32,7 +32,7 @@ def is_hardware_available(self) -> bool:
3232
cuda_available = False
3333
return cuda_available
3434

35-
def assert_hardware_compatible(self) -> None:
35+
def assert_compatible(self) -> None:
3636
from torch.utils.cpp_extension import CUDA_HOME
3737

3838
if not CUDA_HOME:

0 commit comments

Comments
 (0)