diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 7b2caf4973589..b7002e75c9ef5 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,3 +1,5 @@ +import dataclasses +import weakref from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union @@ -17,7 +19,7 @@ from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, + ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, @@ -32,16 +34,17 @@ @dataclass(frozen=True) -class CPUModelInput(ModelRunnerInputBase): +class ModelInputForCPU(ModelRunnerInputBase): """ - Used by the CPUModelRunner. + Base class contains metadata needed for the base model forward pass on CPU """ input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None attn_metadata: Optional["AttentionMetadata"] = None - sampling_metadata: Optional["SamplingMetadata"] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None virtual_engine: Optional[int] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -51,88 +54,96 @@ def as_broadcastable_tensor_dict( "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) + return tensor_dict @classmethod def from_broadcasted_tensor_dict( - cls: Type["CPUModelInput"], - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None - ) -> "CPUModelInput": - tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + cls: Type["ModelInputForCPU"], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None + ) -> "ModelInputForCPU": if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( attn_backend, tensor_dict) return cls(**tensor_dict) -class CPUModelRunner(ModelRunnerBase[CPUModelInput]): +@dataclass(frozen=True) +class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None - def __init__( - self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - kv_cache_dtype: Optional[str] = "auto", - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - is_driver_worker: bool = False, - *args, - **kwargs, - ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - # Currently, CPU worker doesn't support chunked prefill. - assert self.scheduler_config.chunked_prefill_enabled is False - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.prompt_adapter_config = prompt_adapter_config - self.load_config = load_config - self.is_driver_worker = is_driver_worker + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict - self.device = self.device_config.device + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForCPUWithSamplingMetadata": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) - self.kv_cache_dtype = kv_cache_dtype - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), - self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), - self.model_config.get_sliding_window(), - self.model_config.dtype, - self.kv_cache_dtype, - self.block_size, - ) - # Multi-modal data support - self.mm_registry = MULTIMODAL_REGISTRY - self.multi_modal_input_mapper = self.mm_registry \ - .create_input_mapper(self.model_config) - self.mm_registry.init_mm_limits_per_prompt(self.model_config) +class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): - # Lazy initialization. - self.model: nn.Module # Set after init_Model + def __init__(self, + runner: "CPUModelRunner", + finished_requests_ids: Optional[List[str]] = None) -> None: + super().__init__() + self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] + self.runner = runner + self.model_input_cls = self.runner._model_input_cls + self.attn_backend = self.runner.attn_backend + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.device = self.runner.device + self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper - if self.model_config.is_encoder_decoder_model: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU']) + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): + self.seq_group_metadata_list.append(seq_group_metadata) - def load_model(self) -> None: - self.model = get_model(model_config=self.model_config, - load_config=self.load_config, - device_config=self.device_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config) + def build(self) -> ModelInputForCPU: + multi_modal_kwargs = None + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = self.seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_kwargs) = self._prepare_prompt( + self.seq_group_metadata_list) + else: + (input_tokens, input_positions, + attn_metadata) = self._prepare_decode( + self.seq_group_metadata_list) + seq_lens = [] + + return self.model_input_cls( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + multi_modal_kwargs=multi_modal_kwargs, + # query_lens is not needed if chunked prefill is not + # supported. Since CPU worker doesn't support chunked prefill + # just use seq_lens instead. + seq_lens=seq_lens, + query_lens=seq_lens, + ) def _prepare_prompt( self, @@ -165,8 +176,7 @@ def _prepare_prompt( # is always the first token in the sequence. input_positions.extend(list(range(computed_len, seq_len))) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: + if (mm_data := seq_group_metadata.multi_modal_data): mm_kwargs = self.multi_modal_input_mapper(mm_data) multi_modal_inputs_list.append(mm_kwargs) @@ -302,56 +312,130 @@ def _prepare_decode( attn_metadata, ) + +class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]): + _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = ( + ModelInputForCPUWithSamplingMetadata) + _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + is_driver_worker: bool = False, + *args, + **kwargs, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + # Currently, CPU worker doesn't support chunked prefill. + assert self.scheduler_config.chunked_prefill_enabled is False + self.device_config = device_config + self.cache_config = cache_config + self.lora_config = lora_config + self.prompt_adapter_config = prompt_adapter_config + self.load_config = load_config + self.is_driver_worker = is_driver_worker + + self.device = self.device_config.device + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) + + # Multi-modal data support + self.mm_registry = MULTIMODAL_REGISTRY + self.multi_modal_input_mapper = self.mm_registry \ + .create_input_mapper(self.model_config) + self.mm_registry.init_mm_limits_per_prompt(self.model_config) + + # Lazy initialization. + self.model: nn.Module # Set after init_Model + + if self.model_config.is_encoder_decoder_model: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU']) + + def load_model(self) -> None: + self.model = get_model(model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) + def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any], - ) -> CPUModelInput: - return CPUModelInput.from_broadcasted_tensor_dict( + ) -> ModelInputForCPU: + return ModelInputForCPU.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, ) + def _prepare_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForCPUWithSamplingMetadata: + """Helper method to prepare the model input based on a given sequence + group. Prepares metadata needed for the base model forward pass but not + metadata for possible additional steps, e.g., sampling. + + """ + builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) + for seq_group_metadata in seq_group_metadata_list: + builder.add_seq_group(seq_group_metadata) + + return builder.build() # type: ignore + def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> CPUModelInput: - multi_modal_kwargs = None - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, - attn_metadata) = self._prepare_decode(seq_group_metadata_list) - seq_lens = [] - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - # query_lens is not needed if chunked prefill is not - # supported. Since CPU worker doesn't support chunked prefill - # just use seq_lens instead. - seq_lens, - self.device, - pin_memory=False, - generators=self.get_generators(finished_requests_ids)) - return CPUModelInput( - input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata, - sampling_metadata=sampling_metadata, - multi_modal_kwargs=multi_modal_kwargs, - ) + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForCPUWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + # Sampling metadata is only required for the final pp group + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + pin_memory=False, + generators=generators) + + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + virtual_engine=virtual_engine) @torch.no_grad() def execute_model( self, - model_input: CPUModelInput, + model_input: ModelInputForCPUWithSamplingMetadata, kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1,