Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sequence Parallel #1041

Closed
wants to merge 44 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
f498ad1
add sp index
ZYHowell Jul 19, 2024
4a807ec
add clone for rope as it's in-place
ZYHowell Jul 19, 2024
285348c
add decode mask for sp
ZYHowell Jul 19, 2024
5b2a048
insert to prepare batch
ZYHowell Jul 19, 2024
f8b8dbc
add sp size and rank args
ZYHowell Jul 20, 2024
de61f42
update sequence parallel layout
ZYHowell Jul 20, 2024
073b9dc
minor bug fix to pass sp=1 test
ZYHowell Jul 20, 2024
9599131
give local indices to help with position ids; prepare for only record…
ZYHowell Jul 21, 2024
152666f
minor fix
ZYHowell Jul 21, 2024
50436b7
add sp layout to normal layout
ZYHowell Jul 21, 2024
8f8db37
move sp layout transform tool to inputmetadata
ZYHowell Jul 21, 2024
fd49bf4
add debug flatten to sp
ZYHowell Jul 21, 2024
1785ebf
update name and doc string
ZYHowell Jul 21, 2024
c8d850b
bug fix for the sp=1 case
ZYHowell Jul 22, 2024
8f46fee
fix prefix lens None
ZYHowell Jul 22, 2024
73af1b6
fix debug mode indices
ZYHowell Jul 22, 2024
1afdae2
runnable but only first two decode tok cor
ZYHowell Jul 22, 2024
2e41f46
fix early exit for decode with SP
ZYHowell Jul 22, 2024
dd2382d
format
ZYHowell Jul 22, 2024
a11bc61
Merge pull request #1 from ivanium/pr-sp-rope
ivanium Jul 22, 2024
4b8203a
Update sp layout (#3)
ZYHowell Jul 28, 2024
98c1154
Sequence parallel prefill attention kernel (#2)
ivanium Jul 31, 2024
1695aed
Sequence Parallel Decode Attn Kernel (#5)
ivanium Aug 8, 2024
639e716
fix [infer_batch]: fix _get_decode_local_lens and use it to initializ…
ivanium Aug 12, 2024
2ab47a2
all changes
ZYHowell Aug 15, 2024
9f0f3eb
merge
ZYHowell Aug 15, 2024
4930fb9
Merge remote-tracking branch 'origin/main' into pr-rebase
ZYHowell Aug 15, 2024
4de47a0
fix lint and test import
ZYHowell Aug 15, 2024
f23b4cf
fix [sp_linear, llama2]: add prefix support for SP linear layers. Fix…
ivanium Aug 24, 2024
beb7494
fix [forward_batch_info]: prefix_lens can be None in decode phase so …
ivanium Aug 24, 2024
6f9bb27
fix [radix_attention]: remove duplicated sp extend/decode functions
ivanium Aug 24, 2024
e5878f5
fix [radix_attention]: kv_data -> get_kv_buffer()
ivanium Aug 24, 2024
14fc6e2
minor bug fix
ZYHowell Aug 25, 2024
69716be
fix out cache loc
ZYHowell Aug 27, 2024
a8a3d55
fix [forward_batch_info]: init flashinfer ragged kernel correctly for SP
ivanium Aug 31, 2024
86fdc23
fix [cuda_graph_runner]: add missing SP parameters to CUDA graph init…
ivanium Sep 2, 2024
f1ad3ee
Merge pull request #9 from ivanium/pr-rebase
ZYHowell Sep 2, 2024
3b92e9d
Merge remote-tracking branch 'upstream/main' into pr-rebase-2561ed0
ZYHowell Sep 2, 2024
ef22804
minor fix
ZYHowell Sep 4, 2024
daaef64
contiguous and to device for triton kernel
ZYHowell Sep 4, 2024
66ecf83
fix paged kernel lens before kv indptr
ZYHowell Sep 4, 2024
8b3662a
Merge pull request #10 from ivanium/pr-rebase-2561ed0
ivanium Sep 6, 2024
0828662
Sequence parallel prefill attention kernel (#2)
ivanium Sep 7, 2024
0175ca2
Merge pull request #11 from ivanium/pr-rebase-ab4a83b
ZYHowell Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def from_cli_args(cls, args: argparse.Namespace):
)


def load_model(server_args, tp_rank):
def load_model(server_args, tp_rank, sp_rank: int = 0):
suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

Expand All @@ -130,6 +130,8 @@ def load_model(server_args, tp_rank):
gpu_id=tp_rank,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
sp_rank=sp_rank,
sp_size=server_args.sp_size,
nccl_port=28888,
server_args=server_args,
)
Expand Down Expand Up @@ -206,6 +208,8 @@ def extend(reqs, model_runner):
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
tree_cache=None,
sp_size=model_runner.sp_size,
sp_rank=model_runner.sp_rank,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND)
Expand All @@ -225,11 +229,12 @@ def correctness_test(
server_args,
bench_args,
tp_rank,
sp_rank=0,
):
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

# Load the model
model_runner, tokenizer = load_model(server_args, tp_rank)
model_runner, tokenizer = load_model(server_args, tp_rank, sp_rank)

# Prepare inputs
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
Expand Down Expand Up @@ -336,11 +341,12 @@ def latency_test(
server_args,
bench_args,
tp_rank,
sp_rank=0,
):
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

# Load the model
model_runner, tokenizer = load_model(server_args, tp_rank)
model_runner, tokenizer = load_model(server_args, tp_rank, sp_rank)

# Prepare inputs for warm up
reqs = prepare_synthetic_inputs_for_latency_test(
Expand Down Expand Up @@ -458,16 +464,18 @@ def main(server_args, bench_args):
)

if server_args.tp_size == 1:
work_func(server_args, bench_args, 0)
work_func(server_args, bench_args, 0, 0)
else:
workers = []
for tp_rank in range(server_args.tp_size):
sp_rank = tp_rank % server_args.sp_size
proc = multiprocessing.Process(
target=work_func,
args=(
server_args,
bench_args,
tp_rank,
sp_rank,
),
)
proc.start()
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/layers/parallel_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .parallel_state import *
96 changes: 96 additions & 0 deletions python/sglang/srt/layers/parallel_utils/parallel_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import List, Optional

import torch
from vllm.distributed import initialize_model_parallel as vllm_initialize_model_parallel
from vllm.distributed.parallel_state import (
GroupCoordinator,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_world_group,
init_model_parallel_group,
)

_SP: Optional[GroupCoordinator] = None


def get_sp_group():
assert _SP is not None, "sequence parallel group is not initialized"
return _SP


def init_sequence_parallel_group(
group_ranks: List[List[int]], local_rank: int, backend: str
) -> GroupCoordinator:
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=True,
)


def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
sequence_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None:
"""
Initialize model parallel groups and sequence parallel groups.

For sequence parallelism, we partition SP groups within a TP group, and assign
gpus with adjacent ranks to the same SP group. For example, with TP size 8
and SP size 2, we have 1 TP group and 4 SP groups:
SP groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
Their KV TP rank:
[ 0, 0], [ 1, 1], [ 2, 2], [ 3, 3]
Given that we replicate KV heads within the same seq parallel group, we also say that
the KV TP size is 4 (8//2), and gpus in each SP group have KV-tp rank from 0 to 3.
"""
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(get_world_group().device_group)

num_sequence_parallel_groups: int = world_size // sequence_parallel_size
global _SP
assert _SP is None, "sequence parallel group is already initialized"
group_ranks = []
for i in range(num_sequence_parallel_groups):
ranks = list(
range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
)
group_ranks.append(ranks)
_SP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend)

vllm_initialize_model_parallel(
tensor_model_parallel_size, pipeline_model_parallel_size, backend
)


def sequence_parallel_is_initialized():
return _SP is not None


def get_sequence_parallel_world_size():
return get_sp_group().world_size


def get_sequence_parallel_rank():
return get_sp_group().rank_in_group


def get_sequence_parallel_global_rank():
return get_tensor_model_parallel_rank()


# NOTE: For sequence parallelism, we partition Q tensors along the head dimension.
# But K/V tensors are partitioned along the head dimension in TP and partitioned
# along the sequence dimensions in SP. Therefore, their TP size and rank is adjusted
# accordingly as below.
def get_kv_tensor_model_parallel_world_size():
return get_tensor_model_parallel_world_size() // get_sequence_parallel_world_size()


def get_kv_tensor_model_parallel_rank():
return get_tensor_model_parallel_rank() // get_sequence_parallel_world_size()
Loading