From 725b0b25bd4209c4ab76d74e9de6b580eb726863 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 10 Jun 2024 17:01:50 -0700 Subject: [PATCH 01/55] tmp Signed-off-by: Stephanie Wang --- vllm/attention/backends/abstract.py | 6 +- vllm/attention/backends/blocksparse_attn.py | 4 +- vllm/attention/backends/flash_attn.py | 4 +- vllm/attention/backends/flashinfer.py | 4 +- vllm/attention/backends/rocm_flash_attn.py | 4 +- vllm/attention/backends/torch_sdpa.py | 4 +- vllm/attention/backends/xformers.py | 4 +- vllm/executor/gpu_executor.py | 3 +- vllm/executor/multiproc_gpu_executor.py | 4 +- vllm/sequence.py | 120 +++++++++- vllm/worker/embedding_model_runner.py | 110 ++++------ vllm/worker/model_runner.py | 232 ++++++++++---------- vllm/worker/worker.py | 149 +++++++++---- vllm/worker/worker_base.py | 40 +++- 14 files changed, 435 insertions(+), 253 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 6396103bf5efa..40768532f59c2 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -21,9 +21,13 @@ def get_impl_cls() -> Type["AttentionImpl"]: @staticmethod @abstractmethod - def make_metadata(*args, **kwargs) -> "AttentionMetadata": + def get_metadata_cls() -> Type["AttentionMetadata"]: raise NotImplementedError + @classmethod + def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": + return cls.get_metadata_cls()(*args, **kwargs) + @staticmethod @abstractmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index dce2b83615b7a..7b4578fcd8b9d 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -90,8 +90,8 @@ def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: return BlocksparseFlashAttentionImpl @staticmethod - def make_metadata(*args, **kwargs) -> "BlocksparseFlashAttentionMetadata": - return BlocksparseFlashAttentionMetadata(*args, **kwargs) + def get_metadata_cls() -> Type["AttentionMetadata"]: + return BlocksparseFlashAttentionMetadata @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 070c074e511bc..76a43db8ca080 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -25,8 +25,8 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]: return FlashAttentionImpl @staticmethod - def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata": - return FlashAttentionMetadata(*args, **kwargs) + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashAttentionMetadata @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 7b7959d257fac..535d30b55bc9d 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -22,8 +22,8 @@ def get_impl_cls() -> Type["FlashInferImpl"]: return FlashInferImpl @staticmethod - def make_metadata(*args, **kwargs) -> "FlashInferMetadata": - return FlashInferMetadata(*args, **kwargs) + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashInferMetadata @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index e92e6c5e2dc8d..359bddae7e585 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -25,8 +25,8 @@ def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: return ROCmFlashAttentionImpl @staticmethod - def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata": - return ROCmFlashAttentionMetadata(*args, **kwargs) + def get_metadata_cls() -> Type["AttentionMetadata"]: + return ROCmFlashAttentionMetadata @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 9b50adec5244d..da8343e327424 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -23,8 +23,8 @@ def get_impl_cls() -> Type["TorchSDPABackendImpl"]: return TorchSDPABackendImpl @staticmethod - def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata": - return TorchSDPAMetadata(*args, **kwargs) + def get_metadata_cls() -> Type["AttentionMetadata"]: + return TorchSDPAMetadata @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 99a3e88bc07b6..4d108f4668f2e 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -28,8 +28,8 @@ def get_impl_cls() -> Type["XFormersImpl"]: return XFormersImpl @staticmethod - def make_metadata(*args, **kwargs) -> "XFormersMetadata": - return XFormersMetadata(*args, **kwargs) + def get_metadata_cls() -> Type["AttentionMetadata"]: + return XFormersMetadata @staticmethod def get_kv_cache_shape( diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 3ad201f4757ec..c32338d8b3b73 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -88,7 +88,8 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: def execute_model( self, execute_model_req: ExecuteModelRequest ) -> List[Union[SamplerOutput, PoolerOutput]]: - output = self.driver_worker.execute_model(execute_model_req) + model_input = self.driver_worker.prepare_model_input_local(execute_model_req) + output = self.driver_worker.execute_model(model_input) return output def add_lora(self, lora_request: LoRARequest) -> bool: diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index bd1cac2ab9b5b..ad34796fabdf4 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -84,8 +84,8 @@ def _driver_execute_model( Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ - return self.driver_worker.execute_model( - execute_model_req=execute_model_req) + model_input = self.driver_worker.prepare_model_input(execute_model_req) + return self.driver_worker.execute_model(model_input) def _run_workers( self, diff --git a/vllm/sequence.py b/vllm/sequence.py index 2f27bf33b166e..6f9003fbc3e7f 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -2,8 +2,9 @@ import copy import enum from abc import ABC, abstractmethod +import dataclasses from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, Set, Any import torch @@ -14,6 +15,10 @@ from vllm.sampling_params import SamplingParams if TYPE_CHECKING: + from vllm.attention import AttentionMetadata + from vllm.model_executor import SamplingMetadata + from vllm.model_executor.pooling_metadata import PoolingMetadata + from vllm.multimodal import MultiModalData from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics @@ -163,6 +168,8 @@ def get_num_computed_tokens(self) -> int: def update_num_computed_tokens(self, num_new_computed_tokens: int): """Update number of tokens computed so far.""" + # TODO: Check who calls this and make sure it's synchronized across + # driver and workers. self._num_computed_tokens += num_new_computed_tokens assert self._num_computed_tokens <= self.get_len(), ( self._num_computed_tokens, self.get_len()) @@ -841,7 +848,8 @@ def __eq__(self, other: object): @dataclass class ExecuteModelRequest: - """The model execution request.""" + """The model execution request, containing CPU metadata only. The LLM + engine should create an instance of this class for each request batch.""" # The sequence group metadata list. seq_group_metadata_list: List[SequenceGroupMetadata] # Blocks to swap in. List of CPU -> GPU block number. @@ -867,3 +875,111 @@ def clone( num_lookahead_slots=self.num_lookahead_slots, running_queue_size=self.running_queue_size, ) + +@dataclass(frozen=True) +class ModelInput: + """Local inputs to each worker's `execute_model` function. May contain + device-specific data. Different worker backends may have different methods + of converting from the global ExecuteModelRequest produced by the LLM + engine to the worker-local ModelInput objects. + """ + num_seq_groups: int = None + blocks_to_swap_in: torch.Tensor = None + blocks_to_swap_out: torch.Tensor = None + blocks_to_copy: torch.Tensor = None + + input_tokens: torch.Tensor = None + input_positions: torch.Tensor = None + seq_lens: List[int] = None + query_lens: List[int] = None + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Set[LoRARequest] = None + multi_modal_kwargs: Dict[str, torch.Tensor] = None + slot_mapping: torch.Tensor = None + num_prefill_tokens: int = None + num_decode_tokens: int = None + num_prefills: int = None + sampling_metadata: Optional["SamplingMetadata"] = None + attn_metadata: Optional["AttentionMetadata"] = None + pooling_metadata: Optional["PoolingMetadata"] = None + + BROADCASTABLE_FIELDS: List[str] = ( + "num_seq_groups", + "blocks_to_swap_in", + "blocks_to_swap_out", + "blocks_to_copy", + "input_tokens", + "input_positions", + "lora_requests", + "lora_mapping", + "multi_modal_kwargs", + "num_prefill_tokens", + "num_decode_tokens", + "slot_mapping", + "num_prefills", + ) + + @classmethod + def _get_valid_kwargs(cls, + selected_token_indices: Optional[torch.Tensor] = None, + sampling_metadata: Optional["SamplingMetadata"] = None, + attn_backend: Optional["AttentionBackend"] = None, + attn_metadata: Optional["AttentionMetadata"] = None, **kwargs) -> Dict[str, Any]: + from vllm.model_executor import SamplingMetadata + if sampling_metadata is None: + if selected_token_indices is not None: + # Workers do not perform sampling. + sampling_metadata = SamplingMetadata( + seq_groups=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + num_prompts=0, + ) + kwargs["sampling_metadata"] = sampling_metadata + + if attn_metadata is None: + if attn_backend is not None: + + valid_attn_kwargs = {} + for field in dataclasses.fields(attn_backend.get_metadata_cls()): + val = kwargs.get(field.name, None) + if val is not None: + valid_attn_kwargs[field.name] = val + + attn_metadata = attn_backend.make_metadata( + **valid_attn_kwargs + ) + kwargs["attn_metadata"] = attn_metadata + + # Drop extra kwargs that may have been used to initialize other + # values. + valid_kwargs = {} + for field in dataclasses.fields(cls): + val = kwargs.get(field.name, None) + if val is not None: + valid_kwargs[field.name] = val + return valid_kwargs + + + @classmethod + def new(cls, **kwargs) -> "ModelInput": + valid_kwargs = cls._get_valid_kwargs(**kwargs) + return cls(**valid_kwargs) + + def replace(self, **kwargs) -> "ModelInput": + valid_kwargs = self.__class__._get_valid_kwargs(**kwargs) + return dataclasses.replace(self, **valid_kwargs) + + def as_broadcastable_tensor_dict(self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = {} + for field in self.BROADCASTABLE_FIELDS: + val = getattr(self, field, None) + if val is not None: + tensor_dict[field] = val + + if self.sampling_metadata is not None: + tensor_dict["selected_token_indices"] = self.sampling_metadata.selected_token_indices + if self.attn_metadata is not None: + tensor_dict.update(self.attn_metadata.asdict_zerocopy()) + + return tensor_dict diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 465130d10e2f9..8203f148fd261 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -12,7 +12,7 @@ from vllm.lora.request import LoRARequest from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams -from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata +from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata, ModelInput from vllm.worker.model_runner import ModelRunner logger = init_logger(__name__) @@ -47,21 +47,17 @@ def __init__( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + model_input: ModelInput, kv_caches: List[torch.Tensor], ) -> Optional[PoolerOutput]: - (input_tokens, input_positions, attn_metadata, pooling_metadata, - lora_requests, lora_mapping, multi_modal_input - ) = self.prepare_input_tensors(seq_group_metadata_list) - if self.lora_config: - self.set_active_loras(lora_requests, lora_mapping) + self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) # Currently cuda graph is only supported by the decode phase. - prefill_meta = attn_metadata.prefill_metadata - decode_meta = attn_metadata.decode_metadata + prefill_meta = model_input.attn_metadata.prefill_metadata + decode_meta = model_input.attn_metadata.decode_metadata if prefill_meta is None and decode_meta.use_cuda_graph: - graph_batch_size = input_tokens.shape[0] + graph_batch_size = model_input.input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model @@ -70,13 +66,13 @@ def execute_model( kv_caches = [None] * num_layers execute_model_kwargs = { - "input_ids": input_tokens, - "positions": input_positions, + "input_ids": model_input.input_tokens, + "positions": model_input.input_positions, "kv_caches": kv_caches, - "attn_metadata": attn_metadata, + "attn_metadata": model_input.attn_metadata, } if self.vision_language_config: - execute_model_kwargs.update({"image_input": multi_modal_input}) + execute_model_kwargs.update({"image_input": model_input.multi_modal_input}) hidden_states = model_executable(**execute_model_kwargs) # Only perform pooling in the driver worker. @@ -84,66 +80,40 @@ def execute_model( return None return self.model.pooler(hidden_states=hidden_states, - pooling_metadata=pooling_metadata) + pooling_metadata=model_input.pooling_metadata) def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, - Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: - if self.is_driver_worker: - assert seq_group_metadata_list is not None - # Prepare input tensors. - ( - input_tokens, - input_positions, - attn_metadata, - seq_lens, - _, - lora_mapping, - lora_requests, - multi_modal_kwargs, - slot_mapping, - num_prefill_tokens, - num_decode_tokens, - num_prefills, - ) = self._prepare_model_input(seq_group_metadata_list) - # Prepare PoolingMetadata - pooling_metadata = self._prepare_pooling(seq_group_metadata_list, - seq_lens) - - metadata_dict = { - "input_tokens": input_tokens, - "input_positions": input_positions, - "lora_requests": lora_requests, - "lora_mapping": lora_mapping, - "multi_modal_kwargs": multi_modal_kwargs, - "num_prefill_tokens": num_prefill_tokens, - "num_decode_tokens": num_decode_tokens, - "slot_mapping": slot_mapping, - "num_prefills": num_prefills, - } - if attn_metadata: - metadata_dict.update(attn_metadata.asdict_zerocopy()) - broadcast_tensor_dict(metadata_dict, src=0) - else: - metadata_dict = broadcast_tensor_dict(src=0) - input_tokens = metadata_dict.pop("input_tokens") - input_positions = metadata_dict.pop("input_positions") - lora_mapping = metadata_dict.pop("lora_mapping") - lora_requests = metadata_dict.pop("lora_requests") - multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") - if metadata_dict: - attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) - else: - attn_metadata = None - pooling_metadata = PoolingMetadata(seq_groups=None, - seq_data=None, - prompt_lens=None) - - return (input_tokens, input_positions, attn_metadata, pooling_metadata, - lora_requests, lora_mapping, multi_modal_kwargs) + ) -> ModelInput: + assert seq_group_metadata_list is not None + # Prepare input tensors. + ( + input_tokens, + input_positions, + attn_metadata, + seq_lens, + _, + lora_mapping, + lora_requests, + multi_modal_kwargs, + slot_mapping, + num_prefill_tokens, + num_decode_tokens, + num_prefills, + ) = self._prepare_model_input(seq_group_metadata_list) + # Prepare PoolingMetadata + pooling_metadata = self._prepare_pooling(seq_group_metadata_list, + seq_lens) + + return ModelInput( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + pooling_metadata=pooling_metadata, + lora_requests=lora_requests, + lora_mapping=lora_mapping, + multi_modal_kwargs=multi_modal_kwargs) def _prepare_pooling( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c59288b4f73c6..a8af99ff91d0f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -21,7 +21,7 @@ from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata, ModelInput from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) @@ -38,36 +38,36 @@ _NUM_WARMUP_ITERS = 2 -class ModelInput(NamedTuple): - input_tokens: torch.Tensor - input_positions: torch.Tensor - attn_metadata: Optional[AttentionMetadata] - seq_lens: List[int] - query_lens: List[int] - lora_mapping: Optional[LoRAMapping] - lora_requests: Set[LoRARequest] - multi_modal_kwargs: Dict[str, torch.Tensor] - slot_mapping: torch.Tensor - num_prefill_tokens: int - num_decode_tokens: int - num_prefills: int - - @classmethod - def empty(cls, device): - return ModelInput( - input_tokens=torch.empty(0, device=device), - input_positions=torch.empty(0, device=device), - attn_metadata=None, - seq_lens=[], - query_lens=[], - lora_mapping=None, - lora_requests=set(), - multi_modal_kwargs={}, - slot_mapping=torch.empty(0, device=device), - num_prefill_tokens=0, - num_decode_tokens=0, - num_prefills=0, - ) +#class ModelInput(NamedTuple): +# input_tokens: torch.Tensor +# input_positions: torch.Tensor +# attn_metadata: Optional[AttentionMetadata] +# seq_lens: List[int] +# query_lens: List[int] +# lora_mapping: Optional[LoRAMapping] +# lora_requests: Set[LoRARequest] +# multi_modal_kwargs: Dict[str, torch.Tensor] +# slot_mapping: torch.Tensor +# num_prefill_tokens: int +# num_decode_tokens: int +# num_prefills: int +# +# @classmethod +# def empty(cls, device): +# return ModelInput( +# input_tokens=torch.empty(0, device=device), +# input_positions=torch.empty(0, device=device), +# attn_metadata=None, +# seq_lens=[], +# query_lens=[], +# lora_mapping=None, +# lora_requests=set(), +# multi_modal_kwargs={}, +# slot_mapping=torch.empty(0, device=device), +# num_prefill_tokens=0, +# num_decode_tokens=0, +# num_prefills=0, +# ) class ModelRunner: @@ -280,7 +280,7 @@ def _prepare_model_input( paged_kv_last_page_len: List[int] = [] if len(seq_group_metadata_list) == 0: - return ModelInput.empty(self.device) + return ModelInput() if self.sliding_window is not None: sliding_window_blocks = (self.sliding_window + self.block_size - @@ -630,7 +630,11 @@ def _prepare_model_input( for k, v in multi_modal_kwargs_list.items() } - return ModelInput( + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.pin_memory) + + return ModelInput.new( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, attn_metadata=attn_metadata, @@ -643,107 +647,104 @@ def _prepare_model_input( num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, num_prefills=num_prefills, + sampling_metadata=sampling_metadata, ) - def prepare_input_tensors( - self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: - if self.is_driver_worker: - assert seq_group_metadata_list is not None - # Prepare input tensors. - ( - input_tokens, - input_positions, - attn_metadata, - seq_lens, - query_lens, - lora_mapping, - lora_requests, - multi_modal_kwargs, - slot_mapping, - num_prefill_tokens, - num_decode_tokens, - num_prefills, - ) = self._prepare_model_input(seq_group_metadata_list) - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, seq_lens, query_lens, self.device, - self.pin_memory) - - metadata_dict = { - "input_tokens": input_tokens, - "input_positions": input_positions, - "selected_token_indices": - sampling_metadata.selected_token_indices, - "lora_requests": lora_requests, - "lora_mapping": lora_mapping, - "multi_modal_kwargs": multi_modal_kwargs, - "num_prefill_tokens": num_prefill_tokens, - "num_decode_tokens": num_decode_tokens, - "slot_mapping": slot_mapping, - "num_prefills": num_prefills, - } - if attn_metadata: - metadata_dict.update(attn_metadata.asdict_zerocopy()) - broadcast_tensor_dict(metadata_dict, src=0) - else: - metadata_dict = broadcast_tensor_dict(src=0) - input_tokens = metadata_dict.pop("input_tokens") - input_positions = metadata_dict.pop("input_positions") - selected_token_indices = metadata_dict.pop( - "selected_token_indices") - lora_mapping = metadata_dict.pop("lora_mapping") - lora_requests = metadata_dict.pop("lora_requests") - multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") - if metadata_dict: - attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) - else: - attn_metadata = None - sampling_metadata = SamplingMetadata( - seq_groups=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - num_prompts=0, - ) - - return (input_tokens, input_positions, attn_metadata, - sampling_metadata, lora_requests, lora_mapping, - multi_modal_kwargs) + #def prepare_input_tensors( + # self, + # seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + #) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, + # Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: + # if self.is_driver_worker: + # assert seq_group_metadata_list is not None + # # Prepare input tensors. + # ( + # input_tokens, + # input_positions, + # attn_metadata, + # seq_lens, + # query_lens, + # lora_mapping, + # lora_requests, + # multi_modal_kwargs, + # slot_mapping, + # num_prefill_tokens, + # num_decode_tokens, + # num_prefills, + # ) = self._prepare_model_input(seq_group_metadata_list) + # sampling_metadata = SamplingMetadata.prepare( + # seq_group_metadata_list, seq_lens, query_lens, self.device, + # self.pin_memory) + + # metadata_dict = { + # "input_tokens": input_tokens, + # "input_positions": input_positions, + # "selected_token_indices": + # sampling_metadata.selected_token_indices, + # "lora_requests": lora_requests, + # "lora_mapping": lora_mapping, + # "multi_modal_kwargs": multi_modal_kwargs, + # "num_prefill_tokens": num_prefill_tokens, + # "num_decode_tokens": num_decode_tokens, + # "slot_mapping": slot_mapping, + # "num_prefills": num_prefills, + # } + # if attn_metadata: + # metadata_dict.update(attn_metadata.asdict_zerocopy()) + # broadcast_tensor_dict(metadata_dict, src=0) + # else: + # metadata_dict = broadcast_tensor_dict(src=0) + # input_tokens = metadata_dict.pop("input_tokens") + # input_positions = metadata_dict.pop("input_positions") + # selected_token_indices = metadata_dict.pop( + # "selected_token_indices") + # lora_mapping = metadata_dict.pop("lora_mapping") + # lora_requests = metadata_dict.pop("lora_requests") + # multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") + # if metadata_dict: + # attn_metadata = self.attn_backend.make_metadata( + # **metadata_dict) + # else: + # attn_metadata = None + # sampling_metadata = SamplingMetadata( + # seq_groups=None, + # selected_token_indices=selected_token_indices, + # categorized_sample_indices=None, + # num_prompts=0, + # ) + + # return (input_tokens, input_positions, attn_metadata, + # sampling_metadata, lora_requests, lora_mapping, + # multi_modal_kwargs) @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + model_input: ModelInput, kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: - (input_tokens, input_positions, attn_metadata, sampling_metadata, - lora_requests, lora_mapping, multi_modal_kwargs - ) = self.prepare_input_tensors(seq_group_metadata_list) - if self.lora_config: - self.set_active_loras(lora_requests, lora_mapping) + self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) # Currently cuda graph is only supported by the decode phase. - prefill_meta = attn_metadata.prefill_metadata - decode_meta = attn_metadata.decode_metadata + prefill_meta = model_input.attn_metadata.prefill_metadata + decode_meta = model_input.attn_metadata.decode_metadata if prefill_meta is None and decode_meta.use_cuda_graph: - graph_batch_size = input_tokens.shape[0] + graph_batch_size = model_input.input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model hidden_states = model_executable( - input_ids=input_tokens, - positions=input_positions, + input_ids=model_input.input_tokens, + positions=model_input.input_positions, kv_caches=kv_caches, - attn_metadata=attn_metadata, - **multi_modal_kwargs, + attn_metadata=model_input.attn_metadata, + **model_input.multi_modal_kwargs, ) # Compute the logits. - logits = self.model.compute_logits(hidden_states, sampling_metadata) + logits = self.model.compute_logits(hidden_states, model_input.sampling_metadata) # Only perform sampling in the driver worker. if not self.is_driver_worker: @@ -752,7 +753,7 @@ def execute_model( # Sample the next token. output = self.model.sample( logits=logits, - sampling_metadata=sampling_metadata, + sampling_metadata=model_input.sampling_metadata, ) return output @@ -829,7 +830,8 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - self.execute_model(seqs, kv_caches) + model_input = self._prepare_model_input(seqs) + self.execute_model(model_input, kv_caches) torch.cuda.synchronize() return diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 10411a2bf7a10..99387fad6f8c5 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -15,7 +15,7 @@ set_custom_all_reduce) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput +from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput, ModelInput from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.model_runner import ModelRunner @@ -209,38 +209,27 @@ def _warm_up_model(self) -> None: def cache_swap( self, - blocks_to_swap_in: torch.Tensor, - blocks_to_swap_out: torch.Tensor, - blocks_to_copy: torch.Tensor, + blocks_to_swap_in: Optional[torch.Tensor], + blocks_to_swap_out: Optional[torch.Tensor], + blocks_to_copy: Optional[torch.Tensor], ) -> None: # Issue cache operations. - if blocks_to_swap_in.numel() > 0: + if blocks_to_swap_in is not None: self.cache_engine.swap_in(blocks_to_swap_in) - if blocks_to_swap_out.numel() > 0: + if blocks_to_swap_out is not None: self.cache_engine.swap_out(blocks_to_swap_out) - if blocks_to_copy.numel() > 0: + if blocks_to_copy is not None: self.cache_engine.copy(blocks_to_copy) @torch.inference_mode() - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[Union[SamplerOutput, PoolerOutput]]: - if not self.is_driver_worker: - self._execute_model_non_driver() - return [] - - if execute_model_req is None: - # This signals that there's no more requests to process for now. - # All workers are running infinite loop with broadcast_tensor_dict, - # and it stops the loop when the driver broadcasts an empty input. - # Send an empty input to notify all other workers to stop their - # execution loop. - broadcast_tensor_dict({}, src=0) - return [] - - seq_group_metadata_list = execute_model_req.seq_group_metadata_list - num_seq_groups = len(seq_group_metadata_list) + def prepare_model_input_local( + self, + execute_model_req: ExecuteModelRequest) -> ModelInput: + model_input = self.model_runner._prepare_model_input( + execute_model_req.seq_group_metadata_list + ) + + num_seq_groups = len(execute_model_req.seq_group_metadata_list) # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. # they contain parameters to launch cudamemcpyasync. blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, @@ -255,37 +244,105 @@ def execute_model( blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, device=self.device, dtype=torch.int64).view(-1, 2) - data: Dict[str, Any] = { - "num_seq_groups": num_seq_groups, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - } - broadcast_tensor_dict(data, src=0) - self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) + return model_input.replace( + num_seq_groups=num_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ) + + @torch.inference_mode() + def prepare_model_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> ModelInput: + if self.is_driver_worker: + if execute_model_req is None: + # This signals that there's no more requests to process for now. + # All workers are running infinite loop with broadcast_tensor_dict, + # and it stops the loop when the driver broadcasts an empty input. + # Send an empty input to notify all other workers to stop their + # execution loop. + broadcast_tensor_dict({}, src=0) + return None + + model_input = self.prepare_model_input_local(execute_model_req) + metadata_dict = model_input.as_broadcastable_tensor_dict() + broadcast_tensor_dict(metadata_dict, src=0) + else: + metadata_dict = broadcast_tensor_dict(src=0) + if not metadata_dict: + return None + + model_input = ModelInput.new( + attn_backend=self.model_runner.attn_backend, + **metadata_dict) + return model_input + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInput, + ) -> List[Union[SamplerOutput, PoolerOutput]]: + #if not self.is_driver_worker: + # self._execute_model_non_driver() + # return [] + + #if execute_model_req is None: + # # This signals that there's no more requests to process for now. + # # All workers are running infinite loop with broadcast_tensor_dict, + # # and it stops the loop when the driver broadcasts an empty input. + # # Send an empty input to notify all other workers to stop their + # # execution loop. + # broadcast_tensor_dict({}, src=0) + # return [] + + #seq_group_metadata_list = execute_model_req.seq_group_metadata_list + #num_seq_groups = len(seq_group_metadata_list) + ## `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. + ## they contain parameters to launch cudamemcpyasync. + #blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, + # device="cpu", + # dtype=torch.int64).view(-1, 2) + #blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, + # device="cpu", + # dtype=torch.int64).view(-1, 2) + ## `blocks_to_copy` is a gpu tensor. The src and tgt of + ## blocks to copy are in the same device, and `blocks_to_copy` + ## can be used directly within cuda kernels. + #blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + # device=self.device, + # dtype=torch.int64).view(-1, 2) + #data: Dict[str, Any] = { + # "num_seq_groups": num_seq_groups, + # "blocks_to_swap_in": blocks_to_swap_in, + # "blocks_to_swap_out": blocks_to_swap_out, + # "blocks_to_copy": blocks_to_copy, + #} + #broadcast_tensor_dict(data, src=0) + + #self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) + + #output = self.model_runner.execute_model(seq_group_metadata_list, + # self.gpu_cache) + + self.cache_swap( + model_input.blocks_to_swap_in, + model_input.blocks_to_swap_out, + model_input.blocks_to_copy) # If there is no input, we don't need to execute the model. - if num_seq_groups == 0: + if model_input.num_seq_groups == 0: return [] - output = self.model_runner.execute_model(seq_group_metadata_list, + output = self.model_runner.execute_model(model_input, self.gpu_cache) # Worker only supports single-step execution. Wrap the output in a list # to conform to interface. return [output] - @torch.inference_mode() - def start_worker_execution_loop(self) -> None: - """Execute model loop in parallel worker. - - You can stop the loop by executing a driver worker with an empty output. - See `stop_remote_worker_execution_loop` for more details. - """ - while self._execute_model_non_driver(): - pass - def _execute_model_non_driver(self) -> bool: """Execute model in parallel worker. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 258f31de17d87..b504472dc32c3 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -2,10 +2,11 @@ import os from abc import ABC, abstractmethod from typing import Dict, List, Optional, Set, Tuple +import torch from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest, SamplerOutput, ModelInput from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) @@ -46,11 +47,42 @@ def initialize_cache(self, num_gpu_blocks: int, """ raise NotImplementedError + @torch.inference_mode() + def start_worker_execution_loop(self) -> None: + """Execute model loop in parallel worker. + + You can stop the loop by executing a driver worker with an empty output. + See `stop_remote_worker_execution_loop` for more details. + """ + while True: + model_input = self.prepare_model_input(execute_model_req=None) + + if model_input is None: + return + + return self.execute_model(model_input) + + @abstractmethod + def prepare_model_input_local(self, execute_model_req: ExecuteModelRequest) -> ModelInput: + """ + Prepare a model execution request locally. This method is not allowed + to communicate with external devices. + """ + raise NotImplementedError + + @abstractmethod + def prepare_model_input(self, execute_model_req: Optional[ExecuteModelRequest] = None) -> ModelInput: + """ + Prepare a model execution request. Communication with other workers + may occur to produce the model input that should be passed to + execute_model. + """ + raise NotImplementedError + @abstractmethod def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: + self, + model_input: ModelInput) -> List[SamplerOutput]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" raise NotImplementedError From b74eb105d99b737be47bdd7796bd2f758a0022fc Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 10 Jun 2024 17:25:15 -0700 Subject: [PATCH 02/55] fix Signed-off-by: Stephanie Wang --- vllm/worker/worker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 99387fad6f8c5..b4ace3eb08fe4 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -214,11 +214,11 @@ def cache_swap( blocks_to_copy: Optional[torch.Tensor], ) -> None: # Issue cache operations. - if blocks_to_swap_in is not None: + if blocks_to_swap_in is not None and blocks_to_swap_in.numel() > 0: self.cache_engine.swap_in(blocks_to_swap_in) - if blocks_to_swap_out is not None: + if blocks_to_swap_out is not None and blocks_to_swap_out.numel() > 0: self.cache_engine.swap_out(blocks_to_swap_out) - if blocks_to_copy is not None: + if blocks_to_copy is not None and blocks_to_copy.numel() > 0: self.cache_engine.copy(blocks_to_copy) @torch.inference_mode() From 38b0ddf87b7c5fbfb23ab3e58f24bc994eb66cb5 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 10 Jun 2024 17:56:33 -0700 Subject: [PATCH 03/55] ray and mp backends work Signed-off-by: Stephanie Wang --- vllm/executor/distributed_gpu_executor.py | 4 ++-- vllm/executor/multiproc_gpu_executor.py | 4 +++- vllm/executor/ray_gpu_executor.py | 8 +++++--- vllm/sequence.py | 1 - vllm/worker/worker_base.py | 2 +- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index f7c608af1ad39..c4ed3007e2b82 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -79,7 +79,7 @@ def stop_remote_worker_execution_loop(self) -> None: if self.parallel_worker_tasks is None: return - self._driver_execute_model() + self._driver_execute_model(execute_model_req=None) parallel_worker_tasks = self.parallel_worker_tasks self.parallel_worker_tasks = None # Ensure that workers exit model loop cleanly @@ -117,7 +117,7 @@ def save_sharded_state( @abstractmethod def _driver_execute_model( self, - execute_model_req: Optional[ExecuteModelRequest] = None + execute_model_req: Optional[ExecuteModelRequest] ) -> List[SamplerOutput]: """Run execute_model in the driver worker. diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index ad34796fabdf4..0e49f04963881 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -77,7 +77,7 @@ def shutdown(self): def _driver_execute_model( self, - execute_model_req: Optional[ExecuteModelRequest] = None + execute_model_req: Optional[ExecuteModelRequest] ) -> List[SamplerOutput]: """Run execute_model in the driver worker. @@ -85,6 +85,8 @@ def _driver_execute_model( loop running in each of the remote workers. """ model_input = self.driver_worker.prepare_model_input(execute_model_req) + if model_input is None: + return None return self.driver_worker.execute_model(model_input) def _run_workers( diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index bed356d1b6e58..11e55d9fa5771 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -175,15 +175,17 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", def _driver_execute_model( self, - execute_model_req: Optional[ExecuteModelRequest] = None + execute_model_req: Optional[ExecuteModelRequest] ) -> List[SamplerOutput]: """Run execute_model in the driver worker. Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ - return self.driver_worker.execute_method("execute_model", - execute_model_req) + model_input = self.driver_worker.execute_method("prepare_model_input", execute_model_req) + if model_input is None: + return + return self.driver_worker.execute_method("execute_model", model_input) def _run_workers( self, diff --git a/vllm/sequence.py b/vllm/sequence.py index 6f9003fbc3e7f..c392e5c8b76fa 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -960,7 +960,6 @@ def _get_valid_kwargs(cls, valid_kwargs[field.name] = val return valid_kwargs - @classmethod def new(cls, **kwargs) -> "ModelInput": valid_kwargs = cls._get_valid_kwargs(**kwargs) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index b504472dc32c3..3ce375e8a3ea8 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -60,7 +60,7 @@ def start_worker_execution_loop(self) -> None: if model_input is None: return - return self.execute_model(model_input) + self.execute_model(model_input) @abstractmethod def prepare_model_input_local(self, execute_model_req: ExecuteModelRequest) -> ModelInput: From 0d11e922b8fb0a433e3af51a538af5c66d96e5c3 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 10 Jun 2024 19:57:15 -0700 Subject: [PATCH 04/55] embedding model runner works Signed-off-by: Stephanie Wang --- vllm/worker/embedding_model_runner.py | 29 ++-------- vllm/worker/model_runner.py | 83 ++++----------------------- vllm/worker/worker.py | 2 +- 3 files changed, 17 insertions(+), 97 deletions(-) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 8203f148fd261..d75c5eee261b6 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -82,38 +82,19 @@ def execute_model( return self.model.pooler(hidden_states=hidden_states, pooling_metadata=model_input.pooling_metadata) - def prepare_input_tensors( + def prepare_model_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> ModelInput: assert seq_group_metadata_list is not None - # Prepare input tensors. - ( - input_tokens, - input_positions, - attn_metadata, - seq_lens, - _, - lora_mapping, - lora_requests, - multi_modal_kwargs, - slot_mapping, - num_prefill_tokens, - num_decode_tokens, - num_prefills, - ) = self._prepare_model_input(seq_group_metadata_list) + model_input = self._prepare_model_input_tensors(seq_group_metadata_list) # Prepare PoolingMetadata pooling_metadata = self._prepare_pooling(seq_group_metadata_list, - seq_lens) + model_input.seq_lens) - return ModelInput( - input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata, + return model_input.replace( pooling_metadata=pooling_metadata, - lora_requests=lora_requests, - lora_mapping=lora_mapping, - multi_modal_kwargs=multi_modal_kwargs) + ) def _prepare_pooling( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a8af99ff91d0f..4f9e0d32872eb 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -225,7 +225,7 @@ def get_max_block_per_batch(self) -> int: block_size = self.block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size - def _prepare_model_input( + def _prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> ModelInput: @@ -630,10 +630,6 @@ def _prepare_model_input( for k, v in multi_modal_kwargs_list.items() } - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, seq_lens, query_lens, self.device, - self.pin_memory) - return ModelInput.new( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, @@ -647,75 +643,18 @@ def _prepare_model_input( num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, num_prefills=num_prefills, - sampling_metadata=sampling_metadata, ) - #def prepare_input_tensors( - # self, - # seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - #) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - # Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: - # if self.is_driver_worker: - # assert seq_group_metadata_list is not None - # # Prepare input tensors. - # ( - # input_tokens, - # input_positions, - # attn_metadata, - # seq_lens, - # query_lens, - # lora_mapping, - # lora_requests, - # multi_modal_kwargs, - # slot_mapping, - # num_prefill_tokens, - # num_decode_tokens, - # num_prefills, - # ) = self._prepare_model_input(seq_group_metadata_list) - # sampling_metadata = SamplingMetadata.prepare( - # seq_group_metadata_list, seq_lens, query_lens, self.device, - # self.pin_memory) - - # metadata_dict = { - # "input_tokens": input_tokens, - # "input_positions": input_positions, - # "selected_token_indices": - # sampling_metadata.selected_token_indices, - # "lora_requests": lora_requests, - # "lora_mapping": lora_mapping, - # "multi_modal_kwargs": multi_modal_kwargs, - # "num_prefill_tokens": num_prefill_tokens, - # "num_decode_tokens": num_decode_tokens, - # "slot_mapping": slot_mapping, - # "num_prefills": num_prefills, - # } - # if attn_metadata: - # metadata_dict.update(attn_metadata.asdict_zerocopy()) - # broadcast_tensor_dict(metadata_dict, src=0) - # else: - # metadata_dict = broadcast_tensor_dict(src=0) - # input_tokens = metadata_dict.pop("input_tokens") - # input_positions = metadata_dict.pop("input_positions") - # selected_token_indices = metadata_dict.pop( - # "selected_token_indices") - # lora_mapping = metadata_dict.pop("lora_mapping") - # lora_requests = metadata_dict.pop("lora_requests") - # multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs") - # if metadata_dict: - # attn_metadata = self.attn_backend.make_metadata( - # **metadata_dict) - # else: - # attn_metadata = None - # sampling_metadata = SamplingMetadata( - # seq_groups=None, - # selected_token_indices=selected_token_indices, - # categorized_sample_indices=None, - # num_prompts=0, - # ) - - # return (input_tokens, input_positions, attn_metadata, - # sampling_metadata, lora_requests, lora_mapping, - # multi_modal_kwargs) + + def prepare_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> ModelInput: + model_input = self._prepare_model_input(seq_group_metadata_list) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.pin_memory) + return model_input.replace(sampling_metadata=sampling_metadata) @torch.inference_mode() def execute_model( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index b4ace3eb08fe4..66b0859ce64d8 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -225,7 +225,7 @@ def cache_swap( def prepare_model_input_local( self, execute_model_req: ExecuteModelRequest) -> ModelInput: - model_input = self.model_runner._prepare_model_input( + model_input = self.model_runner.prepare_model_input_tensors( execute_model_req.seq_group_metadata_list ) From 2cdc2180ee55a9a5359c34a01bf03f2be716c17a Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 10 Jun 2024 20:26:38 -0700 Subject: [PATCH 05/55] GPU executor works Signed-off-by: Stephanie Wang --- vllm/sequence.py | 53 +++++++-------- vllm/worker/embedding_model_runner.py | 10 ++- vllm/worker/model_runner.py | 96 ++++++++++++++++----------- 3 files changed, 88 insertions(+), 71 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index c392e5c8b76fa..d68e121aa1981 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -882,6 +882,11 @@ class ModelInput: device-specific data. Different worker backends may have different methods of converting from the global ExecuteModelRequest produced by the LLM engine to the worker-local ModelInput objects. + + This base class contains metadata needed for the base model forward pass + but not metadata for possible additional steps, e.g., sampling. Model + runners that run additional steps should subclass this method to add + additional fields. """ num_seq_groups: int = None blocks_to_swap_in: torch.Tensor = None @@ -920,53 +925,41 @@ class ModelInput: ) @classmethod - def _get_valid_kwargs(cls, - selected_token_indices: Optional[torch.Tensor] = None, - sampling_metadata: Optional["SamplingMetadata"] = None, + def _get_init_kwargs(cls, attn_backend: Optional["AttentionBackend"] = None, attn_metadata: Optional["AttentionMetadata"] = None, **kwargs) -> Dict[str, Any]: - from vllm.model_executor import SamplingMetadata - if sampling_metadata is None: - if selected_token_indices is not None: - # Workers do not perform sampling. - sampling_metadata = SamplingMetadata( - seq_groups=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - num_prompts=0, - ) - kwargs["sampling_metadata"] = sampling_metadata - if attn_metadata is None: + # Extract the fields used to create AttentionMetadata. if attn_backend is not None: - valid_attn_kwargs = {} for field in dataclasses.fields(attn_backend.get_metadata_cls()): - val = kwargs.get(field.name, None) + val = kwargs.pop(field.name, None) if val is not None: valid_attn_kwargs[field.name] = val attn_metadata = attn_backend.make_metadata( **valid_attn_kwargs ) - kwargs["attn_metadata"] = attn_metadata + if attn_metadata is not None: + kwargs["attn_metadata"] = attn_metadata - # Drop extra kwargs that may have been used to initialize other - # values. - valid_kwargs = {} - for field in dataclasses.fields(cls): - val = kwargs.get(field.name, None) - if val is not None: - valid_kwargs[field.name] = val - return valid_kwargs + return kwargs @classmethod - def new(cls, **kwargs) -> "ModelInput": - valid_kwargs = cls._get_valid_kwargs(**kwargs) - return cls(**valid_kwargs) + def new(cls, clone: Optional["ModelInput"] = None, **kwargs) -> "ModelInput": + clone_kwargs = {} + if clone is not None: + for field in dataclasses.fields(clone): + val = getattr(clone, field.name) + if val is not None: + clone_kwargs[field.name] = val + clone_kwargs = cls._get_init_kwargs(**clone_kwargs) + + kwargs = cls._get_init_kwargs(**kwargs) + return cls(**clone_kwargs, **kwargs) def replace(self, **kwargs) -> "ModelInput": - valid_kwargs = self.__class__._get_valid_kwargs(**kwargs) + valid_kwargs = self.__class__._get_init_kwargs(**kwargs) return dataclasses.replace(self, **valid_kwargs) def as_broadcastable_tensor_dict(self) -> Dict[str, Union[int, torch.Tensor]]: diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index d75c5eee261b6..3a13540f3cfd0 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -1,4 +1,5 @@ from typing import Dict, List, Optional, Set, Tuple +from dataclasses import dataclass import torch @@ -18,6 +19,10 @@ logger = init_logger(__name__) +@dataclass(frozen=True) +class ModelInputWithPoolingMetadata(ModelInput): + pooling_metadata: Optional["SamplingMetadata"] = None + class EmbeddingModelRunner(ModelRunner): def __init__( @@ -88,11 +93,12 @@ def prepare_model_input_tensors( ) -> ModelInput: assert seq_group_metadata_list is not None model_input = self._prepare_model_input_tensors(seq_group_metadata_list) - # Prepare PoolingMetadata + # Prepare PoolingMetadata. pooling_metadata = self._prepare_pooling(seq_group_metadata_list, model_input.seq_lens) - return model_input.replace( + return ModelInputWithPoolingMetadata.new( + clone=model_input, pooling_metadata=pooling_metadata, ) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4f9e0d32872eb..39e55135dde9c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,7 +1,8 @@ import time import warnings from collections import defaultdict -from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union +from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union, Any +from dataclasses import dataclass import numpy as np import torch @@ -38,37 +39,37 @@ _NUM_WARMUP_ITERS = 2 -#class ModelInput(NamedTuple): -# input_tokens: torch.Tensor -# input_positions: torch.Tensor -# attn_metadata: Optional[AttentionMetadata] -# seq_lens: List[int] -# query_lens: List[int] -# lora_mapping: Optional[LoRAMapping] -# lora_requests: Set[LoRARequest] -# multi_modal_kwargs: Dict[str, torch.Tensor] -# slot_mapping: torch.Tensor -# num_prefill_tokens: int -# num_decode_tokens: int -# num_prefills: int -# -# @classmethod -# def empty(cls, device): -# return ModelInput( -# input_tokens=torch.empty(0, device=device), -# input_positions=torch.empty(0, device=device), -# attn_metadata=None, -# seq_lens=[], -# query_lens=[], -# lora_mapping=None, -# lora_requests=set(), -# multi_modal_kwargs={}, -# slot_mapping=torch.empty(0, device=device), -# num_prefill_tokens=0, -# num_decode_tokens=0, -# num_prefills=0, -# ) - +@dataclass(frozen=True) +class ModelInputWithSamplingMetadata(ModelInput): + # Metadata for sampling outputs. + sampling_metadata: Optional["SamplingMetadata"] = None + + @classmethod + def _get_init_kwargs(cls, + selected_token_indices: Optional[torch.Tensor] = None, + sampling_metadata: Optional["SamplingMetadata"] = None, + **kwargs) -> Dict[str, Any]: + from vllm.model_executor import SamplingMetadata + if sampling_metadata is None: + if selected_token_indices is not None: + # Workers do not perform sampling. + sampling_metadata = SamplingMetadata( + seq_groups=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + num_prompts=0, + ) + if sampling_metadata is not None: + kwargs["sampling_metadata"] = sampling_metadata + return super()._get_init_kwargs(**kwargs) + + def as_broadcastable_tensor_dict(self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = super().as_broadcastable_tensor_dict() + + if self.sampling_metadata is not None: + tensor_dict["selected_token_indices"] = self.sampling_metadata.selected_token_indices + + return tensor_dict class ModelRunner: @@ -229,7 +230,9 @@ def _prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> ModelInput: - """Prepare the model input based on a given sequence group. + """Helper method to prepare the model input based on a given sequence + group. Prepares metadata needed for the base model forward pass but not + metadata for possible additional steps, e.g., sampling. The API assumes seq_group_metadata_list is sorted by prefill -> decode. @@ -649,12 +652,27 @@ def _prepare_model_input_tensors( def prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> ModelInput: - model_input = self._prepare_model_input(seq_group_metadata_list) + ) -> ModelInputWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + + If cuda graph is required, this API automatically pads inputs. + """ + model_input = self._prepare_model_input_tensors(seq_group_metadata_list) sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, seq_lens, query_lens, self.device, - self.pin_memory) - return model_input.replace(sampling_metadata=sampling_metadata) + seq_group_metadata_list, model_input.seq_lens, + model_input.query_lens, self.device, self.pin_memory) + return ModelInputWithSamplingMetadata.new( + clone=model_input, + sampling_metadata=sampling_metadata) @torch.inference_mode() def execute_model( @@ -769,7 +787,7 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - model_input = self._prepare_model_input(seqs) + model_input = self.prepare_model_input_tensors(seqs) self.execute_model(model_input, kv_caches) torch.cuda.synchronize() return From c728512cfd36bdfb50736e8b6a144a637944e5ca Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 10 Jun 2024 20:29:47 -0700 Subject: [PATCH 06/55] remove comment Signed-off-by: Stephanie Wang --- vllm/worker/worker.py | 42 ------------------------------------------ 1 file changed, 42 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 66b0859ce64d8..a1f744df1598d 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -285,48 +285,6 @@ def execute_model( self, model_input: ModelInput, ) -> List[Union[SamplerOutput, PoolerOutput]]: - #if not self.is_driver_worker: - # self._execute_model_non_driver() - # return [] - - #if execute_model_req is None: - # # This signals that there's no more requests to process for now. - # # All workers are running infinite loop with broadcast_tensor_dict, - # # and it stops the loop when the driver broadcasts an empty input. - # # Send an empty input to notify all other workers to stop their - # # execution loop. - # broadcast_tensor_dict({}, src=0) - # return [] - - #seq_group_metadata_list = execute_model_req.seq_group_metadata_list - #num_seq_groups = len(seq_group_metadata_list) - ## `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. - ## they contain parameters to launch cudamemcpyasync. - #blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, - # device="cpu", - # dtype=torch.int64).view(-1, 2) - #blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, - # device="cpu", - # dtype=torch.int64).view(-1, 2) - ## `blocks_to_copy` is a gpu tensor. The src and tgt of - ## blocks to copy are in the same device, and `blocks_to_copy` - ## can be used directly within cuda kernels. - #blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, - # device=self.device, - # dtype=torch.int64).view(-1, 2) - #data: Dict[str, Any] = { - # "num_seq_groups": num_seq_groups, - # "blocks_to_swap_in": blocks_to_swap_in, - # "blocks_to_swap_out": blocks_to_swap_out, - # "blocks_to_copy": blocks_to_copy, - #} - #broadcast_tensor_dict(data, src=0) - - #self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) - - #output = self.model_runner.execute_model(seq_group_metadata_list, - # self.gpu_cache) - self.cache_swap( model_input.blocks_to_swap_in, model_input.blocks_to_swap_out, From 2bf752b5ef6bad6ab8d25320b7dc7c101ea1824f Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 10 Jun 2024 20:53:39 -0700 Subject: [PATCH 07/55] use the right ModelInput class Signed-off-by: Stephanie Wang --- vllm/sequence.py | 39 ++++++++++++++++++++++++--- vllm/worker/embedding_model_runner.py | 5 +++- vllm/worker/model_runner.py | 39 ++++----------------------- vllm/worker/worker.py | 3 ++- 4 files changed, 46 insertions(+), 40 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index d68e121aa1981..02b5971c723dc 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -904,9 +904,7 @@ class ModelInput: num_prefill_tokens: int = None num_decode_tokens: int = None num_prefills: int = None - sampling_metadata: Optional["SamplingMetadata"] = None attn_metadata: Optional["AttentionMetadata"] = None - pooling_metadata: Optional["PoolingMetadata"] = None BROADCASTABLE_FIELDS: List[str] = ( "num_seq_groups", @@ -969,9 +967,42 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Union[int, torch.Tensor]]: if val is not None: tensor_dict[field] = val - if self.sampling_metadata is not None: - tensor_dict["selected_token_indices"] = self.sampling_metadata.selected_token_indices if self.attn_metadata is not None: tensor_dict.update(self.attn_metadata.asdict_zerocopy()) return tensor_dict + + +@dataclass(frozen=True) +class ModelInputWithSamplingMetadata(ModelInput): + # Metadata for sampling outputs. + sampling_metadata: Optional["SamplingMetadata"] = None + + @classmethod + def _get_init_kwargs(cls, + selected_token_indices: Optional[torch.Tensor] = None, + sampling_metadata: Optional["SamplingMetadata"] = None, + **kwargs) -> Dict[str, Any]: + if sampling_metadata is None: + if selected_token_indices is not None: + from vllm.model_executor import SamplingMetadata + + # Workers do not perform sampling. + sampling_metadata = SamplingMetadata( + seq_groups=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + num_prompts=0, + ) + if sampling_metadata is not None: + kwargs["sampling_metadata"] = sampling_metadata + return super()._get_init_kwargs(**kwargs) + + def as_broadcastable_tensor_dict(self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = super().as_broadcastable_tensor_dict() + + if self.sampling_metadata is not None: + tensor_dict["selected_token_indices"] = self.sampling_metadata.selected_token_indices + + return tensor_dict + diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 3a13540f3cfd0..9d3fc1a6fbdcd 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -49,10 +49,13 @@ def __init__( is_driver_worker=is_driver_worker, vision_language_config=vision_language_config) + def get_empty_model_input(self) -> ModelInputWithPoolingMetadata: + return ModelInputWithPoolingMetadata.new() + @torch.inference_mode() def execute_model( self, - model_input: ModelInput, + model_input: ModelInputWithPoolingMetadata, kv_caches: List[torch.Tensor], ) -> Optional[PoolerOutput]: if self.lora_config: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 39e55135dde9c..b30407dc09cec 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -22,7 +22,7 @@ from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata, ModelInput +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata, ModelInput, ModelInputWithSamplingMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) @@ -39,38 +39,6 @@ _NUM_WARMUP_ITERS = 2 -@dataclass(frozen=True) -class ModelInputWithSamplingMetadata(ModelInput): - # Metadata for sampling outputs. - sampling_metadata: Optional["SamplingMetadata"] = None - - @classmethod - def _get_init_kwargs(cls, - selected_token_indices: Optional[torch.Tensor] = None, - sampling_metadata: Optional["SamplingMetadata"] = None, - **kwargs) -> Dict[str, Any]: - from vllm.model_executor import SamplingMetadata - if sampling_metadata is None: - if selected_token_indices is not None: - # Workers do not perform sampling. - sampling_metadata = SamplingMetadata( - seq_groups=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - num_prompts=0, - ) - if sampling_metadata is not None: - kwargs["sampling_metadata"] = sampling_metadata - return super()._get_init_kwargs(**kwargs) - - def as_broadcastable_tensor_dict(self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = super().as_broadcastable_tensor_dict() - - if self.sampling_metadata is not None: - tensor_dict["selected_token_indices"] = self.sampling_metadata.selected_token_indices - - return tensor_dict - class ModelRunner: def __init__( @@ -674,10 +642,13 @@ def prepare_model_input_tensors( clone=model_input, sampling_metadata=sampling_metadata) + def get_empty_model_input(self) -> ModelInputWithSamplingMetadata: + return ModelInputWithSamplingMetadata.new() + @torch.inference_mode() def execute_model( self, - model_input: ModelInput, + model_input: ModelInputWithSamplingMetadata, kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: if self.lora_config: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index a1f744df1598d..d4b6a1ddc5465 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -275,7 +275,8 @@ def prepare_model_input( if not metadata_dict: return None - model_input = ModelInput.new( + model_input = self.model_runner.get_empty_model_input() + model_input = model_input.new( attn_backend=self.model_runner.attn_backend, **metadata_dict) return model_input From f35a23f86735ab987aac8b9c4444fa77cf9bc325 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 10 Jun 2024 21:05:41 -0700 Subject: [PATCH 08/55] CPU worker Signed-off-by: Stephanie Wang --- vllm/executor/cpu_executor.py | 3 +- vllm/worker/cpu_model_runner.py | 152 +++++++++++++++++++------------- vllm/worker/cpu_worker.py | 85 +++++++++++------- 3 files changed, 146 insertions(+), 94 deletions(-) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index a2212459f034e..d7d269a85665a 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -75,7 +75,8 @@ def initialize_cache(self, num_gpu_blocks: int, def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - output = self.driver_worker.execute_model(execute_model_req) + model_input = self.driver_worker.prepare_model_input_local(execute_model_req) + output = self.driver_worker.execute_model(model_input) return output def add_lora(self, lora_request: LoRARequest) -> bool: diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index eaf43247d4fc5..3ec12ad8096fb 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -13,7 +13,7 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import SamplerOutput, SequenceGroupMetadata, ModelInputWithSamplingMetadata from vllm.utils import make_tensor_with_pad logger = init_logger(__name__) @@ -270,86 +270,112 @@ def _prepare_decode( attn_metadata, ) - def prepare_input_tensors( + def prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Optional[Dict[str, torch.Tensor]]]: + ) -> ModelInputWithSamplingMetadata: multi_modal_kwargs = None - if self.is_driver_worker: - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, - attn_metadata) = self._prepare_decode(seq_group_metadata_list) - seq_lens = [] - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - # query_lens is not needed if chunked prefill is not - # supported. Since CPU worker doesn't support chunked prefill - # just use seq_lens instead. - seq_lens, - self.device, - pin_memory=False) - # Broadcast the metadata. - metadata_dict = { - "input_tokens": input_tokens, - "input_positions": input_positions, - "selected_token_indices": - sampling_metadata.selected_token_indices, - } - metadata_dict.update(attn_metadata.asdict_zerocopy()) - broadcast_tensor_dict(metadata_dict, src=0) + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_kwargs + ) = self._prepare_prompt(seq_group_metadata_list) else: - metadata_dict = broadcast_tensor_dict(src=0) - input_tokens = metadata_dict.pop("input_tokens") - input_positions = metadata_dict.pop("input_positions") - selected_token_indices = metadata_dict.pop( - "selected_token_indices") - attn_metadata = self.attn_backend.make_metadata(**metadata_dict) - sampling_metadata = SamplingMetadata( - seq_groups=None, - seq_data=None, - seq_lens=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - generators=None, - ) - - return (input_tokens, input_positions, attn_metadata, - sampling_metadata, multi_modal_kwargs) + (input_tokens, input_positions, + attn_metadata) = self._prepare_decode(seq_group_metadata_list) + seq_lens = [] + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + # query_lens is not needed if chunked prefill is not + # supported. Since CPU worker doesn't support chunked prefill + # just use seq_lens instead. + seq_lens, + self.device, + pin_memory=False) + return ModelInputWithSamplingMetadata.new( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + sampling_metadata=sampling_metadata, + ) + +# if self.is_driver_worker: +# # NOTE: We assume that all sequences in the group are all prompts or +# # all decodes. +# is_prompt = seq_group_metadata_list[0].is_prompt +# # Prepare input tensors. +# if is_prompt: +# (input_tokens, input_positions, attn_metadata, seq_lens, +# multi_modal_kwargs +# ) = self._prepare_prompt(seq_group_metadata_list) +# else: +# (input_tokens, input_positions, +# attn_metadata) = self._prepare_decode(seq_group_metadata_list) +# seq_lens = [] +# sampling_metadata = SamplingMetadata.prepare( +# seq_group_metadata_list, +# seq_lens, +# # query_lens is not needed if chunked prefill is not +# # supported. Since CPU worker doesn't support chunked prefill +# # just use seq_lens instead. +# seq_lens, +# self.device, +# pin_memory=False) +# # Broadcast the metadata. +# metadata_dict = { +# "input_tokens": input_tokens, +# "input_positions": input_positions, +# "selected_token_indices": +# sampling_metadata.selected_token_indices, +# } +# metadata_dict.update(attn_metadata.asdict_zerocopy()) +# broadcast_tensor_dict(metadata_dict, src=0) +# else: +# metadata_dict = broadcast_tensor_dict(src=0) +# input_tokens = metadata_dict.pop("input_tokens") +# input_positions = metadata_dict.pop("input_positions") +# selected_token_indices = metadata_dict.pop( +# "selected_token_indices") +# attn_metadata = self.attn_backend.make_metadata(**metadata_dict) +# sampling_metadata = SamplingMetadata( +# seq_groups=None, +# seq_data=None, +# seq_lens=None, +# selected_token_indices=selected_token_indices, +# categorized_sample_indices=None, +# generators=None, +# ) +# +# return (input_tokens, input_positions, attn_metadata, +# sampling_metadata, multi_modal_kwargs) +# + def get_empty_model_input(self) -> ModelInputWithSamplingMetadata: + return ModelInputWithSamplingMetadata.new() @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + model_input: ModelInputWithSamplingMetadata, kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: - (input_tokens, input_positions, attn_metadata, sampling_metadata, - multi_modal_input - ) = self.prepare_input_tensors(seq_group_metadata_list) - model_executable = self.model execute_model_kwargs = { - "input_ids": input_tokens, - "positions": input_positions, + "input_ids": model_input.input_tokens, + "positions": model_input.input_positions, "kv_caches": kv_caches, - "attn_metadata": attn_metadata, + "attn_metadata": model_input.attn_metadata, } if self.vision_language_config: - execute_model_kwargs.update({"image_input": multi_modal_input}) + execute_model_kwargs.update({"image_input": mmodel_input.multi_modal_input}) hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. - logits = self.model.compute_logits(hidden_states, sampling_metadata) + logits = self.model.compute_logits(hidden_states, model_input.sampling_metadata) # Only perform sampling in the driver worker. if not self.is_driver_worker: @@ -358,6 +384,6 @@ def execute_model( # Sample the next token. output = self.model.sample( logits=logits, - sampling_metadata=sampling_metadata, + sampling_metadata=model_input.sampling_metadata, ) return output diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 3ee394f9912e9..ce6afbbf1105c 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -13,7 +13,7 @@ init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest, SamplerOutput, ModelInput from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase @@ -263,43 +263,68 @@ def cache_copy( self.cache_engine.copy(blocks_to_copy) @torch.inference_mode() - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> List[SamplerOutput]: - - if execute_model_req is None: - seq_group_metadata_list = None - else: - seq_group_metadata_list = execute_model_req.seq_group_metadata_list + def prepare_model_input_local( + self, + execute_model_req: ExecuteModelRequest) -> ModelInput: + assert execute_model_req is not None + + model_input = self.model_runner.prepare_model_input_tensors( + execute_model_req.seq_group_metadata_list + ) + + num_seq_groups: int = len(execute_model_req.seq_group_metadata_list) + blocks_to_copy = execute_model_req.blocks_to_copy + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device="cpu", + dtype=torch.int64).view(-1, 2) + assert len(execute_model_req.blocks_to_swap_in) == 0 + assert len(execute_model_req.blocks_to_swap_out) == 0 + return model_input.replace( + num_seq_groups=num_seq_groups, + blocks_to_copy=blocks_to_copy, + ) + @torch.inference_mode() + def prepare_model_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> ModelInput: if self.is_driver_worker: - assert seq_group_metadata_list is not None - num_seq_groups: int = len(seq_group_metadata_list) - assert execute_model_req is not None - blocks_to_copy = execute_model_req.blocks_to_copy - blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, - device="cpu", - dtype=torch.int64).view(-1, 2) - assert len(execute_model_req.blocks_to_swap_in) == 0 - assert len(execute_model_req.blocks_to_swap_out) == 0 - data: Dict[str, Any] = { - "num_seq_groups": num_seq_groups, - "blocks_to_copy": execute_model_req.blocks_to_copy, - } - broadcast_tensor_dict(data, src=0) + if execute_model_req is None: + # This signals that there's no more requests to process for now. + # All workers are running infinite loop with broadcast_tensor_dict, + # and it stops the loop when the driver broadcasts an empty input. + # Send an empty input to notify all other workers to stop their + # execution loop. + broadcast_tensor_dict({}, src=0) + return None + + model_input = self.prepare_model_input_local(execute_model_req) + metadata_dict = model_input.as_broadcastable_tensor_dict() + broadcast_tensor_dict(metadata_dict, src=0) else: - data = broadcast_tensor_dict(src=0) - num_seq_groups = data["num_seq_groups"] - blocks_to_copy = data["blocks_to_copy"] + metadata_dict = broadcast_tensor_dict(src=0) + if not metadata_dict: + return None + + model_input = self.model_runner.get_empty_model_input() + model_input = model_input.new( + attn_backend=self.model_runner.attn_backend, + **metadata_dict) + return model_input - self.cache_copy(blocks_to_copy) + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInput, + ) -> List[SamplerOutput]: + self.cache_copy(model_input.blocks_to_copy) # If there is no input, we don't need to execute the model. - if num_seq_groups == 0: + if model_input.num_seq_groups == 0: return [] - output = self.model_runner.execute_model(seq_group_metadata_list, + output = self.model_runner.execute_model(model_input, self.cpu_cache) # CPU worker only supports single-step execution. From 11133fe92c92c54fe1e04d365b1c10a301d77955 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 10 Jun 2024 21:11:45 -0700 Subject: [PATCH 09/55] remove commented Signed-off-by: Stephanie Wang --- vllm/worker/cpu_model_runner.py | 50 --------------------------------- 1 file changed, 50 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 3ec12ad8096fb..6e8d60b225e7d 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -303,56 +303,6 @@ def prepare_model_input_tensors( sampling_metadata=sampling_metadata, ) -# if self.is_driver_worker: -# # NOTE: We assume that all sequences in the group are all prompts or -# # all decodes. -# is_prompt = seq_group_metadata_list[0].is_prompt -# # Prepare input tensors. -# if is_prompt: -# (input_tokens, input_positions, attn_metadata, seq_lens, -# multi_modal_kwargs -# ) = self._prepare_prompt(seq_group_metadata_list) -# else: -# (input_tokens, input_positions, -# attn_metadata) = self._prepare_decode(seq_group_metadata_list) -# seq_lens = [] -# sampling_metadata = SamplingMetadata.prepare( -# seq_group_metadata_list, -# seq_lens, -# # query_lens is not needed if chunked prefill is not -# # supported. Since CPU worker doesn't support chunked prefill -# # just use seq_lens instead. -# seq_lens, -# self.device, -# pin_memory=False) -# # Broadcast the metadata. -# metadata_dict = { -# "input_tokens": input_tokens, -# "input_positions": input_positions, -# "selected_token_indices": -# sampling_metadata.selected_token_indices, -# } -# metadata_dict.update(attn_metadata.asdict_zerocopy()) -# broadcast_tensor_dict(metadata_dict, src=0) -# else: -# metadata_dict = broadcast_tensor_dict(src=0) -# input_tokens = metadata_dict.pop("input_tokens") -# input_positions = metadata_dict.pop("input_positions") -# selected_token_indices = metadata_dict.pop( -# "selected_token_indices") -# attn_metadata = self.attn_backend.make_metadata(**metadata_dict) -# sampling_metadata = SamplingMetadata( -# seq_groups=None, -# seq_data=None, -# seq_lens=None, -# selected_token_indices=selected_token_indices, -# categorized_sample_indices=None, -# generators=None, -# ) -# -# return (input_tokens, input_positions, attn_metadata, -# sampling_metadata, multi_modal_kwargs) -# def get_empty_model_input(self) -> ModelInputWithSamplingMetadata: return ModelInputWithSamplingMetadata.new() From 174bdb14ef5c5e114d9b38940eb668db9c8ecd56 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 10 Jun 2024 21:17:53 -0700 Subject: [PATCH 10/55] lint Signed-off-by: Stephanie Wang --- vllm/executor/cpu_executor.py | 3 +- vllm/executor/distributed_gpu_executor.py | 10 ++-- vllm/executor/gpu_executor.py | 3 +- vllm/executor/multiproc_gpu_executor.py | 5 +- vllm/executor/ray_gpu_executor.py | 12 ++-- vllm/sequence.py | 68 ++++++++++++----------- vllm/worker/cpu_model_runner.py | 16 +++--- vllm/worker/cpu_worker.py | 22 +++----- vllm/worker/embedding_model_runner.py | 16 ++++-- vllm/worker/model_runner.py | 23 ++++---- vllm/worker/worker.py | 33 +++++------ vllm/worker/worker_base.py | 12 ++-- 12 files changed, 117 insertions(+), 106 deletions(-) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index d7d269a85665a..69e84f42d498f 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -75,7 +75,8 @@ def initialize_cache(self, num_gpu_blocks: int, def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - model_input = self.driver_worker.prepare_model_input_local(execute_model_req) + model_input = self.driver_worker.prepare_model_input_local( + execute_model_req) output = self.driver_worker.execute_model(model_input) return output diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index c4ed3007e2b82..06b8bcdac1d19 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -116,13 +116,13 @@ def save_sharded_state( @abstractmethod def _driver_execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] - ) -> List[SamplerOutput]: + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: """Run execute_model in the driver worker. - Passing None will cause the driver to stop the model execution - loop running in each of the remote workers. + Passing None will cause the driver to stop the model execution loop + running in each of the remote workers. In this case, this method + returns None. Otherwise, this method returns the model output. """ raise NotImplementedError diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index c32338d8b3b73..1ddd525fdeafa 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -88,7 +88,8 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: def execute_model( self, execute_model_req: ExecuteModelRequest ) -> List[Union[SamplerOutput, PoolerOutput]]: - model_input = self.driver_worker.prepare_model_input_local(execute_model_req) + model_input = self.driver_worker.prepare_model_input_local( + execute_model_req) output = self.driver_worker.execute_model(model_input) return output diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 0e49f04963881..ec337ec0e15f1 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -76,9 +76,8 @@ def shutdown(self): worker_monitor.close() def _driver_execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] - ) -> List[SamplerOutput]: + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: """Run execute_model in the driver worker. Passing None will cause the driver to stop the model execution diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 11e55d9fa5771..cb1628aba27be 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -174,17 +174,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_parallel_loading_workers) def _driver_execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] - ) -> List[SamplerOutput]: + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: """Run execute_model in the driver worker. Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ - model_input = self.driver_worker.execute_method("prepare_model_input", execute_model_req) + model_input = self.driver_worker.execute_method( + "prepare_model_input", execute_model_req) + if model_input is None: - return + return None + return self.driver_worker.execute_method("execute_model", model_input) def _run_workers( diff --git a/vllm/sequence.py b/vllm/sequence.py index 02b5971c723dc..f13b81bf356c1 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -876,6 +876,7 @@ def clone( running_queue_size=self.running_queue_size, ) + @dataclass(frozen=True) class ModelInput: """Local inputs to each worker's `execute_model` function. May contain @@ -907,44 +908,46 @@ class ModelInput: attn_metadata: Optional["AttentionMetadata"] = None BROADCASTABLE_FIELDS: List[str] = ( - "num_seq_groups", - "blocks_to_swap_in", - "blocks_to_swap_out", - "blocks_to_copy", - "input_tokens", - "input_positions", - "lora_requests", - "lora_mapping", - "multi_modal_kwargs", - "num_prefill_tokens", - "num_decode_tokens", - "slot_mapping", - "num_prefills", - ) + "num_seq_groups", + "blocks_to_swap_in", + "blocks_to_swap_out", + "blocks_to_copy", + "input_tokens", + "input_positions", + "lora_requests", + "lora_mapping", + "multi_modal_kwargs", + "num_prefill_tokens", + "num_decode_tokens", + "slot_mapping", + "num_prefills", + ) @classmethod def _get_init_kwargs(cls, - attn_backend: Optional["AttentionBackend"] = None, - attn_metadata: Optional["AttentionMetadata"] = None, **kwargs) -> Dict[str, Any]: + attn_backend: Optional["AttentionBackend"] = None, + attn_metadata: Optional["AttentionMetadata"] = None, + **kwargs) -> Dict[str, Any]: if attn_metadata is None: # Extract the fields used to create AttentionMetadata. if attn_backend is not None: valid_attn_kwargs = {} - for field in dataclasses.fields(attn_backend.get_metadata_cls()): + for field in dataclasses.fields( + attn_backend.get_metadata_cls()): val = kwargs.pop(field.name, None) if val is not None: valid_attn_kwargs[field.name] = val - attn_metadata = attn_backend.make_metadata( - **valid_attn_kwargs - ) + attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) if attn_metadata is not None: kwargs["attn_metadata"] = attn_metadata return kwargs @classmethod - def new(cls, clone: Optional["ModelInput"] = None, **kwargs) -> "ModelInput": + def new(cls, + clone: Optional["ModelInput"] = None, + **kwargs) -> "ModelInput": clone_kwargs = {} if clone is not None: for field in dataclasses.fields(clone): @@ -960,7 +963,8 @@ def replace(self, **kwargs) -> "ModelInput": valid_kwargs = self.__class__._get_init_kwargs(**kwargs) return dataclasses.replace(self, **valid_kwargs) - def as_broadcastable_tensor_dict(self) -> Dict[str, Union[int, torch.Tensor]]: + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: tensor_dict = {} for field in self.BROADCASTABLE_FIELDS: val = getattr(self, field, None) @@ -979,7 +983,8 @@ class ModelInputWithSamplingMetadata(ModelInput): sampling_metadata: Optional["SamplingMetadata"] = None @classmethod - def _get_init_kwargs(cls, + def _get_init_kwargs( + cls, selected_token_indices: Optional[torch.Tensor] = None, sampling_metadata: Optional["SamplingMetadata"] = None, **kwargs) -> Dict[str, Any]: @@ -989,20 +994,21 @@ def _get_init_kwargs(cls, # Workers do not perform sampling. sampling_metadata = SamplingMetadata( - seq_groups=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - num_prompts=0, - ) + seq_groups=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + num_prompts=0, + ) if sampling_metadata is not None: kwargs["sampling_metadata"] = sampling_metadata return super()._get_init_kwargs(**kwargs) - def as_broadcastable_tensor_dict(self) -> Dict[str, Union[int, torch.Tensor]]: + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: tensor_dict = super().as_broadcastable_tensor_dict() if self.sampling_metadata is not None: - tensor_dict["selected_token_indices"] = self.sampling_metadata.selected_token_indices + tensor_dict[ + "selected_token_indices"] = self.sampling_metadata.selected_token_indices return tensor_dict - diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 6e8d60b225e7d..c11304f93bda9 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -297,11 +297,11 @@ def prepare_model_input_tensors( self.device, pin_memory=False) return ModelInputWithSamplingMetadata.new( - input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata, - sampling_metadata=sampling_metadata, - ) + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + sampling_metadata=sampling_metadata, + ) def get_empty_model_input(self) -> ModelInputWithSamplingMetadata: return ModelInputWithSamplingMetadata.new() @@ -320,12 +320,14 @@ def execute_model( "attn_metadata": model_input.attn_metadata, } if self.vision_language_config: - execute_model_kwargs.update({"image_input": mmodel_input.multi_modal_input}) + execute_model_kwargs.update( + {"image_input": mmodel_input.multi_modal_input}) hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. - logits = self.model.compute_logits(hidden_states, model_input.sampling_metadata) + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) # Only perform sampling in the driver worker. if not self.is_driver_worker: diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index ce6afbbf1105c..dce69438ae8fb 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -264,13 +264,11 @@ def cache_copy( @torch.inference_mode() def prepare_model_input_local( - self, - execute_model_req: ExecuteModelRequest) -> ModelInput: + self, execute_model_req: ExecuteModelRequest) -> ModelInput: assert execute_model_req is not None model_input = self.model_runner.prepare_model_input_tensors( - execute_model_req.seq_group_metadata_list - ) + execute_model_req.seq_group_metadata_list) num_seq_groups: int = len(execute_model_req.seq_group_metadata_list) blocks_to_copy = execute_model_req.blocks_to_copy @@ -280,14 +278,14 @@ def prepare_model_input_local( assert len(execute_model_req.blocks_to_swap_in) == 0 assert len(execute_model_req.blocks_to_swap_out) == 0 return model_input.replace( - num_seq_groups=num_seq_groups, - blocks_to_copy=blocks_to_copy, - ) + num_seq_groups=num_seq_groups, + blocks_to_copy=blocks_to_copy, + ) @torch.inference_mode() def prepare_model_input( - self, - execute_model_req: Optional[ExecuteModelRequest] = None + self, + execute_model_req: Optional[ExecuteModelRequest] = None ) -> ModelInput: if self.is_driver_worker: if execute_model_req is None: @@ -309,8 +307,7 @@ def prepare_model_input( model_input = self.model_runner.get_empty_model_input() model_input = model_input.new( - attn_backend=self.model_runner.attn_backend, - **metadata_dict) + attn_backend=self.model_runner.attn_backend, **metadata_dict) return model_input @torch.inference_mode() @@ -324,8 +321,7 @@ def execute_model( if model_input.num_seq_groups == 0: return [] - output = self.model_runner.execute_model(model_input, - self.cpu_cache) + output = self.model_runner.execute_model(model_input, self.cpu_cache) # CPU worker only supports single-step execution. return [output] diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 9d3fc1a6fbdcd..17a0e4beae194 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -23,6 +23,7 @@ class ModelInputWithPoolingMetadata(ModelInput): pooling_metadata: Optional["SamplingMetadata"] = None + class EmbeddingModelRunner(ModelRunner): def __init__( @@ -59,7 +60,8 @@ def execute_model( kv_caches: List[torch.Tensor], ) -> Optional[PoolerOutput]: if self.lora_config: - self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) # Currently cuda graph is only supported by the decode phase. prefill_meta = model_input.attn_metadata.prefill_metadata @@ -80,7 +82,8 @@ def execute_model( "attn_metadata": model_input.attn_metadata, } if self.vision_language_config: - execute_model_kwargs.update({"image_input": model_input.multi_modal_input}) + execute_model_kwargs.update( + {"image_input": model_input.multi_modal_input}) hidden_states = model_executable(**execute_model_kwargs) # Only perform pooling in the driver worker. @@ -95,15 +98,16 @@ def prepare_model_input_tensors( seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> ModelInput: assert seq_group_metadata_list is not None - model_input = self._prepare_model_input_tensors(seq_group_metadata_list) + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list) # Prepare PoolingMetadata. pooling_metadata = self._prepare_pooling(seq_group_metadata_list, model_input.seq_lens) return ModelInputWithPoolingMetadata.new( - clone=model_input, - pooling_metadata=pooling_metadata, - ) + clone=model_input, + pooling_metadata=pooling_metadata, + ) def _prepare_pooling( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b30407dc09cec..7c076aad4ce1d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -22,7 +22,7 @@ from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata, ModelInput, ModelInputWithSamplingMetadata +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata, ModelInput, ModelInputWithSamplingMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) @@ -616,7 +616,6 @@ def _prepare_model_input_tensors( num_prefills=num_prefills, ) - def prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -634,13 +633,15 @@ def prepare_model_input_tensors( If cuda graph is required, this API automatically pads inputs. """ - model_input = self._prepare_model_input_tensors(seq_group_metadata_list) - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, model_input.seq_lens, - model_input.query_lens, self.device, self.pin_memory) + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + self.pin_memory) return ModelInputWithSamplingMetadata.new( - clone=model_input, - sampling_metadata=sampling_metadata) + clone=model_input, sampling_metadata=sampling_metadata) def get_empty_model_input(self) -> ModelInputWithSamplingMetadata: return ModelInputWithSamplingMetadata.new() @@ -652,7 +653,8 @@ def execute_model( kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: if self.lora_config: - self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) # Currently cuda graph is only supported by the decode phase. prefill_meta = model_input.attn_metadata.prefill_metadata @@ -672,7 +674,8 @@ def execute_model( ) # Compute the logits. - logits = self.model.compute_logits(hidden_states, model_input.sampling_metadata) + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) # Only perform sampling in the driver worker. if not self.is_driver_worker: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d4b6a1ddc5465..dcf1055fd0e07 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -223,11 +223,9 @@ def cache_swap( @torch.inference_mode() def prepare_model_input_local( - self, - execute_model_req: ExecuteModelRequest) -> ModelInput: + self, execute_model_req: ExecuteModelRequest) -> ModelInput: model_input = self.model_runner.prepare_model_input_tensors( - execute_model_req.seq_group_metadata_list - ) + execute_model_req.seq_group_metadata_list) num_seq_groups = len(execute_model_req.seq_group_metadata_list) # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. @@ -246,16 +244,16 @@ def prepare_model_input_local( dtype=torch.int64).view(-1, 2) return model_input.replace( - num_seq_groups=num_seq_groups, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + num_seq_groups=num_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ) @torch.inference_mode() def prepare_model_input( - self, - execute_model_req: Optional[ExecuteModelRequest] = None + self, + execute_model_req: Optional[ExecuteModelRequest] = None ) -> ModelInput: if self.is_driver_worker: if execute_model_req is None: @@ -277,8 +275,7 @@ def prepare_model_input( model_input = self.model_runner.get_empty_model_input() model_input = model_input.new( - attn_backend=self.model_runner.attn_backend, - **metadata_dict) + attn_backend=self.model_runner.attn_backend, **metadata_dict) return model_input @torch.inference_mode() @@ -286,17 +283,15 @@ def execute_model( self, model_input: ModelInput, ) -> List[Union[SamplerOutput, PoolerOutput]]: - self.cache_swap( - model_input.blocks_to_swap_in, - model_input.blocks_to_swap_out, - model_input.blocks_to_copy) + self.cache_swap(model_input.blocks_to_swap_in, + model_input.blocks_to_swap_out, + model_input.blocks_to_copy) # If there is no input, we don't need to execute the model. if model_input.num_seq_groups == 0: return [] - output = self.model_runner.execute_model(model_input, - self.gpu_cache) + output = self.model_runner.execute_model(model_input, self.gpu_cache) # Worker only supports single-step execution. Wrap the output in a list # to conform to interface. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 3ce375e8a3ea8..c82fce504b655 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -63,7 +63,8 @@ def start_worker_execution_loop(self) -> None: self.execute_model(model_input) @abstractmethod - def prepare_model_input_local(self, execute_model_req: ExecuteModelRequest) -> ModelInput: + def prepare_model_input_local( + self, execute_model_req: ExecuteModelRequest) -> ModelInput: """ Prepare a model execution request locally. This method is not allowed to communicate with external devices. @@ -71,7 +72,10 @@ def prepare_model_input_local(self, execute_model_req: ExecuteModelRequest) -> M raise NotImplementedError @abstractmethod - def prepare_model_input(self, execute_model_req: Optional[ExecuteModelRequest] = None) -> ModelInput: + def prepare_model_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> ModelInput: """ Prepare a model execution request. Communication with other workers may occur to produce the model input that should be passed to @@ -80,9 +84,7 @@ def prepare_model_input(self, execute_model_req: Optional[ExecuteModelRequest] = raise NotImplementedError @abstractmethod - def execute_model( - self, - model_input: ModelInput) -> List[SamplerOutput]: + def execute_model(self, model_input: ModelInput) -> List[SamplerOutput]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" raise NotImplementedError From c0e98caee1e2d25a24192c9c2563ec69cdbb1ad9 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Tue, 11 Jun 2024 11:01:54 -0700 Subject: [PATCH 11/55] Worker.execute_model vs execute_model_local Signed-off-by: Stephanie Wang --- vllm/executor/cpu_executor.py | 4 +--- vllm/executor/gpu_executor.py | 4 +--- vllm/executor/multiproc_gpu_executor.py | 5 +---- vllm/executor/ray_gpu_executor.py | 8 +------ vllm/worker/cpu_worker.py | 10 +++++---- vllm/worker/worker.py | 8 ++++--- vllm/worker/worker_base.py | 29 +++++++++++++++++-------- 7 files changed, 35 insertions(+), 33 deletions(-) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 69e84f42d498f..a2212459f034e 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -75,9 +75,7 @@ def initialize_cache(self, num_gpu_blocks: int, def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - model_input = self.driver_worker.prepare_model_input_local( - execute_model_req) - output = self.driver_worker.execute_model(model_input) + output = self.driver_worker.execute_model(execute_model_req) return output def add_lora(self, lora_request: LoRARequest) -> bool: diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 1ddd525fdeafa..3ad201f4757ec 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -88,9 +88,7 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: def execute_model( self, execute_model_req: ExecuteModelRequest ) -> List[Union[SamplerOutput, PoolerOutput]]: - model_input = self.driver_worker.prepare_model_input_local( - execute_model_req) - output = self.driver_worker.execute_model(model_input) + output = self.driver_worker.execute_model(execute_model_req) return output def add_lora(self, lora_request: LoRARequest) -> bool: diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index ec337ec0e15f1..c3c03a4db632c 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -83,10 +83,7 @@ def _driver_execute_model( Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ - model_input = self.driver_worker.prepare_model_input(execute_model_req) - if model_input is None: - return None - return self.driver_worker.execute_model(model_input) + return self.driver_worker.execute_model(execute_model_req) def _run_workers( self, diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index cb1628aba27be..0079321e78ed1 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -181,13 +181,7 @@ def _driver_execute_model( Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ - model_input = self.driver_worker.execute_method( - "prepare_model_input", execute_model_req) - - if model_input is None: - return None - - return self.driver_worker.execute_method("execute_model", model_input) + return self.driver_worker.execute_method("execute_model", execute_model_req) def _run_workers( self, diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index dce69438ae8fb..ec48241fbc4ae 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -285,8 +285,11 @@ def prepare_model_input_local( @torch.inference_mode() def prepare_model_input( self, - execute_model_req: Optional[ExecuteModelRequest] = None + execute_model_req: Optional[ExecuteModelRequest] ) -> ModelInput: + if self.parallel_config.tensor_parallel_size <= 1: + return self.prepare_model_input_local(execute_model_req) + if self.is_driver_worker: if execute_model_req is None: # This signals that there's no more requests to process for now. @@ -311,9 +314,8 @@ def prepare_model_input( return model_input @torch.inference_mode() - def execute_model( - self, - model_input: ModelInput, + def execute_model_local( + self, model_input: ModelInput, ) -> List[SamplerOutput]: self.cache_copy(model_input.blocks_to_copy) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index dcf1055fd0e07..7b567057300d4 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -255,6 +255,9 @@ def prepare_model_input( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> ModelInput: + if self.parallel_config.tensor_parallel_size <= 1: + return self.prepare_model_input_local(execute_model_req) + if self.is_driver_worker: if execute_model_req is None: # This signals that there's no more requests to process for now. @@ -279,10 +282,9 @@ def prepare_model_input( return model_input @torch.inference_mode() - def execute_model( + def execute_model_local( self, - model_input: ModelInput, - ) -> List[Union[SamplerOutput, PoolerOutput]]: + model_input: ModelInput) -> List[Union[SamplerOutput, PoolerOutput]]: self.cache_swap(model_input.blocks_to_swap_in, model_input.blocks_to_swap_out, model_input.blocks_to_copy) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index c82fce504b655..fafdf7933c205 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -55,12 +55,9 @@ def start_worker_execution_loop(self) -> None: See `stop_remote_worker_execution_loop` for more details. """ while True: - model_input = self.prepare_model_input(execute_model_req=None) - - if model_input is None: - return - - self.execute_model(model_input) + output = self.execute_model(model_input) + if output is None: + return None @abstractmethod def prepare_model_input_local( @@ -74,7 +71,7 @@ def prepare_model_input_local( @abstractmethod def prepare_model_input( self, - execute_model_req: Optional[ExecuteModelRequest] = None + execute_model_req: Optional[ExecuteModelRequest] ) -> ModelInput: """ Prepare a model execution request. Communication with other workers @@ -84,9 +81,23 @@ def prepare_model_input( raise NotImplementedError @abstractmethod - def execute_model(self, model_input: ModelInput) -> List[SamplerOutput]: + def execute_model(self, execute_model_req: Optional[ExecuteModelRequest]) -> List[SamplerOutput]: + """Executes at least one model step on the given sequences, unless no + sequences are provided. Communication with other workers + may occur to produce the model input that should be passed to + the model runner.""" + model_input = self.prepare_model_input(execute_model_req=execute_model_req) + if model_input is None: + return None + + return self.execute_model_local(model_input) + + @abstractmethod + def execute_model_local(self, model_input: ModelInput) -> List[SamplerOutput]: """Executes at least one model step on the given sequences, unless no - sequences are provided.""" + sequences are provided. This method is not allowed to communciate with + other workers. + """ raise NotImplementedError @abstractmethod From dccec959fa6d041524d8ee00f5b064ea347b7f6a Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Tue, 11 Jun 2024 11:25:01 -0700 Subject: [PATCH 12/55] lint Signed-off-by: Stephanie Wang --- vllm/executor/distributed_gpu_executor.py | 4 +- vllm/executor/executor_base.py | 4 +- vllm/executor/gpu_executor.py | 2 +- vllm/executor/ray_gpu_executor.py | 3 +- vllm/sequence.py | 112 +++++++++++----------- vllm/worker/cpu_model_runner.py | 6 +- vllm/worker/cpu_worker.py | 20 ++-- vllm/worker/embedding_model_runner.py | 12 +-- vllm/worker/model_runner.py | 7 +- vllm/worker/worker.py | 19 ++-- vllm/worker/worker_base.py | 19 ++-- 11 files changed, 106 insertions(+), 102 deletions(-) diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 06b8bcdac1d19..b18c703c0ae39 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -64,8 +64,8 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks=num_cpu_blocks) def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> Optional[List[SamplerOutput]]: if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 4d01939c2e38b..3020987db3f1e 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -69,8 +69,8 @@ def initialize_cache(self, num_gpu_blocks: int, @abstractmethod def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> Optional[List[SamplerOutput]]: """Executes at least one model step on the given sequences.""" raise NotImplementedError diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 3ad201f4757ec..33d8d94e18153 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -87,7 +87,7 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: def execute_model( self, execute_model_req: ExecuteModelRequest - ) -> List[Union[SamplerOutput, PoolerOutput]]: + ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: output = self.driver_worker.execute_model(execute_model_req) return output diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 0079321e78ed1..20468812e1e33 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -181,7 +181,8 @@ def _driver_execute_model( Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ - return self.driver_worker.execute_method("execute_model", execute_model_req) + return self.driver_worker.execute_method("execute_model", + execute_model_req) def _run_workers( self, diff --git a/vllm/sequence.py b/vllm/sequence.py index f13b81bf356c1..d6311a233496b 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,10 +1,9 @@ """Sequence and its related classes.""" import copy +import dataclasses import enum from abc import ABC, abstractmethod -import dataclasses -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, Set, Any +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union import torch @@ -16,14 +15,14 @@ if TYPE_CHECKING: from vllm.attention import AttentionMetadata + from vllm.attention.backends.abstract import AttentionBackend + from vllm.lora.layers import LoRAMapping from vllm.model_executor import SamplingMetadata - from vllm.model_executor.pooling_metadata import PoolingMetadata - from vllm.multimodal import MultiModalData from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics -@dataclass +@dataclasses.dataclass class Logprob: """Infos for supporting OpenAI compatible logprobs and token ranks. @@ -86,7 +85,7 @@ class SequenceStage(enum.Enum): DECODE = enum.auto() -@dataclass +@dataclasses.dataclass class RequestMetrics: """Metrics associated with a request. @@ -398,7 +397,7 @@ def __repr__(self) -> str: f"num_blocks={len(self.logical_token_blocks)})") -@dataclass +@dataclasses.dataclass class SequenceGroupState: """Mutable state tied to a specific sequence group""" @@ -775,7 +774,7 @@ def __eq__(self, other: object) -> bool: return self.embeddings == other.embeddings -@dataclass +@dataclasses.dataclass class SamplerOutput: """For each sequence group, we generate a list of SequenceOutput object, each of which contains one possible candidate for the next token. @@ -825,7 +824,7 @@ def __repr__(self) -> str: f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") -@dataclass +@dataclasses.dataclass class PoolerOutput: """The output from a pooling operation in the embedding model.""" outputs: List[EmbeddingSequenceGroupOutput] @@ -846,18 +845,21 @@ def __eq__(self, other: object): self.__class__) and self.outputs == other.outputs -@dataclass +@dataclasses.dataclass class ExecuteModelRequest: """The model execution request, containing CPU metadata only. The LLM engine should create an instance of this class for each request batch.""" # The sequence group metadata list. seq_group_metadata_list: List[SequenceGroupMetadata] # Blocks to swap in. List of CPU -> GPU block number. - blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list) + blocks_to_swap_in: List[Tuple[int, int]] = dataclasses.field( + default_factory=list) # Blocks to swap out. List of GPU -> CPU block number. - blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list) + blocks_to_swap_out: List[Tuple[int, int]] = dataclasses.field( + default_factory=list) # Blocks to copy. Source to dest block. - blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list) + blocks_to_copy: List[Tuple[int, + int]] = dataclasses.field(default_factory=list) # The number of slots for lookahead decoding. num_lookahead_slots: int = 0 # The number of requests in the running queue. @@ -877,7 +879,7 @@ def clone( ) -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class ModelInput: """Local inputs to each worker's `execute_model` function. May contain device-specific data. Different worker backends may have different methods @@ -889,25 +891,25 @@ class ModelInput: runners that run additional steps should subclass this method to add additional fields. """ - num_seq_groups: int = None - blocks_to_swap_in: torch.Tensor = None - blocks_to_swap_out: torch.Tensor = None - blocks_to_copy: torch.Tensor = None - - input_tokens: torch.Tensor = None - input_positions: torch.Tensor = None - seq_lens: List[int] = None - query_lens: List[int] = None + num_seq_groups: Optional[int] = None + blocks_to_swap_in: Optional[torch.Tensor] = None + blocks_to_swap_out: Optional[torch.Tensor] = None + blocks_to_copy: Optional[torch.Tensor] = None + + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None lora_mapping: Optional["LoRAMapping"] = None - lora_requests: Set[LoRARequest] = None - multi_modal_kwargs: Dict[str, torch.Tensor] = None - slot_mapping: torch.Tensor = None - num_prefill_tokens: int = None - num_decode_tokens: int = None - num_prefills: int = None + lora_requests: Optional[Set[LoRARequest]] = None + multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None + slot_mapping: Optional[torch.Tensor] = None + num_prefill_tokens: Optional[int] = None + num_decode_tokens: Optional[int] = None + num_prefills: Optional[int] = None attn_metadata: Optional["AttentionMetadata"] = None - BROADCASTABLE_FIELDS: List[str] = ( + BROADCASTABLE_FIELDS = ( "num_seq_groups", "blocks_to_swap_in", "blocks_to_swap_out", @@ -928,17 +930,15 @@ def _get_init_kwargs(cls, attn_backend: Optional["AttentionBackend"] = None, attn_metadata: Optional["AttentionMetadata"] = None, **kwargs) -> Dict[str, Any]: - if attn_metadata is None: + if attn_metadata is None and attn_backend is not None: # Extract the fields used to create AttentionMetadata. - if attn_backend is not None: - valid_attn_kwargs = {} - for field in dataclasses.fields( - attn_backend.get_metadata_cls()): - val = kwargs.pop(field.name, None) - if val is not None: - valid_attn_kwargs[field.name] = val - - attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) + valid_attn_kwargs = {} + for field in dataclasses.fields(attn_backend.get_metadata_cls()): + val = kwargs.pop(field.name, None) + if val is not None: + valid_attn_kwargs[field.name] = val + + attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) if attn_metadata is not None: kwargs["attn_metadata"] = attn_metadata @@ -977,28 +977,28 @@ def as_broadcastable_tensor_dict( return tensor_dict -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class ModelInputWithSamplingMetadata(ModelInput): # Metadata for sampling outputs. sampling_metadata: Optional["SamplingMetadata"] = None @classmethod - def _get_init_kwargs( + def _get_init_kwargs( # type: ignore cls, selected_token_indices: Optional[torch.Tensor] = None, sampling_metadata: Optional["SamplingMetadata"] = None, **kwargs) -> Dict[str, Any]: - if sampling_metadata is None: - if selected_token_indices is not None: - from vllm.model_executor import SamplingMetadata - - # Workers do not perform sampling. - sampling_metadata = SamplingMetadata( - seq_groups=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - num_prompts=0, - ) + if sampling_metadata is None and selected_token_indices is not None: + from vllm.model_executor import SamplingMetadata + + # An empty SamplingMetadata to signal that the worker should skip + # sampling. + sampling_metadata = SamplingMetadata( + seq_groups=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + num_prompts=0, + ) if sampling_metadata is not None: kwargs["sampling_metadata"] = sampling_metadata return super()._get_init_kwargs(**kwargs) @@ -1008,7 +1008,7 @@ def as_broadcastable_tensor_dict( tensor_dict = super().as_broadcastable_tensor_dict() if self.sampling_metadata is not None: - tensor_dict[ - "selected_token_indices"] = self.sampling_metadata.selected_token_indices + tensor_dict["selected_token_indices"] = ( + self.sampling_metadata.selected_token_indices) return tensor_dict diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index c11304f93bda9..4af5b09d0ce35 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -8,12 +8,12 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import SamplerOutput, SequenceGroupMetadata, ModelInputWithSamplingMetadata +from vllm.sequence import (ModelInputWithSamplingMetadata, SamplerOutput, + SequenceGroupMetadata) from vllm.utils import make_tensor_with_pad logger = init_logger(__name__) @@ -321,7 +321,7 @@ def execute_model( } if self.vision_language_config: execute_model_kwargs.update( - {"image_input": mmodel_input.multi_modal_input}) + {"image_input": model_input.multi_modal_input}) hidden_states = model_executable(**execute_model_kwargs) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index ec48241fbc4ae..35511ff962fae 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -1,5 +1,5 @@ """A CPU worker class.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch import torch.distributed @@ -13,7 +13,7 @@ init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest, SamplerOutput, ModelInput +from vllm.sequence import ExecuteModelRequest, ModelInput, SamplerOutput from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase @@ -285,18 +285,17 @@ def prepare_model_input_local( @torch.inference_mode() def prepare_model_input( self, - execute_model_req: Optional[ExecuteModelRequest] - ) -> ModelInput: + execute_model_req: Optional[ExecuteModelRequest]) -> ModelInput: if self.parallel_config.tensor_parallel_size <= 1: return self.prepare_model_input_local(execute_model_req) if self.is_driver_worker: if execute_model_req is None: - # This signals that there's no more requests to process for now. - # All workers are running infinite loop with broadcast_tensor_dict, - # and it stops the loop when the driver broadcasts an empty input. - # Send an empty input to notify all other workers to stop their - # execution loop. + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the driver + # broadcasts an empty input. Send an empty input to notify all + # other workers to stop their execution loop. broadcast_tensor_dict({}, src=0) return None @@ -315,7 +314,8 @@ def prepare_model_input( @torch.inference_mode() def execute_model_local( - self, model_input: ModelInput, + self, + model_input: ModelInput, ) -> List[SamplerOutput]: self.cache_copy(model_input.blocks_to_copy) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 17a0e4beae194..1a289c16f91f3 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -1,21 +1,21 @@ -from typing import Dict, List, Optional, Set, Tuple from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import torch -from vllm.attention import AttentionMetadata from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams -from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata, ModelInput +from vllm.sequence import (ModelInput, PoolerOutput, SequenceData, + SequenceGroupMetadata) from vllm.worker.model_runner import ModelRunner +if TYPE_CHECKING: + from vllm.model_executor import SamplingMetadata + logger = init_logger(__name__) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7c076aad4ce1d..2eae63799b983 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,8 +1,7 @@ import time import warnings from collections import defaultdict -from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union, Any -from dataclasses import dataclass +from typing import Dict, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -12,7 +11,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import broadcast_tensor_dict from vllm.distributed.communication_op import graph_capture from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping @@ -22,7 +20,8 @@ from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata, ModelInput, ModelInputWithSamplingMetadata +from vllm.sequence import (ModelInput, ModelInputWithSamplingMetadata, + SamplerOutput, SequenceData, SequenceGroupMetadata) from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 7b567057300d4..9754f50be54ec 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,7 +1,7 @@ """A GPU worker class.""" import gc import os -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import List, Optional, Set, Tuple, Union import torch import torch.distributed @@ -15,7 +15,8 @@ set_custom_all_reduce) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput, ModelInput +from vllm.sequence import (ExecuteModelRequest, ModelInput, PoolerOutput, + SamplerOutput) from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.model_runner import ModelRunner @@ -260,11 +261,11 @@ def prepare_model_input( if self.is_driver_worker: if execute_model_req is None: - # This signals that there's no more requests to process for now. - # All workers are running infinite loop with broadcast_tensor_dict, - # and it stops the loop when the driver broadcasts an empty input. - # Send an empty input to notify all other workers to stop their - # execution loop. + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the driver + # broadcasts an empty input. Send an empty input to notify all + # other workers to stop their execution loop. broadcast_tensor_dict({}, src=0) return None @@ -283,8 +284,8 @@ def prepare_model_input( @torch.inference_mode() def execute_model_local( - self, - model_input: ModelInput) -> List[Union[SamplerOutput, PoolerOutput]]: + self, model_input: ModelInput + ) -> List[Union[SamplerOutput, PoolerOutput]]: self.cache_swap(model_input.blocks_to_swap_in, model_input.blocks_to_swap_out, model_input.blocks_to_copy) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index fafdf7933c205..63cfa96cb3395 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -2,11 +2,12 @@ import os from abc import ABC, abstractmethod from typing import Dict, List, Optional, Set, Tuple + import torch from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput, ModelInput +from vllm.sequence import ExecuteModelRequest, ModelInput, SamplerOutput from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) @@ -55,7 +56,7 @@ def start_worker_execution_loop(self) -> None: See `stop_remote_worker_execution_loop` for more details. """ while True: - output = self.execute_model(model_input) + output = self.execute_model(execute_model_req=None) if output is None: return None @@ -71,8 +72,7 @@ def prepare_model_input_local( @abstractmethod def prepare_model_input( self, - execute_model_req: Optional[ExecuteModelRequest] - ) -> ModelInput: + execute_model_req: Optional[ExecuteModelRequest]) -> ModelInput: """ Prepare a model execution request. Communication with other workers may occur to produce the model input that should be passed to @@ -80,20 +80,23 @@ def prepare_model_input( """ raise NotImplementedError - @abstractmethod - def execute_model(self, execute_model_req: Optional[ExecuteModelRequest]) -> List[SamplerOutput]: + def execute_model( + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: """Executes at least one model step on the given sequences, unless no sequences are provided. Communication with other workers may occur to produce the model input that should be passed to the model runner.""" - model_input = self.prepare_model_input(execute_model_req=execute_model_req) + model_input: Optional[ModelInput] = self.prepare_model_input( + execute_model_req=execute_model_req) if model_input is None: return None return self.execute_model_local(model_input) @abstractmethod - def execute_model_local(self, model_input: ModelInput) -> List[SamplerOutput]: + def execute_model_local(self, + model_input: ModelInput) -> List[SamplerOutput]: """Executes at least one model step on the given sequences, unless no sequences are provided. This method is not allowed to communciate with other workers. From dad94ba556a85553cb353302d2db3dc6f6b77576 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Tue, 11 Jun 2024 14:18:52 -0700 Subject: [PATCH 13/55] neuron model runner Signed-off-by: Stephanie Wang --- vllm/executor/neuron_executor.py | 3 +-- vllm/worker/neuron_model_runner.py | 39 +++++++++++++++++++----------- vllm/worker/neuron_worker.py | 27 ++++++++++++++------- 3 files changed, 44 insertions(+), 25 deletions(-) diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index e7f0e887921b7..9a6e69a163c0a 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -55,8 +55,7 @@ def execute_model( assert execute_model_req.num_lookahead_slots == 0, ( "lookahead not supported for Neuron backend.") - output = self.driver_worker.execute_model( - execute_model_req.seq_group_metadata_list) + output = self.driver_worker.execute_model(execute_model_req) return output def add_lora(self, lora_request: LoRARequest) -> bool: diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index a336be04e124f..197e61a3e1dc3 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -1,3 +1,4 @@ +import dataclasses from typing import List, Optional, Tuple import torch @@ -8,12 +9,22 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader.neuron import get_neuron_model -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (ModelInputWithSamplingMetadata, SamplerOutput, + SequenceGroupMetadata) from vllm.utils import is_pin_memory_available, make_tensor_with_pad logger = init_logger(__name__) +@dataclasses.dataclass(frozen=True) +class ModelInputForNeuron(ModelInputWithSamplingMetadata): + input_block_ids: Optional[torch.Tensor] = None + + BROADCASTABLE_FIELDS = ( + ModelInputWithSamplingMetadata.BROADCASTABLE_FIELDS + + ("input_block_ids", )) + + class NeuronModelRunner: def __init__( @@ -139,10 +150,10 @@ def _prepare_decode( return input_tokens, input_positions, input_block_ids - def prepare_input_tensors( + def prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]: + ) -> ModelInputForNeuron: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt @@ -164,30 +175,30 @@ def prepare_input_tensors( self.device, self.pin_memory) - return (input_tokens, input_positions, input_block_ids, - sampling_metadata) + return ModelInputForNeuron(input_tokens=input_tokens, + input_positions=input_positions, + input_block_ids=input_block_ids, + sampling_metadata=sampling_metadata) @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + model_input: ModelInputForNeuron, ) -> Optional[SamplerOutput]: - (input_tokens, input_positions, input_block_ids, sampling_metadata - ) = self.prepare_input_tensors(seq_group_metadata_list) - hidden_states = self.model( - input_ids=input_tokens, - positions=input_positions, - input_block_ids=input_block_ids, + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + input_block_ids=model_input.input_block_ids, ) # Compute the logits. - logits = self.model.compute_logits(hidden_states, sampling_metadata) + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) # Sample the next token. output = self.model.sample( logits=logits, - sampling_metadata=sampling_metadata, + sampling_metadata=model_input.sampling_metadata, ) return output diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index d0e6aaed180e6..82d735233d7b3 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -1,5 +1,5 @@ """A Neuron worker class.""" -from typing import List, Tuple +from typing import List, Optional, Tuple import torch import torch.distributed @@ -7,7 +7,7 @@ from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.model_executor import set_random_seed -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, ModelInput, SamplerOutput from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase @@ -74,17 +74,26 @@ def initialize_cache(self, num_gpu_blocks: int, self.cache_config.num_cpu_blocks = num_cpu_blocks @torch.inference_mode() - def execute_model( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> List[SamplerOutput]: - num_seq_groups = len(seq_group_metadata_list) + def prepare_model_input_local( + self, execute_model_req: ExecuteModelRequest) -> ModelInput: + model_input = self.model_runner.prepare_model_input_tensors( + execute_model_req.seq_group_metadata_list) + return model_input + + def prepare_model_input( + self, + execute_model_req: Optional[ExecuteModelRequest]) -> ModelInput: + assert execute_model_req is not None + return self.prepare_model_input_local(execute_model_req) + @torch.inference_mode() + def execute_model_local(self, + model_input: ModelInput) -> List[SamplerOutput]: # If there is no input, we don't need to execute the model. - if num_seq_groups == 0: + if model_input.num_seq_groups == 0: return [] - output = self.model_runner.execute_model(seq_group_metadata_list) + output = self.model_runner.execute_model(model_input) # Neuron worker only supports single-step output. Wrap the output in a # list to conform to interface. From fca606eac1e0cdec562df4d113204da293968975 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Tue, 11 Jun 2024 14:53:26 -0700 Subject: [PATCH 14/55] disallow distributed comms Signed-off-by: Stephanie Wang --- vllm/distributed/communication_op.py | 21 +++++++++++++++++++++ vllm/worker/worker.py | 1 + vllm/worker/worker_base.py | 11 ++++++++++- 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 2b38ec472de66..cf1e975a613c7 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -14,6 +14,22 @@ get_tp_pynccl_communicator) +@dataclass +class DistributedContext: + communication_allowed: bool = True + + @staticmethod + def get_current() -> "DistributedContext": + """ + Get the singleton context. + """ + global _default_context + return _default_context + + +_default_context: DistributedContext = DistributedContext() + + @dataclass class GraphCaptureContext: stream: torch.cuda.Stream @@ -235,6 +251,11 @@ def broadcast_tensor_dict( to broadcast the metadata of the dict (e.g. dict structure, tensor sizes, dtypes). """ + ctx = DistributedContext.get_current() + if not ctx.communication_allowed: + raise RuntimeError( + "Control plane communication not allowed in current module") + # Bypass the function if we are using only 1 GPU. if (not torch.distributed.is_initialized() or torch.distributed.get_world_size(group=group) == 1): diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 9754f50be54ec..107fd22588e84 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -48,6 +48,7 @@ def __init__( is_driver_worker: bool = False, ) -> None: self.model_config = model_config + self.model_config.dtype = torch.float16 self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 63cfa96cb3395..8c9bb1d291544 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -5,6 +5,7 @@ import torch +from vllm.distributed import DistributedContext from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import ExecuteModelRequest, ModelInput, SamplerOutput @@ -92,7 +93,15 @@ def execute_model( if model_input is None: return None - return self.execute_model_local(model_input) + # Disallow control plane communication in worker-local code. + comm_ctx = DistributedContext.get_current() + comm_allowed = comm_ctx.communication_allowed + comm_ctx.communication_allowed = False + try: + return_val = self.execute_model_local(model_input) + finally: + comm_ctx.communication_allowed = comm_allowed + return return_val @abstractmethod def execute_model_local(self, From 6ed3c2ae24d2eefd239a7f83e231d563d1b4781d Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Tue, 11 Jun 2024 17:21:14 -0700 Subject: [PATCH 15/55] disable communication Signed-off-by: Stephanie Wang --- vllm/distributed/communication_op.py | 24 +++++++++++++++++++++++- vllm/worker/cpu_worker.py | 4 +++- vllm/worker/neuron_worker.py | 3 +++ vllm/worker/worker.py | 4 +++- vllm/worker/worker_base.py | 14 ++++---------- 5 files changed, 36 insertions(+), 13 deletions(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index cf1e975a613c7..a6e1586249789 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -30,6 +30,27 @@ def get_current() -> "DistributedContext": _default_context: DistributedContext = DistributedContext() +def disable_communication(fn): + """ + Helper decorator to disable control plane communication, i.e. + calling broadcast_tensor_dict will throw a RuntimeError. This can be used + to ensure that decorated code stays worker-local. + """ + + def wrapper(*args, **kwargs): + # Disallow control plane communication. + comm_ctx = DistributedContext.get_current() + original_comm_allowed = comm_ctx.communication_allowed + comm_ctx.communication_allowed = False + + try: + return fn(*args, **kwargs) + finally: + comm_ctx.communication_allowed = original_comm_allowed + + return wrapper + + @dataclass class GraphCaptureContext: stream: torch.cuda.Stream @@ -254,7 +275,8 @@ def broadcast_tensor_dict( ctx = DistributedContext.get_current() if not ctx.communication_allowed: raise RuntimeError( - "Control plane communication not allowed in current module") + "Control plane communication not allowed in functions decorated " + "with @disable_communication") # Bypass the function if we are using only 1 GPU. if (not torch.distributed.is_initialized() diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 35511ff962fae..2b10ff84ec3ac 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -8,7 +8,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import (broadcast_tensor_dict, +from vllm.distributed import (broadcast_tensor_dict, disable_communication, ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -263,6 +263,7 @@ def cache_copy( self.cache_engine.copy(blocks_to_copy) @torch.inference_mode() + @disable_communication def prepare_model_input_local( self, execute_model_req: ExecuteModelRequest) -> ModelInput: assert execute_model_req is not None @@ -313,6 +314,7 @@ def prepare_model_input( return model_input @torch.inference_mode() + @disable_communication def execute_model_local( self, model_input: ModelInput, diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 82d735233d7b3..7d46c741922a3 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -6,6 +6,7 @@ from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig) +from vllm.distributed import disable_communication from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest, ModelInput, SamplerOutput from vllm.worker.neuron_model_runner import NeuronModelRunner @@ -74,6 +75,7 @@ def initialize_cache(self, num_gpu_blocks: int, self.cache_config.num_cpu_blocks = num_cpu_blocks @torch.inference_mode() + @disable_communication def prepare_model_input_local( self, execute_model_req: ExecuteModelRequest) -> ModelInput: model_input = self.model_runner.prepare_model_input_tensors( @@ -87,6 +89,7 @@ def prepare_model_input( return self.prepare_model_input_local(execute_model_req) @torch.inference_mode() + @disable_communication def execute_model_local(self, model_input: ModelInput) -> List[SamplerOutput]: # If there is no input, we don't need to execute the model. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 107fd22588e84..ee7a507eaeb81 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -9,7 +9,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VisionLanguageConfig) -from vllm.distributed import (broadcast_tensor_dict, +from vllm.distributed import (broadcast_tensor_dict, disable_communication, ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) @@ -224,6 +224,7 @@ def cache_swap( self.cache_engine.copy(blocks_to_copy) @torch.inference_mode() + @disable_communication def prepare_model_input_local( self, execute_model_req: ExecuteModelRequest) -> ModelInput: model_input = self.model_runner.prepare_model_input_tensors( @@ -284,6 +285,7 @@ def prepare_model_input( return model_input @torch.inference_mode() + @disable_communication def execute_model_local( self, model_input: ModelInput ) -> List[Union[SamplerOutput, PoolerOutput]]: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 8c9bb1d291544..0426653849d29 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -5,7 +5,7 @@ import torch -from vllm.distributed import DistributedContext +from vllm.distributed import disable_communication from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import ExecuteModelRequest, ModelInput, SamplerOutput @@ -62,6 +62,7 @@ def start_worker_execution_loop(self) -> None: return None @abstractmethod + @disable_communication def prepare_model_input_local( self, execute_model_req: ExecuteModelRequest) -> ModelInput: """ @@ -93,17 +94,10 @@ def execute_model( if model_input is None: return None - # Disallow control plane communication in worker-local code. - comm_ctx = DistributedContext.get_current() - comm_allowed = comm_ctx.communication_allowed - comm_ctx.communication_allowed = False - try: - return_val = self.execute_model_local(model_input) - finally: - comm_ctx.communication_allowed = comm_allowed - return return_val + return self.execute_model_local(model_input) @abstractmethod + @disable_communication def execute_model_local(self, model_input: ModelInput) -> List[SamplerOutput]: """Executes at least one model step on the given sequences, unless no From 1803e330275d02c22f4dfecdf7fafa6a3bb11331 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Tue, 11 Jun 2024 17:41:01 -0700 Subject: [PATCH 16/55] Update worker.py --- vllm/worker/worker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index ee7a507eaeb81..95ed113c18778 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -48,7 +48,6 @@ def __init__( is_driver_worker: bool = False, ) -> None: self.model_config = model_config - self.model_config.dtype = torch.float16 self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config From dde799eab417a2e7acaff2439bc8df2c16b5be6d Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Wed, 12 Jun 2024 14:36:51 -0700 Subject: [PATCH 17/55] fix tests Signed-off-by: Stephanie Wang --- tests/worker/test_model_runner.py | 37 ++++++++++++++++---------- vllm/spec_decode/ngram_worker.py | 22 ++++++++++++++- vllm/spec_decode/spec_decode_worker.py | 25 ++++++++++++++++- vllm/worker/cpu_worker.py | 21 +++++++-------- vllm/worker/worker.py | 21 +++++++-------- 5 files changed, 88 insertions(+), 38 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 92de545acd53d..0cb57f2ef9407 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -58,7 +58,8 @@ def test_prepare_prompt(batch_size): expected_selected_token_indices.append(selected_token_start_idx + seq_len - 1) selected_token_start_idx += seq_len - model_input = model_runner._prepare_model_input(seq_group_metadata_list) + model_input = model_runner._prepare_model_input_tensors( + seq_group_metadata_list) input_tokens = model_input.input_tokens input_positions = model_input.input_positions attn_metadata = model_input.attn_metadata @@ -172,7 +173,8 @@ def test_prepare_decode_cuda_graph(batch_size): assert seq_group_metadata.token_chunk_size == 1 seq_group_metadata_list.append(seq_group_metadata) - model_input = model_runner._prepare_model_input(seq_group_metadata_list) + model_input = model_runner._prepare_model_input_tensors( + seq_group_metadata_list) input_tokens, input_positions, attn_metadata, slot_mapping = ( model_input.input_tokens, model_input.input_positions, model_input.attn_metadata, model_input.slot_mapping) @@ -257,19 +259,21 @@ def test_empty_seq_group(): enforce_eager=False, ) seq_group_metadata_list = [] - model_input = model_runner._prepare_model_input(seq_group_metadata_list) + model_input = model_runner._prepare_model_input_tensors( + seq_group_metadata_list) input_tokens, input_positions, attn_metadata, slot_mapping = ( model_input.input_tokens, model_input.input_positions, model_input.attn_metadata, model_input.slot_mapping, ) - assert len(input_tokens) == 0 - assert len(input_positions) == 0 + assert input_tokens is None + assert input_positions is None assert attn_metadata is None - assert len(slot_mapping) == 0 + assert slot_mapping is None - model_input = model_runner._prepare_model_input(seq_group_metadata_list) + model_input = model_runner._prepare_model_input_tensors( + seq_group_metadata_list) (input_tokens, input_positions, attn_metadata, slot_mapping, return_seq_lens) = ( model_input.input_tokens, @@ -278,11 +282,11 @@ def test_empty_seq_group(): model_input.slot_mapping, model_input.seq_lens, ) - assert len(input_tokens) == 0 - assert len(input_positions) == 0 + assert input_tokens is None + assert input_positions is None assert attn_metadata is None - assert len(slot_mapping) == 0 - assert len(return_seq_lens) == 0 + assert slot_mapping is None + assert return_seq_lens is None @pytest.fixture @@ -350,8 +354,13 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): seq_group_metadata_list.append(seq_group_metadata) decode_metadata_list.append(seq_group_metadata) - (input_tokens, input_positions, attn_metadata, _, _, _, - _) = model_runner.prepare_input_tensors(seq_group_metadata_list) + model_input = model_runner.prepare_model_input_tensors( + seq_group_metadata_list) + (input_tokens, input_positions, attn_metadata) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + ) prefill_meta_actual = attn_metadata.prefill_metadata decode_meta_actual = attn_metadata.decode_metadata @@ -364,7 +373,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): # Verify attn metadata is consistent. We don't need to test individual # values here because they are tested above. - attn_metadata = model_runner._prepare_model_input( + attn_metadata = model_runner._prepare_model_input_tensors( seq_group_metadata_list).attn_metadata for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata), diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 33af588d0ba29..734474ba83ee4 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -3,7 +3,7 @@ import torch -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest, ModelInput, SamplerOutput from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer @@ -161,3 +161,23 @@ def _raise_if_unsupported( execute_model_req.seq_group_metadata_list): raise NotImplementedError( "NGramWorker does not support beam search.") + + @torch.inference_mode() + def prepare_model_input_local( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + raise NotImplementedError("NGramWorker does not allow direct calls to " + "prepare_model_input_local") + + @torch.inference_mode() + def prepare_model_input( + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> List[SamplerOutput]: + raise NotImplementedError("NGramWorker does not allow direct calls to " + "prepare_model_input") + + @torch.inference_mode() + def execute_model_local(self, + model_input: ModelInput) -> List[SamplerOutput]: + raise NotImplementedError("NGramWorker does not allow direct calls to " + "execute_model_local") diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 45d9d5735efc6..9541bf4215e08 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -6,7 +6,7 @@ from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.sequence import (ExecuteModelRequest, SamplerOutput, +from vllm.sequence import (ExecuteModelRequest, ModelInput, SamplerOutput, SequenceGroupMetadata) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.interfaces import (SpeculativeProposals, @@ -232,6 +232,29 @@ def initialize_cache(self, num_gpu_blocks: int, self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) + @torch.inference_mode() + def prepare_model_input_local( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + raise NotImplementedError( + "SpecDecodeWorker does not allow direct calls to " + "prepare_model_input_local") + + @torch.inference_mode() + def prepare_model_input( + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> List[SamplerOutput]: + raise NotImplementedError( + "SpecDecodeWorker does not allow direct calls to " + "prepare_model_input") + + @torch.inference_mode() + def execute_model_local(self, + model_input: ModelInput) -> List[SamplerOutput]: + raise NotImplementedError( + "SpecDecodeWorker does not allow direct calls to " + "execute_model_local") + @torch.inference_mode() def execute_model( self, diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 2b10ff84ec3ac..e6b8cc33c8ec7 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -287,22 +287,21 @@ def prepare_model_input_local( def prepare_model_input( self, execute_model_req: Optional[ExecuteModelRequest]) -> ModelInput: - if self.parallel_config.tensor_parallel_size <= 1: - return self.prepare_model_input_local(execute_model_req) - if self.is_driver_worker: if execute_model_req is None: - # This signals that there's no more requests to process for - # now. All workers are running infinite loop with - # broadcast_tensor_dict, and it stops the loop when the driver - # broadcasts an empty input. Send an empty input to notify all - # other workers to stop their execution loop. - broadcast_tensor_dict({}, src=0) + if self.parallel_config.tensor_parallel_size > 1: + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the + # driver broadcasts an empty input. Send an empty input to + # notify all other workers to stop their execution loop. + broadcast_tensor_dict({}, src=0) return None model_input = self.prepare_model_input_local(execute_model_req) - metadata_dict = model_input.as_broadcastable_tensor_dict() - broadcast_tensor_dict(metadata_dict, src=0) + if self.parallel_config.tensor_parallel_size > 1: + metadata_dict = model_input.as_broadcastable_tensor_dict() + broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) if not metadata_dict: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index ee7a507eaeb81..d60e9165cdc12 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -258,22 +258,21 @@ def prepare_model_input( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> ModelInput: - if self.parallel_config.tensor_parallel_size <= 1: - return self.prepare_model_input_local(execute_model_req) - if self.is_driver_worker: if execute_model_req is None: - # This signals that there's no more requests to process for - # now. All workers are running infinite loop with - # broadcast_tensor_dict, and it stops the loop when the driver - # broadcasts an empty input. Send an empty input to notify all - # other workers to stop their execution loop. - broadcast_tensor_dict({}, src=0) + if self.parallel_config.tensor_parallel_size > 1: + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the + # driver broadcasts an empty input. Send an empty input to + # notify all other workers to stop their execution loop. + broadcast_tensor_dict({}, src=0) return None model_input = self.prepare_model_input_local(execute_model_req) - metadata_dict = model_input.as_broadcastable_tensor_dict() - broadcast_tensor_dict(metadata_dict, src=0) + if self.parallel_config.tensor_parallel_size > 1: + metadata_dict = model_input.as_broadcastable_tensor_dict() + broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) if not metadata_dict: From 039863143c23b5a7be881d0fcfa2b307e763287b Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Wed, 12 Jun 2024 16:53:11 -0700 Subject: [PATCH 18/55] update Signed-off-by: Stephanie Wang --- tests/worker/test_model_input.py | 95 +++++++++ vllm/model_input.py | 276 ++++++++++++++++++++++++++ vllm/worker/cpu_model_runner.py | 14 +- vllm/worker/cpu_worker.py | 9 +- vllm/worker/embedding_model_runner.py | 25 +-- vllm/worker/model_runner.py | 15 +- vllm/worker/neuron_model_runner.py | 14 +- vllm/worker/neuron_worker.py | 14 +- vllm/worker/worker.py | 14 +- vllm/worker/worker_base.py | 11 +- 10 files changed, 421 insertions(+), 66 deletions(-) create mode 100644 tests/worker/test_model_input.py create mode 100644 vllm/model_input.py diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py new file mode 100644 index 0000000000000..66c0fa26bc9fa --- /dev/null +++ b/tests/worker/test_model_input.py @@ -0,0 +1,95 @@ +import dataclasses +from typing import List, Tuple, Type + +import torch + +from vllm.attention import AttentionMetadata +from vllm.attention.backends.abstract import AttentionBackend +from vllm.model_executor import SamplingMetadata +from vllm.model_input import GPUModelInputWithSamplingMetadata + + +class MockAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + pass + + @staticmethod + def get_impl_cls(): + pass + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return AttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + pass + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + pass + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + pass + + +def test_gpu_model_input(): + sampling_metadata = SamplingMetadata( + ["seq_group"], + "selected_token_indices", + "categorized_sample_indices", + "num_prompts", + ) + attn_metadata = AttentionMetadata( + num_prefills=1, + num_prefill_tokens=2, + num_decode_tokens=3, + slot_mapping=torch.zeros(1), + ) + model_input = GPUModelInputWithSamplingMetadata.new( + num_seq_groups=10, + sampling_metadata=sampling_metadata, + attn_metadata=attn_metadata) + + assert isinstance(model_input, GPUModelInputWithSamplingMetadata) + + # Test round trip serialization. + tensor_dict = model_input.as_broadcastable_tensor_dict() + attn_backend = MockAttentionBackend() + received_model_input = GPUModelInputWithSamplingMetadata.new( + attn_backend=attn_backend, **tensor_dict) + assert isinstance(received_model_input, GPUModelInputWithSamplingMetadata) + + # Broadcast should not contain empty values. + for field in dataclasses.fields(model_input): + if getattr(model_input, field.name) is None: + assert field.name not in tensor_dict + # Broadcast should contain all non-empty fields defined by the developer + # for this input type. + for field in GPUModelInputWithSamplingMetadata.BROADCASTABLE_FIELDS: + if getattr(model_input, field) is not None: + assert field in tensor_dict + + # Check that received copy has correct values. + for field in dataclasses.fields(AttentionMetadata): + assert getattr(received_model_input.attn_metadata, field.name, + None) == getattr(attn_metadata, field.name, None) + # For sampling metadata, only selected_token_indices is copied. + assert (received_model_input.sampling_metadata.selected_token_indices == + sampling_metadata.selected_token_indices) + assert received_model_input.sampling_metadata.seq_groups is None diff --git a/vllm/model_input.py b/vllm/model_input.py new file mode 100644 index 0000000000000..9d4512405116e --- /dev/null +++ b/vllm/model_input.py @@ -0,0 +1,276 @@ +"""Worker-local model inputs. These define the inputs to different model +runners.""" +import dataclasses +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union + +import torch + +from vllm.lora.request import LoRARequest + +if TYPE_CHECKING: + from vllm.attention import AttentionMetadata + from vllm.attention.backends.abstract import AttentionBackend + from vllm.lora.layers import LoRAMapping + from vllm.model_executor import SamplingMetadata + from vllm.model_executor.pooling_metadata import PoolingMetadata + + +def _init_attn_metadata_from_kwargs( + attn_backend: Optional["AttentionBackend"] = None, + attn_metadata: Optional["AttentionMetadata"] = None, + **kwargs) -> Dict[str, Any]: + if attn_metadata is None and attn_backend is not None: + # Extract the fields used to create AttentionMetadata. + valid_attn_kwargs = {} + for field in dataclasses.fields(attn_backend.get_metadata_cls()): + val = kwargs.pop(field.name, None) + if val is not None: + valid_attn_kwargs[field.name] = val + + attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) + if attn_metadata is not None: + kwargs["attn_metadata"] = attn_metadata + return kwargs + + +def _add_attn_metadata_broadcastable_dict( + tensor_dict: Dict[str, Union[int, torch.Tensor]], + attn_metadata: Optional["AttentionMetadata"]) -> None: + if attn_metadata is not None: + tensor_dict.update(attn_metadata.asdict_zerocopy()) + + +def _init_sampling_metadata_from_kwargs( # type: ignore + selected_token_indices: Optional[torch.Tensor] = None, + sampling_metadata: Optional["SamplingMetadata"] = None, + **kwargs) -> Dict[str, Any]: + if sampling_metadata is None and selected_token_indices is not None: + from vllm.model_executor import SamplingMetadata + + # An empty SamplingMetadata to signal that the worker should skip + # sampling. + sampling_metadata = SamplingMetadata( + seq_groups=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + num_prompts=0, + ) + if sampling_metadata is not None: + kwargs["sampling_metadata"] = sampling_metadata + return kwargs + + +def _add_sampling_metadata_broadcastable_dict( + tensor_dict: Dict[str, Union[int, torch.Tensor]], + sampling_metadata: Optional["SamplingMetadata"]) -> None: + if sampling_metadata is not None: + tensor_dict["selected_token_indices"] = ( + sampling_metadata.selected_token_indices) + + +@dataclasses.dataclass(frozen=True) +class ModelInput: + """Local inputs to each worker's model runner. May contain + device-specific data. Different worker backends may have different methods + of converting from the global ExecuteModelRequest produced by the LLM + engine to the worker-local ModelInput objects. + + Model runners should inherit from this class and add their required fields. + For distributed executors, any fields that should be sent during a + broadcast op should also be added to the BROADCASTABLE_FIELDS. During + execution, these fields will be extracted from the source copy and + broadcasted to all workers using broadcast_tensor_dict. + + Some fields may have values that cannot be broadcasted with this method + because they require some special serialization/deserialization, e.g., a + Python class like SamplingMetadata. For these fields, override + as_broadcastable_tensor_dict to return the custom serialized values and + override _get_init_kwargs to perform the custom deserialization ( + GPUModelInput for an example). + """ + # Fields to broadcast to all workers from driver. The value must be + # broadcastable using broadcast_tensor_dict (i.e. either a tensor, or a + # Python primitive like int). During the broadcast, the listed fields will + # be extracted from the source copy and then passed to `new()` to create a + # copy on the destination(s). + BROADCASTABLE_FIELDS: Tuple[str, ...] = () + + @classmethod + def _get_init_kwargs(cls, **kwargs) -> Dict[str, Any]: + """ + Helper method to extract all dataclass fields from the given kwargs. + Override for fields that require some custom deserialization. + """ + return kwargs + + @classmethod + def new(cls, + clone: Optional["ModelInput"] = None, + **kwargs) -> "ModelInput": + """ + Create a new instance of this class. Copy fields from `clone` if + provided. Populate the new instance with the given kwargs. + """ + clone_kwargs = {} + if clone is not None: + for field in dataclasses.fields(clone): + val = getattr(clone, field.name) + if val is not None: + clone_kwargs[field.name] = val + clone_kwargs = cls._get_init_kwargs(**clone_kwargs) + + kwargs = cls._get_init_kwargs(**kwargs) + return cls(**clone_kwargs, **kwargs) + + def replace(self, **kwargs) -> "ModelInput": + """ + Replace current fields with fields in kwargs. + """ + valid_kwargs = self.__class__._get_init_kwargs(**kwargs) + return dataclasses.replace(self, **valid_kwargs) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + """ + Extract broadcastable fields. Override for fields that require some + custom deserialization. + """ + tensor_dict: Dict[str, Union[int, torch.Tensor]] = {} + for field in self.BROADCASTABLE_FIELDS: + val = getattr(self, field, None) + if val is not None: + tensor_dict[field] = val + + return tensor_dict + + +@dataclasses.dataclass(frozen=True) +class CPUModelInput(ModelInput): + """ + Used by the CPUModelRunner. + """ + num_seq_groups: Optional[int] = None + blocks_to_copy: Optional[torch.Tensor] = None + + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None + + attn_metadata: Optional["AttentionMetadata"] = None + sampling_metadata: Optional["SamplingMetadata"] = None + + BROADCASTABLE_FIELDS: Tuple[str, ...] = ( + "num_seq_groups", + "blocks_to_copy", + "input_tokens", + "input_positions", + "multi_modal_kwargs", + ) + + @classmethod + def _get_init_kwargs( # type: ignore + cls, **kwargs) -> Dict[str, Any]: + kwargs = _init_attn_metadata_from_kwargs(**kwargs) + kwargs = _init_sampling_metadata_from_kwargs(**kwargs) + return super()._get_init_kwargs(**kwargs) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = super().as_broadcastable_tensor_dict() + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + +@dataclasses.dataclass(frozen=True) +class GPUModelInput(ModelInput): + """ + This base class contains metadata needed for the base model forward pass + but not metadata for possible additional steps, e.g., sampling. Model + runners that run additional steps should subclass this method to add + additional fields. + """ + num_seq_groups: Optional[int] = None + blocks_to_swap_in: Optional[torch.Tensor] = None + blocks_to_swap_out: Optional[torch.Tensor] = None + blocks_to_copy: Optional[torch.Tensor] = None + + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Optional[Set[LoRARequest]] = None + multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None + + attn_metadata: Optional["AttentionMetadata"] = None + + BROADCASTABLE_FIELDS: Tuple[str, ...] = ( + "num_seq_groups", + "blocks_to_swap_in", + "blocks_to_swap_out", + "blocks_to_copy", + "input_tokens", + "input_positions", + "lora_requests", + "lora_mapping", + "multi_modal_kwargs", + ) + + @classmethod + def _get_init_kwargs( # type: ignore + cls, **kwargs) -> Dict[str, Any]: + kwargs = _init_attn_metadata_from_kwargs(**kwargs) + return super()._get_init_kwargs(**kwargs) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = super().as_broadcastable_tensor_dict() + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + return tensor_dict + + +@dataclasses.dataclass(frozen=True) +class GPUModelInputWithPoolingMetadata(GPUModelInput): + """ + Used by the EmbeddingModelRunner. + """ + pooling_metadata: Optional["PoolingMetadata"] = None + + +@dataclasses.dataclass(frozen=True) +class GPUModelInputWithSamplingMetadata(GPUModelInput): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None + + @classmethod + def _get_init_kwargs( # type: ignore + cls, **kwargs) -> Dict[str, Any]: + kwargs = _init_sampling_metadata_from_kwargs(**kwargs) + return super()._get_init_kwargs(**kwargs) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = super().as_broadcastable_tensor_dict() + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + +@dataclasses.dataclass(frozen=True) +class ModelInputForNeuron(ModelInput): + """ + Used by the NeuronModelRunner. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + input_block_ids: Optional[torch.Tensor] = None + seq_lens: Optional[List[int]] = None + sampling_metadata: Optional["SamplingMetadata"] = None + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + raise NotImplementedError("ModelInputForNeuron cannot be broadcast.") diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 4af5b09d0ce35..42df74802d472 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -11,9 +11,9 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model +from vllm.model_input import CPUModelInput from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import (ModelInputWithSamplingMetadata, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad logger = init_logger(__name__) @@ -273,7 +273,7 @@ def _prepare_decode( def prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> ModelInputWithSamplingMetadata: + ) -> CPUModelInput: multi_modal_kwargs = None # NOTE: We assume that all sequences in the group are all prompts or # all decodes. @@ -296,20 +296,20 @@ def prepare_model_input_tensors( seq_lens, self.device, pin_memory=False) - return ModelInputWithSamplingMetadata.new( + return CPUModelInput.new( input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, sampling_metadata=sampling_metadata, ) - def get_empty_model_input(self) -> ModelInputWithSamplingMetadata: - return ModelInputWithSamplingMetadata.new() + def get_empty_model_input(self) -> CPUModelInput: + return CPUModelInput.new() @torch.inference_mode() def execute_model( self, - model_input: ModelInputWithSamplingMetadata, + model_input: CPUModelInput, kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: model_executable = self.model diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index e6b8cc33c8ec7..3a970acda702c 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -13,7 +13,8 @@ init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest, ModelInput, SamplerOutput +from vllm.model_input import CPUModelInput +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase @@ -265,7 +266,7 @@ def cache_copy( @torch.inference_mode() @disable_communication def prepare_model_input_local( - self, execute_model_req: ExecuteModelRequest) -> ModelInput: + self, execute_model_req: ExecuteModelRequest) -> CPUModelInput: assert execute_model_req is not None model_input = self.model_runner.prepare_model_input_tensors( @@ -286,7 +287,7 @@ def prepare_model_input_local( @torch.inference_mode() def prepare_model_input( self, - execute_model_req: Optional[ExecuteModelRequest]) -> ModelInput: + execute_model_req: Optional[ExecuteModelRequest]) -> CPUModelInput: if self.is_driver_worker: if execute_model_req is None: if self.parallel_config.tensor_parallel_size > 1: @@ -316,7 +317,7 @@ def prepare_model_input( @disable_communication def execute_model_local( self, - model_input: ModelInput, + model_input: CPUModelInput, ) -> List[SamplerOutput]: self.cache_copy(model_input.blocks_to_copy) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 1a289c16f91f3..53d29d8bdf1c6 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -1,5 +1,4 @@ -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch @@ -8,22 +7,14 @@ VisionLanguageConfig) from vllm.logger import init_logger from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.model_input import GPUModelInputWithPoolingMetadata from vllm.pooling_params import PoolingParams -from vllm.sequence import (ModelInput, PoolerOutput, SequenceData, - SequenceGroupMetadata) +from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner -if TYPE_CHECKING: - from vllm.model_executor import SamplingMetadata - logger = init_logger(__name__) -@dataclass(frozen=True) -class ModelInputWithPoolingMetadata(ModelInput): - pooling_metadata: Optional["SamplingMetadata"] = None - - class EmbeddingModelRunner(ModelRunner): def __init__( @@ -50,13 +41,13 @@ def __init__( is_driver_worker=is_driver_worker, vision_language_config=vision_language_config) - def get_empty_model_input(self) -> ModelInputWithPoolingMetadata: - return ModelInputWithPoolingMetadata.new() + def get_empty_model_input(self) -> GPUModelInputWithPoolingMetadata: + return GPUModelInputWithPoolingMetadata.new() @torch.inference_mode() def execute_model( self, - model_input: ModelInputWithPoolingMetadata, + model_input: GPUModelInputWithPoolingMetadata, kv_caches: List[torch.Tensor], ) -> Optional[PoolerOutput]: if self.lora_config: @@ -96,7 +87,7 @@ def execute_model( def prepare_model_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> ModelInput: + ) -> GPUModelInputWithPoolingMetadata: assert seq_group_metadata_list is not None model_input = self._prepare_model_input_tensors( seq_group_metadata_list) @@ -104,7 +95,7 @@ def prepare_model_input_tensors( pooling_metadata = self._prepare_pooling(seq_group_metadata_list, model_input.seq_lens) - return ModelInputWithPoolingMetadata.new( + return GPUModelInputWithPoolingMetadata.new( clone=model_input, pooling_metadata=pooling_metadata, ) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2eae63799b983..ba5ac85ea3848 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -18,10 +18,11 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model +from vllm.model_input import GPUModelInput from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams -from vllm.sequence import (ModelInput, ModelInputWithSamplingMetadata, - SamplerOutput, SequenceData, SequenceGroupMetadata) +from vllm.sequence import (ModelInputWithSamplingMetadata, SamplerOutput, + SequenceData, SequenceGroupMetadata) from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) @@ -196,7 +197,7 @@ def get_max_block_per_batch(self) -> int: def _prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> ModelInput: + ) -> GPUModelInput: """Helper method to prepare the model input based on a given sequence group. Prepares metadata needed for the base model forward pass but not metadata for possible additional steps, e.g., sampling. @@ -250,7 +251,7 @@ def _prepare_model_input_tensors( paged_kv_last_page_len: List[int] = [] if len(seq_group_metadata_list) == 0: - return ModelInput() + return GPUModelInput() if self.sliding_window is not None: sliding_window_blocks = (self.sliding_window + self.block_size - @@ -600,7 +601,7 @@ def _prepare_model_input_tensors( for k, v in multi_modal_kwargs_list.items() } - return ModelInput.new( + return GPUModelInput.new( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, attn_metadata=attn_metadata, @@ -609,10 +610,6 @@ def _prepare_model_input_tensors( lora_mapping=lora_mapping, lora_requests=lora_requests, multi_modal_kwargs=multi_modal_kwargs, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - num_prefills=num_prefills, ) def prepare_model_input_tensors( diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 197e61a3e1dc3..27b2d01e5e814 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -1,4 +1,3 @@ -import dataclasses from typing import List, Optional, Tuple import torch @@ -9,22 +8,13 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader.neuron import get_neuron_model -from vllm.sequence import (ModelInputWithSamplingMetadata, SamplerOutput, - SequenceGroupMetadata) +from vllm.model_input import ModelInputForNeuron +from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad logger = init_logger(__name__) -@dataclasses.dataclass(frozen=True) -class ModelInputForNeuron(ModelInputWithSamplingMetadata): - input_block_ids: Optional[torch.Tensor] = None - - BROADCASTABLE_FIELDS = ( - ModelInputWithSamplingMetadata.BROADCASTABLE_FIELDS + - ("input_block_ids", )) - - class NeuronModelRunner: def __init__( diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 7d46c741922a3..0880ba8df36a6 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -8,7 +8,8 @@ ParallelConfig, SchedulerConfig) from vllm.distributed import disable_communication from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest, ModelInput, SamplerOutput +from vllm.model_input import ModelInputForNeuron +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase @@ -77,21 +78,22 @@ def initialize_cache(self, num_gpu_blocks: int, @torch.inference_mode() @disable_communication def prepare_model_input_local( - self, execute_model_req: ExecuteModelRequest) -> ModelInput: + self, + execute_model_req: ExecuteModelRequest) -> ModelInputForNeuron: model_input = self.model_runner.prepare_model_input_tensors( execute_model_req.seq_group_metadata_list) return model_input def prepare_model_input( - self, - execute_model_req: Optional[ExecuteModelRequest]) -> ModelInput: + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> ModelInputForNeuron: assert execute_model_req is not None return self.prepare_model_input_local(execute_model_req) @torch.inference_mode() @disable_communication - def execute_model_local(self, - model_input: ModelInput) -> List[SamplerOutput]: + def execute_model_local( + self, model_input: ModelInputForNeuron) -> List[SamplerOutput]: # If there is no input, we don't need to execute the model. if model_input.num_seq_groups == 0: return [] diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d60e9165cdc12..fc07993d30482 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -15,8 +15,8 @@ set_custom_all_reduce) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.sequence import (ExecuteModelRequest, ModelInput, PoolerOutput, - SamplerOutput) +from vllm.model_input import GPUModelInput +from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.model_runner import ModelRunner @@ -226,7 +226,7 @@ def cache_swap( @torch.inference_mode() @disable_communication def prepare_model_input_local( - self, execute_model_req: ExecuteModelRequest) -> ModelInput: + self, execute_model_req: ExecuteModelRequest) -> GPUModelInput: model_input = self.model_runner.prepare_model_input_tensors( execute_model_req.seq_group_metadata_list) @@ -255,9 +255,9 @@ def prepare_model_input_local( @torch.inference_mode() def prepare_model_input( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> ModelInput: + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> GPUModelInput: if self.is_driver_worker: if execute_model_req is None: if self.parallel_config.tensor_parallel_size > 1: @@ -286,7 +286,7 @@ def prepare_model_input( @torch.inference_mode() @disable_communication def execute_model_local( - self, model_input: ModelInput + self, model_input: GPUModelInput ) -> List[Union[SamplerOutput, PoolerOutput]]: self.cache_swap(model_input.blocks_to_swap_in, model_input.blocks_to_swap_out, diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 0426653849d29..6cfaa3847ab78 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -66,8 +66,10 @@ def start_worker_execution_loop(self) -> None: def prepare_model_input_local( self, execute_model_req: ExecuteModelRequest) -> ModelInput: """ - Prepare a model execution request locally. This method is not allowed - to communicate with external devices. + Prepare a model execution request locally. This method may move data to + the worker's local device. It is not allowed to communicate with + other workers or devices. Subclasses should keep the + @disable_communication decorator to enforce this. """ raise NotImplementedError @@ -101,8 +103,9 @@ def execute_model( def execute_model_local(self, model_input: ModelInput) -> List[SamplerOutput]: """Executes at least one model step on the given sequences, unless no - sequences are provided. This method is not allowed to communciate with - other workers. + sequences are provided. This method is not allowed to communciate + metadata to other workers. Subclasses should keep the + @disable_communication decorator to enforce this. """ raise NotImplementedError From eef66230225dc27a1c9d6a12a57120a4e102fa27 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Wed, 12 Jun 2024 16:57:22 -0700 Subject: [PATCH 19/55] merge Signed-off-by: Stephanie Wang --- vllm/worker/cpu_model_runner.py | 3 ++- vllm/worker/model_runner.py | 2 +- vllm/worker/worker.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 4cd63e54df4d3..188eba395df4e 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -319,7 +319,8 @@ def execute_model( "kv_caches": kv_caches, "attn_metadata": model_input.attn_metadata, } - if self.vision_language_config and model_input.multi_modal_input is not None: + if (self.vision_language_config + and model_input.multi_modal_input is not None): execute_model_kwargs.update(model_input.multi_modal_input) hidden_states = model_executable(**execute_model_kwargs) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b552dbb88399b..71015297e129f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -19,8 +19,8 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.model_input import GPUModelInput from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.model_input import GPUModelInput from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams from vllm.sequence import (ModelInputWithSamplingMetadata, SamplerOutput, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 8158e3dd80b68..b0d53ccc9ddf4 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -15,8 +15,8 @@ set_custom_all_reduce) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.model_input import GPUModelInput from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.model_input import GPUModelInput from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner From 3004ceb08db7c7706aa20246b0b26f111b53e0ab Mon Sep 17 00:00:00 2001 From: Stephanie Date: Wed, 12 Jun 2024 22:19:30 -0700 Subject: [PATCH 20/55] update Signed-off-by: Stephanie --- vllm/model_input.py | 61 ++++++++++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/vllm/model_input.py b/vllm/model_input.py index 9d4512405116e..371a06a490448 100644 --- a/vllm/model_input.py +++ b/vllm/model_input.py @@ -77,7 +77,7 @@ class ModelInput: Model runners should inherit from this class and add their required fields. For distributed executors, any fields that should be sent during a - broadcast op should also be added to the BROADCASTABLE_FIELDS. During + broadcast op should also be added to the broadcastable_fields. During execution, these fields will be extracted from the source copy and broadcasted to all workers using broadcast_tensor_dict. @@ -88,12 +88,17 @@ class ModelInput: override _get_init_kwargs to perform the custom deserialization ( GPUModelInput for an example). """ - # Fields to broadcast to all workers from driver. The value must be - # broadcastable using broadcast_tensor_dict (i.e. either a tensor, or a - # Python primitive like int). During the broadcast, the listed fields will - # be extracted from the source copy and then passed to `new()` to create a - # copy on the destination(s). - BROADCASTABLE_FIELDS: Tuple[str, ...] = () + + @property + def broadcastable_fields(self) -> Tuple[str, ...]: + """ + Return fields to broadcast to all workers from driver. The value of + each field must be broadcastable using broadcast_tensor_dict (i.e. + either a tensor, or a Python primitive like int). During the broadcast, + the listed fields will be extracted from the source copy and then + passed to `new()` to create a copy on the destination(s). + """ + raise NotImplementedError() @classmethod def _get_init_kwargs(cls, **kwargs) -> Dict[str, Any]: @@ -136,7 +141,7 @@ def as_broadcastable_tensor_dict( custom deserialization. """ tensor_dict: Dict[str, Union[int, torch.Tensor]] = {} - for field in self.BROADCASTABLE_FIELDS: + for field in self.broadcastable_fields: val = getattr(self, field, None) if val is not None: tensor_dict[field] = val @@ -159,13 +164,15 @@ class CPUModelInput(ModelInput): attn_metadata: Optional["AttentionMetadata"] = None sampling_metadata: Optional["SamplingMetadata"] = None - BROADCASTABLE_FIELDS: Tuple[str, ...] = ( - "num_seq_groups", - "blocks_to_copy", - "input_tokens", - "input_positions", - "multi_modal_kwargs", - ) + @property + def broadcastable_fields(self) -> Tuple[str, ...]: + return ( + "num_seq_groups", + "blocks_to_copy", + "input_tokens", + "input_positions", + "multi_modal_kwargs", + ) @classmethod def _get_init_kwargs( # type: ignore @@ -206,17 +213,19 @@ class GPUModelInput(ModelInput): attn_metadata: Optional["AttentionMetadata"] = None - BROADCASTABLE_FIELDS: Tuple[str, ...] = ( - "num_seq_groups", - "blocks_to_swap_in", - "blocks_to_swap_out", - "blocks_to_copy", - "input_tokens", - "input_positions", - "lora_requests", - "lora_mapping", - "multi_modal_kwargs", - ) + @property + def broadcastable_fields(self) -> Tuple[str, ...]: + return ( + "num_seq_groups", + "blocks_to_swap_in", + "blocks_to_swap_out", + "blocks_to_copy", + "input_tokens", + "input_positions", + "lora_requests", + "lora_mapping", + "multi_modal_kwargs", + ) @classmethod def _get_init_kwargs( # type: ignore From 9380ed8408338882cf15adef08f8bcc3d9721dbe Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Thu, 13 Jun 2024 12:40:13 -0700 Subject: [PATCH 21/55] fix Signed-off-by: Stephanie Wang --- tests/worker/test_model_runner.py | 23 +++++++++-------------- vllm/distributed/communication_op.py | 1 + 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 0f369ca8b7c78..d3dcdda4946d8 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -65,7 +65,7 @@ def test_prepare_prompt(batch_size): input_positions = model_input.input_positions attn_metadata = model_input.attn_metadata return_seq_lens = model_input.seq_lens - slot_mapping = model_input.slot_mapping + slot_mapping = attn_metadata.slot_mapping assert return_seq_lens == seq_lens assert len(slot_mapping) == len(input_tokens) @@ -178,7 +178,7 @@ def test_prepare_decode_cuda_graph(batch_size): seq_group_metadata_list) input_tokens, input_positions, attn_metadata, slot_mapping = ( model_input.input_tokens, model_input.input_positions, - model_input.attn_metadata, model_input.slot_mapping) + model_input.attn_metadata, model_input.attn_metadata.slot_mapping) assert len(slot_mapping) == len(input_tokens) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) @@ -262,31 +262,26 @@ def test_empty_seq_group(): seq_group_metadata_list = [] model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - input_tokens, input_positions, attn_metadata, slot_mapping = ( + input_tokens, input_positions, attn_metadata = ( model_input.input_tokens, model_input.input_positions, model_input.attn_metadata, - model_input.slot_mapping, ) assert input_tokens is None assert input_positions is None assert attn_metadata is None - assert slot_mapping is None model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - (input_tokens, input_positions, attn_metadata, slot_mapping, - return_seq_lens) = ( - model_input.input_tokens, - model_input.input_positions, - model_input.attn_metadata, - model_input.slot_mapping, - model_input.seq_lens, - ) + (input_tokens, input_positions, attn_metadata, return_seq_lens) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + model_input.seq_lens, + ) assert input_tokens is None assert input_positions is None assert attn_metadata is None - assert slot_mapping is None assert return_seq_lens is None diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 3195a52716005..0fbe7e1621905 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Any, Dict, Optional, Union import torch From 3c4de6d8d2b9be504d87ebfa51e8421ccc25f6ee Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Thu, 13 Jun 2024 13:32:59 -0700 Subject: [PATCH 22/55] fix Signed-off-by: Stephanie Wang --- tests/worker/test_model_input.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 66c0fa26bc9fa..8ddb641b79c55 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -81,7 +81,7 @@ def test_gpu_model_input(): assert field.name not in tensor_dict # Broadcast should contain all non-empty fields defined by the developer # for this input type. - for field in GPUModelInputWithSamplingMetadata.BROADCASTABLE_FIELDS: + for field in model_input.broadcastable_fields: if getattr(model_input, field) is not None: assert field in tensor_dict From 5053f307cf2bba561cf71fa6d1d609964482ab64 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Thu, 13 Jun 2024 15:19:32 -0700 Subject: [PATCH 23/55] fix Signed-off-by: Stephanie Wang --- vllm/spec_decode/spec_decode_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 3db25c544561e..49ddec5df0bbd 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -387,9 +387,9 @@ def _run_non_driver_rank(self) -> bool: # We run the proposer once per lookahead slot. In the future we should # delegate how many times it runs to the proposer. for _ in range(max(num_lookahead_slots, 1)): - self.proposer_worker.execute_model() + self.proposer_worker.execute_model(execute_model_req=None) - self.scorer_worker.execute_model() + self.scorer_worker.execute_model(execute_model_req=None) return True @nvtx_range("spec_decode_worker._run_speculative_decoding_step") From db38556114be038f3b3655f04557c4319ec22253 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Thu, 13 Jun 2024 19:25:39 -0700 Subject: [PATCH 24/55] x Signed-off-by: Stephanie Wang --- tests/worker/test_model_input.py | 2 +- vllm/model_input.py | 285 ------------------------- vllm/sequence.py | 141 +----------- vllm/spec_decode/ngram_worker.py | 3 +- vllm/spec_decode/spec_decode_worker.py | 3 +- vllm/worker/cpu_model_runner.py | 6 +- vllm/worker/cpu_worker.py | 6 +- vllm/worker/embedding_model_runner.py | 11 +- vllm/worker/model_runner.py | 23 +- vllm/worker/neuron_model_runner.py | 2 +- vllm/worker/neuron_worker.py | 5 +- vllm/worker/worker.py | 4 +- vllm/worker/worker_base.py | 7 +- 13 files changed, 44 insertions(+), 454 deletions(-) delete mode 100644 vllm/model_input.py diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 8ddb641b79c55..922fb435df5ff 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -6,7 +6,7 @@ from vllm.attention import AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend from vllm.model_executor import SamplingMetadata -from vllm.model_input import GPUModelInputWithSamplingMetadata +from vllm.worker.model_input import GPUModelInputWithSamplingMetadata class MockAttentionBackend(AttentionBackend): diff --git a/vllm/model_input.py b/vllm/model_input.py deleted file mode 100644 index 371a06a490448..0000000000000 --- a/vllm/model_input.py +++ /dev/null @@ -1,285 +0,0 @@ -"""Worker-local model inputs. These define the inputs to different model -runners.""" -import dataclasses -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union - -import torch - -from vllm.lora.request import LoRARequest - -if TYPE_CHECKING: - from vllm.attention import AttentionMetadata - from vllm.attention.backends.abstract import AttentionBackend - from vllm.lora.layers import LoRAMapping - from vllm.model_executor import SamplingMetadata - from vllm.model_executor.pooling_metadata import PoolingMetadata - - -def _init_attn_metadata_from_kwargs( - attn_backend: Optional["AttentionBackend"] = None, - attn_metadata: Optional["AttentionMetadata"] = None, - **kwargs) -> Dict[str, Any]: - if attn_metadata is None and attn_backend is not None: - # Extract the fields used to create AttentionMetadata. - valid_attn_kwargs = {} - for field in dataclasses.fields(attn_backend.get_metadata_cls()): - val = kwargs.pop(field.name, None) - if val is not None: - valid_attn_kwargs[field.name] = val - - attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) - if attn_metadata is not None: - kwargs["attn_metadata"] = attn_metadata - return kwargs - - -def _add_attn_metadata_broadcastable_dict( - tensor_dict: Dict[str, Union[int, torch.Tensor]], - attn_metadata: Optional["AttentionMetadata"]) -> None: - if attn_metadata is not None: - tensor_dict.update(attn_metadata.asdict_zerocopy()) - - -def _init_sampling_metadata_from_kwargs( # type: ignore - selected_token_indices: Optional[torch.Tensor] = None, - sampling_metadata: Optional["SamplingMetadata"] = None, - **kwargs) -> Dict[str, Any]: - if sampling_metadata is None and selected_token_indices is not None: - from vllm.model_executor import SamplingMetadata - - # An empty SamplingMetadata to signal that the worker should skip - # sampling. - sampling_metadata = SamplingMetadata( - seq_groups=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - num_prompts=0, - ) - if sampling_metadata is not None: - kwargs["sampling_metadata"] = sampling_metadata - return kwargs - - -def _add_sampling_metadata_broadcastable_dict( - tensor_dict: Dict[str, Union[int, torch.Tensor]], - sampling_metadata: Optional["SamplingMetadata"]) -> None: - if sampling_metadata is not None: - tensor_dict["selected_token_indices"] = ( - sampling_metadata.selected_token_indices) - - -@dataclasses.dataclass(frozen=True) -class ModelInput: - """Local inputs to each worker's model runner. May contain - device-specific data. Different worker backends may have different methods - of converting from the global ExecuteModelRequest produced by the LLM - engine to the worker-local ModelInput objects. - - Model runners should inherit from this class and add their required fields. - For distributed executors, any fields that should be sent during a - broadcast op should also be added to the broadcastable_fields. During - execution, these fields will be extracted from the source copy and - broadcasted to all workers using broadcast_tensor_dict. - - Some fields may have values that cannot be broadcasted with this method - because they require some special serialization/deserialization, e.g., a - Python class like SamplingMetadata. For these fields, override - as_broadcastable_tensor_dict to return the custom serialized values and - override _get_init_kwargs to perform the custom deserialization ( - GPUModelInput for an example). - """ - - @property - def broadcastable_fields(self) -> Tuple[str, ...]: - """ - Return fields to broadcast to all workers from driver. The value of - each field must be broadcastable using broadcast_tensor_dict (i.e. - either a tensor, or a Python primitive like int). During the broadcast, - the listed fields will be extracted from the source copy and then - passed to `new()` to create a copy on the destination(s). - """ - raise NotImplementedError() - - @classmethod - def _get_init_kwargs(cls, **kwargs) -> Dict[str, Any]: - """ - Helper method to extract all dataclass fields from the given kwargs. - Override for fields that require some custom deserialization. - """ - return kwargs - - @classmethod - def new(cls, - clone: Optional["ModelInput"] = None, - **kwargs) -> "ModelInput": - """ - Create a new instance of this class. Copy fields from `clone` if - provided. Populate the new instance with the given kwargs. - """ - clone_kwargs = {} - if clone is not None: - for field in dataclasses.fields(clone): - val = getattr(clone, field.name) - if val is not None: - clone_kwargs[field.name] = val - clone_kwargs = cls._get_init_kwargs(**clone_kwargs) - - kwargs = cls._get_init_kwargs(**kwargs) - return cls(**clone_kwargs, **kwargs) - - def replace(self, **kwargs) -> "ModelInput": - """ - Replace current fields with fields in kwargs. - """ - valid_kwargs = self.__class__._get_init_kwargs(**kwargs) - return dataclasses.replace(self, **valid_kwargs) - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - """ - Extract broadcastable fields. Override for fields that require some - custom deserialization. - """ - tensor_dict: Dict[str, Union[int, torch.Tensor]] = {} - for field in self.broadcastable_fields: - val = getattr(self, field, None) - if val is not None: - tensor_dict[field] = val - - return tensor_dict - - -@dataclasses.dataclass(frozen=True) -class CPUModelInput(ModelInput): - """ - Used by the CPUModelRunner. - """ - num_seq_groups: Optional[int] = None - blocks_to_copy: Optional[torch.Tensor] = None - - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None - - attn_metadata: Optional["AttentionMetadata"] = None - sampling_metadata: Optional["SamplingMetadata"] = None - - @property - def broadcastable_fields(self) -> Tuple[str, ...]: - return ( - "num_seq_groups", - "blocks_to_copy", - "input_tokens", - "input_positions", - "multi_modal_kwargs", - ) - - @classmethod - def _get_init_kwargs( # type: ignore - cls, **kwargs) -> Dict[str, Any]: - kwargs = _init_attn_metadata_from_kwargs(**kwargs) - kwargs = _init_sampling_metadata_from_kwargs(**kwargs) - return super()._get_init_kwargs(**kwargs) - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = super().as_broadcastable_tensor_dict() - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) - return tensor_dict - - -@dataclasses.dataclass(frozen=True) -class GPUModelInput(ModelInput): - """ - This base class contains metadata needed for the base model forward pass - but not metadata for possible additional steps, e.g., sampling. Model - runners that run additional steps should subclass this method to add - additional fields. - """ - num_seq_groups: Optional[int] = None - blocks_to_swap_in: Optional[torch.Tensor] = None - blocks_to_swap_out: Optional[torch.Tensor] = None - blocks_to_copy: Optional[torch.Tensor] = None - - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - seq_lens: Optional[List[int]] = None - query_lens: Optional[List[int]] = None - lora_mapping: Optional["LoRAMapping"] = None - lora_requests: Optional[Set[LoRARequest]] = None - multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None - - attn_metadata: Optional["AttentionMetadata"] = None - - @property - def broadcastable_fields(self) -> Tuple[str, ...]: - return ( - "num_seq_groups", - "blocks_to_swap_in", - "blocks_to_swap_out", - "blocks_to_copy", - "input_tokens", - "input_positions", - "lora_requests", - "lora_mapping", - "multi_modal_kwargs", - ) - - @classmethod - def _get_init_kwargs( # type: ignore - cls, **kwargs) -> Dict[str, Any]: - kwargs = _init_attn_metadata_from_kwargs(**kwargs) - return super()._get_init_kwargs(**kwargs) - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = super().as_broadcastable_tensor_dict() - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - return tensor_dict - - -@dataclasses.dataclass(frozen=True) -class GPUModelInputWithPoolingMetadata(GPUModelInput): - """ - Used by the EmbeddingModelRunner. - """ - pooling_metadata: Optional["PoolingMetadata"] = None - - -@dataclasses.dataclass(frozen=True) -class GPUModelInputWithSamplingMetadata(GPUModelInput): - """ - Used by the ModelRunner. - """ - sampling_metadata: Optional["SamplingMetadata"] = None - - @classmethod - def _get_init_kwargs( # type: ignore - cls, **kwargs) -> Dict[str, Any]: - kwargs = _init_sampling_metadata_from_kwargs(**kwargs) - return super()._get_init_kwargs(**kwargs) - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = super().as_broadcastable_tensor_dict() - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) - return tensor_dict - - -@dataclasses.dataclass(frozen=True) -class ModelInputForNeuron(ModelInput): - """ - Used by the NeuronModelRunner. - """ - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - input_block_ids: Optional[torch.Tensor] = None - seq_lens: Optional[List[int]] = None - sampling_metadata: Optional["SamplingMetadata"] = None - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - raise NotImplementedError("ModelInputForNeuron cannot be broadcast.") diff --git a/vllm/sequence.py b/vllm/sequence.py index d6311a233496b..830cc0533af97 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -3,7 +3,7 @@ import dataclasses import enum from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -14,10 +14,6 @@ from vllm.sampling_params import SamplingParams if TYPE_CHECKING: - from vllm.attention import AttentionMetadata - from vllm.attention.backends.abstract import AttentionBackend - from vllm.lora.layers import LoRAMapping - from vllm.model_executor import SamplingMetadata from vllm.multimodal import MultiModalData from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics @@ -877,138 +873,3 @@ def clone( num_lookahead_slots=self.num_lookahead_slots, running_queue_size=self.running_queue_size, ) - - -@dataclasses.dataclass(frozen=True) -class ModelInput: - """Local inputs to each worker's `execute_model` function. May contain - device-specific data. Different worker backends may have different methods - of converting from the global ExecuteModelRequest produced by the LLM - engine to the worker-local ModelInput objects. - - This base class contains metadata needed for the base model forward pass - but not metadata for possible additional steps, e.g., sampling. Model - runners that run additional steps should subclass this method to add - additional fields. - """ - num_seq_groups: Optional[int] = None - blocks_to_swap_in: Optional[torch.Tensor] = None - blocks_to_swap_out: Optional[torch.Tensor] = None - blocks_to_copy: Optional[torch.Tensor] = None - - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - seq_lens: Optional[List[int]] = None - query_lens: Optional[List[int]] = None - lora_mapping: Optional["LoRAMapping"] = None - lora_requests: Optional[Set[LoRARequest]] = None - multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None - slot_mapping: Optional[torch.Tensor] = None - num_prefill_tokens: Optional[int] = None - num_decode_tokens: Optional[int] = None - num_prefills: Optional[int] = None - attn_metadata: Optional["AttentionMetadata"] = None - - BROADCASTABLE_FIELDS = ( - "num_seq_groups", - "blocks_to_swap_in", - "blocks_to_swap_out", - "blocks_to_copy", - "input_tokens", - "input_positions", - "lora_requests", - "lora_mapping", - "multi_modal_kwargs", - "num_prefill_tokens", - "num_decode_tokens", - "slot_mapping", - "num_prefills", - ) - - @classmethod - def _get_init_kwargs(cls, - attn_backend: Optional["AttentionBackend"] = None, - attn_metadata: Optional["AttentionMetadata"] = None, - **kwargs) -> Dict[str, Any]: - if attn_metadata is None and attn_backend is not None: - # Extract the fields used to create AttentionMetadata. - valid_attn_kwargs = {} - for field in dataclasses.fields(attn_backend.get_metadata_cls()): - val = kwargs.pop(field.name, None) - if val is not None: - valid_attn_kwargs[field.name] = val - - attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) - if attn_metadata is not None: - kwargs["attn_metadata"] = attn_metadata - - return kwargs - - @classmethod - def new(cls, - clone: Optional["ModelInput"] = None, - **kwargs) -> "ModelInput": - clone_kwargs = {} - if clone is not None: - for field in dataclasses.fields(clone): - val = getattr(clone, field.name) - if val is not None: - clone_kwargs[field.name] = val - clone_kwargs = cls._get_init_kwargs(**clone_kwargs) - - kwargs = cls._get_init_kwargs(**kwargs) - return cls(**clone_kwargs, **kwargs) - - def replace(self, **kwargs) -> "ModelInput": - valid_kwargs = self.__class__._get_init_kwargs(**kwargs) - return dataclasses.replace(self, **valid_kwargs) - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = {} - for field in self.BROADCASTABLE_FIELDS: - val = getattr(self, field, None) - if val is not None: - tensor_dict[field] = val - - if self.attn_metadata is not None: - tensor_dict.update(self.attn_metadata.asdict_zerocopy()) - - return tensor_dict - - -@dataclasses.dataclass(frozen=True) -class ModelInputWithSamplingMetadata(ModelInput): - # Metadata for sampling outputs. - sampling_metadata: Optional["SamplingMetadata"] = None - - @classmethod - def _get_init_kwargs( # type: ignore - cls, - selected_token_indices: Optional[torch.Tensor] = None, - sampling_metadata: Optional["SamplingMetadata"] = None, - **kwargs) -> Dict[str, Any]: - if sampling_metadata is None and selected_token_indices is not None: - from vllm.model_executor import SamplingMetadata - - # An empty SamplingMetadata to signal that the worker should skip - # sampling. - sampling_metadata = SamplingMetadata( - seq_groups=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - num_prompts=0, - ) - if sampling_metadata is not None: - kwargs["sampling_metadata"] = sampling_metadata - return super()._get_init_kwargs(**kwargs) - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = super().as_broadcastable_tensor_dict() - - if self.sampling_metadata is not None: - tensor_dict["selected_token_indices"] = ( - self.sampling_metadata.selected_token_indices) - - return tensor_dict diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 734474ba83ee4..593aa6139921b 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -3,7 +3,8 @@ import torch -from vllm.sequence import ExecuteModelRequest, ModelInput, SamplerOutput +from vllm.worker.model_input import ModelInput +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 49ddec5df0bbd..e8d708dcd189e 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -7,7 +7,8 @@ from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.sequence import (ExecuteModelRequest, ModelInput, SamplerOutput, +from vllm.worker.model_input import ModelInput +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceGroupMetadata) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.interfaces import (SpeculativeProposals, diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 188eba395df4e..0cb5ff5ad3fff 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -11,10 +11,10 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.model_input import CPUModelInput from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad +from vllm.worker.model_input import CPUModelInput logger = init_logger(__name__) @@ -320,8 +320,8 @@ def execute_model( "attn_metadata": model_input.attn_metadata, } if (self.vision_language_config - and model_input.multi_modal_input is not None): - execute_model_kwargs.update(model_input.multi_modal_input) + and model_input.multi_modal_kwargs is not None): + execute_model_kwargs.update(model_input.multi_modal_kwargs) hidden_states = model_executable(**execute_model_kwargs) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 3a970acda702c..5048bf7182fc8 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -13,10 +13,10 @@ init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.model_input import CPUModelInput from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.cpu_model_runner import CPUModelRunner +from vllm.worker.model_input import CPUModelInput from vllm.worker.worker_base import LoraNotSupportedWorkerBase logger = init_logger(__name__) @@ -286,8 +286,8 @@ def prepare_model_input_local( @torch.inference_mode() def prepare_model_input( - self, - execute_model_req: Optional[ExecuteModelRequest]) -> CPUModelInput: + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[CPUModelInput]: if self.is_driver_worker: if execute_model_req is None: if self.parallel_config.tensor_parallel_size > 1: diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 53d29d8bdf1c6..d782bdce5fbe9 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -7,9 +7,9 @@ VisionLanguageConfig) from vllm.logger import init_logger from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.model_input import GPUModelInputWithPoolingMetadata from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata +from vllm.worker.model_input import GPUModelInputWithPoolingMetadata from vllm.worker.model_runner import ModelRunner logger = init_logger(__name__) @@ -51,13 +51,17 @@ def execute_model( kv_caches: List[torch.Tensor], ) -> Optional[PoolerOutput]: if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) # Currently cuda graph is only supported by the decode phase. + assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata decode_meta = model_input.attn_metadata.decode_metadata if prefill_meta is None and decode_meta.use_cuda_graph: + assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: @@ -73,8 +77,8 @@ def execute_model( "attn_metadata": model_input.attn_metadata, } if self.vision_language_config: - execute_model_kwargs.update( - {"image_input": model_input.multi_modal_input}) + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + execute_model_kwargs.update({"image_input": multi_modal_kwargs}) hidden_states = model_executable(**execute_model_kwargs) # Only perform pooling in the driver worker. @@ -92,6 +96,7 @@ def prepare_model_input_tensors( model_input = self._prepare_model_input_tensors( seq_group_metadata_list) # Prepare PoolingMetadata. + assert model_input.seq_lens is not None pooling_metadata = self._prepare_pooling(seq_group_metadata_list, model_input.seq_lens) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1bdd050028733..1c22041bcb19f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -20,13 +20,13 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.model_input import GPUModelInput from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sampling_params import SamplingParams -from vllm.sequence import (ModelInputWithSamplingMetadata, SamplerOutput, - SequenceData, SequenceGroupMetadata) +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) +from vllm.worker.model_input import (GPUModelInput, + GPUModelInputWithSamplingMetadata) logger = init_logger(__name__) @@ -627,7 +627,7 @@ def _prepare_model_input_tensors( def prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> ModelInputWithSamplingMetadata: + ) -> GPUModelInput: """Prepare the model input based on a given sequence group, including metadata for the sampling step. @@ -648,37 +648,42 @@ def prepare_model_input_tensors( model_input.query_lens, self.device, self.pin_memory) - return ModelInputWithSamplingMetadata.new( + return GPUModelInputWithSamplingMetadata.new( clone=model_input, sampling_metadata=sampling_metadata) - def get_empty_model_input(self) -> ModelInputWithSamplingMetadata: - return ModelInputWithSamplingMetadata.new() + def get_empty_model_input(self) -> GPUModelInput: + return GPUModelInputWithSamplingMetadata.new() @torch.inference_mode() def execute_model( self, - model_input: ModelInputWithSamplingMetadata, + model_input: GPUModelInputWithSamplingMetadata, kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) # Currently cuda graph is only supported by the decode phase. + assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata decode_meta = model_input.attn_metadata.decode_metadata if prefill_meta is None and decode_meta.use_cuda_graph: + assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model + multi_modal_kwargs = model_input.multi_modal_kwargs or {} hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, - **model_input.multi_modal_kwargs, + **multi_modal_kwargs, ) # Compute the logits. diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 27b2d01e5e814..61c0a2b3b1ee7 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -8,9 +8,9 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader.neuron import get_neuron_model -from vllm.model_input import ModelInputForNeuron from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.worker.model_input import ModelInputForNeuron logger = init_logger(__name__) diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 0880ba8df36a6..3017dfe76ab8e 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -8,8 +8,8 @@ ParallelConfig, SchedulerConfig) from vllm.distributed import disable_communication from vllm.model_executor import set_random_seed -from vllm.model_input import ModelInputForNeuron from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.worker.model_input import ModelInputForNeuron from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase @@ -82,7 +82,8 @@ def prepare_model_input_local( execute_model_req: ExecuteModelRequest) -> ModelInputForNeuron: model_input = self.model_runner.prepare_model_input_tensors( execute_model_req.seq_group_metadata_list) - return model_input + return model_input.replace(num_seq_groups=len( + execute_model_req.seq_group_metadata_list), ) def prepare_model_input( self, execute_model_req: Optional[ExecuteModelRequest] diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index b0d53ccc9ddf4..d96e85035e280 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -16,10 +16,10 @@ from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.model_input import GPUModelInput from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner +from vllm.worker.model_input import GPUModelInput from vllm.worker.model_runner import ModelRunner from vllm.worker.worker_base import WorkerBase @@ -264,7 +264,7 @@ def prepare_model_input_local( def prepare_model_input( self, execute_model_req: Optional[ExecuteModelRequest] = None - ) -> GPUModelInput: + ) -> Optional[GPUModelInput]: if self.is_driver_worker: if execute_model_req is None: if self.parallel_config.tensor_parallel_size > 1: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 6cfaa3847ab78..37306debf14a3 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -8,7 +8,8 @@ from vllm.distributed import disable_communication from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, ModelInput, SamplerOutput +from vllm.worker.model_input import ModelInput +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) @@ -75,8 +76,8 @@ def prepare_model_input_local( @abstractmethod def prepare_model_input( - self, - execute_model_req: Optional[ExecuteModelRequest]) -> ModelInput: + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[ModelInput]: """ Prepare a model execution request. Communication with other workers may occur to produce the model input that should be passed to From 456185d526a545a3e9f731b5d864c700a3e9f18e Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Thu, 13 Jun 2024 19:26:18 -0700 Subject: [PATCH 25/55] rm Signed-off-by: Stephanie Wang --- vllm/sequence.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 830cc0533af97..bd36083663526 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -163,8 +163,6 @@ def get_num_computed_tokens(self) -> int: def update_num_computed_tokens(self, num_new_computed_tokens: int): """Update number of tokens computed so far.""" - # TODO: Check who calls this and make sure it's synchronized across - # driver and workers. self._num_computed_tokens += num_new_computed_tokens assert self._num_computed_tokens <= self.get_len(), ( self._num_computed_tokens, self.get_len()) From e8606523d981713f3a1ebbb8e02bc4e7006dec30 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Thu, 13 Jun 2024 19:29:06 -0700 Subject: [PATCH 26/55] lint Signed-off-by: Stephanie Wang --- vllm/spec_decode/ngram_worker.py | 2 +- vllm/spec_decode/spec_decode_worker.py | 2 +- vllm/worker/worker_base.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 593aa6139921b..a0968f2f658bc 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -3,11 +3,11 @@ import torch -from vllm.worker.model_input import ModelInput from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer +from vllm.worker.model_input import ModelInput from vllm.worker.worker_base import LoraNotSupportedWorkerBase diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index e8d708dcd189e..4facd50a0c8cc 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -7,7 +7,6 @@ from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.worker.model_input import ModelInput from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceGroupMetadata) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer @@ -21,6 +20,7 @@ get_all_num_logprobs, get_all_seq_ids, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) +from vllm.worker.model_input import ModelInput from vllm.worker.worker import Worker from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 37306debf14a3..568cf1b395a32 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -8,10 +8,10 @@ from vllm.distributed import disable_communication from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.worker.model_input import ModelInput from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) +from vllm.worker.model_input import ModelInput logger = init_logger(__name__) From 3d4f242c44cf0a4ca9eaec657cc0308ff9a631cd Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Thu, 13 Jun 2024 22:19:17 -0700 Subject: [PATCH 27/55] add missing Signed-off-by: Stephanie Wang --- vllm/worker/model_input.py | 288 +++++++++++++++++++++++++++++++++++++ 1 file changed, 288 insertions(+) create mode 100644 vllm/worker/model_input.py diff --git a/vllm/worker/model_input.py b/vllm/worker/model_input.py new file mode 100644 index 0000000000000..b85cf100513c3 --- /dev/null +++ b/vllm/worker/model_input.py @@ -0,0 +1,288 @@ +"""Worker-local model inputs. These define the inputs to different model +runners.""" +import dataclasses +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, + TypeVar, Union) + +import torch + +from vllm.lora.request import LoRARequest + +if TYPE_CHECKING: + from vllm.attention import AttentionMetadata + from vllm.attention.backends.abstract import AttentionBackend + from vllm.lora.layers import LoRAMapping + from vllm.model_executor import SamplingMetadata + from vllm.model_executor.pooling_metadata import PoolingMetadata + + +def _init_attn_metadata_from_kwargs( + attn_backend: Optional["AttentionBackend"] = None, + attn_metadata: Optional["AttentionMetadata"] = None, + **kwargs) -> Dict[str, Any]: + if attn_metadata is None and attn_backend is not None: + # Extract the fields used to create AttentionMetadata. + valid_attn_kwargs = {} + for field in dataclasses.fields(attn_backend.get_metadata_cls()): + val = kwargs.pop(field.name, None) + if val is not None: + valid_attn_kwargs[field.name] = val + + attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) + if attn_metadata is not None: + kwargs["attn_metadata"] = attn_metadata + return kwargs + + +def _add_attn_metadata_broadcastable_dict( + tensor_dict: Dict[str, Union[int, torch.Tensor]], + attn_metadata: Optional["AttentionMetadata"]) -> None: + if attn_metadata is not None: + tensor_dict.update(attn_metadata.asdict_zerocopy()) + + +def _init_sampling_metadata_from_kwargs( # type: ignore + selected_token_indices: Optional[torch.Tensor] = None, + sampling_metadata: Optional["SamplingMetadata"] = None, + **kwargs) -> Dict[str, Any]: + if sampling_metadata is None and selected_token_indices is not None: + from vllm.model_executor import SamplingMetadata + + # An empty SamplingMetadata to signal that the worker should skip + # sampling. + sampling_metadata = SamplingMetadata( + seq_groups=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + num_prompts=0, + ) + if sampling_metadata is not None: + kwargs["sampling_metadata"] = sampling_metadata + return kwargs + + +def _add_sampling_metadata_broadcastable_dict( + tensor_dict: Dict[str, Union[int, torch.Tensor]], + sampling_metadata: Optional["SamplingMetadata"]) -> None: + if sampling_metadata is not None: + tensor_dict["selected_token_indices"] = ( + sampling_metadata.selected_token_indices) + + +T = TypeVar('T', bound="ModelInput") + + +@dataclasses.dataclass(frozen=True) +class ModelInput: + """Local inputs to each worker's model runner. May contain + device-specific data. Different worker backends may have different methods + of converting from the global ExecuteModelRequest produced by the LLM + engine to the worker-local ModelInput objects. + + Model runners should inherit from this class and add their required fields. + For distributed executors, any fields that should be sent during a + broadcast op should also be added to the broadcastable_fields. During + execution, these fields will be extracted from the source copy and + broadcasted to all workers using broadcast_tensor_dict. + + Some fields may have values that cannot be broadcasted with this method + because they require some special serialization/deserialization, e.g., a + Python class like SamplingMetadata. For these fields, override + as_broadcastable_tensor_dict to return the custom serialized values and + override _get_init_kwargs to perform the custom deserialization ( + GPUModelInput for an example). + """ + + @property + def broadcastable_fields(self) -> Tuple[str, ...]: + """ + Return fields to broadcast to all workers from driver. The value of + each field must be broadcastable using broadcast_tensor_dict (i.e. + either a tensor, or a Python primitive like int). During the broadcast, + the listed fields will be extracted from the source copy and then + passed to `new()` to create a copy on the destination(s). + """ + raise NotImplementedError() + + @classmethod + def _get_init_kwargs(cls, **kwargs) -> Dict[str, Any]: + """ + Helper method to extract all dataclass fields from the given kwargs. + Override for fields that require some custom deserialization. + """ + return kwargs + + @classmethod + def new(cls: Type[T], clone: Optional["ModelInput"] = None, **kwargs) -> T: + """ + Create a new instance of this class. Copy fields from `clone` if + provided. Populate the new instance with the given kwargs. + """ + clone_kwargs = {} + if clone is not None: + for field in dataclasses.fields(clone): + val = getattr(clone, field.name) + if val is not None: + clone_kwargs[field.name] = val + clone_kwargs = cls._get_init_kwargs(**clone_kwargs) + + kwargs = cls._get_init_kwargs(**kwargs) + return cls(**clone_kwargs, **kwargs) + + def replace(self: T, **kwargs) -> T: + """ + Replace current fields with fields in kwargs. + """ + valid_kwargs = self.__class__._get_init_kwargs(**kwargs) + return dataclasses.replace(self, **valid_kwargs) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + """ + Extract broadcastable fields. Override for fields that require some + custom deserialization. + """ + tensor_dict: Dict[str, Union[int, torch.Tensor]] = {} + for field in self.broadcastable_fields: + val = getattr(self, field, None) + if val is not None: + tensor_dict[field] = val + + return tensor_dict + + +@dataclasses.dataclass(frozen=True) +class CPUModelInput(ModelInput): + """ + Used by the CPUModelRunner. + """ + num_seq_groups: Optional[int] = None + blocks_to_copy: Optional[torch.Tensor] = None + + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None + + attn_metadata: Optional["AttentionMetadata"] = None + sampling_metadata: Optional["SamplingMetadata"] = None + + @property + def broadcastable_fields(self) -> Tuple[str, ...]: + return ( + "num_seq_groups", + "blocks_to_copy", + "input_tokens", + "input_positions", + "multi_modal_kwargs", + ) + + @classmethod + def _get_init_kwargs( # type: ignore + cls, **kwargs) -> Dict[str, Any]: + kwargs = _init_attn_metadata_from_kwargs(**kwargs) + kwargs = _init_sampling_metadata_from_kwargs(**kwargs) + return super()._get_init_kwargs(**kwargs) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = super().as_broadcastable_tensor_dict() + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + +@dataclasses.dataclass(frozen=True) +class GPUModelInput(ModelInput): + """ + This base class contains metadata needed for the base model forward pass + but not metadata for possible additional steps, e.g., sampling. Model + runners that run additional steps should subclass this method to add + additional fields. + """ + num_seq_groups: Optional[int] = None + blocks_to_swap_in: Optional[torch.Tensor] = None + blocks_to_swap_out: Optional[torch.Tensor] = None + blocks_to_copy: Optional[torch.Tensor] = None + + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Optional[Set[LoRARequest]] = None + multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None + + attn_metadata: Optional["AttentionMetadata"] = None + + @property + def broadcastable_fields(self) -> Tuple[str, ...]: + return ( + "num_seq_groups", + "blocks_to_swap_in", + "blocks_to_swap_out", + "blocks_to_copy", + "input_tokens", + "input_positions", + "lora_requests", + "lora_mapping", + "multi_modal_kwargs", + ) + + @classmethod + def _get_init_kwargs( # type: ignore + cls, **kwargs) -> Dict[str, Any]: + kwargs = _init_attn_metadata_from_kwargs(**kwargs) + return super()._get_init_kwargs(**kwargs) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = super().as_broadcastable_tensor_dict() + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + return tensor_dict + + +@dataclasses.dataclass(frozen=True) +class GPUModelInputWithPoolingMetadata(GPUModelInput): + """ + Used by the EmbeddingModelRunner. + """ + pooling_metadata: Optional["PoolingMetadata"] = None + + +@dataclasses.dataclass(frozen=True) +class GPUModelInputWithSamplingMetadata(GPUModelInput): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None + + @classmethod + def _get_init_kwargs( # type: ignore + cls, **kwargs) -> Dict[str, Any]: + kwargs = _init_sampling_metadata_from_kwargs(**kwargs) + return super()._get_init_kwargs(**kwargs) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = super().as_broadcastable_tensor_dict() + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + +@dataclasses.dataclass(frozen=True) +class ModelInputForNeuron(ModelInput): + """ + Used by the NeuronModelRunner. + """ + num_seq_groups: Optional[int] = None + + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + input_block_ids: Optional[torch.Tensor] = None + sampling_metadata: Optional["SamplingMetadata"] = None + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + raise NotImplementedError("ModelInputForNeuron cannot be broadcast.") From 11304cb40ff59c1f9c221cdd659c4d06b7c56083 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Fri, 14 Jun 2024 14:17:33 -0700 Subject: [PATCH 28/55] revert Signed-off-by: Stephanie Wang --- vllm/sequence.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index bd36083663526..88dcd6c442578 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,8 +1,8 @@ """Sequence and its related classes.""" import copy -import dataclasses import enum from abc import ABC, abstractmethod +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -18,7 +18,7 @@ from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics -@dataclasses.dataclass +@dataclass class Logprob: """Infos for supporting OpenAI compatible logprobs and token ranks. @@ -81,7 +81,7 @@ class SequenceStage(enum.Enum): DECODE = enum.auto() -@dataclasses.dataclass +@dataclass class RequestMetrics: """Metrics associated with a request. @@ -391,7 +391,7 @@ def __repr__(self) -> str: f"num_blocks={len(self.logical_token_blocks)})") -@dataclasses.dataclass +@dataclass class SequenceGroupState: """Mutable state tied to a specific sequence group""" @@ -768,7 +768,7 @@ def __eq__(self, other: object) -> bool: return self.embeddings == other.embeddings -@dataclasses.dataclass +@dataclass class SamplerOutput: """For each sequence group, we generate a list of SequenceOutput object, each of which contains one possible candidate for the next token. @@ -818,7 +818,7 @@ def __repr__(self) -> str: f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") -@dataclasses.dataclass +@dataclass class PoolerOutput: """The output from a pooling operation in the embedding model.""" outputs: List[EmbeddingSequenceGroupOutput] @@ -839,21 +839,18 @@ def __eq__(self, other: object): self.__class__) and self.outputs == other.outputs -@dataclasses.dataclass +@dataclass class ExecuteModelRequest: """The model execution request, containing CPU metadata only. The LLM engine should create an instance of this class for each request batch.""" # The sequence group metadata list. seq_group_metadata_list: List[SequenceGroupMetadata] # Blocks to swap in. List of CPU -> GPU block number. - blocks_to_swap_in: List[Tuple[int, int]] = dataclasses.field( - default_factory=list) + blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list) # Blocks to swap out. List of GPU -> CPU block number. - blocks_to_swap_out: List[Tuple[int, int]] = dataclasses.field( - default_factory=list) + blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list) # Blocks to copy. Source to dest block. - blocks_to_copy: List[Tuple[int, - int]] = dataclasses.field(default_factory=list) + blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list) # The number of slots for lookahead decoding. num_lookahead_slots: int = 0 # The number of requests in the running queue. From 99f532efb23045c7d4a36da242ce9b10f68bacaa Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Sat, 15 Jun 2024 15:13:11 -0700 Subject: [PATCH 29/55] refactor Signed-off-by: Stephanie Wang --- tests/worker/test_model_runner.py | 3 +- vllm/distributed/communication_op.py | 44 ------- vllm/worker/cpu_model_runner.py | 14 +- vllm/worker/cpu_worker.py | 105 ++++++--------- vllm/worker/embedding_model_runner.py | 21 ++- vllm/worker/model_input.py | 41 ++---- vllm/worker/model_runner.py | 176 ++++++++++++++------------ vllm/worker/model_runner_base.py | 45 +++++++ vllm/worker/neuron_model_runner.py | 12 +- vllm/worker/neuron_worker.py | 61 ++++----- vllm/worker/worker.py | 168 ++++++++---------------- vllm/worker/worker_base.py | 159 ++++++++++++++++------- vllm/worker/worker_input.py | 62 +++++++++ 13 files changed, 470 insertions(+), 441 deletions(-) create mode 100644 vllm/worker/model_runner_base.py create mode 100644 vllm/worker/worker_input.py diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index d3dcdda4946d8..9c1952851f8e1 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -351,8 +351,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): seq_group_metadata_list.append(seq_group_metadata) decode_metadata_list.append(seq_group_metadata) - model_input = model_runner.prepare_model_input_tensors( - seq_group_metadata_list) + model_input = model_runner.prepare_model_input(seq_group_metadata_list) (input_tokens, input_positions, attn_metadata) = ( model_input.input_tokens, model_input.input_positions, diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 0fbe7e1621905..32394a07b00b9 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Any, Dict, Optional, Union import torch @@ -7,43 +6,6 @@ from .parallel_state import get_tp_group -@dataclass -class DistributedContext: - communication_allowed: bool = True - - @staticmethod - def get_current() -> "DistributedContext": - """ - Get the singleton context. - """ - global _default_context - return _default_context - - -_default_context: DistributedContext = DistributedContext() - - -def disable_communication(fn): - """ - Helper decorator to disable control plane communication, i.e. - calling broadcast_tensor_dict will throw a RuntimeError. This can be used - to ensure that decorated code stays worker-local. - """ - - def wrapper(*args, **kwargs): - # Disallow control plane communication. - comm_ctx = DistributedContext.get_current() - original_comm_allowed = comm_ctx.communication_allowed - comm_ctx.communication_allowed = False - - try: - return fn(*args, **kwargs) - finally: - comm_ctx.communication_allowed = original_comm_allowed - - return wrapper - - def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: """All-reduce the input tensor across model parallel group.""" return get_tp_group().all_reduce(input_) @@ -65,12 +27,6 @@ def tensor_model_parallel_gather(input_: torch.Tensor, def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0): - ctx = DistributedContext.get_current() - if not ctx.communication_allowed: - raise RuntimeError( - "Control plane communication not allowed in functions decorated " - "with @disable_communication") - if not torch.distributed.is_initialized(): return tensor_dict return get_tp_group().broadcast_tensor_dict(tensor_dict, src) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 0cb5ff5ad3fff..b6f2eca7e5909 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Type import torch from torch import nn @@ -15,13 +15,14 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad from vllm.worker.model_input import CPUModelInput +from vllm.worker.model_runner_base import ModelRunnerBase logger = init_logger(__name__) _PAD_SLOT_ID = -1 -class CPUModelRunner: +class CPUModelRunner(ModelRunnerBase[CPUModelInput]): def __init__( self, @@ -270,7 +271,11 @@ def _prepare_decode( attn_metadata, ) - def prepare_model_input_tensors( + @staticmethod + def model_input_cls() -> Type[CPUModelInput]: + return CPUModelInput + + def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> CPUModelInput: @@ -303,9 +308,6 @@ def prepare_model_input_tensors( sampling_metadata=sampling_metadata, ) - def get_empty_model_input(self) -> CPUModelInput: - return CPUModelInput.new() - @torch.inference_mode() def execute_model( self, diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 5048bf7182fc8..4bbfc548604cb 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -8,16 +8,17 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import (broadcast_tensor_dict, disable_communication, - ensure_model_parallel_initialized, +from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.cpu_model_runner import CPUModelRunner -from vllm.worker.model_input import CPUModelInput -from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.model_runner_base import ModelRunnerBase +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, + LoraNotSupportedWorkerBase) +from vllm.worker.worker_input import WorkerInput logger = init_logger(__name__) @@ -111,7 +112,7 @@ def get_cache_block_size( return dtype_size * total -class CPUWorker(LoraNotSupportedWorkerBase): +class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): """A worker class that executes (a partition of) the model on a CPU socket. Each worker is associated with a single CPU socket. The worker is @@ -147,15 +148,15 @@ def __init__( self.distributed_init_method = distributed_init_method self.lora_config = lora_config self.vision_language_config = vision_language_config - self.is_driver_worker = is_driver_worker - if self.is_driver_worker: + self._is_driver_worker = is_driver_worker + if self._is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.model_runner = CPUModelRunner( + self._model_runner = CPUModelRunner( model_config, parallel_config, scheduler_config, @@ -177,7 +178,7 @@ def init_device(self) -> None: set_random_seed(self.model_config.seed) def load_model(self): - self.model_runner.load_model() + self._model_runner.load_model() def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of blocks available for the KV cache. @@ -248,7 +249,7 @@ def _init_cache_engine(self) -> None: self.parallel_config, self.device_config) self.cpu_cache = self.cache_engine.cpu_cache - self.model_runner.block_size = self.cache_engine.block_size + self._model_runner.block_size = self.cache_engine.block_size assert self.cpu_cache is not None @@ -256,22 +257,34 @@ def _init_cache_engine(self) -> None: for layer_cache in self.cpu_cache: layer_cache.fill_(0) - def cache_copy( + @property + def is_driver_worker(self) -> bool: + return self._is_driver_worker + + @property + def do_broadcast(self) -> bool: + return self.parallel_config.tensor_parallel_size > 1 + + @property + def model_runner(self) -> ModelRunnerBase: + return self._model_runner + + @property + def kv_cache(self) -> Optional[List[torch.Tensor]]: + return self.cpu_cache + + def execute_worker( self, - blocks_to_copy: torch.Tensor, + worker_input: WorkerInput, ) -> None: - if blocks_to_copy.numel() > 0: - self.cache_engine.copy(blocks_to_copy) + if (worker_input.blocks_to_copy is not None + and worker_input.blocks_to_copy.numel() > 0): + self.cache_engine.copy(worker_input.blocks_to_copy) @torch.inference_mode() - @disable_communication - def prepare_model_input_local( - self, execute_model_req: ExecuteModelRequest) -> CPUModelInput: + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: assert execute_model_req is not None - - model_input = self.model_runner.prepare_model_input_tensors( - execute_model_req.seq_group_metadata_list) - num_seq_groups: int = len(execute_model_req.seq_group_metadata_list) blocks_to_copy = execute_model_req.blocks_to_copy blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, @@ -279,57 +292,11 @@ def prepare_model_input_local( dtype=torch.int64).view(-1, 2) assert len(execute_model_req.blocks_to_swap_in) == 0 assert len(execute_model_req.blocks_to_swap_out) == 0 - return model_input.replace( + return WorkerInput( num_seq_groups=num_seq_groups, blocks_to_copy=blocks_to_copy, ) - @torch.inference_mode() - def prepare_model_input( - self, execute_model_req: Optional[ExecuteModelRequest] - ) -> Optional[CPUModelInput]: - if self.is_driver_worker: - if execute_model_req is None: - if self.parallel_config.tensor_parallel_size > 1: - # This signals that there's no more requests to process for - # now. All workers are running infinite loop with - # broadcast_tensor_dict, and it stops the loop when the - # driver broadcasts an empty input. Send an empty input to - # notify all other workers to stop their execution loop. - broadcast_tensor_dict({}, src=0) - return None - - model_input = self.prepare_model_input_local(execute_model_req) - if self.parallel_config.tensor_parallel_size > 1: - metadata_dict = model_input.as_broadcastable_tensor_dict() - broadcast_tensor_dict(metadata_dict, src=0) - else: - metadata_dict = broadcast_tensor_dict(src=0) - if not metadata_dict: - return None - - model_input = self.model_runner.get_empty_model_input() - model_input = model_input.new( - attn_backend=self.model_runner.attn_backend, **metadata_dict) - return model_input - - @torch.inference_mode() - @disable_communication - def execute_model_local( - self, - model_input: CPUModelInput, - ) -> List[SamplerOutput]: - self.cache_copy(model_input.blocks_to_copy) - - # If there is no input, we don't need to execute the model. - if model_input.num_seq_groups == 0: - return [] - - output = self.model_runner.execute_model(model_input, self.cpu_cache) - - # CPU worker only supports single-step execution. - return [output] - def init_distributed_environment(self) -> None: """Initialize the distributed environment.""" diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index d782bdce5fbe9..c777d32bcfe5a 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Type import torch @@ -10,12 +10,13 @@ from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata from vllm.worker.model_input import GPUModelInputWithPoolingMetadata -from vllm.worker.model_runner import ModelRunner +from vllm.worker.model_runner import GPUModelRunnerBase logger = init_logger(__name__) -class EmbeddingModelRunner(ModelRunner): +class EmbeddingModelRunner(GPUModelRunnerBase[GPUModelInputWithPoolingMetadata] + ): def __init__( self, @@ -41,9 +42,6 @@ def __init__( is_driver_worker=is_driver_worker, vision_language_config=vision_language_config) - def get_empty_model_input(self) -> GPUModelInputWithPoolingMetadata: - return GPUModelInputWithPoolingMetadata.new() - @torch.inference_mode() def execute_model( self, @@ -88,7 +86,11 @@ def execute_model( return self.model.pooler(hidden_states=hidden_states, pooling_metadata=model_input.pooling_metadata) - def prepare_model_input_tensors( + @staticmethod + def model_input_cls() -> Type[GPUModelInputWithPoolingMetadata]: + return GPUModelInputWithPoolingMetadata + + def prepare_model_input( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> GPUModelInputWithPoolingMetadata: @@ -100,10 +102,7 @@ def prepare_model_input_tensors( pooling_metadata = self._prepare_pooling(seq_group_metadata_list, model_input.seq_lens) - return GPUModelInputWithPoolingMetadata.new( - clone=model_input, - pooling_metadata=pooling_metadata, - ) + return model_input.replace(pooling_metadata=pooling_metadata) def _prepare_pooling( self, diff --git a/vllm/worker/model_input.py b/vllm/worker/model_input.py index b85cf100513c3..8abb5a72eaf14 100644 --- a/vllm/worker/model_input.py +++ b/vllm/worker/model_input.py @@ -105,29 +105,26 @@ def broadcastable_fields(self) -> Tuple[str, ...]: raise NotImplementedError() @classmethod - def _get_init_kwargs(cls, **kwargs) -> Dict[str, Any]: + def _get_init_kwargs(cls: Type[T], **kwargs) -> Dict[str, Any]: """ Helper method to extract all dataclass fields from the given kwargs. Override for fields that require some custom deserialization. """ - return kwargs + init_kwargs = {} + for field in dataclasses.fields(cls): + val = kwargs.get(field.name, None) + if val is not None: + init_kwargs[field.name] = val + return init_kwargs @classmethod - def new(cls: Type[T], clone: Optional["ModelInput"] = None, **kwargs) -> T: + def new(cls: Type[T], **kwargs) -> T: """ - Create a new instance of this class. Copy fields from `clone` if - provided. Populate the new instance with the given kwargs. + Create a new instance of this class. Populate the new instance with + the given kwargs. """ - clone_kwargs = {} - if clone is not None: - for field in dataclasses.fields(clone): - val = getattr(clone, field.name) - if val is not None: - clone_kwargs[field.name] = val - clone_kwargs = cls._get_init_kwargs(**clone_kwargs) - kwargs = cls._get_init_kwargs(**kwargs) - return cls(**clone_kwargs, **kwargs) + return cls(**kwargs) def replace(self: T, **kwargs) -> T: """ @@ -156,9 +153,6 @@ class CPUModelInput(ModelInput): """ Used by the CPUModelRunner. """ - num_seq_groups: Optional[int] = None - blocks_to_copy: Optional[torch.Tensor] = None - input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None @@ -169,8 +163,6 @@ class CPUModelInput(ModelInput): @property def broadcastable_fields(self) -> Tuple[str, ...]: return ( - "num_seq_groups", - "blocks_to_copy", "input_tokens", "input_positions", "multi_modal_kwargs", @@ -200,11 +192,6 @@ class GPUModelInput(ModelInput): runners that run additional steps should subclass this method to add additional fields. """ - num_seq_groups: Optional[int] = None - blocks_to_swap_in: Optional[torch.Tensor] = None - blocks_to_swap_out: Optional[torch.Tensor] = None - blocks_to_copy: Optional[torch.Tensor] = None - input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None @@ -218,10 +205,6 @@ class GPUModelInput(ModelInput): @property def broadcastable_fields(self) -> Tuple[str, ...]: return ( - "num_seq_groups", - "blocks_to_swap_in", - "blocks_to_swap_out", - "blocks_to_copy", "input_tokens", "input_positions", "lora_requests", @@ -276,8 +259,6 @@ class ModelInputForNeuron(ModelInput): """ Used by the NeuronModelRunner. """ - num_seq_groups: Optional[int] = None - input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None input_block_ids: Optional[torch.Tensor] = None diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1c22041bcb19f..fd1941456bf95 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -2,7 +2,7 @@ import time import warnings from collections import defaultdict -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Dict, List, Optional, Set, Tuple, Type, TypeVar, Union import numpy as np import torch @@ -27,6 +27,7 @@ is_pin_memory_available, make_tensor_with_pad) from vllm.worker.model_input import (GPUModelInput, GPUModelInputWithSamplingMetadata) +from vllm.worker.model_runner_base import ModelRunnerBase logger = init_logger(__name__) @@ -40,8 +41,10 @@ ] _NUM_WARMUP_ITERS = 2 +TGPUModelInput = TypeVar('TGPUModelInput', bound="GPUModelInput") -class ModelRunner: + +class GPUModelRunnerBase(ModelRunnerBase[TGPUModelInput]): def __init__( self, @@ -209,7 +212,7 @@ def get_max_block_per_batch(self) -> int: def _prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> GPUModelInput: + ) -> TGPUModelInput: """Helper method to prepare the model input based on a given sequence group. Prepares metadata needed for the base model forward pass but not metadata for possible additional steps, e.g., sampling. @@ -263,7 +266,8 @@ def _prepare_model_input_tensors( paged_kv_last_page_len: List[int] = [] if len(seq_group_metadata_list) == 0: - return GPUModelInput() + model_input_cls = self.model_input_cls() + return model_input_cls() if self.sliding_window is not None: sliding_window_blocks = (self.sliding_window + self.block_size - @@ -613,7 +617,8 @@ def _prepare_model_input_tensors( for k, v in multi_modal_kwargs_list.items() } - return GPUModelInput.new( + model_input_cls = self.model_input_cls() + return model_input_cls( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, attn_metadata=attn_metadata, @@ -624,84 +629,6 @@ def _prepare_model_input_tensors( multi_modal_kwargs=multi_modal_kwargs, ) - def prepare_model_input_tensors( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> GPUModelInput: - """Prepare the model input based on a given sequence group, including - metadata for the sampling step. - - The API assumes seq_group_metadata_list is sorted by prefill -> decode. - - The result tensors and data structure also batches input in prefill - -> decode order. For example, - - - input_tokens[:num_prefill_tokens] contains prefill tokens. - - input_tokens[num_prefill_tokens:] contains decode tokens. - - If cuda graph is required, this API automatically pads inputs. - """ - model_input = self._prepare_model_input_tensors( - seq_group_metadata_list) - sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, - model_input.seq_lens, - model_input.query_lens, - self.device, - self.pin_memory) - return GPUModelInputWithSamplingMetadata.new( - clone=model_input, sampling_metadata=sampling_metadata) - - def get_empty_model_input(self) -> GPUModelInput: - return GPUModelInputWithSamplingMetadata.new() - - @torch.inference_mode() - def execute_model( - self, - model_input: GPUModelInputWithSamplingMetadata, - kv_caches: List[torch.Tensor], - ) -> Optional[SamplerOutput]: - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) - - # Currently cuda graph is only supported by the decode phase. - assert model_input.attn_metadata is not None - prefill_meta = model_input.attn_metadata.prefill_metadata - decode_meta = model_input.attn_metadata.decode_metadata - if prefill_meta is None and decode_meta.use_cuda_graph: - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[graph_batch_size] - else: - model_executable = self.model - - multi_modal_kwargs = model_input.multi_modal_kwargs or {} - hidden_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - **multi_modal_kwargs, - ) - - # Compute the logits. - logits = self.model.compute_logits(hidden_states, - model_input.sampling_metadata) - - # Only perform sampling in the driver worker. - if not self.is_driver_worker: - return None - - # Sample the next token. - output = self.model.sample( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - - return output - @torch.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. @@ -774,7 +701,7 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - model_input = self.prepare_model_input_tensors(seqs) + model_input = self.prepare_model_input(seqs) self.execute_model(model_input, kv_caches) torch.cuda.synchronize() return @@ -903,6 +830,87 @@ def vocab_size(self) -> int: return self.model_config.get_vocab_size() +class ModelRunner(GPUModelRunnerBase[GPUModelInputWithSamplingMetadata]): + + @staticmethod + def model_input_cls() -> Type[GPUModelInputWithSamplingMetadata]: + return GPUModelInputWithSamplingMetadata + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> GPUModelInputWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + + If cuda graph is required, this API automatically pads inputs. + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + self.pin_memory) + return model_input.replace(sampling_metadata=sampling_metadata) + + @torch.inference_mode() + def execute_model( + self, + model_input: GPUModelInputWithSamplingMetadata, + kv_caches: List[torch.Tensor], + ) -> SamplerOutput: + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + # Currently cuda graph is only supported by the decode phase. + assert model_input.attn_metadata is not None + prefill_meta = model_input.attn_metadata.prefill_metadata + decode_meta = model_input.attn_metadata.decode_metadata + if prefill_meta is None and decode_meta.use_cuda_graph: + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = self.graph_runners[graph_batch_size] + else: + model_executable = self.model + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + hidden_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + **multi_modal_kwargs, + ) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return None + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + + return output + + class CUDAGraphRunner: def __init__(self, model: nn.Module): diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py new file mode 100644 index 0000000000000..e618d40a53a7f --- /dev/null +++ b/vllm/worker/model_runner_base.py @@ -0,0 +1,45 @@ +from abc import ABC, abstractmethod +from typing import Generic, List, Optional, Type, TypeVar + +import torch + +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.worker.model_input import ModelInput + +T = TypeVar('T', bound="ModelInput") + + +class ModelRunnerBase(ABC, Generic[T]): + """ + Model runner interface that abstracts a particular hardware and/or type of + model. Model execution may communicate data with model runners in other + processes, but it should not include control plane metadata communication. + """ + + @staticmethod + @abstractmethod + def model_input_cls() -> Type[T]: + raise NotImplementedError + + @abstractmethod + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> T: + """ + Prepare the inputs to ModelRunnerBase.execute_model from an execution + request. This method may move data to the worker's local device. It is + not allowed to communicate with other workers or devices. + """ + raise NotImplementedError + + @torch.inference_mode() + def execute_model( + self, + model_input: T, + kv_caches: Optional[List[torch.Tensor]], + ) -> Optional[SamplerOutput]: + """ + Execute the model on the given input. + """ + raise NotImplementedError diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 61c0a2b3b1ee7..55faf54c9358a 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Type import torch from torch import nn @@ -11,11 +11,12 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.worker.model_input import ModelInputForNeuron +from vllm.worker.model_runner_base import ModelRunnerBase logger = init_logger(__name__) -class NeuronModelRunner: +class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): def __init__( self, @@ -140,7 +141,11 @@ def _prepare_decode( return input_tokens, input_positions, input_block_ids - def prepare_model_input_tensors( + @staticmethod + def model_input_cls() -> Type[ModelInputForNeuron]: + return ModelInputForNeuron + + def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> ModelInputForNeuron: @@ -174,6 +179,7 @@ def prepare_model_input_tensors( def execute_model( self, model_input: ModelInputForNeuron, + kv_caches: Optional[List[torch.Tensor]] = None, ) -> Optional[SamplerOutput]: hidden_states = self.model( input_ids=model_input.input_tokens, diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 3017dfe76ab8e..d7d5f9ac58770 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -6,15 +6,16 @@ from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from vllm.distributed import disable_communication from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.worker.model_input import ModelInputForNeuron +from vllm.sequence import ExecuteModelRequest +from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.neuron_model_runner import NeuronModelRunner -from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, + LoraNotSupportedWorkerBase) +from vllm.worker.worker_input import WorkerInput -class NeuronWorker(LoraNotSupportedWorkerBase): +class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): """A worker class that executes the model on a group of neuron cores. """ @@ -36,15 +37,15 @@ def __init__( from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.model_runner = NeuronModelRunner(model_config, parallel_config, - scheduler_config, device_config) + self._model_runner = NeuronModelRunner(model_config, parallel_config, + scheduler_config, device_config) def init_device(self) -> None: # Set random seed. set_random_seed(self.model_config.seed) def load_model(self): - self.model_runner.load_model() + self._model_runner.load_model() def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks. @@ -75,35 +76,27 @@ def initialize_cache(self, num_gpu_blocks: int, self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - @torch.inference_mode() - @disable_communication - def prepare_model_input_local( - self, - execute_model_req: ExecuteModelRequest) -> ModelInputForNeuron: - model_input = self.model_runner.prepare_model_input_tensors( - execute_model_req.seq_group_metadata_list) - return model_input.replace(num_seq_groups=len( - execute_model_req.seq_group_metadata_list), ) + @property + def is_driver_worker(self) -> bool: + return True + + @property + def do_broadcast(self) -> bool: + return False - def prepare_model_input( - self, execute_model_req: Optional[ExecuteModelRequest] - ) -> ModelInputForNeuron: - assert execute_model_req is not None - return self.prepare_model_input_local(execute_model_req) + @property + def model_runner(self) -> ModelRunnerBase: + return self._model_runner + + @property + def kv_cache(self) -> Optional[List[torch.Tensor]]: + return None @torch.inference_mode() - @disable_communication - def execute_model_local( - self, model_input: ModelInputForNeuron) -> List[SamplerOutput]: - # If there is no input, we don't need to execute the model. - if model_input.num_seq_groups == 0: - return [] - - output = self.model_runner.execute_model(model_input) - - # Neuron worker only supports single-step output. Wrap the output in a - # list to conform to interface. - return [output] + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + return WorkerInput(num_seq_groups=len( + execute_model_req.seq_group_metadata_list), ) def get_cache_block_size_bytes(self) -> int: """Determine the size in bytes of a cache block. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d96e85035e280..fd246261ed0b1 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,7 +1,7 @@ """A GPU worker class.""" import gc import os -from typing import List, Optional, Set, Tuple, Union +from typing import List, Optional, Set, Tuple, Type import torch import torch.distributed @@ -9,22 +9,22 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VisionLanguageConfig) -from vllm.distributed import (broadcast_tensor_dict, disable_communication, - ensure_model_parallel_initialized, +from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner -from vllm.worker.model_input import GPUModelInput -from vllm.worker.model_runner import ModelRunner -from vllm.worker.worker_base import WorkerBase +from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner +from vllm.worker.model_runner_base import ModelRunnerBase +from vllm.worker.worker_base import LocalOrDistributedWorkerBase +from vllm.worker.worker_input import WorkerInput -class Worker(WorkerBase): +class Worker(LocalOrDistributedWorkerBase): """A worker class that executes (a partition of) the model on a GPU. Each worker is associated with a single GPU. The worker is responsible for @@ -58,8 +58,8 @@ def __init__( self.distributed_init_method = distributed_init_method self.lora_config = lora_config self.load_config = load_config - self.is_driver_worker = is_driver_worker - if self.is_driver_worker: + self._is_driver_worker = is_driver_worker + if self._is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." if self.model_config.trust_remote_code: @@ -71,9 +71,10 @@ def __init__( assert not self.lora_config, ( "To be tested: vision language model with LoRA settings.") - ModelRunnerClass = (EmbeddingModelRunner if - self.model_config.embedding_mode else ModelRunner) - self.model_runner = ModelRunnerClass( + ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner + if self.model_config.embedding_mode: + ModelRunnerClass = EmbeddingModelRunner + self._model_runner: GPUModelRunnerBase = ModelRunnerClass( model_config, parallel_config, scheduler_config, @@ -120,7 +121,7 @@ def init_device(self) -> None: set_random_seed(self.model_config.seed) def load_model(self): - self.model_runner.load_model() + self._model_runner.load_model() def save_sharded_state( self, @@ -128,7 +129,7 @@ def save_sharded_state( pattern: Optional[str] = None, max_size: Optional[int] = None, ) -> None: - self.model_runner.save_sharded_state( + self._model_runner.save_sharded_state( path, pattern=pattern, max_size=max_size, @@ -138,7 +139,7 @@ def save_tensorized_model( self, tensorizer_config: TensorizerConfig, ) -> None: - self.model_runner.save_tensorized_model( + self._model_runner.save_tensorized_model( tensorizer_config=tensorizer_config, ) @torch.inference_mode() @@ -160,7 +161,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Execute a forward pass with dummy inputs to profile the memory usage # of the model. - self.model_runner.profile_run() + self._model_runner.profile_run() # Calculate the number of blocks that can be allocated with the # profiled peak memory. @@ -181,8 +182,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) - if self.model_runner.lora_manager: - self.model_runner.remove_all_loras() + if self._model_runner.lora_manager: + self._model_runner.remove_all_loras() gc.collect() torch.cuda.empty_cache() return num_gpu_blocks, num_cpu_blocks @@ -211,32 +212,30 @@ def _init_cache_engine(self): def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: - self.model_runner.capture_model(self.gpu_cache) + self._model_runner.capture_model(self.gpu_cache) # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) - def cache_swap( - self, - blocks_to_swap_in: Optional[torch.Tensor], - blocks_to_swap_out: Optional[torch.Tensor], - blocks_to_copy: Optional[torch.Tensor], - ) -> None: - # Issue cache operations. - if blocks_to_swap_in is not None and blocks_to_swap_in.numel() > 0: - self.cache_engine.swap_in(blocks_to_swap_in) - if blocks_to_swap_out is not None and blocks_to_swap_out.numel() > 0: - self.cache_engine.swap_out(blocks_to_swap_out) - if blocks_to_copy is not None and blocks_to_copy.numel() > 0: - self.cache_engine.copy(blocks_to_copy) + @property + def is_driver_worker(self) -> bool: + return self._is_driver_worker - @torch.inference_mode() - @disable_communication - def prepare_model_input_local( - self, execute_model_req: ExecuteModelRequest) -> GPUModelInput: - model_input = self.model_runner.prepare_model_input_tensors( - execute_model_req.seq_group_metadata_list) + @property + def do_broadcast(self) -> bool: + return self.parallel_config.tensor_parallel_size > 1 + @property + def model_runner(self) -> ModelRunnerBase: + return self._model_runner + + @property + def kv_cache(self) -> Optional[List[torch.Tensor]]: + return self.gpu_cache + + @torch.inference_mode() + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: num_seq_groups = len(execute_model_req.seq_group_metadata_list) # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. # they contain parameters to launch cudamemcpyasync. @@ -253,7 +252,7 @@ def prepare_model_input_local( device=self.device, dtype=torch.int64).view(-1, 2) - return model_input.replace( + return WorkerInput( num_seq_groups=num_seq_groups, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, @@ -261,85 +260,26 @@ def prepare_model_input_local( ) @torch.inference_mode() - def prepare_model_input( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[GPUModelInput]: - if self.is_driver_worker: - if execute_model_req is None: - if self.parallel_config.tensor_parallel_size > 1: - # This signals that there's no more requests to process for - # now. All workers are running infinite loop with - # broadcast_tensor_dict, and it stops the loop when the - # driver broadcasts an empty input. Send an empty input to - # notify all other workers to stop their execution loop. - broadcast_tensor_dict({}, src=0) - return None - - model_input = self.prepare_model_input_local(execute_model_req) - if self.parallel_config.tensor_parallel_size > 1: - metadata_dict = model_input.as_broadcastable_tensor_dict() - broadcast_tensor_dict(metadata_dict, src=0) - else: - metadata_dict = broadcast_tensor_dict(src=0) - if not metadata_dict: - return None - - model_input = self.model_runner.get_empty_model_input() - model_input = model_input.new( - attn_backend=self.model_runner.attn_backend, **metadata_dict) - return model_input - - @torch.inference_mode() - @disable_communication - def execute_model_local( - self, model_input: GPUModelInput - ) -> List[Union[SamplerOutput, PoolerOutput]]: - self.cache_swap(model_input.blocks_to_swap_in, - model_input.blocks_to_swap_out, - model_input.blocks_to_copy) - - # If there is no input, we don't need to execute the model. - if model_input.num_seq_groups == 0: - return [] - - output = self.model_runner.execute_model(model_input, self.gpu_cache) - - # Worker only supports single-step execution. Wrap the output in a list - # to conform to interface. - return [output] - - def _execute_model_non_driver(self) -> bool: - """Execute model in parallel worker. - - Returns True iff there are remaining sequences to process. - """ - assert not self.is_driver_worker - data = broadcast_tensor_dict(src=0) - if not data: - return False - - num_seq_groups = data.get("num_seq_groups", 0) - blocks_to_swap_in = data.get("blocks_to_swap_in") - blocks_to_swap_out = data.get("blocks_to_swap_out") - blocks_to_copy = data.get("blocks_to_copy") - self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) - - # If there is no input, we don't need to execute the model. - if num_seq_groups == 0: - return False - - self.model_runner.execute_model(None, self.gpu_cache) - return True + def execute_worker(self, worker_input: WorkerInput) -> None: + # Issue cache operations. + if (worker_input.blocks_to_swap_in is not None + and worker_input.blocks_to_swap_in.numel() > 0): + self.cache_engine.swap_in(worker_input.blocks_to_swap_in) + if (worker_input.blocks_to_swap_out is not None + and worker_input.blocks_to_swap_out.numel() > 0): + self.cache_engine.swap_out(worker_input.blocks_to_swap_out) + if (worker_input.blocks_to_copy is not None + and worker_input.blocks_to_copy.numel() > 0): + self.cache_engine.copy(worker_input.blocks_to_copy) def add_lora(self, lora_request: LoRARequest) -> bool: - return self.model_runner.add_lora(lora_request) + return self._model_runner.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: - return self.model_runner.remove_lora(lora_id) + return self._model_runner.remove_lora(lora_id) def list_loras(self) -> Set[int]: - return self.model_runner.list_loras() + return self._model_runner.list_loras() @property def max_model_len(self) -> int: @@ -347,7 +287,7 @@ def max_model_len(self) -> int: @property def vocab_size(self) -> int: - return self.model_runner.vocab_size + return self._model_runner.vocab_size def get_cache_block_size_bytes(self) -> int: """Get the size of the KV cache block size in bytes. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 568cf1b395a32..a0ee53a1abe46 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -5,20 +5,23 @@ import torch -from vllm.distributed import disable_communication +from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) from vllm.worker.model_input import ModelInput +from vllm.worker.model_runner_base import ModelRunnerBase +from vllm.worker.worker_input import WorkerInput logger = init_logger(__name__) class WorkerBase(ABC): """Worker interface that allows vLLM to cleanly separate implementations for - different hardware. + different hardware. Also abstracts control plane communication, e.g., to + communicate request metadata to other workers. """ @abstractmethod @@ -63,51 +66,9 @@ def start_worker_execution_loop(self) -> None: return None @abstractmethod - @disable_communication - def prepare_model_input_local( - self, execute_model_req: ExecuteModelRequest) -> ModelInput: - """ - Prepare a model execution request locally. This method may move data to - the worker's local device. It is not allowed to communicate with - other workers or devices. Subclasses should keep the - @disable_communication decorator to enforce this. - """ - raise NotImplementedError - - @abstractmethod - def prepare_model_input( - self, execute_model_req: Optional[ExecuteModelRequest] - ) -> Optional[ModelInput]: - """ - Prepare a model execution request. Communication with other workers - may occur to produce the model input that should be passed to - execute_model. - """ - raise NotImplementedError - def execute_model( self, execute_model_req: Optional[ExecuteModelRequest] ) -> Optional[List[SamplerOutput]]: - """Executes at least one model step on the given sequences, unless no - sequences are provided. Communication with other workers - may occur to produce the model input that should be passed to - the model runner.""" - model_input: Optional[ModelInput] = self.prepare_model_input( - execute_model_req=execute_model_req) - if model_input is None: - return None - - return self.execute_model_local(model_input) - - @abstractmethod - @disable_communication - def execute_model_local(self, - model_input: ModelInput) -> List[SamplerOutput]: - """Executes at least one model step on the given sequences, unless no - sequences are provided. This method is not allowed to communciate - metadata to other workers. Subclasses should keep the - @disable_communication decorator to enforce this. - """ raise NotImplementedError @abstractmethod @@ -145,6 +106,116 @@ def list_loras(self) -> Set[int]: raise ValueError(f"{type(self)} does not support LoRA") +class LocalOrDistributedWorkerBase(WorkerBase): + + @property + @abstractmethod + def is_driver_worker(self) -> bool: + """ + Used by the default `execute_model` to check whether this is the driver + worker. The driver worker is responsible for broadcasting request + inputs to other workers in its TP group. If WorkerBase subclass only + supports single-worker execution, then this method should return True. + """ + raise NotImplementedError + + @property + @abstractmethod + def do_broadcast(self) -> bool: + """ + Used by the default `execute_model` to check whether broadcast is + needed to transfer request inputs from the driver worker to other + workers in the TP group. If WorkerBase subclass only supports + single-worker execution, then this method should return False. + """ + raise NotImplementedError + + @property + @abstractmethod + def model_runner(self) -> ModelRunnerBase: + """ + Get the worker's model runner. Used by the default `execute_model`. If + the worker's model runner does not follow the ModelRunnerBase + interface, then this method should raise NotImplementedError. + """ + raise NotImplementedError + + @property + @abstractmethod + def kv_cache(self) -> Optional[List[torch.Tensor]]: + """ + Get the kv cache to pass to the worker's model runner. Used by the + default `execute_model`. If the worker's model runner does not follow + the ModelRunnerBase interface, then this method should raise + NotImplementedError. + """ + raise NotImplementedError + + @abstractmethod + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + """ + Prepare the inputs to WorkerBase.execute_worker from an execution + request. This method may move data to the worker's local device. It is + not allowed to communicate with other workers or devices. + """ + raise NotImplementedError + + @abstractmethod + def execute_worker(self, worker_input: WorkerInput) -> None: + """ + Process an execution request. + """ + raise NotImplementedError + + def execute_model( + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: + """Executes at least one model step on the given sequences, unless no + sequences are provided.""" + if self.is_driver_worker: + if execute_model_req is None: + if self.do_broadcast: + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the + # driver broadcasts an empty input. Send an empty input to + # notify all other workers to stop their execution loop. + broadcast_tensor_dict({}, src=0) + return None + + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: ModelInput = self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list) + + if self.do_broadcast: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update( + model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + else: + assert self.do_broadcast + broadcast_data = broadcast_tensor_dict(src=0) + if not broadcast_data: + return None + + worker_input = WorkerInput.new(**broadcast_data) + model_input_cls = self.model_runner.model_input_cls() + model_input = model_input_cls.new(**broadcast_data) + + self.execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + output = self.model_runner.execute_model(model_input, self.kv_cache) + # Worker only supports single-step execution. Wrap the output in a + # list to conform to interface. + return [output] + + class WorkerWrapperBase: """ The whole point of this class is to lazily initialize the worker. diff --git a/vllm/worker/worker_input.py b/vllm/worker/worker_input.py new file mode 100644 index 0000000000000..f010270f816e1 --- /dev/null +++ b/vllm/worker/worker_input.py @@ -0,0 +1,62 @@ +"""Worker-local model inputs. These define the inputs to different model +runners.""" +import dataclasses +from typing import Any, Dict, Optional, Type, Union + +import torch + + +@dataclasses.dataclass(frozen=True) +class WorkerInput: + """Local inputs to each worker. May contain device-specific data. Different + worker backends may have different methods of converting from the global + ExecuteModelRequest produced by the LLM engine to the worker-local + WorkerInput objects. + + Subclasses of WorkerBase should inherit from this class and add their + required fields. For distributed executors, any fields that should be sent + during a broadcast op should also be added to the broadcastable_fields. + During execution, these fields will be extracted from the source copy and + broadcasted to all workers using broadcast_tensor_dict. + """ + + num_seq_groups: Optional[int] = None + blocks_to_swap_in: Optional[torch.Tensor] = None + blocks_to_swap_out: Optional[torch.Tensor] = None + blocks_to_copy: Optional[torch.Tensor] = None + + @classmethod + def _get_init_kwargs(cls: Type["WorkerInput"], **kwargs) -> Dict[str, Any]: + """ + Helper method to extract all dataclass fields from the given kwargs. + Override for fields that require some custom deserialization. + """ + init_kwargs = {} + for field in dataclasses.fields(cls): + val = kwargs.get(field.name, None) + if val is not None: + init_kwargs[field.name] = val + return init_kwargs + + @classmethod + def new(cls: Type["WorkerInput"], **kwargs) -> "WorkerInput": + """ + Create a new instance of this class. Populate the new instance with + the given kwargs. + """ + kwargs = cls._get_init_kwargs(**kwargs) + return cls(**kwargs) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + """ + Extract broadcastable fields. Override for fields that require some + custom deserialization. + """ + tensor_dict: Dict[str, Union[int, torch.Tensor]] = {} + for field in dataclasses.fields(self): + val = getattr(self, field.name, None) + if val is not None: + tensor_dict[field.name] = val + + return tensor_dict From 797a7cf2308d4563cdd1bf603790512212897d8e Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Sat, 15 Jun 2024 15:19:48 -0700 Subject: [PATCH 30/55] doc Signed-off-by: Stephanie Wang --- vllm/worker/cpu_worker.py | 2 +- vllm/worker/model_runner.py | 6 ++++++ vllm/worker/neuron_worker.py | 2 +- vllm/worker/worker.py | 2 +- vllm/worker/worker_base.py | 14 ++++++++++---- 5 files changed, 19 insertions(+), 7 deletions(-) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 4bbfc548604cb..442ce1dc0edb8 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -262,7 +262,7 @@ def is_driver_worker(self) -> bool: return self._is_driver_worker @property - def do_broadcast(self) -> bool: + def do_metadata_broadcast(self) -> bool: return self.parallel_config.tensor_parallel_size > 1 @property diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fd1941456bf95..759f58e0b2802 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -45,6 +45,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TGPUModelInput]): + """ + Helper class for shared methods between GPU model runners. + """ def __init__( self, @@ -831,6 +834,9 @@ def vocab_size(self) -> int: class ModelRunner(GPUModelRunnerBase[GPUModelInputWithSamplingMetadata]): + """ + GPU model runner with sampling step. + """ @staticmethod def model_input_cls() -> Type[GPUModelInputWithSamplingMetadata]: diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index d7d5f9ac58770..ed478c0ac1db6 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -81,7 +81,7 @@ def is_driver_worker(self) -> bool: return True @property - def do_broadcast(self) -> bool: + def do_metadata_broadcast(self) -> bool: return False @property diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index fd246261ed0b1..95171be942ac3 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -222,7 +222,7 @@ def is_driver_worker(self) -> bool: return self._is_driver_worker @property - def do_broadcast(self) -> bool: + def do_metadata_broadcast(self) -> bool: return self.parallel_config.tensor_parallel_size > 1 @property diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index a0ee53a1abe46..c8bac1f256559 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -107,6 +107,12 @@ def list_loras(self) -> Set[int]: class LocalOrDistributedWorkerBase(WorkerBase): + """ + Partial implementation of WorkerBase that has a default execute_model + definition to perform metadata transfer between workers when in distributed + mode. Subclasses of this interface should only need to implement + worker-local logic. + """ @property @abstractmethod @@ -121,7 +127,7 @@ def is_driver_worker(self) -> bool: @property @abstractmethod - def do_broadcast(self) -> bool: + def do_metadata_broadcast(self) -> bool: """ Used by the default `execute_model` to check whether broadcast is needed to transfer request inputs from the driver worker to other @@ -175,7 +181,7 @@ def execute_model( sequences are provided.""" if self.is_driver_worker: if execute_model_req is None: - if self.do_broadcast: + if self.do_metadata_broadcast: # This signals that there's no more requests to process for # now. All workers are running infinite loop with # broadcast_tensor_dict, and it stops the loop when the @@ -189,13 +195,13 @@ def execute_model( model_input: ModelInput = self.model_runner.prepare_model_input( execute_model_req.seq_group_metadata_list) - if self.do_broadcast: + if self.do_metadata_broadcast: broadcast_data = worker_input.as_broadcastable_tensor_dict() broadcast_data.update( model_input.as_broadcastable_tensor_dict()) broadcast_tensor_dict(broadcast_data, src=0) else: - assert self.do_broadcast + assert self.do_metadata_broadcast broadcast_data = broadcast_tensor_dict(src=0) if not broadcast_data: return None From 6ad251347114e38793b27d430413102334b01cbe Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Sat, 15 Jun 2024 15:28:24 -0700 Subject: [PATCH 31/55] revert spec decode and doc Signed-off-by: Stephanie Wang --- vllm/spec_decode/multi_step_worker.py | 10 +++---- vllm/spec_decode/ngram_worker.py | 27 +++---------------- vllm/spec_decode/spec_decode_worker.py | 36 +++++--------------------- vllm/spec_decode/util.py | 4 +-- vllm/worker/model_runner_base.py | 3 +++ vllm/worker/worker_base.py | 13 +++++----- 6 files changed, 26 insertions(+), 67 deletions(-) diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index fe15ea33b5f36..668ceefe6175f 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -1,10 +1,10 @@ import copy import weakref -from typing import List, Tuple +from typing import Dict, List, Tuple import torch -from vllm.sequence import (ExecuteModelRequest, SamplerOutput, +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, SequenceGroupMetadata) from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase @@ -71,7 +71,7 @@ def sampler_output( sample_len) # Run model sample_len times. - model_outputs = [] + model_outputs: List[SamplerOutput] = [] for _ in range(sample_len): model_output = super().execute_model( execute_model_req=copied_execute_model_req) @@ -132,7 +132,7 @@ def _shallow_copy_inputs( # Shallow-copy the list of SequenceGroupMetadata. This allows us to # append tokens and change is_prompt without external side-effects. - new_seq_group_metadata_list = [] + new_seq_group_metadata_list: List[SequenceGroupMetadata] = [] for old_seq_group_metadata in seq_group_metadata_list: # We must shallow-copy seq_group_metadata as is_prompt could change. @@ -140,7 +140,7 @@ def _shallow_copy_inputs( new_seq_group_metadata_list.append(seq_group_metadata) # We must shallow-copy seq_data as we will append token ids - new_seq_data = {} + new_seq_data: Dict[int, SequenceData] = {} for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): new_seq_data[seq_id] = copy.copy(old_seq_data) new_seq_data[ diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index a0968f2f658bc..23a3e1649914b 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -7,7 +7,6 @@ from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer -from vllm.worker.model_input import ModelInput from vllm.worker.worker_base import LoraNotSupportedWorkerBase @@ -49,7 +48,7 @@ def sampler_output( self, execute_model_req: ExecuteModelRequest, sample_len: int, - ) -> Tuple[Optional[List[SamplerOutput]], bool]: + ) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]: """NGram match algo to pick proposal candidate. Returns the list of sampler output, one per SequenceGroupMetadata. @@ -59,8 +58,8 @@ def sampler_output( self._raise_if_unsupported(execute_model_req) has_spec_out = False - token_id_list = [] - token_prob_list = [] + token_id_list: List[Optional[torch.Tensor]] = [] + token_prob_list: List[Optional[torch.Tensor]] = [] for idx, seq_group_metadata in enumerate( execute_model_req.seq_group_metadata_list): seq_data = next(iter(seq_group_metadata.seq_data.values())) @@ -162,23 +161,3 @@ def _raise_if_unsupported( execute_model_req.seq_group_metadata_list): raise NotImplementedError( "NGramWorker does not support beam search.") - - @torch.inference_mode() - def prepare_model_input_local( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - raise NotImplementedError("NGramWorker does not allow direct calls to " - "prepare_model_input_local") - - @torch.inference_mode() - def prepare_model_input( - self, execute_model_req: Optional[ExecuteModelRequest] - ) -> List[SamplerOutput]: - raise NotImplementedError("NGramWorker does not allow direct calls to " - "prepare_model_input") - - @torch.inference_mode() - def execute_model_local(self, - model_input: ModelInput) -> List[SamplerOutput]: - raise NotImplementedError("NGramWorker does not allow direct calls to " - "execute_model_local") diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 4facd50a0c8cc..03fad5663037b 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -7,8 +7,8 @@ from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.sequence import (ExecuteModelRequest, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, + SamplerOutput, SequenceGroupMetadata) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) @@ -20,7 +20,6 @@ get_all_num_logprobs, get_all_seq_ids, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) -from vllm.worker.model_input import ModelInput from vllm.worker.worker import Worker from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase @@ -233,29 +232,6 @@ def initialize_cache(self, num_gpu_blocks: int, self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) - @torch.inference_mode() - def prepare_model_input_local( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - raise NotImplementedError( - "SpecDecodeWorker does not allow direct calls to " - "prepare_model_input_local") - - @torch.inference_mode() - def prepare_model_input( - self, execute_model_req: Optional[ExecuteModelRequest] - ) -> List[SamplerOutput]: - raise NotImplementedError( - "SpecDecodeWorker does not allow direct calls to " - "prepare_model_input") - - @torch.inference_mode() - def execute_model_local(self, - model_input: ModelInput) -> List[SamplerOutput]: - raise NotImplementedError( - "SpecDecodeWorker does not allow direct calls to " - "execute_model_local") - @torch.inference_mode() def execute_model( self, @@ -388,9 +364,9 @@ def _run_non_driver_rank(self) -> bool: # We run the proposer once per lookahead slot. In the future we should # delegate how many times it runs to the proposer. for _ in range(max(num_lookahead_slots, 1)): - self.proposer_worker.execute_model(execute_model_req=None) + self.proposer_worker.execute_model() - self.scorer_worker.execute_model(execute_model_req=None) + self.scorer_worker.execute_model() return True @nvtx_range("spec_decode_worker._run_speculative_decoding_step") @@ -540,13 +516,13 @@ def _create_output_sampler_list( topk_indices_by_step = topk_indices_by_step.tolist() # Construct the output on a per-step, per-sequence basis. - sampler_output_list = [] + sampler_output_list: List[SamplerOutput] = [] for step_index in range(num_steps): if all(token_id == -1 for token_id in accepted_token_ids_by_step[step_index]): break - step_output_token_ids = [] + step_output_token_ids: List[CompletionSequenceGroupOutput] = [] for sequence_index in range(batch_size): # Each sequence may have a different num_logprobs; retrieve it. num_logprobs = num_logprobs_per_seq[sequence_index] diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 60ed9d39eb8d6..9bbe3f8d16117 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -26,10 +26,10 @@ def get_all_num_logprobs( sequence. """ - all_num_logprobs = [] + all_num_logprobs: List[int] = [] for seq_group_metadata in seq_group_metadata_list: num_logprobs = seq_group_metadata.sampling_params.logprobs - if seq_group_metadata.sampling_params.logprobs is None: + if num_logprobs is None: num_logprobs = 0 all_num_logprobs.append(num_logprobs) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index e618d40a53a7f..08ef1dc11dca4 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -14,6 +14,9 @@ class ModelRunnerBase(ABC, Generic[T]): Model runner interface that abstracts a particular hardware and/or type of model. Model execution may communicate data with model runners in other processes, but it should not include control plane metadata communication. + + Each ModelRunnerBase subclass should define a corresponding ModelInput + subclass. """ @staticmethod diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index c8bac1f256559..3cbba1d7f643b 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -108,10 +108,12 @@ def list_loras(self) -> Set[int]: class LocalOrDistributedWorkerBase(WorkerBase): """ - Partial implementation of WorkerBase that has a default execute_model + Partial implementation of WorkerBase that has a default `execute_model` definition to perform metadata transfer between workers when in distributed - mode. Subclasses of this interface should only need to implement - worker-local logic. + mode. Subclasses of this interface should use model runners that inherit + from ModelRunnerBase, and should only need to implement worker-local logic. + If custom control plane logic is needed to transfer metadata, or if the + model runner cannot inherit from ModelRunnerBase, use WorkerBase instead. """ @property @@ -142,7 +144,7 @@ def model_runner(self) -> ModelRunnerBase: """ Get the worker's model runner. Used by the default `execute_model`. If the worker's model runner does not follow the ModelRunnerBase - interface, then this method should raise NotImplementedError. + interface, then inherit from WorkerBase instead. """ raise NotImplementedError @@ -152,8 +154,7 @@ def kv_cache(self) -> Optional[List[torch.Tensor]]: """ Get the kv cache to pass to the worker's model runner. Used by the default `execute_model`. If the worker's model runner does not follow - the ModelRunnerBase interface, then this method should raise - NotImplementedError. + the ModelRunnerBase interface, then inherit from WorkerBase instead. """ raise NotImplementedError From e10bacedbc0dbea3c681cda3282c93ee1100249f Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Sat, 15 Jun 2024 15:37:45 -0700 Subject: [PATCH 32/55] typing Signed-off-by: Stephanie Wang --- tests/worker/test_model_input.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 922fb435df5ff..9c4c0acda64a7 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -13,11 +13,11 @@ class MockAttentionBackend(AttentionBackend): @staticmethod def get_name() -> str: - pass + raise NotImplementedError @staticmethod def get_impl_cls(): - pass + raise NotImplementedError @staticmethod def get_metadata_cls() -> Type["AttentionMetadata"]: @@ -30,7 +30,7 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - pass + raise NotImplementedError @staticmethod def swap_blocks( @@ -81,9 +81,9 @@ def test_gpu_model_input(): assert field.name not in tensor_dict # Broadcast should contain all non-empty fields defined by the developer # for this input type. - for field in model_input.broadcastable_fields: - if getattr(model_input, field) is not None: - assert field in tensor_dict + for field_name in model_input.broadcastable_fields: + if getattr(model_input, field_name, None) is not None: + assert field_name in tensor_dict # Check that received copy has correct values. for field in dataclasses.fields(AttentionMetadata): From ce087ae9375b466f5a03d2e544b7ae14553bab6f Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 17 Jun 2024 19:48:35 -0700 Subject: [PATCH 33/55] fix Signed-off-by: Stephanie Wang --- vllm/worker/cpu_model_runner.py | 8 +++++--- vllm/worker/embedding_model_runner.py | 8 +++++--- vllm/worker/model_runner.py | 14 +++++++------- vllm/worker/model_runner_base.py | 9 ++++++--- vllm/worker/neuron_model_runner.py | 5 ++--- vllm/worker/worker_base.py | 3 +-- 6 files changed, 26 insertions(+), 21 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index b6f2eca7e5909..c1da7f6ed1131 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -271,9 +271,11 @@ def _prepare_decode( attn_metadata, ) - @staticmethod - def model_input_cls() -> Type[CPUModelInput]: - return CPUModelInput + def make_model_input(self, **kwargs) -> CPUModelInput: + return CPUModelInput.new( + attn_backend=self.attn_backend, + **kwargs, + ) def prepare_model_input( self, diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index c777d32bcfe5a..e25af973343f4 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -86,9 +86,11 @@ def execute_model( return self.model.pooler(hidden_states=hidden_states, pooling_metadata=model_input.pooling_metadata) - @staticmethod - def model_input_cls() -> Type[GPUModelInputWithPoolingMetadata]: - return GPUModelInputWithPoolingMetadata + def make_model_input(self, **kwargs) -> GPUModelInputWithPoolingMetadata: + return GPUModelInputWithPoolingMetadata.new( + attn_backend=self.attn_backend, + **kwargs, + ) def prepare_model_input( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e102eb5bb52e1..b336f7b8286c5 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -269,8 +269,7 @@ def _prepare_model_input_tensors( paged_kv_last_page_len: List[int] = [] if len(seq_group_metadata_list) == 0: - model_input_cls = self.model_input_cls() - return model_input_cls() + return self.make_model_input() if self.sliding_window is not None: sliding_window_blocks = (self.sliding_window + self.block_size - @@ -620,8 +619,7 @@ def _prepare_model_input_tensors( for k, v in multi_modal_kwargs_list.items() } - model_input_cls = self.model_input_cls() - return model_input_cls( + return self.make_model_input( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, attn_metadata=attn_metadata, @@ -838,9 +836,11 @@ class ModelRunner(GPUModelRunnerBase[GPUModelInputWithSamplingMetadata]): GPU model runner with sampling step. """ - @staticmethod - def model_input_cls() -> Type[GPUModelInputWithSamplingMetadata]: - return GPUModelInputWithSamplingMetadata + def make_model_input(self, **kwargs) -> GPUModelInputWithSamplingMetadata: + return GPUModelInputWithSamplingMetadata.new( + attn_backend=self.attn_backend, + **kwargs, + ) def prepare_model_input( self, diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 08ef1dc11dca4..f9315c9908714 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Generic, List, Optional, Type, TypeVar +from typing import Generic, List, Optional, TypeVar import torch @@ -19,9 +19,12 @@ class ModelRunnerBase(ABC, Generic[T]): subclass. """ - @staticmethod @abstractmethod - def model_input_cls() -> Type[T]: + def make_model_input(self, **model_input_fields) -> T: + """ + Make an instance of a ModelInput from the given + fields. + """ raise NotImplementedError @abstractmethod diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 55faf54c9358a..be24badf72f9a 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -141,9 +141,8 @@ def _prepare_decode( return input_tokens, input_positions, input_block_ids - @staticmethod - def model_input_cls() -> Type[ModelInputForNeuron]: - return ModelInputForNeuron + def make_model_input(self, **kwargs) -> ModelInputForNeuron: + return ModelInputForNeuron.new(**kwargs) def prepare_model_input( self, diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 78788f9eb5163..dda2e2f14e29f 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -208,8 +208,7 @@ def execute_model( return None worker_input = WorkerInput.new(**broadcast_data) - model_input_cls = self.model_runner.model_input_cls() - model_input = model_input_cls.new(**broadcast_data) + model_input = self.model_runner.make_model_input(**broadcast_data) self.execute_worker(worker_input) From 0e2acc4583fdd518aeb3524191bfa6686a74256c Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 17 Jun 2024 20:05:05 -0700 Subject: [PATCH 34/55] XPU worker and rename Signed-off-by: Stephanie Wang --- tests/worker/test_model_input.py | 11 +++---- vllm/worker/embedding_model_runner.py | 15 +++++----- vllm/worker/model_input.py | 41 +++++++++++++++++++++----- vllm/worker/model_runner.py | 21 +++++++------- vllm/worker/xpu_model_runner.py | 42 ++++++++++++++++----------- vllm/worker/xpu_worker.py | 12 ++++---- 6 files changed, 89 insertions(+), 53 deletions(-) diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 9c4c0acda64a7..0663357ec62ac 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -6,7 +6,7 @@ from vllm.attention import AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend from vllm.model_executor import SamplingMetadata -from vllm.worker.model_input import GPUModelInputWithSamplingMetadata +from vllm.worker.model_input import ModelInputForGPUWithSamplingMetadata class MockAttentionBackend(AttentionBackend): @@ -61,19 +61,20 @@ def test_gpu_model_input(): num_decode_tokens=3, slot_mapping=torch.zeros(1), ) - model_input = GPUModelInputWithSamplingMetadata.new( + model_input = ModelInputForGPUWithSamplingMetadata.new( num_seq_groups=10, sampling_metadata=sampling_metadata, attn_metadata=attn_metadata) - assert isinstance(model_input, GPUModelInputWithSamplingMetadata) + assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata) # Test round trip serialization. tensor_dict = model_input.as_broadcastable_tensor_dict() attn_backend = MockAttentionBackend() - received_model_input = GPUModelInputWithSamplingMetadata.new( + received_model_input = ModelInputForGPUWithSamplingMetadata.new( attn_backend=attn_backend, **tensor_dict) - assert isinstance(received_model_input, GPUModelInputWithSamplingMetadata) + assert isinstance(received_model_input, + ModelInputForGPUWithSamplingMetadata) # Broadcast should not contain empty values. for field in dataclasses.fields(model_input): diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index e25af973343f4..ad1761ecb4340 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -9,14 +9,14 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata -from vllm.worker.model_input import GPUModelInputWithPoolingMetadata +from vllm.worker.model_input import ModelInputForGPUWithPoolingMetadata from vllm.worker.model_runner import GPUModelRunnerBase logger = init_logger(__name__) -class EmbeddingModelRunner(GPUModelRunnerBase[GPUModelInputWithPoolingMetadata] - ): +class EmbeddingModelRunner( + GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): def __init__( self, @@ -45,7 +45,7 @@ def __init__( @torch.inference_mode() def execute_model( self, - model_input: GPUModelInputWithPoolingMetadata, + model_input: ModelInputForGPUWithPoolingMetadata, kv_caches: List[torch.Tensor], ) -> Optional[PoolerOutput]: if self.lora_config: @@ -86,8 +86,9 @@ def execute_model( return self.model.pooler(hidden_states=hidden_states, pooling_metadata=model_input.pooling_metadata) - def make_model_input(self, **kwargs) -> GPUModelInputWithPoolingMetadata: - return GPUModelInputWithPoolingMetadata.new( + def make_model_input(self, + **kwargs) -> ModelInputForGPUWithPoolingMetadata: + return ModelInputForGPUWithPoolingMetadata.new( attn_backend=self.attn_backend, **kwargs, ) @@ -95,7 +96,7 @@ def make_model_input(self, **kwargs) -> GPUModelInputWithPoolingMetadata: def prepare_model_input( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> GPUModelInputWithPoolingMetadata: + ) -> ModelInputForGPUWithPoolingMetadata: assert seq_group_metadata_list is not None model_input = self._prepare_model_input_tensors( seq_group_metadata_list) diff --git a/vllm/worker/model_input.py b/vllm/worker/model_input.py index 8abb5a72eaf14..dddaf79528fb9 100644 --- a/vllm/worker/model_input.py +++ b/vllm/worker/model_input.py @@ -90,7 +90,7 @@ class ModelInput: Python class like SamplingMetadata. For these fields, override as_broadcastable_tensor_dict to return the custom serialized values and override _get_init_kwargs to perform the custom deserialization ( - GPUModelInput for an example). + ModelInputForGPU for an example). """ @property @@ -155,10 +155,9 @@ class CPUModelInput(ModelInput): """ input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None - multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None - attn_metadata: Optional["AttentionMetadata"] = None sampling_metadata: Optional["SamplingMetadata"] = None + multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None @property def broadcastable_fields(self) -> Tuple[str, ...]: @@ -185,7 +184,7 @@ def as_broadcastable_tensor_dict( @dataclasses.dataclass(frozen=True) -class GPUModelInput(ModelInput): +class ModelInputForGPU(ModelInput): """ This base class contains metadata needed for the base model forward pass but not metadata for possible additional steps, e.g., sampling. Model @@ -198,9 +197,8 @@ class GPUModelInput(ModelInput): query_lens: Optional[List[int]] = None lora_mapping: Optional["LoRAMapping"] = None lora_requests: Optional[Set[LoRARequest]] = None - multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None - attn_metadata: Optional["AttentionMetadata"] = None + multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None @property def broadcastable_fields(self) -> Tuple[str, ...]: @@ -226,7 +224,7 @@ def as_broadcastable_tensor_dict( @dataclasses.dataclass(frozen=True) -class GPUModelInputWithPoolingMetadata(GPUModelInput): +class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): """ Used by the EmbeddingModelRunner. """ @@ -234,7 +232,7 @@ class GPUModelInputWithPoolingMetadata(GPUModelInput): @dataclasses.dataclass(frozen=True) -class GPUModelInputWithSamplingMetadata(GPUModelInput): +class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): """ Used by the ModelRunner. """ @@ -267,3 +265,30 @@ class ModelInputForNeuron(ModelInput): def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: raise NotImplementedError("ModelInputForNeuron cannot be broadcast.") + + +@dataclasses.dataclass(frozen=True) +class ModelInputForXPU(ModelInput): + """ + Used by the NeuronModelRunner. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + attn_metadata: Optional["AttentionMetadata"] = None + sampling_metadata: Optional["SamplingMetadata"] = None + multi_modal_input: Optional[Dict[str, torch.Tensor]] = None + + @classmethod + def _get_init_kwargs( # type: ignore + cls, **kwargs) -> Dict[str, Any]: + kwargs = _init_attn_metadata_from_kwargs(**kwargs) + kwargs = _init_sampling_metadata_from_kwargs(**kwargs) + return super()._get_init_kwargs(**kwargs) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = super().as_broadcastable_tensor_dict() + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b336f7b8286c5..2b93481313433 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -25,8 +25,8 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) -from vllm.worker.model_input import (GPUModelInput, - GPUModelInputWithSamplingMetadata) +from vllm.worker.model_input import (ModelInputForGPU, + ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ModelRunnerBase logger = init_logger(__name__) @@ -41,10 +41,10 @@ ] _NUM_WARMUP_ITERS = 2 -TGPUModelInput = TypeVar('TGPUModelInput', bound="GPUModelInput") +TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") -class GPUModelRunnerBase(ModelRunnerBase[TGPUModelInput]): +class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): """ Helper class for shared methods between GPU model runners. """ @@ -215,7 +215,7 @@ def get_max_block_per_batch(self) -> int: def _prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> TGPUModelInput: + ) -> TModelInputForGPU: """Helper method to prepare the model input based on a given sequence group. Prepares metadata needed for the base model forward pass but not metadata for possible additional steps, e.g., sampling. @@ -831,13 +831,14 @@ def vocab_size(self) -> int: return self.model_config.get_vocab_size() -class ModelRunner(GPUModelRunnerBase[GPUModelInputWithSamplingMetadata]): +class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): """ GPU model runner with sampling step. """ - def make_model_input(self, **kwargs) -> GPUModelInputWithSamplingMetadata: - return GPUModelInputWithSamplingMetadata.new( + def make_model_input(self, + **kwargs) -> ModelInputForGPUWithSamplingMetadata: + return ModelInputForGPUWithSamplingMetadata.new( attn_backend=self.attn_backend, **kwargs, ) @@ -845,7 +846,7 @@ def make_model_input(self, **kwargs) -> GPUModelInputWithSamplingMetadata: def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> GPUModelInputWithSamplingMetadata: + ) -> ModelInputForGPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. @@ -871,7 +872,7 @@ def prepare_model_input( @torch.inference_mode() def execute_model( self, - model_input: GPUModelInputWithSamplingMetadata, + model_input: ModelInputForGPUWithSamplingMetadata, kv_caches: List[torch.Tensor], ) -> SamplerOutput: if self.lora_config: diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index f30de703e805d..a8538ecb1b844 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -14,6 +14,8 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata +from vllm.worker.model_runner_base import ModelRunnerBase +from vllm.worker.model_input import ModelInputForXPU logger = init_logger(__name__) @@ -24,7 +26,7 @@ ] -class XPUModelRunner: +class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): def __init__( self, @@ -134,11 +136,16 @@ def profile_run(self) -> None: torch.xpu.synchronize() return - def prepare_input_tensors( + def make_model_input(self, **kwargs) -> ModelInputForXPU: + return ModelInputForXPU.new( + attn_backend=self.attn_backend, + **kwargs, + ) + + def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Optional[torch.Tensor]]: + ) -> ModelInputForXPU: multi_modal_input = None if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or @@ -185,8 +192,11 @@ def prepare_input_tensors( num_prompts=0, ) - return (input_tokens, input_positions, attn_metadata, - sampling_metadata, multi_modal_input) + return self.make_model_input(input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + sampling_metadata=sampling_metadata, + multi_modal_input=multi_modal_input) def _prepare_decode( self, @@ -277,27 +287,25 @@ def _prepare_decode( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + model_input: ModelInputForXPU, kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: - (input_tokens, input_positions, attn_metadata, sampling_metadata, - multi_modal_input - ) = self.prepare_input_tensors(seq_group_metadata_list) - model_executable = self.model execute_model_kwargs = { - "input_ids": input_tokens, - "positions": input_positions, + "input_ids": model_input.input_tokens, + "positions": model_input.input_positions, "kv_caches": kv_caches, - "attn_metadata": attn_metadata, + "attn_metadata": model_input.attn_metadata, } if self.vision_language_config: - execute_model_kwargs.update({"image_input": multi_modal_input}) + execute_model_kwargs.update( + {"image_input": model_input.multi_modal_input}) hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. - logits = self.model.compute_logits(hidden_states, sampling_metadata) + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) # Only perform sampling in the driver worker. if not self.is_driver_worker: @@ -306,7 +314,7 @@ def execute_model( # Sample the next token. output = self.model.sample( logits=logits, - sampling_metadata=sampling_metadata, + sampling_metadata=model_input.sampling_metadata, ) return output diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index 773ee9f8159e1..c03595bc9787f 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -47,7 +47,7 @@ def __init__( lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, speculative_config: Optional[SpeculativeConfig] = None, - is_driver_worker: bool = False, + _is_driver_worker: bool = False, ) -> None: assert device_config.device_type == "xpu" assert is_xpu() @@ -62,8 +62,8 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config - self.is_driver_worker = is_driver_worker - if self.is_driver_worker: + self._is_driver_worker = _is_driver_worker + if self._is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." self.vision_language_config = vision_language_config @@ -71,7 +71,7 @@ def __init__( assert not self.lora_config, ( "To be tested: vision language model with LoRA settings.") - self.model_runner = XPUModelRunner( # type: ignore + self._model_runner = XPUModelRunner( # type: ignore model_config, parallel_config, scheduler_config, @@ -80,7 +80,7 @@ def __init__( load_config=self.load_config, lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, - is_driver_worker=is_driver_worker, + _is_driver_worker=_is_driver_worker, vision_language_config=vision_language_config, ) # Uninitialized cache engine. Will be initialized by @@ -123,7 +123,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Execute a forward pass with dummy inputs to profile the memory usage # of the model. - self.model_runner.profile_run() + self._model_runner.profile_run() # Calculate the number of blocks that can be allocated with the # profiled peak memory. From d318ec8f56c6768a934e7738ca9deedd364b7684 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 17 Jun 2024 20:06:55 -0700 Subject: [PATCH 35/55] lint Signed-off-by: Stephanie Wang --- vllm/worker/cpu_model_runner.py | 2 +- vllm/worker/embedding_model_runner.py | 2 +- vllm/worker/model_runner.py | 2 +- vllm/worker/neuron_model_runner.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index c1da7f6ed1131..deba9d04e2fa8 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Type +from typing import Dict, List, Optional, Tuple import torch from torch import nn diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index ad1761ecb4340..4277d6db7afdc 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple, Type +from typing import Dict, List, Optional, Tuple import torch diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2b93481313433..22f9687811d5e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -2,7 +2,7 @@ import time import warnings from collections import defaultdict -from typing import Dict, List, Optional, Set, Tuple, Type, TypeVar, Union +from typing import Dict, List, Optional, Set, Tuple, TypeVar, Union import numpy as np import torch diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index be24badf72f9a..a15c20a0e0daa 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Type +from typing import List, Optional, Tuple import torch from torch import nn From b48f783f127222cc90a8d4132fa97ba3a05b6569 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 17 Jun 2024 20:08:28 -0700 Subject: [PATCH 36/55] lint Signed-off-by: Stephanie Wang --- vllm/worker/xpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index a8538ecb1b844..ec2917f0e8832 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -13,9 +13,9 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad +from vllm.worker.model_input import ModelInputForXPU from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata from vllm.worker.model_runner_base import ModelRunnerBase -from vllm.worker.model_input import ModelInputForXPU logger = init_logger(__name__) From 30ac40002f13cdc8b9d7e1495b2b7735c629a4b8 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Tue, 18 Jun 2024 10:00:27 -0700 Subject: [PATCH 37/55] fix Signed-off-by: Stephanie Wang --- vllm/worker/cpu_model_runner.py | 11 ++++++----- vllm/worker/embedding_model_runner.py | 3 +++ vllm/worker/model_runner.py | 8 ++++---- vllm/worker/model_runner_base.py | 9 ++++++--- vllm/worker/neuron_model_runner.py | 6 +++++- vllm/worker/worker.py | 1 + vllm/worker/worker_base.py | 3 ++- vllm/worker/xpu_model_runner.py | 6 +++++- 8 files changed, 32 insertions(+), 15 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index deba9d04e2fa8..71aafc6168ad8 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -271,11 +271,12 @@ def _prepare_decode( attn_metadata, ) - def make_model_input(self, **kwargs) -> CPUModelInput: - return CPUModelInput.new( - attn_backend=self.attn_backend, - **kwargs, - ) + def make_model_input(self, + make_attn_metadata: bool = False, + **kwargs) -> CPUModelInput: + if make_attn_metadata: + kwargs["attn_backend"] = self.attn_backend + return CPUModelInput.new(**kwargs, ) def prepare_model_input( self, diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 4277d6db7afdc..16e765523871b 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -87,7 +87,10 @@ def execute_model( pooling_metadata=model_input.pooling_metadata) def make_model_input(self, + make_attn_metadata: bool = False, **kwargs) -> ModelInputForGPUWithPoolingMetadata: + if make_attn_metadata: + kwargs["attn_backend"] = self.attn_backend return ModelInputForGPUWithPoolingMetadata.new( attn_backend=self.attn_backend, **kwargs, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 22f9687811d5e..c3a158614e116 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -837,11 +837,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): """ def make_model_input(self, + make_attn_metadata: bool = False, **kwargs) -> ModelInputForGPUWithSamplingMetadata: - return ModelInputForGPUWithSamplingMetadata.new( - attn_backend=self.attn_backend, - **kwargs, - ) + if make_attn_metadata: + kwargs["attn_backend"] = self.attn_backend + return ModelInputForGPUWithSamplingMetadata.new(**kwargs, ) def prepare_model_input( self, diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index f9315c9908714..33dc2fb852b49 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -20,10 +20,13 @@ class ModelRunnerBase(ABC, Generic[T]): """ @abstractmethod - def make_model_input(self, **model_input_fields) -> T: + def make_model_input(self, + make_attn_metadata: bool = False, + **model_input_fields) -> T: """ - Make an instance of a ModelInput from the given - fields. + Make an instance of a ModelInput from the given fields. If + make_attn_metadata=True, then AttentionMetadata will be created from + fields extracted from model_input_fields. """ raise NotImplementedError diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index a15c20a0e0daa..c1386376451b3 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -141,7 +141,11 @@ def _prepare_decode( return input_tokens, input_positions, input_block_ids - def make_model_input(self, **kwargs) -> ModelInputForNeuron: + def make_model_input(self, + make_attn_metadata: bool = False, + **kwargs) -> ModelInputForNeuron: + if make_attn_metadata: + kwargs["attn_backend"] = self.attn_backend return ModelInputForNeuron.new(**kwargs) def prepare_model_input( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index c4cf9dc9e5c68..d5bc8cd501407 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -49,6 +49,7 @@ def __init__( is_driver_worker: bool = False, ) -> None: self.model_config = model_config + self.model_config.dtype = torch.float16 self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index dda2e2f14e29f..b6031c3860c0a 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -208,7 +208,8 @@ def execute_model( return None worker_input = WorkerInput.new(**broadcast_data) - model_input = self.model_runner.make_model_input(**broadcast_data) + model_input = self.model_runner.make_model_input( + make_attn_metadata=True, **broadcast_data) self.execute_worker(worker_input) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index ec2917f0e8832..7a7ce28b4c9ca 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -136,7 +136,11 @@ def profile_run(self) -> None: torch.xpu.synchronize() return - def make_model_input(self, **kwargs) -> ModelInputForXPU: + def make_model_input(self, + make_attn_metadata: bool = False, + **kwargs) -> ModelInputForXPU: + if make_attn_metadata: + kwargs["attn_backend"] = self.attn_backend return ModelInputForXPU.new( attn_backend=self.attn_backend, **kwargs, From 01688d507b83a7a1f26302aa7175da16be6d9ce2 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Tue, 18 Jun 2024 11:11:31 -0700 Subject: [PATCH 38/55] x Signed-off-by: Stephanie Wang --- vllm/worker/neuron_model_runner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index c1386376451b3..6060a70c95cc9 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -144,8 +144,6 @@ def _prepare_decode( def make_model_input(self, make_attn_metadata: bool = False, **kwargs) -> ModelInputForNeuron: - if make_attn_metadata: - kwargs["attn_backend"] = self.attn_backend return ModelInputForNeuron.new(**kwargs) def prepare_model_input( From 7dbb646269f9c641a5a22b2aef1c8259c4047de1 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Tue, 18 Jun 2024 15:00:58 -0700 Subject: [PATCH 39/55] fix Signed-off-by: Stephanie Wang --- vllm/worker/worker_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index b6031c3860c0a..dbdd024dcd024 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -67,7 +67,7 @@ def start_worker_execution_loop(self) -> None: @abstractmethod def execute_model( - self, execute_model_req: Optional[ExecuteModelRequest] + self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> Optional[List[SamplerOutput]]: raise NotImplementedError @@ -176,7 +176,7 @@ def execute_worker(self, worker_input: WorkerInput) -> None: raise NotImplementedError def execute_model( - self, execute_model_req: Optional[ExecuteModelRequest] + self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> Optional[List[SamplerOutput]]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" From d2e4c4123952adcbb07690ccb81cc505edb41159 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Wed, 19 Jun 2024 09:59:06 -0700 Subject: [PATCH 40/55] fix Signed-off-by: Stephanie Wang --- vllm/worker/worker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d5bc8cd501407..c4cf9dc9e5c68 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -49,7 +49,6 @@ def __init__( is_driver_worker: bool = False, ) -> None: self.model_config = model_config - self.model_config.dtype = torch.float16 self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config From 3e462538ce8436e0e97f02c0f86c7d7653eb9bf1 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Wed, 19 Jun 2024 10:57:45 -0700 Subject: [PATCH 41/55] lint Signed-off-by: Stephanie Wang --- vllm/worker/worker_base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index dbdd024dcd024..e6a277cbf1d1a 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -67,7 +67,8 @@ def start_worker_execution_loop(self) -> None: @abstractmethod def execute_model( - self, execute_model_req: Optional[ExecuteModelRequest] = None + self, + execute_model_req: Optional[ExecuteModelRequest] = None ) -> Optional[List[SamplerOutput]]: raise NotImplementedError @@ -176,7 +177,8 @@ def execute_worker(self, worker_input: WorkerInput) -> None: raise NotImplementedError def execute_model( - self, execute_model_req: Optional[ExecuteModelRequest] = None + self, + execute_model_req: Optional[ExecuteModelRequest] = None ) -> Optional[List[SamplerOutput]]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" From 36dfce17fe2633ec28b046e0bf94098ec5be9ca9 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Fri, 21 Jun 2024 16:19:20 -0700 Subject: [PATCH 42/55] merge Signed-off-by: Stephanie Wang --- vllm/worker/cpu_model_runner.py | 41 +++- vllm/worker/embedding_model_runner.py | 12 +- vllm/worker/model_input.py | 279 +++++++------------------- vllm/worker/model_runner.py | 93 ++++++++- vllm/worker/neuron_model_runner.py | 20 +- vllm/worker/xpu_model_runner.py | 33 ++- 6 files changed, 253 insertions(+), 225 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 71aafc6168ad8..e92b17aa78f4f 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,5 +1,6 @@ from collections import defaultdict -from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn @@ -14,7 +15,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad -from vllm.worker.model_input import CPUModelInput +from vllm.worker.model_input import ModelInput from vllm.worker.model_runner_base import ModelRunnerBase logger = init_logger(__name__) @@ -22,6 +23,42 @@ _PAD_SLOT_ID = -1 +@dataclass(frozen=True) +class CPUModelInput(ModelInput): + """ + Used by the CPUModelRunner. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + attn_metadata: Optional["AttentionMetadata"] = None + sampling_metadata: Optional["SamplingMetadata"] = None + multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None + + @property + def broadcastable_fields(self) -> Tuple[str, ...]: + return ( + "input_tokens", + "input_positions", + "multi_modal_kwargs", + ) + + @classmethod + def _get_init_kwargs( # type: ignore + cls, **kwargs) -> Dict[str, Any]: + kwargs = cls._init_attn_metadata_from_kwargs(**kwargs) + kwargs = cls._init_sampling_metadata_from_kwargs(**kwargs) + return super()._get_init_kwargs(**kwargs) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = super().as_broadcastable_tensor_dict() + self._add_attn_metadata_broadcastable_dict(tensor_dict, + self.attn_metadata) + self._add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + class CPUModelRunner(ModelRunnerBase[CPUModelInput]): def __init__( diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 16e765523871b..1f69d76d673e7 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import torch @@ -9,12 +10,19 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata -from vllm.worker.model_input import ModelInputForGPUWithPoolingMetadata -from vllm.worker.model_runner import GPUModelRunnerBase +from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU logger = init_logger(__name__) +@dataclass(frozen=True) +class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): + """ + Used by the EmbeddingModelRunner. + """ + pooling_metadata: Optional["PoolingMetadata"] = None + + class EmbeddingModelRunner( GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): diff --git a/vllm/worker/model_input.py b/vllm/worker/model_input.py index dddaf79528fb9..0e6017c00d9f2 100644 --- a/vllm/worker/model_input.py +++ b/vllm/worker/model_input.py @@ -1,73 +1,15 @@ """Worker-local model inputs. These define the inputs to different model runners.""" import dataclasses -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, - TypeVar, Union) +from typing import (TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, TypeVar, + Union) import torch -from vllm.lora.request import LoRARequest - if TYPE_CHECKING: from vllm.attention import AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend - from vllm.lora.layers import LoRAMapping from vllm.model_executor import SamplingMetadata - from vllm.model_executor.pooling_metadata import PoolingMetadata - - -def _init_attn_metadata_from_kwargs( - attn_backend: Optional["AttentionBackend"] = None, - attn_metadata: Optional["AttentionMetadata"] = None, - **kwargs) -> Dict[str, Any]: - if attn_metadata is None and attn_backend is not None: - # Extract the fields used to create AttentionMetadata. - valid_attn_kwargs = {} - for field in dataclasses.fields(attn_backend.get_metadata_cls()): - val = kwargs.pop(field.name, None) - if val is not None: - valid_attn_kwargs[field.name] = val - - attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) - if attn_metadata is not None: - kwargs["attn_metadata"] = attn_metadata - return kwargs - - -def _add_attn_metadata_broadcastable_dict( - tensor_dict: Dict[str, Union[int, torch.Tensor]], - attn_metadata: Optional["AttentionMetadata"]) -> None: - if attn_metadata is not None: - tensor_dict.update(attn_metadata.asdict_zerocopy()) - - -def _init_sampling_metadata_from_kwargs( # type: ignore - selected_token_indices: Optional[torch.Tensor] = None, - sampling_metadata: Optional["SamplingMetadata"] = None, - **kwargs) -> Dict[str, Any]: - if sampling_metadata is None and selected_token_indices is not None: - from vllm.model_executor import SamplingMetadata - - # An empty SamplingMetadata to signal that the worker should skip - # sampling. - sampling_metadata = SamplingMetadata( - seq_groups=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - num_prompts=0, - ) - if sampling_metadata is not None: - kwargs["sampling_metadata"] = sampling_metadata - return kwargs - - -def _add_sampling_metadata_broadcastable_dict( - tensor_dict: Dict[str, Union[int, torch.Tensor]], - sampling_metadata: Optional["SamplingMetadata"]) -> None: - if sampling_metadata is not None: - tensor_dict["selected_token_indices"] = ( - sampling_metadata.selected_token_indices) - T = TypeVar('T', bound="ModelInput") @@ -79,8 +21,8 @@ class ModelInput: of converting from the global ExecuteModelRequest produced by the LLM engine to the worker-local ModelInput objects. - Model runners should inherit from this class and add their required fields. - For distributed executors, any fields that should be sent during a + Model runners should define a ModelInput subclass and add their required + fields. For distributed executors, any fields that should be sent during a broadcast op should also be added to the broadcastable_fields. During execution, these fields will be extracted from the source copy and broadcasted to all workers using broadcast_tensor_dict. @@ -147,148 +89,71 @@ def as_broadcastable_tensor_dict( return tensor_dict - -@dataclasses.dataclass(frozen=True) -class CPUModelInput(ModelInput): - """ - Used by the CPUModelRunner. - """ - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - attn_metadata: Optional["AttentionMetadata"] = None - sampling_metadata: Optional["SamplingMetadata"] = None - multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None - - @property - def broadcastable_fields(self) -> Tuple[str, ...]: - return ( - "input_tokens", - "input_positions", - "multi_modal_kwargs", - ) - - @classmethod - def _get_init_kwargs( # type: ignore - cls, **kwargs) -> Dict[str, Any]: - kwargs = _init_attn_metadata_from_kwargs(**kwargs) - kwargs = _init_sampling_metadata_from_kwargs(**kwargs) - return super()._get_init_kwargs(**kwargs) - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = super().as_broadcastable_tensor_dict() - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) - return tensor_dict - - -@dataclasses.dataclass(frozen=True) -class ModelInputForGPU(ModelInput): - """ - This base class contains metadata needed for the base model forward pass - but not metadata for possible additional steps, e.g., sampling. Model - runners that run additional steps should subclass this method to add - additional fields. - """ - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - seq_lens: Optional[List[int]] = None - query_lens: Optional[List[int]] = None - lora_mapping: Optional["LoRAMapping"] = None - lora_requests: Optional[Set[LoRARequest]] = None - attn_metadata: Optional["AttentionMetadata"] = None - multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None - - @property - def broadcastable_fields(self) -> Tuple[str, ...]: - return ( - "input_tokens", - "input_positions", - "lora_requests", - "lora_mapping", - "multi_modal_kwargs", - ) - - @classmethod - def _get_init_kwargs( # type: ignore - cls, **kwargs) -> Dict[str, Any]: - kwargs = _init_attn_metadata_from_kwargs(**kwargs) - return super()._get_init_kwargs(**kwargs) - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = super().as_broadcastable_tensor_dict() - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - return tensor_dict - - -@dataclasses.dataclass(frozen=True) -class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): - """ - Used by the EmbeddingModelRunner. - """ - pooling_metadata: Optional["PoolingMetadata"] = None - - -@dataclasses.dataclass(frozen=True) -class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): - """ - Used by the ModelRunner. - """ - sampling_metadata: Optional["SamplingMetadata"] = None - - @classmethod - def _get_init_kwargs( # type: ignore - cls, **kwargs) -> Dict[str, Any]: - kwargs = _init_sampling_metadata_from_kwargs(**kwargs) - return super()._get_init_kwargs(**kwargs) - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = super().as_broadcastable_tensor_dict() - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) - return tensor_dict - - -@dataclasses.dataclass(frozen=True) -class ModelInputForNeuron(ModelInput): - """ - Used by the NeuronModelRunner. - """ - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - input_block_ids: Optional[torch.Tensor] = None - sampling_metadata: Optional["SamplingMetadata"] = None - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - raise NotImplementedError("ModelInputForNeuron cannot be broadcast.") - - -@dataclasses.dataclass(frozen=True) -class ModelInputForXPU(ModelInput): - """ - Used by the NeuronModelRunner. - """ - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - attn_metadata: Optional["AttentionMetadata"] = None - sampling_metadata: Optional["SamplingMetadata"] = None - multi_modal_input: Optional[Dict[str, torch.Tensor]] = None - - @classmethod - def _get_init_kwargs( # type: ignore - cls, **kwargs) -> Dict[str, Any]: - kwargs = _init_attn_metadata_from_kwargs(**kwargs) - kwargs = _init_sampling_metadata_from_kwargs(**kwargs) - return super()._get_init_kwargs(**kwargs) - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = super().as_broadcastable_tensor_dict() - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) - return tensor_dict + @staticmethod + def _add_attn_metadata_broadcastable_dict( + tensor_dict: Dict[str, Union[int, torch.Tensor]], + attn_metadata: Optional["AttentionMetadata"]) -> None: + """ + Helper method to update tensor_dict with broadcastable + AttentionMetadata fields. + """ + if attn_metadata is not None: + tensor_dict.update(attn_metadata.asdict_zerocopy()) + + @staticmethod + def _init_attn_metadata_from_kwargs( + attn_backend: Optional["AttentionBackend"] = None, + attn_metadata: Optional["AttentionMetadata"] = None, + **kwargs) -> Dict[str, Any]: + """ + Helper method to initialize AttentionMetadata based on an + AttentionBackend and broadcastable AttentionMetadata fields. + """ + if attn_metadata is None and attn_backend is not None: + # Extract the fields used to create AttentionMetadata. + valid_attn_kwargs = {} + for field in dataclasses.fields(attn_backend.get_metadata_cls()): + val = kwargs.pop(field.name, None) + if val is not None: + valid_attn_kwargs[field.name] = val + + attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) + if attn_metadata is not None: + kwargs["attn_metadata"] = attn_metadata + return kwargs + + @staticmethod + def _init_sampling_metadata_from_kwargs( # type: ignore + selected_token_indices: Optional[torch.Tensor] = None, + sampling_metadata: Optional["SamplingMetadata"] = None, + **kwargs) -> Dict[str, Any]: + """ + Helper method to initialize SamplingMetadata based on broadcastable + SamplingMetadata fields. + """ + if sampling_metadata is None and selected_token_indices is not None: + from vllm.model_executor import SamplingMetadata + + # An empty SamplingMetadata to signal that the worker should skip + # sampling. + sampling_metadata = SamplingMetadata( + seq_groups=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + num_prompts=0, + ) + if sampling_metadata is not None: + kwargs["sampling_metadata"] = sampling_metadata + return kwargs + + @staticmethod + def _add_sampling_metadata_broadcastable_dict( + tensor_dict: Dict[str, Union[int, torch.Tensor]], + sampling_metadata: Optional["SamplingMetadata"]) -> None: + """ + Helper method to update tensor_dict with broadcastable + SamplingMetadata fields. + """ + if sampling_metadata is not None: + tensor_dict["selected_token_indices"] = ( + sampling_metadata.selected_token_indices) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 61112fdd34481..9f227214d75f0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -2,7 +2,8 @@ import time import warnings from collections import defaultdict -from typing import Dict, List, Optional, Set, Tuple, TypeVar, Union +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union import numpy as np import torch @@ -25,8 +26,7 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) -from vllm.worker.model_input import (ModelInputForGPU, - ModelInputForGPUWithSamplingMetadata) +from vllm.worker.model_input import ModelInput from vllm.worker.model_runner_base import ModelRunnerBase logger = init_logger(__name__) @@ -44,6 +44,71 @@ TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") +@dataclass(frozen=True) +class ModelInputForGPU(ModelInput): + """ + This base class contains metadata needed for the base model forward pass + but not metadata for possible additional steps, e.g., sampling. Model + runners that run additional steps should subclass this method to add + additional fields. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Optional[Set[LoRARequest]] = None + attn_metadata: Optional["AttentionMetadata"] = None + multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None + + @property + def broadcastable_fields(self) -> Tuple[str, ...]: + return ( + "input_tokens", + "input_positions", + "lora_requests", + "lora_mapping", + "multi_modal_kwargs", + ) + + @classmethod + def _get_init_kwargs( # type: ignore + cls, **kwargs) -> Dict[str, Any]: + kwargs = cls._init_attn_metadata_from_kwargs(**kwargs) + return super()._get_init_kwargs(**kwargs) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = super().as_broadcastable_tensor_dict() + self._add_attn_metadata_broadcastable_dict(tensor_dict, + self.attn_metadata) + return tensor_dict + + +@dataclass(frozen=True) +class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None + # Used for speculative decoding. We do not broadcast it because it is only + # used by the driver worker. + is_prompt: Optional[bool] = None + + @classmethod + def _get_init_kwargs( # type: ignore + cls, **kwargs) -> Dict[str, Any]: + kwargs = cls._init_sampling_metadata_from_kwargs(**kwargs) + return super()._get_init_kwargs(**kwargs) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = super().as_broadcastable_tensor_dict() + self._add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): """ Helper class for shared methods between GPU model runners. @@ -871,7 +936,9 @@ def prepare_model_input( model_input.query_lens, self.device, self.pin_memory) - return model_input.replace(sampling_metadata=sampling_metadata) + return model_input.replace( + sampling_metadata=sampling_metadata, + is_prompt=seq_group_metadata_list[0].is_prompt) @torch.inference_mode() def execute_model( @@ -880,28 +947,34 @@ def execute_model( kv_caches: List[torch.Tensor], ) -> SamplerOutput: if self.lora_config: - self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata decode_meta = model_input.attn_metadata.decode_metadata if prefill_meta is None and decode_meta.use_cuda_graph: + assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model + multi_modal_kwargs = model_input.multi_modal_kwargs or {} hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, - **model_input.multi_modal_kwargs, + **multi_modal_kwargs, ) # Compute the logits. - logits = self.model.compute_logits(hidden_states, model_input.sampling_metadata) + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) # Only perform sampling in the driver worker. if not self.is_driver_worker: @@ -915,10 +988,10 @@ def execute_model( if self.return_hidden_states: # we only need to pass hidden states of most recent token - assert seq_group_metadata_list is not None - if seq_group_metadata_list[0].is_prompt: + if model_input.is_prompt: + assert model_input.sampling_metadata is not None hidden_states = hidden_states.index_select( - 0, sampling_metadata.selected_token_indices) + 0, model_input.sampling_metadata.selected_token_indices) output.hidden_states = hidden_states return output diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 6060a70c95cc9..df742c8ae10b6 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -1,4 +1,5 @@ -from typing import List, Optional, Tuple +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union import torch from torch import nn @@ -10,12 +11,27 @@ from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad -from vllm.worker.model_input import ModelInputForNeuron +from vllm.worker.model_input import ModelInput from vllm.worker.model_runner_base import ModelRunnerBase logger = init_logger(__name__) +@dataclass(frozen=True) +class ModelInputForNeuron(ModelInput): + """ + Used by the NeuronModelRunner. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + input_block_ids: Optional[torch.Tensor] = None + sampling_metadata: Optional["SamplingMetadata"] = None + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + raise NotImplementedError("ModelInputForNeuron cannot be broadcast.") + + class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): def __init__( diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 7a7ce28b4c9ca..63d9285075118 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -1,4 +1,5 @@ -from typing import List, Optional, Tuple +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -13,7 +14,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad -from vllm.worker.model_input import ModelInputForXPU +from vllm.worker.model_input import ModelInput from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata from vllm.worker.model_runner_base import ModelRunnerBase @@ -26,6 +27,34 @@ ] +@dataclass(frozen=True) +class ModelInputForXPU(ModelInput): + """ + Used by the NeuronModelRunner. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + attn_metadata: Optional["AttentionMetadata"] = None + sampling_metadata: Optional["SamplingMetadata"] = None + multi_modal_input: Optional[Dict[str, torch.Tensor]] = None + + @classmethod + def _get_init_kwargs( # type: ignore + cls, **kwargs) -> Dict[str, Any]: + kwargs = cls._init_attn_metadata_from_kwargs(**kwargs) + kwargs = cls._init_sampling_metadata_from_kwargs(**kwargs) + return super()._get_init_kwargs(**kwargs) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = super().as_broadcastable_tensor_dict() + self._add_attn_metadata_broadcastable_dict(tensor_dict, + self.attn_metadata) + self._add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): def __init__( From dc2f1037d11234a62dc2f3a59c9eca09306bcd86 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Fri, 21 Jun 2024 16:21:00 -0700 Subject: [PATCH 43/55] x Signed-off-by: Stephanie Wang --- vllm/worker/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index b52a506403709..0835366bf63c6 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -289,7 +289,7 @@ def remove_lora(self, lora_id: int) -> bool: return self._model_runner.remove_lora(lora_id) def pin_lora(self, lora_id: int) -> bool: - return self.model_runner.pin_lora(lora_id) + return self._model_runner.pin_lora(lora_id) def list_loras(self) -> Set[int]: return self._model_runner.list_loras() From fbf074d689cfb69d7d33a5c195171c0af815c7ad Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Fri, 21 Jun 2024 17:40:36 -0700 Subject: [PATCH 44/55] x Signed-off-by: Stephanie Wang --- tests/worker/test_model_input.py | 2 +- vllm/spec_decode/mlp_speculator_worker.py | 3 +- vllm/worker/cpu_model_runner.py | 3 +- vllm/worker/cpu_worker.py | 3 +- vllm/worker/model_input.py | 159 ---------------------- vllm/worker/model_runner.py | 3 +- vllm/worker/model_runner_base.py | 155 ++++++++++++++++++++- vllm/worker/neuron_model_runner.py | 3 +- vllm/worker/neuron_worker.py | 3 +- vllm/worker/worker.py | 3 +- vllm/worker/worker_base.py | 63 ++++++++- vllm/worker/worker_input.py | 62 --------- vllm/worker/xpu_model_runner.py | 3 +- 13 files changed, 221 insertions(+), 244 deletions(-) delete mode 100644 vllm/worker/model_input.py delete mode 100644 vllm/worker/worker_input.py diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 0663357ec62ac..94f7bfa119c34 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -6,7 +6,7 @@ from vllm.attention import AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend from vllm.model_executor import SamplingMetadata -from vllm.worker.model_input import ModelInputForGPUWithSamplingMetadata +from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata class MockAttentionBackend(AttentionBackend): diff --git a/vllm/spec_decode/mlp_speculator_worker.py b/vllm/spec_decode/mlp_speculator_worker.py index 0926e13bedab1..6c1c8da57d188 100644 --- a/vllm/spec_decode/mlp_speculator_worker.py +++ b/vllm/spec_decode/mlp_speculator_worker.py @@ -7,7 +7,6 @@ SequenceGroupMetadata) from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase -from vllm.worker.model_runner import ModelInput class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker): @@ -56,7 +55,7 @@ def _prepare_input_tensors( seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, List[int], List[int]]: if not seq_group_metadata_list: - return ModelInput.empty(self.device) + return torch.empty(0, device=self.device), [], [] input_tokens: List[int] = [] seq_lens: List[int] = [] diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index e92b17aa78f4f..2a233a2d2fe50 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -15,8 +15,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad -from vllm.worker.model_input import ModelInput -from vllm.worker.model_runner_base import ModelRunnerBase +from vllm.worker.model_runner_base import ModelInput, ModelRunnerBase logger = init_logger(__name__) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 442ce1dc0edb8..882c31ba3b04d 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -17,8 +17,7 @@ from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoraNotSupportedWorkerBase) -from vllm.worker.worker_input import WorkerInput + LoraNotSupportedWorkerBase, WorkerInput) logger = init_logger(__name__) diff --git a/vllm/worker/model_input.py b/vllm/worker/model_input.py deleted file mode 100644 index 0e6017c00d9f2..0000000000000 --- a/vllm/worker/model_input.py +++ /dev/null @@ -1,159 +0,0 @@ -"""Worker-local model inputs. These define the inputs to different model -runners.""" -import dataclasses -from typing import (TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, TypeVar, - Union) - -import torch - -if TYPE_CHECKING: - from vllm.attention import AttentionMetadata - from vllm.attention.backends.abstract import AttentionBackend - from vllm.model_executor import SamplingMetadata - -T = TypeVar('T', bound="ModelInput") - - -@dataclasses.dataclass(frozen=True) -class ModelInput: - """Local inputs to each worker's model runner. May contain - device-specific data. Different worker backends may have different methods - of converting from the global ExecuteModelRequest produced by the LLM - engine to the worker-local ModelInput objects. - - Model runners should define a ModelInput subclass and add their required - fields. For distributed executors, any fields that should be sent during a - broadcast op should also be added to the broadcastable_fields. During - execution, these fields will be extracted from the source copy and - broadcasted to all workers using broadcast_tensor_dict. - - Some fields may have values that cannot be broadcasted with this method - because they require some special serialization/deserialization, e.g., a - Python class like SamplingMetadata. For these fields, override - as_broadcastable_tensor_dict to return the custom serialized values and - override _get_init_kwargs to perform the custom deserialization ( - ModelInputForGPU for an example). - """ - - @property - def broadcastable_fields(self) -> Tuple[str, ...]: - """ - Return fields to broadcast to all workers from driver. The value of - each field must be broadcastable using broadcast_tensor_dict (i.e. - either a tensor, or a Python primitive like int). During the broadcast, - the listed fields will be extracted from the source copy and then - passed to `new()` to create a copy on the destination(s). - """ - raise NotImplementedError() - - @classmethod - def _get_init_kwargs(cls: Type[T], **kwargs) -> Dict[str, Any]: - """ - Helper method to extract all dataclass fields from the given kwargs. - Override for fields that require some custom deserialization. - """ - init_kwargs = {} - for field in dataclasses.fields(cls): - val = kwargs.get(field.name, None) - if val is not None: - init_kwargs[field.name] = val - return init_kwargs - - @classmethod - def new(cls: Type[T], **kwargs) -> T: - """ - Create a new instance of this class. Populate the new instance with - the given kwargs. - """ - kwargs = cls._get_init_kwargs(**kwargs) - return cls(**kwargs) - - def replace(self: T, **kwargs) -> T: - """ - Replace current fields with fields in kwargs. - """ - valid_kwargs = self.__class__._get_init_kwargs(**kwargs) - return dataclasses.replace(self, **valid_kwargs) - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - """ - Extract broadcastable fields. Override for fields that require some - custom deserialization. - """ - tensor_dict: Dict[str, Union[int, torch.Tensor]] = {} - for field in self.broadcastable_fields: - val = getattr(self, field, None) - if val is not None: - tensor_dict[field] = val - - return tensor_dict - - @staticmethod - def _add_attn_metadata_broadcastable_dict( - tensor_dict: Dict[str, Union[int, torch.Tensor]], - attn_metadata: Optional["AttentionMetadata"]) -> None: - """ - Helper method to update tensor_dict with broadcastable - AttentionMetadata fields. - """ - if attn_metadata is not None: - tensor_dict.update(attn_metadata.asdict_zerocopy()) - - @staticmethod - def _init_attn_metadata_from_kwargs( - attn_backend: Optional["AttentionBackend"] = None, - attn_metadata: Optional["AttentionMetadata"] = None, - **kwargs) -> Dict[str, Any]: - """ - Helper method to initialize AttentionMetadata based on an - AttentionBackend and broadcastable AttentionMetadata fields. - """ - if attn_metadata is None and attn_backend is not None: - # Extract the fields used to create AttentionMetadata. - valid_attn_kwargs = {} - for field in dataclasses.fields(attn_backend.get_metadata_cls()): - val = kwargs.pop(field.name, None) - if val is not None: - valid_attn_kwargs[field.name] = val - - attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) - if attn_metadata is not None: - kwargs["attn_metadata"] = attn_metadata - return kwargs - - @staticmethod - def _init_sampling_metadata_from_kwargs( # type: ignore - selected_token_indices: Optional[torch.Tensor] = None, - sampling_metadata: Optional["SamplingMetadata"] = None, - **kwargs) -> Dict[str, Any]: - """ - Helper method to initialize SamplingMetadata based on broadcastable - SamplingMetadata fields. - """ - if sampling_metadata is None and selected_token_indices is not None: - from vllm.model_executor import SamplingMetadata - - # An empty SamplingMetadata to signal that the worker should skip - # sampling. - sampling_metadata = SamplingMetadata( - seq_groups=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - num_prompts=0, - ) - if sampling_metadata is not None: - kwargs["sampling_metadata"] = sampling_metadata - return kwargs - - @staticmethod - def _add_sampling_metadata_broadcastable_dict( - tensor_dict: Dict[str, Union[int, torch.Tensor]], - sampling_metadata: Optional["SamplingMetadata"]) -> None: - """ - Helper method to update tensor_dict with broadcastable - SamplingMetadata fields. - """ - if sampling_metadata is not None: - tensor_dict["selected_token_indices"] = ( - sampling_metadata.selected_token_indices) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4003fdadcce28..861f9274d6d6e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -26,8 +26,7 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) -from vllm.worker.model_input import ModelInput -from vllm.worker.model_runner_base import ModelRunnerBase +from vllm.worker.model_runner_base import ModelInput, ModelRunnerBase logger = init_logger(__name__) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 33dc2fb852b49..022ff8110a6b2 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -1,14 +1,165 @@ +import dataclasses from abc import ABC, abstractmethod -from typing import Generic, List, Optional, TypeVar +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, + Type, TypeVar, Union) import torch from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.worker.model_input import ModelInput + +if TYPE_CHECKING: + from vllm.attention import AttentionMetadata + from vllm.attention.backends.abstract import AttentionBackend + from vllm.model_executor import SamplingMetadata T = TypeVar('T', bound="ModelInput") +@dataclasses.dataclass(frozen=True) +class ModelInput: + """Local inputs to each worker's model runner. May contain + device-specific data. Different worker backends may have different methods + of converting from the global ExecuteModelRequest produced by the LLM + engine to the worker-local ModelInput objects. + + Model runners should define a ModelInput subclass and add their required + fields. For distributed executors, any fields that should be sent during a + broadcast op should also be added to the broadcastable_fields. During + execution, these fields will be extracted from the source copy and + broadcasted to all workers using broadcast_tensor_dict. + + Some fields may have values that cannot be broadcasted with this method + because they require some special serialization/deserialization, e.g., a + Python class like SamplingMetadata. For these fields, override + as_broadcastable_tensor_dict to return the custom serialized values and + override _get_init_kwargs to perform the custom deserialization ( + ModelInputForGPU for an example). + """ + + @property + def broadcastable_fields(self) -> Tuple[str, ...]: + """ + Return fields to broadcast to all workers from driver. The value of + each field must be broadcastable using broadcast_tensor_dict (i.e. + either a tensor, or a Python primitive like int). During the broadcast, + the listed fields will be extracted from the source copy and then + passed to `new()` to create a copy on the destination(s). + """ + raise NotImplementedError() + + @classmethod + def _get_init_kwargs(cls: Type[T], **kwargs) -> Dict[str, Any]: + """ + Helper method to extract all dataclass fields from the given kwargs. + Override for fields that require some custom deserialization. + """ + init_kwargs = {} + for field in dataclasses.fields(cls): + val = kwargs.get(field.name, None) + if val is not None: + init_kwargs[field.name] = val + return init_kwargs + + @classmethod + def new(cls: Type[T], **kwargs) -> T: + """ + Create a new instance of this class. Populate the new instance with + the given kwargs. + """ + kwargs = cls._get_init_kwargs(**kwargs) + return cls(**kwargs) + + def replace(self: T, **kwargs) -> T: + """ + Replace current fields with fields in kwargs. + """ + valid_kwargs = self.__class__._get_init_kwargs(**kwargs) + return dataclasses.replace(self, **valid_kwargs) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + """ + Extract broadcastable fields. Override for fields that require some + custom deserialization. + """ + tensor_dict: Dict[str, Union[int, torch.Tensor]] = {} + for field in self.broadcastable_fields: + val = getattr(self, field, None) + if val is not None: + tensor_dict[field] = val + + return tensor_dict + + @staticmethod + def _add_attn_metadata_broadcastable_dict( + tensor_dict: Dict[str, Union[int, torch.Tensor]], + attn_metadata: Optional["AttentionMetadata"]) -> None: + """ + Helper method to update tensor_dict with broadcastable + AttentionMetadata fields. + """ + if attn_metadata is not None: + tensor_dict.update(attn_metadata.asdict_zerocopy()) + + @staticmethod + def _init_attn_metadata_from_kwargs( + attn_backend: Optional["AttentionBackend"] = None, + attn_metadata: Optional["AttentionMetadata"] = None, + **kwargs) -> Dict[str, Any]: + """ + Helper method to initialize AttentionMetadata based on an + AttentionBackend and broadcastable AttentionMetadata fields. + """ + if attn_metadata is None and attn_backend is not None: + # Extract the fields used to create AttentionMetadata. + valid_attn_kwargs = {} + for field in dataclasses.fields(attn_backend.get_metadata_cls()): + val = kwargs.pop(field.name, None) + if val is not None: + valid_attn_kwargs[field.name] = val + + attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) + if attn_metadata is not None: + kwargs["attn_metadata"] = attn_metadata + return kwargs + + @staticmethod + def _init_sampling_metadata_from_kwargs( # type: ignore + selected_token_indices: Optional[torch.Tensor] = None, + sampling_metadata: Optional["SamplingMetadata"] = None, + **kwargs) -> Dict[str, Any]: + """ + Helper method to initialize SamplingMetadata based on broadcastable + SamplingMetadata fields. + """ + if sampling_metadata is None and selected_token_indices is not None: + from vllm.model_executor import SamplingMetadata + + # An empty SamplingMetadata to signal that the worker should skip + # sampling. + sampling_metadata = SamplingMetadata( + seq_groups=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + num_prompts=0, + ) + if sampling_metadata is not None: + kwargs["sampling_metadata"] = sampling_metadata + return kwargs + + @staticmethod + def _add_sampling_metadata_broadcastable_dict( + tensor_dict: Dict[str, Union[int, torch.Tensor]], + sampling_metadata: Optional["SamplingMetadata"]) -> None: + """ + Helper method to update tensor_dict with broadcastable + SamplingMetadata fields. + """ + if sampling_metadata is not None: + tensor_dict["selected_token_indices"] = ( + sampling_metadata.selected_token_indices) + + class ModelRunnerBase(ABC, Generic[T]): """ Model runner interface that abstracts a particular hardware and/or type of diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index df742c8ae10b6..3fda809dffb61 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -11,8 +11,7 @@ from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad -from vllm.worker.model_input import ModelInput -from vllm.worker.model_runner_base import ModelRunnerBase +from vllm.worker.model_runner_base import ModelInput, ModelRunnerBase logger = init_logger(__name__) diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index ed478c0ac1db6..2ae611f99a85b 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -11,8 +11,7 @@ from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoraNotSupportedWorkerBase) -from vllm.worker.worker_input import WorkerInput + LoraNotSupportedWorkerBase, WorkerInput) class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 0835366bf63c6..f0107c7b5c791 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -20,8 +20,7 @@ from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.model_runner_base import ModelRunnerBase -from vllm.worker.worker_base import LocalOrDistributedWorkerBase -from vllm.worker.worker_input import WorkerInput +from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput class Worker(LocalOrDistributedWorkerBase): diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index b87420f044987..ff7562e06b901 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,7 +1,8 @@ +import dataclasses import importlib import os from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union import torch @@ -11,9 +12,7 @@ from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) -from vllm.worker.model_input import ModelInput -from vllm.worker.model_runner_base import ModelRunnerBase -from vllm.worker.worker_input import WorkerInput +from vllm.worker.model_runner_base import ModelInput, ModelRunnerBase logger = init_logger(__name__) @@ -115,6 +114,62 @@ def list_loras(self) -> Set[int]: raise ValueError(f"{type(self)} does not support LoRA") +@dataclasses.dataclass(frozen=True) +class WorkerInput: + """Local inputs to each worker. May contain device-specific data. Different + worker backends may have different methods of converting from the global + ExecuteModelRequest produced by the LLM engine to the worker-local + WorkerInput objects. + + Subclasses of WorkerBase should inherit from this class and add their + required fields. For distributed executors, any fields that should be sent + during a broadcast op should also be added to the broadcastable_fields. + During execution, these fields will be extracted from the source copy and + broadcasted to all workers using broadcast_tensor_dict. + """ + + num_seq_groups: Optional[int] = None + blocks_to_swap_in: Optional[torch.Tensor] = None + blocks_to_swap_out: Optional[torch.Tensor] = None + blocks_to_copy: Optional[torch.Tensor] = None + + @classmethod + def _get_init_kwargs(cls: Type["WorkerInput"], **kwargs) -> Dict[str, Any]: + """ + Helper method to extract all dataclass fields from the given kwargs. + Override for fields that require some custom deserialization. + """ + init_kwargs = {} + for field in dataclasses.fields(cls): + val = kwargs.get(field.name, None) + if val is not None: + init_kwargs[field.name] = val + return init_kwargs + + @classmethod + def new(cls: Type["WorkerInput"], **kwargs) -> "WorkerInput": + """ + Create a new instance of this class. Populate the new instance with + the given kwargs. + """ + kwargs = cls._get_init_kwargs(**kwargs) + return cls(**kwargs) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + """ + Extract broadcastable fields. Override for fields that require some + custom deserialization. + """ + tensor_dict: Dict[str, Union[int, torch.Tensor]] = {} + for field in dataclasses.fields(self): + val = getattr(self, field.name, None) + if val is not None: + tensor_dict[field.name] = val + + return tensor_dict + + class LocalOrDistributedWorkerBase(WorkerBase): """ Partial implementation of WorkerBase that has a default `execute_model` diff --git a/vllm/worker/worker_input.py b/vllm/worker/worker_input.py deleted file mode 100644 index f010270f816e1..0000000000000 --- a/vllm/worker/worker_input.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Worker-local model inputs. These define the inputs to different model -runners.""" -import dataclasses -from typing import Any, Dict, Optional, Type, Union - -import torch - - -@dataclasses.dataclass(frozen=True) -class WorkerInput: - """Local inputs to each worker. May contain device-specific data. Different - worker backends may have different methods of converting from the global - ExecuteModelRequest produced by the LLM engine to the worker-local - WorkerInput objects. - - Subclasses of WorkerBase should inherit from this class and add their - required fields. For distributed executors, any fields that should be sent - during a broadcast op should also be added to the broadcastable_fields. - During execution, these fields will be extracted from the source copy and - broadcasted to all workers using broadcast_tensor_dict. - """ - - num_seq_groups: Optional[int] = None - blocks_to_swap_in: Optional[torch.Tensor] = None - blocks_to_swap_out: Optional[torch.Tensor] = None - blocks_to_copy: Optional[torch.Tensor] = None - - @classmethod - def _get_init_kwargs(cls: Type["WorkerInput"], **kwargs) -> Dict[str, Any]: - """ - Helper method to extract all dataclass fields from the given kwargs. - Override for fields that require some custom deserialization. - """ - init_kwargs = {} - for field in dataclasses.fields(cls): - val = kwargs.get(field.name, None) - if val is not None: - init_kwargs[field.name] = val - return init_kwargs - - @classmethod - def new(cls: Type["WorkerInput"], **kwargs) -> "WorkerInput": - """ - Create a new instance of this class. Populate the new instance with - the given kwargs. - """ - kwargs = cls._get_init_kwargs(**kwargs) - return cls(**kwargs) - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - """ - Extract broadcastable fields. Override for fields that require some - custom deserialization. - """ - tensor_dict: Dict[str, Union[int, torch.Tensor]] = {} - for field in dataclasses.fields(self): - val = getattr(self, field.name, None) - if val is not None: - tensor_dict[field.name] = val - - return tensor_dict diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 63d9285075118..764f3f9ece931 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -14,9 +14,8 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad -from vllm.worker.model_input import ModelInput from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata -from vllm.worker.model_runner_base import ModelRunnerBase +from vllm.worker.model_runner_base import ModelInput, ModelRunnerBase logger = init_logger(__name__) From 660a8d51ae165290d5c2aa24d47bdd42db8ffe6e Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Sun, 23 Jun 2024 14:10:12 -0700 Subject: [PATCH 45/55] rename ModelInput -> ModelInputBase, override as_broadcastable_tensor_dict and new Signed-off-by: Stephanie Wang --- tests/worker/test_model_input.py | 82 ++++++++++-- vllm/worker/cpu_model_runner.py | 48 ++++--- vllm/worker/model_runner.py | 60 +++++---- vllm/worker/model_runner_base.py | 207 ++++++++++++----------------- vllm/worker/neuron_model_runner.py | 16 ++- vllm/worker/worker_base.py | 4 +- vllm/worker/xpu_model_runner.py | 43 +++--- 7 files changed, 256 insertions(+), 204 deletions(-) diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 94f7bfa119c34..4737fab4004c9 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -6,6 +6,9 @@ from vllm.attention import AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend from vllm.model_executor import SamplingMetadata +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.worker.embedding_model_runner import ( + ModelInputForGPUWithPoolingMetadata) from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata @@ -48,7 +51,7 @@ def copy_blocks( pass -def test_gpu_model_input(): +def test_model_runner_input(): sampling_metadata = SamplingMetadata( ["seq_group"], "selected_token_indices", @@ -62,7 +65,8 @@ def test_gpu_model_input(): slot_mapping=torch.zeros(1), ) model_input = ModelInputForGPUWithSamplingMetadata.new( - num_seq_groups=10, + input_tokens=torch.ones(10), + input_positions=torch.ones(10), sampling_metadata=sampling_metadata, attn_metadata=attn_metadata) @@ -73,20 +77,22 @@ def test_gpu_model_input(): attn_backend = MockAttentionBackend() received_model_input = ModelInputForGPUWithSamplingMetadata.new( attn_backend=attn_backend, **tensor_dict) + # Check that received copy has correct values. assert isinstance(received_model_input, ModelInputForGPUWithSamplingMetadata) - - # Broadcast should not contain empty values. - for field in dataclasses.fields(model_input): - if getattr(model_input, field.name) is None: - assert field.name not in tensor_dict - # Broadcast should contain all non-empty fields defined by the developer - # for this input type. - for field_name in model_input.broadcastable_fields: - if getattr(model_input, field_name, None) is not None: - assert field_name in tensor_dict - - # Check that received copy has correct values. + assert received_model_input.input_tokens is not None + assert ( + received_model_input.input_tokens == model_input.input_tokens).all() + assert received_model_input.input_positions is not None + assert (received_model_input.input_positions == model_input.input_positions + ).all() + assert received_model_input.multi_modal_kwargs is None + assert (received_model_input.multi_modal_kwargs == + model_input.multi_modal_kwargs) + assert received_model_input.lora_requests is None + assert received_model_input.lora_requests == model_input.lora_requests + assert received_model_input.lora_mapping is None + assert received_model_input.lora_mapping == model_input.lora_mapping for field in dataclasses.fields(AttentionMetadata): assert getattr(received_model_input.attn_metadata, field.name, None) == getattr(attn_metadata, field.name, None) @@ -94,3 +100,51 @@ def test_gpu_model_input(): assert (received_model_input.sampling_metadata.selected_token_indices == sampling_metadata.selected_token_indices) assert received_model_input.sampling_metadata.seq_groups is None + + +def test_embedding_model_runner_input(): + pooling_metadata = PoolingMetadata( + seq_groups=[[0]], + seq_data={}, + prompt_lens=[1], + ) + attn_metadata = AttentionMetadata( + num_prefills=1, + num_prefill_tokens=2, + num_decode_tokens=3, + slot_mapping=torch.zeros(1), + ) + model_input = ModelInputForGPUWithPoolingMetadata.new( + input_tokens=torch.ones(10), + input_positions=torch.ones(10), + pooling_metadata=pooling_metadata, + attn_metadata=attn_metadata) + + assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata) + + # Test round trip serialization. + tensor_dict = model_input.as_broadcastable_tensor_dict() + attn_backend = MockAttentionBackend() + received_model_input = ModelInputForGPUWithPoolingMetadata.new( + attn_backend=attn_backend, **tensor_dict) + # Check that received copy has correct values. + assert isinstance(received_model_input, + ModelInputForGPUWithPoolingMetadata) + assert received_model_input.input_tokens is not None + assert ( + received_model_input.input_tokens == model_input.input_tokens).all() + assert received_model_input.input_positions is not None + assert (received_model_input.input_positions == model_input.input_positions + ).all() + assert received_model_input.multi_modal_kwargs is None + assert (received_model_input.multi_modal_kwargs == + model_input.multi_modal_kwargs) + assert received_model_input.lora_requests is None + assert received_model_input.lora_requests == model_input.lora_requests + assert received_model_input.lora_mapping is None + assert received_model_input.lora_mapping == model_input.lora_mapping + for field in dataclasses.fields(AttentionMetadata): + assert getattr(received_model_input.attn_metadata, field.name, + None) == getattr(attn_metadata, field.name, None) + # Pooling metadata is not broadcast. + assert received_model_input.pooling_metadata is None diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 2a233a2d2fe50..3002936f88c6e 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,6 +1,6 @@ from collections import defaultdict from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch from torch import nn @@ -15,7 +15,13 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad -from vllm.worker.model_runner_base import ModelInput, ModelRunnerBase +from vllm.worker.model_runner_base import ( + ModelInputBase, ModelRunnerBase, _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_kwargs, + _init_sampling_metadata_from_kwargs) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend logger = init_logger(__name__) @@ -23,7 +29,7 @@ @dataclass(frozen=True) -class CPUModelInput(ModelInput): +class CPUModelInput(ModelInputBase): """ Used by the CPUModelRunner. """ @@ -33,29 +39,29 @@ class CPUModelInput(ModelInput): sampling_metadata: Optional["SamplingMetadata"] = None multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None - @property - def broadcastable_fields(self) -> Tuple[str, ...]: - return ( + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = self._get_attrs([ "input_tokens", "input_positions", "multi_modal_kwargs", - ) + ]) + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict @classmethod - def _get_init_kwargs( # type: ignore - cls, **kwargs) -> Dict[str, Any]: - kwargs = cls._init_attn_metadata_from_kwargs(**kwargs) - kwargs = cls._init_sampling_metadata_from_kwargs(**kwargs) - return super()._get_init_kwargs(**kwargs) - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = super().as_broadcastable_tensor_dict() - self._add_attn_metadata_broadcastable_dict(tensor_dict, - self.attn_metadata) - self._add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) - return tensor_dict + def new(cls, + attn_backend: Optional["AttentionBackend"] = None, + selected_token_indices: Optional[torch.Tensor] = None, + **kwargs) -> "CPUModelInput": + if attn_backend is not None: + kwargs = _init_attn_metadata_from_kwargs(attn_backend, **kwargs) + if selected_token_indices is not None: + kwargs = _init_sampling_metadata_from_kwargs( + selected_token_indices, **kwargs) + return cls(**kwargs) class CPUModelRunner(ModelRunnerBase[CPUModelInput]): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 861f9274d6d6e..43c08d8fce255 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -3,7 +3,8 @@ import warnings from collections import defaultdict from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union +from typing import (TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, + TypeVar, Union) import numpy as np import torch @@ -26,7 +27,13 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) -from vllm.worker.model_runner_base import ModelInput, ModelRunnerBase +from vllm.worker.model_runner_base import ( + ModelInputBase, ModelRunnerBase, _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_kwargs, + _init_sampling_metadata_from_kwargs) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend logger = init_logger(__name__) @@ -44,7 +51,7 @@ @dataclass(frozen=True) -class ModelInputForGPU(ModelInput): +class ModelInputForGPU(ModelInputBase): """ This base class contains metadata needed for the base model forward pass but not metadata for possible additional steps, e.g., sampling. Model @@ -60,28 +67,25 @@ class ModelInputForGPU(ModelInput): attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None - @property - def broadcastable_fields(self) -> Tuple[str, ...]: - return ( + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = self._get_attrs([ "input_tokens", "input_positions", "lora_requests", "lora_mapping", "multi_modal_kwargs", - ) + ]) + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + return tensor_dict @classmethod - def _get_init_kwargs( # type: ignore - cls, **kwargs) -> Dict[str, Any]: - kwargs = cls._init_attn_metadata_from_kwargs(**kwargs) - return super()._get_init_kwargs(**kwargs) - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = super().as_broadcastable_tensor_dict() - self._add_attn_metadata_broadcastable_dict(tensor_dict, - self.attn_metadata) - return tensor_dict + def new(cls: Type[TModelInputForGPU], + attn_backend: Optional["AttentionBackend"] = None, + **kwargs) -> TModelInputForGPU: + if attn_backend is not None: + kwargs = _init_attn_metadata_from_kwargs(attn_backend, **kwargs) + return cls(**kwargs) @dataclass(frozen=True) @@ -94,19 +98,23 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): # used by the driver worker. is_prompt: Optional[bool] = None - @classmethod - def _get_init_kwargs( # type: ignore - cls, **kwargs) -> Dict[str, Any]: - kwargs = cls._init_sampling_metadata_from_kwargs(**kwargs) - return super()._get_init_kwargs(**kwargs) - def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: tensor_dict = super().as_broadcastable_tensor_dict() - self._add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) return tensor_dict + @classmethod + def new(cls, + attn_backend: Optional["AttentionBackend"] = None, + selected_token_indices: Optional[torch.Tensor] = None, + **kwargs) -> "ModelInputForGPUWithSamplingMetadata": + if selected_token_indices is not None: + kwargs = _init_sampling_metadata_from_kwargs( + selected_token_indices, **kwargs) + return super().new(attn_backend, **kwargs) + class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): """ diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 022ff8110a6b2..1f4db2cdd0139 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -1,7 +1,7 @@ import dataclasses from abc import ABC, abstractmethod -from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, - Type, TypeVar, Union) +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, + TypeVar, Union) import torch @@ -12,69 +12,97 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.model_executor import SamplingMetadata -T = TypeVar('T', bound="ModelInput") +T = TypeVar('T', bound="ModelInputBase") + + +def _add_attn_metadata_broadcastable_dict( + tensor_dict: Dict[str, Union[int, torch.Tensor]], + attn_metadata: Optional["AttentionMetadata"]) -> None: + """ + Helper method to update tensor_dict with broadcastable + AttentionMetadata fields. + """ + if attn_metadata is not None: + tensor_dict.update(attn_metadata.asdict_zerocopy()) + + +def _init_attn_metadata_from_kwargs(attn_backend: "AttentionBackend", + **kwargs) -> Dict[str, Any]: + """ + Helper method to initialize AttentionMetadata based on an + AttentionBackend and broadcastable AttentionMetadata fields. + """ + # Extract the fields used to create AttentionMetadata. + valid_attn_kwargs = {} + for field in dataclasses.fields(attn_backend.get_metadata_cls()): + val = kwargs.pop(field.name, None) + if val is not None: + valid_attn_kwargs[field.name] = val + + attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) + kwargs["attn_metadata"] = attn_metadata + return kwargs + + +def _init_sampling_metadata_from_kwargs( # type: ignore + selected_token_indices: torch.Tensor = None, + **kwargs) -> Dict[str, Any]: + """ + Helper method to initialize SamplingMetadata based on broadcastable + SamplingMetadata fields. + """ + from vllm.model_executor import SamplingMetadata + + # An empty SamplingMetadata to signal that the worker should skip + # sampling. + sampling_metadata = SamplingMetadata( + seq_groups=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + num_prompts=0, + ) + kwargs["sampling_metadata"] = sampling_metadata + return kwargs + + +def _add_sampling_metadata_broadcastable_dict( + tensor_dict: Dict[str, Union[int, torch.Tensor]], + sampling_metadata: Optional["SamplingMetadata"]) -> None: + """ + Helper method to update tensor_dict with broadcastable + SamplingMetadata fields. + """ + if sampling_metadata is not None: + tensor_dict["selected_token_indices"] = ( + sampling_metadata.selected_token_indices) @dataclasses.dataclass(frozen=True) -class ModelInput: +class ModelInputBase(ABC): """Local inputs to each worker's model runner. May contain device-specific data. Different worker backends may have different methods of converting from the global ExecuteModelRequest produced by the LLM - engine to the worker-local ModelInput objects. - - Model runners should define a ModelInput subclass and add their required - fields. For distributed executors, any fields that should be sent during a - broadcast op should also be added to the broadcastable_fields. During - execution, these fields will be extracted from the source copy and - broadcasted to all workers using broadcast_tensor_dict. - - Some fields may have values that cannot be broadcasted with this method - because they require some special serialization/deserialization, e.g., a - Python class like SamplingMetadata. For these fields, override - as_broadcastable_tensor_dict to return the custom serialized values and - override _get_init_kwargs to perform the custom deserialization ( - ModelInputForGPU for an example). - """ - - @property - def broadcastable_fields(self) -> Tuple[str, ...]: - """ - Return fields to broadcast to all workers from driver. The value of - each field must be broadcastable using broadcast_tensor_dict (i.e. - either a tensor, or a Python primitive like int). During the broadcast, - the listed fields will be extracted from the source copy and then - passed to `new()` to create a copy on the destination(s). - """ - raise NotImplementedError() + engine to the worker-local ModelInputBase objects. - @classmethod - def _get_init_kwargs(cls: Type[T], **kwargs) -> Dict[str, Any]: - """ - Helper method to extract all dataclass fields from the given kwargs. - Override for fields that require some custom deserialization. - """ - init_kwargs = {} - for field in dataclasses.fields(cls): - val = kwargs.get(field.name, None) - if val is not None: - init_kwargs[field.name] = val - return init_kwargs + Model runners that support multi-GPU execution should define a + ModelInputBase subclass, add their required fields, and specify how to + serialize/deserialize a ModelInput for broadcast between workers. + """ @classmethod + @abstractmethod def new(cls: Type[T], **kwargs) -> T: """ Create a new instance of this class. Populate the new instance with the given kwargs. """ - kwargs = cls._get_init_kwargs(**kwargs) - return cls(**kwargs) + raise NotImplementedError def replace(self: T, **kwargs) -> T: """ Replace current fields with fields in kwargs. """ - valid_kwargs = self.__class__._get_init_kwargs(**kwargs) - return dataclasses.replace(self, **valid_kwargs) + return dataclasses.replace(self, **kwargs) def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -82,83 +110,22 @@ def as_broadcastable_tensor_dict( Extract broadcastable fields. Override for fields that require some custom deserialization. """ + raise NotImplementedError + + def _get_attrs(self, attrs: List[str]) -> Dict[str, Any]: + """ + Helper method to get a dictionary from attribute name to value. + Attributes whose values are None will not be added to the returned + dictionary. + """ tensor_dict: Dict[str, Union[int, torch.Tensor]] = {} - for field in self.broadcastable_fields: - val = getattr(self, field, None) + for attr in attrs: + val = getattr(self, attr, None) if val is not None: - tensor_dict[field] = val + tensor_dict[attr] = val return tensor_dict - @staticmethod - def _add_attn_metadata_broadcastable_dict( - tensor_dict: Dict[str, Union[int, torch.Tensor]], - attn_metadata: Optional["AttentionMetadata"]) -> None: - """ - Helper method to update tensor_dict with broadcastable - AttentionMetadata fields. - """ - if attn_metadata is not None: - tensor_dict.update(attn_metadata.asdict_zerocopy()) - - @staticmethod - def _init_attn_metadata_from_kwargs( - attn_backend: Optional["AttentionBackend"] = None, - attn_metadata: Optional["AttentionMetadata"] = None, - **kwargs) -> Dict[str, Any]: - """ - Helper method to initialize AttentionMetadata based on an - AttentionBackend and broadcastable AttentionMetadata fields. - """ - if attn_metadata is None and attn_backend is not None: - # Extract the fields used to create AttentionMetadata. - valid_attn_kwargs = {} - for field in dataclasses.fields(attn_backend.get_metadata_cls()): - val = kwargs.pop(field.name, None) - if val is not None: - valid_attn_kwargs[field.name] = val - - attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) - if attn_metadata is not None: - kwargs["attn_metadata"] = attn_metadata - return kwargs - - @staticmethod - def _init_sampling_metadata_from_kwargs( # type: ignore - selected_token_indices: Optional[torch.Tensor] = None, - sampling_metadata: Optional["SamplingMetadata"] = None, - **kwargs) -> Dict[str, Any]: - """ - Helper method to initialize SamplingMetadata based on broadcastable - SamplingMetadata fields. - """ - if sampling_metadata is None and selected_token_indices is not None: - from vllm.model_executor import SamplingMetadata - - # An empty SamplingMetadata to signal that the worker should skip - # sampling. - sampling_metadata = SamplingMetadata( - seq_groups=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - num_prompts=0, - ) - if sampling_metadata is not None: - kwargs["sampling_metadata"] = sampling_metadata - return kwargs - - @staticmethod - def _add_sampling_metadata_broadcastable_dict( - tensor_dict: Dict[str, Union[int, torch.Tensor]], - sampling_metadata: Optional["SamplingMetadata"]) -> None: - """ - Helper method to update tensor_dict with broadcastable - SamplingMetadata fields. - """ - if sampling_metadata is not None: - tensor_dict["selected_token_indices"] = ( - sampling_metadata.selected_token_indices) - class ModelRunnerBase(ABC, Generic[T]): """ @@ -166,7 +133,7 @@ class ModelRunnerBase(ABC, Generic[T]): model. Model execution may communicate data with model runners in other processes, but it should not include control plane metadata communication. - Each ModelRunnerBase subclass should define a corresponding ModelInput + Each ModelRunnerBase subclass should define a corresponding ModelInputBase subclass. """ @@ -175,7 +142,7 @@ def make_model_input(self, make_attn_metadata: bool = False, **model_input_fields) -> T: """ - Make an instance of a ModelInput from the given fields. If + Make an instance of a ModelInputBase from the given fields. If make_attn_metadata=True, then AttentionMetadata will be created from fields extracted from model_input_fields. """ diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 3fda809dffb61..51cf271627168 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -11,13 +11,13 @@ from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad -from vllm.worker.model_runner_base import ModelInput, ModelRunnerBase +from vllm.worker.model_runner_base import ModelInputBase, ModelRunnerBase logger = init_logger(__name__) @dataclass(frozen=True) -class ModelInputForNeuron(ModelInput): +class ModelInputForNeuron(ModelInputBase): """ Used by the NeuronModelRunner. """ @@ -30,6 +30,10 @@ def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: raise NotImplementedError("ModelInputForNeuron cannot be broadcast.") + @classmethod + def new(cls, **kwargs) -> "ModelInputForNeuron": + return cls(**kwargs) + class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): @@ -186,10 +190,10 @@ def prepare_model_input( self.device, self.pin_memory) - return ModelInputForNeuron(input_tokens=input_tokens, - input_positions=input_positions, - input_block_ids=input_block_ids, - sampling_metadata=sampling_metadata) + return ModelInputForNeuron.new(input_tokens=input_tokens, + input_positions=input_positions, + input_block_ids=input_block_ids, + sampling_metadata=sampling_metadata) @torch.inference_mode() def execute_model( diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index ff7562e06b901..46f808dfcff3b 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -12,7 +12,7 @@ from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) -from vllm.worker.model_runner_base import ModelInput, ModelRunnerBase +from vllm.worker.model_runner_base import ModelInputBase, ModelRunnerBase logger = init_logger(__name__) @@ -258,7 +258,7 @@ def execute_model( worker_input: WorkerInput = self.prepare_worker_input( execute_model_req=execute_model_req) - model_input: ModelInput = self.model_runner.prepare_model_input( + model_input: ModelInputBase = self.model_runner.prepare_model_input( execute_model_req.seq_group_metadata_list) if self.do_metadata_broadcast: diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 764f3f9ece931..ba13da5f3be17 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -15,7 +15,13 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata -from vllm.worker.model_runner_base import ModelInput, ModelRunnerBase +from vllm.worker.model_runner_base import ( + ModelInputBase, ModelRunnerBase, _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_kwargs, + _init_sampling_metadata_from_kwargs) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend logger = init_logger(__name__) @@ -27,7 +33,7 @@ @dataclass(frozen=True) -class ModelInputForXPU(ModelInput): +class ModelInputForXPU(ModelInputBase): """ Used by the NeuronModelRunner. """ @@ -37,22 +43,29 @@ class ModelInputForXPU(ModelInput): sampling_metadata: Optional["SamplingMetadata"] = None multi_modal_input: Optional[Dict[str, torch.Tensor]] = None - @classmethod - def _get_init_kwargs( # type: ignore - cls, **kwargs) -> Dict[str, Any]: - kwargs = cls._init_attn_metadata_from_kwargs(**kwargs) - kwargs = cls._init_sampling_metadata_from_kwargs(**kwargs) - return super()._get_init_kwargs(**kwargs) - def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = super().as_broadcastable_tensor_dict() - self._add_attn_metadata_broadcastable_dict(tensor_dict, - self.attn_metadata) - self._add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) + tensor_dict = self._get_attrs([ + "input_tokens", + "input_positions", + ]) + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) return tensor_dict + @classmethod + def new(cls, + attn_backend: Optional["AttentionBackend"] = None, + selected_token_indices: Optional[torch.Tensor] = None, + **kwargs) -> "ModelInputForXPU": + if attn_backend is not None: + kwargs = _init_attn_metadata_from_kwargs(attn_backend, **kwargs) + if selected_token_indices is not None: + kwargs = _init_sampling_metadata_from_kwargs( + selected_token_indices, **kwargs) + return cls(**kwargs) + class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): From 8cca6348ee3ef60fe0b682e75efcabc01d08b3a9 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Sun, 23 Jun 2024 15:53:22 -0700 Subject: [PATCH 46/55] fixes Signed-off-by: Stephanie Wang --- vllm/worker/embedding_model_runner.py | 5 +--- vllm/worker/model_runner.py | 3 ++- vllm/worker/worker_base.py | 36 +++++++-------------------- vllm/worker/xpu_worker.py | 6 ++--- 4 files changed, 15 insertions(+), 35 deletions(-) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 1f69d76d673e7..4720259c55347 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -99,10 +99,7 @@ def make_model_input(self, **kwargs) -> ModelInputForGPUWithPoolingMetadata: if make_attn_metadata: kwargs["attn_backend"] = self.attn_backend - return ModelInputForGPUWithPoolingMetadata.new( - attn_backend=self.attn_backend, - **kwargs, - ) + return ModelInputForGPUWithPoolingMetadata.new(**kwargs, ) def prepare_model_input( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 43c08d8fce255..a18e09e1380c9 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -950,7 +950,8 @@ def prepare_model_input( self.pin_memory) return model_input.replace( sampling_metadata=sampling_metadata, - is_prompt=seq_group_metadata_list[0].is_prompt) + is_prompt=seq_group_metadata_list[0].is_prompt + if seq_group_metadata_list else None) @torch.inference_mode() def execute_model( diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 46f808dfcff3b..2c963e505b7d6 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -2,7 +2,7 @@ import importlib import os from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Dict, List, Optional, Set, Tuple, Type, Union import torch @@ -116,16 +116,8 @@ def list_loras(self) -> Set[int]: @dataclasses.dataclass(frozen=True) class WorkerInput: - """Local inputs to each worker. May contain device-specific data. Different - worker backends may have different methods of converting from the global - ExecuteModelRequest produced by the LLM engine to the worker-local - WorkerInput objects. - - Subclasses of WorkerBase should inherit from this class and add their - required fields. For distributed executors, any fields that should be sent - during a broadcast op should also be added to the broadcastable_fields. - During execution, these fields will be extracted from the source copy and - broadcasted to all workers using broadcast_tensor_dict. + """Local inputs to each worker. May contain device-specific data. These + fields should be broadcastable to other workers. """ num_seq_groups: Optional[int] = None @@ -134,32 +126,22 @@ class WorkerInput: blocks_to_copy: Optional[torch.Tensor] = None @classmethod - def _get_init_kwargs(cls: Type["WorkerInput"], **kwargs) -> Dict[str, Any]: + def new(cls: Type["WorkerInput"], **kwargs) -> "WorkerInput": """ - Helper method to extract all dataclass fields from the given kwargs. - Override for fields that require some custom deserialization. + Create a new instance of this class. Populate the new instance with + fields popped from the given kwargs. """ init_kwargs = {} for field in dataclasses.fields(cls): - val = kwargs.get(field.name, None) + val = kwargs.pop(field.name, None) if val is not None: init_kwargs[field.name] = val - return init_kwargs - - @classmethod - def new(cls: Type["WorkerInput"], **kwargs) -> "WorkerInput": - """ - Create a new instance of this class. Populate the new instance with - the given kwargs. - """ - kwargs = cls._get_init_kwargs(**kwargs) - return cls(**kwargs) + return cls(**init_kwargs) def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: """ - Extract broadcastable fields. Override for fields that require some - custom deserialization. + Extract broadcastable fields. """ tensor_dict: Dict[str, Union[int, torch.Tensor]] = {} for field in dataclasses.fields(self): diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index c03595bc9787f..18e8ab72fad48 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -47,7 +47,7 @@ def __init__( lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, speculative_config: Optional[SpeculativeConfig] = None, - _is_driver_worker: bool = False, + is_driver_worker: bool = False, ) -> None: assert device_config.device_type == "xpu" assert is_xpu() @@ -62,7 +62,7 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config - self._is_driver_worker = _is_driver_worker + self._is_driver_worker = is_driver_worker if self._is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -80,7 +80,7 @@ def __init__( load_config=self.load_config, lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, - _is_driver_worker=_is_driver_worker, + is_driver_worker=is_driver_worker, vision_language_config=vision_language_config, ) # Uninitialized cache engine. Will be initialized by From 0a25c19d64480f708fe48898c763f538fc3e5e29 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Sun, 23 Jun 2024 15:55:38 -0700 Subject: [PATCH 47/55] rename Signed-off-by: Stephanie Wang --- vllm/worker/cpu_model_runner.py | 5 +++-- vllm/worker/model_runner.py | 5 +++-- vllm/worker/model_runner_base.py | 14 +++++++------- vllm/worker/neuron_model_runner.py | 4 ++-- vllm/worker/worker_base.py | 7 ++++--- vllm/worker/xpu_model_runner.py | 5 +++-- 6 files changed, 22 insertions(+), 18 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 3002936f88c6e..35357e82fd2e4 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -16,7 +16,8 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ( - ModelInputBase, ModelRunnerBase, _add_attn_metadata_broadcastable_dict, + ModelRunnerBase, ModelRunnerInputBase, + _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_kwargs, _init_sampling_metadata_from_kwargs) @@ -29,7 +30,7 @@ @dataclass(frozen=True) -class CPUModelInput(ModelInputBase): +class CPUModelInput(ModelRunnerInputBase): """ Used by the CPUModelRunner. """ diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a18e09e1380c9..c4462bc964a8f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -28,7 +28,8 @@ from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available, make_tensor_with_pad) from vllm.worker.model_runner_base import ( - ModelInputBase, ModelRunnerBase, _add_attn_metadata_broadcastable_dict, + ModelRunnerBase, ModelRunnerInputBase, + _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_kwargs, _init_sampling_metadata_from_kwargs) @@ -51,7 +52,7 @@ @dataclass(frozen=True) -class ModelInputForGPU(ModelInputBase): +class ModelInputForGPU(ModelRunnerInputBase): """ This base class contains metadata needed for the base model forward pass but not metadata for possible additional steps, e.g., sampling. Model diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 1f4db2cdd0139..85a138dc3e478 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -12,7 +12,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.model_executor import SamplingMetadata -T = TypeVar('T', bound="ModelInputBase") +T = TypeVar('T', bound="ModelRunnerInputBase") def _add_attn_metadata_broadcastable_dict( @@ -78,14 +78,14 @@ def _add_sampling_metadata_broadcastable_dict( @dataclasses.dataclass(frozen=True) -class ModelInputBase(ABC): +class ModelRunnerInputBase(ABC): """Local inputs to each worker's model runner. May contain device-specific data. Different worker backends may have different methods of converting from the global ExecuteModelRequest produced by the LLM - engine to the worker-local ModelInputBase objects. + engine to the worker-local ModelRunnerInputBase objects. Model runners that support multi-GPU execution should define a - ModelInputBase subclass, add their required fields, and specify how to + ModelRunnerInputBase subclass, add their required fields, and specify how to serialize/deserialize a ModelInput for broadcast between workers. """ @@ -133,8 +133,8 @@ class ModelRunnerBase(ABC, Generic[T]): model. Model execution may communicate data with model runners in other processes, but it should not include control plane metadata communication. - Each ModelRunnerBase subclass should define a corresponding ModelInputBase - subclass. + Each ModelRunnerBase subclass should define a corresponding + ModelRunnerInputBase subclass. """ @abstractmethod @@ -142,7 +142,7 @@ def make_model_input(self, make_attn_metadata: bool = False, **model_input_fields) -> T: """ - Make an instance of a ModelInputBase from the given fields. If + Make an instance of a ModelRunnerInputBase from the given fields. If make_attn_metadata=True, then AttentionMetadata will be created from fields extracted from model_input_fields. """ diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 51cf271627168..1481372906bc1 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -11,13 +11,13 @@ from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad -from vllm.worker.model_runner_base import ModelInputBase, ModelRunnerBase +from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase logger = init_logger(__name__) @dataclass(frozen=True) -class ModelInputForNeuron(ModelInputBase): +class ModelInputForNeuron(ModelRunnerInputBase): """ Used by the NeuronModelRunner. """ diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 2c963e505b7d6..8289fbb026c89 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -12,7 +12,7 @@ from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) -from vllm.worker.model_runner_base import ModelInputBase, ModelRunnerBase +from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase logger = init_logger(__name__) @@ -240,8 +240,9 @@ def execute_model( worker_input: WorkerInput = self.prepare_worker_input( execute_model_req=execute_model_req) - model_input: ModelInputBase = self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list) + model_input: ModelRunnerInputBase = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list)) if self.do_metadata_broadcast: broadcast_data = worker_input.as_broadcastable_tensor_dict() diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index ba13da5f3be17..847ef6c5e21fc 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -16,7 +16,8 @@ from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata from vllm.worker.model_runner_base import ( - ModelInputBase, ModelRunnerBase, _add_attn_metadata_broadcastable_dict, + ModelRunnerBase, ModelRunnerInputBase, + _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_kwargs, _init_sampling_metadata_from_kwargs) @@ -33,7 +34,7 @@ @dataclass(frozen=True) -class ModelInputForXPU(ModelInputBase): +class ModelInputForXPU(ModelRunnerInputBase): """ Used by the NeuronModelRunner. """ From 0b26877fb6e57c84d6265508065c1516a1a6f5b8 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 24 Jun 2024 10:48:53 -0700 Subject: [PATCH 48/55] fix Signed-off-by: Stephanie Wang --- vllm/worker/cpu_model_runner.py | 5 +++-- vllm/worker/model_runner.py | 5 +++-- vllm/worker/model_runner_base.py | 13 +++++++++++++ vllm/worker/neuron_model_runner.py | 5 ++++- vllm/worker/worker_base.py | 4 ++-- vllm/worker/xpu_model_runner.py | 5 +++-- 6 files changed, 28 insertions(+), 9 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 35357e82fd2e4..009f1d6a390d0 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -18,8 +18,8 @@ from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, - _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_kwargs, - _init_sampling_metadata_from_kwargs) + _add_sampling_metadata_broadcastable_dict, _filter_valid_kwargs, + _init_attn_metadata_from_kwargs, _init_sampling_metadata_from_kwargs) if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -62,6 +62,7 @@ def new(cls, if selected_token_indices is not None: kwargs = _init_sampling_metadata_from_kwargs( selected_token_indices, **kwargs) + kwargs = _filter_valid_kwargs(cls, kwargs) return cls(**kwargs) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c4462bc964a8f..b831bc7f3682d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -30,8 +30,8 @@ from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, - _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_kwargs, - _init_sampling_metadata_from_kwargs) + _add_sampling_metadata_broadcastable_dict, _filter_valid_kwargs, + _init_attn_metadata_from_kwargs, _init_sampling_metadata_from_kwargs) if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -86,6 +86,7 @@ def new(cls: Type[TModelInputForGPU], **kwargs) -> TModelInputForGPU: if attn_backend is not None: kwargs = _init_attn_metadata_from_kwargs(attn_backend, **kwargs) + kwargs = _filter_valid_kwargs(cls, kwargs) return cls(**kwargs) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 85a138dc3e478..339de899243fa 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -77,6 +77,19 @@ def _add_sampling_metadata_broadcastable_dict( sampling_metadata.selected_token_indices) +def _filter_valid_kwargs(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """ + Helper method to filter the given kwargs to kwargs that + are valid for the given dataclass `cls`. + """ + init_kwargs = {} + for field in dataclasses.fields(cls): + val = kwargs.get(field.name, None) + if val is not None: + init_kwargs[field.name] = val + return init_kwargs + + @dataclasses.dataclass(frozen=True) class ModelRunnerInputBase(ABC): """Local inputs to each worker's model runner. May contain diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 1481372906bc1..a924356f60580 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -11,7 +11,9 @@ from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad -from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase +from vllm.worker.model_runner_base import (ModelRunnerBase, + ModelRunnerInputBase, + _filter_valid_kwargs) logger = init_logger(__name__) @@ -32,6 +34,7 @@ def as_broadcastable_tensor_dict( @classmethod def new(cls, **kwargs) -> "ModelInputForNeuron": + kwargs = _filter_valid_kwargs(cls, kwargs) return cls(**kwargs) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 8289fbb026c89..f18d4b093cd4a 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -129,11 +129,11 @@ class WorkerInput: def new(cls: Type["WorkerInput"], **kwargs) -> "WorkerInput": """ Create a new instance of this class. Populate the new instance with - fields popped from the given kwargs. + fields from the given kwargs. """ init_kwargs = {} for field in dataclasses.fields(cls): - val = kwargs.pop(field.name, None) + val = kwargs.get(field.name, None) if val is not None: init_kwargs[field.name] = val return cls(**init_kwargs) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 847ef6c5e21fc..764c87eea3217 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -18,8 +18,8 @@ from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, - _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_kwargs, - _init_sampling_metadata_from_kwargs) + _add_sampling_metadata_broadcastable_dict, _filter_valid_kwargs, + _init_attn_metadata_from_kwargs, _init_sampling_metadata_from_kwargs) if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -65,6 +65,7 @@ def new(cls, if selected_token_indices is not None: kwargs = _init_sampling_metadata_from_kwargs( selected_token_indices, **kwargs) + kwargs = _filter_valid_kwargs(cls, kwargs) return cls(**kwargs) From e7052d57fe5f1c0fd84557acd57581c0e7202ccc Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 24 Jun 2024 11:26:22 -0700 Subject: [PATCH 49/55] do not filter Nones Signed-off-by: Stephanie Wang --- vllm/worker/cpu_model_runner.py | 10 +++++----- vllm/worker/model_runner.py | 14 +++++++------- vllm/worker/model_runner_base.py | 14 -------------- vllm/worker/xpu_model_runner.py | 11 ++++++----- 4 files changed, 18 insertions(+), 31 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 009f1d6a390d0..a53b1a82041ce 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -42,11 +42,11 @@ class CPUModelInput(ModelRunnerInputBase): def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = self._get_attrs([ - "input_tokens", - "input_positions", - "multi_modal_kwargs", - ]) + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "multi_modal_kwargs": self.multi_modal_kwargs, + } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, self.sampling_metadata) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b831bc7f3682d..2bd5fd6bc2907 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -70,13 +70,13 @@ class ModelInputForGPU(ModelRunnerInputBase): def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = self._get_attrs([ - "input_tokens", - "input_positions", - "lora_requests", - "lora_mapping", - "multi_modal_kwargs", - ]) + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, + "multi_modal_kwargs": self.multi_modal_kwargs, + } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 339de899243fa..4d9e18cfe0135 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -125,20 +125,6 @@ def as_broadcastable_tensor_dict( """ raise NotImplementedError - def _get_attrs(self, attrs: List[str]) -> Dict[str, Any]: - """ - Helper method to get a dictionary from attribute name to value. - Attributes whose values are None will not be added to the returned - dictionary. - """ - tensor_dict: Dict[str, Union[int, torch.Tensor]] = {} - for attr in attrs: - val = getattr(self, attr, None) - if val is not None: - tensor_dict[attr] = val - - return tensor_dict - class ModelRunnerBase(ABC, Generic[T]): """ diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 764c87eea3217..98f30bc73915c 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -46,10 +46,10 @@ class ModelInputForXPU(ModelRunnerInputBase): def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = self._get_attrs([ - "input_tokens", - "input_positions", - ]) + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, self.sampling_metadata) @@ -175,7 +175,8 @@ def profile_run(self) -> None: # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - self.execute_model(seqs, kv_caches) + model_input = self.prepare_model_input(seqs) + self.execute_model(model_input, kv_caches) torch.xpu.synchronize() return From df5551ff3c0023154d06266fef266800d76e6438 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 24 Jun 2024 11:29:38 -0700 Subject: [PATCH 50/55] dupe Signed-off-by: Stephanie Wang --- vllm/worker/model_runner.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2bd5fd6bc2907..af79e0e4deaf3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -102,7 +102,14 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = super().as_broadcastable_tensor_dict() + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, + "multi_modal_kwargs": self.multi_modal_kwargs, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, self.sampling_metadata) return tensor_dict @@ -115,7 +122,10 @@ def new(cls, if selected_token_indices is not None: kwargs = _init_sampling_metadata_from_kwargs( selected_token_indices, **kwargs) - return super().new(attn_backend, **kwargs) + if attn_backend is not None: + kwargs = _init_attn_metadata_from_kwargs(attn_backend, **kwargs) + kwargs = _filter_valid_kwargs(cls, kwargs) + return cls(**kwargs) class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): From 6745b3b27e1a10533786951414f1f79942bb8685 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 24 Jun 2024 17:29:08 -0700 Subject: [PATCH 51/55] update Signed-off-by: Stephanie Wang --- tests/worker/test_model_input.py | 12 ++-- vllm/attention/backends/ipex_attn.py | 4 +- vllm/attention/backends/pallas.py | 4 +- vllm/worker/cpu_model_runner.py | 42 +++++++------ vllm/worker/embedding_model_runner.py | 18 +++--- vllm/worker/model_runner.py | 82 +++++++++++++----------- vllm/worker/model_runner_base.py | 90 ++++++++++++--------------- vllm/worker/neuron_model_runner.py | 34 +++++----- vllm/worker/worker_base.py | 42 +++++++------ vllm/worker/xpu_model_runner.py | 49 +++++++-------- 10 files changed, 194 insertions(+), 183 deletions(-) diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 4737fab4004c9..e8fcb1aab5528 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -64,7 +64,7 @@ def test_model_runner_input(): num_decode_tokens=3, slot_mapping=torch.zeros(1), ) - model_input = ModelInputForGPUWithSamplingMetadata.new( + model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), input_positions=torch.ones(10), sampling_metadata=sampling_metadata, @@ -75,8 +75,8 @@ def test_model_runner_input(): # Test round trip serialization. tensor_dict = model_input.as_broadcastable_tensor_dict() attn_backend = MockAttentionBackend() - received_model_input = ModelInputForGPUWithSamplingMetadata.new( - attn_backend=attn_backend, **tensor_dict) + received_model_input = ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( + tensor_dict, attn_backend=attn_backend) # Check that received copy has correct values. assert isinstance(received_model_input, ModelInputForGPUWithSamplingMetadata) @@ -114,7 +114,7 @@ def test_embedding_model_runner_input(): num_decode_tokens=3, slot_mapping=torch.zeros(1), ) - model_input = ModelInputForGPUWithPoolingMetadata.new( + model_input = ModelInputForGPUWithPoolingMetadata( input_tokens=torch.ones(10), input_positions=torch.ones(10), pooling_metadata=pooling_metadata, @@ -125,8 +125,8 @@ def test_embedding_model_runner_input(): # Test round trip serialization. tensor_dict = model_input.as_broadcastable_tensor_dict() attn_backend = MockAttentionBackend() - received_model_input = ModelInputForGPUWithPoolingMetadata.new( - attn_backend=attn_backend, **tensor_dict) + received_model_input = ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( + tensor_dict, attn_backend=attn_backend) # Check that received copy has correct values. assert isinstance(received_model_input, ModelInputForGPUWithPoolingMetadata) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index f09b24f2a0304..5114bfa6e1589 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -25,8 +25,8 @@ def get_impl_cls() -> Type["IpexAttnBackendImpl"]: return IpexAttnBackendImpl @staticmethod - def make_metadata(*args, **kwargs) -> "IpexAttnMetadata": - return IpexAttnMetadata(*args, **kwargs) + def get_metadata_cls() -> Type["IpexAttnMetadata"]: + return IpexAttnMetadata @staticmethod def get_kv_cache_shape( diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index b203c5ec54c92..62b4a144fc443 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -16,8 +16,8 @@ def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: return PallasAttentionBackendImpl @staticmethod - def make_metadata(*args, **kwargs) -> "PallasMetadata": - return PallasMetadata(*args, **kwargs) + def get_metadata_cls() -> Type["PallasMetadata"]: + return PallasMetadata @staticmethod def get_kv_cache_shape( diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index a53b1a82041ce..e3464c0d3900c 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,6 +1,6 @@ from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import torch from torch import nn @@ -18,8 +18,9 @@ from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, - _add_sampling_metadata_broadcastable_dict, _filter_valid_kwargs, - _init_attn_metadata_from_kwargs, _init_sampling_metadata_from_kwargs) + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -53,17 +54,16 @@ def as_broadcastable_tensor_dict( return tensor_dict @classmethod - def new(cls, - attn_backend: Optional["AttentionBackend"] = None, - selected_token_indices: Optional[torch.Tensor] = None, - **kwargs) -> "CPUModelInput": + def from_broadcasted_tensor_dict( + cls: Type["CPUModelInput"], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None + ) -> "CPUModelInput": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) if attn_backend is not None: - kwargs = _init_attn_metadata_from_kwargs(attn_backend, **kwargs) - if selected_token_indices is not None: - kwargs = _init_sampling_metadata_from_kwargs( - selected_token_indices, **kwargs) - kwargs = _filter_valid_kwargs(cls, kwargs) - return cls(**kwargs) + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) class CPUModelRunner(ModelRunnerBase[CPUModelInput]): @@ -315,12 +315,14 @@ def _prepare_decode( attn_metadata, ) - def make_model_input(self, - make_attn_metadata: bool = False, - **kwargs) -> CPUModelInput: - if make_attn_metadata: - kwargs["attn_backend"] = self.attn_backend - return CPUModelInput.new(**kwargs, ) + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, Any], + ) -> CPUModelInput: + return CPUModelInput.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) def prepare_model_input( self, @@ -348,7 +350,7 @@ def prepare_model_input( seq_lens, self.device, pin_memory=False) - return CPUModelInput.new( + return CPUModelInput( input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 4720259c55347..bf7dd9158671c 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch @@ -50,6 +50,8 @@ def __init__( is_driver_worker=is_driver_worker, vision_language_config=vision_language_config) + self._model_input_cls : Type[TModelInputForGPU] = ModelInputForGPUWithPoolingMetadata + @torch.inference_mode() def execute_model( self, @@ -94,12 +96,14 @@ def execute_model( return self.model.pooler(hidden_states=hidden_states, pooling_metadata=model_input.pooling_metadata) - def make_model_input(self, - make_attn_metadata: bool = False, - **kwargs) -> ModelInputForGPUWithPoolingMetadata: - if make_attn_metadata: - kwargs["attn_backend"] = self.attn_backend - return ModelInputForGPUWithPoolingMetadata.new(**kwargs, ) + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, + Any]) -> ModelInputForGPUWithPoolingMetadata: + return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) def prepare_model_input( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index af79e0e4deaf3..b483ae8fd5512 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,9 +1,9 @@ +import dataclasses import gc import time import warnings from collections import defaultdict -from dataclasses import dataclass -from typing import (TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union) import numpy as np @@ -30,8 +30,9 @@ from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, - _add_sampling_metadata_broadcastable_dict, _filter_valid_kwargs, - _init_attn_metadata_from_kwargs, _init_sampling_metadata_from_kwargs) + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -51,7 +52,7 @@ TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class ModelInputForGPU(ModelRunnerInputBase): """ This base class contains metadata needed for the base model forward pass @@ -68,8 +69,7 @@ class ModelInputForGPU(ModelRunnerInputBase): attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, @@ -81,16 +81,18 @@ def as_broadcastable_tensor_dict( return tensor_dict @classmethod - def new(cls: Type[TModelInputForGPU], - attn_backend: Optional["AttentionBackend"] = None, - **kwargs) -> TModelInputForGPU: + def from_broadcasted_tensor_dict( + cls: Type[TModelInputForGPU], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> TModelInputForGPU: if attn_backend is not None: - kwargs = _init_attn_metadata_from_kwargs(attn_backend, **kwargs) - kwargs = _filter_valid_kwargs(cls, kwargs) - return cls(**kwargs) + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): """ Used by the ModelRunner. @@ -100,8 +102,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): # used by the driver worker. is_prompt: Optional[bool] = None - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, @@ -115,17 +116,16 @@ def as_broadcastable_tensor_dict( return tensor_dict @classmethod - def new(cls, - attn_backend: Optional["AttentionBackend"] = None, - selected_token_indices: Optional[torch.Tensor] = None, - **kwargs) -> "ModelInputForGPUWithSamplingMetadata": - if selected_token_indices is not None: - kwargs = _init_sampling_metadata_from_kwargs( - selected_token_indices, **kwargs) + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForGPUWithSamplingMetadata": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) if attn_backend is not None: - kwargs = _init_attn_metadata_from_kwargs(attn_backend, **kwargs) - kwargs = _filter_valid_kwargs(cls, kwargs) - return cls(**kwargs) + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): @@ -206,6 +206,8 @@ def __init__( # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + self._model_input_cls : Type[TModelInputForGPU] = ModelInputForGPUWithSamplingMetadata + def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( @@ -357,7 +359,7 @@ def _prepare_model_input_tensors( paged_kv_last_page_len: List[int] = [] if len(seq_group_metadata_list) == 0: - return self.make_model_input() + return self._model_input_cls() if self.sliding_window is not None: sliding_window_blocks = (self.sliding_window + self.block_size - @@ -707,7 +709,7 @@ def _prepare_model_input_tensors( for k, v in multi_modal_kwargs_list.items() } - return self.make_model_input( + return self._model_input_cls( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, attn_metadata=attn_metadata, @@ -929,12 +931,15 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): GPU model runner with sampling step. """ - def make_model_input(self, - make_attn_metadata: bool = False, - **kwargs) -> ModelInputForGPUWithSamplingMetadata: - if make_attn_metadata: - kwargs["attn_backend"] = self.attn_backend - return ModelInputForGPUWithSamplingMetadata.new(**kwargs, ) + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, Any], + ) -> ModelInputForGPUWithSamplingMetadata: + return ( + ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + )) def prepare_model_input( self, @@ -960,10 +965,11 @@ def prepare_model_input( model_input.query_lens, self.device, self.pin_memory) - return model_input.replace( - sampling_metadata=sampling_metadata, - is_prompt=seq_group_metadata_list[0].is_prompt - if seq_group_metadata_list else None) + is_prompt = (seq_group_metadata_list[0].is_prompt + if seq_group_metadata_list else None) + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + is_prompt=is_prompt) @torch.inference_mode() def execute_model( diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 4d9e18cfe0135..921404513d79c 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -1,7 +1,7 @@ import dataclasses from abc import ABC, abstractmethod from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, - TypeVar, Union) + TypeVar) import torch @@ -16,7 +16,7 @@ def _add_attn_metadata_broadcastable_dict( - tensor_dict: Dict[str, Union[int, torch.Tensor]], + tensor_dict: Dict[str, Any], attn_metadata: Optional["AttentionMetadata"]) -> None: """ Helper method to update tensor_dict with broadcastable @@ -26,8 +26,10 @@ def _add_attn_metadata_broadcastable_dict( tensor_dict.update(attn_metadata.asdict_zerocopy()) -def _init_attn_metadata_from_kwargs(attn_backend: "AttentionBackend", - **kwargs) -> Dict[str, Any]: +def _init_attn_metadata_from_tensor_dict( + attn_backend: "AttentionBackend", + tensor_dict: Dict[str, Any], +) -> Dict[str, Any]: """ Helper method to initialize AttentionMetadata based on an AttentionBackend and broadcastable AttentionMetadata fields. @@ -35,38 +37,38 @@ def _init_attn_metadata_from_kwargs(attn_backend: "AttentionBackend", # Extract the fields used to create AttentionMetadata. valid_attn_kwargs = {} for field in dataclasses.fields(attn_backend.get_metadata_cls()): - val = kwargs.pop(field.name, None) + val = tensor_dict.pop(field.name, None) if val is not None: valid_attn_kwargs[field.name] = val attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) - kwargs["attn_metadata"] = attn_metadata - return kwargs + tensor_dict["attn_metadata"] = attn_metadata + return tensor_dict -def _init_sampling_metadata_from_kwargs( # type: ignore - selected_token_indices: torch.Tensor = None, - **kwargs) -> Dict[str, Any]: +def _init_sampling_metadata_from_tensor_dict( # type: ignore + tensor_dict: Dict[str, Any]) -> Dict[str, Any]: """ Helper method to initialize SamplingMetadata based on broadcastable SamplingMetadata fields. """ from vllm.model_executor import SamplingMetadata + selected_token_indices = tensor_dict.pop("selected_token_indices", None) # An empty SamplingMetadata to signal that the worker should skip # sampling. - sampling_metadata = SamplingMetadata( - seq_groups=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - num_prompts=0, - ) - kwargs["sampling_metadata"] = sampling_metadata - return kwargs + if selected_token_indices is not None: + tensor_dict["sampling_metadata"] = SamplingMetadata( + seq_groups=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + num_prompts=0, + ) + return tensor_dict def _add_sampling_metadata_broadcastable_dict( - tensor_dict: Dict[str, Union[int, torch.Tensor]], + tensor_dict: Dict[str, Any], sampling_metadata: Optional["SamplingMetadata"]) -> None: """ Helper method to update tensor_dict with broadcastable @@ -77,19 +79,6 @@ def _add_sampling_metadata_broadcastable_dict( sampling_metadata.selected_token_indices) -def _filter_valid_kwargs(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]: - """ - Helper method to filter the given kwargs to kwargs that - are valid for the given dataclass `cls`. - """ - init_kwargs = {} - for field in dataclasses.fields(cls): - val = kwargs.get(field.name, None) - if val is not None: - init_kwargs[field.name] = val - return init_kwargs - - @dataclasses.dataclass(frozen=True) class ModelRunnerInputBase(ABC): """Local inputs to each worker's model runner. May contain @@ -102,29 +91,32 @@ class ModelRunnerInputBase(ABC): serialize/deserialize a ModelInput for broadcast between workers. """ - @classmethod - @abstractmethod - def new(cls: Type[T], **kwargs) -> T: - """ - Create a new instance of this class. Populate the new instance with - the given kwargs. - """ - raise NotImplementedError - def replace(self: T, **kwargs) -> T: """ Replace current fields with fields in kwargs. """ return dataclasses.replace(self, **kwargs) - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: """ Extract broadcastable fields. Override for fields that require some custom deserialization. """ raise NotImplementedError + @classmethod + @abstractmethod + def from_broadcasted_tensor_dict( + cls: Type[T], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> T: + """ + Pop fields from the given tensor_dict and populate a new instance of + ModelRunnerInputBase. + """ + raise NotImplementedError + class ModelRunnerBase(ABC, Generic[T]): """ @@ -137,13 +129,13 @@ class ModelRunnerBase(ABC, Generic[T]): """ @abstractmethod - def make_model_input(self, - make_attn_metadata: bool = False, - **model_input_fields) -> T: + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, Any], + ) -> T: """ - Make an instance of a ModelRunnerInputBase from the given fields. If - make_attn_metadata=True, then AttentionMetadata will be created from - fields extracted from model_input_fields. + Make an instance of a ModelRunnerInputBase from the broadcasted tensor + dict. """ raise NotImplementedError diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index a924356f60580..fec2c97e73889 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch from torch import nn @@ -11,9 +11,10 @@ from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad -from vllm.worker.model_runner_base import (ModelRunnerBase, - ModelRunnerInputBase, - _filter_valid_kwargs) +from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend logger = init_logger(__name__) @@ -33,9 +34,13 @@ def as_broadcastable_tensor_dict( raise NotImplementedError("ModelInputForNeuron cannot be broadcast.") @classmethod - def new(cls, **kwargs) -> "ModelInputForNeuron": - kwargs = _filter_valid_kwargs(cls, kwargs) - return cls(**kwargs) + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForNeuron": + assert attn_backend is None + return cls.from_broadcasted_tensor_dict(tensor_dict) class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): @@ -163,10 +168,9 @@ def _prepare_decode( return input_tokens, input_positions, input_block_ids - def make_model_input(self, - make_attn_metadata: bool = False, - **kwargs) -> ModelInputForNeuron: - return ModelInputForNeuron.new(**kwargs) + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron: + return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict) def prepare_model_input( self, @@ -193,10 +197,10 @@ def prepare_model_input( self.device, self.pin_memory) - return ModelInputForNeuron.new(input_tokens=input_tokens, - input_positions=input_positions, - input_block_ids=input_block_ids, - sampling_metadata=sampling_metadata) + return ModelInputForNeuron(input_tokens=input_tokens, + input_positions=input_positions, + input_block_ids=input_block_ids, + sampling_metadata=sampling_metadata) @torch.inference_mode() def execute_model( diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index f18d4b093cd4a..266155870d935 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -2,7 +2,7 @@ import importlib import os from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union import torch @@ -126,28 +126,32 @@ class WorkerInput: blocks_to_copy: Optional[torch.Tensor] = None @classmethod - def new(cls: Type["WorkerInput"], **kwargs) -> "WorkerInput": + def from_broadcasted_tensor_dict( + cls: Type["WorkerInput"], + tensor_dict: Dict[str, Any], + ) -> "WorkerInput": """ - Create a new instance of this class. Populate the new instance with - fields from the given kwargs. + Pop fields from the given tensor_dict and populate a new instance of + WorkerInput. """ - init_kwargs = {} - for field in dataclasses.fields(cls): - val = kwargs.get(field.name, None) - if val is not None: - init_kwargs[field.name] = val - return cls(**init_kwargs) + return cls( + num_seq_groups=tensor_dict.pop("num_seq_groups"), + blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), + blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), + blocks_to_copy=tensor_dict.pop("blocks_to_copy"), + ) def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: """ Extract broadcastable fields. """ - tensor_dict: Dict[str, Union[int, torch.Tensor]] = {} - for field in dataclasses.fields(self): - val = getattr(self, field.name, None) - if val is not None: - tensor_dict[field.name] = val + tensor_dict = { + "num_seq_groups": self.num_seq_groups, + "blocks_to_swap_in": self.blocks_to_swap_in, + "blocks_to_swap_out": self.blocks_to_swap_out, + "blocks_to_copy": self.blocks_to_copy, + } return tensor_dict @@ -255,9 +259,11 @@ def execute_model( if not broadcast_data: return None - worker_input = WorkerInput.new(**broadcast_data) - model_input = self.model_runner.make_model_input( - make_attn_metadata=True, **broadcast_data) + worker_input = WorkerInput.from_broadcasted_tensor_dict( + broadcast_data) + model_input = ( + self.model_runner. + make_model_input_from_broadcasted_tensor_dict(broadcast_data)) self.execute_worker(worker_input) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 98f30bc73915c..d9124a788a69d 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -18,8 +18,9 @@ from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, - _add_sampling_metadata_broadcastable_dict, _filter_valid_kwargs, - _init_attn_metadata_from_kwargs, _init_sampling_metadata_from_kwargs) + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -56,17 +57,16 @@ def as_broadcastable_tensor_dict( return tensor_dict @classmethod - def new(cls, - attn_backend: Optional["AttentionBackend"] = None, - selected_token_indices: Optional[torch.Tensor] = None, - **kwargs) -> "ModelInputForXPU": + def from_broadcasted_tensor_dict( + cls: Type["ModelInputForXPU"], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForXPU": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) if attn_backend is not None: - kwargs = _init_attn_metadata_from_kwargs(attn_backend, **kwargs) - if selected_token_indices is not None: - kwargs = _init_sampling_metadata_from_kwargs( - selected_token_indices, **kwargs) - kwargs = _filter_valid_kwargs(cls, kwargs) - return cls(**kwargs) + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): @@ -180,15 +180,12 @@ def profile_run(self) -> None: torch.xpu.synchronize() return - def make_model_input(self, - make_attn_metadata: bool = False, - **kwargs) -> ModelInputForXPU: - if make_attn_metadata: - kwargs["attn_backend"] = self.attn_backend - return ModelInputForXPU.new( + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> ModelInputForXPU: + return (ModelInputForXPU.from_broadcasted_tensor_dict( + tensor_dict, attn_backend=self.attn_backend, - **kwargs, - ) + )) def prepare_model_input( self, @@ -240,11 +237,11 @@ def prepare_model_input( num_prompts=0, ) - return self.make_model_input(input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata, - sampling_metadata=sampling_metadata, - multi_modal_input=multi_modal_input) + return ModelInputForXPU(input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + sampling_metadata=sampling_metadata, + multi_modal_input=multi_modal_input) def _prepare_decode( self, From ebae970c2f8986f895e5d0d57267c1e600f11611 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 24 Jun 2024 18:35:10 -0700 Subject: [PATCH 52/55] lint Signed-off-by: Stephanie Wang --- tests/worker/test_model_input.py | 10 ++++++---- vllm/worker/embedding_model_runner.py | 6 +++--- vllm/worker/model_runner.py | 5 +++-- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index e8fcb1aab5528..ae818ee360f19 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -75,8 +75,9 @@ def test_model_runner_input(): # Test round trip serialization. tensor_dict = model_input.as_broadcastable_tensor_dict() attn_backend = MockAttentionBackend() - received_model_input = ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( - tensor_dict, attn_backend=attn_backend) + received_model_input = ( + ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( + tensor_dict, attn_backend=attn_backend)) # Check that received copy has correct values. assert isinstance(received_model_input, ModelInputForGPUWithSamplingMetadata) @@ -125,8 +126,9 @@ def test_embedding_model_runner_input(): # Test round trip serialization. tensor_dict = model_input.as_broadcastable_tensor_dict() attn_backend = MockAttentionBackend() - received_model_input = ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( - tensor_dict, attn_backend=attn_backend) + received_model_input = ( + ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( + tensor_dict, attn_backend=attn_backend)) # Check that received copy has correct values. assert isinstance(received_model_input, ModelInputForGPUWithPoolingMetadata) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index bf7dd9158671c..b9897ab2ace90 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type import torch @@ -25,6 +25,8 @@ class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): class EmbeddingModelRunner( GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): + _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = ( + ModelInputForGPUWithPoolingMetadata) def __init__( self, @@ -50,8 +52,6 @@ def __init__( is_driver_worker=is_driver_worker, vision_language_config=vision_language_config) - self._model_input_cls : Type[TModelInputForGPU] = ModelInputForGPUWithPoolingMetadata - @torch.inference_mode() def execute_model( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b483ae8fd5512..9fdb2ea5dd4e4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -132,6 +132,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): """ Helper class for shared methods between GPU model runners. """ + _model_input_cls: Type[TModelInputForGPU] def __init__( self, @@ -206,8 +207,6 @@ def __init__( # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None - self._model_input_cls : Type[TModelInputForGPU] = ModelInputForGPUWithSamplingMetadata - def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( @@ -930,6 +929,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): """ GPU model runner with sampling step. """ + _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = ( + ModelInputForGPUWithSamplingMetadata) def make_model_input_from_broadcasted_tensor_dict( self, From 5763621d1db06523d0d3211b8a26e7462dfc4856 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 24 Jun 2024 22:54:37 -0700 Subject: [PATCH 53/55] revert Signed-off-by: Stephanie Wang --- vllm/worker/cpu_worker.py | 19 +++++------------- vllm/worker/neuron_worker.py | 16 ++++----------- vllm/worker/worker.py | 39 ++++++++++++++---------------------- vllm/worker/worker_base.py | 23 ++------------------- vllm/worker/xpu_worker.py | 8 ++++---- 5 files changed, 30 insertions(+), 75 deletions(-) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 882c31ba3b04d..30ee262c7a8b3 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -15,7 +15,6 @@ from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.cpu_model_runner import CPUModelRunner -from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoraNotSupportedWorkerBase, WorkerInput) @@ -147,15 +146,15 @@ def __init__( self.distributed_init_method = distributed_init_method self.lora_config = lora_config self.vision_language_config = vision_language_config - self._is_driver_worker = is_driver_worker - if self._is_driver_worker: + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self._model_runner = CPUModelRunner( + self.model_runner: CPUModelRunner = CPUModelRunner( model_config, parallel_config, scheduler_config, @@ -177,7 +176,7 @@ def init_device(self) -> None: set_random_seed(self.model_config.seed) def load_model(self): - self._model_runner.load_model() + self.model_runner.load_model() def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of blocks available for the KV cache. @@ -248,7 +247,7 @@ def _init_cache_engine(self) -> None: self.parallel_config, self.device_config) self.cpu_cache = self.cache_engine.cpu_cache - self._model_runner.block_size = self.cache_engine.block_size + self.model_runner.block_size = self.cache_engine.block_size assert self.cpu_cache is not None @@ -256,18 +255,10 @@ def _init_cache_engine(self) -> None: for layer_cache in self.cpu_cache: layer_cache.fill_(0) - @property - def is_driver_worker(self) -> bool: - return self._is_driver_worker - @property def do_metadata_broadcast(self) -> bool: return self.parallel_config.tensor_parallel_size > 1 - @property - def model_runner(self) -> ModelRunnerBase: - return self._model_runner - @property def kv_cache(self) -> Optional[List[torch.Tensor]]: return self.cpu_cache diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 2ae611f99a85b..307c107ddef71 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -8,7 +8,6 @@ ParallelConfig, SchedulerConfig) from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest -from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoraNotSupportedWorkerBase, WorkerInput) @@ -36,15 +35,16 @@ def __init__( from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self._model_runner = NeuronModelRunner(model_config, parallel_config, - scheduler_config, device_config) + self.model_runner: NeuronModelRunner = NeuronModelRunner( + model_config, parallel_config, scheduler_config, device_config) + self.is_driver_worker = True def init_device(self) -> None: # Set random seed. set_random_seed(self.model_config.seed) def load_model(self): - self._model_runner.load_model() + self.model_runner.load_model() def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks. @@ -75,18 +75,10 @@ def initialize_cache(self, num_gpu_blocks: int, self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - @property - def is_driver_worker(self) -> bool: - return True - @property def do_metadata_broadcast(self) -> bool: return False - @property - def model_runner(self) -> ModelRunnerBase: - return self._model_runner - @property def kv_cache(self) -> Optional[List[torch.Tensor]]: return None diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index f0107c7b5c791..e1944a4f1d636 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -19,7 +19,6 @@ from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner -from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput @@ -57,8 +56,8 @@ def __init__( self.distributed_init_method = distributed_init_method self.lora_config = lora_config self.load_config = load_config - self._is_driver_worker = is_driver_worker - if self._is_driver_worker: + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." if self.model_config.trust_remote_code: @@ -81,7 +80,7 @@ def __init__( ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner if self.model_config.embedding_mode: ModelRunnerClass = EmbeddingModelRunner - self._model_runner: GPUModelRunnerBase = ModelRunnerClass( + self.model_runner: GPUModelRunnerBase = ModelRunnerClass( model_config, parallel_config, scheduler_config, @@ -129,7 +128,7 @@ def init_device(self) -> None: set_random_seed(self.model_config.seed) def load_model(self): - self._model_runner.load_model() + self.model_runner.load_model() def save_sharded_state( self, @@ -137,7 +136,7 @@ def save_sharded_state( pattern: Optional[str] = None, max_size: Optional[int] = None, ) -> None: - self._model_runner.save_sharded_state( + self.model_runner.save_sharded_state( path, pattern=pattern, max_size=max_size, @@ -147,7 +146,7 @@ def save_tensorized_model( self, tensorizer_config: TensorizerConfig, ) -> None: - self._model_runner.save_tensorized_model( + self.model_runner.save_tensorized_model( tensorizer_config=tensorizer_config, ) @torch.inference_mode() @@ -169,7 +168,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Execute a forward pass with dummy inputs to profile the memory usage # of the model. - self._model_runner.profile_run() + self.model_runner.profile_run() # Calculate the number of blocks that can be allocated with the # profiled peak memory. @@ -190,8 +189,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) - if self._model_runner.lora_manager: - self._model_runner.remove_all_loras() + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() gc.collect() torch.cuda.empty_cache() return num_gpu_blocks, num_cpu_blocks @@ -221,23 +220,15 @@ def _init_cache_engine(self): def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: - self._model_runner.capture_model(self.gpu_cache) + self.model_runner.capture_model(self.gpu_cache) # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) - @property - def is_driver_worker(self) -> bool: - return self._is_driver_worker - @property def do_metadata_broadcast(self) -> bool: return self.parallel_config.tensor_parallel_size > 1 - @property - def model_runner(self) -> ModelRunnerBase: - return self._model_runner - @property def kv_cache(self) -> Optional[List[torch.Tensor]]: return self.gpu_cache @@ -282,16 +273,16 @@ def execute_worker(self, worker_input: WorkerInput) -> None: self.cache_engine.copy(worker_input.blocks_to_copy) def add_lora(self, lora_request: LoRARequest) -> bool: - return self._model_runner.add_lora(lora_request) + return self.model_runner.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: - return self._model_runner.remove_lora(lora_id) + return self.model_runner.remove_lora(lora_id) def pin_lora(self, lora_id: int) -> bool: - return self._model_runner.pin_lora(lora_id) + return self.model_runner.pin_lora(lora_id) def list_loras(self) -> Set[int]: - return self._model_runner.list_loras() + return self.model_runner.list_loras() @property def max_model_len(self) -> int: @@ -299,7 +290,7 @@ def max_model_len(self) -> int: @property def vocab_size(self) -> int: - return self._model_runner.vocab_size + return self.model_runner.vocab_size def get_cache_block_size_bytes(self) -> int: """Get the size of the KV cache block size in bytes. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 266155870d935..1f1ce4e7b114b 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -165,17 +165,8 @@ class LocalOrDistributedWorkerBase(WorkerBase): If custom control plane logic is needed to transfer metadata, or if the model runner cannot inherit from ModelRunnerBase, use WorkerBase instead. """ - - @property - @abstractmethod - def is_driver_worker(self) -> bool: - """ - Used by the default `execute_model` to check whether this is the driver - worker. The driver worker is responsible for broadcasting request - inputs to other workers in its TP group. If WorkerBase subclass only - supports single-worker execution, then this method should return True. - """ - raise NotImplementedError + is_driver_worker: bool + model_runner: ModelRunnerBase @property @abstractmethod @@ -188,16 +179,6 @@ def do_metadata_broadcast(self) -> bool: """ raise NotImplementedError - @property - @abstractmethod - def model_runner(self) -> ModelRunnerBase: - """ - Get the worker's model runner. Used by the default `execute_model`. If - the worker's model runner does not follow the ModelRunnerBase - interface, then inherit from WorkerBase instead. - """ - raise NotImplementedError - @property @abstractmethod def kv_cache(self) -> Optional[List[torch.Tensor]]: diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index 18e8ab72fad48..773ee9f8159e1 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -62,8 +62,8 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config - self._is_driver_worker = is_driver_worker - if self._is_driver_worker: + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." self.vision_language_config = vision_language_config @@ -71,7 +71,7 @@ def __init__( assert not self.lora_config, ( "To be tested: vision language model with LoRA settings.") - self._model_runner = XPUModelRunner( # type: ignore + self.model_runner = XPUModelRunner( # type: ignore model_config, parallel_config, scheduler_config, @@ -123,7 +123,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Execute a forward pass with dummy inputs to profile the memory usage # of the model. - self._model_runner.profile_run() + self.model_runner.profile_run() # Calculate the number of blocks that can be allocated with the # profiled peak memory. From d16d5fec24beffcb8b8c0447f34b18da053d4803 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Tue, 25 Jun 2024 10:28:18 -0700 Subject: [PATCH 54/55] rm Signed-off-by: Stephanie Wang --- vllm/worker/model_runner_base.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 921404513d79c..9b1706035a33e 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -91,12 +91,6 @@ class ModelRunnerInputBase(ABC): serialize/deserialize a ModelInput for broadcast between workers. """ - def replace(self: T, **kwargs) -> T: - """ - Replace current fields with fields in kwargs. - """ - return dataclasses.replace(self, **kwargs) - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: """ Extract broadcastable fields. Override for fields that require some From f6c62349c85e64a883c34ad6292dcb6c33c6d65d Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Tue, 25 Jun 2024 10:34:39 -0700 Subject: [PATCH 55/55] fix Signed-off-by: Stephanie Wang --- vllm/worker/embedding_model_runner.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index b9897ab2ace90..3c8dfa2c6d8df 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +import dataclasses from typing import Any, Dict, List, Optional, Tuple, Type import torch @@ -15,7 +15,7 @@ logger = init_logger(__name__) -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): """ Used by the EmbeddingModelRunner. @@ -117,7 +117,8 @@ def prepare_model_input( pooling_metadata = self._prepare_pooling(seq_group_metadata_list, model_input.seq_lens) - return model_input.replace(pooling_metadata=pooling_metadata) + return dataclasses.replace(model_input, + pooling_metadata=pooling_metadata) def _prepare_pooling( self,