Skip to content

Commit

Permalink
[Bugfix] Fix async postprocessor in case of preemption (vllm-project#…
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-neuralmagic authored Sep 8, 2024
1 parent cfe712b commit 4ef41b8
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 114 deletions.
87 changes: 47 additions & 40 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,13 +537,6 @@ def _schedule_running(
preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out

# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.

# Store original running requests for the case of async + preemption
if self.use_async_output_proc:
orig_running = self.running.copy()

running_queue = self.running
assert len(self._async_stopped) == 0
while running_queue:
Expand All @@ -552,6 +545,7 @@ def _schedule_running(
seq_group, SequenceStatus.RUNNING, enable_chunking, budget)

if num_running_tokens == 0:
# No budget => Stop
break

running_queue.popleft()
Expand All @@ -565,18 +559,8 @@ def _schedule_running(
self._async_stopped.append(seq_group)
continue

# With async postprocessor, when preemption kicks in, we need
# first to drain the async postprocessor, so that all async
# block_table freeing is applied before the preemption freeing
# is applied.
if self.use_async_output_proc and not self._can_append_slots(
seq_group):
tmp = self.running
self.running = orig_running
assert self.output_proc_callback is not None
self.output_proc_callback()
self.running = tmp

# NOTE(woosuk): Preemption happens only when there is no available
# slot to keep all the sequence groups in the RUNNING state.
while not self._can_append_slots(seq_group):
budget.subtract_num_batched_tokens(seq_group.request_id,
num_running_tokens)
Expand All @@ -588,24 +572,43 @@ def _schedule_running(
and seq_group.lora_int_id in curr_loras):
curr_loras.remove(seq_group.lora_int_id)

# Determine victim sequence
cont_loop = True
if running_queue:
# Preempt the lowest-priority sequence groups.
# Preempt the lowest-priority sequence group.
victim_seq_group = running_queue.pop()
else:
# No other sequence group can be preempted.
# Preempt the current sequence group.
# Note: This is also where we stop this loop
# (since there is nothing else to preempt)
victim_seq_group = seq_group
cont_loop = False

# With async postprocessor, before preempting a sequence
# we need to ensure it has no pending async postprocessor
do_preempt = True
if self.use_async_output_proc:
assert self.output_proc_callback is not None
self.output_proc_callback(
request_id=victim_seq_group.request_id)

# It may be that the async pending "victim_seq_group"
# becomes finished, in which case we simply free it.
if victim_seq_group.is_finished():
self._free_finished_seq_group(victim_seq_group)
do_preempt = False

# Do preemption
if do_preempt:
preempted_mode = self._preempt(victim_seq_group,
blocks_to_swap_out)
if preempted_mode == PreemptionMode.RECOMPUTE:
preempted.append(victim_seq_group)
else:
swapped_out.append(victim_seq_group)
else:
# No other sequence groups can be preempted.
# Preempt the current sequence group.
preempted_mode = self._preempt(seq_group,
blocks_to_swap_out)
if preempted_mode == PreemptionMode.RECOMPUTE:
preempted.append(seq_group)
else:
swapped_out.append(seq_group)

if not cont_loop:
break
else:
self._append_slots(seq_group, blocks_to_copy)
Expand Down Expand Up @@ -1264,22 +1267,26 @@ def _free_finished_seqs(self, seq_group: SequenceGroup) -> None:
if seq.is_finished():
self.free_seq(seq)

def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None:
if seq_group.is_finished():
# Free cross-attention block table, if it exists
self._free_seq_group_cross_attn_blocks(seq_group)

# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# next step.
self._finished_requests_ids.append(seq_group.request_id)

# Free finished seqs
self._free_finished_seqs(seq_group)

def free_finished_seq_groups(self) -> None:
remaining: Deque[SequenceGroup] = deque()
for seq_group in self.running:
if seq_group.is_finished():
# Free cross-attention block table, if it exists
self._free_seq_group_cross_attn_blocks(seq_group)
# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# next step.
self._finished_requests_ids.append(seq_group.request_id)
else:
self._free_finished_seq_group(seq_group)
if not seq_group.is_finished():
remaining.append(seq_group)

# Free finished seqs
self._free_finished_seqs(seq_group)

self.running = remaining

# Handle async stopped sequence groups
Expand Down
24 changes: 12 additions & 12 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,17 +342,17 @@ async def step_async(
virtual_engine]

# Execute the model.
output = await self.model_executor.execute_model_async(
outputs = await self.model_executor.execute_model_async(
execute_model_req)

# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
self._update_cached_scheduler_output(virtual_engine, outputs)
else:
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
output = []
outputs = []

# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
Expand All @@ -365,25 +365,25 @@ async def step_async(
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()

is_async = allow_async_output_proc
is_last_step = True
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step))
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=allow_async_output_proc,
is_last_step=True)

if output and allow_async_output_proc:
if outputs and allow_async_output_proc:
assert len(
output
outputs
) == 1, "Async postprocessor expects only a single output set"
self._advance_to_next_step(
output[0], seq_group_metadata_list,
outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)

if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx)

# Log stats.
self.do_log_stats(scheduler_outputs, output)
self.do_log_stats(scheduler_outputs, outputs)

# Tracing
self.do_tracing(scheduler_outputs)
Expand Down
Loading

0 comments on commit 4ef41b8

Please sign in to comment.