Skip to content

Commit

Permalink
improve chunked prefill performance
Browse files Browse the repository at this point in the history
[Bugfix] Fix vllm-project#7592 vllm 0.5.4 enable_chunked_prefill throughput is slightly lower than 0.5.3~0.5.0. (vllm-project#7874)
  • Loading branch information
noooop authored and siddharth9820 committed Sep 30, 2024
1 parent 03e71a2 commit 5244bb4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
3 changes: 3 additions & 0 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def test_models_with_fp8_kv_cache(
pytest.skip(
"#7378: CUDA illegal memory access (undiagnosed) facebook/opt-125m"
)
if ((model, kv_cache_dtype, chunked_prefill_token_size) == (
"nm-testing/Qwen2-1.5B-Instruct-FP8-K-V", "fp8_e4m3", 4)):
pytest.skip("flakey test, see: #7874 #8051")

max_num_seqs = chunked_prefill_token_size
max_num_batched_tokens = chunked_prefill_token_size
Expand Down
15 changes: 10 additions & 5 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,16 +1027,21 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs:

# Update waiting requests.
self.waiting.extendleft(running_scheduled.preempted)

# Update new running requests.
self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.prefill_seq_groups])
# By default, vLLM scheduler prioritizes prefills.
# Once chunked prefill is enabled,
# the policy is changed to prioritize decode requests.
self.running.extend(
[s.seq_group for s in swapped_in.decode_seq_groups])
self.running.extend(
[s.seq_group for s in swapped_in.prefill_seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.prefill_seq_groups])
self.running.extend([s.seq_group for s in prefills.seq_groups])

# Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out)
return SchedulerOutputs(
Expand Down

0 comments on commit 5244bb4

Please sign in to comment.