From bc34937d68e9715d8416457539fb528301cf6269 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 25 Jun 2024 15:25:52 -0700 Subject: [PATCH] [Hardware][TPU] Refactor TPU backend (#5831) --- vllm/executor/tpu_executor.py | 58 +++++++++++++++++++++------------ vllm/worker/tpu_model_runner.py | 4 +++ vllm/worker/tpu_worker.py | 35 +++++++++++++------- 3 files changed, 65 insertions(+), 32 deletions(-) diff --git a/vllm/executor/tpu_executor.py b/vllm/executor/tpu_executor.py index 5ed00e137410..7fe5349c987a 100644 --- a/vllm/executor/tpu_executor.py +++ b/vllm/executor/tpu_executor.py @@ -1,4 +1,4 @@ -from typing import List, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple import torch @@ -26,29 +26,45 @@ def _init_executor(self) -> None: self.model_config.dtype = torch.bfloat16 # Instantiate the worker and load the model to the device. - self._init_worker() - - def _init_worker(self): - from vllm.worker.tpu_worker import TPUWorker + self.driver_worker = self._create_worker() + self.driver_worker.init_device() + self.driver_worker.load_model() - assert self.parallel_config.world_size == 1, ( - "TPUExecutor currently only supports a single TPU chip.") - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - self.driver_worker = TPUWorker( - self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, - self.cache_config, - self.load_config, - self.vision_language_config, - local_rank=0, - rank=0, + def _get_worker_kwargs( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None, + ) -> Dict[str, Any]: + """Return worker init args for a given rank.""" + if distributed_init_method is None: + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + return dict( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + load_config=self.load_config, + local_rank=local_rank, + rank=rank, distributed_init_method=distributed_init_method, + vision_language_config=self.vision_language_config, + is_driver_worker=rank == 0, ) - self.driver_worker.init_device() - self.driver_worker.load_model() + + def _create_worker( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None, + ): + from vllm.worker.tpu_worker import TPUWorker + + worker = TPUWorker(**self._get_worker_kwargs(local_rank, rank, + distributed_init_method)) + return worker def initialize_cache( self, diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 5003d3b0ca44..2d8fffe5ac16 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -33,6 +33,7 @@ def __init__( cache_config: CacheConfig, load_config: LoadConfig, vision_language_config: Optional[VisionLanguageConfig] = None, + is_driver_worker: bool = False, ): self.model_config = model_config self.parallel_config = parallel_config @@ -41,6 +42,7 @@ def __init__( self.cache_config = cache_config self.load_config = load_config self.vision_language_config = vision_language_config + self.is_driver_worker = is_driver_worker self.block_size = self.cache_config.block_size self.max_num_blocks_per_seq = (self.model_config.max_model_len // @@ -373,6 +375,8 @@ def _execute_model( inputs = self.prepare_inputs(seq_group_metadata_list) next_token_ids = self.model(inputs[0], inputs[1], kv_caches, *inputs[2:]) + if not self.is_driver_worker: + return [] next_token_ids = next_token_ids.cpu().tolist() i = 0 diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 04576015dadb..828bb89d70ba 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -34,6 +34,7 @@ def __init__( local_rank: int, rank: int, distributed_init_method: str, + is_driver_worker: bool, ) -> None: self.model_config = model_config self.parallel_config = parallel_config @@ -45,6 +46,7 @@ def __init__( self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker assert self.device_config.device_type == "tpu" if self.cache_config.cache_dtype == "auto": @@ -53,10 +55,14 @@ def __init__( self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ self.cache_config.cache_dtype] - self.model_runner = TPUModelRunner(model_config, parallel_config, - scheduler_config, device_config, - cache_config, load_config, - vision_language_config) + self.model_runner = TPUModelRunner(model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + vision_language_config, + is_driver_worker=is_driver_worker) def init_device(self) -> None: os.environ["PJRT_DEVICE"] = "TPU" @@ -175,16 +181,13 @@ def get_cache_block_size_bytes(self) -> int: def execute_model( self, - execute_model_req: Optional[ExecuteModelRequest] = None + execute_model_req: Optional[ExecuteModelRequest] = None, ) -> List[SamplerOutput]: - if execute_model_req is None: - return [] - - seq_group_metadata_list = execute_model_req.seq_group_metadata_list - num_seq_groups = len(seq_group_metadata_list) - if num_seq_groups == 0: + if not self.is_driver_worker: + self._execute_model_non_driver() return [] + assert execute_model_req is not None # Currently, TPUWorker does not support swapping. # TODO(woosuk): Support block copying. assert len(execute_model_req.blocks_to_swap_in) == 0, ( @@ -193,6 +196,16 @@ def execute_model( "Swapping is not supported for the TPU backend.") assert len(execute_model_req.blocks_to_copy) == 0 + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + assert len(seq_group_metadata_list) > 0 output = self.model_runner.execute_model(seq_group_metadata_list, self.tpu_cache) return [output] + + def start_worker_execution_loop(self) -> None: + while self._execute_model_non_driver(): + pass + + def _execute_model_non_driver(self) -> bool: + self.model_runner.execute_model(None, self.tpu_cache) + return True