diff --git a/tests/core/test_num_computed_tokens_update.py b/tests/core/test_num_computed_tokens_update.py new file mode 100644 index 0000000000000..f3ec24e7bee3e --- /dev/null +++ b/tests/core/test_num_computed_tokens_update.py @@ -0,0 +1,81 @@ +import pytest + +from tests.conftest import VllmRunner +from tests.core.utils import create_dummy_prompt +from vllm.engine.llm_engine import LLMEngine +from vllm.platforms import current_platform +from vllm.sequence import SequenceGroup + +MODEL = "JackFram/llama-160m" + + +def add_seq_group_to_engine(engine: LLMEngine, seq_group: SequenceGroup): + scheduler = engine.scheduler[0] + scheduler.add_seq_group(seq_group) + + +@pytest.mark.parametrize("num_scheduler_steps", [1, 8]) +@pytest.mark.parametrize("enable_chunked_prefill", [False, True]) +@pytest.mark.parametrize("enforce_eager", [False, True]) +def test_num_computed_tokens_update(num_scheduler_steps: int, + enable_chunked_prefill: bool, + enforce_eager: bool): + + is_multi_step = num_scheduler_steps > 1 + is_multi_step_chunked_prefill = is_multi_step and enable_chunked_prefill + + if is_multi_step_chunked_prefill and current_platform.is_rocm(): + pytest.skip("Multi-step with Chunked-Prefill does not support " + "rocm_flash_attn backend") + + # Make a vllm engine + runner = VllmRunner(model_name=MODEL, + gpu_memory_utilization=0.7, + use_v2_block_manager=True, + num_scheduler_steps=num_scheduler_steps, + enable_chunked_prefill=enable_chunked_prefill, + enforce_eager=enforce_eager) + engine: LLMEngine = runner.model.llm_engine + + # In multi-step + chunked-prefill there is no separate single prompt step. + # What is scheduled will run for num_scheduler_steps always. + num_prompt_steps = num_scheduler_steps \ + if is_multi_step_chunked_prefill else 1 + + num_output_tokens_list = [4, 8, 12, 15, 16, 17] + + # Create sequence and add to engine + prompt_len = 10 + + for req_idx, num_output_tokens in enumerate(num_output_tokens_list): + seq, seq_group = create_dummy_prompt(request_id=str(req_idx), + prompt_length=prompt_len, + min_tokens=num_output_tokens, + max_tokens=num_output_tokens) + add_seq_group_to_engine(engine, seq_group) + + assert seq.data.get_num_computed_tokens() == 0 + + for _ in range(num_prompt_steps): + # prompt steps + engine.step() + + if not seq.is_finished(): + prompt_num_computed_tokens = seq.data.get_num_computed_tokens() + # Test correctness of num_computed_tokens after the prompt steps + assert prompt_num_computed_tokens == \ + prompt_len + num_prompt_steps - 1 + + decode_step_counter = 0 + while not seq.is_finished(): + # Test correctness of num_computed_tokens after the decode steps + assert seq.data.get_num_computed_tokens( + ) == prompt_num_computed_tokens + decode_step_counter + for _ in range(num_scheduler_steps): + # decode step + engine.step() + decode_step_counter += 1 + + # Test correctness of num_computed_tokens after the sequence finish. + assert seq.data.get_num_computed_tokens( + ) == prompt_len + num_output_tokens - 1 diff --git a/tests/core/utils.py b/tests/core/utils.py index 40d8f51fc186e..1e4332268c2f3 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -16,6 +16,8 @@ def create_dummy_prompt( use_beam_search: bool = False, best_of: int = 1, prompt_tokens: Optional[List[int]] = None, + min_tokens: int = 0, + max_tokens: int = 16, ) -> Tuple[Sequence, SequenceGroup]: if not block_size: block_size = prompt_length @@ -36,7 +38,9 @@ def create_dummy_prompt( arrival_time=time.time(), sampling_params=SamplingParams( use_beam_search=use_beam_search, - best_of=best_of), + best_of=best_of, + max_tokens=max_tokens, + min_tokens=min_tokens), lora_request=lora_request) return prompt, seq_group diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index fb5cd11ec033a..7456aab8b8d2a 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -191,12 +191,22 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: ) return self._cached_decode_metadata - def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", sampled_token_ids: Optional[torch.Tensor], - block_size: int, num_seqs: int, num_queries: int): + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): """ Update metadata in-place to advance one decode step. """ + + assert not turn_prefills_into_decodes, \ + ("Chunked prefill is not supported with rocm_flash_attn yet." + "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " + "specific parameter.") + # When using cudagraph, the num_seqs is padded to the next captured # batch sized, but num_queries tracks the actual number of requests in # the batch. For --enforce-eager mode, num_seqs == num_queries diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d6258c6413d87..62fb0aa5f859f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -965,6 +965,45 @@ def _process_sequence_group_outputs( return + def _update_num_computed_tokens_for_multi_step_prefill( + self, seq_group: SequenceGroup, + seq_group_meta: SequenceGroupMetadata, + is_first_step_output: Optional[bool]): + """ + This function updates num_computed_tokens for prompt sequences + when Multi-Step is enabled. + + seq_group: SequenceGroup to update the num_computed_tokens for. + seq_group_meta: Metadata of the given SequenceGroup. + is_first_step_output: Optional[bool] - + When available, is_first_step_output indicates if the appended + output token is the output of the first-step in multi-step. + A value of None indicates that outputs from all steps in + in multi-step are submitted in a single burst. + """ + + assert self.scheduler_config.is_multi_step + + if not seq_group_meta.is_prompt: + # num_computed_token updates for multi-step decodes happen after + # the tokens are appended to the sequence. + return + + do_update: bool = False + if self.scheduler_config.chunked_prefill_enabled: + # In multi-step + chunked-prefill case, the prompt sequences + # that are scheduled are fully processed in the first step. + do_update = is_first_step_output is None or is_first_step_output + else: + # Normal multi-step decoding case. In this case prompt-sequences + # are actually single-stepped. Always update in this case. + assert seq_group.state.num_steps == 1 + do_update = True + + if do_update: + seq_group.update_num_computed_tokens( + seq_group_meta.token_chunk_size) + def _process_model_outputs(self, ctx: SchedulerContext, request_id: Optional[str] = None) -> None: @@ -975,64 +1014,6 @@ def _process_model_outputs(self, request_id: If provided, then only this request is going to be processed """ - def update_prefill_num_computed_tokens( - seq_group: SequenceGroup, - seq_group_meta: SequenceGroupMetadata, num_outputs: int, - is_first_step_output: Optional[bool]) -> None: - """ - When multi-step and chunked-prefill are enabled together, the - prefill sequence scheduled for multi-step execution turn into - decodes in the first step itself. This function accounts - for that conversion. - - seq_group: SequenceGroup - A prefill seq_group - seq_group_meta: SequenceGroupMetadata - Metadata of the given - prefill seq_group - num_outputs: int - number of output tokens being processed for the - given seq_group - is_first_step_output: Optional[bool] - - If multi-step is enabled and num_outputs is 1, this value - indicates if this outputs belongs to the first step in the - multi-step. - If multi-step is enabled and num_outputs > 1, this value - must be None, as num_outputs > 1 indicates that outputs from - all the steps in multi-step are submitted in a single burst. - When multi-step is disabled, this value is always True. - """ - - assert seq_group_meta.is_prompt - - token_chunk_size = seq_group_meta.token_chunk_size - - if num_outputs == 1: - assert is_first_step_output is not None - - if seq_group_meta.state.num_steps == 1: - assert is_first_step_output is True - seq_group.update_num_computed_tokens(token_chunk_size) - return - - # multi-step prefill is only supported when multi-step is - # enabled with chunked prefill - assert self.scheduler_config.is_multi_step and \ - self.scheduler_config.chunked_prefill_enabled - if is_first_step_output is True: - # This sequence is a prompt during the first step only. - seq_group.update_num_computed_tokens(token_chunk_size) - return - - assert is_first_step_output is None - - # multi-step prefill is only supported when multi-step is - # enabled with chunked prefill. Outputs from all the steps are - # submitted in a single burst. - assert self.scheduler_config.is_multi_step and \ - self.scheduler_config.chunked_prefill_enabled - assert num_outputs == seq_group_meta.state.num_steps, \ - f"#outputs {len(outputs)} - num steps {seq_group_meta.state.num_steps}" #noqa - # This sequence is a prompt during the first step only. - seq_group.update_num_computed_tokens(token_chunk_size) - now = time.time() if len(ctx.output_queue) == 0: @@ -1093,7 +1074,7 @@ def update_prefill_num_computed_tokens( seq_group_meta = seq_group_metadata_list[i] scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - seq_group = scheduled_seq_group.seq_group + seq_group: SequenceGroup = scheduled_seq_group.seq_group if seq_group.is_finished(): finished_before.append(i) @@ -1104,14 +1085,14 @@ def update_prefill_num_computed_tokens( else: output = [outputs_by_sequence_group[0][i]] - if not is_async and seq_group_meta.is_prompt: - # Updates for all decodes happen when we actually append the - # token ids to the seq in process_outputs. - update_prefill_num_computed_tokens(seq_group, seq_group_meta, - len(output), - is_first_step_output) - elif not is_async: - seq_group.update_num_computed_tokens(1) + if not is_async: + if self.scheduler_config.is_multi_step: + # Updates happen only if the sequence is prefill + self._update_num_computed_tokens_for_multi_step_prefill( + seq_group, seq_group_meta, is_first_step_output) + else: + seq_group.update_num_computed_tokens( + seq_group_meta.token_chunk_size) if outputs: for o in outputs: @@ -1135,16 +1116,8 @@ def update_prefill_num_computed_tokens( else: self.output_processor.process_prompt_logprob(seq_group, output) if seq_group_meta.do_sample: - output_token_num = self.output_processor.process_outputs( + self.output_processor.process_outputs( seq_group, output, is_async) - if self.speculative_config: - # We -1 here because we always - # (w/o speculative decoding) add the number of - # computed tokens by one in the decoding phase. - # Therefore, we remove that one token that - # is already added. - seq_group.update_num_computed_tokens(output_token_num - - 1) if seq_group.is_finished(): finished_now.append(i) @@ -1253,20 +1226,15 @@ def _advance_to_next_step( if seq_group.is_finished(): continue - if seq_group_metadata.is_prompt: - if self.scheduler_config.is_multi_step and \ - self.scheduler_config.chunked_prefill_enabled: - # Prompts are scheduled in multi-step only when - # chunking is enabled. These prompts turn into - # decodes after the very first step. Therefore, - # we skip the update to the num_computed_tokens - # here. - seq_group.update_num_computed_tokens(1) - else: - seq_group.update_num_computed_tokens( - seq_group_metadata.token_chunk_size) + if self.scheduler_config.is_multi_step: + # Updates happen only if the sequence is prefill + self._update_num_computed_tokens_for_multi_step_prefill( + seq_group, seq_group_metadata, + seq_group.state.num_steps == 1) else: - seq_group.update_num_computed_tokens(1) + seq_group.update_num_computed_tokens( + seq_group_metadata.token_chunk_size) + if seq_group_metadata.do_sample: assert len(sequence_group_outputs.samples) == 1, ( "Async output processor expects a single sample" @@ -1276,7 +1244,15 @@ def _advance_to_next_step( assert len(seq_group.seqs) == 1 seq = seq_group.seqs[0] - seq.append_token_id(sample.output_token, sample.logprobs) + + if self.scheduler_config.is_multi_step: + is_prefill_append = seq.data.get_num_uncomputed_tokens( + ) == 0 + seq.append_token_id(sample.output_token, sample.logprobs) + if not is_prefill_append: + seq_group.update_num_computed_tokens(1) + else: + seq.append_token_id(sample.output_token, sample.logprobs) def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 554880a3cc438..50adaf4e59188 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, List, Optional +from typing import Callable, List from vllm.config import SchedulerConfig from vllm.core.scheduler import Scheduler @@ -58,14 +58,10 @@ def create_output_processor( @abstractmethod def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput], - is_async: bool) -> Optional[int]: + is_async: bool) -> None: """Process new token ids for the sequence group. Handles logic such as detokenization, stop checking, and freeing/forking sequences in the scheduler. - - Return the number of new tokens generated in the sequence group. - The returned value is optional because it is only used for - speculative decoding mqa scorer. """ pass diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index f35b1ba9c2bdd..47de3656ca892 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -1,5 +1,5 @@ import functools -from typing import Callable, List, Optional +from typing import Callable, List from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.interfaces import ( @@ -69,7 +69,7 @@ def _log_prompt_logprob_unsupported_warning_once(): def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput], - is_async: bool = False) -> Optional[int]: + is_async: bool = False) -> None: """Append new tokens in the outputs to sequences in the sequence group. This only supports sequence groups of size 1. It supports greater than @@ -84,10 +84,6 @@ def process_outputs(self, tokens from the previous step. If this is true, then no tokens need to be appended since it is already done externally (before the next schedule() call) - - Returns: - The number of tokens appended to the sequence. This is optional - because only speculative decode uses this return value. """ # Sequences can be in RUNNING or FINISHED_ABORTED state # once scheduled, as a sequence is moved to FINSIHED_ABORTED @@ -110,7 +106,6 @@ def process_outputs(self, # was already appended, so we only need to do the rest of the # postprocessor: Detokenization + stopping logic self._process_decode_and_stop(seq, sequence_group.sampling_params) - return None else: # Standard multi-step case @@ -126,8 +121,8 @@ def process_outputs(self, ] assert valid_samples - return self._process_seq_outputs(seq, valid_samples, - sequence_group.sampling_params) + self._process_seq_outputs(seq, valid_samples, + sequence_group.sampling_params) def _process_decode_and_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None: @@ -145,7 +140,7 @@ def _process_decode_and_stop(self, seq: Sequence, def _process_seq_outputs(self, seq: Sequence, valid_samples: List[SequenceOutput], - sampling_params: SamplingParams) -> int: + sampling_params: SamplingParams) -> None: output_token_ids = [sample.output_token for sample in valid_samples] output_logprobs = [sample.logprobs for sample in valid_samples] @@ -168,6 +163,7 @@ def _process_seq_outputs(self, seq: Sequence, output_token_ids = output_token_ids[:i + 1] break + is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0 # Incrementally append tokens to the sequence, as if we had only one new # token. for output_token_id, output_logprob in zip(output_token_ids, @@ -177,8 +173,14 @@ def _process_seq_outputs(self, seq: Sequence, logprobs=output_logprob, ) + if is_prefill_sampled_token: + is_prefill_sampled_token = False + else: + # Update num_computed_tokens iff the sampled token is not from + # a prefill step. + seq.data.update_num_computed_tokens(1) + self._process_decode_and_stop(seq, sampling_params) if seq.is_finished(): break - return len(output_token_ids)