diff --git a/examples/openai_completion_client.py b/examples/openai_completion_client.py index 58519f978d34..260cdb672cef 100644 --- a/examples/openai_completion_client.py +++ b/examples/openai_completion_client.py @@ -4,28 +4,44 @@ openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" -client = OpenAI( - # defaults to os.environ.get("OPENAI_API_KEY") - api_key=openai_api_key, - base_url=openai_api_base, -) - -models = client.models.list() -model = models.data[0].id - -# Completion API -stream = False -completion = client.completions.create( - model=model, - prompt="A robot may not injure a human being", - echo=False, - n=2, - stream=stream, - logprobs=3) - -print("Completion results:") -if stream: - for c in completion: - print(c) -else: - print(completion) + +def get_prompts(n=1): + ps = ['A robot may not injure a human being'] + for i in range(1, n): + ps.append(' '.join(["hi!"] * i)) + + return ps + + +def main(): + client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, + ) + + models = client.models.list() + model = models.data[0].id + + prompts = get_prompts(50) + #print (f"{prompts}") + print(f"# PROMPTS : {len(prompts)}") + + # Completion API + stream = False + completion = client.completions.create(model=model, + prompt=prompts, + echo=False, + n=1, + stream=stream) + + print("Completion results:") + if stream: + for c in completion: + print(c) + else: + print(completion) + + +if __name__ == '__main__': + main() diff --git a/tests/basic_correctness/test_multi_step_chunked_prefill.py b/tests/basic_correctness/test_multi_step_chunked_prefill.py new file mode 100644 index 000000000000..d60fc1ecaed9 --- /dev/null +++ b/tests/basic_correctness/test_multi_step_chunked_prefill.py @@ -0,0 +1,75 @@ +# Test the AsyncLLMEngine with multi-step-decoding and chunked prefill + +from typing import List + +import pytest + +from ..utils import RemoteOpenAIServer + +MODELS = [ + "facebook/opt-125m", + "meta-llama/Llama-2-7b-hf", +] +NUM_SCHEDULER_STEPS = [8, 16] # Multi-step decoding steps +NUM_PROMPTS = [100] + +# TODO (varun) : Expand tests for multiple TP & PP +DEFAULT_SERVER_ARGS: List[str] = [ + "--disable-log-requests", + "--use-v2-block-manager", + "--worker-use-ray", + "--gpu-memory-utilization", + "0.90", + "--swap-space", + "16", + "--tensor-parallel-size", + "1", + "--pipeline-parallel-size", + "1", +] + + +async def completions_with_server_args(prompts: List[str], model_name: str, + server_cli_args: List[str]): + + outputs = None + with RemoteOpenAIServer(model_name, server_cli_args) as server: + client = server.get_async_client() + outputs = await client.completions.create(model=model_name, + prompt=prompts, + temperature=0, + stream=False, + max_tokens=150) + assert outputs is not None + + return outputs + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) +@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.asyncio +async def test_mutli_step_with_chunked_prefill(example_prompts, model: str, + num_scheduler_steps: int, + num_prompts: int): + + prompts = example_prompts + if len(prompts) < num_prompts: + prompts = prompts * ((num_prompts // len(prompts)) + 1) + prompts = prompts[:num_prompts] + assert len(prompts) == num_prompts + + server_args = DEFAULT_SERVER_ARGS + \ + ["--num-scheduler-steps", f"{num_scheduler_steps}"] + + ref_completions = await completions_with_server_args( + prompts, model, server_args) + test_completions = await completions_with_server_args( + prompts, model, server_args + ["--enable-chunked-prefill"]) + + def get_text_generations(completions): + return [x.text for x in completions.choices] + + ref_generations = get_text_generations(ref_completions) + test_generations = get_text_generations(test_completions) + assert ref_generations == test_generations diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 2126fafb2323..0244919152ca 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -10,6 +10,8 @@ from vllm.worker.embedding_model_runner import ( ModelInputForGPUWithPoolingMetadata) from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata +from vllm.worker.multi_step_model_runner import ( + MutableModelInputForGPUWithMultiStepMetadata) class MockAttentionBackend(AttentionBackend): @@ -154,3 +156,82 @@ def test_embedding_model_runner_input(): None) == getattr(attn_metadata, field.name, None) # Pooling metadata is not broadcast. assert received_model_input.pooling_metadata is None + + +def test_multi_step_model_runner_input(): + sampling_metadata = SamplingMetadata( + ["seq_group"], + "selected_token_indices", + "categorized_sample_indices", + "num_prompts", + ) + attn_metadata = AttentionMetadata( + num_prefills=1, + num_prefill_tokens=2, + num_decode_tokens=3, + slot_mapping=torch.zeros(1), + ) + frozen_model_input = ModelInputForGPUWithSamplingMetadata( + input_tokens=torch.ones(10), + input_positions=torch.ones(10), + sampling_metadata=sampling_metadata, + attn_metadata=attn_metadata) + + model_input = MutableModelInputForGPUWithMultiStepMetadata( + frozen_model_input=frozen_model_input, + is_last_step=True, + is_first_multi_step=False, + current_step=4, + last_sampled_token_ids=torch.ones((10, 1)), + is_multi_step=True, + num_queries=8, + num_seqs=5, + outputs=[], + ) + + assert isinstance(model_input, + MutableModelInputForGPUWithMultiStepMetadata) + + # Test round trip serialization. + tensor_dict = model_input.as_broadcastable_tensor_dict() + attn_backend = MockAttentionBackend() + received_model_input = (MutableModelInputForGPUWithMultiStepMetadata. + from_broadcasted_tensor_dict( + tensor_dict, attn_backend=attn_backend)) + + receieved_frozen_input = received_model_input.frozen_model_input + + # Check that received copy has correct values. + assert isinstance(received_model_input, + MutableModelInputForGPUWithMultiStepMetadata) + assert receieved_frozen_input.input_tokens is not None + assert (receieved_frozen_input.input_tokens == + frozen_model_input.input_tokens).all() + assert receieved_frozen_input.input_positions is not None + assert (receieved_frozen_input.input_positions == + frozen_model_input.input_positions).all() + assert receieved_frozen_input.multi_modal_kwargs is None + assert (frozen_model_input.multi_modal_kwargs == + frozen_model_input.multi_modal_kwargs) + assert receieved_frozen_input.lora_requests is None + assert (receieved_frozen_input.lora_requests == + frozen_model_input.lora_requests) + assert receieved_frozen_input.lora_mapping is None + assert ( + receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping) + for field in dataclasses.fields(AttentionMetadata): + assert getattr(receieved_frozen_input.attn_metadata, field.name, + None) == getattr(attn_metadata, field.name, None) + # For sampling metadata, only selected_token_indices is copied. + assert (receieved_frozen_input.sampling_metadata.selected_token_indices == + sampling_metadata.selected_token_indices) + assert receieved_frozen_input.sampling_metadata.seq_groups is None + + # check non frozen fields + assert received_model_input.is_last_step == model_input.is_last_step + assert (received_model_input.is_first_multi_step == + model_input.is_first_multi_step) + assert received_model_input.current_step == model_input.current_step + assert (received_model_input.last_sampled_token_ids == + model_input.last_sampled_token_ids).all() + assert received_model_input.is_multi_step == model_input.is_multi_step diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 6ed75a6e2ea6..f986c3400e12 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -981,6 +981,21 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: [s.seq_group for s in swapped_in.prefill_seq_groups]) # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) + + if self.scheduler_config.is_multi_step: + # It maybe the case that prefills are scheduled along + # with decodes. In that case update the multi-step state + # of all the scheduled sequences to perform just a single + # decoding step. + has_prefills = len(prefills.seq_groups) + \ + len(running_scheduled.prefill_seq_groups) + \ + len(swapped_in.prefill_seq_groups) > 0 + if has_prefills: + for sg in running_scheduled.decode_seq_groups: + sg.seq_group.init_multi_step(1) + for sg in swapped_in.decode_seq_groups: + sg.seq_group.init_multi_step(1) + return SchedulerOutputs( scheduled_seq_groups=(prefills.seq_groups + running_scheduled.prefill_seq_groups + @@ -1187,7 +1202,8 @@ def _append_slots( the new source and destination block indices for the appended slots. """ - num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) + num_lookahead_slots = self._get_num_lookahead_slots(\ + is_prefill=seq_group.is_prefill()) seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d99387542da1..9b1180a1ee0b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -865,12 +865,11 @@ def create_engine_config(self, ) -> EngineConfig: ) if self.num_scheduler_steps > 1: - raise NotImplementedError("Multi-step is not yet supported.") if speculative_config is not None: raise ValueError("Speculative decoding is not supported with " "multi-step (--num-scheduler-steps > 1)") - if self.enable_chunked_prefill: - raise ValueError("Chunked prefill is not supported with " + if not self.use_v2_block_manager: + raise ValueError("BlockSpaceManagerV2 is required for " "multi-step (--num-scheduler-steps > 1)") # make sure num_lookahead_slots is set the higher value depending on diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index a28b20fcbbcd..2a5f6f87dabf 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,9 +1,11 @@ import asyncio import time +from dataclasses import dataclass from functools import partial from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Type, Union) +import torch from transformers import PreTrainedTokenizer from typing_extensions import assert_never @@ -27,7 +29,8 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) from vllm.usage.usage_lib import UsageContext from vllm.utils import print_warning_once @@ -249,9 +252,24 @@ def has_new_requests(self): return not self._new_requests.empty() +@dataclass +class SchedulerOutputState: + """Caches the scheduler outputs for a virtual engine. Used for Multi-Step""" + last_output: Optional[SamplerOutput] = None + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + scheduler_outputs: Optional[SchedulerOutputs] = None + + class _AsyncLLMEngine(LLMEngine): """Extension of LLMEngine to add async methods.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + pipeline_parallel_size = \ + self.parallel_config.pipeline_parallel_size + self.cached_scheduler_outputs = [SchedulerOutputState() + ] * pipeline_parallel_size + async def step_async( self, virtual_engine: int ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: @@ -264,13 +282,39 @@ async def step_async( and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ - seq_group_metadata_list, scheduler_outputs = self.scheduler[ - virtual_engine].schedule() + # these are cached outputs from previous iterations. None if on first + # iteration + cached_outputs = self.cached_scheduler_outputs[virtual_engine] + seq_group_metadata_list = cached_outputs.seq_group_metadata_list + scheduler_outputs = cached_outputs.scheduler_outputs + # skip the scheduler if there are any remaining steps in the seq groups. + # This ensures that the scheduler is only called again when the current + # batch has completed. + if not self._has_remaining_steps(seq_group_metadata_list): + seq_group_metadata_list, scheduler_outputs = self.scheduler[ + virtual_engine].schedule() + + if self.scheduler_config.is_multi_step and \ + self._remaining_steps(seq_group_metadata_list) > 1: + # cache the scheduler outputs for the next iteration if we have + # one. + self._cache_scheduler_outputs_for_multi_step( + virtual_engine, seq_group_metadata_list, scheduler_outputs) + + assert seq_group_metadata_list is not None + assert scheduler_outputs is not None if not scheduler_outputs.is_empty(): - # Execute the model. finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() + + # Check if we have a cached last_output from the previous iteration. + # For supporting PP this is probably the best way to pass the + # sampled_token_ids, as a separate broadcast over all the PP stages + # will cause one virtual engine's microbatch to block the pipeline. + last_sampled_token_ids = \ + self._get_last_sampled_token_ids(virtual_engine) + execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, @@ -279,15 +323,35 @@ async def step_async( virtual_engine=virtual_engine, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - finished_requests_ids=finished_requests_ids) + finished_requests_ids=finished_requests_ids, + # We use ExecuteModelRequest to pass the last sampled_token_ids + # to each of the non-last PP stages for in-place prepare_input. + last_sampled_token_ids=last_sampled_token_ids) + # Execute the model. output = await self.model_executor.execute_model_async( execute_model_req) + # we need to do this here so that last step's sampled_token_ids can + # be passed to the next iteration for PP. + if self.scheduler_config.is_multi_step: + self._update_cached_scheduler_output(virtual_engine, output) else: output = [] - request_outputs = self._process_model_outputs( - output, scheduler_outputs.scheduled_seq_groups, - scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) + # Finish the current step for all the sequence groups. + if self.scheduler_config.is_multi_step: + for seq_group in seq_group_metadata_list: + seq_group.finish_step() + + if not self._has_remaining_steps(seq_group_metadata_list): + # clear the cache if we have finished all the steps + if self.scheduler_config.is_multi_step: + self.cached_scheduler_outputs[ + virtual_engine] = SchedulerOutputState() + request_outputs = self._process_model_outputs( + output, scheduler_outputs.scheduled_seq_groups, + scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) + else: + request_outputs = [] # Log stats. self.do_log_stats(scheduler_outputs, output) @@ -297,6 +361,68 @@ async def step_async( return request_outputs + def _remaining_steps( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] + ) -> int: + if not self.scheduler_config.is_multi_step: + return 0 + + if not seq_group_metadata_list: + return 0 + + # TODO(will) this is a sanity check for nowto make sure that all the + # seqs are on the same steps. Eventually we will want to do some sort of + # dynamic scheduling when doing multi-step decoding. + ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps + if any([ + seq_group.state.remaining_steps != ref_remaining_steps + for seq_group in seq_group_metadata_list[1:] + ]): + raise AssertionError(("All running sequence groups should " + "have the same remaining steps.")) + + return ref_remaining_steps + + def _has_remaining_steps( + self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] + ) -> bool: + return self._remaining_steps(seq_group_metadata_list) > 0 + + def _cache_scheduler_outputs_for_multi_step( + self, virtual_engine: int, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + scheduler_outputs: SchedulerOutputs) -> None: + self.cached_scheduler_outputs[ + virtual_engine].seq_group_metadata_list = seq_group_metadata_list + self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \ + scheduler_outputs + self.cached_scheduler_outputs[virtual_engine].last_output = None + + def _get_last_sampled_token_ids( + self, virtual_engine: int) -> Optional[torch.Tensor]: + cached_last_output = self.cached_scheduler_outputs[ + virtual_engine].last_output + if (self.scheduler_config.is_multi_step + and self.parallel_config.pipeline_parallel_size > 1 + and cached_last_output is not None + and cached_last_output.sampled_token_ids_cpu is not None): + return cached_last_output.sampled_token_ids_cpu + return None + + def _update_cached_scheduler_output( + self, virtual_engine: int, + output: List[Optional[SamplerOutput]]) -> None: + if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0 + and output[0] is not None): + last_output = output[-1] + assert last_output is not None + assert last_output.sampled_token_ids_cpu is not None + assert last_output.sampled_token_ids is None + assert last_output.sampled_token_probs is None + self.cached_scheduler_outputs[ + virtual_engine].last_output = last_output + async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 57b9e2b33b98..38d8fd91b6d4 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -70,13 +70,19 @@ def _get_create_worker_kwargs( distributed_init_method: Optional[str] = None) -> Dict: worker_kwargs = self._get_worker_kwargs(local_rank, rank, distributed_init_method) - if self.speculative_config is None: - worker_kwargs.update(worker_module_name="vllm.worker.worker", - worker_class_name="Worker") - else: + + if self.scheduler_config.is_multi_step: + worker_kwargs.update( + worker_module_name="vllm.worker.multi_step_worker", + worker_class_name="MultiStepWorker") + elif self.speculative_config: worker_kwargs.update( worker_module_name="vllm.spec_decode.spec_decode_worker", worker_class_name="create_spec_worker") + else: + worker_kwargs.update(worker_module_name="vllm.worker.worker", + worker_class_name="Worker") + return worker_kwargs def _create_worker(self, diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index fa3646012dd6..c55aef84159e 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -87,6 +87,9 @@ def _get_worker_wrapper_args(self) -> Dict[str, Any]: if self.speculative_config is not None: worker_module_name = "vllm.spec_decode.spec_decode_worker" worker_class_name = "create_spec_worker" + elif self.scheduler_config.is_multi_step: + worker_module_name = "vllm.worker.multi_step_worker" + worker_class_name = "MultiStepWorker" else: worker_module_name = "vllm.worker.worker" worker_class_name = "Worker" diff --git a/vllm/sequence.py b/vllm/sequence.py index b83e345235cd..c7849507476e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -8,7 +8,6 @@ from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, Union, cast) -import numpy import torch from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs @@ -997,7 +996,7 @@ class SamplerOutput: # On-device tensor containing the sampled token ids. sampled_token_ids: Optional[torch.Tensor] = None - sampled_token_ids_numpy: Optional[numpy.ndarray] = None + sampled_token_ids_cpu: Optional[torch.Tensor] = None # Spec decode metrics populated by workers. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None @@ -1156,9 +1155,7 @@ def is_last_step(self) -> bool: # steps assert len(self.seq_group_metadata_list) > 0 first_seq_group = self.seq_group_metadata_list[0] - num_steps = first_seq_group.state.num_steps - current_step = first_seq_group.state.current_step - return num_steps - current_step == 1 + return first_seq_group.state.remaining_steps == 1 @property def current_step(self) -> int: diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 46ac16b504bf..90c39407d726 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -14,7 +14,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.model_executor import SamplingMetadata -T = TypeVar('T', bound="ModelRunnerInputBase") +T = TypeVar('T', bound="BroadcastableModelInput") def _add_attn_metadata_broadcastable_dict( @@ -81,18 +81,26 @@ def _add_sampling_metadata_broadcastable_dict( sampling_metadata.selected_token_indices) -@dataclasses.dataclass(frozen=True) -class ModelRunnerInputBase(ABC): - """Local inputs to each worker's model runner. May contain - device-specific data. Different worker backends may have different methods - of converting from the global ExecuteModelRequest produced by the LLM - engine to the worker-local ModelRunnerInputBase objects. - - Model runners that support multi-GPU execution should define a - ModelRunnerInputBase subclass, add their required fields, and specify how to - serialize/deserialize a ModelInput for broadcast between workers. +def _init_frozen_model_input_from_tensor_dict( + frozen_model_input_cls: Type["ModelRunnerInputBase"], + tensor_dict: Dict[str, Any]) -> Dict[str, Any]: """ + Helper method to initialize a frozen ModelInput based on broadcastable + """ + valid_tensor_kwargs = {} + for field in dataclasses.fields(frozen_model_input_cls): + val = tensor_dict.pop(field.name, None) + if val is not None: + valid_tensor_kwargs[field.name] = val + + frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs) + tensor_dict["frozen_model_input"] = frozen_model_input + return tensor_dict + +class BroadcastableModelInput(ABC): + + @abstractmethod def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: """ Extract broadcastable fields. Override for fields that require some @@ -109,11 +117,25 @@ def from_broadcasted_tensor_dict( ) -> T: """ Pop fields from the given tensor_dict and populate a new instance of - ModelRunnerInputBase. + BroadcastableModelInput. """ raise NotImplementedError +@dataclasses.dataclass(frozen=True) +class ModelRunnerInputBase(BroadcastableModelInput): + """Local inputs to each worker's model runner. May contain + device-specific data. Different worker backends may have different methods + of converting from the global ExecuteModelRequest produced by the LLM + engine to the worker-local ModelRunnerInputBase objects. + + Model runners that support multi-GPU execution should define a + ModelRunnerInputBase subclass, add their required fields, and specify how to + serialize/deserialize a ModelInput for broadcast between workers. + """ + pass + + class ModelRunnerInputBuilderBase(ABC, Generic[T]): """A builder to create ModelRunnerInputBase objects. """ diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py new file mode 100644 index 000000000000..e8b46a03d842 --- /dev/null +++ b/vllm/worker/multi_step_model_runner.py @@ -0,0 +1,516 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +try: + from vllm.attention.backends.flash_attn import FlashAttentionMetadata +except ModuleNotFoundError: + # vllm_flash_attn is not installed, use the identical ROCm FA metadata + from vllm.attention.backends.rocm_flash_attn import ( + ROCmFlashAttentionMetadata as FlashAttentionMetadata) + +import torch + +from vllm import _custom_ops as ops +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + Logprob, SamplerOutput, SequenceGroupMetadata, + SequenceOutput) +from vllm.worker.model_runner import (GPUModelRunnerBase, + ModelInputForGPUWithSamplingMetadata) +from vllm.worker.model_runner_base import ( + BroadcastableModelInput, _init_attn_metadata_from_tensor_dict, + _init_frozen_model_input_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +from ..model_executor.model_loader.tensorizer import TensorizerConfig + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + + +@dataclass +class ModelOutput: + """The output of a single model forward pass. + + The sampler_output_ready_event is set when the tensors in + sampler_output are ready (the model+sampler forward pass has + completed). We use the event to synchronize the GPU->CPU transfer, + which we want to only run when the data has been written to the + GPU tensors. Until the event is ready, the tensors in sampler_output + will have garbage data. + + There are two scenarios: + 1. The output tensors are ready and we can pythonize them immediately. + 2. The output tensors are not ready and we need to wait for the event to be + ready. + """ + sampler_output: SamplerOutput + sampler_output_ready_event: torch.cuda.Event + sampled_token_ids: Optional[torch.Tensor] = None + pythonized: bool = False + + def pythonize( + self, + input_metadata: "MutableModelInputForGPUWithMultiStepMetadata", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor) -> None: + """Pythonize the output. Blocking.""" + if not self.pythonized: + self._pythonize_sampler_output_wait_on_event( + input_metadata, copy_stream, pinned_sampled_token_buffer) + self.pythonized = True + + def maybe_pythonize( + self, + input_metadata: "MutableModelInputForGPUWithMultiStepMetadata", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor) -> None: + """Pythonize the output if ready, else return None. Non-blocking.""" + if not self.pythonized: + self.pythonized = self._pythonize_sampler_output_if_event_ready( + input_metadata, copy_stream, pinned_sampled_token_buffer) + + def _pythonize_sampler_output_wait_on_event( + self, + input_metadata: "MutableModelInputForGPUWithMultiStepMetadata", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor) -> None: + """ + Block until the forward pass for the output is ready and pythonize the + output. + """ + assert self.sampled_token_ids is not None + self.sampler_output_ready_event.synchronize() + with torch.cuda.stream(copy_stream): + _pythonize_sampler_output(input_metadata, self.sampler_output, + pinned_sampled_token_buffer, + self.sampled_token_ids) + + def _pythonize_sampler_output_if_event_ready( + self, + input_metadata: "MutableModelInputForGPUWithMultiStepMetadata", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor) -> bool: + """ + Check if the forward pass for this output is finished and only pythonize + the output if it is. + """ + if self.sampler_output_ready_event.query(): + assert self.sampled_token_ids is not None + with torch.cuda.stream(copy_stream): + _pythonize_sampler_output(input_metadata, self.sampler_output, + pinned_sampled_token_buffer, + self.sampled_token_ids) + return True + return False + + +@dataclass(frozen=False) +class MutableModelInputForGPUWithMultiStepMetadata(BroadcastableModelInput): + # actual frozen model input dataclass passed to _base_model_runner + frozen_model_input: Optional[ModelInputForGPUWithSamplingMetadata] = None + + # list of model outputs for each step, may not be all pythonized + outputs: List[ModelOutput] = field(default_factory=list) + + # used to pass sampled token ids from the last step to the current step for + # TP workers. Used to append to end of outputs and used by advance_step + last_sampled_token_ids: Optional[torch.Tensor] = None + current_step: int = 0 + is_multi_step: bool = True + is_last_step: bool = False + is_first_multi_step: bool = False + # ping-pong data structures for multi-step to wait on the previous step + step_cuda_events: List[torch.cuda.Event] = field( + default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2) + num_seqs: int = -1 + num_queries: int = -1 + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + assert self.frozen_model_input is not None + tensor_dict = self.frozen_model_input.as_broadcastable_tensor_dict() + new_tensor_dict = { + 'last_sampled_token_ids': self.last_sampled_token_ids, + 'current_step': self.current_step, + 'is_multi_step': self.is_multi_step, + 'is_last_step': self.is_last_step, + 'is_first_multi_step': self.is_first_multi_step, + 'num_seqs': self.num_seqs, + 'num_queries': self.num_queries, + } + tensor_dict.update(new_tensor_dict) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "MutableModelInputForGPUWithMultiStepMetadata": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + tensor_dict = _init_frozen_model_input_from_tensor_dict( + ModelInputForGPUWithSamplingMetadata, tensor_dict) + + return cls(**tensor_dict) + + def record_step_event(self, current_stream: torch.cuda.Stream): + # record the event for the current step so that the next step can sync + # on it. We modulo by 2 to keep the events in a circular buffer and + # support any attn backends that may be supported in the future. ie + # Flashinfer would want two DecodeWrappers to overlap the CPU and GPU. + self.step_cuda_events[self.current_step & 1] = \ + torch.cuda.Event(blocking=True) + self.step_cuda_events[self.current_step & 1].record(current_stream) + + def wait_previous_step(self): + # These cuda events are an explicit synchronization to ensure that + # advance_step() (for other attn backends that may be supported in the + # future) do not clobber any data structures that is also used by any + # enqueued forwards steps. For distributed case, only a single event is + # needed, but for single GPU case, since we can let the CPU run much + # further ahead, two events allow us to overlap the advance_step with + # the previous forward (ie using two DecodeWrappers for flashinfer + # backend) + self.step_cuda_events[(self.current_step + 1) & 1].wait() + + def add_sampler_output(self, + sampler_output: SamplerOutput, + sampled_token_ids: Optional[torch.Tensor] = None): + self.outputs.append( + ModelOutput(sampler_output=sampler_output, + sampler_output_ready_event=None, + sampled_token_ids=sampled_token_ids, + pythonized=False)) + + +# MutableModelInputForGPUWithMultiStepMetadata is not subclass of +# ModelInputForGPU but it wraps the actual input dataclass and adds multi-step +# metadata +# mypy: disable-error-code=type-var +class MultiStepModelRunner( + GPUModelRunnerBase[MutableModelInputForGPUWithMultiStepMetadata]): + # mypy: enable-error-code=type-var + + def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): + super().__init__(*args, **kwargs) + + # uses the base model runner to execute the model and wraps it with + # multi-step logic + self._base_model_runner: GPUModelRunnerBase = base_model_runner + + self.is_multi_step = self.scheduler_config.is_multi_step + # used to copy tensors from GPU to CPU asynchronously + self._copy_stream = torch.cuda.Stream() + self.pinned_sampled_token_ids: Optional[torch.Tensor] = None + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any] + ) -> MutableModelInputForGPUWithMultiStepMetadata: + model_input = (MutableModelInputForGPUWithMultiStepMetadata. + from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + )) + return model_input + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> MutableModelInputForGPUWithMultiStepMetadata: + frozen_model_input = self._base_model_runner.prepare_model_input( + seq_group_metadata_list, virtual_engine, finished_requests_ids) + + model_input = MutableModelInputForGPUWithMultiStepMetadata( + frozen_model_input=frozen_model_input, + num_seqs=len(frozen_model_input.seq_lens), + num_queries=len(frozen_model_input.query_lens), + ) + return model_input + + @torch.inference_mode() + def execute_model( + self, + model_input: MutableModelInputForGPUWithMultiStepMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + """ + Execute the model for a single step and update multi-step + metadata + """ + assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1" + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + + # path for warm up runs + if not model_input.is_multi_step: + return self._base_model_runner.execute_model( + frozen_model_input, kv_caches, intermediate_tensors, num_steps) + + # make sure we skip the sampler on the lask rank and only pythonize + # if CPU is ahead. + if self.is_driver_worker and get_pp_group().is_last_rank: + if self.pinned_sampled_token_ids is None: + self.pinned_sampled_token_ids = torch.zeros( + (self.scheduler_config.max_num_seqs, 1), + dtype=torch.long, + device="cpu", + pin_memory=True) + + self._base_model_runner.model.sampler.include_gpu_probs_tensor = ( + True) + if frozen_model_input.sampling_metadata: + frozen_model_input.sampling_metadata.skip_sampler_cpu_output = ( + True) + + # some pre-execute model logic for multi-step: + # - if it's the first step, we need to reset the sampling tensors + # - if it's not the first step, we need to advance the step using the + # appended sampler output from last iteration + # - also maybe pythonize if CPU is ahead of GPU + + current_stream = torch.cuda.current_stream() + if not model_input.is_first_multi_step: + # Explicitly block on the previous step's forward to make sure we + # don't clobber any GPU tensors still in use. + # This is not needed for flashattn backend, but for other attn + # backends such as flashinfer that performs extra CPU operations on + # input metadata we may need to synchronize any CPU operations that + # might clobber enqueued forwards. (prevents CPU from running too + # far ahead if needed) + model_input.wait_previous_step() + model_input = self._advance_step( + model_input, model_input.outputs[-1].sampler_output) + + # Execute the model + output = self._base_model_runner.execute_model(frozen_model_input, + kv_caches, + intermediate_tensors, + num_steps=1) + + # record the event for the current step so that the next step can sync + model_input.record_step_event(current_stream) + + if get_pp_group().is_last_rank and self.is_driver_worker: + assert len( + output + ) == 1, "MultiStepModelRunner requires single-step base_models" + + # event for the pythonization so that we only pythonize if the + # tensors are ready. May be able to be combined with the step event + output_ready_event = torch.cuda.Event() + output_ready_event.record(current_stream) + if self.parallel_config.pipeline_parallel_size > 1: + output[0].sampled_token_ids_cpu = output[ + 0].sampled_token_ids.cpu() + model_input.outputs.append( + ModelOutput(output[0], output_ready_event, + output[0].sampled_token_ids, False)) + # make sure we dont try to serialize any GPU tensors + output[0].sampled_token_ids = None + output[0].sampled_token_probs = None + output[0].logprobs = None + # Pythonize the output if CPU is ahead and the previous step is + # ready. + for model_output in model_input.outputs: + model_output.maybe_pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + + model_input.current_step += 1 + + if not get_pp_group().is_last_rank: + # Should be IntermediateTensors + assert isinstance(output, IntermediateTensors) + return output + if not self.is_driver_worker: + return [] + + # Pythonize the output and block if needed since it is the last step + if model_input.is_last_step: + outputs = [] + for output in model_input.outputs: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + outputs.append(output.sampler_output) + return outputs + + # should be [SamplerOutput] + return output + + def _update_flash_attn_metadata(self, attn_metadata, num_seqs, + num_queries): + assert isinstance(attn_metadata, FlashAttentionMetadata) + + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert attn_metadata.use_cuda_graph + + assert attn_metadata.num_prefills == 0 + assert attn_metadata.num_prefill_tokens == 0 + assert attn_metadata.num_decode_tokens == num_seqs + assert attn_metadata.slot_mapping.shape == (num_seqs, ) + + assert len(attn_metadata.seq_lens) == num_seqs + assert attn_metadata.seq_lens_tensor.shape == (num_seqs, ) + assert attn_metadata.max_query_len == 1 + assert attn_metadata.max_prefill_seq_len == 0 + assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens) + + assert attn_metadata.query_start_loc.shape == (num_queries + 1, ) + assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, ) + + assert attn_metadata.context_lens_tensor.shape == (num_queries, ) + + assert attn_metadata.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + attn_metadata.seq_lens[i] += 1 + attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens) + + def _update_sampling_metadata(self, sampling_metadata, num_seqs, + num_queries): + + assert sampling_metadata.num_prompts == 0 + assert len(sampling_metadata.seq_groups) == num_queries + assert sampling_metadata.selected_token_indices.shape == ( + num_queries, ) + # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 + + # Verify that all sequences are decodes + for i in range(num_queries): + seq_group = sampling_metadata.seq_groups[i] + + assert seq_group.is_prompt is False # No prompt + assert seq_group.prompt_logprob_indices == [] # No prompt + assert seq_group.sample_indices == [i] # Simple + assert seq_group.seq_len is None # Decode + assert seq_group.query_len is None # Decode + + def _advance_step( + self, model_input: MutableModelInputForGPUWithMultiStepMetadata, + out: SamplerOutput + ) -> MutableModelInputForGPUWithMultiStepMetadata: + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + assert frozen_model_input.attn_metadata is not None + + num_seqs = model_input.num_seqs + num_queries = model_input.num_queries + assert num_seqs > 0 + assert num_queries > 0 + assert num_seqs >= num_queries + + attn_metadata = frozen_model_input.attn_metadata + assert isinstance(attn_metadata, FlashAttentionMetadata) + self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries) + + # Update GPU tensors + ops.advance_step( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=self.block_size, + input_tokens=frozen_model_input.input_tokens, + sampled_token_ids=model_input.outputs[-1].sampled_token_ids, + input_positions=frozen_model_input.input_positions, + seq_lens=attn_metadata.seq_lens_tensor, + slot_mapping=attn_metadata.slot_mapping, + block_tables=attn_metadata.block_tables) + + if frozen_model_input.seq_lens is not None: + for i in range(num_queries): + frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i] + + return model_input + + def load_model(self) -> None: + return self._base_model_runner.load_model() + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + return self._base_model_runner.save_sharded_state( + path, pattern, max_size) + + def save_tensorized_model(self, + tensorizer_config: TensorizerConfig) -> None: + return self._base_model_runner.save_tensorized_model(tensorizer_config) + + def profile_run(self) -> None: + return self._base_model_runner.profile_run() + + def remove_all_loras(self): + return self._base_model_runner.remove_all_loras() + + def capture_model(self, kv_caches: List[List]) -> None: + return self._base_model_runner.capture_model(kv_caches) + + @property + def vocab_size(self) -> int: + return self._base_model_runner.vocab_size + + +def _pythonize_sampler_output( + model_input: MutableModelInputForGPUWithMultiStepMetadata, + output: SamplerOutput, pinned_sampled_token_buffer: torch.Tensor, + sampled_token_ids: torch.Tensor) -> SamplerOutput: + """ This function is only called when the output tensors are ready. + See ModelOutput + """ + + assert model_input.frozen_model_input is not None + + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input.sampling_metadata is not None + # samples generation should have been skipped + assert not output.outputs + + # dont use num-queries as some of the sequence's may not need sampling. + # Like, chunked prefill seqs. + n_sampled_token_ids = sampled_token_ids.shape[0] + pinned_buffer = pinned_sampled_token_buffer[:n_sampled_token_ids] + + # CPU GPU sync + pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False) + + # this will not block as the tensors are already on CPU + samples_list = pinned_buffer.tolist() + + sampling_metadata = frozen_model_input.sampling_metadata + + sample_result_it = iter(samples_list) + for seq_group in sampling_metadata.seq_groups: + seq_outputs: List[SequenceOutput] = [] + if seq_group.sampling_params.logits_processors: + assert len(seq_group.sampling_params.logits_processors) == 0, ( + "Logits Processors are not supported in multi-step decoding") + if seq_group.do_sample: + sample_result = next(sample_result_it) + seq_ids = seq_group.seq_ids + next_token_ids = sample_result + parent_ids = [0] + for parent_id, next_token_id in zip(parent_ids, next_token_ids): + # TODO(will): support logprobs + # Hard coded logprob + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, + {next_token_id: Logprob(logprob=42)})) + output.outputs.append(CompletionSequenceGroupOutput(seq_outputs, None)) + + assert len(output.outputs) > 0 diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py new file mode 100644 index 000000000000..b14da2747629 --- /dev/null +++ b/vllm/worker/multi_step_worker.py @@ -0,0 +1,193 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple + +from vllm.distributed import broadcast_tensor_dict, get_pp_group +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.worker.model_runner_base import BroadcastableModelInput +from vllm.worker.multi_step_model_runner import ( + MultiStepModelRunner, MutableModelInputForGPUWithMultiStepMetadata) +from vllm.worker.worker import Worker, WorkerInput + + +@dataclass +class MultiStepState: + worker_input: WorkerInput + model_input: MutableModelInputForGPUWithMultiStepMetadata + + +class MultiStepWorker(Worker): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + base_model_runner = self.model_runner + # for multi-step model, wrap the model runner with MultiStepModelRunner + self.model_runner = MultiStepModelRunner( + base_model_runner, + base_model_runner.model_config, + base_model_runner.parallel_config, + base_model_runner.scheduler_config, + base_model_runner.device_config, + base_model_runner.cache_config, + load_config=base_model_runner.load_config, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=base_model_runner.is_driver_worker, + prompt_adapter_config=base_model_runner.prompt_adapter_config, + multimodal_config=base_model_runner.multimodal_config, + ) + + pipeline_parallel_size = self.parallel_config.pipeline_parallel_size + self.multi_step_states: List[ + Optional[MultiStepState]] = [None] * pipeline_parallel_size + self.temp_output = None + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[BroadcastableModelInput, WorkerInput]: + """ + Get the driver input and broadcast it to other workers. + """ + assert self.is_driver_worker + virtual_engine = execute_model_req.virtual_engine + is_first_multi_step = execute_model_req.is_first_multi_step + if is_first_multi_step: + # on first step we prepare the worker input and model input normally + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: MutableModelInputForGPUWithMultiStepMetadata = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + else: + # on subsequent steps we reuse the worker input and model input + multi_step_state = self.multi_step_states[virtual_engine] + worker_input = multi_step_state.worker_input + model_input = multi_step_state.model_input + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + assert frozen_model_input.attn_metadata is not None + # clear the cached decode metadata so that it can be recomputed on + # the workers + frozen_model_input.attn_metadata._cached_decode_metadata = None + + model_input.is_first_multi_step = is_first_multi_step + model_input.is_last_step = execute_model_req.is_last_step + + if not is_first_multi_step: + # we broadcast the last sampled token ids to all TP workers so they + # can update their model input metadata in-place. + self._prepare_last_sampled_token_ids_for_tp_workers( + execute_model_req=execute_model_req, model_input=model_input) + + if self.do_metadata_broadcast: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update(model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + + return model_input, worker_input + + def _prepare_last_sampled_token_ids_for_tp_workers( + self, + execute_model_req: ExecuteModelRequest, + model_input: MutableModelInputForGPUWithMultiStepMetadata, + ) -> None: + """ + Prepare the last sampled token ids for TP workers. If it's the last + PP rank, then the last sampled token ids are already in the model_input. + If it is NOT the last PP rank, then we need to get the last sampled + token that is cached in the execute_model_req. + """ + if get_pp_group().is_last_rank: + assert model_input.outputs[ + -1].sampler_output.sampled_token_ids is None + assert model_input.outputs[-1].sampled_token_ids is not None + model_input.last_sampled_token_ids = model_input.outputs[ + -1].sampled_token_ids + # free sampled token ids from the previous step if it has been + # pythonized. Cannot free the last sampled token ids because + # we need it for GPU advance_step. + for output in model_input.outputs[:-1]: + if output.pythonized: + output.sampled_token_ids = None + else: + # otherwise we need to get the cached sampled token ids from the + # execute_model_req + assert execute_model_req.last_sampled_token_ids is not None + model_input.last_sampled_token_ids = ( + execute_model_req.last_sampled_token_ids.cuda()) + model_input.add_sampler_output( + SamplerOutput(outputs=[], sampled_token_ids=None), + model_input.last_sampled_token_ids) + + # free sampled token ids from the previous step. + # TODO(will) we could reuse the sampled token ids tensor from + # the previous step instead. + for output in model_input.outputs[:-1]: + output.sampled_token_ids = None + assert model_input.outputs[-1].sampled_token_ids is not None + + def prepare_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[Tuple[MutableModelInputForGPUWithMultiStepMetadata, + WorkerInput]]: + """ + Depending on the current state of the request and multi step worker, + this method may skip the normal _prepare_model_input and + _prepare_worker_input methods and instead used cached values. + """ + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the + # driver broadcasts an empty input. Send an empty input to + # notify all other workers to stop their execution loop. + broadcast_tensor_dict({}, src=0) + return None + + virtual_engine = execute_model_req.virtual_engine + model_input, worker_input = self._get_driver_input_and_broadcast( + execute_model_req) + assert isinstance(model_input, + MutableModelInputForGPUWithMultiStepMetadata) + if execute_model_req.is_first_multi_step: + # cache the worker input and model input for the next steps + self.multi_step_states[virtual_engine] = MultiStepState( + worker_input=worker_input, model_input=model_input) + # if TP workers + else: + broadcast_data = self._get_worker_input_from_broadcast() + # if the driver has sent an empty input, we should stop the worker + # loop + if broadcast_data is None: + return None + model_input, worker_input = broadcast_data + assert isinstance(model_input, + MutableModelInputForGPUWithMultiStepMetadata) + virtual_engine = worker_input.virtual_engine + if model_input.is_first_multi_step: + pass + # TODO(will) Can cache the worker input and model input for the + # next steps. See below for details + else: + # TODO(will) possible to also cache and reuse the cached worker + # input and model input. The idea is essentially the delta + # optimization for model_inputs. Where the TP workers can cache + # the model input states and we only broadcast the delta need + # for the next step (sampled_token_ids from the previous step) + + assert isinstance( + model_input, MutableModelInputForGPUWithMultiStepMetadata) + # we need to update the last sampled token ids in the model + # input for the workers so that they can run inplace + # advance_step + model_input.add_sampler_output( + SamplerOutput(outputs=[], sampled_token_ids=None), + model_input.last_sampled_token_ids) + + assert model_input is not None + assert worker_input is not None + return model_input, worker_input diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 905052d1a951..9fddc863548e 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -16,7 +16,9 @@ SamplerOutput) from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) -from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase +from vllm.worker.model_runner_base import (BroadcastableModelInput, + ModelRunnerBase, + ModelRunnerInputBase) logger = init_logger(__name__) @@ -220,7 +222,7 @@ def execute_worker(self, worker_input: WorkerInput) -> None: raise NotImplementedError def _get_worker_input_from_broadcast( - self) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]: + self) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]: """ Get the worker input from the broadcasted tensor dict. """ assert self.do_metadata_broadcast assert not self.is_driver_worker @@ -237,7 +239,7 @@ def _get_worker_input_from_broadcast( def _get_driver_input_and_broadcast( self, execute_model_req: ExecuteModelRequest - ) -> Tuple[ModelRunnerInputBase, WorkerInput]: + ) -> Tuple[BroadcastableModelInput, WorkerInput]: """ Get the driver input and broadcast it to other workers. """ assert self.is_driver_worker @@ -259,7 +261,7 @@ def _get_driver_input_and_broadcast( def prepare_input( self, execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]: + ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]: """ Prepare the inputs to ModelRunner and workers. """