Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix #7592 vllm 0.5.4 enable_chunked_prefill throughput is slightly lower than 0.5.3~0.5.0. #7874

Merged
merged 2 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def test_models_with_fp8_kv_cache(
pytest.skip(
"#7378: CUDA illegal memory access (undiagnosed) facebook/opt-125m"
)
if ((model, kv_cache_dtype, chunked_prefill_token_size) == (
"nm-testing/Qwen2-1.5B-Instruct-FP8-K-V", "fp8_e4m3", 4)):
pytest.skip("flakey test, see: #7874 #8051")

max_num_seqs = chunked_prefill_token_size
max_num_batched_tokens = chunked_prefill_token_size
Expand Down
15 changes: 10 additions & 5 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,16 +1027,21 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs:

# Update waiting requests.
self.waiting.extendleft(running_scheduled.preempted)

# Update new running requests.
self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.prefill_seq_groups])
# By default, vLLM scheduler prioritizes prefills.
# Once chunked prefill is enabled,
# the policy is changed to prioritize decode requests.
self.running.extend(
[s.seq_group for s in swapped_in.decode_seq_groups])
self.running.extend(
[s.seq_group for s in swapped_in.prefill_seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.prefill_seq_groups])
self.running.extend([s.seq_group for s in prefills.seq_groups])

# Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out)
return SchedulerOutputs(
Expand Down
Loading