Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom mask for Triton attention #3317

Merged
merged 3 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):

qo_indptr = None
custom_mask = None
mask_offsets = None
else:
kv_indptr[1 : bs + 1] = torch.cumsum(
forward_batch.extend_prefix_lens, dim=0
Expand All @@ -115,6 +116,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
mask_offsets = None

attn_logits = None
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
Expand All @@ -126,6 +128,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices,
qo_indptr,
custom_mask,
mask_offsets,
)

def init_cuda_graph_state(self, max_bs: int):
Expand Down Expand Up @@ -180,6 +183,7 @@ def init_forward_metadata_capture_cuda_graph(
kv_indices,
None,
None,
None,
)

def init_forward_metadata_replay_cuda_graph(
Expand Down Expand Up @@ -233,9 +237,15 @@ def forward_extend(
layer, forward_batch.out_cache_loc, k, v
)

_, max_extend_len, kv_indptr, kv_indices, qo_indptr, custom_mask = (
self.forward_metadata
)
(
_,
max_extend_len,
kv_indptr,
kv_indices,
qo_indptr,
custom_mask,
mask_offsets,
) = self.forward_metadata
self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k.contiguous(),
Expand All @@ -246,6 +256,8 @@ def forward_extend(
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
mask_offsets,
max_extend_len,
layer.scaling,
layer.logit_cap,
Expand All @@ -271,7 +283,7 @@ def forward_decode(
else:
o = torch.empty_like(q)

attn_logits, _, kv_indptr, kv_indices, _, _ = self.forward_metadata
attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata

if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def _fwd_kernel(
qo_indptr,
kv_indptr,
kv_indices,
mask_ptr,
mask_offsets,
sm_scale,
kv_group_num,
stride_qbs,
Expand All @@ -71,6 +73,7 @@ def _fwd_kernel(
BLOCK_DV: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
USE_CUSTOM_MASK: tl.constexpr,
):
cur_seq = tl.program_id(0)
cur_head = tl.program_id(1)
Expand All @@ -81,6 +84,10 @@ def _fwd_kernel(
cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx
cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq)
cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx
cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend

if USE_CUSTOM_MASK:
cur_seq_mask_start_idx = tl.load(mask_offsets + cur_seq)

offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
Expand Down Expand Up @@ -152,7 +159,20 @@ def _fwd_kernel(
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)

qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
if USE_CUSTOM_MASK:
custom_mask = tl.load(
mask_ptr
+ cur_seq_mask_start_idx
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
+ start_n
+ offs_n[None, :],
mask=(mask_m[:, None] & mask_n[None, :]),
other=0,
)
custom_mask &= mask_m[:, None] & mask_n[None, :]
qk = tl.where(custom_mask, qk, float("-inf"))
else:
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))

n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
Expand All @@ -172,7 +192,7 @@ def _fwd_kernel(

e_max = n_e_max

# stage 2: compute the trianlge part
# stage 2: compute the triangle part

cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
for start_n in range(0, cur_block_m_end, BLOCK_N):
Expand Down Expand Up @@ -208,11 +228,25 @@ def _fwd_kernel(
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)

mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
start_n + offs_n[None, :]
)
mask_causual &= mask_m[:, None] & mask_n[None, :]
qk = tl.where(mask_causual, qk, float("-inf"))
if USE_CUSTOM_MASK:
custom_mask = tl.load(
mask_ptr
+ cur_seq_mask_start_idx
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
+ cur_seq_len_prefix
+ start_n
+ offs_n[None, :],
mask=(mask_m[:, None] & mask_n[None, :]),
other=0,
)
custom_mask &= mask_m[:, None] & mask_n[None, :]
qk = tl.where(custom_mask, qk, float("-inf"))
else:
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
start_n + offs_n[None, :]
)
mask_causual &= mask_m[:, None] & mask_n[None, :]
qk = tl.where(mask_causual, qk, float("-inf"))

n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
Expand Down Expand Up @@ -253,6 +287,8 @@ def extend_attention_fwd(
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
mask_offsets,
max_len_extend,
sm_scale=None,
logit_cap=0.0,
Expand Down Expand Up @@ -308,6 +344,8 @@ def extend_attention_fwd(
batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
kv_group_num = q_extend.shape[1] // k_extend.shape[1]

USE_CUSTOM_MASK = custom_mask is not None

grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
num_stages = 1

Expand All @@ -325,6 +363,8 @@ def extend_attention_fwd(
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
mask_offsets,
sm_scale,
kv_group_num,
q_extend.stride(0),
Expand All @@ -347,6 +387,7 @@ def extend_attention_fwd(
BLOCK_N=BLOCK_N,
Lq=Lq,
Lv=Lv,
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
num_warps=num_warps,
num_stages=num_stages,
**extra_kargs,
Expand Down
43 changes: 43 additions & 0 deletions test/srt/test_triton_attention_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D):
).normal_(mean=0.1, std=0.2)

o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
o_extend_mask = torch.empty(
(extend_token_num, H_Q, D), dtype=dtype, device="cuda"
)
o_redundant = torch.empty(
(extend_token_num, H_Q, D), dtype=dtype, device="cuda"
)
Expand All @@ -98,6 +101,9 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D):
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)

custom_mask = None
mask_offsets = None

extend_attention_fwd(
q_extend,
k_extend,
Expand All @@ -108,6 +114,42 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D):
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
mask_offsets,
max_len_extend,
)

b_seq_mask_len = b_seq_len_extend * b_seq_len
custom_mask = torch.ones(
(b_seq_mask_len.sum().item(),), dtype=torch.bool, device="cuda"
)
mask_offsets = torch.zeros((B + 1,), dtype=torch.int64, device="cuda")
mask_offsets[1 : B + 1] = torch.cumsum(b_seq_mask_len[:B], dim=0)
for i in range(B):
causal_mask = (
torch.tril(
torch.ones(b_seq_len_extend[i], b_seq_len_extend[i]), diagonal=0
)
== 1
)
prefix_mask = torch.ones(
b_seq_len_extend[i], b_seq_len_prefix[i], dtype=torch.bool
)
mask_flatten = torch.cat([prefix_mask, causal_mask], dim=1).flatten()
custom_mask[mask_offsets[i] : mask_offsets[i + 1]] = mask_flatten

extend_attention_fwd(
q_extend,
k_extend,
v_extend,
o_extend_mask,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
mask_offsets,
max_len_extend,
)

Expand All @@ -124,6 +166,7 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D):
)

self.assertTrue(torch.allclose(o_extend, o_redundant, rtol=1e-2))
self.assertTrue(torch.allclose(o_extend_mask, o_redundant, rtol=1e-2))

def test_extend_attention(self):

Expand Down
Loading