Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix async postprocessor in case of preemption #8267

Merged
merged 8 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 47 additions & 40 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,13 +537,6 @@ def _schedule_running(
preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out

# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.

# Store original running requests for the case of async + preemption
if self.use_async_output_proc:
orig_running = self.running.copy()

running_queue = self.running
assert len(self._async_stopped) == 0
while running_queue:
Expand All @@ -552,6 +545,7 @@ def _schedule_running(
seq_group, SequenceStatus.RUNNING, enable_chunking, budget)

if num_running_tokens == 0:
# No budget => Stop
break

running_queue.popleft()
Expand All @@ -565,18 +559,8 @@ def _schedule_running(
self._async_stopped.append(seq_group)
continue

# With async postprocessor, when preemption kicks in, we need
# first to drain the async postprocessor, so that all async
# block_table freeing is applied before the preemption freeing
# is applied.
if self.use_async_output_proc and not self._can_append_slots(
seq_group):
tmp = self.running
self.running = orig_running
assert self.output_proc_callback is not None
self.output_proc_callback()
self.running = tmp

# NOTE(woosuk): Preemption happens only when there is no available
# slot to keep all the sequence groups in the RUNNING state.
while not self._can_append_slots(seq_group):
budget.subtract_num_batched_tokens(seq_group.request_id,
num_running_tokens)
Expand All @@ -588,24 +572,43 @@ def _schedule_running(
and seq_group.lora_int_id in curr_loras):
curr_loras.remove(seq_group.lora_int_id)

# Determine victim sequence
cont_loop = True
if running_queue:
# Preempt the lowest-priority sequence groups.
# Preempt the lowest-priority sequence group.
victim_seq_group = running_queue.pop()
else:
# No other sequence group can be preempted.
# Preempt the current sequence group.
# Note: This is also where we stop this loop
# (since there is nothing else to preempt)
victim_seq_group = seq_group
cont_loop = False

# With async postprocessor, before preempting a sequence
# we need to ensure it has no pending async postprocessor
do_preempt = True
if self.use_async_output_proc:
assert self.output_proc_callback is not None
self.output_proc_callback(
request_id=victim_seq_group.request_id)

# It may be that the async pending "victim_seq_group"
# becomes finished, in which case we simply free it.
if victim_seq_group.is_finished():
self._free_finished_seq_group(victim_seq_group)
do_preempt = False

# Do preemption
if do_preempt:
preempted_mode = self._preempt(victim_seq_group,
blocks_to_swap_out)
if preempted_mode == PreemptionMode.RECOMPUTE:
preempted.append(victim_seq_group)
else:
swapped_out.append(victim_seq_group)
else:
# No other sequence groups can be preempted.
# Preempt the current sequence group.
preempted_mode = self._preempt(seq_group,
blocks_to_swap_out)
if preempted_mode == PreemptionMode.RECOMPUTE:
preempted.append(seq_group)
else:
swapped_out.append(seq_group)

if not cont_loop:
break
else:
self._append_slots(seq_group, blocks_to_copy)
Expand Down Expand Up @@ -1264,22 +1267,26 @@ def _free_finished_seqs(self, seq_group: SequenceGroup) -> None:
if seq.is_finished():
self.free_seq(seq)

def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None:
if seq_group.is_finished():
# Free cross-attention block table, if it exists
self._free_seq_group_cross_attn_blocks(seq_group)

# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# next step.
self._finished_requests_ids.append(seq_group.request_id)

# Free finished seqs
self._free_finished_seqs(seq_group)

def free_finished_seq_groups(self) -> None:
remaining: Deque[SequenceGroup] = deque()
for seq_group in self.running:
if seq_group.is_finished():
# Free cross-attention block table, if it exists
self._free_seq_group_cross_attn_blocks(seq_group)
# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# next step.
self._finished_requests_ids.append(seq_group.request_id)
else:
self._free_finished_seq_group(seq_group)
if not seq_group.is_finished():
remaining.append(seq_group)

# Free finished seqs
self._free_finished_seqs(seq_group)

self.running = remaining

# Handle async stopped sequence groups
Expand Down
24 changes: 12 additions & 12 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,17 +342,17 @@ async def step_async(
virtual_engine]

# Execute the model.
output = await self.model_executor.execute_model_async(
outputs = await self.model_executor.execute_model_async(
execute_model_req)

# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
self._update_cached_scheduler_output(virtual_engine, outputs)
else:
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
output = []
outputs = []

# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
Expand All @@ -365,25 +365,25 @@ async def step_async(
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()

is_async = allow_async_output_proc
is_last_step = True
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step))
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=allow_async_output_proc,
is_last_step=True)

if output and allow_async_output_proc:
if outputs and allow_async_output_proc:
assert len(
output
outputs
) == 1, "Async postprocessor expects only a single output set"
self._advance_to_next_step(
output[0], seq_group_metadata_list,
outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)

if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx)

# Log stats.
self.do_log_stats(scheduler_outputs, output)
self.do_log_stats(scheduler_outputs, outputs)

# Tracing
self.do_tracing(scheduler_outputs)
Expand Down
Loading
Loading