|
| 1 | +from enum import Enum |
| 2 | +from typing import Callable, Dict, Optional, Tuple |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn.functional as F |
| 6 | + |
| 7 | +from colossalai.kernel.kernel_loader import ( |
| 8 | + FlashAttentionLoader, |
| 9 | + FlashAttentionWithCustomMaskLoader, |
| 10 | + FlashAttentionWithPaddingMaskLoader, |
| 11 | +) |
| 12 | + |
| 13 | +__all__ = [ |
| 14 | + "AttnMaskType", |
| 15 | + "ColoAttention", |
| 16 | +] |
| 17 | + |
| 18 | + |
| 19 | +class AttnMaskType(Enum): |
| 20 | + CUSTOM = 0 |
| 21 | + PADDED = 1 |
| 22 | + CAUSAL = 2 |
| 23 | + PADDED_CAUSAL = 3 |
| 24 | + |
| 25 | + |
| 26 | +def invert_mask(mask: torch.Tensor) -> torch.Tensor: |
| 27 | + """Invert the mask tensor. |
| 28 | +
|
| 29 | + Args: |
| 30 | + mask (torch.Tensor): Mask tensor. Shape should be [B, 1, Sq, Skv] |
| 31 | +
|
| 32 | + Returns: |
| 33 | + torch.Tensor: Inverted mask tensor. |
| 34 | + """ |
| 35 | + inverted_mask = 1.0 - mask |
| 36 | + return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(mask.dtype).min) |
| 37 | + |
| 38 | + |
| 39 | +# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py |
| 40 | +def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]: |
| 41 | + """Get padding information from padding mask. |
| 42 | +
|
| 43 | + Args: |
| 44 | + padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, S] |
| 45 | +
|
| 46 | + Returns: |
| 47 | + Tuple[int, torch.Tensor, torch.Tensor]: Tuple of (max_seq_len, cu_seqlens, indices) |
| 48 | + """ |
| 49 | + seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) |
| 50 | + indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() |
| 51 | + max_seqlen_in_batch = seqlens_in_batch.max().item() |
| 52 | + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) |
| 53 | + return max_seqlen_in_batch, cu_seqlens, indices |
| 54 | + |
| 55 | + |
| 56 | +class ColoAttention: |
| 57 | + # these two attrs are initialized in the first call of attention() method |
| 58 | + _flash_attn_func: Optional[Callable] = None |
| 59 | + _flash_attn_with_custom_mask_func: Optional[Callable] = None |
| 60 | + _flash_attn_with_padding_mask_func: Optional[Callable] = None |
| 61 | + |
| 62 | + @staticmethod |
| 63 | + def _init_flash_attn_func(): |
| 64 | + if ColoAttention._flash_attn_func is None: |
| 65 | + ColoAttention._flash_attn_func = FlashAttentionLoader().load() |
| 66 | + if ColoAttention._flash_attn_with_custom_mask_func is None: |
| 67 | + ColoAttention._flash_attn_with_custom_mask_func = FlashAttentionWithCustomMaskLoader().load() |
| 68 | + if ColoAttention._flash_attn_with_padding_mask_func is None: |
| 69 | + ColoAttention._flash_attn_with_padding_mask_func = FlashAttentionWithPaddingMaskLoader().load() |
| 70 | + |
| 71 | + @staticmethod |
| 72 | + def prepare_attn_kwargs( |
| 73 | + shape_4d: Tuple[int], |
| 74 | + dtype: torch.dtype, |
| 75 | + device: torch.device, |
| 76 | + q_padding_mask: Optional[torch.Tensor] = None, |
| 77 | + kv_padding_mask: Optional[torch.Tensor] = None, |
| 78 | + is_causal: bool = False, |
| 79 | + ) -> Dict[str, torch.Tensor]: |
| 80 | + """Return a dictionary of keyword arguments for attention function. It supports 4 mask type. |
| 81 | + 1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves. |
| 82 | + 2. padded mask: recv padding mask and is_causal=False, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}. |
| 83 | + 3. causal mask: no padding mask and is_causal=True, return {attention_mask, attention_mask_type}. |
| 84 | + 4. padded causal mask: recv padding mask and is_causal=True, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}. |
| 85 | +
|
| 86 | + Args: |
| 87 | + shape_4d (Tuple[int]): Should be (B, 1, Sq, Skv) |
| 88 | + dtype (torch.dtype): Dtype of attention mask, generally should be ``hidden_states.dtype`` |
| 89 | + device (torch.device): Device of attention mask, generally should be ``hidden_states.device`` |
| 90 | + q_padding_mask (Optional[torch.Tensor], optional): Padding mask of query. It should be a long tensor or int tensor. |
| 91 | + The shape should be [B, Sq]. ``1`` means valid token, and ``0`` means padding token. Defaults to None. |
| 92 | + kv_padding_mask (Optional[torch.Tensor], optional): Padding mask of key and value. It should be a long tensor or int tensor. |
| 93 | + The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token. |
| 94 | + If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None. |
| 95 | + is_causal (bool, optional): Whether to use causal attention mask. Defaults to False. |
| 96 | +
|
| 97 | + Returns: |
| 98 | + Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function. |
| 99 | + """ |
| 100 | + if q_padding_mask is None and not is_causal: |
| 101 | + return {} |
| 102 | + assert len(shape_4d) == 4 and shape_4d[1] == 1 |
| 103 | + b, _, s_q, s_kv = shape_4d |
| 104 | + outputs = {} |
| 105 | + if q_padding_mask is not None: |
| 106 | + if kv_padding_mask is None: |
| 107 | + kv_padding_mask = q_padding_mask |
| 108 | + assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == (b, s_kv) |
| 109 | + attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device) |
| 110 | + max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask) |
| 111 | + max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask) |
| 112 | + outputs.update( |
| 113 | + { |
| 114 | + "cu_seqlens_q": cu_seqlens_q, |
| 115 | + "cu_seqlens_kv": cu_seqlens_kv, |
| 116 | + "max_seqlen_q": max_seqlen_q, |
| 117 | + "max_seqlen_kv": max_seqlen_kv, |
| 118 | + "q_indices": q_indices, |
| 119 | + "kv_indices": kv_indices, |
| 120 | + } |
| 121 | + ) |
| 122 | + if is_causal: |
| 123 | + outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL |
| 124 | + attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) |
| 125 | + else: |
| 126 | + outputs["attention_mask_type"] = AttnMaskType.PADDED |
| 127 | + else: |
| 128 | + assert is_causal |
| 129 | + outputs["attention_mask_type"] = AttnMaskType.CAUSAL |
| 130 | + attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv) |
| 131 | + attention_mask = invert_mask(attention_mask).unsqueeze(1) |
| 132 | + outputs["attention_mask"] = attention_mask |
| 133 | + return outputs |
| 134 | + |
| 135 | + @staticmethod |
| 136 | + def attention( |
| 137 | + q: torch.Tensor, |
| 138 | + k: torch.Tensor, |
| 139 | + v: torch.Tensor, |
| 140 | + attention_mask: Optional[torch.Tensor] = None, |
| 141 | + attention_mask_type: AttnMaskType = AttnMaskType.CUSTOM, |
| 142 | + cu_seqlens_q: Optional[torch.Tensor] = None, |
| 143 | + cu_seqlens_kv: Optional[torch.Tensor] = None, |
| 144 | + max_seqlen_q: Optional[int] = None, |
| 145 | + max_seqlen_kv: Optional[int] = None, |
| 146 | + q_indices: Optional[torch.Tensor] = None, |
| 147 | + kv_indices: Optional[torch.Tensor] = None, |
| 148 | + dropout_p: float = 0.0, |
| 149 | + scale: Optional[float] = None, |
| 150 | + ) -> torch.Tensor: |
| 151 | + """Flash Attention function. It supports 4 mask type. |
| 152 | + 1. custom mask: recv attention_mask |
| 153 | + 2. padded mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices |
| 154 | + 3. causal mask: recv attention_mask, attention_mask_type |
| 155 | + 4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices |
| 156 | +
|
| 157 | + Args: |
| 158 | + q (torch.Tensor): Query tensor. Shape should be [B, N, Sq, D] |
| 159 | + k (torch.Tensor): Key tensor. Shape should be [B, N, Skv, D] |
| 160 | + v (torch.Tensor): Value tensor. Shape should be [B, N, Skv, D] |
| 161 | + attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None. |
| 162 | + attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM. |
| 163 | + cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths |
| 164 | + of the sequences in the batch, used to index into q. |
| 165 | + Shape should be [B+1]. Defaults to None. |
| 166 | + cu_seqlens_kv (Optional[torch.Tensor], optional): The cumulative sequence lengths |
| 167 | + of the sequences in the batch, used to index into kv. |
| 168 | + Shape should be [B+1]. Defaults to None. |
| 169 | + max_seqlen_q (Optional[int], optional): Maximum query sequence length in the batch. Defaults to None. |
| 170 | + max_seqlen_kv (Optional[int], optional): Maximum key/value sequence length in the batch. Defaults to None. |
| 171 | + indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from the flattened input sequence. |
| 172 | + Shape should be [NUM_TOKENS]. Defaults to None. |
| 173 | + dropout_p (float, optional): Dropout probability. Defaults to 0.0. |
| 174 | + scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None. |
| 175 | +
|
| 176 | + Returns: |
| 177 | + torch.Tensor: Output tensor. Shape should be [B, N, Sq, D] |
| 178 | + """ |
| 179 | + ColoAttention._init_flash_attn_func() |
| 180 | + # known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan |
| 181 | + # this case is usaul when padding mask is used and self attention is performed |
| 182 | + # thus, we don't use sdpa when padding mask is used |
| 183 | + # sanity check |
| 184 | + if attention_mask is not None: |
| 185 | + assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor." |
| 186 | + if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL): |
| 187 | + assert ( |
| 188 | + cu_seqlens_q is None |
| 189 | + and cu_seqlens_kv is None |
| 190 | + and max_seqlen_q is None |
| 191 | + and max_seqlen_kv is None |
| 192 | + and q_indices is None |
| 193 | + and kv_indices is None |
| 194 | + ) |
| 195 | + if attention_mask_type == AttnMaskType.CUSTOM: |
| 196 | + assert not torch.all(attention_mask != 0, dim=-1).any() |
| 197 | + elif attention_mask_type in (AttnMaskType.PADDED, AttnMaskType.PADDED_CAUSAL): |
| 198 | + assert ( |
| 199 | + cu_seqlens_q is not None |
| 200 | + and cu_seqlens_kv is not None |
| 201 | + and max_seqlen_q is not None |
| 202 | + and max_seqlen_kv is not None |
| 203 | + and q_indices is not None |
| 204 | + and kv_indices is not None |
| 205 | + ) |
| 206 | + else: |
| 207 | + # if attention_mask is None, attention_mask_type should be the default value |
| 208 | + assert attention_mask_type == AttnMaskType.CUSTOM |
| 209 | + # kernel dispatch |
| 210 | + if attention_mask is not None and attention_mask_type == AttnMaskType.CUSTOM: |
| 211 | + attn_func = ColoAttention._flash_attn_with_custom_mask_func |
| 212 | + elif attention_mask_type in (AttnMaskType.PADDED, AttnMaskType.PADDED_CAUSAL): |
| 213 | + attn_func = ColoAttention._flash_attn_with_padding_mask_func |
| 214 | + else: |
| 215 | + attn_func = ColoAttention._flash_attn_func |
| 216 | + is_causal = attention_mask is not None and attention_mask_type in ( |
| 217 | + AttnMaskType.CAUSAL, |
| 218 | + AttnMaskType.PADDED_CAUSAL, |
| 219 | + ) |
| 220 | + return attn_func( |
| 221 | + q, |
| 222 | + k, |
| 223 | + v, |
| 224 | + dropout_p=dropout_p, |
| 225 | + scale=scale, |
| 226 | + attention_mask=attention_mask, |
| 227 | + is_causal=is_causal, |
| 228 | + cu_seqlens_q=cu_seqlens_q, |
| 229 | + cu_seqlens_kv=cu_seqlens_kv, |
| 230 | + max_seqlen_q=max_seqlen_q, |
| 231 | + max_seqlen_kv=max_seqlen_kv, |
| 232 | + q_indices=q_indices, |
| 233 | + kv_indices=kv_indices, |
| 234 | + ) |
0 commit comments