diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 9006b7150aa..19f4960c5ac 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -115,7 +115,7 @@ def from_cli_args(cls, args: argparse.Namespace): ) -def load_model(server_args, tp_rank): +def load_model(server_args, tp_rank, sp_rank: int = 0): suppress_other_loggers() rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None @@ -130,6 +130,8 @@ def load_model(server_args, tp_rank): gpu_id=tp_rank, tp_rank=tp_rank, tp_size=server_args.tp_size, + sp_rank=sp_rank, + sp_size=server_args.sp_size, nccl_port=28888, server_args=server_args, ) @@ -206,6 +208,8 @@ def extend(reqs, model_runner): req_to_token_pool=model_runner.req_to_token_pool, token_to_kv_pool=model_runner.token_to_kv_pool, tree_cache=None, + sp_size=model_runner.sp_size, + sp_rank=model_runner.sp_rank, ) batch.prepare_for_extend(model_runner.model_config.vocab_size) sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND) @@ -225,11 +229,12 @@ def correctness_test( server_args, bench_args, tp_rank, + sp_rank=0, ): rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None # Load the model - model_runner, tokenizer = load_model(server_args, tp_rank) + model_runner, tokenizer = load_model(server_args, tp_rank, sp_rank) # Prepare inputs input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer) @@ -336,11 +341,12 @@ def latency_test( server_args, bench_args, tp_rank, + sp_rank=0, ): rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None # Load the model - model_runner, tokenizer = load_model(server_args, tp_rank) + model_runner, tokenizer = load_model(server_args, tp_rank, sp_rank) # Prepare inputs for warm up reqs = prepare_synthetic_inputs_for_latency_test( @@ -458,16 +464,18 @@ def main(server_args, bench_args): ) if server_args.tp_size == 1: - work_func(server_args, bench_args, 0) + work_func(server_args, bench_args, 0, 0) else: workers = [] for tp_rank in range(server_args.tp_size): + sp_rank = tp_rank % server_args.sp_size proc = multiprocessing.Process( target=work_func, args=( server_args, bench_args, tp_rank, + sp_rank, ), ) proc.start() diff --git a/python/sglang/srt/layers/parallel_utils/__init__.py b/python/sglang/srt/layers/parallel_utils/__init__.py new file mode 100644 index 00000000000..f8104e1d30d --- /dev/null +++ b/python/sglang/srt/layers/parallel_utils/__init__.py @@ -0,0 +1 @@ +from .parallel_state import * diff --git a/python/sglang/srt/layers/parallel_utils/parallel_state.py b/python/sglang/srt/layers/parallel_utils/parallel_state.py new file mode 100644 index 00000000000..4c1a05f0724 --- /dev/null +++ b/python/sglang/srt/layers/parallel_utils/parallel_state.py @@ -0,0 +1,96 @@ +from typing import List, Optional + +import torch +from vllm.distributed import initialize_model_parallel as vllm_initialize_model_parallel +from vllm.distributed.parallel_state import ( + GroupCoordinator, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_world_group, + init_model_parallel_group, +) + +_SP: Optional[GroupCoordinator] = None + + +def get_sp_group(): + assert _SP is not None, "sequence parallel group is not initialized" + return _SP + + +def init_sequence_parallel_group( + group_ranks: List[List[int]], local_rank: int, backend: str +) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_pynccl=True, + ) + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + sequence_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups and sequence parallel groups. + + For sequence parallelism, we partition SP groups within a TP group, and assign + gpus with adjacent ranks to the same SP group. For example, with TP size 8 + and SP size 2, we have 1 TP group and 4 SP groups: + SP groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + Their KV TP rank: + [ 0, 0], [ 1, 1], [ 2, 2], [ 3, 3] + Given that we replicate KV heads within the same seq parallel group, we also say that + the KV TP size is 4 (8//2), and gpus in each SP group have KV-tp rank from 0 to 3. + """ + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + + num_sequence_parallel_groups: int = world_size // sequence_parallel_size + global _SP + assert _SP is None, "sequence parallel group is already initialized" + group_ranks = [] + for i in range(num_sequence_parallel_groups): + ranks = list( + range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) + ) + group_ranks.append(ranks) + _SP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend) + + vllm_initialize_model_parallel( + tensor_model_parallel_size, pipeline_model_parallel_size, backend + ) + + +def sequence_parallel_is_initialized(): + return _SP is not None + + +def get_sequence_parallel_world_size(): + return get_sp_group().world_size + + +def get_sequence_parallel_rank(): + return get_sp_group().rank_in_group + + +def get_sequence_parallel_global_rank(): + return get_tensor_model_parallel_rank() + + +# NOTE: For sequence parallelism, we partition Q tensors along the head dimension. +# But K/V tensors are partitioned along the head dimension in TP and partitioned +# along the sequence dimensions in SP. Therefore, their TP size and rank is adjusted +# accordingly as below. +def get_kv_tensor_model_parallel_world_size(): + return get_tensor_model_parallel_world_size() // get_sequence_parallel_world_size() + + +def get_kv_tensor_model_parallel_rank(): + return get_tensor_model_parallel_rank() // get_sequence_parallel_world_size() diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 91735a1b810..9ec35cc70a6 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -20,10 +20,12 @@ import torch from flashinfer.cascade import merge_state from torch import nn +from torch.distributed import P2POp, batch_isend_irecv, irecv, isend from sglang.global_config import global_config from sglang.srt.layers.decode_attention import decode_attention_fwd from sglang.srt.layers.extend_attention import extend_attention_fwd +from sglang.srt.layers.parallel_utils import get_sp_group from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.model_runner import global_server_args_dict @@ -64,6 +66,11 @@ def __init__( self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0 def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): + if input_metadata.sp_size > 1: + raise NotImplementedError( + "Sequence parallel is not supported with Triton backend." + ) + if self.qk_head_dim != self.v_head_dim: o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim)) else: @@ -93,6 +100,11 @@ def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): return o def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata): + if input_metadata.sp_size > 1: + raise NotImplementedError( + "Sequence parallel is not supported with Triton backend." + ) + if self.qk_head_dim != self.v_head_dim: o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim)) else: @@ -117,6 +129,8 @@ def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata): return o def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): + if input_metadata.sp_size > 1: + return self.seq_parallel_extend_forward_flashinfer(q, k, v, input_metadata) # using two wrappers is unnecessary in the current PR, but are prepared for future PRs prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged if self.sliding_window_size != -1: @@ -171,6 +185,8 @@ def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): return o.view(-1, self.tp_q_head_num * self.head_dim) def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): + if input_metadata.sp_size > 1: + return self.seq_parallel_decode_forward_flashinfer(q, k, v, input_metadata) decode_wrapper = input_metadata.flashinfer_decode_wrapper if self.sliding_window_size != -1: decode_wrapper = decode_wrapper[0] @@ -191,6 +207,257 @@ def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): return o.view(-1, self.tp_q_head_num * self.head_dim) + def launch_sp_comm_ops( + self, kv_to_recv, kv_to_send, from_rank, to_rank, my_rank, sp_size, itr + ): + # Interleaving workers for send and recv to avoid deadlock + def _send_first(): + flags = [None for _ in range(sp_size)] + for _rank in range(sp_size): + _next = _rank + flag = True + while flags[_next] is None: + flags[_next] = flag + _next = (_next + itr) % sp_size + flag = not flag + return flags[my_rank] + + def _send(handles, group): + if my_rank != to_rank: + to_global_rank = group.first_rank + to_rank + for t in kv_to_send: + handles.append( + P2POp( + op=isend, + tensor=t, + peer=to_global_rank, + group=group.device_group, + ) + ) + + def _recv(handles, group): + if my_rank != from_rank: + from_global_rank = group.first_rank + from_rank + for t in kv_to_recv: + handles.append( + P2POp( + op=irecv, + tensor=t, + peer=from_global_rank, + group=group.device_group, + ) + ) + + handles = [] + reqs = [] + sp_group = get_sp_group() + + if _send_first(): + _send(handles, sp_group) + _recv(handles, sp_group) + else: + _recv(handles, sp_group) + _send(handles, sp_group) + if handles: + reqs = batch_isend_irecv(handles) + return reqs + + def wait_sp_comm_ops(self, reqs): + for req in reqs: + req.wait() + + def seq_parallel_extend_forward_flashinfer( + self, q, k, v, input_metadata: InputMetadata + ): + """Here we adopted a unique parallelization strategy. + For each SP worker, we have either (1) QKV of entire sequences: + q tensor: [padded_total_num_tokens, q_head_num // SP_SIZE, head_dim] + k tensor: [padded_total_num_tokens, k_head_num, head_dim] + v tensor: [padded_total_num_tokens, v_head_num, head_dim] + Or (2) Q of entire sequences and KV of the current SP shard: + q tensor: [padded_total_num_tokens, q_head_num // SP_SIZE, head_dim] + k tensor: [padded_sp_shard_num_tokens, k_head_num, head_dim] + v tensor: [padded_sp_shard_num_tokens, v_head_num, head_dim] + + Case (1) saves cross-SP-worker communication, while case (2) saves computation + to get K and V for entire sequences but need computation in SP attn. + """ + + def append_merge_shard(shard_list, o, s): + if len(shard_list) == 0: + shard_list.append((o, s)) + else: + o_prev, s_prev = shard_list[-1] + o, s = merge_state(o_prev, s_prev, o, s) + shard_list[-1] = (o, s) + + sp_rank = input_metadata.sp_rank + sp_size = input_metadata.sp_size + num_shards = num_iters = sp_size + sp_shard_size = (q.shape[0] + sp_size - 1) // sp_size + assert k.shape[0] == v.shape[0] and ( + k.shape[0] == q.shape[0] or k.shape[0] == sp_shard_size + ), "Invalid K and V partition in sequence parallel." + + qs = [] + for i in range(num_shards): + qs.append(q[sp_shard_size * i : sp_shard_size * (i + 1)]) + need_comm = k.shape[0] == sp_shard_size # Case 2. + + owned_sids = [sp_rank] + kv_shards = [None for _ in range(num_shards)] + output_shards = [[] for _ in range(num_shards)] + + if need_comm: # We have already got sharded K and V. + local_k = k.contiguous().view(-1, self.tp_k_head_num, self.head_dim) + local_v = v.contiguous().view(-1, self.tp_v_head_num, self.head_dim) + for i in range(sp_size): + if i == sp_rank: + kv_shards[i] = (local_k, local_v) + else: # reserve space for kv tensors received from other peers + kv_shards[i] = ( + torch.empty_like(local_k), + torch.empty_like(local_v), + ) + else: # We need to manually shard K and V. + for i in range(num_shards): + k_shard = k[sp_shard_size * i : sp_shard_size * (i + 1)] + v_shard = v[sp_shard_size * i : sp_shard_size * (i + 1)] + kv_shards[i] = ( + k_shard.contiguous().view(-1, self.tp_k_head_num, self.head_dim), + v_shard.contiguous().view(-1, self.tp_v_head_num, self.head_dim), + ) + local_k, local_v = kv_shards[sp_rank] + + # For communication + to_rank = sp_rank # which SP worker to send my sequence KV shard to. + from_rank = sp_rank # which SP worker to receive the sequence KV shard from. + sid = sp_rank # start from the worker's own shard + for itr in range(num_iters): + to_rank = (to_rank + 1) % sp_size + from_rank = (from_rank - 1) % sp_size + if need_comm: # Launch async communication operations + comm_reqs = self.launch_sp_comm_ops( + kv_shards[from_rank], + kv_shards[sp_rank], + from_rank, + to_rank, + sp_rank, + sp_size, + itr, + ) + q_shard = qs[sid] + k_shard, v_shard = kv_shards[sid] + # Self attention within the SP shard. + attn_wrapper = ( # Only the last SP shard needs a mask. + input_metadata.flashinfer_prefill_wrapper_sp_causal + if sid == sp_size - 1 + else input_metadata.flashinfer_prefill_wrapper_ragged + ) + o, s = attn_wrapper.forward_return_lse( + q_shard.contiguous().view(-1, self.tp_q_head_num, self.head_dim), + k_shard.contiguous().view(-1, self.tp_k_head_num, self.head_dim), + v_shard.contiguous().view(-1, self.tp_v_head_num, self.head_dim), + causal=True, + sm_scale=self.scaling, + logits_soft_cap=self.logit_cap, + ) + append_merge_shard(output_shards[sid], o, s) + # Cross SP shard attention. + # NOTE: below schedule is for load balancing. Basically, at iteration i, + # (i starting from 0), each SP worker will run i paged attentions. + for existing_sid in owned_sids: + if existing_sid == sid: + continue + # Due to the causal nature of the attention, swap pids if necessary. + i, j = ( + (existing_sid, sid) if existing_sid > sid else (sid, existing_sid) + ) + q_shard = qs[i] + k_shard, v_shard = kv_shards[j] + attn_wrapper = ( # Only the last SP shard needs a mask. + input_metadata.flashinfer_prefill_wrapper_sp_full + if i == sp_size - 1 + else input_metadata.flashinfer_prefill_wrapper_ragged + ) + o, s = attn_wrapper.forward_return_lse( + q_shard.contiguous().view(-1, self.tp_q_head_num, self.head_dim), + k_shard.contiguous().view(-1, self.tp_k_head_num, self.head_dim), + v_shard.contiguous().view(-1, self.tp_v_head_num, self.head_dim), + causal=False, + sm_scale=self.scaling, + logits_soft_cap=self.logit_cap, + ) + append_merge_shard(output_shards[i], o, s) + + if need_comm: # Wait for async communication to complete. + self.wait_sp_comm_ops(comm_reqs) + if sp_rank != from_rank: + owned_sids.append(from_rank) + sid = from_rank + + # Concat all output shards along the sequence dimension. + os = [o for shard_list in output_shards for o, _ in shard_list] + o = torch.cat(os, dim=0) + + self.store_kv_cache(local_k, local_v, input_metadata) + + if input_metadata.total_num_tokens >= global_config.layer_sync_threshold: + torch.cuda.synchronize() + + return o.view(-1, self.tp_q_head_num * self.head_dim) + + # TODO(yifan): check if flashinfer seq_parallel is broken after the rebase + def seq_parallel_decode_forward_flashinfer( + self, q, k, v, input_metadata: InputMetadata + ): + sp_size = input_metadata.sp_size + sp_rank = input_metadata.sp_rank + total_num_heads = self.tp_q_head_num * sp_size + + sp_offset = input_metadata.sp_local_token_offset + sp_len = input_metadata.sp_local_token_length + sp_slice = slice(sp_offset, sp_offset + sp_len) + cache_k = k[sp_slice] + cache_v = v[sp_slice] + self.store_kv_cache(cache_k, cache_v, input_metadata) + + # Convert Q back by gathering all TP heads. + q = q.contiguous().view(-1, self.tp_q_head_num, self.head_dim) + gathered_q = get_sp_group().all_gather(q.view(1, *q.shape), dim=0) + q = torch.empty_like(gathered_q).view(-1, total_num_heads, self.head_dim) + for i in range(sp_size): + idxs = _get_sequence_parallel_head_idxes( + total_num_heads, self.tp_k_head_num, i, sp_size + ) + q[:, idxs] = gathered_q[i] + + o, s = input_metadata.flashinfer_decode_wrapper.forward_return_lse( + q.contiguous().view(-1, total_num_heads, self.head_dim), + input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), + sm_scale=self.scaling, + logits_soft_cap=self.logit_cap, + ) + + # TODO: in fact we can use all-to-all to gather the output and state here + # to collect only q head shards that are needed by the current SP worker. + # All-to-all will save communication and `merge_state` computation. + os = get_sp_group().all_gather(o.view(1, *o.shape), dim=0) + ss = get_sp_group().all_gather(s.view(1, *s.shape), dim=0) + for i in range(sp_size): + if i != sp_rank: + o, s = merge_state(os[i], ss[i], o, s) + + # TODO: consequently, if we use all-to-all rather than all-gather, we don't + # need to partition the output again along the head dimension. + # Partition the output again along the head dimension. + idxs = _get_sequence_parallel_head_idxes( + total_num_heads, self.tp_k_head_num, sp_rank, sp_size + ) + o = o[:, idxs] + + return o.view(-1, self.tp_q_head_num * self.head_dim) + def forward(self, q, k, v, input_metadata: InputMetadata): if k is not None: assert v is not None @@ -206,3 +473,15 @@ def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): input_metadata.token_to_kv_pool.set_kv_buffer( self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v ) + + +def _get_sequence_parallel_head_idxes(total_num_heads, num_kv_heads, sp_rank, sp_size): + group_size = total_num_heads // num_kv_heads + shard_num_heads = group_size // sp_size + + idxes = [ + group_size * i + sp_rank * shard_num_heads + j + for i in range(num_kv_heads) + for j in range(0, shard_num_heads) + ] + return idxes diff --git a/python/sglang/srt/layers/sp_linear.py b/python/sglang/srt/layers/sp_linear.py new file mode 100644 index 00000000000..f8757738aee --- /dev/null +++ b/python/sglang/srt/layers/sp_linear.py @@ -0,0 +1,503 @@ +# Adapted from +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/layers/linear.py#L1 +import logging +from typing import Dict, Iterable, Optional, Tuple + +import torch +from torch.nn.parameter import Parameter +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + RowParallelLinear, + adjust_marlin_shard, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + +from sglang.srt.layers.parallel_utils import ( + get_kv_tensor_model_parallel_rank, + get_kv_tensor_model_parallel_world_size, + get_sequence_parallel_rank, + get_sequence_parallel_world_size, +) + +logger = logging.getLogger(__name__) + + +def adjust_bitsandbytes_shard( + param: Parameter, kv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str +) -> Tuple[int, int]: + """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" + + total, _ = kv_offsets["total"] + orig_offset, orig_size = kv_offsets[loaded_shard_id] + + quantized_total = param.data.shape[0] + quantized_offset = orig_offset * quantized_total // total + quantized_size = orig_size * quantized_total // total + + return quantized_size, quantized_offset + + +def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): + """For fused modules (KV) we have an array of length + N that holds 1 scale for each "logical" matrix. So the param + is an array of length N. The loaded_weight corresponds to + one of the shards on disk. Here, we slice the param based on + the shard_id for loading. + """ + kv_idxs = {"k": 0, "v": 1} + + if isinstance(shard_id, str): + shard_id = kv_idxs[shard_id] + elif not isinstance(shard_id, int): + raise ValueError(f"Unknown Shard Id {shard_id}") + + # AutoFP8 scales do not have a shape + # compressed-tensors scales do have a shape + if len(loaded_weight.shape) != 0: + assert loaded_weight.shape[0] == 1 + loaded_weight = loaded_weight[0] + + return param[shard_id], loaded_weight + + +class QKVParallelLinear(torch.nn.Module): + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + # q projection can be naively tensor parallelized. However, to adapt to + # GQA, we need to manually partition q heads for sequence parallelism. + # See _get_sequence_parallel_head_idxes() for details. + self.q_proj = ColumnSeqParallelLinear( + hidden_size, + head_size, + total_num_heads, + total_num_kv_heads, + bias, + skip_bias_add, + params_dtype, + quant_config, + f"{prefix}.q_proj", + ) + # kv projection needs both tensor and sequence parallelization + self.kv_proj = KVSeqParallelLinear( + hidden_size, + head_size, + total_num_heads, + total_num_kv_heads, + bias, + skip_bias_add, + params_dtype, + quant_config, + f"{prefix}.kv_proj", + ) + self.hidden_size = hidden_size + self.kv_size = self.kv_proj.num_kv_heads * self.kv_proj.head_size + + def forward( + self, + hidden_states: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q, _ = self.q_proj(hidden_states) + kv, _ = self.kv_proj(hidden_states) + k, v = kv.split([self.kv_size, self.kv_size], dim=-1) + return q, k, v + + +class ColumnSeqParallelLinear(ColumnParallelLinear): + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + self.hidden_size = hidden_size + self.head_size = head_size + if total_num_kv_heads is None: + total_num_kv_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + # Divide the weight matrix along the last dimension. + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = total_num_heads + self.num_heads = divide(total_num_heads, tp_size) + # num_kv_heads is used for tracking the number of groups in GQA. + kv_tp_size = get_kv_tensor_model_parallel_world_size() + if kv_tp_size >= self.total_num_kv_heads: + self.num_kv_heads = 1 + self.num_kv_head_replicas = divide(kv_tp_size, self.total_num_kv_heads) + else: + self.num_kv_heads = divide(self.total_num_kv_heads, kv_tp_size) + self.num_kv_head_replicas = 1 + + input_size = self.hidden_size + # NOTE: here we use total_num_heads to make the parent class happy because + # it expects pure tensor parallelism along the heads dimension. output_size + # here is the total size of all TP and SP workers. + output_size = self.total_num_heads * self.head_size + + super().__init__( + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + ) + + def weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + ): + kv_tp_rank = get_kv_tensor_model_parallel_rank() + kv_tp_size = get_kv_tensor_model_parallel_world_size() + sp_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + + output_dim = getattr(param, "output_dim", None) + param_data = param.data + if output_dim is not None: + shard_size = param_data.shape[output_dim] + # Load TP weight shard + tp_shard_size = shard_size * sp_size + start_idx = kv_tp_rank * tp_shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, tp_shard_size) + # Load SP weight shard + tp_num_heads = self.total_num_heads // kv_tp_size + idxes = torch.tensor( + _get_sequence_parallel_head_idxes( + tp_num_heads, self.num_kv_heads, sp_rank, sp_size + ), + dtype=torch.int32, + ) + weight_shape = loaded_weight.shape + tp_shard_shape = _reshape_dimension( + weight_shape, output_dim, [tp_num_heads, self.head_size] + ) + sp_shard_shape = _reshape_dimension(weight_shape, output_dim, [shard_size]) + loaded_weight = ( + loaded_weight.reshape(tp_shard_shape) + .index_select(output_dim, idxes) + .contiguous() + .view(sp_shard_shape) + ) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +# Adapted from +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/layers/linear.py#L422 +class KVSeqParallelLinear(ColumnParallelLinear): + """Linear layers for the attention's KV transformation. + + Linear layers for the linear transformation of the key, and value + vectors in the attention layer. The weight matrix is concatenated along + the output dimension. The layer is parallelized along the head dimension. + When the number of key/value heads is smaller than the number of query + heads (e.g., multi-query/grouped-query attention), the key/value head may + be replicated while the query heads are partitioned. + + Args: + hidden_size: input hidden state size of the transformer. + head_size: size of each attention head. + total_num_heads: total number of attention query heads. + total_num_kv_heads: total number of attention key/value heads. If + None, assume total_num_kv_heads = total_num_heads. + bias: If true, add bias. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + linear_method: (Maybe quantized) linear method. + """ + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + self.hidden_size = hidden_size + self.head_size = head_size + if total_num_kv_heads is None: + total_num_kv_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + # Divide the weight matrix along the last dimension. + kv_tp_size = get_kv_tensor_model_parallel_world_size() + if kv_tp_size >= self.total_num_kv_heads: + self.num_kv_heads = 1 + self.num_kv_head_replicas = divide(kv_tp_size, self.total_num_kv_heads) + else: + self.num_kv_heads = divide(self.total_num_kv_heads, kv_tp_size) + self.num_kv_head_replicas = 1 + input_size = self.hidden_size + # NOTE: here we use tp_size to make the parent class happy because it + # expects pure tensor parallelism along the num_heads dimension. + tp_size = get_tensor_model_parallel_world_size() + output_size = 2 * self.num_kv_heads * tp_size * self.head_size + self.output_sizes = [ + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj + ] + + super().__init__( + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + ) + + def weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None, + ): + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + + # Special case for per-tensor scales in fused case. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (qkv/mlp). + if get_sequence_parallel_world_size() > 1: + raise NotImplementedError( + "Fused weight loading is not supported in SP." + ) + if output_dim is None: + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ( + "k", + 0, + self.total_num_kv_heads * self.head_size, + ), + ( + "v", + self.total_num_kv_heads * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ] + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size + ) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + kv_tp_rank = get_kv_tensor_model_parallel_rank() + assert loaded_shard_id in ["k", "v"] + + # If output dim is defined, use the default loading process. + if output_dim is not None: + if loaded_shard_id == "k": + shard_offset = 0 + shard_size = self.num_kv_heads * self.head_size + elif loaded_shard_id == "v": + shard_offset = self.num_kv_heads * self.head_size + shard_size = self.num_kv_heads * self.head_size + # Special case for Quantized Weights. + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset + ) + + use_bitsandbytes = getattr(param, "use_bitsandbytes", False) + if use_bitsandbytes: + orig_kv_offsets = { + "k": ( + 0, + self.num_kv_heads * self.head_size, + ), + "v": ( + self.num_kv_heads * self.head_size, + self.num_kv_heads * self.head_size, + ), + "total": ( + 2 * self.num_kv_heads * self.head_size, + 0, + ), + } + shard_size, shard_offset = adjust_bitsandbytes_shard( + param, orig_kv_offsets, loaded_shard_id + ) + + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + shard_id = kv_tp_rank // self.num_kv_head_replicas + start_idx = shard_id * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_index = ["k", "v"].index(loaded_shard_id) + param_data = param_data.narrow(0, shard_index * shard_size, shard_size) + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id + ) + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "QKVParallelLinear, assume the weight is the same " + "for all partitions." + ) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class RowSeqParallelLinear(RowParallelLinear): + """TODO: add doc string.""" + + def __init__( + self, + input_size: int, + output_size: int, + total_num_heads: int, + num_kv_heads: int, + head_dim: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__( + input_size, + output_size, + bias, + input_is_parallel, + skip_bias_add, + params_dtype, + reduce_results, + quant_config, + prefix, + ) + self.total_num_heads = total_num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + kv_tp_rank = get_kv_tensor_model_parallel_rank() + kv_tp_size = get_kv_tensor_model_parallel_world_size() + sp_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + + input_dim = getattr(param, "input_dim", None) + param_data = param.data + if input_dim is not None: + shard_size = param_data.shape[input_dim] + # Load TP weight shard + tp_shard_size = shard_size * sp_size + start_idx = kv_tp_rank * tp_shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, tp_shard_size) + # Load SP weight shard + tp_num_heads = self.total_num_heads // kv_tp_size + idxes = torch.tensor( + _get_sequence_parallel_head_idxes( + tp_num_heads, self.num_kv_heads, sp_rank, sp_size + ) + ) + weight_shape = loaded_weight.shape + tp_shard_shape = _reshape_dimension( + weight_shape, input_dim, [tp_num_heads, self.head_dim] + ) + sp_shard_shape = _reshape_dimension(weight_shape, input_dim, [shard_size]) + loaded_weight = ( + loaded_weight.reshape(tp_shard_shape) + .index_select(input_dim, idxes) + .contiguous() + .view(sp_shard_shape) + ) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +def _get_sequence_parallel_head_idxes(total_num_heads, num_kv_heads, sp_rank, sp_size): + group_size = total_num_heads // num_kv_heads + shard_num_heads = group_size // sp_size + + idxes = [ + group_size * i + sp_rank * shard_num_heads + j + for i in range(num_kv_heads) + for j in range(0, shard_num_heads) + ] + return idxes + + +def _reshape_dimension(shape: Tuple[int], dim_idx: int, new_dims: Iterable[int]): + if isinstance(new_dims, int): + new_dims = (new_dims,) + if not isinstance(shape, tuple): + raise TypeError("shape must be a tuple") + if not isinstance(new_dims, (list, tuple)): + raise TypeError("new_dims must be a list or a tuple") + if dim_idx < 0 or dim_idx >= len(shape): + raise IndexError("dim_idx out of range") + + new_shape = shape[:dim_idx] + tuple(new_dims) + shape[dim_idx + 1 :] + return new_shape diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c80cf2e2723..48c1dcb0e20 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -21,11 +21,18 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional, Union +import numpy as np import torch from sglang.global_config import global_config from sglang.srt.constrained import RegexGuide from sglang.srt.constrained.jump_forward import JumpForwardMap +from sglang.srt.managers.seq_parallel_layout import ( + seq_parallel_decode_indices, + seq_parallel_input_ids_decode, + seq_parallel_input_ids_extend, + seq_parallel_local_len_extend, +) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool @@ -348,8 +355,22 @@ class ScheduleBatch: return_logprob: bool = False top_logprobs_nums: List[int] = None + # Sequence Parallel params + sp_size: int = None + sp_rank: int = None + prefill_extend_lens: np.ndarray = None + sp_decode_local_lens: np.ndarray = None + @classmethod - def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): + def init_new( + cls, + reqs, + req_to_token_pool, + token_to_kv_pool, + tree_cache, + sp_size: int = 1, + sp_rank: int = 0, + ): return_logprob = any(req.return_logprob for req in reqs) return cls( @@ -358,6 +379,8 @@ def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): token_to_kv_pool=token_to_kv_pool, tree_cache=tree_cache, return_logprob=return_logprob, + sp_size=sp_size, + sp_rank=sp_rank, ) def batch_size(self): @@ -402,8 +425,24 @@ def prepare_for_extend(self, vocab_size: int): extend_num_tokens = sum(len(ids) for ids in input_ids) seq_lens = [] + if self.sp_size == 1: + flatten_input_ids = sum(input_ids, []) + else: + flatten_input_ids = seq_parallel_input_ids_extend( + input_ids, self.sp_size, bs + ) + # Allocate memory req_pool_indices_cpu = self.alloc_req_slots(bs) + if self.sp_size > 1: + ext_lens = np.asarray( + [len(req.fill_ids) - len(req.prefix_indices) for req in reqs] + ) + extend_local_token_nums = seq_parallel_local_len_extend( + self.sp_rank, self.sp_size, ext_lens + ) + self.prefill_extend_lens = ext_lens + extend_num_tokens = int(np.sum(extend_local_token_nums)) out_cache_loc = self.alloc_token_slots(extend_num_tokens) pt = 0 @@ -418,6 +457,11 @@ def prepare_for_extend(self, vocab_size: int): :pre_len ] = req.prefix_indices + if self.sp_size > 1: + ext_len = extend_local_token_nums[i] + # Prefix are stored elsewhere and not affected by the layout of + # **this** request. + seq_len = pre_len + ext_len self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = ( out_cache_loc[pt : pt + ext_len] ) @@ -425,7 +469,7 @@ def prepare_for_extend(self, vocab_size: int): # Set fields with torch.device("cuda"): - self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32) + self.input_ids = torch.tensor(flatten_input_ids, dtype=torch.int32) self.req_pool_indices = torch.tensor(req_pool_indices_cpu) self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32) self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64) @@ -636,13 +680,38 @@ def prepare_for_decode(self, input_ids=None): self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda") self.seq_lens.add_(1) + if self.sp_size > 1: + seq_lens_cpu = self.seq_lens.cpu().numpy() + input_ids = seq_parallel_input_ids_decode( + input_ids, self.sp_size, seq_lens_cpu + ) + self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda") + # Alloc mem bs = self.batch_size() + if self.sp_size > 1: + sp_local_indices = seq_parallel_decode_indices( + self.sp_rank, self.sp_size, seq_lens_cpu + ) + bs = len(sp_local_indices) + self.out_cache_loc = self.alloc_token_slots(bs) - self.req_to_token_pool.req_to_token[ - self.req_pool_indices, self.seq_lens - 1 - ] = self.out_cache_loc + if self.sp_size > 1: + # With SP, reqs are partitioned across SP workers so we need to use + # decode_local_lens instead of seq_lens when preparing KV cache. + bs = self.batch_size() + sp_decode_local_lens = self._sp_decode_local_len(range(bs)) + self.sp_decode_local_lens = torch.from_numpy(sp_decode_local_lens) + local_req_indices = self.req_pool_indices[sp_local_indices] + local_lens_cpu = sp_decode_local_lens[sp_local_indices] + self.req_to_token_pool.req_to_token[ + local_req_indices, local_lens_cpu - 1 + ] = self.out_cache_loc + else: + self.req_to_token_pool.req_to_token[ + self.req_pool_indices, self.seq_lens - 1 + ] = self.out_cache_loc self.sampling_info.update_regex_vocab_mask(self) @@ -665,6 +734,8 @@ def filter_batch(self, unfinished_indices: List[int]): self.out_cache_loc = None self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices] self.return_logprob = any(req.return_logprob for req in self.reqs) + if self.sp_size > 1: + self.prefill_extend_lens = self.prefill_extend_lens[new_indices] self.sampling_info.filter(unfinished_indices, new_indices) @@ -686,6 +757,10 @@ def merge(self, other: "ScheduleBatch"): self.out_cache_loc = None self.top_logprobs_nums.extend(other.top_logprobs_nums) self.return_logprob = any(req.return_logprob for req in self.reqs) + if self.sp_size > 1: + self.prefill_extend_lens = np.concatenate( + [self.prefill_extend_lens, other.prefill_extend_lens] + ) def check_sample_results(self, sample_output: SampleOutput): if not torch.all(sample_output.success): @@ -701,3 +776,36 @@ def check_sample_results(self, sample_output: SampleOutput): sample_output.batch_next_token_ids = batch_next_token_ids return sample_output.batch_next_token_ids + + def _sp_decode_local_len(self, local_req_indices: np.ndarray): + """ + Args: + local_req_indices(np.ndarray): 1D int array indexing selected + requests that stores KV-Cache on this SP rank. + Returns: + local_len(np.ndarray): 1D int array, describing the local KV cache + length on this SP rank, for selected request indices. + """ + sp_size = self.sp_size + + extend_lens = self.prefill_extend_lens[local_req_indices] + cur_lens = self.seq_lens.cpu().numpy()[local_req_indices] + decode_lens = cur_lens - extend_lens + + extend_chunk_size = np.ceil(extend_lens / sp_size).astype(np.int32) + if self.sp_rank != sp_size - 1: + extend_size = extend_chunk_size + else: + extend_size = extend_lens - extend_chunk_size * (sp_size - 1) + # note that sp_len (as well as decode_lens) already increased 1. + # NOTE: for decoding tokens, assume there's no prefix, they are located: + # dec token 0 = all token [extend_lens] = stored at extend_lens % sp + # decode token i = stored at (extend_lens + i) % sp + # Hence, for the remainder tokens, they are stored at extend_lens % sp, + # extend_lens % sp + 1, ... + # For example, if sp = 4, extend lens = 6, the first decode remainder + # token is at rank 3 (7 % 4) + decode_extra_tok_offset = (self.sp_rank - extend_lens - 1) % sp_size + decode_extra_tok = decode_extra_tok_offset < (decode_lens % sp_size) + decode_size = decode_lens // sp_size + decode_extra_tok + return extend_size + decode_size diff --git a/python/sglang/srt/managers/seq_parallel_layout.py b/python/sglang/srt/managers/seq_parallel_layout.py new file mode 100644 index 00000000000..279e77f7c4a --- /dev/null +++ b/python/sglang/srt/managers/seq_parallel_layout.py @@ -0,0 +1,302 @@ +"""Util functions for sequence parallel layout and runtime metadata.""" + +import itertools +from typing import TYPE_CHECKING, Sequence, Union + +import numpy as np +import torch + +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import ScheduleBatch + from sglang.srt.model_executor.forward_batch_info import InputMetadata + from sglang.srt.model_executor.model_runner import ModelRunner + + +#### Offset of a sequence parallel shard under the sequence parallel layout. +def _seq_parallel_offset_extend(sp_rank, sp_size, extend_seq_lens: np.ndarray): + return np.sum(np.ceil(extend_seq_lens / sp_size).astype(np.int32)) * sp_rank + + +def _seq_parallel_offset_decode(sp_rank, sp_size, seq_lens: np.ndarray): + return np.sum((seq_lens % sp_size) < sp_rank) + + +#### Indices from sequence parallel layout to normal layout +def _sp_to_normal_indices_extend(sp_size, extend_seq_lens: np.ndarray): + """ + Indices from the Sequence Parallel layout (padded) to the normal layout. + """ + sp_seq_lens = np.ceil(extend_seq_lens / sp_size).astype(np.int32) + sp_len = np.sum(sp_seq_lens) + sp_seq_offset = np.concatenate( + [np.asarray([0], dtype=np.int32), np.cumsum(sp_seq_lens[:-1])] + ) + sp_arange = np.arange(sp_size).reshape(-1, 1) + indices = [] + for i in range(len(extend_seq_lens)): + sp_idx = np.arange(sp_seq_lens[i]).reshape(1, -1).repeat(sp_size, axis=0) + sp_idx = (sp_idx + sp_seq_offset[i] + sp_len * sp_arange).reshape(-1) + sp_idx = sp_idx[: extend_seq_lens[i]] + indices.append(sp_idx) + indices = np.concatenate(indices) + return indices + + +def _sp_to_normal_indices_decode(sp_size, seq_lens: np.ndarray): + """ + Indices from the Sequence Parallel layout (padded) to the normal layout. + """ + req_sp_rank = seq_lens % sp_size + sp_rank_size = [np.sum(req_sp_rank == r) for r in range(sp_size)] + req_sp_offset = np.cumsum(np.asarray([0] + sp_rank_size[:-1])) + req_sp_offset = req_sp_offset[req_sp_rank] + for sp_rank in range(sp_size): + local_reqs = req_sp_rank == sp_rank + req_sp_index = np.cumsum(local_reqs) - 1 + req_sp_offset += req_sp_index * local_reqs # mask out reqs not here. + return req_sp_offset + + +#### From normal layout to sequence parallel layout. Only for debug purpose +def _debug_normal_to_sp_indices_decode(sp_size, seq_lens): + """(Debug only) Indices from normal layout to the SP layout (padded).""" + indices = [ + seq_parallel_decode_indices(sp_rank, sp_size, seq_lens) + for sp_rank in range(sp_size) + ] + indices = [(np.arange(len(idxs)), idxs) for idxs in indices] + return indices + + +def _debug_normal_to_sp_indices_extend(sp_size, seq_lens): + """(Debug only) Indices from normal layout to the SP layout (padded).""" + indices = [] + sp_seq_lens = np.ceil(seq_lens / sp_size).astype(np.int32) + seq_offset = np.concatenate( + [np.asarray([0], dtype=np.int32), np.cumsum(seq_lens[:-1])] + ) + sp_seq_offset = np.concatenate( + [np.asarray([0], dtype=np.int32), np.cumsum(sp_seq_lens[:-1])] + ) + for sp_rank in range(sp_size): + start_idx = seq_offset + sp_seq_lens * sp_rank + end_idx = np.minimum(seq_offset + sp_seq_lens * (sp_rank + 1), seq_lens) + normal_layout_idx = np.concatenate( + [np.arange(start_idx[i], end_idx[i]) for i in range(len(seq_lens))] + ) + if sp_rank == sp_size - 1: + length = end_idx - start_idx + target_layout_idx = np.concatenate( + [ + np.arange(sp_seq_offset[i], sp_seq_offset[i] + length[i]) + for i in range(len(seq_lens)) + ] + ) + else: + target_layout_idx = np.arange(len(normal_layout_idx)) + indices.append((target_layout_idx, normal_layout_idx)) + return indices + + +def _debug_normal_to_sp(indices, output_tensor, tensor): + """ + Use the indices generated above to translate from a normal layout to a + SP layout (padded). Due to the padding, `output_tensor`'s shape is different + from the input `tensor`'s. + """ + for idxs in indices: + output_tensor[idxs] = tensor + output_tensor = output_tensor.contiguous() + return output_tensor + + +#### Padding +def seq_parallel_pad_zeros( + indices: torch.Tensor, seq_lens, sp_size: int, only_last_shard: bool = False +): + """ + Add padding zeros to SP-layout indices (must be a 1D tensor) so that the last + SP shard will have its sequences padded after each sequence and all SP shards + can have the same length. + + This function is used to (1) adjust the positions tensor to align input_ids with + their positions during positional encoding and (2) adjust the output cache location + to write KV cache of padded tokens to slot 0 (reserved for dummy output). + """ + sp_seq_lens = np.ceil(seq_lens / sp_size).astype(np.int32) + last_sp_seq_lens = seq_lens - sp_seq_lens * (sp_size - 1) + padded_num_tokens = np.sum(sp_seq_lens).astype(np.int32) + if only_last_shard: + padded_indices = torch.zeros( + padded_num_tokens, dtype=indices.dtype, device=indices.device + ) + padded_stt = stt = 0 + else: + padded_indices = torch.zeros( + sp_size * padded_num_tokens, dtype=indices.dtype, device=indices.device + ) + # All non-last shards do not need padding and hence can be copied. + padded_stt = padded_num_tokens * (sp_size - 1) + stt = padded_stt + padded_indices[:padded_stt] = indices[:stt] + + bs = seq_lens.size + for i in range(bs): + padded_end = padded_stt + sp_seq_lens[i] + end = stt + last_sp_seq_lens[i] + padded_indices[padded_stt : padded_stt + last_sp_seq_lens[i]] = indices[stt:end] + padded_stt = padded_end + stt = end + return padded_indices + + +def _get_num_padding_tokens(sp_size, extend_seq_lens: np.ndarray): + """Get the number of tokens padded for SP.""" + padded_size = np.ceil(extend_seq_lens / sp_size).astype(np.int32) + return sp_size * padded_size - extend_seq_lens + + +#### Get length/indices of sequence parallel local tokens within a batch +def seq_parallel_local_len_extend( + sp_rank, sp_size, extend_seq_lens: Union[int, np.ndarray] +): + """Get the number of tokens in this SP. Padding is not considered.""" + padded_size = np.ceil(extend_seq_lens / sp_size).astype(np.int32) + return ( + padded_size + if sp_rank != sp_size - 1 + else extend_seq_lens - (sp_size - 1) * padded_size + ) + + +def seq_parallel_extend_local_token_slice(sp_rank, sp_size, seq_len: int): + """Get the SP local slice for a single request's extended input ids.""" + start = int(np.ceil(seq_len / sp_size) * sp_rank) + length = seq_parallel_local_len_extend(sp_rank, sp_size, seq_len) + return slice(start, start + length) + + +def seq_parallel_decode_indices(sp_rank, sp_size, seq_lens: np.ndarray): + """Get Indices from the normal layout to the sequence parallel layout.""" + return np.nonzero((seq_lens % sp_size) == sp_rank)[0] + + +#### Transpose to sequence parallel layout +def seq_parallel_input_ids_extend( + input_ids: Sequence[Sequence[int]], sp_size: int, bs: int +): + # Note: The flatten input ids with Sequence Parallel is in form of: + # [req_0_sp_0, req_1_sp_0, ... req_n_sp_0, + # req_0_sp_1, req_1_sp_1, ..., req_n_sp_1, + # ... + # req_0_sp_m, req_0_padding, req_1_sp_m, req_1_padding, ...] + # ] + # The padding is for collection primitives which needs each candidate to + # have the same size. Since we don't expect too many requests in SP, + # the extra compute caused by this is affordable. + flatten_input_ids = [[] for _ in range(sp_size)] + num_padding_tokens = _get_num_padding_tokens( + sp_size, np.asarray([len(ids) for ids in input_ids]) + ) + for i in range(bs): + for sp_rank in range(sp_size): + ids = input_ids[i] + local_slice = seq_parallel_extend_local_token_slice( + sp_rank, sp_size, len(ids) + ) + flatten_input_ids[sp_rank].extend(ids[local_slice]) + flatten_input_ids[-1].extend([0] * num_padding_tokens[i]) + flatten_input_ids = list(itertools.chain(*flatten_input_ids)) + return flatten_input_ids + + +def seq_parallel_input_ids_decode( + input_ids: Sequence[int], sp_size: int, seq_lens: np.ndarray +): + input_indices_sp = [[] for _ in range(sp_size)] + # NOTE: in the extend phase, we evenly do sequence partition on extended + # tokens (extend_len). However, since prefix lens is cleaned, we instead + # use the whole sequence length (seq_lens) for the round-robin KV-cache. + for sp_rank in range(sp_size): + indices = seq_parallel_decode_indices(sp_rank, sp_size, seq_lens) + input_indices_sp[sp_rank].extend(indices) + flatten_input_indices = list(itertools.chain(*input_indices_sp)) + flatten_input_ids = np.asarray(input_ids)[flatten_input_indices] + return flatten_input_ids + + +#### Handle metadata +def init_sequence_parallel_args( + model_runner: "ModelRunner", batch: "ScheduleBatch", forward_mode +): + from sglang.srt.model_executor.forward_batch_info import ForwardMode + + sp_rank = model_runner.sp_rank + sp_size = model_runner.sp_size + seq_lens = batch.seq_lens + extend_seq_lens_cpu = batch.prefill_extend_lens + num_tokens = batch.input_ids.numel() + if sp_size > 1: + # During the runtime, we should use positions[local_token_indices] + # to get positions for each SP shard. + if forward_mode == ForwardMode.DECODE: + seq_lens_cpu = seq_lens.cpu().numpy() + sp_to_normal_indices = _sp_to_normal_indices_decode(sp_size, seq_lens_cpu) + sp_local_token_length = seq_parallel_decode_indices( + sp_rank, sp_size, seq_lens_cpu + ).size + sp_local_token_offset = _seq_parallel_offset_decode( + sp_rank, sp_size, seq_lens_cpu + ) + # Convert positions to SP layout and add padding zeros. + normal_to_sp_indices = np.argsort(sp_to_normal_indices) + # positions = positions[normal_to_sp_indices] + else: + sp_to_normal_indices = _sp_to_normal_indices_extend( + sp_size, extend_seq_lens_cpu + ) + sp_local_token_length = seq_parallel_local_len_extend( + sp_rank, sp_size, extend_seq_lens_cpu + ) + sp_local_token_offset = _seq_parallel_offset_extend( + sp_rank, sp_size, extend_seq_lens_cpu + ) + # Convert positions to SP layout and add padding zeros. + normal_to_sp_indices = np.argsort(sp_to_normal_indices) + # positions = positions[normal_to_sp_indices] + # positions = seq_parallel_pad_zeros(positions, extend_seq_lens_cpu, sp_size) + # Add padding zeros to out_cache_loc and write KV of padded tokens that may + # exist in the last SP shard to slot 0 (reserved for dummy output). + if sp_rank == sp_size - 1: + batch.out_cache_loc = seq_parallel_pad_zeros( + batch.out_cache_loc, extend_seq_lens_cpu, sp_size, True + ) + else: + sp_to_normal_indices = np.arange(num_tokens) + normal_to_sp_indices = np.arange(num_tokens) + sp_local_token_length = num_tokens + sp_local_token_offset = 0 + + _debug_normal_to_sp_metadata = None + if False and sp_size > 1: + if forward_mode == ForwardMode.DECODE: + _debug_normal_to_sp_metadata = _debug_normal_to_sp_indices_decode( + sp_size, seq_lens_cpu + ) + else: + _debug_normal_to_sp_metadata = _debug_normal_to_sp_indices_extend( + sp_size, extend_seq_lens_cpu + ) + + init_args = { + "sp_size": sp_size, + "sp_rank": sp_rank, + "sp_to_normal_indices": sp_to_normal_indices, + "sp_local_token_length": sp_local_token_length, + "sp_local_token_offset": sp_local_token_offset, + "_debug_normal_to_sp_metadata": _debug_normal_to_sp_metadata, + "flashinfer_prefill_wrapper_sp_full": model_runner.flashinfer_prefill_wrapper_sp_full, + "flashinfer_prefill_wrapper_sp_causal": model_runner.flashinfer_prefill_wrapper_sp_causal, + } + aux_args = {"normal_to_sp_indices": normal_to_sp_indices} + return init_args, aux_args diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 4459213b02f..38d73910726 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -194,6 +194,13 @@ def capture_one_batch_size(self, bs: int, forward: Callable): seq_lens = self.seq_lens[:bs] position_ids_offsets = self.position_ids_offsets[:bs] out_cache_loc = self.out_cache_loc[:bs] + # TODO (yonghao): fix parameter initialization below. + normal_to_sp_indices = None + sp_decode_local_lens = torch.ceil(seq_lens / self.model_runner.sp_size).to( + torch.int32 + ) + sp_local_token_offset = 0 + sp_local_token_length = torch.sum(sp_decode_local_lens).to(torch.int32) # FlashInfer inputs if not _grouped_size_compiled_for_decode_kernels( @@ -237,6 +244,8 @@ def capture_one_batch_size(self, bs: int, forward: Callable): seq_lens, None, flashinfer_decode_wrapper, + normal_to_sp_indices=normal_to_sp_indices, + sp_decode_local_lens=sp_decode_local_lens, ) # Run and capture @@ -254,6 +263,10 @@ def run_once(): top_logprobs_nums=0, positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64), flashinfer_decode_wrapper=flashinfer_decode_wrapper, + sp_rank=self.model_runner.sp_rank, + sp_size=self.model_runner.sp_size, + sp_local_token_offset=sp_local_token_offset, + sp_local_token_length=sp_local_token_length, ) return forward(input_ids, input_metadata.positions, input_metadata) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index a443b113d44..824f0586015 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -26,6 +26,11 @@ import triton.language as tl from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.managers.seq_parallel_layout import ( + init_sequence_parallel_args, + seq_parallel_local_len_extend, + seq_parallel_pad_zeros, +) from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool if TYPE_CHECKING: @@ -90,6 +95,18 @@ class InputMetadata: flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None flashinfer_use_ragged: bool = False + # NOTE: for sequence parallel, we need dedicated kernels for cross-shard attn. + # Especially, we need custom masks for the last SP shard which may contain padding tokens. + flashinfer_prefill_wrapper_sp_full: "BatchPrefillWithRaggedKVCacheWrapper" = None + flashinfer_prefill_wrapper_sp_causal: "BatchPrefillWithRaggedKVCacheWrapper" = None + + # For Sequence Parallel + sp_rank: int = None + sp_size: int = None + sp_to_normal_indices: np.ndarray = None + sp_local_token_length: int = None + sp_local_token_offset: int = None + _debug_normal_to_sp_metadata: Optional[List[np.ndarray]] = None def init_multimuldal_info(self, batch: ScheduleBatch): reqs = batch.reqs @@ -97,7 +114,7 @@ def init_multimuldal_info(self, batch: ScheduleBatch): self.image_sizes = [r.image_sizes for r in reqs] self.image_offsets = [r.image_offsets for r in reqs] - def compute_positions(self, batch: ScheduleBatch): + def compute_positions(self, batch: ScheduleBatch, normal_to_sp_indices): position_ids_offsets = batch.position_ids_offsets if self.forward_mode == ForwardMode.DECODE: @@ -137,6 +154,9 @@ def compute_positions(self, batch: ScheduleBatch): # Positions should be in long type self.positions = self.positions.to(torch.int64) + update_positions_for_seq_parallel( + self, normal_to_sp_indices, batch.prefill_extend_lens + ) def compute_extend_infos(self, batch: ScheduleBatch): if self.forward_mode == ForwardMode.DECODE: @@ -173,6 +193,9 @@ def from_schedule_batch( batch: ScheduleBatch, forward_mode: ForwardMode, ): + sp_args, aux_args = init_sequence_parallel_args( + model_runner, batch, forward_mode + ) ret = cls( forward_mode=forward_mode, sampling_info=batch.sampling_info, @@ -184,11 +207,12 @@ def from_schedule_batch( out_cache_loc=batch.out_cache_loc, return_logprob=batch.return_logprob, top_logprobs_nums=batch.top_logprobs_nums, + **sp_args, ) ret.sampling_info.prepare_penalties() - ret.compute_positions(batch) + ret.compute_positions(batch, aux_args["normal_to_sp_indices"]) ret.compute_extend_infos(batch) @@ -208,12 +232,17 @@ def from_schedule_batch( if not model_runner.server_args.disable_flashinfer: if ( forward_mode != ForwardMode.DECODE - and int(torch.sum(ret.seq_lens)) > 4096 + and (int(torch.sum(ret.seq_lens)) > 4096 or ret.sp_size > 1) and model_runner.sliding_window_size is None ): + # NOTE: SP requires the ragged kernel regardless of the sequence length. flashinfer_use_ragged = True ret.init_flashinfer_handlers( - model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged + model_runner, + batch.prefix_lens_cpu, + flashinfer_use_ragged, + aux_args["normal_to_sp_indices"], + batch.sp_decode_local_lens, ) return ret @@ -236,6 +265,8 @@ def init_flashinfer_handlers( model_runner, prefix_lens_cpu, flashinfer_use_ragged, + normal_to_sp_indices, + sp_decode_local_lens, ): if self.forward_mode == ForwardMode.DECODE: prefix_lens = None @@ -249,6 +280,8 @@ def init_flashinfer_handlers( self.seq_lens, prefix_lens, flashinfer_use_ragged=flashinfer_use_ragged, + normal_to_sp_indices=normal_to_sp_indices, + sp_decode_local_lens=sp_decode_local_lens, ) ( @@ -308,10 +341,16 @@ def update_flashinfer_indices( prefix_lens, flashinfer_decode_wrapper=None, flashinfer_use_ragged=False, + normal_to_sp_indices=None, + sp_decode_local_lens=None, ): """Init auxiliary variables for FlashInfer attention backend.""" num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size - num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size) + # NOTE (yifan): we partitioned K and V along both TP and SP dimensions. + # And here tp_size represents KV-TP size * SP size. + num_kv_heads = model_runner.model_config.get_num_kv_heads( + model_runner.tp_size // model_runner.sp_size + ) head_dim = model_runner.model_config.head_dim batch_size = len(req_pool_indices) @@ -321,6 +360,28 @@ def update_flashinfer_indices( else: paged_kernel_lens = seq_lens + sp_size = model_runner.sp_size + if forward_mode == ForwardMode.DECODE: + # With SP, reqs may have been reordered so we track them here. + if normal_to_sp_indices is not None: + req_ids = normal_to_sp_indices.tolist() + else: + req_ids = list(range(batch_size)) + paged_kernel_lens = seq_lens if sp_size == 1 else sp_decode_local_lens + else: + extend_lens = seq_lens - prefix_lens + # With SP, we use different kernels for sequences that are not evenly partitioned + # across SP workers. Here seq_lens works for most SP workers that do not need + # masks, and we initiaize kernels with masks separately below. + seq_lens = torch.ceil(seq_lens / sp_size).to(torch.int32) + prefix_lens = torch.ceil(prefix_lens / sp_size).to(torch.int32) + req_ids = list(range(batch_size)) + + if sp_size > 1: + req_pool_indices = req_pool_indices[req_ids].contiguous() + paged_kernel_lens = paged_kernel_lens[req_ids].contiguous() + paged_kernel_lens = paged_kernel_lens.to(req_pool_indices.device) + kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) @@ -338,6 +399,9 @@ def update_flashinfer_indices( kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") if forward_mode == ForwardMode.DECODE: + # For decode, we replicate the current token across SP workers and hence + # each SP worker will have all q heads. + num_qo_heads *= model_runner.sp_size # CUDA graph uses different flashinfer_decode_wrapper if flashinfer_decode_wrapper is None: flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper @@ -381,7 +445,63 @@ def update_flashinfer_indices( head_dim, 1, ) + if ( + sp_size > 1 and forward_mode != ForwardMode.DECODE + ): # Sequence parallel enabled, initialize SP kernels with custom masks. + # NOTE (yifan): here we assume that when sequence parallel is enabled, + # prefix_lens are always 0s, and we will use flashinfer paged attn kernel + # for cross-SP-shard attn computation. If later prefix_lens can be non-0s, ( + # e.g., extend phases with SP), we will need a dedicate paged attn kernel + # wrapper for cross-SP-shard attn. + if torch.sum(prefix_lens) != 0: + raise ValueError( + "Prefix caching with sequence parallelism is not supported." + ) + + # Prepare masks. + sp_size = sp_size + extend_lens_cpu = extend_lens.cpu().numpy() + padded_extend_lens = seq_parallel_local_len_extend( + 0, sp_size, extend_lens_cpu + ) + last_extend_lens = seq_parallel_local_len_extend( + sp_size - 1, sp_size, extend_lens_cpu + ) + qo_len = (seq_lens - prefix_lens).cpu().tolist() + full_mask_arr = [] + causal_mask_arr = [] + for i in range(batch_size): + full_mask_i = torch.full((qo_len[i], qo_len[i]), False, device="cuda") + full_mask_i[: last_extend_lens[i], : padded_extend_lens[i]] = True + full_mask_arr.append(full_mask_i.flatten()) + causal_mask_i = torch.tril(full_mask_i, diagonal=0) + causal_mask_arr.append(causal_mask_i.flatten()) + full_mask = torch.cat(full_mask_arr, dim=0) + causal_mask = torch.cat(causal_mask_arr, dim=0) + + # Cross-SP-shard extend part -- masked for the last SP shard which may have + # padding tokens. For the othe shards, we can simply use the ragged kernel. + model_runner.flashinfer_prefill_wrapper_sp_causal.end_forward() + model_runner.flashinfer_prefill_wrapper_sp_causal.begin_forward( + qo_indptr, + qo_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + custom_mask=causal_mask, + ) + + model_runner.flashinfer_prefill_wrapper_sp_full.end_forward() + model_runner.flashinfer_prefill_wrapper_sp_full.begin_forward( + qo_indptr, + qo_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + custom_mask=full_mask, + ) else: + assert model_runner.sp_size == 1, "SP with sliding window not supported" # window attention use paged only kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") for wrapper_id in range(2): @@ -451,3 +571,20 @@ def update_flashinfer_indices( head_dim, 1, ) + + +def update_positions_for_seq_parallel( + input_metadata: InputMetadata, normal_to_sp_indices, extend_seq_lens +): + sp_size = input_metadata.sp_size + if sp_size == 1: + return + + positions = input_metadata.positions + + if input_metadata.forward_mode == ForwardMode.DECODE: + positions = positions[normal_to_sp_indices] + else: + positions = positions[normal_to_sp_indices] + positions = seq_parallel_pad_zeros(positions, extend_seq_lens, sp_size) + input_metadata.positions = positions diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3d3e0cde9d1..8fff310aaf0 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -36,7 +36,6 @@ from vllm.distributed import ( get_tp_group, init_distributed_environment, - initialize_model_parallel, set_custom_all_reduce, ) from vllm.distributed.parallel_state import in_the_same_node_as @@ -45,6 +44,7 @@ from sglang.global_config import global_config from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.parallel_utils import initialize_model_parallel from sglang.srt.layers.sampler import SampleOutput from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( @@ -78,6 +78,8 @@ def __init__( tp_size: int, nccl_port: int, server_args: ServerArgs, + sp_rank: int = 0, + sp_size: int = 1, ): # Parse args self.model_config = model_config @@ -85,6 +87,8 @@ def __init__( self.gpu_id = gpu_id self.tp_rank = tp_rank self.tp_size = tp_size + self.sp_rank = sp_rank + self.sp_size = sp_size self.nccl_port = nccl_port self.server_args = server_args self.is_multimodal_model = is_multimodal_model( @@ -137,7 +141,11 @@ def init_torch_distributed(self): local_rank=self.gpu_id, distributed_init_method=nccl_init_method, ) - initialize_model_parallel(tensor_model_parallel_size=self.tp_size) + initialize_model_parallel( + tensor_model_parallel_size=self.tp_size, + sequence_parallel_size=self.sp_size, + ) + self.tp_group = get_tp_group() min_per_gpu_memory = get_available_gpu_memory( self.gpu_id, distributed=self.tp_size > 1 ) @@ -321,14 +329,18 @@ def profile_max_num_token(self, total_gpu_memory: int): self.model_config.attention_arch == AttentionArch.MLA and self.server_args.enable_mla ): + # FIXME: temporarily disable SP with MLA + assert self.sp_size == 1, "sequence parallel with MLA not supported" cell_size = ( (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) * self.model_config.num_hidden_layers * torch._utils._element_size(self.kv_cache_dtype) ) else: + kv_tp_size = self.tp_size // self.sp_size + head_num = self.model_config.get_num_kv_heads(kv_tp_size) cell_size = ( - self.model_config.get_num_kv_heads(self.tp_size) + head_num * self.model_config.head_dim * self.model_config.num_hidden_layers * 2 @@ -346,6 +358,11 @@ def init_memory_pool( max_num_reqs: int = None, max_total_tokens: int = None, ): + if self.tp_size % self.sp_size != 0: + raise ValueError( + f"Invalid sequence parallel configuration. tp_size={self.tp_size} " + f"must be divisible by sp_size={self.sp_size}" + ) if self.server_args.kv_cache_dtype == "auto": self.kv_cache_dtype = self.dtype elif self.server_args.kv_cache_dtype == "fp8_e5m2": @@ -389,6 +406,8 @@ def init_memory_pool( self.model_config.attention_arch == AttentionArch.MLA and self.server_args.enable_mla ): + # FIXME: temporarily disable SP with MLA + assert self.sp_size == 1, "sequence parallel with MLA not supported" self.token_to_kv_pool = MLATokenToKVPool( self.max_total_num_tokens, dtype=self.kv_cache_dtype, @@ -400,10 +419,11 @@ def init_memory_pool( # FIXME: temporarily only Triton MLA is supported self.server_args.disable_flashinfer = True else: + kv_tp_size = self.tp_size // self.sp_size self.token_to_kv_pool = MHATokenToKVPool( self.max_total_num_tokens, dtype=self.kv_cache_dtype, - head_num=self.model_config.get_num_kv_heads(self.tp_size), + head_num=self.model_config.get_num_kv_heads(kv_tp_size), head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, ) @@ -430,6 +450,9 @@ def init_flashinfer(self): self.flashinfer_prefill_wrapper_ragged = None self.flashinfer_prefill_wrapper_paged = None self.flashinfer_decode_wrapper = None + # NOTE: for sequence parallel, we need to use a dedicated kernel for cross-shard attn. + self.flashinfer_prefill_wrapper_sp_full = None + self.flashinfer_prefill_wrapper_sp_causal = None return if not _grouped_size_compiled_for_decode_kernels( @@ -440,7 +463,10 @@ def init_flashinfer(self): else: use_tensor_cores = False + self.flashinfer_prefill_wrapper_sp_full = None + self.flashinfer_prefill_wrapper_sp_causal = None if self.sliding_window_size is None: + # FIXME: missing SP info here. self.flashinfer_workspace_buffer = torch.empty( global_config.flashinfer_workspace_size, dtype=torch.uint8, @@ -459,6 +485,17 @@ def init_flashinfer(self): "NHD", use_tensor_cores=use_tensor_cores, ) + if self.sp_size > 1: # Sequence parallel enabled. + self.flashinfer_prefill_wrapper_sp_full = ( + BatchPrefillWithRaggedKVCacheWrapper( + self.flashinfer_workspace_buffer, "NHD" + ) + ) + self.flashinfer_prefill_wrapper_sp_causal = ( + BatchPrefillWithRaggedKVCacheWrapper( + self.flashinfer_workspace_buffer, "NHD" + ) + ) else: self.flashinfer_workspace_buffer = torch.empty( global_config.flashinfer_workspace_size, diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 926d87db8b7..d11b91cd12d 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -26,7 +26,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, - QKVParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig @@ -40,8 +39,10 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.parallel_utils import get_kv_tensor_model_parallel_world_size from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.sampler import Sampler +from sglang.srt.layers.sp_linear import QKVParallelLinear, RowSeqParallelLinear from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -100,20 +101,28 @@ def __init__( ) -> None: super().__init__() self.hidden_size = hidden_size + # This is KV_TP_SIZE * SP_SIZE tp_size = get_tensor_model_parallel_world_size() + # This is the KV-TP size + kv_tp_size = get_kv_tensor_model_parallel_world_size() + # Sequence parallel size self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 + # num_heads is partitioned by both TP and SP so here use tp_size which + # represents the total TP x SP parallelism. self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition + # num_kv_heads is partitioned only by TP so here use kv_tp_size which + # represents the KV-TP parallelism. + if self.total_num_kv_heads >= kv_tp_size: + # Number of KV heads is greater than KV-TP size, so we partition # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 + assert self.total_num_kv_heads % kv_tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + assert kv_tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // kv_tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo self.head_dim = getattr( config, "head_dim", self.hidden_size // self.total_num_heads @@ -133,9 +142,12 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) - self.o_proj = RowParallelLinear( + self.o_proj = RowSeqParallelLinear( self.total_num_heads * self.head_dim, hidden_size, + self.total_num_heads, + self.num_kv_heads, + self.head_dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj", @@ -163,8 +175,7 @@ def forward( hidden_states: torch.Tensor, input_metadata: InputMetadata, ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k, v = self.qkv_proj(hidden_states) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, input_metadata) output, _ = self.o_proj(attn_output) @@ -315,6 +326,13 @@ def forward( input_embeds: torch.Tensor = None, ) -> LogitsProcessorOutput: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) + if input_metadata.sp_size > 1: + # TODO: instead of a GPU indexing, sample under SP layout and parse + # sampling result back to normal layout + hidden_states = hidden_states[ + input_metadata.sp_to_normal_indices + ].contiguous() + input_ids = input_ids[input_metadata.sp_to_normal_indices].contiguous() logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) @@ -324,11 +342,14 @@ def forward( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), + ("qkv_proj.kv_proj", "k_proj", "k"), + ("qkv_proj.kv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + renamed_params_mapping = [ + # (param_name, weight_name) + ("qkv_proj.q_proj", "q_proj"), ] params_dict = self.param_dict @@ -357,6 +378,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + for param_name, weight_name in renamed_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + break param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8a56c02e162..90093c5b2f5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -58,6 +58,7 @@ class ServerArgs: # Other runtime options tp_size: int = 1 + sp_size: int = 1 stream_interval: int = 1 random_seed: Optional[int] = None @@ -304,6 +305,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tp_size, help="The tensor parallelism size.", ) + parser.add_argument( + "--sp-size", + type=int, + default=ServerArgs.sp_size, + help="The sequence parallelism size.", + ) parser.add_argument( "--stream-interval", type=int, diff --git a/test/srt/test_seq_parallel_attn_kernel.py b/test/srt/test_seq_parallel_attn_kernel.py new file mode 100644 index 00000000000..a2895422afe --- /dev/null +++ b/test/srt/test_seq_parallel_attn_kernel.py @@ -0,0 +1,233 @@ +import multiprocessing +import random + +import torch +from vllm.distributed import init_distributed_environment + +from sglang.srt.layers.parallel_utils import initialize_model_parallel +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.model_executor.model_runner import InputMetadata + +NUM_HEADS = 32 +HEAD_DIM = 128 +SCALING = 1 +NUM_KV_HEADS = 8 +LAYER_ID = 0 +LOGIT_CAP = -1 + + +BATCH_SIZE = 1 +QO_LEN = 128 +KV_LEN = 128 + + +def gen_qkv(rank: int = 0, sp_size: int = 1): + torch.manual_seed(42) + random.seed(42) + q = torch.randn(BATCH_SIZE * QO_LEN, NUM_HEADS, HEAD_DIM).cuda().half() + k = torch.randn(BATCH_SIZE * KV_LEN, NUM_KV_HEADS, HEAD_DIM).cuda().half() + v = torch.randn(BATCH_SIZE * KV_LEN, NUM_KV_HEADS, HEAD_DIM).cuda().half() + + # num_heads_per_partition = NUM_HEADS // sp_size + # q = q[ + # :, :, num_heads_per_partition * rank : num_heads_per_partition * (rank + 1) + # ].contiguous() + # kv_len_per_partition = KV_LEN // sp_size + # k = k[ + # :, kv_len_per_partition * rank : kv_len_per_partition * (rank + 1) + # ].contiguous() + # v = v[ + # :, kv_len_per_partition * rank : kv_len_per_partition * (rank + 1) + # ].contiguous() + + return q, k, v + + +def get_input_metadata(sp_size: int = 1, tp_size: int = 1): + from flashinfer import ( + BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, + ) + + input_metadata = InputMetadata( + forward_mode=None, + batch_size=BATCH_SIZE, + total_num_tokens=BATCH_SIZE * QO_LEN, + req_pool_indices=None, + seq_lens=None, + positions=None, + req_to_token_pool=None, + token_to_kv_pool=None, + out_cache_loc=None, + extend_seq_lens=None, + extend_start_loc=None, + extend_no_prefix=True, + return_logprob=None, + top_logprobs_nums=None, + flashinfer_prefill_wrapper_ragged=None, + flashinfer_prefill_wrapper_paged=None, + flashinfer_decode_wrapper=None, + sp_size=sp_size, + ) + + workspace_buffer = torch.empty( + 2, 128 * 1024 * 1024, dtype=torch.int8, device="cuda" + ) + + input_metadata.flashinfer_prefill_wrapper_ragged = ( + BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer[0], "NHD") + ) + input_metadata.flashinfer_prefill_wrapper_paged = ( + BatchPrefillWithPagedKVCacheWrapper(workspace_buffer[1], "NHD") + ) + + num_qo_heads = NUM_HEADS // sp_size + num_kv_heads = NUM_KV_HEADS + qo_len_per_iter = QO_LEN // sp_size + kv_len_per_partition = KV_LEN // sp_size + + qo_indptr = torch.arange(0, BATCH_SIZE + 1).cuda().int() * qo_len_per_iter + kv_indptr = torch.arange(0, BATCH_SIZE + 1).cuda().int() * kv_len_per_partition + input_metadata.flashinfer_prefill_wrapper_ragged.end_forward() + input_metadata.flashinfer_prefill_wrapper_ragged.begin_forward( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + HEAD_DIM, + ) + + # cached part + kv_indices = torch.arange(0, BATCH_SIZE * kv_len_per_partition).cuda().int() + kv_last_page_len = torch.full((BATCH_SIZE,), 1, dtype=torch.int32).cuda() + input_metadata.flashinfer_prefill_wrapper_paged.end_forward() + input_metadata.flashinfer_prefill_wrapper_paged.begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + HEAD_DIM, + 1, + ) + + return input_metadata + + +def sp_worker(rank: int = 0, sp_size: int = 1, tp_size: int = 1): + torch.manual_seed(42) + random.seed(42) + + def init_comm(): + nccl_init_method = f"tcp://127.0.0.1:28888" + init_distributed_environment( + backend="nccl", + world_size=tp_size, + rank=rank, + local_rank=rank, + distributed_init_method=nccl_init_method, + ) + initialize_model_parallel( + tensor_model_parallel_size=tp_size, sequence_parallel_size=sp_size + ) + torch.cuda.set_device(rank) + + init_comm() + + def init_attention(): + attention = RadixAttention( + num_heads=NUM_HEADS // sp_size, + head_dim=HEAD_DIM, + scaling=SCALING, + num_kv_heads=NUM_KV_HEADS, + layer_id=LAYER_ID, + logit_cap=LOGIT_CAP, + ) + return attention + + attn = init_attention() + print("SP worker", rank, "initialized on", torch.cuda.current_device()) + + # Computation + input_metadata = get_input_metadata(sp_size=sp_size, tp_size=tp_size) + q, k, v = gen_qkv(rank, sp_size) + qs, ks, vs = [], [], [] + q_head_idxes = _get_sequence_parallel_head_idxes( + NUM_HEADS, NUM_KV_HEADS, rank, sp_size + ) + print(rank, q_head_idxes) + for i in range(sp_size): + qs.append( + q[(QO_LEN // sp_size) * i : (QO_LEN // sp_size) * (i + 1), q_head_idxes] + ) + ks.append(k[(KV_LEN // sp_size) * i : (KV_LEN // sp_size) * (i + 1)]) + vs.append(v[(KV_LEN // sp_size) * i : (KV_LEN // sp_size) * (i + 1)]) + + output = attn.seq_parallel_extend_forward_flashinfer(qs, ks, vs, input_metadata) + + o_truth = reference_attn() + o_truth = ( + o_truth.contiguous() + .view(-1, NUM_HEADS, HEAD_DIM)[:, q_head_idxes] + .view(-1, NUM_HEADS // sp_size * HEAD_DIM) + ) + + print("SP worker", rank, "results:") + print("Mean: ", torch.mean(torch.abs(output - o_truth))) + print("Max: ", torch.max(torch.abs(output - o_truth))) + assert torch.allclose(output, o_truth, rtol=1e-2, atol=1e-3) + + +def _get_sequence_parallel_head_idxes(total_num_heads, num_kv_heads, sp_rank, sp_size): + group_num = num_kv_heads + group_size = total_num_heads // num_kv_heads + shard_num_heads = group_size // sp_size + idxes = [ + group_size * i + sp_rank * shard_num_heads + j + for i in range(group_num) + for j in range(0, shard_num_heads) + ] + return idxes + + +def reference_attn(): + torch.manual_seed(42) + random.seed(42) + + attn = RadixAttention( + num_heads=NUM_HEADS, + head_dim=HEAD_DIM, + scaling=SCALING, + num_kv_heads=NUM_KV_HEADS, + layer_id=LAYER_ID, + logit_cap=LOGIT_CAP, + ) + + input_metadata = get_input_metadata() + q, k, v = gen_qkv() + + return attn.extend_forward_flashinfer(q, k, v, input_metadata) + + +def main(): + sp_size = 2 + tp_size = 2 + + multiprocessing.set_start_method("spawn", force=True) + sp_procs = [] + for rank in range(1, sp_size): + sp_proc = multiprocessing.Process( + target=sp_worker, args=(rank, sp_size, tp_size) + ) + sp_proc.start() + sp_procs.append(sp_proc) + + output = sp_worker(0, sp_size, tp_size) + + for sp_proc in sp_procs: + sp_proc.join() + + +if __name__ == "__main__": + main() diff --git a/test/srt/test_seq_parallel_attn_kernel_simple.py b/test/srt/test_seq_parallel_attn_kernel_simple.py new file mode 100644 index 00000000000..1ff578b8f07 --- /dev/null +++ b/test/srt/test_seq_parallel_attn_kernel_simple.py @@ -0,0 +1,279 @@ +import pytest +import torch +from flashinfer import ( + BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, +) +from flashinfer.cascade import merge_state +from flashinfer.decode import _grouped_size_compiled_for_decode_kernels + +from sglang.srt.layers.extend_attention import extend_attention_fwd, redundant_attention +from sglang.srt.layers.token_attention import token_attention_fwd + +flashinfer_prefill_wrapper_ragged = None +flashinfer_prefill_wrapper_paged = None +flashinfer_decode_wrapper = None + + +def get_next_partition_id(curr_partition_id, num_partitions): + assert curr_partition_id < num_partitions + return (curr_partition_id - 1) % num_partitions + + +def get_sp_prev_local_rank(rank, num_partitions): + return (rank - 1) % num_partitions + + +def get_sp_next_local_rank(rank, num_partitions): + return (rank + 1) % num_partitions + + +def append_merge_partition(partition_list, o, s): + if len(partition_list) == 0: + partition_list.append((o, s)) + else: + o_prev, s_prev = partition_list[-1] + o, s = merge_state(o_prev, s_prev, o, s) + partition_list[-1] = (o, s) + + +def seq_parallel_attn( + batch_size, + kv_len, + qo_len, + num_kv_heads, + num_qo_heads, + head_dim, + q, + k, + v, + rank: int, + sp_size: int, +): + """Simulate a sequence parallel attention kernel. It takes full Q, K, and V + with simulated communication. TODO: replace with actual communication. + """ + num_partitions = sp_size + num_iters = sp_size + # NOTE: we assume sequence length is divisible by num_partitions + qo_len_per_iter = qo_len // num_iters + kv_len_per_partition = kv_len // num_partitions + + qo_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len_per_iter + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len_per_partition + flashinfer_prefill_wrapper_ragged.end_forward() + flashinfer_prefill_wrapper_ragged.begin_forward( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + ) + + kv_indices = torch.arange(0, batch_size * kv_len_per_partition).to(0).int() + kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0) + flashinfer_prefill_wrapper_paged.end_forward() + flashinfer_prefill_wrapper_paged.begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + 1, + ) + + local_k, local_v = ( + k[:, rank * kv_len_per_partition : (rank + 1) * kv_len_per_partition] + .contiguous() + .view(-1, num_kv_heads, head_dim), + v[:, rank * kv_len_per_partition : (rank + 1) * kv_len_per_partition] + .contiguous() + .view(-1, num_kv_heads, head_dim), + ) + k_partition, v_partition = local_k, local_v + + owned_pids = [rank] + owned_partitions = [None for _ in range(num_partitions)] + owned_partitions[rank] = (local_k, local_v) + o_partitions = [[] for _ in range(num_partitions)] + + to_rank = rank # which SP worker to send my sequence KV partition to. + from_rank = rank # which SP worker to receive the sequence KV partition from. + + pid = rank # start from the worker's own partition + for _ in range(num_iters): + # TODO: send-recv communication here + to_rank = get_sp_next_local_rank(to_rank, num_partitions) + # send_to(to_rank, k, v) + q_partition = q[:, pid * qo_len_per_iter : (pid + 1) * qo_len_per_iter] + k_partition, v_partition = owned_partitions[pid] + # Ragged attention computation for self attention within the partition + o, s = flashinfer_prefill_wrapper_ragged.forward_return_lse( + q_partition.contiguous().view(-1, num_qo_heads, head_dim), + k_partition.contiguous().view(-1, num_kv_heads, head_dim), + v_partition.contiguous().view(-1, num_kv_heads, head_dim), + ) + append_merge_partition(o_partitions[pid], o, s) + # Paged attention computation for cross partition attention + # NOTE: below schedule is for load balancing + for existing_pid in owned_pids: + if existing_pid == pid: + continue + i, j = (existing_pid, pid) if existing_pid > pid else (pid, existing_pid) + q_data = q[:, i * qo_len_per_iter : (i + 1) * qo_len_per_iter] + kv_data = torch.stack(owned_partitions[j], dim=1) + o, s = flashinfer_prefill_wrapper_paged.forward_return_lse( + q_data.contiguous().view(-1, num_qo_heads, head_dim), + kv_data, + causal=False, + ) + append_merge_partition(o_partitions[i], o, s) + + # TODO: send-recv communication here + from_rank = get_sp_prev_local_rank(from_rank, num_partitions) + # recv_from(from_rank, k, v) + pid = from_rank + kv_recved = ( + k[:, pid * kv_len_per_partition : (pid + 1) * kv_len_per_partition] + .contiguous() + .view(-1, num_kv_heads, head_dim), + v[:, pid * kv_len_per_partition : (pid + 1) * kv_len_per_partition] + .contiguous() + .view(-1, num_kv_heads, head_dim), + ) + owned_pids.append(pid) + owned_partitions[pid] = kv_recved + + # Reshape all o tensors so that we can concatenate along the sequence dimension + # we must have len(partition_list) == 1 here + os = [ + o.view(batch_size, qo_len_per_iter, num_qo_heads, head_dim) + for partition_list in o_partitions + for o, _ in partition_list + ] + o = torch.cat(os, dim=1).view( + -1, num_qo_heads, head_dim + ) # restore the original shape + return o + + +@pytest.mark.parametrize("batch_size", [12, 37, 67]) +@pytest.mark.parametrize("kv_len", [54, 97]) +@pytest.mark.parametrize("qo_len", [37, 17]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("num_qo_heads", [32, 4]) +@pytest.mark.parametrize("head_dim", [128]) +def test_seq_parallel_prefill( + batch_size, + kv_len, + qo_len, + num_kv_heads, + num_qo_heads, + head_dim, + rank: int = 0, + sp_size: int = 2, +): + init_flashinfer(num_qo_heads, num_kv_heads) + + q = torch.randn(batch_size, qo_len, num_qo_heads, head_dim).to(0).half() + k = torch.randn(batch_size, kv_len, num_kv_heads, head_dim).to(0).half() + v = torch.randn(batch_size, kv_len, num_kv_heads, head_dim).to(0).half() + + def reference_impl_ragged(): + qo_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len + + flashinfer_prefill_wrapper_ragged.end_forward() + flashinfer_prefill_wrapper_ragged.begin_forward( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + ) + o = flashinfer_prefill_wrapper_ragged.forward( + q.contiguous().view(-1, num_qo_heads, head_dim), + k.contiguous().view(-1, num_kv_heads, head_dim), + v.contiguous().view(-1, num_kv_heads, head_dim), + ) + flashinfer_prefill_wrapper_ragged.end_forward() + return o + + def reference_impl_paged(): + qo_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len + total_tokens = kv_len * batch_size + + kv_data = torch.zeros(total_tokens, 2, num_kv_heads, head_dim).to(0).half() + kv_data[:, 0] = k.contiguous().view(-1, num_kv_heads, head_dim) + kv_data[:, 1] = v.contiguous().view(-1, num_kv_heads, head_dim) + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len + kv_indices = torch.arange(0, total_tokens).to(0).int() + kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0) + + flashinfer_prefill_wrapper_paged.end_forward() + flashinfer_prefill_wrapper_paged.begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + 1, + ) + o = flashinfer_prefill_wrapper_paged.forward( + q.contiguous().view(-1, num_qo_heads, head_dim), kv_data + ) + flashinfer_prefill_wrapper_paged.end_forward() + return o + + o_sp = seq_parallel_attn( + batch_size, + kv_len, + qo_len, + num_kv_heads, + num_qo_heads, + head_dim, + q, + k, + v, + rank=1, + sp_size=4, + ) + o_truth = reference_impl_paged() + + print("Mean: ", torch.mean(torch.abs(o_sp - o_truth))) + print("Max: ", torch.max(torch.abs(o_sp - o_truth))) + assert torch.allclose(o_sp, o_truth, rtol=1e-2, atol=1e-3) + + +def init_flashinfer(num_attention_heads, num_kv_heads): + if not _grouped_size_compiled_for_decode_kernels(num_attention_heads, num_kv_heads): + use_tensor_cores = True + else: + use_tensor_cores = False + + workspace_buffer = torch.empty( + 3, 128 * 1024 * 1024, dtype=torch.int8, device="cuda" + ) + + global flashinfer_prefill_wrapper_ragged, flashinfer_prefill_wrapper_paged, flashinfer_decode_wrapper + + flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer[0], "NHD" + ) + flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer[1], "NHD" + ) + flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer[2], "NHD", use_tensor_cores=use_tensor_cores + ) + + +if __name__ == "__main__": + test_seq_parallel_prefill(12, 128, 128, 8, 8, 128, rank=3, sp_size=4) + test_seq_parallel_prefill(12, 4096, 4096, 8, 8, 128, rank=4, sp_size=8) + test_seq_parallel_prefill(12, 1024, 1024, 32, 32, 128, rank=1, sp_size=2) diff --git a/test/srt/test_sp_comm_group.py b/test/srt/test_sp_comm_group.py new file mode 100644 index 00000000000..17d0a45fd9a --- /dev/null +++ b/test/srt/test_sp_comm_group.py @@ -0,0 +1,70 @@ +import multiprocessing +import random + +import torch +from vllm.distributed import init_distributed_environment + +from sglang.srt.layers.parallel_utils import get_sp_group, initialize_model_parallel + +NUM_TOKENS = 3 +NUM_KV_HEADS = 2 +HEAD_DIM = 4 + + +def gen_kv(rank: int = 0, sp_size: int = 1): + torch.manual_seed(42) + random.seed(42) + k = torch.randn(NUM_TOKENS, NUM_KV_HEADS, HEAD_DIM).cuda().half() + v = torch.randn(NUM_TOKENS, NUM_KV_HEADS, HEAD_DIM).cuda().half() + + return k, v + + +def sp_worker(rank: int = 0, sp_size: int = 1, tp_size: int = 1): + torch.manual_seed(42) + random.seed(42) + + nccl_init_method = f"tcp://127.0.0.1:28888" + init_distributed_environment( + backend="nccl", + world_size=tp_size, + rank=rank, + local_rank=rank, + distributed_init_method=nccl_init_method, + ) + initialize_model_parallel( + tensor_model_parallel_size=tp_size, sequence_parallel_size=sp_size + ) + torch.cuda.set_device(rank) + print("SP worker", rank, "initialized on", torch.cuda.current_device()) + + k, v = gen_kv(rank, sp_size) + + ks = get_sp_group().all_gather(k.view(1, *k.shape), dim=0) + vs = get_sp_group().all_gather(v.view(1, *v.shape), dim=0) + + print("SP worker", rank, "all-gathered ks", ks) + print("SP worker", rank, "all-gathered vs", vs) + + +def main(): + sp_size = 2 + tp_size = 2 + + multiprocessing.set_start_method("spawn", force=True) + sp_procs = [] + for rank in range(1, sp_size): + sp_proc = multiprocessing.Process( + target=sp_worker, args=(rank, sp_size, tp_size) + ) + sp_proc.start() + sp_procs.append(sp_proc) + + sp_worker(0, sp_size, tp_size) + + for sp_proc in sp_procs: + sp_proc.join() + + +if __name__ == "__main__": + main() diff --git a/test/srt/test_sp_decode_attn.py b/test/srt/test_sp_decode_attn.py new file mode 100644 index 00000000000..084c95985d9 --- /dev/null +++ b/test/srt/test_sp_decode_attn.py @@ -0,0 +1,191 @@ +import multiprocessing +import random + +import torch +from flashinfer import BatchDecodeWithPagedKVCacheWrapper, merge_state +from vllm.distributed import init_distributed_environment + +from sglang.srt.layers.parallel_utils import get_sp_group, initialize_model_parallel + +NUM_HEADS = 32 +HEAD_DIM = 128 +SCALING = 1 +NUM_KV_HEADS = 8 +LAYER_ID = 0 +LOGIT_CAP = -1 + + +BATCH_SIZE = 3 +SEQ_LENS = [16, 64, 128] + + +def gen_qkv(sp_rank: int = 0, sp_size: int = 1): + torch.manual_seed(42) + random.seed(42) + + q = torch.randn(BATCH_SIZE, NUM_HEADS, HEAD_DIM).cuda().half() + total_num_context_tokens = sum(SEQ_LENS) + kv_cache = ( + torch.randn(total_num_context_tokens, 2, NUM_KV_HEADS, HEAD_DIM).cuda().half() + ) + + if sp_size > 1: + q_head_idxes = _get_sequence_parallel_head_idxes( + NUM_HEADS, NUM_KV_HEADS, sp_rank, sp_size + ) + q = q[:, q_head_idxes].contiguous() + + sp_kv_cache = ( + torch.empty(total_num_context_tokens // sp_size, 2, NUM_KV_HEADS, HEAD_DIM) + .cuda() + .half() + ) + sp_stt, stt = 0, 0 + for i in range(BATCH_SIZE): + seq_len = SEQ_LENS[i] + sp_seq_len = seq_len // sp_size + + sp_end = sp_stt + sp_seq_len + end = stt + seq_len + + sp_kv_cache[sp_stt:sp_end] = kv_cache[ + stt + sp_rank * sp_seq_len : stt + (sp_rank + 1) * sp_seq_len + ] + sp_stt = sp_end + stt = end + kv_cache = sp_kv_cache + + return q, kv_cache + + +def init_flashinfer(sp_size: int = 1, tp_size: int = 1): + + workspace_buffer = torch.empty( + 1, 128 * 1024 * 1024, dtype=torch.int8, device="cuda" + ) + + flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer[0], "NHD" + ) + + num_qo_heads = NUM_HEADS + num_kv_heads = NUM_KV_HEADS + + seq_lens = torch.tensor(SEQ_LENS, dtype=torch.int32, device="cuda") + seq_lens = seq_lens // sp_size + total_num_context_tokens = sum(SEQ_LENS) // sp_size + + kv_indptr = torch.zeros((BATCH_SIZE + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1:] = torch.cumsum(seq_lens, dim=0) + kv_indices = torch.arange( + total_num_context_tokens, dtype=torch.int32, device="cuda" + ) + kv_last_page_len = torch.ones((BATCH_SIZE,), dtype=torch.int32, device="cuda") + + flashinfer_decode_wrapper.end_forward() + flashinfer_decode_wrapper.begin_forward( + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + HEAD_DIM, + 1, + ) + + return flashinfer_decode_wrapper + + +def sp_worker(rank: int = 0, sp_size: int = 1, tp_size: int = 1): + torch.manual_seed(42) + random.seed(42) + + def init_comm(): + nccl_init_method = f"tcp://127.0.0.1:28888" + init_distributed_environment( + backend="nccl", + world_size=tp_size, + rank=rank, + local_rank=rank, + distributed_init_method=nccl_init_method, + ) + initialize_model_parallel( + tensor_model_parallel_size=tp_size, sequence_parallel_size=sp_size + ) + torch.cuda.set_device(rank) + + init_comm() + + print("SP worker", rank, "initialized on", torch.cuda.current_device()) + + decode_wrapper = init_flashinfer(sp_size=sp_size, tp_size=tp_size) + q, kv_cache = gen_qkv(rank, sp_size) + + gathered_q = get_sp_group().all_gather(q.view(1, *q.shape), dim=0) + q = torch.empty_like(gathered_q).view(-1, NUM_HEADS, HEAD_DIM) + + for i in range(sp_size): + idxes = _get_sequence_parallel_head_idxes(NUM_HEADS, NUM_KV_HEADS, i, sp_size) + q[:, idxes] = gathered_q[i] + + # Computation + o, s = decode_wrapper.forward_return_lse(q, kv_cache) + + os = get_sp_group().all_gather(o.view(1, *o.shape), dim=0) + ss = get_sp_group().all_gather(s.view(1, *s.shape), dim=0) + for i in range(sp_size): + if i != rank: + o, s = merge_state(os[i], ss[i], o, s) + output = o + + o_truth = reference_attn() + + print("SP worker", rank, "results:") + print("Mean: ", torch.mean(torch.abs(output - o_truth))) + print("Max: ", torch.max(torch.abs(output - o_truth))) + assert torch.allclose(output, o_truth, rtol=1e-2, atol=1e-3) + + +def _get_sequence_parallel_head_idxes(total_num_heads, num_kv_heads, sp_rank, sp_size): + group_num = num_kv_heads + group_size = total_num_heads // num_kv_heads + shard_num_heads = group_size // sp_size + idxes = [ + group_size * i + sp_rank * shard_num_heads + j + for i in range(group_num) + for j in range(0, shard_num_heads) + ] + return idxes + + +def reference_attn(): + torch.manual_seed(42) + random.seed(42) + + decode_wrapper = init_flashinfer() + q, kv_cache = gen_qkv() + + return decode_wrapper.forward(q, kv_cache) + + +def main(): + sp_size = 2 + tp_size = 2 + + multiprocessing.set_start_method("spawn", force=True) + sp_procs = [] + for rank in range(1, sp_size): + sp_proc = multiprocessing.Process( + target=sp_worker, args=(rank, sp_size, tp_size) + ) + sp_proc.start() + sp_procs.append(sp_proc) + + sp_worker(0, sp_size, tp_size) + + for sp_proc in sp_procs: + sp_proc.join() + + +if __name__ == "__main__": + main()