Skip to content

Commit

Permalink
align to flash attention v2.2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
sallyjunjun committed Jan 30, 2024
1 parent 8da532c commit 22f7c0c
Showing 1 changed file with 28 additions and 11 deletions.
39 changes: 28 additions & 11 deletions internlm/model/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,18 @@ def forward(self, qkv, causal=None, key_padding_mask=None):
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
if key_padding_mask is not None:
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
padding_mask = torch.full(
(batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
)
padding_mask.masked_fill_(key_padding_mask, 0.0)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
causal_mask = torch.triu(
torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
Expand Down Expand Up @@ -86,29 +90,40 @@ def forward(self, q, kv, causal=None, key_padding_mask=None):
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, Sk)
"""
batch_size, seqlen_q = q.shape[0], q.shape[1]
causal = self.causal if causal is None else causal
seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
if kv.shape[3] != q.shape[2]: # MQA/GQA
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
k, v = kv.unbind(dim=2)
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
if key_padding_mask is not None:
padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device)
padding_mask = torch.full(
(batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
)
padding_mask.masked_fill_(key_padding_mask, 0.0)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype)
# causal mask needs to take into account the difference between seqlen_q and seqlen_k
row_idx = rearrange(
torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
)
col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
causal_mask = col_idx > row_idx + sk - seqlen_q
scores = scores.masked_fill(causal_mask, -10000.0)
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
attention_drop = self.drop(attention)
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
Expand Down Expand Up @@ -168,7 +183,9 @@ def _update_kv_cache(kv, inference_params, layer_idx):
k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
)
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d")
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(
kv[:, :, 1], "b s h d -> b h s d"
)
return kv


Expand Down

0 comments on commit 22f7c0c

Please sign in to comment.