Skip to content

Commit 79df20a

Browse files
wooyeonlee0prashantgupta24
authored andcommitted
[Speculative Decoding] Support draft model on different tensor-parallel size than target model (vllm-project#5414)
1 parent 65b7543 commit 79df20a

11 files changed

+388
-59
lines changed

.buildkite/test-pipeline.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ steps:
5454
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
5555
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
5656
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
57-
- pytest -v -s spec_decode/e2e/test_integration_dist.py
57+
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
5858
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
5959
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
6060

@@ -71,6 +71,7 @@ steps:
7171
# See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context.
7272
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
7373
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
74+
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
7475

7576
- label: Engine Test
7677
mirror_hardwares: [amd]

benchmarks/benchmark_latency.py

+6
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def main(args: argparse.Namespace):
2525
model=args.model,
2626
speculative_model=args.speculative_model,
2727
num_speculative_tokens=args.num_speculative_tokens,
28+
speculative_draft_tensor_parallel_size=\
29+
args.speculative_draft_tensor_parallel_size,
2830
tokenizer=args.tokenizer,
2931
quantization=args.quantization,
3032
tensor_parallel_size=args.tensor_parallel_size,
@@ -127,6 +129,10 @@ def run_to_completion(profile_dir: Optional[str] = None):
127129
parser.add_argument('--model', type=str, default='facebook/opt-125m')
128130
parser.add_argument('--speculative-model', type=str, default=None)
129131
parser.add_argument('--num-speculative-tokens', type=int, default=None)
132+
parser.add_argument('--speculative-draft-tensor-parallel-size',
133+
'-spec-draft-tp',
134+
type=int,
135+
default=None)
130136
parser.add_argument('--tokenizer', type=str, default=None)
131137
parser.add_argument('--quantization',
132138
'-q',
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""Tests which cover integration of the speculative decoding framework with
2+
tensor parallelism.
3+
"""
4+
5+
import pytest
6+
import torch
7+
8+
from vllm.utils import is_hip
9+
10+
from .conftest import run_greedy_equality_correctness_test
11+
12+
13+
@pytest.mark.skipif(torch.cuda.device_count() < 2,
14+
reason="Need at least 2 GPUs to run the test.")
15+
@pytest.mark.parametrize(
16+
"common_llm_kwargs",
17+
[{
18+
"model": "JackFram/llama-68m",
19+
20+
# Skip cuda graph recording for fast test.
21+
"enforce_eager": True,
22+
23+
# Required for spec decode.
24+
"use_v2_block_manager": True,
25+
"tensor_parallel_size": 2,
26+
27+
# Use AsyncLLM engine, so that the engine runs in its own process.
28+
# Otherwise, since vLLM does not follow true SPMD, the test runner
29+
# process will have both the engine and the rank0 worker. NCCL is not
30+
# cleaned up properly, and its server host thread leaks, causing the
31+
# second run of the test to fail with internal NCCL error.
32+
"use_async": True,
33+
}])
34+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
35+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
36+
@pytest.mark.parametrize("test_llm_kwargs", [
37+
{
38+
"speculative_model": "JackFram/llama-68m",
39+
"num_speculative_tokens": 3,
40+
},
41+
{
42+
"speculative_model": "[ngram]",
43+
"num_speculative_tokens": 5,
44+
"ngram_prompt_lookup_max": 3,
45+
},
46+
])
47+
@pytest.mark.parametrize("batch_size", [2])
48+
@pytest.mark.parametrize(
49+
"output_len",
50+
[
51+
# Use smaller output len for fast test.
52+
32,
53+
])
54+
@pytest.mark.parametrize("seed", [1])
55+
def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
56+
batch_size: int, output_len: int):
57+
"""Verify greedy equality when tensor parallelism is used.
58+
"""
59+
if is_hip():
60+
pytest.skip("hip is not well-supported yet")
61+
run_greedy_equality_correctness_test(baseline_llm_generator,
62+
test_llm_generator,
63+
batch_size,
64+
max_output_len=output_len,
65+
force_output_len=True)
66+
67+
68+
@pytest.mark.skipif(torch.cuda.device_count() < 2,
69+
reason="Need at least 2 GPUs to run the test.")
70+
@pytest.mark.parametrize(
71+
"common_llm_kwargs",
72+
[{
73+
# Use a small model for a fast test.
74+
# Note this is repeated in the test body; to initialize a tokenizer.
75+
"model": "JackFram/llama-68m",
76+
77+
# Skip cuda graph recording for fast test.
78+
"enforce_eager": True,
79+
80+
# Required for spec decode.
81+
"use_v2_block_manager": True,
82+
"tensor_parallel_size": 2,
83+
84+
# Use AsyncLLM engine, so that the engine runs in its own process.
85+
# Otherwise, since vLLM does not follow true SPMD, the test runner
86+
# process will have both the engine and the rank0 worker. NCCL is not
87+
# cleaned up properly, and its server host thread leaks, causing the
88+
# second run of the test to fail with internal NCCL error.
89+
"use_async": True,
90+
}])
91+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
92+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
93+
@pytest.mark.parametrize("test_llm_kwargs", [
94+
{
95+
"speculative_model": "JackFram/llama-68m",
96+
"num_speculative_tokens": 5,
97+
"speculative_draft_tensor_parallel_size": 1,
98+
},
99+
])
100+
@pytest.mark.parametrize("batch_size", [2])
101+
@pytest.mark.parametrize("seed", [1])
102+
def test_draft_model_tp_lt_target_model_tp2(test_llm_generator,
103+
baseline_llm_generator,
104+
batch_size: int):
105+
"""Verify spec decode works well with smaller tp for draft models.
106+
"""
107+
run_greedy_equality_correctness_test(baseline_llm_generator,
108+
test_llm_generator,
109+
batch_size,
110+
max_output_len=32,
111+
force_output_len=True)

tests/spec_decode/e2e/test_integration_dist.py tests/spec_decode/e2e/test_integration_dist_tp4.py

+18-23
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,24 @@
55
import pytest
66
import torch
77

8-
from vllm.utils import is_hip
9-
108
from .conftest import run_greedy_equality_correctness_test
119

1210

13-
@pytest.mark.skipif(torch.cuda.device_count() < 2,
14-
reason="Need at least 2 GPUs to run the test.")
11+
@pytest.mark.skipif(torch.cuda.device_count() < 4,
12+
reason="Need at least 4 GPUs to run the test.")
1513
@pytest.mark.parametrize(
1614
"common_llm_kwargs",
1715
[{
16+
# Use a small model for a fast test.
17+
# Note this is repeated in the test body; to initialize a tokenizer.
1818
"model": "JackFram/llama-68m",
1919
2020
# Skip cuda graph recording for fast test.
2121
"enforce_eager": True,
2222
2323
# Required for spec decode.
2424
"use_v2_block_manager": True,
25-
"tensor_parallel_size": 2,
25+
"tensor_parallel_size": 4,
2626
2727
# Use AsyncLLM engine, so that the engine runs in its own process.
2828
# Otherwise, since vLLM does not follow true SPMD, the test runner
@@ -31,35 +31,30 @@
3131
# second run of the test to fail with internal NCCL error.
3232
"use_async": True,
3333
}])
34-
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
35-
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
36-
@pytest.mark.parametrize("test_llm_kwargs", [
34+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
3735
{
3836
"speculative_model": "JackFram/llama-68m",
39-
"num_speculative_tokens": 3,
40-
},
41-
{
42-
"speculative_model": "[ngram]",
4337
"num_speculative_tokens": 5,
44-
"ngram_prompt_lookup_max": 3,
4538
},
4639
])
47-
@pytest.mark.parametrize("batch_size", [2])
40+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
4841
@pytest.mark.parametrize(
49-
"output_len",
42+
"test_llm_kwargs",
5043
[
51-
# Use smaller output len for fast test.
52-
32,
44+
#TODO(wooyeon): add spec_draft_dp=2 case
45+
{
46+
"speculative_draft_tensor_parallel_size": 1,
47+
},
5348
])
49+
@pytest.mark.parametrize("batch_size", [2])
5450
@pytest.mark.parametrize("seed", [1])
55-
def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
56-
batch_size: int, output_len: int):
57-
"""Verify greedy equality when tensor parallelism is used.
51+
def test_draft_model_tp_lt_target_model_tp4(test_llm_generator,
52+
baseline_llm_generator,
53+
batch_size: int):
54+
"""Verify spec decode works well with smaller tp for draft models.
5855
"""
59-
if is_hip():
60-
pytest.skip("hip is not well-supported yet")
6156
run_greedy_equality_correctness_test(baseline_llm_generator,
6257
test_llm_generator,
6358
batch_size,
64-
max_output_len=output_len,
59+
max_output_len=32,
6560
force_output_len=True)

vllm/config.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,7 @@ def maybe_create_spec_config(
797797
target_parallel_config: ParallelConfig,
798798
target_dtype: str,
799799
speculative_model: Optional[str],
800+
speculative_draft_tensor_parallel_size: Optional[int],
800801
num_speculative_tokens: Optional[int],
801802
speculative_max_model_len: Optional[int],
802803
enable_chunked_prefill: bool,
@@ -819,6 +820,8 @@ def maybe_create_spec_config(
819820
target_dtype (str): The data type used for the target model.
820821
speculative_model (Optional[str]): The name of the speculative
821822
model, if provided.
823+
speculative_draft_tensor_parallel_size (Optional[int]): The degree
824+
of the tensor parallelism for the draft model.
822825
num_speculative_tokens (Optional[int]): The number of speculative
823826
tokens, if provided. Will default to the number in the draft
824827
model config if present, otherwise is required.
@@ -939,7 +942,8 @@ def maybe_create_spec_config(
939942

940943
draft_parallel_config = (
941944
SpeculativeConfig.create_draft_parallel_config(
942-
target_parallel_config))
945+
target_parallel_config,
946+
speculative_draft_tensor_parallel_size))
943947

944948
if num_speculative_tokens is None:
945949
raise ValueError(
@@ -993,16 +997,26 @@ def _maybe_override_draft_max_model_len(
993997

994998
@staticmethod
995999
def create_draft_parallel_config(
996-
target_parallel_config: ParallelConfig) -> ParallelConfig:
1000+
target_parallel_config: ParallelConfig,
1001+
speculative_draft_tensor_parallel_size: Optional[int]
1002+
) -> ParallelConfig:
9971003
"""Create a parallel config for use by the draft worker.
9981004
999-
This is mostly a copy of the target parallel config. In the future the
1000-
draft worker can have a different parallel strategy, e.g. TP=1.
1005+
This is mostly a copy of the target parallel config, except the tp_size.
10011006
"""
1007+
if speculative_draft_tensor_parallel_size is None:
1008+
speculative_draft_tensor_parallel_size = \
1009+
target_parallel_config.tensor_parallel_size
1010+
elif speculative_draft_tensor_parallel_size != 1:
1011+
# TODO(wooyeon): allow tp values larger than 1
1012+
raise ValueError(
1013+
f"{speculative_draft_tensor_parallel_size=} cannot be"
1014+
f"other value than 1")
1015+
10021016
draft_parallel_config = ParallelConfig(
10031017
pipeline_parallel_size=target_parallel_config.
10041018
pipeline_parallel_size,
1005-
tensor_parallel_size=target_parallel_config.tensor_parallel_size,
1019+
tensor_parallel_size=speculative_draft_tensor_parallel_size,
10061020
distributed_executor_backend=target_parallel_config.
10071021
distributed_executor_backend,
10081022
max_parallel_loading_workers=target_parallel_config.

vllm/distributed/parallel_state.py

+55-21
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,28 @@ def get_world_group() -> GroupCoordinator:
676676
return _WORLD
677677

678678

679+
def init_world_group(ranks: List[int], local_rank: int,
680+
backend: str) -> GroupCoordinator:
681+
return GroupCoordinator(
682+
group_ranks=[ranks],
683+
local_rank=local_rank,
684+
torch_distributed_backend=backend,
685+
use_pynccl=False,
686+
use_custom_allreduce=False,
687+
)
688+
689+
690+
def init_model_parallel_group(group_ranks: List[List[int]], local_rank: int,
691+
backend: str) -> GroupCoordinator:
692+
return GroupCoordinator(
693+
group_ranks=group_ranks,
694+
local_rank=local_rank,
695+
torch_distributed_backend=backend,
696+
use_pynccl=True,
697+
use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE,
698+
)
699+
700+
679701
_TP: Optional[GroupCoordinator] = None
680702

681703

@@ -764,13 +786,7 @@ def init_distributed_environment(
764786
global _WORLD
765787
if _WORLD is None:
766788
ranks = list(range(torch.distributed.get_world_size()))
767-
_WORLD = GroupCoordinator(
768-
group_ranks=[ranks],
769-
local_rank=local_rank,
770-
torch_distributed_backend=backend,
771-
use_pynccl=False,
772-
use_custom_allreduce=False,
773-
)
789+
_WORLD = init_world_group(ranks, local_rank, backend)
774790
else:
775791
assert _WORLD.world_size == torch.distributed.get_world_size(), (
776792
"world group already initialized with a different world size")
@@ -827,13 +843,8 @@ def initialize_model_parallel(
827843
range(i * tensor_model_parallel_size,
828844
(i + 1) * tensor_model_parallel_size))
829845
group_ranks.append(ranks)
830-
_TP = GroupCoordinator(
831-
group_ranks=group_ranks,
832-
local_rank=get_world_group().local_rank,
833-
torch_distributed_backend=backend,
834-
use_pynccl=True,
835-
use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE,
836-
)
846+
_TP = init_model_parallel_group(group_ranks,
847+
get_world_group().local_rank, backend)
837848

838849
# Build the pipeline model-parallel groups.
839850
num_pipeline_model_parallel_groups: int = (world_size //
@@ -845,13 +856,8 @@ def initialize_model_parallel(
845856
for i in range(num_pipeline_model_parallel_groups):
846857
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
847858
group_ranks.append(ranks)
848-
_PP = GroupCoordinator(
849-
group_ranks=group_ranks,
850-
local_rank=get_world_group().local_rank,
851-
torch_distributed_backend=backend,
852-
use_pynccl=True,
853-
use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE,
854-
)
859+
_PP = init_model_parallel_group(group_ranks,
860+
get_world_group().local_rank, backend)
855861

856862

857863
def ensure_model_parallel_initialized(
@@ -887,6 +893,34 @@ def model_parallel_is_initialized():
887893
return (_TP is not None and _PP is not None)
888894

889895

896+
_TP_STATE_PATCHED = False
897+
898+
899+
@contextmanager
900+
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
901+
"""Patch the tp group temporarily until this function ends.
902+
903+
This method is for draft workers of speculative decoding to run draft model
904+
with different tp degree from that of target model workers.
905+
906+
Args:
907+
tp_group (GroupCoordinator): the tp group coordinator
908+
"""
909+
global _TP_STATE_PATCHED
910+
assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
911+
912+
_TP_STATE_PATCHED = True
913+
old_tp_group = get_tp_group()
914+
global _TP
915+
_TP = tp_group
916+
try:
917+
yield
918+
finally:
919+
# restore the original state
920+
_TP_STATE_PATCHED = False
921+
_TP = old_tp_group
922+
923+
890924
def get_tensor_model_parallel_world_size():
891925
"""Return world size for the tensor model parallel group."""
892926
return get_tp_group().world_size

0 commit comments

Comments
 (0)