From e1eae1fd15ed8e125ddcd18d0193ae8529c0c309 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Mon, 5 Aug 2024 01:40:33 +0800 Subject: [PATCH] Support MLA for DeepSeek-V2 with Triton - step 1 (#905) --- benchmark/gsm8k/download_data.sh | 0 python/sglang/srt/layers/extend_attention.py | 66 +++++- python/sglang/srt/layers/radix_attention.py | 31 ++- python/sglang/srt/layers/token_attention.py | 30 ++- python/sglang/srt/managers/schedule_batch.py | 7 +- python/sglang/srt/mem_cache/memory_pool.py | 89 ++++++-- python/sglang/srt/model_config.py | 11 + .../sglang/srt/model_executor/model_runner.py | 63 ++++-- python/sglang/srt/models/deepseek_v2.py | 214 ++++++++++++++++-- python/sglang/srt/server_args.py | 6 + 10 files changed, 439 insertions(+), 78 deletions(-) mode change 100644 => 100755 benchmark/gsm8k/download_data.sh diff --git a/benchmark/gsm8k/download_data.sh b/benchmark/gsm8k/download_data.sh old mode 100644 new mode 100755 diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index ca6f86cea2b..7398895d62d 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -57,6 +57,8 @@ def _fwd_kernel( stride_buf_vh, stride_req_to_tokens_b, BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, logit_cap: tl.constexpr, @@ -75,8 +77,10 @@ def _fwd_kernel( cur_batch_req_idx = tl.load(B_req_idx + cur_seq) offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) offs_m = tl.arange(0, BLOCK_M) mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend + offs_q = ( (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs @@ -85,10 +89,20 @@ def _fwd_kernel( ) q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0) + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + offs_qpe = ( + (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_qbs + + cur_head * stride_qh + + offs_dpe[None, :] + ) + qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0) + # stage1: compute scores with prefix offs_n = tl.arange(0, BLOCK_N) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) deno = tl.zeros([BLOCK_M], dtype=tl.float32) e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") @@ -110,6 +124,18 @@ def _fwd_kernel( qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) + if BLOCK_DPE > 0: + offs_kpe = ( + offs_kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_kpe, + mask=mask_n[None, :], + other=0.0, + ) + qk += tl.dot(qpe, kpe) qk *= sm_scale if logit_cap > 0: @@ -125,7 +151,7 @@ def _fwd_kernel( offs_buf_v = ( offs_kv_loc[:, None] * stride_buf_vbs + cur_kv_head * stride_buf_vh - + offs_d[None, :] + + offs_dv[None, :] ) v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0) p = p.to(v.dtype) @@ -150,6 +176,21 @@ def _fwd_kernel( qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) + + if BLOCK_DPE > 0: + offs_kpe = ( + (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) + * stride_kbs + + cur_kv_head * stride_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Extend + offs_kpe, + mask=mask_n[None, :], + other=0.0, + ) + qk += tl.dot(qpe, kpe) + qk *= sm_scale if logit_cap > 0: @@ -169,7 +210,7 @@ def _fwd_kernel( offs_v = ( (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs + cur_kv_head * stride_vh - + offs_d[None, :] + + offs_dv[None, :] ) v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0) p = p.to(v.dtype) @@ -181,7 +222,7 @@ def _fwd_kernel( (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_obs + cur_head * stride_oh - + offs_d[None, :] + + offs_dv[None, :] ) tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None]) @@ -217,8 +258,17 @@ def extend_attention_fwd( o_extend.shape[-1], ) - assert Lq == Lk and Lk == Lv and Lv == Lo - assert Lq in {16, 32, 64, 128, 256} + assert Lq == Lk and Lv == Lo + assert Lq in {16, 32, 64, 128, 256, 576} + assert Lv in {16, 32, 64, 128, 256, 512} + + if Lq == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + else: + BLOCK_DMODEL = Lq + BLOCK_DPE = 0 + BLOCK_DV = Lv if CUDA_CAPABILITY[0] >= 8: BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64) @@ -260,7 +310,9 @@ def extend_attention_fwd( v_buffer.stride(0), v_buffer.stride(1), req_to_tokens.stride(0), - BLOCK_DMODEL=Lq, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, num_warps=num_warps, diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 45b80b8f23e..784f0df3450 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -38,16 +38,22 @@ def __init__( num_kv_heads: int, layer_id: int, logit_cap: int = -1, + v_head_dim: int = -1, ): super().__init__() self.tp_q_head_num = num_heads self.tp_k_head_num = num_kv_heads self.tp_v_head_num = num_kv_heads self.head_dim = head_dim + self.qk_head_dim = head_dim + self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim self.scaling = scaling self.layer_id = layer_id - if not global_server_args_dict.get("disable_flashinfer", False): + if ( + not global_server_args_dict.get("disable_flashinfer", False) + and self.qk_head_dim == self.v_head_dim + ): self.extend_forward = self.extend_forward_flashinfer self.decode_forward = self.decode_forward_flashinfer else: @@ -57,13 +63,17 @@ def __init__( self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0 def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): - o = torch.empty_like(q) + if self.qk_head_dim != self.v_head_dim: + o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim)) + else: + o = torch.empty_like(q) + self.store_kv_cache(k, v, input_metadata) extend_attention_fwd( - q.view(-1, self.tp_q_head_num, self.head_dim), + q.view(-1, self.tp_q_head_num, self.qk_head_dim), k.contiguous(), v.contiguous(), - o.view(-1, self.tp_q_head_num, self.head_dim), + o.view(-1, self.tp_q_head_num, self.v_head_dim), input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id), input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id), input_metadata.req_to_token_pool.req_to_token, @@ -82,14 +92,17 @@ def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): return o def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata): - o = torch.empty_like(q) + if self.qk_head_dim != self.v_head_dim: + o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim)) + else: + o = torch.empty_like(q) self.store_kv_cache(k, v, input_metadata) token_attention_fwd( - q.view(-1, self.tp_q_head_num, self.head_dim), + q.view(-1, self.tp_q_head_num, self.qk_head_dim), input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id), input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id), - o.view(-1, self.tp_q_head_num, self.head_dim), + o.view(-1, self.tp_q_head_num, self.v_head_dim), input_metadata.req_to_token_pool.req_to_token, input_metadata.req_pool_indices, input_metadata.triton_start_loc, @@ -160,8 +173,8 @@ def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): return o.view(-1, self.tp_q_head_num * self.head_dim) def forward(self, q, k, v, input_metadata: InputMetadata): - k = k.view(-1, self.tp_k_head_num, self.head_dim) - v = v.view(-1, self.tp_v_head_num, self.head_dim) + k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) + v = v.view(-1, self.tp_v_head_num, self.v_head_dim) if input_metadata.forward_mode == ForwardMode.EXTEND: return self.extend_forward(q, k, v, input_metadata) diff --git a/python/sglang/srt/layers/token_attention.py b/python/sglang/srt/layers/token_attention.py index a792b7f3aee..ab6e7ba7727 100644 --- a/python/sglang/srt/layers/token_attention.py +++ b/python/sglang/srt/layers/token_attention.py @@ -54,6 +54,7 @@ def _fwd_kernel_stage1( att_stride_h, kv_group_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, BLOCK_N: tl.constexpr, logit_cap: tl.constexpr, ): @@ -73,6 +74,10 @@ def _fwd_kernel_stage1( off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + off_qpe = cur_batch * stride_qbs + cur_head * stride_qh + offs_dpe + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) block_stard_index = start_n * BLOCK_N @@ -97,6 +102,19 @@ def _fwd_kernel_stage1( other=0.0, ).to(REDUCE_TRITON_TYPE) att_value = tl.sum(q[None, :] * k, 1) + if BLOCK_DPE > 0: + qpe = tl.load(Q + off_qpe + start_mark).to(REDUCE_TRITON_TYPE) + offs_buf_kpe = ( + k_loc[:, None] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[None, :] + ) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=offs_n_new[:, None] < cur_batch_end_index, + other=0.0, + ).to(REDUCE_TRITON_TYPE) + att_value += tl.sum(qpe[None, :] * kpe, 1) att_value *= sm_scale if logit_cap > 0: @@ -192,7 +210,14 @@ def _token_att_m_fwd( # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128, 256} + assert Lk in {16, 32, 64, 128, 256, 576} + + if Lk == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + else: + BLOCK_DMODEL = Lk + BLOCK_DPE = 0 batch, head_num = B_req_idx.shape[0], q.shape[1] @@ -220,7 +245,8 @@ def _token_att_m_fwd( k_buffer.stride(1), att_out.stride(0), kv_group_num=kv_group_num, - BLOCK_DMODEL=Lk, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, BLOCK_N=BLOCK, logit_cap=logit_cap, num_warps=num_warps, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ed4c37a1746..5ebf12e30d4 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -29,7 +29,7 @@ from sglang.srt.constrained import RegexGuide from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.mem_cache.chunk_cache import ChunkCache -from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool +from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixCache INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 @@ -39,6 +39,7 @@ "disable_flashinfer": False, "disable_flashinfer_sampling": False, "attention_reduce_in_fp32": False, + "enable_mla": False, } @@ -289,7 +290,7 @@ class Batch: # Request, memory pool, and cache reqs: List[Req] req_to_token_pool: ReqToTokenPool - token_to_kv_pool: TokenToKVPool + token_to_kv_pool: BaseTokenToKVPool tree_cache: RadixCache # Batched arguments to model runner @@ -780,7 +781,7 @@ class InputMetadata: seq_lens: torch.Tensor positions: torch.Tensor req_to_token_pool: ReqToTokenPool - token_to_kv_pool: TokenToKVPool + token_to_kv_pool: BaseTokenToKVPool # For extend extend_seq_lens: torch.Tensor diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index fa38ee41c61..761b668bd86 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -57,32 +57,18 @@ def clear(self): self.can_use_mem_size = len(self.mem_state) -class TokenToKVPool: +class BaseTokenToKVPool: """A memory pool that maps a token to its kv cache locations""" def __init__( self, size: int, - dtype: torch.dtype, - head_num: int, - head_dim: int, - layer_num: int, ): self.size = size # We also add one slot. This slot is used for writing dummy output from padded tokens. self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda") - # [size, head_num, head_dim] for each layer - self.k_buffer = [ - torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda") - for _ in range(layer_num) - ] - self.v_buffer = [ - torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda") - for _ in range(layer_num) - ] - # Prefetch buffer self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32) self.prefetch_chunk_size = 512 @@ -90,15 +76,6 @@ def __init__( self.can_use_mem_size = self.size self.clear() - def get_key_buffer(self, layer_id: int): - return self.k_buffer[layer_id] - - def get_value_buffer(self, layer_id: int): - return self.v_buffer[layer_id] - - def get_kv_buffer(self, layer_id: int): - return self.k_buffer[layer_id], self.v_buffer[layer_id] - def available_size(self): return self.can_use_mem_size + len(self.prefetch_buffer) @@ -139,3 +116,67 @@ def clear(self): # We also add one slot. This slot is used for writing dummy output from padded tokens. self.mem_state[0] = False + + +class MHATokenToKVPool(BaseTokenToKVPool): + + def __init__( + self, + size: int, + dtype: torch.dtype, + head_num: int, + head_dim: int, + layer_num: int, + ): + super().__init__(size) + + # [size, head_num, head_dim] for each layer + self.k_buffer = [ + torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda") + for _ in range(layer_num) + ] + self.v_buffer = [ + torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda") + for _ in range(layer_num) + ] + + def get_key_buffer(self, layer_id: int): + return self.k_buffer[layer_id] + + def get_value_buffer(self, layer_id: int): + return self.v_buffer[layer_id] + + def get_kv_buffer(self, layer_id: int): + return self.k_buffer[layer_id], self.v_buffer[layer_id] + + +class MLATokenToKVPool(BaseTokenToKVPool): + + def __init__( + self, + size: int, + dtype: torch.dtype, + kv_lora_rank: int, + qk_rope_head_dim: int, + layer_num: int, + ): + super().__init__(size) + + self.kv_lora_rank = kv_lora_rank + self.kv_buffer = [ + torch.empty( + (size + 1, 1, kv_lora_rank + qk_rope_head_dim), + dtype=dtype, + device="cuda", + ) + for _ in range(layer_num) + ] + + def get_key_buffer(self, layer_id: int): + return self.kv_buffer[layer_id] + + def get_value_buffer(self, layer_id: int): + return self.kv_buffer[layer_id][..., : self.kv_lora_rank] + + def get_kv_buffer(self, layer_id: int): + return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) diff --git a/python/sglang/srt/model_config.py b/python/sglang/srt/model_config.py index 8a431b0a1b4..ed496515cd3 100644 --- a/python/sglang/srt/model_config.py +++ b/python/sglang/srt/model_config.py @@ -13,6 +13,7 @@ limitations under the License. """ +from enum import IntEnum, auto from typing import Optional from transformers import PretrainedConfig @@ -20,6 +21,11 @@ from sglang.srt.hf_transformers_utils import get_config, get_context_length +class AttentionArch(IntEnum): + MLA = auto() + MHA = auto() + + class ModelConfig: def __init__( self, @@ -55,6 +61,11 @@ def __init__( # FIXME: temporary special judge for deepseek v2 MLA architecture if "DeepseekV2ForCausalLM" in self.hf_config.architectures: self.head_dim = 256 + self.attention_arch = AttentionArch.MLA + self.kv_lora_rank = self.hf_config.kv_lora_rank + self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim + else: + self.attention_arch = AttentionArch.MHA self.num_attention_heads = self.hf_config.num_attention_heads self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e65ea788866..5aa6de55099 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -47,7 +47,12 @@ InputMetadata, global_server_args_dict, ) -from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool +from sglang.srt.mem_cache.memory_pool import ( + MHATokenToKVPool, + MLATokenToKVPool, + ReqToTokenPool, +) +from sglang.srt.model_config import AttentionArch from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( get_available_gpu_memory, @@ -86,6 +91,7 @@ def __init__( "disable_flashinfer": server_args.disable_flashinfer, "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling, "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, + "enable_mla": server_args.enable_mla, } ) @@ -193,15 +199,23 @@ def profile_max_num_token(self, total_gpu_memory): available_gpu_memory = get_available_gpu_memory( self.gpu_id, distributed=self.tp_size > 1 ) - head_dim = self.model_config.head_dim - head_num = self.model_config.get_num_kv_heads(self.tp_size) - cell_size = ( - head_num - * head_dim - * self.model_config.num_hidden_layers - * 2 - * torch._utils._element_size(self.dtype) - ) + if ( + self.model_config.attention_arch == AttentionArch.MLA + and self.server_args.enable_mla + ): + cell_size = ( + (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) + * self.model_config.num_hidden_layers + * torch._utils._element_size(self.dtype) + ) + else: + cell_size = ( + self.model_config.get_num_kv_heads(self.tp_size) + * self.model_config.head_dim + * self.model_config.num_hidden_layers + * 2 + * torch._utils._element_size(self.dtype) + ) rest_memory = available_gpu_memory - total_gpu_memory * ( 1 - self.mem_fraction_static ) @@ -241,13 +255,28 @@ def init_memory_pool( max_num_reqs, self.model_config.context_len + 8, ) - self.token_to_kv_pool = TokenToKVPool( - self.max_total_num_tokens, - dtype=self.dtype, - head_num=self.model_config.get_num_kv_heads(self.tp_size), - head_dim=self.model_config.head_dim, - layer_num=self.model_config.num_hidden_layers, - ) + if ( + self.model_config.attention_arch == AttentionArch.MLA + and self.server_args.enable_mla + ): + self.token_to_kv_pool = MLATokenToKVPool( + self.max_total_num_tokens, + dtype=self.dtype, + kv_lora_rank=self.model_config.kv_lora_rank, + qk_rope_head_dim=self.model_config.qk_rope_head_dim, + layer_num=self.model_config.num_hidden_layers, + ) + logger.info("using MLA Triton implementaion, flashinfer is disabled") + # FIXME: temporarily only Triton MLA is supported + self.server_args.disable_flashinfer = True + else: + self.token_to_kv_pool = MHATokenToKVPool( + self.max_total_num_tokens, + dtype=self.dtype, + head_num=self.model_config.get_num_kv_heads(self.tp_size), + head_dim=self.model_config.head_dim, + layer_num=self.model_config.num_hidden_layers, + ) logger.info( f"[gpu={self.gpu_id}] Memory pool end. " f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 4cc37c388a9..bc31d89ae72 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -45,6 +45,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.model_runner import InputMetadata @@ -312,6 +313,165 @@ def forward( return output +class DeepseekV2AttentionMLA(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + layer_id=None, + ) -> None: + super().__init__() + self.layer_id = layer_id + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + ) + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + ) + # O projection. + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + rope_scaling["type"] = "deepseek_yarn" + self.rotary_emb = get_rope( + qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False, + ) + + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.attn = RadixAttention( + self.num_local_heads, + self.kv_lora_rank + self.qk_rope_head_dim, + self.scaling, + num_kv_heads=1, + layer_id=layer_id, + v_head_dim=self.kv_lora_rank, + ) + + kv_b_proj = self.kv_b_proj + w_kc, w_vc = kv_b_proj.weight.unflatten( + 0, (-1, qk_nope_head_dim + v_head_dim) + ).split([qk_nope_head_dim, v_head_dim], dim=1) + self.w_kc = w_kc + self.w_vc = w_vc + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + q_len = hidden_states.shape[0] + q_input = hidden_states.new_empty( + q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim + ) + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view( + -1, self.num_local_heads, self.qk_head_dim + ) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_nope_out = q_input[..., : self.kv_lora_rank] + torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1)) + + k_input = self.kv_a_proj_with_mqa(hidden_states)[0].unsqueeze(1) + k_pe = k_input[..., self.kv_lora_rank :] + v_input = k_input[..., : self.kv_lora_rank] + v_input = self.kv_a_layernorm(v_input.contiguous()) + k_input[..., : self.kv_lora_rank] = v_input + + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + q_input[..., self.kv_lora_rank :] = q_pe + k_input[..., self.kv_lora_rank :] = k_pe + + attn_output = self.attn(q_input, k_input, v_input, input_metadata) + attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) + attn_bmm_output = attn_output.new_empty( + q_len, self.num_local_heads, self.v_head_dim + ) + torch.bmm( + attn_output.transpose(0, 1), + self.w_vc.transpose(1, 2).contiguous(), + out=attn_bmm_output.transpose(0, 1), + ) + + attn_output = attn_bmm_output.flatten(1, 2) + output, _ = self.o_proj(attn_output) + + return output + + class DeepseekV2DecoderLayer(nn.Module): def __init__( @@ -326,22 +486,44 @@ def __init__( rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - self.self_attn = DeepseekV2Attention( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=config.v_head_dim, - q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None, - kv_lora_rank=config.kv_lora_rank, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, - layer_id=layer_id, - ) + if global_server_args_dict["enable_mla"]: + self.self_attn = DeepseekV2AttentionMLA( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=( + config.q_lora_rank if hasattr(config, "q_lora_rank") else None + ), + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + layer_id=layer_id, + ) + else: + self.self_attn = DeepseekV2Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=( + config.q_lora_rank if hasattr(config, "q_lora_rank") else None + ), + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + layer_id=layer_id, + ) if ( config.n_routed_experts is not None and layer_id >= config.first_k_dense_replace diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 5114d99aabe..53aaca977e8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -80,6 +80,7 @@ class ServerArgs: disable_disk_cache: bool = False enable_torch_compile: bool = False enable_p2p_check: bool = False + enable_mla: bool = False attention_reduce_in_fp32: bool = False efficient_weight_load: bool = False @@ -393,6 +394,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.", ) + parser.add_argument( + "--enable-mla", + action="store_true", + help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2", + ) parser.add_argument( "--attention-reduce-in-fp32", action="store_true",