Skip to content

Commit

Permalink
Fix test with alibi and cache_leftpad
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jul 23, 2024
1 parent 4488ace commit 2995636
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tests/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


def attn_bias_from_alibi_slopes(
slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False
slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False, key_leftpad=None
):
batch, nheads = slopes.shape
device = slopes.device
Expand All @@ -37,6 +37,10 @@ def attn_bias_from_alibi_slopes(
else:
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
if key_leftpad is not None:
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
sk = (
seqlen_k
if key_padding_mask is None
Expand Down Expand Up @@ -1993,7 +1997,7 @@ def test_flash_attn_kvcache(
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(
alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal
alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad
)
else:
alibi_slopes, attn_bias = None, None
Expand Down

0 comments on commit 2995636

Please sign in to comment.