Skip to content

Commit

Permalink
Async + multi-step support
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-neuralmagic committed Aug 27, 2024
1 parent 9e8f61e commit dd38054
Show file tree
Hide file tree
Showing 9 changed files with 232 additions and 80 deletions.
5 changes: 4 additions & 1 deletion examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="facebook/opt-125m",
num_scheduler_steps=8,
use_v2_block_manager=True,
disable_async_output_proc=False)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
Expand Down
10 changes: 6 additions & 4 deletions tests/multi_step/test_correctness_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
@pytest.mark.parametrize("eager_mode", [False, True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio
async def test_multi_step(example_prompts, model: str, tp_size: int,
pp_size: int, eager_mode: int,
num_scheduler_steps: int, num_prompts: int):
num_scheduler_steps: int, num_prompts: int,
is_async: bool):

prompts = example_prompts
if len(prompts) < num_prompts:
Expand All @@ -62,9 +64,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
ms_server_args = DEFAULT_SERVER_ARGS + \
["--num-scheduler-steps", f"{num_scheduler_steps}"]

# Disable output proc callback as its not supported
# with multi-step right now
ms_server_args += ["--disable-async-output-proc"]
if not is_async:
ms_server_args += ["--disable-async-output-proc"]

if eager_mode:
ms_server_args.append("--enforce-eager")

Expand Down
4 changes: 1 addition & 3 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,9 +1106,7 @@ def schedule(
common_computed_block_nums = []

# TODO: Combine multi-step and async postprocessor
allow_async_output_proc: bool = (
self.use_async_output_proc
and not self.scheduler_config.is_multi_step)
allow_async_output_proc: bool = self.use_async_output_proc

# Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = []
Expand Down
89 changes: 63 additions & 26 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,19 +279,40 @@ async def step_async(
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc

# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)

ctx = self.scheduler_contexts[virtual_engine]

# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):

# Clear outputs on scheduler iteration start
ctx.request_outputs.clear()

# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()

# If current scheduler iteration has no async postprocessor,
# then we need first to drain the pending async postprocessor
# before moving forward
if not allow_async_output_proc and len(self.output_queue) > 0:
self._process_model_outputs(is_async=True)
# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)

# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)

# For async + multi-step, init the queue
if use_async_and_multi_step:
assert len(ctx.output_queue) == 0
assert seq_group_metadata_list is not None
ctx.output_queue.append(
(None, seq_group_metadata_list, scheduler_outputs))

if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
Expand All @@ -304,9 +325,6 @@ async def step_async(
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None

assert not (self.scheduler_config.is_multi_step and \
allow_async_output_proc)

if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
Expand All @@ -332,8 +350,10 @@ async def step_async(
last_sampled_token_ids=last_sampled_token_ids)

if allow_async_output_proc:
execute_model_req.output_proc_callback_fn = \
self._process_model_outputs
execute_model_req.async_callback = self.async_callback_data[
virtual_engine]
execute_model_req.use_async_and_multi_step = \
use_async_and_multi_step

# Execute the model.
output = await self.model_executor.execute_model_async(
Expand All @@ -343,9 +363,10 @@ async def step_async(
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
if len(self.output_queue) > 0:
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(is_async=True)
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
output = []

# Finish the current step for all the sequence groups.
Expand All @@ -354,25 +375,29 @@ async def step_async(
seq_group.finish_step()

if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps
# Clear the cache if we have finished all the steps
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()

# Cache results in engine
self.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if use_async_and_multi_step:
# For async + multi-step, clear the queue
ctx.output_queue.clear()
else:
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))

if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)

if not allow_async_output_proc:
self._process_model_outputs(is_async=False)
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=False)

# Log stats.
self.do_log_stats(scheduler_outputs, output)
Expand All @@ -381,9 +406,21 @@ async def step_async(
self.do_tracing(scheduler_outputs)

else:
self.request_outputs = []
# Multi-step case
if use_async_and_multi_step:
return []
else:
ctx.request_outputs = []

if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
assert len(ctx.output_queue) == 0

return self.request_outputs
return ctx.request_outputs

async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
Expand Down
Loading

0 comments on commit dd38054

Please sign in to comment.