From dd38054f1554ffa0cf0eaa3c526ab9786135c91b Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Tue, 27 Aug 2024 17:48:02 +0000 Subject: [PATCH] Async + multi-step support --- examples/offline_inference.py | 5 +- .../multi_step/test_correctness_async_llm.py | 10 +- vllm/core/scheduler.py | 4 +- vllm/engine/async_llm_engine.py | 89 ++++++++++----- vllm/engine/llm_engine.py | 102 +++++++++++++----- vllm/sequence.py | 4 +- vllm/worker/model_runner.py | 11 +- vllm/worker/multi_step_model_runner.py | 79 ++++++++++++-- vllm/worker/multi_step_worker.py | 8 ++ 9 files changed, 232 insertions(+), 80 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f..a39fd1f151e1 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,7 +11,10 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="facebook/opt-125m", + num_scheduler_steps=8, + use_v2_block_manager=True, + disable_async_output_proc=False) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index ad99d70d7417..ac04be3d9a68 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -47,10 +47,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str, @pytest.mark.parametrize("eager_mode", [False, True]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.parametrize("is_async", [False, True]) @pytest.mark.asyncio async def test_multi_step(example_prompts, model: str, tp_size: int, pp_size: int, eager_mode: int, - num_scheduler_steps: int, num_prompts: int): + num_scheduler_steps: int, num_prompts: int, + is_async: bool): prompts = example_prompts if len(prompts) < num_prompts: @@ -62,9 +64,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int, ms_server_args = DEFAULT_SERVER_ARGS + \ ["--num-scheduler-steps", f"{num_scheduler_steps}"] - # Disable output proc callback as its not supported - # with multi-step right now - ms_server_args += ["--disable-async-output-proc"] + if not is_async: + ms_server_args += ["--disable-async-output-proc"] + if eager_mode: ms_server_args.append("--enforce-eager") diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 280d7b7e61e2..bb4b64cbcf0b 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1106,9 +1106,7 @@ def schedule( common_computed_block_nums = [] # TODO: Combine multi-step and async postprocessor - allow_async_output_proc: bool = ( - self.use_async_output_proc - and not self.scheduler_config.is_multi_step) + allow_async_output_proc: bool = self.use_async_output_proc # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3445b7084bbc..9315d74237a1 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -279,19 +279,40 @@ async def step_async( scheduler_outputs = cached_outputs.scheduler_outputs allow_async_output_proc = cached_outputs.allow_async_output_proc + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + + ctx = self.scheduler_contexts[virtual_engine] + # skip the scheduler if there are any remaining steps in the seq groups. # This ensures that the scheduler is only called again when the current # batch has completed. if not self._has_remaining_steps(seq_group_metadata_list): + + # Clear outputs on scheduler iteration start + ctx.request_outputs.clear() + + # Schedule iteration (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc ) = self.scheduler[virtual_engine].schedule() - # If current scheduler iteration has no async postprocessor, - # then we need first to drain the pending async postprocessor - # before moving forward - if not allow_async_output_proc and len(self.output_queue) > 0: - self._process_model_outputs(is_async=True) + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + + # Maybe switch from async mode to sync mode + if not allow_async_output_proc and len(ctx.output_queue) > 0: + self._process_model_outputs(virtual_engine=virtual_engine, + is_async=True) + + # For async + multi-step, init the queue + if use_async_and_multi_step: + assert len(ctx.output_queue) == 0 + assert seq_group_metadata_list is not None + ctx.output_queue.append( + (None, seq_group_metadata_list, scheduler_outputs)) if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): @@ -304,9 +325,6 @@ async def step_async( assert seq_group_metadata_list is not None assert scheduler_outputs is not None - assert not (self.scheduler_config.is_multi_step and \ - allow_async_output_proc) - if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() @@ -332,8 +350,10 @@ async def step_async( last_sampled_token_ids=last_sampled_token_ids) if allow_async_output_proc: - execute_model_req.output_proc_callback_fn = \ - self._process_model_outputs + execute_model_req.async_callback = self.async_callback_data[ + virtual_engine] + execute_model_req.use_async_and_multi_step = \ + use_async_and_multi_step # Execute the model. output = await self.model_executor.execute_model_async( @@ -343,9 +363,10 @@ async def step_async( if self.scheduler_config.is_multi_step: self._update_cached_scheduler_output(virtual_engine, output) else: - if len(self.output_queue) > 0: + if not use_async_and_multi_step and len(ctx.output_queue) > 0: assert not self.scheduler_config.is_multi_step - self._process_model_outputs(is_async=True) + self._process_model_outputs(virtual_engine=virtual_engine, + is_async=True) output = [] # Finish the current step for all the sequence groups. @@ -354,25 +375,29 @@ async def step_async( seq_group.finish_step() if not self._has_remaining_steps(seq_group_metadata_list): - # clear the cache if we have finished all the steps + # Clear the cache if we have finished all the steps if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[ virtual_engine] = SchedulerOutputState() - # Cache results in engine - self.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs)) + if use_async_and_multi_step: + # For async + multi-step, clear the queue + ctx.output_queue.clear() + else: + ctx.output_queue.append( + (output, seq_group_metadata_list, scheduler_outputs)) - if output and allow_async_output_proc: - assert len( - output - ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 - self._advance_to_next_step( - output[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) + if output and allow_async_output_proc: + assert len( + output + ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) if not allow_async_output_proc: - self._process_model_outputs(is_async=False) + self._process_model_outputs(virtual_engine=virtual_engine, + is_async=False) # Log stats. self.do_log_stats(scheduler_outputs, output) @@ -381,9 +406,21 @@ async def step_async( self.do_tracing(scheduler_outputs) else: - self.request_outputs = [] + # Multi-step case + if use_async_and_multi_step: + return [] + else: + ctx.request_outputs = [] + + if not self.has_unfinished_requests(): + # Drain async postprocessor (if exists) + if len(ctx.output_queue) > 0: + assert not self.scheduler_config.is_multi_step + self._process_model_outputs(virtual_engine=virtual_engine, + is_async=True) + assert len(ctx.output_queue) == 0 - return self.request_outputs + return ctx.request_outputs async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 46db1f4aa3a2..88de9c906f8d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -38,10 +38,10 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - SamplerOutput, Sequence, SequenceGroup, - SequenceGroupMetadata, SequenceStatus, - AsyncCallbackData) +from vllm.sequence import (AsyncCallbackData, EmbeddingSequenceGroupOutput, + ExecuteModelRequest, SamplerOutput, Sequence, + SequenceGroup, SequenceGroupMetadata, + SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -91,9 +91,8 @@ class SchedulerOutputState: @dataclass class SchedulerContext: - output_queue: Deque[Tuple[List[SamplerOutput], - List[Tuple[ScheduledSequenceGroup, - SequenceGroupMetadata]], + output_queue: Deque[Tuple[Optional[List[SamplerOutput]], + List[SequenceGroupMetadata], SchedulerOutputs]] = field( default_factory=lambda: deque()) @@ -1233,8 +1232,11 @@ def _process_sequence_group_outputs( return - def _process_model_outputs(self, virtual_engine: int, - is_async: bool) -> None: + def _process_model_outputs(self, + virtual_engine: int, + is_async: bool, + sampler_output: Optional[SamplerOutput] = None, + is_last_output: bool = False) -> None: """Apply the model output to the sequences in the scheduled seq groups. virtual_engine: The engine id to operate on @@ -1248,13 +1250,25 @@ def _process_model_outputs(self, virtual_engine: int, """ now = time.time() + is_multi_step = sampler_output is not None + ctx: SchedulerContext = self.scheduler_contexts[virtual_engine] if len(ctx.output_queue) == 0: return None - (outputs, seq_group_metadata_list, - scheduler_outputs) = ctx.output_queue.popleft() + if is_multi_step: + # Async + multi-step case + (outputs, seq_group_metadata_list, + scheduler_outputs) = ctx.output_queue[0] + assert outputs is None + outputs = [sampler_output] + else: + # Async standard case + (outputs, seq_group_metadata_list, + scheduler_outputs) = ctx.output_queue.popleft() + + assert outputs is not None # Sanity check assert len(seq_group_metadata_list) == len( @@ -1313,7 +1327,11 @@ def _process_model_outputs(self, virtual_engine: int, self.output_processor.process_outputs(seq_group, output, is_async) - # Free the finished sequence groups. + # For async + multi-step, free finished seqs and create outputs + # only on the final step. + if is_multi_step and not is_last_output: + return + for scheduler in self.scheduler: scheduler.free_finished_seq_groups() @@ -1321,7 +1339,7 @@ def _process_model_outputs(self, virtual_engine: int, for i, _ in enumerate(seq_group_metadata_list): scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - if i in finished_before: + if not is_multi_step and i in finished_before: continue # Avoids double processing seq_group = scheduled_seq_group.seq_group @@ -1335,7 +1353,11 @@ def _process_model_outputs(self, virtual_engine: int, request_output = RequestOutputFactory.create(seq_group) ctx.request_outputs.append(request_output) - if is_async: + # For async + multi-step, do stats only on the last output. + # Otherwise, do stats if the execution is async + do_stats = is_multi_step or is_async + + if do_stats: # Log stats. self.do_log_stats(scheduler_outputs, outputs, finished_before) @@ -1438,6 +1460,10 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: scheduler_outputs = cached_outputs.scheduler_outputs allow_async_output_proc = cached_outputs.allow_async_output_proc + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + ctx = self.scheduler_contexts[virtual_engine] # Skip the scheduler if there are any remaining steps in the seq groups. @@ -1453,11 +1479,22 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: allow_async_output_proc ) = self.scheduler[virtual_engine].schedule() + # Detect async + multi-step + use_async_and_multi_step = (self.scheduler_config.is_multi_step + and allow_async_output_proc) + # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) + # For async + multi-step, init the queue + if use_async_and_multi_step: + assert len(ctx.output_queue) == 0 + assert seq_group_metadata_list is not None + ctx.output_queue.append( + (None, seq_group_metadata_list, scheduler_outputs)) + if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): # cache the scheduler outputs for the next iteration if we have @@ -1469,9 +1506,6 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: assert seq_group_metadata_list is not None assert scheduler_outputs is not None - assert not (self.scheduler_config.is_multi_step and \ - allow_async_output_proc) - if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() @@ -1498,6 +1532,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: if allow_async_output_proc: execute_model_req.async_callback = self.async_callback_data[ virtual_engine] + execute_model_req.use_async_and_multi_step = \ + use_async_and_multi_step output = self.model_executor.execute_model( execute_model_req=execute_model_req) @@ -1509,7 +1545,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: else: # Nothing scheduled => If there is pending async postprocessor, # then finish it here. - if len(ctx.output_queue) > 0: + if not use_async_and_multi_step and len(ctx.output_queue) > 0: assert not self.scheduler_config.is_multi_step self._process_model_outputs(virtual_engine=virtual_engine, is_async=True) @@ -1526,18 +1562,23 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[0] = SchedulerOutputState() - # Add results to the output_queue - # (for async or non-async postprocessing) - ctx.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs)) + if use_async_and_multi_step: + # For async + multi-step, clear the queue + ctx.output_queue.clear() + else: + # Add results to the output_queue + # (for async or non-async postprocessing) + ctx.output_queue.append( + (output, seq_group_metadata_list, scheduler_outputs)) - if output and allow_async_output_proc: - assert len(output) == 1, ("Multi step decoding does not work " - "with async output processing.") + if output and allow_async_output_proc: + assert len(output) == 1, ( + "Multi step decoding does not work " + "with async output processing.") - self._advance_to_next_step( - output[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) # Check if need to run the usual non-async path if not allow_async_output_proc: @@ -1551,7 +1592,10 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: self.do_tracing(scheduler_outputs) else: # Multi-step case - self.request_outputs = [] + if use_async_and_multi_step: + return [] + else: + ctx.request_outputs = [] if not self.has_unfinished_requests(): # Drain async postprocessor (if exists) diff --git a/vllm/sequence.py b/vllm/sequence.py index 6cc5a8b33e27..ee4795f40c17 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1301,6 +1301,7 @@ class ExecuteModelRequest( last_sampled_token_ids: Optional[torch.Tensor] = None # Async callback async_callback: Optional[AsyncCallbackData] = None + use_async_and_multi_step: bool = False @property def is_first_multi_step(self) -> bool: @@ -1347,4 +1348,5 @@ def clone( finished_requests_ids=self.finished_requests_ids, last_sampled_token_ids=self.last_sampled_token_ids.clone() if self.last_sampled_token_ids is not None else None, - async_callback=self.async_callback) + async_callback=self.async_callback, + use_async_and_multi_step=self.use_async_and_multi_step) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fbfc911b7cb0..1a4f58070527 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -6,8 +6,8 @@ import warnings import weakref from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, - Tuple, Type, TypeVar, Union) +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, + TypeVar, Union) import numpy as np import torch @@ -41,8 +41,8 @@ from vllm.prompt_adapter.worker_manager import ( LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata, AsyncCallbackData) +from vllm.sequence import (AsyncCallbackData, IntermediateTensors, + SamplerOutput, SequenceGroupMetadata) from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_hip, is_pin_memory_available) from vllm.worker.model_runner_base import ( @@ -91,6 +91,7 @@ class ModelInputForGPU(ModelRunnerInputBase): finished_requests_ids: Optional[List[str]] = None virtual_engine: int = 0 async_callback: Optional[AsyncCallbackData] = None + use_async_and_multi_step: bool = False def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -1460,7 +1461,7 @@ def execute_model( func = model_input.async_callback.func kw_args = model_input.async_callback.kw_args func(**kw_args) - + # Sample the next token. output: SamplerOutput = self.model.sample( logits=logits, diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 521205eca05a..91bea68a9528 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -1,3 +1,4 @@ +import dataclasses from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union @@ -13,9 +14,9 @@ from vllm import _custom_ops as ops from vllm.distributed import get_pp_group from vllm.logger import init_logger -from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, - Logprob, SamplerOutput, SequenceGroupMetadata, - SequenceOutput) +from vllm.sequence import (AsyncCallbackData, CompletionSequenceGroupOutput, + IntermediateTensors, Logprob, SamplerOutput, + SequenceGroupMetadata, SequenceOutput) from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ( @@ -215,6 +216,49 @@ def prepare_model_input( ) return model_input + def _async_process_outputs(self, model_input: StatefulModelInput, + output_proc_callback: AsyncCallbackData): + output_proc_fn = output_proc_callback.func + output_proc_kw_args = output_proc_callback.kw_args + virtual_engine = output_proc_kw_args["virtual_engine"] + + for model_output in model_input.cached_outputs: + if not model_output.pythonized: + model_output.maybe_pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + if model_output.pythonized: + output_proc_fn(virtual_engine=virtual_engine, + is_async=False, + sampler_output=model_output.sampler_output) + + def _final_process_outputs(self, model_input: StatefulModelInput, + output_proc_callback: AsyncCallbackData): + assert model_input.frozen_model_input is not None + + if output_proc_callback is not None: + output_proc_fn = output_proc_callback.func + output_proc_kw_args = output_proc_callback.kw_args + virtual_engine = output_proc_kw_args["virtual_engine"] + + outputs = [] + for output_id in range(len(model_input.cached_outputs)): + is_last_output = output_id == len(model_input.cached_outputs) - 1 + + output = model_input.cached_outputs[output_id] + if not output.pythonized: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + + if model_input.frozen_model_input.use_async_and_multi_step: + output_proc_fn(virtual_engine=virtual_engine, + is_async=False, + sampler_output=output.sampler_output, + is_last_output=is_last_output) + + outputs.append(output.sampler_output) + + return outputs + @torch.inference_mode() def execute_model( self, @@ -271,6 +315,20 @@ def execute_model( model_input = self._advance_step( model_input, model_input.cached_outputs[-1].sampler_output) + output_proc_callback = None + if frozen_model_input.use_async_and_multi_step: + output_proc_callback = frozen_model_input.async_callback + async_callback = AsyncCallbackData( + self._async_process_outputs, { + "model_input": model_input, + "output_proc_callback": output_proc_callback + }) + + frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input, + async_callback=async_callback) + assert frozen_model_input is not None + # Execute the model output = self._base_model_runner.execute_model(frozen_model_input, kv_caches, @@ -301,9 +359,11 @@ def execute_model( output[0].logprobs = None # Pythonize the output if CPU is ahead and the previous step is # ready. - for model_output in model_input.cached_outputs: - model_output.maybe_pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) + if not frozen_model_input.use_async_and_multi_step: + for model_output in model_input.cached_outputs: + model_output.maybe_pythonize(model_input, + self._copy_stream, + self.pinned_sampled_token_ids) model_input.current_step += 1 @@ -316,11 +376,8 @@ def execute_model( # Pythonize the output and block if needed since it is the last step if model_input.is_last_step: - outputs = [] - for output in model_input.cached_outputs: - output.pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) - outputs.append(output.sampler_output) + outputs = self._final_process_outputs(model_input, + output_proc_callback) return outputs # should be [SamplerOutput] diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index 2ed77dd698f5..e0e421942f40 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -1,3 +1,4 @@ +import dataclasses from dataclasses import dataclass from typing import Dict, List, Optional, Tuple @@ -61,6 +62,13 @@ def _get_driver_input_and_broadcast( execute_model_req.seq_group_metadata_list, execute_model_req.virtual_engine, execute_model_req.finished_requests_ids)) + + if execute_model_req.async_callback: + model_input.frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input, + async_callback=execute_model_req.async_callback, + use_async_and_multi_step=execute_model_req. + use_async_and_multi_step) else: # on subsequent steps we reuse the worker input and model input multi_step_state = self.multi_step_states[virtual_engine]