Skip to content

Commit

Permalink
Merge pull request #6061 from wangbluo/sp_fix
Browse files Browse the repository at this point in the history
[sp] : fix the attention kernel for sp
  • Loading branch information
wangbluo authored Sep 14, 2024
2 parents bdb125f + 827ef3e commit 37e3523
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 10 deletions.
4 changes: 4 additions & 0 deletions colossalai/kernel/kernel_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ class FlashAttentionLoader(KernelLoader):
]


class FlashAttentionDaoLoader(KernelLoader):
REGISTRY = [FlashAttentionDaoCudaExtension]


class FlashAttentionWithCustomMaskLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]

Expand Down
48 changes: 38 additions & 10 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from einops import rearrange

from colossalai.kernel.kernel_loader import (
FlashAttentionDaoLoader,
FlashAttentionForFloatAndCustomMaskLoader,
FlashAttentionLoader,
FlashAttentionWithCustomMaskLoader,
Expand All @@ -17,6 +18,8 @@

from .utils import RingComm, get_half_index, split_varlen_zigzag

MEMORY_BOUND = 10 * 1e9

__all__ = [
"AttnMaskType",
"ColoAttention",
Expand Down Expand Up @@ -77,6 +80,7 @@ def get_pad_info(

class ColoAttention:
_kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
_flash_kernel_dispatch: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None

@staticmethod
def _init_kernels_dispatch():
Expand All @@ -102,9 +106,11 @@ def _init_kernels_dispatch():
torch.bfloat16: half_dispatch_map,
torch.float32: float_dispatch_map,
}
if ColoAttention._flash_kernel_dispatch is None:
ColoAttention._flash_kernel_dispatch = FlashAttentionDaoLoader()

@staticmethod
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable:
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable:
ColoAttention._init_kernels_dispatch()
if (
dtype not in ColoAttention._kernel_dispatch_map
Expand All @@ -113,12 +119,19 @@ def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> C
raise ValueError(
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
)

if size >= MEMORY_BOUND:
ColoAttention._flash_kernel_dispatch = ColoAttention._flash_kernel_dispatch.load()
# lazy load
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
mask_type
].load()
return ColoAttention._kernel_dispatch_map[dtype][mask_type]

if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL):
return ColoAttention._flash_kernel_dispatch
else:
return ColoAttention._kernel_dispatch_map[dtype][mask_type]

@staticmethod
def prepare_attn_kwargs(
Expand Down Expand Up @@ -154,17 +167,22 @@ def prepare_attn_kwargs(
return {}
assert len(shape_4d) == 4 and shape_4d[1] == 1
b, _, s_q, s_kv = shape_4d
element_size = torch.tensor([], dtype=dtype).element_size()
memory_size = s_q * s_kv * element_size
outputs = {}
if (q_padding_mask is None or q_padding_mask.bool().all()) and (
kv_padding_mask is None or kv_padding_mask.bool().all()
):
# no padding
assert is_causal
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
if s_q != 1:
attention_mask = attention_mask.tril(diagonal=0)
attention_mask = attention_mask.expand(b, s_q, s_kv)
if memory_size < MEMORY_BOUND:
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
if s_q != 1:
attention_mask.tril_(diagonal=0)
attention_mask = attention_mask.expand(b, s_q, s_kv)
else:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
else:
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
if kv_padding_mask is None:
Expand All @@ -177,7 +195,6 @@ def prepare_attn_kwargs(
b,
s_kv,
), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
outputs.update(
{
"cu_seqlens_q": cu_seqlens_q,
Expand All @@ -190,10 +207,17 @@ def prepare_attn_kwargs(
)
if is_causal:
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
if s_q != 1:
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
if memory_size < MEMORY_BOUND:
if s_q != 1:
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
else:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
else:
outputs["attention_mask_type"] = AttnMaskType.PADDED
if memory_size < MEMORY_BOUND:
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)

if invert:
attention_mask = invert_mask(attention_mask).unsqueeze(1)
outputs["attention_mask"] = attention_mask
Expand Down Expand Up @@ -278,8 +302,12 @@ def attention(
assert attention_mask_type == AttnMaskType.CUSTOM

# kernel dispatch
b, _, s_q, _ = q.shape
b, _, s_kv, _ = v.shape
element_size = torch.tensor([], dtype=q.dtype).element_size()
memory_size = s_q * s_kv * element_size
mask_type = attention_mask_type if attention_mask is not None else None
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type)
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size)
is_causal = attention_mask is not None and attention_mask_type in (
AttnMaskType.CAUSAL,
AttnMaskType.PADDED_CAUSAL,
Expand Down

0 comments on commit 37e3523

Please sign in to comment.