|
| 1 | +from typing import List, Optional |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, |
| 6 | + ModelConfig, ParallelConfig, SchedulerConfig, |
| 7 | + VisionLanguageConfig) |
| 8 | +from vllm.logger import init_logger |
| 9 | +from vllm.sequence import SamplerOutput, SequenceGroupMetadata |
| 10 | +from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, |
| 11 | + ModelRunner) |
| 12 | + |
| 13 | +logger = init_logger(__name__) |
| 14 | + |
| 15 | + |
| 16 | +class DraftModelRunner(ModelRunner): |
| 17 | + |
| 18 | + def __init__( |
| 19 | + self, |
| 20 | + model_config: ModelConfig, |
| 21 | + parallel_config: ParallelConfig, |
| 22 | + scheduler_config: SchedulerConfig, |
| 23 | + device_config: DeviceConfig, |
| 24 | + cache_config: CacheConfig, |
| 25 | + load_config: LoadConfig, |
| 26 | + lora_config: Optional[LoRAConfig], |
| 27 | + kv_cache_dtype: Optional[str] = "auto", |
| 28 | + is_driver_worker: bool = False, |
| 29 | + vision_language_config: Optional[VisionLanguageConfig] = None, |
| 30 | + return_hidden_states: bool = False, |
| 31 | + ): |
| 32 | + if return_hidden_states: |
| 33 | + raise ValueError( |
| 34 | + "return_hidden_states is not supported for DraftModelRunner.") |
| 35 | + |
| 36 | + super().__init__( |
| 37 | + model_config=model_config, |
| 38 | + parallel_config=parallel_config, |
| 39 | + scheduler_config=scheduler_config, |
| 40 | + device_config=device_config, |
| 41 | + cache_config=cache_config, |
| 42 | + load_config=load_config, |
| 43 | + lora_config=lora_config, |
| 44 | + kv_cache_dtype=kv_cache_dtype, |
| 45 | + is_driver_worker=is_driver_worker, |
| 46 | + vision_language_config=vision_language_config, |
| 47 | + return_hidden_states=return_hidden_states, |
| 48 | + ) |
| 49 | + |
| 50 | + # TODO: Remove this cache when we are able to update model_input |
| 51 | + # directly in advance_step. |
| 52 | + self.cached_seq_group_metadata_list: Optional[ |
| 53 | + List[SequenceGroupMetadata]] = None |
| 54 | + |
| 55 | + def prepare_model_input( |
| 56 | + self, |
| 57 | + seq_group_metadata_list: List[SequenceGroupMetadata], |
| 58 | + ) -> ModelInputForGPUWithSamplingMetadata: |
| 59 | + """A temporary solution that caches the seq_group_metadata_list |
| 60 | + for multi-step execution. |
| 61 | + TODO: In-place update model_input and remove this function. |
| 62 | + """ |
| 63 | + self.cached_seq_group_metadata_list = seq_group_metadata_list |
| 64 | + return super().prepare_model_input(seq_group_metadata_list) |
| 65 | + |
| 66 | + def advance_step( |
| 67 | + self, model_input: ModelInputForGPUWithSamplingMetadata, |
| 68 | + last_output: SamplerOutput |
| 69 | + ) -> ModelInputForGPUWithSamplingMetadata: |
| 70 | + """Prepare the model inputs for the next step. |
| 71 | + TODO: In-place update model_input instead of calling |
| 72 | + prepare_model_input. |
| 73 | + """ |
| 74 | + |
| 75 | + # Append the output token to the sequence data. |
| 76 | + assert self.cached_seq_group_metadata_list is not None |
| 77 | + for seq_group_metadata, sequence_group_outputs in zip( |
| 78 | + self.cached_seq_group_metadata_list, last_output.outputs): |
| 79 | + seq_group_metadata.is_prompt = False |
| 80 | + |
| 81 | + for seq_output in sequence_group_outputs.samples: |
| 82 | + seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] |
| 83 | + |
| 84 | + token_id = seq_output.output_token |
| 85 | + token_logprob = seq_output.logprobs[token_id] |
| 86 | + |
| 87 | + seq.append_token_id(token_id, token_logprob.logprob) |
| 88 | + seq.update_num_computed_tokens(1) |
| 89 | + |
| 90 | + return self.prepare_model_input(self.cached_seq_group_metadata_list) |
| 91 | + |
| 92 | + @torch.inference_mode() |
| 93 | + def execute_model( |
| 94 | + self, |
| 95 | + model_input: ModelInputForGPUWithSamplingMetadata, |
| 96 | + kv_caches: List[torch.Tensor], |
| 97 | + num_steps: int = 1, |
| 98 | + ) -> Optional[List[SamplerOutput]]: |
| 99 | + # Since we do not broadcast data inside execute_model anymore, |
| 100 | + # we need to figure out the best way to support TP > 1 in this |
| 101 | + # case, because we will at least need to broadcast the sampled |
| 102 | + # tokens to all workers. |
| 103 | + if not self.is_driver_worker: |
| 104 | + raise ValueError("DraftModelRunner only supports TP=1 for now.") |
| 105 | + |
| 106 | + if self.lora_config: |
| 107 | + assert model_input.lora_requests is not None |
| 108 | + assert model_input.lora_mapping is not None |
| 109 | + self.set_active_loras(model_input.lora_requests, |
| 110 | + model_input.lora_mapping) |
| 111 | + |
| 112 | + outputs: List[SamplerOutput] = [] |
| 113 | + for step in range(num_steps): |
| 114 | + # Currently cuda graph is only supported by the decode phase. |
| 115 | + assert model_input.attn_metadata is not None |
| 116 | + prefill_meta = model_input.attn_metadata.prefill_metadata |
| 117 | + decode_meta = model_input.attn_metadata.decode_metadata |
| 118 | + if prefill_meta is None and decode_meta.use_cuda_graph: |
| 119 | + assert model_input.input_tokens is not None |
| 120 | + graph_batch_size = model_input.input_tokens.shape[0] |
| 121 | + model_executable = self.graph_runners[graph_batch_size] |
| 122 | + else: |
| 123 | + model_executable = self.model |
| 124 | + |
| 125 | + multi_modal_kwargs = model_input.multi_modal_kwargs or {} |
| 126 | + hidden_states = model_executable( |
| 127 | + input_ids=model_input.input_tokens, |
| 128 | + positions=model_input.input_positions, |
| 129 | + kv_caches=kv_caches, |
| 130 | + attn_metadata=model_input.attn_metadata, |
| 131 | + **multi_modal_kwargs, |
| 132 | + ) |
| 133 | + |
| 134 | + # Compute the logits. |
| 135 | + logits = self.model.compute_logits(hidden_states, |
| 136 | + model_input.sampling_metadata) |
| 137 | + |
| 138 | + # Sample the next token. |
| 139 | + outputs.append( |
| 140 | + self.model.sample( |
| 141 | + logits=logits, |
| 142 | + sampling_metadata=model_input.sampling_metadata, |
| 143 | + )) |
| 144 | + |
| 145 | + # Prepare the inputs for the next step. |
| 146 | + if step != num_steps - 1: |
| 147 | + model_input = self.advance_step(model_input, outputs[-1]) |
| 148 | + |
| 149 | + return outputs |
0 commit comments