Skip to content

Commit 547588e

Browse files
committedJun 26, 2024
done
1 parent c54269d commit 547588e

15 files changed

+221
-58
lines changed
 

‎tests/spec_decode/test_multi_step_worker.py

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from vllm.model_executor.utils import set_random_seed
99
from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
10+
from vllm.spec_decode.draft_model_runner import DraftModelRunner
1011
from vllm.spec_decode.multi_step_worker import MultiStepWorker
1112
from vllm.spec_decode.top1_proposer import Top1Proposer
1213
from vllm.worker.worker import Worker
@@ -85,6 +86,7 @@ def test_same_output_for_single_step():
8586
block_size,
8687
num_gpu_blocks,
8788
seed,
89+
model_runner_cls=DraftModelRunner,
8890
)
8991
worker = create_worker(
9092
Worker,
@@ -168,6 +170,7 @@ def test_same_output_for_multi_step():
168170
block_size,
169171
num_gpu_blocks,
170172
seed,
173+
model_runner_cls=DraftModelRunner,
171174
)
172175

173176
worker = create_worker(

‎tests/spec_decode/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
SequenceOutput)
1515
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
1616
from vllm.worker.cache_engine import CacheEngine
17+
from vllm.worker.model_runner import ModelRunner
1718
from vllm.worker.worker import Worker
1819

1920
T = TypeVar("T", bound=Worker)
@@ -66,7 +67,8 @@ def create_worker(cls: Callable[..., T],
6667
num_gpu_blocks: int,
6768
seed: int,
6869
is_driver_worker: bool = True,
69-
enforce_eager: bool = True) -> T:
70+
enforce_eager: bool = True,
71+
model_runner_cls: Optional[ModelRunner] = None) -> T:
7072
engine_args = EngineArgs(
7173
model=model_name,
7274
seed=seed,
@@ -89,6 +91,7 @@ def create_worker(cls: Callable[..., T],
8991
rank=0,
9092
distributed_init_method=distributed_init_method,
9193
is_driver_worker=is_driver_worker,
94+
model_runner_cls=model_runner_cls,
9295
)
9396

9497
worker.init_device()

‎vllm/sequence.py

+3
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,8 @@ class ExecuteModelRequest:
903903
running_queue_size: int = 0
904904
# Optional hidden states from prior step.
905905
previous_hidden_states: Optional[HiddenStates] = None
906+
# The number of forward steps to run.
907+
num_steps: int = 1
906908

907909
def clone(
908910
self, seq_group_metadata_list: List[SequenceGroupMetadata]
@@ -916,4 +918,5 @@ def clone(
916918
num_lookahead_slots=self.num_lookahead_slots,
917919
running_queue_size=self.running_queue_size,
918920
previous_hidden_states=self.previous_hidden_states,
921+
num_steps=self.num_steps,
919922
)
+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
from typing import List, Optional
2+
3+
import torch
4+
5+
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
6+
ModelConfig, ParallelConfig, SchedulerConfig,
7+
VisionLanguageConfig)
8+
from vllm.logger import init_logger
9+
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
10+
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
11+
ModelRunner)
12+
13+
logger = init_logger(__name__)
14+
15+
16+
class DraftModelRunner(ModelRunner):
17+
18+
def __init__(
19+
self,
20+
model_config: ModelConfig,
21+
parallel_config: ParallelConfig,
22+
scheduler_config: SchedulerConfig,
23+
device_config: DeviceConfig,
24+
cache_config: CacheConfig,
25+
load_config: LoadConfig,
26+
lora_config: Optional[LoRAConfig],
27+
kv_cache_dtype: Optional[str] = "auto",
28+
is_driver_worker: bool = False,
29+
vision_language_config: Optional[VisionLanguageConfig] = None,
30+
return_hidden_states: bool = False,
31+
):
32+
if return_hidden_states:
33+
raise ValueError(
34+
"return_hidden_states is not supported for DraftModelRunner.")
35+
36+
super().__init__(
37+
model_config=model_config,
38+
parallel_config=parallel_config,
39+
scheduler_config=scheduler_config,
40+
device_config=device_config,
41+
cache_config=cache_config,
42+
load_config=load_config,
43+
lora_config=lora_config,
44+
kv_cache_dtype=kv_cache_dtype,
45+
is_driver_worker=is_driver_worker,
46+
vision_language_config=vision_language_config,
47+
return_hidden_states=return_hidden_states,
48+
)
49+
50+
# TODO: Remove this cache when we are able to update model_input
51+
# directly in advance_step.
52+
self.cached_seq_group_metadata_list: Optional[
53+
List[SequenceGroupMetadata]] = None
54+
55+
def prepare_model_input(
56+
self,
57+
seq_group_metadata_list: List[SequenceGroupMetadata],
58+
) -> ModelInputForGPUWithSamplingMetadata:
59+
"""A temporary solution that caches the seq_group_metadata_list
60+
for multi-step execution.
61+
TODO: In-place update model_input and remove this function.
62+
"""
63+
self.cached_seq_group_metadata_list = seq_group_metadata_list
64+
return super().prepare_model_input(seq_group_metadata_list)
65+
66+
def advance_step(
67+
self, model_input: ModelInputForGPUWithSamplingMetadata,
68+
last_output: SamplerOutput
69+
) -> ModelInputForGPUWithSamplingMetadata:
70+
"""Prepare the model inputs for the next step.
71+
TODO: In-place update model_input instead of calling
72+
prepare_model_input.
73+
"""
74+
75+
# Append the output token to the sequence data.
76+
assert self.cached_seq_group_metadata_list is not None
77+
for seq_group_metadata, sequence_group_outputs in zip(
78+
self.cached_seq_group_metadata_list, last_output.outputs):
79+
seq_group_metadata.is_prompt = False
80+
81+
for seq_output in sequence_group_outputs.samples:
82+
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
83+
84+
token_id = seq_output.output_token
85+
token_logprob = seq_output.logprobs[token_id]
86+
87+
seq.append_token_id(token_id, token_logprob.logprob)
88+
seq.update_num_computed_tokens(1)
89+
90+
return self.prepare_model_input(self.cached_seq_group_metadata_list)
91+
92+
@torch.inference_mode()
93+
def execute_model(
94+
self,
95+
model_input: ModelInputForGPUWithSamplingMetadata,
96+
kv_caches: List[torch.Tensor],
97+
num_steps: int = 1,
98+
) -> Optional[List[SamplerOutput]]:
99+
# Since we do not broadcast data inside execute_model anymore,
100+
# we need to figure out the best way to support TP > 1 in this
101+
# case, because we will at least need to broadcast the sampled
102+
# tokens to all workers.
103+
if not self.is_driver_worker:
104+
raise ValueError("DraftModelRunner only supports TP=1 for now.")
105+
106+
if self.lora_config:
107+
assert model_input.lora_requests is not None
108+
assert model_input.lora_mapping is not None
109+
self.set_active_loras(model_input.lora_requests,
110+
model_input.lora_mapping)
111+
112+
outputs: List[SamplerOutput] = []
113+
for step in range(num_steps):
114+
# Currently cuda graph is only supported by the decode phase.
115+
assert model_input.attn_metadata is not None
116+
prefill_meta = model_input.attn_metadata.prefill_metadata
117+
decode_meta = model_input.attn_metadata.decode_metadata
118+
if prefill_meta is None and decode_meta.use_cuda_graph:
119+
assert model_input.input_tokens is not None
120+
graph_batch_size = model_input.input_tokens.shape[0]
121+
model_executable = self.graph_runners[graph_batch_size]
122+
else:
123+
model_executable = self.model
124+
125+
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
126+
hidden_states = model_executable(
127+
input_ids=model_input.input_tokens,
128+
positions=model_input.input_positions,
129+
kv_caches=kv_caches,
130+
attn_metadata=model_input.attn_metadata,
131+
**multi_modal_kwargs,
132+
)
133+
134+
# Compute the logits.
135+
logits = self.model.compute_logits(hidden_states,
136+
model_input.sampling_metadata)
137+
138+
# Sample the next token.
139+
outputs.append(
140+
self.model.sample(
141+
logits=logits,
142+
sampling_metadata=model_input.sampling_metadata,
143+
))
144+
145+
# Prepare the inputs for the next step.
146+
if step != num_steps - 1:
147+
model_input = self.advance_step(model_input, outputs[-1])
148+
149+
return outputs

‎vllm/spec_decode/multi_step_worker.py

+3-38
Original file line numberDiff line numberDiff line change
@@ -67,22 +67,10 @@ def sampler_output(
6767
copied_execute_model_req = execute_model_req.clone(
6868
copied_seq_group_metadata_list)
6969

70-
# Assert enough KV space for sample_len tokens per sequence.
71-
self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list,
72-
sample_len)
73-
7470
# Run model sample_len times.
75-
model_outputs: List[SamplerOutput] = []
76-
for _ in range(sample_len):
77-
model_output: List[SamplerOutput] = super().execute_model(
78-
execute_model_req=copied_execute_model_req)
79-
assert (len(model_output) == 1
80-
), "composing multistep workers not supported"
81-
model_output = model_output[0]
82-
83-
self._append_new_tokens(model_output,
84-
copied_seq_group_metadata_list)
85-
model_outputs.append(model_output)
71+
copied_execute_model_req.num_steps = sample_len
72+
model_outputs: List[SamplerOutput] = self.execute_model(
73+
execute_model_req=copied_execute_model_req)
8674

8775
return model_outputs, True
8876

@@ -96,29 +84,6 @@ def get_spec_proposals(
9684

9785
return self._proposer.get_spec_proposals(execute_model_req)
9886

99-
@staticmethod
100-
def _append_new_tokens(
101-
model_output: List[SamplerOutput],
102-
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
103-
"""Given model output from a single run, append the tokens to the
104-
sequences. This is normally done outside of the worker, but it is
105-
required if the worker is to perform multiple forward passes.
106-
"""
107-
for seq_group_metadata, sequence_group_outputs in zip(
108-
seq_group_metadata_list, model_output):
109-
seq_group_metadata.is_prompt = False
110-
111-
for seq_output in sequence_group_outputs.samples:
112-
# NOTE: Beam search is not supported, so we can assume that
113-
# parent_seq_id == seq_id.
114-
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
115-
116-
token_id = seq_output.output_token
117-
token_logprob = seq_output.logprobs[token_id]
118-
119-
seq.append_token_id(token_id, token_logprob.logprob)
120-
seq.update_num_computed_tokens(1)
121-
12287
@staticmethod
12388
def _shallow_copy_inputs(
12489
seq_group_metadata_list: List[SequenceGroupMetadata]

‎vllm/spec_decode/spec_decode_worker.py

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
HiddenStates, SamplerOutput, SequenceGroupMetadata,
1212
get_all_seq_ids)
1313
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
14+
from vllm.spec_decode.draft_model_runner import DraftModelRunner
1415
from vllm.spec_decode.interfaces import (SpeculativeProposals,
1516
SpeculativeScorer, SpeculativeScores)
1617
from vllm.spec_decode.metrics import AsyncMetricsCollector
@@ -117,6 +118,9 @@ def create_worker(
117118
draft_tp = draft_parallel_config.tensor_parallel_size
118119
target_tp = scorer_worker.parallel_config.tensor_parallel_size
119120

121+
# DraftModelRunner only supports TP=1 for now.
122+
if draft_tp == 1:
123+
draft_worker_kwargs["model_runner_cls"] = DraftModelRunner
120124
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
121125
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
122126
proposer_worker, draft_tp, target_tp)

‎vllm/worker/cpu_model_runner.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,12 @@ def execute_model(
362362
self,
363363
model_input: CPUModelInput,
364364
kv_caches: List[torch.Tensor],
365-
) -> Optional[SamplerOutput]:
365+
num_steps: int = 1,
366+
) -> Optional[List[SamplerOutput]]:
367+
if num_steps > 1:
368+
raise ValueError(
369+
"CPU worker does not support multi-step execution.")
370+
366371
model_executable = self.model
367372
execute_model_kwargs = {
368373
"input_ids": model_input.input_tokens,
@@ -389,4 +394,4 @@ def execute_model(
389394
logits=logits,
390395
sampling_metadata=model_input.sampling_metadata,
391396
)
392-
return output
397+
return [output]

‎vllm/worker/embedding_model_runner.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,12 @@ def execute_model(
5757
self,
5858
model_input: ModelInputForGPUWithPoolingMetadata,
5959
kv_caches: List[torch.Tensor],
60-
) -> Optional[PoolerOutput]:
60+
num_steps: int = 1,
61+
) -> Optional[List[PoolerOutput]]:
62+
if num_steps > 1:
63+
raise ValueError(
64+
"EmbeddingModelRunner does not support multi-step execution.")
65+
6166
if self.lora_config:
6267
assert model_input.lora_requests is not None
6368
assert model_input.lora_mapping is not None
@@ -93,8 +98,10 @@ def execute_model(
9398
if not self.is_driver_worker:
9499
return None
95100

96-
return self.model.pooler(hidden_states=hidden_states,
97-
pooling_metadata=model_input.pooling_metadata)
101+
return [
102+
self.model.pooler(hidden_states=hidden_states,
103+
pooling_metadata=model_input.pooling_metadata)
104+
]
98105

99106
def make_model_input_from_broadcasted_tensor_dict(
100107
self,

‎vllm/worker/model_runner.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,11 @@ def execute_model(
977977
self,
978978
model_input: ModelInputForGPUWithSamplingMetadata,
979979
kv_caches: List[torch.Tensor],
980-
) -> SamplerOutput:
980+
num_steps: int = 1,
981+
) -> Optional[List[SamplerOutput]]:
982+
if num_steps > 1:
983+
raise ValueError("num_steps > 1 is not supported in ModelRunner")
984+
981985
if self.lora_config:
982986
assert model_input.lora_requests is not None
983987
assert model_input.lora_mapping is not None
@@ -1026,7 +1030,7 @@ def execute_model(
10261030
0, model_input.sampling_metadata.selected_token_indices)
10271031
output.hidden_states = hidden_states
10281032

1029-
return output
1033+
return [output]
10301034

10311035

10321036
class CUDAGraphRunner:

‎vllm/worker/model_runner_base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ def execute_model(
150150
self,
151151
model_input: T,
152152
kv_caches: Optional[List[torch.Tensor]],
153-
) -> Optional[SamplerOutput]:
153+
num_steps: int = 1,
154+
) -> Optional[List[SamplerOutput]]:
154155
"""
155156
Execute the model on the given input.
156157
"""

‎vllm/worker/neuron_model_runner.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,12 @@ def execute_model(
207207
self,
208208
model_input: ModelInputForNeuron,
209209
kv_caches: Optional[List[torch.Tensor]] = None,
210-
) -> Optional[SamplerOutput]:
210+
num_steps: int = 1,
211+
) -> Optional[List[SamplerOutput]]:
212+
if num_steps > 1:
213+
raise ValueError(
214+
"NeuronModelRunner does not support multi-step execution.")
215+
211216
hidden_states = self.model(
212217
input_ids=model_input.input_tokens,
213218
positions=model_input.input_positions,
@@ -223,7 +228,7 @@ def execute_model(
223228
logits=logits,
224229
sampling_metadata=model_input.sampling_metadata,
225230
)
226-
return output
231+
return [output]
227232

228233
@property
229234
def vocab_size(self) -> int:

0 commit comments

Comments
 (0)