Skip to content

Commit 5a3ab2a

Browse files
comaniacAlvant
authored andcommitted
[Spec Decode] Introduce DraftModelRunner (vllm-project#5799)
Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent 4f99718 commit 5a3ab2a

15 files changed

+258
-37
lines changed

tests/spec_decode/test_multi_step_worker.py

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from vllm.model_executor.utils import set_random_seed
99
from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
10+
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
1011
from vllm.spec_decode.multi_step_worker import MultiStepWorker
1112
from vllm.spec_decode.top1_proposer import Top1Proposer
1213
from vllm.worker.worker import Worker
@@ -85,6 +86,7 @@ def test_same_output_for_single_step():
8586
block_size,
8687
num_gpu_blocks,
8788
seed,
89+
model_runner_cls=TP1DraftModelRunner,
8890
)
8991
worker = create_worker(
9092
Worker,
@@ -168,6 +170,7 @@ def test_same_output_for_multi_step():
168170
block_size,
169171
num_gpu_blocks,
170172
seed,
173+
model_runner_cls=TP1DraftModelRunner,
171174
)
172175

173176
worker = create_worker(

tests/spec_decode/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
SequenceOutput)
1515
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
1616
from vllm.worker.cache_engine import CacheEngine
17+
from vllm.worker.model_runner import ModelRunner
1718
from vllm.worker.worker import Worker
1819

1920
T = TypeVar("T", bound=Worker)
@@ -66,7 +67,8 @@ def create_worker(cls: Callable[..., T],
6667
num_gpu_blocks: int,
6768
seed: int,
6869
is_driver_worker: bool = True,
69-
enforce_eager: bool = True) -> T:
70+
enforce_eager: bool = True,
71+
model_runner_cls: Optional[ModelRunner] = None) -> T:
7072
engine_args = EngineArgs(
7173
model=model_name,
7274
seed=seed,
@@ -89,6 +91,7 @@ def create_worker(cls: Callable[..., T],
8991
rank=0,
9092
distributed_init_method=distributed_init_method,
9193
is_driver_worker=is_driver_worker,
94+
model_runner_cls=model_runner_cls,
9295
)
9396

9497
worker.init_device()

vllm/sequence.py

+3
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,8 @@ class ExecuteModelRequest:
880880
running_queue_size: int = 0
881881
# Optional hidden states from prior step.
882882
previous_hidden_states: Optional[HiddenStates] = None
883+
# The number of forward steps to run.
884+
num_steps: int = 1
883885

884886
def clone(
885887
self, seq_group_metadata_list: List[SequenceGroupMetadata]
@@ -893,4 +895,5 @@ def clone(
893895
num_lookahead_slots=self.num_lookahead_slots,
894896
running_queue_size=self.running_queue_size,
895897
previous_hidden_states=self.previous_hidden_states,
898+
num_steps=self.num_steps,
896899
)
+170
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from typing import List, Optional
2+
3+
import torch
4+
5+
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
6+
ModelConfig, ParallelConfig, SchedulerConfig,
7+
VisionLanguageConfig)
8+
from vllm.logger import init_logger
9+
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
10+
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
11+
ModelRunner)
12+
13+
logger = init_logger(__name__)
14+
15+
16+
class TP1DraftModelRunner(ModelRunner):
17+
"""Specialized model runner for speculative decoding draft model.
18+
Since the draft model always execute k forward passes consecutively to
19+
generate k speculative tokens in a single speculative decoding step,
20+
we could get rid of most CPU-GPU synchronization and data transfer
21+
overheads by keeping model input and output tensors on GPU all the time.
22+
23+
This runner is still under development so there's no performance gain
24+
at this moment. Currently we adopt a temporary solution that caches the
25+
seq_group_metadata_list for multi-step execution, so that we can
26+
leverage existing prepare_model_input to be compatible with the current
27+
execution flow, but we plan to remove this cache and avoid calling
28+
prepare_model_input in execute_model at all.
29+
30+
The detail development plan includes:
31+
1. Use "update_model_input" to update existing model_input without
32+
creating a new one.
33+
2. Improve the performance of "update_model_input" with a GPU kernel.
34+
3. Support TP > 1 (this requires some designs because we do not expect
35+
any broadcasting inside execute_model).
36+
"""
37+
38+
def __init__(
39+
self,
40+
model_config: ModelConfig,
41+
parallel_config: ParallelConfig,
42+
scheduler_config: SchedulerConfig,
43+
device_config: DeviceConfig,
44+
cache_config: CacheConfig,
45+
load_config: LoadConfig,
46+
lora_config: Optional[LoRAConfig],
47+
kv_cache_dtype: Optional[str] = "auto",
48+
is_driver_worker: bool = False,
49+
vision_language_config: Optional[VisionLanguageConfig] = None,
50+
return_hidden_states: bool = False,
51+
):
52+
if return_hidden_states:
53+
raise ValueError(
54+
"return_hidden_states is not supported for TP1DraftModelRunner."
55+
)
56+
57+
super().__init__(
58+
model_config=model_config,
59+
parallel_config=parallel_config,
60+
scheduler_config=scheduler_config,
61+
device_config=device_config,
62+
cache_config=cache_config,
63+
load_config=load_config,
64+
lora_config=lora_config,
65+
kv_cache_dtype=kv_cache_dtype,
66+
is_driver_worker=is_driver_worker,
67+
vision_language_config=vision_language_config,
68+
return_hidden_states=return_hidden_states,
69+
)
70+
71+
# TODO: Remove this cache when we are able to update model_input
72+
# directly in advance_step.
73+
self.cached_seq_group_metadata_list: Optional[
74+
List[SequenceGroupMetadata]] = None
75+
76+
def prepare_model_input(
77+
self,
78+
seq_group_metadata_list: List[SequenceGroupMetadata],
79+
) -> ModelInputForGPUWithSamplingMetadata:
80+
"""A temporary solution that caches the seq_group_metadata_list
81+
for multi-step execution.
82+
TODO: In-place update model_input and remove this function.
83+
"""
84+
self.cached_seq_group_metadata_list = seq_group_metadata_list
85+
return super().prepare_model_input(seq_group_metadata_list)
86+
87+
def update_model_input(
88+
self, model_input: ModelInputForGPUWithSamplingMetadata,
89+
last_output: SamplerOutput
90+
) -> ModelInputForGPUWithSamplingMetadata:
91+
"""Prepare the model inputs for the next step.
92+
TODO: In-place update model_input instead of calling
93+
prepare_model_input.
94+
"""
95+
96+
# Append the output token to the sequence data.
97+
assert self.cached_seq_group_metadata_list is not None
98+
for seq_group_metadata, sequence_group_outputs in zip(
99+
self.cached_seq_group_metadata_list, last_output.outputs):
100+
seq_group_metadata.is_prompt = False
101+
102+
for seq_output in sequence_group_outputs.samples:
103+
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
104+
105+
token_id = seq_output.output_token
106+
token_logprob = seq_output.logprobs[token_id]
107+
108+
seq.append_token_id(token_id, token_logprob.logprob)
109+
seq.update_num_computed_tokens(1)
110+
111+
return self.prepare_model_input(self.cached_seq_group_metadata_list)
112+
113+
@torch.inference_mode()
114+
def execute_model(
115+
self,
116+
model_input: ModelInputForGPUWithSamplingMetadata,
117+
kv_caches: List[torch.Tensor],
118+
num_steps: int = 1,
119+
) -> Optional[List[SamplerOutput]]:
120+
# Since we do not broadcast data inside execute_model anymore,
121+
# we need to figure out the best way to support TP > 1 in this
122+
# case, because we will at least need to broadcast the sampled
123+
# tokens to all workers.
124+
if not self.is_driver_worker:
125+
raise ValueError("TP1DraftModelRunner only supports TP=1.")
126+
127+
if self.lora_config:
128+
assert model_input.lora_requests is not None
129+
assert model_input.lora_mapping is not None
130+
self.set_active_loras(model_input.lora_requests,
131+
model_input.lora_mapping)
132+
133+
outputs: List[SamplerOutput] = []
134+
for step in range(num_steps):
135+
# Currently cuda graph is only supported by the decode phase.
136+
assert model_input.attn_metadata is not None
137+
prefill_meta = model_input.attn_metadata.prefill_metadata
138+
decode_meta = model_input.attn_metadata.decode_metadata
139+
if prefill_meta is None and decode_meta.use_cuda_graph:
140+
assert model_input.input_tokens is not None
141+
graph_batch_size = model_input.input_tokens.shape[0]
142+
model_executable = self.graph_runners[graph_batch_size]
143+
else:
144+
model_executable = self.model
145+
146+
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
147+
hidden_states = model_executable(
148+
input_ids=model_input.input_tokens,
149+
positions=model_input.input_positions,
150+
kv_caches=kv_caches,
151+
attn_metadata=model_input.attn_metadata,
152+
**multi_modal_kwargs,
153+
)
154+
155+
# Compute the logits.
156+
logits = self.model.compute_logits(hidden_states,
157+
model_input.sampling_metadata)
158+
159+
# Sample the next token.
160+
outputs.append(
161+
self.model.sample(
162+
logits=logits,
163+
sampling_metadata=model_input.sampling_metadata,
164+
))
165+
166+
# Prepare the inputs for the next step.
167+
if step != num_steps - 1:
168+
model_input = self.update_model_input(model_input, outputs[-1])
169+
170+
return outputs

vllm/spec_decode/multi_step_worker.py

+16-13
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
88
SequenceGroupMetadata)
9+
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
910
from vllm.spec_decode.interfaces import (SpeculativeProposals,
1011
SpeculativeProposer)
1112
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
@@ -67,22 +68,24 @@ def sampler_output(
6768
copied_execute_model_req = execute_model_req.clone(
6869
copied_seq_group_metadata_list)
6970

70-
# Assert enough KV space for sample_len tokens per sequence.
71-
self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list,
72-
sample_len)
73-
7471
# Run model sample_len times.
7572
model_outputs: List[SamplerOutput] = []
76-
for _ in range(sample_len):
77-
model_output: List[SamplerOutput] = super().execute_model(
73+
if isinstance(self.model_runner, TP1DraftModelRunner):
74+
copied_execute_model_req.num_steps = sample_len
75+
model_outputs = self.execute_model(
7876
execute_model_req=copied_execute_model_req)
79-
assert (len(model_output) == 1
80-
), "composing multistep workers not supported"
81-
model_output = model_output[0]
82-
83-
self._append_new_tokens(model_output,
84-
copied_seq_group_metadata_list)
85-
model_outputs.append(model_output)
77+
else:
78+
# TODO: Remove this branch once DraftModelRunner supports TP>1.
79+
for _ in range(sample_len):
80+
model_output: List[SamplerOutput] = super().execute_model(
81+
execute_model_req=copied_execute_model_req)
82+
assert (len(model_output) == 1
83+
), "composing multistep workers not supported"
84+
model_output = model_output[0]
85+
86+
self._append_new_tokens(model_output,
87+
copied_seq_group_metadata_list)
88+
model_outputs.append(model_output)
8689

8790
return model_outputs, True
8891

vllm/spec_decode/spec_decode_worker.py

+3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
HiddenStates, SamplerOutput, SequenceGroupMetadata,
1212
get_all_seq_ids)
1313
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
14+
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
1415
from vllm.spec_decode.interfaces import (SpeculativeProposals,
1516
SpeculativeScorer, SpeculativeScores)
1617
from vllm.spec_decode.metrics import AsyncMetricsCollector
@@ -117,6 +118,8 @@ def create_worker(
117118
draft_tp = draft_parallel_config.tensor_parallel_size
118119
target_tp = scorer_worker.parallel_config.tensor_parallel_size
119120

121+
if draft_tp == 1:
122+
draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
120123
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
121124
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
122125
proposer_worker, draft_tp, target_tp)

vllm/worker/cpu_model_runner.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,12 @@ def execute_model(
351351
self,
352352
model_input: CPUModelInput,
353353
kv_caches: List[torch.Tensor],
354-
) -> Optional[SamplerOutput]:
354+
num_steps: int = 1,
355+
) -> Optional[List[SamplerOutput]]:
356+
if num_steps > 1:
357+
raise ValueError(
358+
"CPU worker does not support multi-step execution.")
359+
355360
model_executable = self.model
356361
execute_model_kwargs = {
357362
"input_ids": model_input.input_tokens,
@@ -371,11 +376,11 @@ def execute_model(
371376

372377
# Only perform sampling in the driver worker.
373378
if not self.is_driver_worker:
374-
return None
379+
return []
375380

376381
# Sample the next token.
377382
output = self.model.sample(
378383
logits=logits,
379384
sampling_metadata=model_input.sampling_metadata,
380385
)
381-
return output
386+
return [output]

vllm/worker/embedding_model_runner.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,12 @@ def execute_model(
5757
self,
5858
model_input: ModelInputForGPUWithPoolingMetadata,
5959
kv_caches: List[torch.Tensor],
60-
) -> Optional[PoolerOutput]:
60+
num_steps: int = 1,
61+
) -> Optional[List[PoolerOutput]]:
62+
if num_steps > 1:
63+
raise ValueError(
64+
"EmbeddingModelRunner does not support multi-step execution.")
65+
6166
if self.lora_config:
6267
assert model_input.lora_requests is not None
6368
assert model_input.lora_mapping is not None
@@ -91,10 +96,12 @@ def execute_model(
9196

9297
# Only perform pooling in the driver worker.
9398
if not self.is_driver_worker:
94-
return None
99+
return []
95100

96-
return self.model.pooler(hidden_states=hidden_states,
97-
pooling_metadata=model_input.pooling_metadata)
101+
return [
102+
self.model.pooler(hidden_states=hidden_states,
103+
pooling_metadata=model_input.pooling_metadata)
104+
]
98105

99106
def make_model_input_from_broadcasted_tensor_dict(
100107
self,

vllm/worker/model_runner.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,11 @@ def execute_model(
959959
self,
960960
model_input: ModelInputForGPUWithSamplingMetadata,
961961
kv_caches: List[torch.Tensor],
962-
) -> SamplerOutput:
962+
num_steps: int = 1,
963+
) -> Optional[List[SamplerOutput]]:
964+
if num_steps > 1:
965+
raise ValueError("num_steps > 1 is not supported in ModelRunner")
966+
963967
if self.lora_config:
964968
assert model_input.lora_requests is not None
965969
assert model_input.lora_mapping is not None
@@ -992,7 +996,7 @@ def execute_model(
992996

993997
# Only perform sampling in the driver worker.
994998
if not self.is_driver_worker:
995-
return None
999+
return []
9961000

9971001
# Sample the next token.
9981002
output: SamplerOutput = self.model.sample(
@@ -1011,7 +1015,7 @@ def execute_model(
10111015

10121016
output.hidden_states = hidden_states
10131017

1014-
return output
1018+
return [output]
10151019

10161020

10171021
class CUDAGraphRunner:

vllm/worker/model_runner_base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ def execute_model(
150150
self,
151151
model_input: T,
152152
kv_caches: Optional[List[torch.Tensor]],
153-
) -> Optional[SamplerOutput]:
153+
num_steps: int = 1,
154+
) -> Optional[List[SamplerOutput]]:
154155
"""
155156
Execute the model on the given input.
156157
"""

0 commit comments

Comments
 (0)