-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] Fix incorrect updates to num_computed_tokens in multi-step scheduling #9038
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import pytest | ||
|
||
from tests.conftest import VllmRunner | ||
from tests.core.utils import create_dummy_prompt | ||
from vllm.engine.llm_engine import LLMEngine | ||
from vllm.sequence import SequenceGroup | ||
|
||
MODEL = "JackFram/llama-160m" | ||
|
||
|
||
def add_seq_group_to_engine(engine: LLMEngine, seq_group: SequenceGroup): | ||
scheduler = engine.scheduler[0] | ||
scheduler.add_seq_group(seq_group) | ||
|
||
|
||
@pytest.mark.parametrize("num_scheduler_steps", [1, 8]) | ||
@pytest.mark.parametrize("enable_chunked_prefill", [False, True]) | ||
@pytest.mark.parametrize("enforce_eager", [False, True]) | ||
def test_num_computed_tokens_update(num_scheduler_steps: int, | ||
enable_chunked_prefill: bool, | ||
enforce_eager: bool): | ||
|
||
# Make a vllm engine | ||
runner = VllmRunner(model_name=MODEL, | ||
gpu_memory_utilization=0.7, | ||
use_v2_block_manager=True, | ||
num_scheduler_steps=num_scheduler_steps, | ||
enable_chunked_prefill=enable_chunked_prefill, | ||
enforce_eager=enforce_eager) | ||
engine: LLMEngine = runner.model.llm_engine | ||
|
||
is_multi_step = num_scheduler_steps > 1 | ||
is_multi_step_chunked_prefill = is_multi_step and enable_chunked_prefill | ||
# In multi-step + chunked-prefill there is no separate single prompt step. | ||
# What is scheduled will run for num_scheduler_steps always. | ||
num_prompt_steps = num_scheduler_steps \ | ||
if is_multi_step_chunked_prefill else 1 | ||
|
||
num_output_tokens_list = [4, 8, 12, 15, 16, 17] | ||
|
||
# Create sequence and add to engine | ||
prompt_len = 10 | ||
|
||
for req_idx, num_output_tokens in enumerate(num_output_tokens_list): | ||
seq, seq_group = create_dummy_prompt(request_id=str(req_idx), | ||
prompt_length=prompt_len, | ||
min_tokens=num_output_tokens, | ||
max_tokens=num_output_tokens) | ||
add_seq_group_to_engine(engine, seq_group) | ||
|
||
assert seq.data.get_num_computed_tokens() == 0 | ||
|
||
for _ in range(num_prompt_steps): | ||
# prompt steps | ||
engine.step() | ||
|
||
if not seq.is_finished(): | ||
assert seq.data.get_num_computed_tokens( | ||
) == prompt_len + num_prompt_steps - 1 | ||
|
||
prompt_num_computed_tokens = seq.data.get_num_computed_tokens() | ||
|
||
decode_step_counter = 0 | ||
while not seq.is_finished(): | ||
assert seq.data.get_num_computed_tokens( | ||
) == prompt_num_computed_tokens + decode_step_counter | ||
for _ in range(num_scheduler_steps): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. QQ: why do we need the for loop here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when Also, multi-step doesn't provide any guarantees that output processing will happen every step. The only guarantee is that after the completion of |
||
# decode step | ||
engine.step() | ||
decode_step_counter += 1 | ||
|
||
assert seq.data.get_num_computed_tokens( | ||
) == prompt_len + num_output_tokens - 1 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -965,6 +965,45 @@ def _process_sequence_group_outputs( | |
|
||
return | ||
|
||
def _update_num_computed_tokens_for_multi_step_prefill( | ||
self, seq_group: SequenceGroup, | ||
seq_group_meta: SequenceGroupMetadata, | ||
is_first_step_output: Optional[bool]): | ||
""" | ||
This function updates num_computed_tokens for prompt sequences | ||
when Multi-Step is enabled. | ||
|
||
seq_group: SequenceGroup to update the num_computed_tokens for. | ||
seq_group_meta: Metadata of the given SequenceGroup. | ||
is_first_step_output: Optional[bool] - | ||
When available, is_first_step_output indicates if the appended | ||
output token is the output of the first-step in multi-step. | ||
A value of None indicates that outputs from all steps in | ||
in multi-step are submitted in a single burst. | ||
""" | ||
|
||
assert self.scheduler_config.is_multi_step | ||
|
||
if not seq_group_meta.is_prompt: | ||
# num_computed_token updates for multi-step decodes happen after | ||
# the tokens are appended to the sequence. | ||
return | ||
|
||
do_update: bool = False | ||
if self.scheduler_config.chunked_prefill_enabled: | ||
# In multi-step + chunked-prefill case, the prompt sequences | ||
# that are scheduled are fully processed in the first step. | ||
do_update = is_first_step_output is None or is_first_step_output | ||
else: | ||
# Normal multi-step decoding case. In this case prompt-sequences | ||
# are actually single-stepped. Always update in this case. | ||
assert seq_group.state.num_steps == 1 | ||
do_update = True | ||
|
||
if do_update: | ||
seq_group.update_num_computed_tokens( | ||
seq_group_meta.token_chunk_size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the cleanup, the logic is much cleaner now! |
||
|
||
def _process_model_outputs(self, | ||
ctx: SchedulerContext, | ||
request_id: Optional[str] = None) -> None: | ||
|
@@ -975,64 +1014,6 @@ def _process_model_outputs(self, | |
request_id: If provided, then only this request is going to be processed | ||
""" | ||
|
||
def update_prefill_num_computed_tokens( | ||
seq_group: SequenceGroup, | ||
seq_group_meta: SequenceGroupMetadata, num_outputs: int, | ||
is_first_step_output: Optional[bool]) -> None: | ||
""" | ||
When multi-step and chunked-prefill are enabled together, the | ||
prefill sequence scheduled for multi-step execution turn into | ||
decodes in the first step itself. This function accounts | ||
for that conversion. | ||
|
||
seq_group: SequenceGroup - A prefill seq_group | ||
seq_group_meta: SequenceGroupMetadata - Metadata of the given | ||
prefill seq_group | ||
num_outputs: int - number of output tokens being processed for the | ||
given seq_group | ||
is_first_step_output: Optional[bool] - | ||
If multi-step is enabled and num_outputs is 1, this value | ||
indicates if this outputs belongs to the first step in the | ||
multi-step. | ||
If multi-step is enabled and num_outputs > 1, this value | ||
must be None, as num_outputs > 1 indicates that outputs from | ||
all the steps in multi-step are submitted in a single burst. | ||
When multi-step is disabled, this value is always True. | ||
""" | ||
|
||
assert seq_group_meta.is_prompt | ||
|
||
token_chunk_size = seq_group_meta.token_chunk_size | ||
|
||
if num_outputs == 1: | ||
assert is_first_step_output is not None | ||
|
||
if seq_group_meta.state.num_steps == 1: | ||
assert is_first_step_output is True | ||
seq_group.update_num_computed_tokens(token_chunk_size) | ||
return | ||
|
||
# multi-step prefill is only supported when multi-step is | ||
# enabled with chunked prefill | ||
assert self.scheduler_config.is_multi_step and \ | ||
self.scheduler_config.chunked_prefill_enabled | ||
if is_first_step_output is True: | ||
# This sequence is a prompt during the first step only. | ||
seq_group.update_num_computed_tokens(token_chunk_size) | ||
return | ||
|
||
assert is_first_step_output is None | ||
|
||
# multi-step prefill is only supported when multi-step is | ||
# enabled with chunked prefill. Outputs from all the steps are | ||
# submitted in a single burst. | ||
assert self.scheduler_config.is_multi_step and \ | ||
self.scheduler_config.chunked_prefill_enabled | ||
assert num_outputs == seq_group_meta.state.num_steps, \ | ||
f"#outputs {len(outputs)} - num steps {seq_group_meta.state.num_steps}" #noqa | ||
# This sequence is a prompt during the first step only. | ||
seq_group.update_num_computed_tokens(token_chunk_size) | ||
|
||
now = time.time() | ||
|
||
if len(ctx.output_queue) == 0: | ||
|
@@ -1093,7 +1074,7 @@ def update_prefill_num_computed_tokens( | |
seq_group_meta = seq_group_metadata_list[i] | ||
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] | ||
|
||
seq_group = scheduled_seq_group.seq_group | ||
seq_group: SequenceGroup = scheduled_seq_group.seq_group | ||
|
||
if seq_group.is_finished(): | ||
finished_before.append(i) | ||
|
@@ -1104,14 +1085,14 @@ def update_prefill_num_computed_tokens( | |
else: | ||
output = [outputs_by_sequence_group[0][i]] | ||
|
||
if not is_async and seq_group_meta.is_prompt: | ||
# Updates for all decodes happen when we actually append the | ||
# token ids to the seq in process_outputs. | ||
update_prefill_num_computed_tokens(seq_group, seq_group_meta, | ||
len(output), | ||
is_first_step_output) | ||
elif not is_async: | ||
seq_group.update_num_computed_tokens(1) | ||
if not is_async: | ||
if self.scheduler_config.is_multi_step: | ||
# Updates happen only if the sequence is prefill | ||
self._update_num_computed_tokens_for_multi_step_prefill( | ||
seq_group, seq_group_meta, is_first_step_output) | ||
else: | ||
seq_group.update_num_computed_tokens( | ||
seq_group_meta.token_chunk_size) | ||
|
||
if outputs: | ||
for o in outputs: | ||
|
@@ -1135,16 +1116,8 @@ def update_prefill_num_computed_tokens( | |
else: | ||
self.output_processor.process_prompt_logprob(seq_group, output) | ||
if seq_group_meta.do_sample: | ||
output_token_num = self.output_processor.process_outputs( | ||
self.output_processor.process_outputs( | ||
seq_group, output, is_async) | ||
if self.speculative_config: | ||
# We -1 here because we always | ||
# (w/o speculative decoding) add the number of | ||
# computed tokens by one in the decoding phase. | ||
# Therefore, we remove that one token that | ||
# is already added. | ||
seq_group.update_num_computed_tokens(output_token_num - | ||
1) | ||
|
||
if seq_group.is_finished(): | ||
finished_now.append(i) | ||
|
@@ -1253,20 +1226,15 @@ def _advance_to_next_step( | |
if seq_group.is_finished(): | ||
continue | ||
|
||
if seq_group_metadata.is_prompt: | ||
if self.scheduler_config.is_multi_step and \ | ||
self.scheduler_config.chunked_prefill_enabled: | ||
# Prompts are scheduled in multi-step only when | ||
# chunking is enabled. These prompts turn into | ||
# decodes after the very first step. Therefore, | ||
# we skip the update to the num_computed_tokens | ||
# here. | ||
seq_group.update_num_computed_tokens(1) | ||
else: | ||
seq_group.update_num_computed_tokens( | ||
seq_group_metadata.token_chunk_size) | ||
if self.scheduler_config.is_multi_step: | ||
# Updates happen only if the sequence is prefill | ||
self._update_num_computed_tokens_for_multi_step_prefill( | ||
seq_group, seq_group_metadata, | ||
seq_group.state.num_steps == 1) | ||
else: | ||
seq_group.update_num_computed_tokens(1) | ||
seq_group.update_num_computed_tokens( | ||
seq_group_metadata.token_chunk_size) | ||
|
||
if seq_group_metadata.do_sample: | ||
assert len(sequence_group_outputs.samples) == 1, ( | ||
"Async output processor expects a single sample" | ||
|
@@ -1276,7 +1244,15 @@ def _advance_to_next_step( | |
|
||
assert len(seq_group.seqs) == 1 | ||
seq = seq_group.seqs[0] | ||
seq.append_token_id(sample.output_token, sample.logprobs) | ||
|
||
if self.scheduler_config.is_multi_step: | ||
is_prefill_append = seq.data.get_num_uncomputed_tokens( | ||
) == 0 | ||
seq.append_token_id(sample.output_token, sample.logprobs) | ||
if not is_prefill_append: | ||
seq_group.update_num_computed_tokens(1) | ||
else: | ||
seq.append_token_id(sample.output_token, sample.logprobs) | ||
|
||
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: | ||
"""Performs one decoding iteration and returns newly generated results. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: merge this call with line 58 so we only call get_num_computed_tokens once.