Skip to content

Commit

Permalink
Support MLA for DeepSeek-V2 with Triton - step 1 (#905)
Browse files Browse the repository at this point in the history
  • Loading branch information
ispobock authored Aug 4, 2024
1 parent f4d9953 commit e1eae1f
Show file tree
Hide file tree
Showing 10 changed files with 439 additions and 78 deletions.
Empty file modified benchmark/gsm8k/download_data.sh
100644 → 100755
Empty file.
66 changes: 59 additions & 7 deletions python/sglang/srt/layers/extend_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def _fwd_kernel(
stride_buf_vh,
stride_req_to_tokens_b,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr,
Expand All @@ -75,8 +77,10 @@ def _fwd_kernel(
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)

offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
offs_m = tl.arange(0, BLOCK_M)
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend

offs_q = (
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs
Expand All @@ -85,10 +89,20 @@ def _fwd_kernel(
)
q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0)

if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
offs_qpe = (
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs
+ cur_head * stride_qh
+ offs_dpe[None, :]
)
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)

# stage1: compute scores with prefix
offs_n = tl.arange(0, BLOCK_N)

acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")

Expand All @@ -110,6 +124,18 @@ def _fwd_kernel(

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
if BLOCK_DPE > 0:
offs_kpe = (
offs_kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Buffer + offs_kpe,
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe, kpe)
qk *= sm_scale

if logit_cap > 0:
Expand All @@ -125,7 +151,7 @@ def _fwd_kernel(
offs_buf_v = (
offs_kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_d[None, :]
+ offs_dv[None, :]
)
v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0)
p = p.to(v.dtype)
Expand All @@ -150,6 +176,21 @@ def _fwd_kernel(

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)

if BLOCK_DPE > 0:
offs_kpe = (
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
* stride_kbs
+ cur_kv_head * stride_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Extend + offs_kpe,
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe, kpe)

qk *= sm_scale

if logit_cap > 0:
Expand All @@ -169,7 +210,7 @@ def _fwd_kernel(
offs_v = (
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
+ cur_kv_head * stride_vh
+ offs_d[None, :]
+ offs_dv[None, :]
)
v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0)
p = p.to(v.dtype)
Expand All @@ -181,7 +222,7 @@ def _fwd_kernel(
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_obs
+ cur_head * stride_oh
+ offs_d[None, :]
+ offs_dv[None, :]
)
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])

Expand Down Expand Up @@ -217,8 +258,17 @@ def extend_attention_fwd(
o_extend.shape[-1],
)

assert Lq == Lk and Lk == Lv and Lv == Lo
assert Lq in {16, 32, 64, 128, 256}
assert Lq == Lk and Lv == Lo
assert Lq in {16, 32, 64, 128, 256, 576}
assert Lv in {16, 32, 64, 128, 256, 512}

if Lq == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
else:
BLOCK_DMODEL = Lq
BLOCK_DPE = 0
BLOCK_DV = Lv

if CUDA_CAPABILITY[0] >= 8:
BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
Expand Down Expand Up @@ -260,7 +310,9 @@ def extend_attention_fwd(
v_buffer.stride(0),
v_buffer.stride(1),
req_to_tokens.stride(0),
BLOCK_DMODEL=Lq,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
Expand Down
31 changes: 22 additions & 9 deletions python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,22 @@ def __init__(
num_kv_heads: int,
layer_id: int,
logit_cap: int = -1,
v_head_dim: int = -1,
):
super().__init__()
self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads
self.tp_v_head_num = num_kv_heads
self.head_dim = head_dim
self.qk_head_dim = head_dim
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
self.scaling = scaling
self.layer_id = layer_id

if not global_server_args_dict.get("disable_flashinfer", False):
if (
not global_server_args_dict.get("disable_flashinfer", False)
and self.qk_head_dim == self.v_head_dim
):
self.extend_forward = self.extend_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer
else:
Expand All @@ -57,13 +63,17 @@ def __init__(
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0

def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
if self.qk_head_dim != self.v_head_dim:
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
else:
o = torch.empty_like(q)

self.store_kv_cache(k, v, input_metadata)
extend_attention_fwd(
q.view(-1, self.tp_q_head_num, self.head_dim),
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
k.contiguous(),
v.contiguous(),
o.view(-1, self.tp_q_head_num, self.head_dim),
o.view(-1, self.tp_q_head_num, self.v_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
input_metadata.req_to_token_pool.req_to_token,
Expand All @@ -82,14 +92,17 @@ def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
return o

def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
if self.qk_head_dim != self.v_head_dim:
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
else:
o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata)

token_attention_fwd(
q.view(-1, self.tp_q_head_num, self.head_dim),
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
o.view(-1, self.tp_q_head_num, self.head_dim),
o.view(-1, self.tp_q_head_num, self.v_head_dim),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.triton_start_loc,
Expand Down Expand Up @@ -160,8 +173,8 @@ def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
return o.view(-1, self.tp_q_head_num * self.head_dim)

def forward(self, q, k, v, input_metadata: InputMetadata):
k = k.view(-1, self.tp_k_head_num, self.head_dim)
v = v.view(-1, self.tp_v_head_num, self.head_dim)
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)

if input_metadata.forward_mode == ForwardMode.EXTEND:
return self.extend_forward(q, k, v, input_metadata)
Expand Down
30 changes: 28 additions & 2 deletions python/sglang/srt/layers/token_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def _fwd_kernel_stage1(
att_stride_h,
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr,
):
Expand All @@ -73,6 +74,10 @@ def _fwd_kernel_stage1(

off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d

if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
off_qpe = cur_batch * stride_qbs + cur_head * stride_qh + offs_dpe

offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)

block_stard_index = start_n * BLOCK_N
Expand All @@ -97,6 +102,19 @@ def _fwd_kernel_stage1(
other=0.0,
).to(REDUCE_TRITON_TYPE)
att_value = tl.sum(q[None, :] * k, 1)
if BLOCK_DPE > 0:
qpe = tl.load(Q + off_qpe + start_mark).to(REDUCE_TRITON_TYPE)
offs_buf_kpe = (
k_loc[:, None] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[None, :]
)
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=offs_n_new[:, None] < cur_batch_end_index,
other=0.0,
).to(REDUCE_TRITON_TYPE)
att_value += tl.sum(qpe[None, :] * kpe, 1)
att_value *= sm_scale

if logit_cap > 0:
Expand Down Expand Up @@ -192,7 +210,14 @@ def _token_att_m_fwd(
# shape constraints
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
assert Lq == Lk
assert Lk in {16, 32, 64, 128, 256}
assert Lk in {16, 32, 64, 128, 256, 576}

if Lk == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
else:
BLOCK_DMODEL = Lk
BLOCK_DPE = 0

batch, head_num = B_req_idx.shape[0], q.shape[1]

Expand Down Expand Up @@ -220,7 +245,8 @@ def _token_att_m_fwd(
k_buffer.stride(1),
att_out.stride(0),
kv_group_num=kv_group_num,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_N=BLOCK,
logit_cap=logit_cap,
num_warps=num_warps,
Expand Down
7 changes: 4 additions & 3 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixCache

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Expand All @@ -39,6 +39,7 @@
"disable_flashinfer": False,
"disable_flashinfer_sampling": False,
"attention_reduce_in_fp32": False,
"enable_mla": False,
}


Expand Down Expand Up @@ -289,7 +290,7 @@ class Batch:
# Request, memory pool, and cache
reqs: List[Req]
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool
token_to_kv_pool: BaseTokenToKVPool
tree_cache: RadixCache

# Batched arguments to model runner
Expand Down Expand Up @@ -780,7 +781,7 @@ class InputMetadata:
seq_lens: torch.Tensor
positions: torch.Tensor
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool
token_to_kv_pool: BaseTokenToKVPool

# For extend
extend_seq_lens: torch.Tensor
Expand Down
Loading

0 comments on commit e1eae1f

Please sign in to comment.