From 1eb9991cd19bfd388eadf49b06318fa2579a5eab Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 9 Aug 2024 16:13:04 -0700 Subject: [PATCH 01/13] [Core] Add engine option to return only deltas or final output The LLMEngine and AsyncLLMEngine APIs will currently return/stream cumulative outputs for all sequences at every step. This is more data than needed for LLM.generate or the OpenAI server APIs: - For LLM.generate and non-streaming APIs we only need the final output - For streaming APIs we only require deltas This PR adds an `output_kind` parameter to SamplingParams with an enum value of either CUMULATIVE, DELTA, or FINAL_ONLY. It will reduce the number of objects that need to be constructed at each step, and the amount of data to be serialized to return to the newly-decoupled front-end API process. --- vllm/engine/llm_engine.py | 28 ++++++-- vllm/entrypoints/llm.py | 16 ++--- vllm/entrypoints/openai/protocol.py | 7 +- vllm/entrypoints/openai/serving_chat.py | 21 +++--- vllm/entrypoints/openai/serving_completion.py | 18 +++-- vllm/outputs.py | 71 +++++++++++++------ vllm/sampling_params.py | 16 ++++- vllm/sequence.py | 7 +- 8 files changed, 121 insertions(+), 63 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1191d0c66044..4441138e6115 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -34,7 +34,7 @@ RequestOutputFactory) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, PoolerOutput, SamplerOutput, Sequence, SequenceGroup, SequenceGroupMetadata, @@ -1182,11 +1182,23 @@ def _process_model_outputs( output_by_sequence_group = create_output_by_sequence_group( output, num_seq_groups=len(scheduled_seq_groups)) + # seq_id to (output token count, text len) + # only for delta output seq groups + previous_output_lens: Dict[int, Tuple[int, int]] = {} + # Update the scheduled sequence groups with the model outputs. for scheduled_seq_group, outputs, seq_group_meta in zip( scheduled_seq_groups, output_by_sequence_group, seq_group_metadata_list): - seq_group = scheduled_seq_group.seq_group + seq_group: SequenceGroup = scheduled_seq_group.seq_group + if seq_group.sampling_params.output_kind == RequestOutputKind.DELTA: + text_buffer_length = ( + seq_group.sampling_params.output_text_buffer_length) + for seq in seq_group.seqs: + previous_output_lens[seq.seq_id] = ( + seq.get_output_len(), + seq.get_output_text_to_return_len(text_buffer_length)) + seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) if output is not None and len(output) > 0: @@ -1223,11 +1235,15 @@ def _process_model_outputs( for scheduled_seq_group in scheduled_seq_groups: seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - request_output = RequestOutputFactory.create(seq_group) - request_outputs.append(request_output) + request_output = RequestOutputFactory.create( + seq_group, previous_output_lens) + if request_output: + request_outputs.append(request_output) for seq_group in ignored_seq_groups: - request_output = RequestOutputFactory.create(seq_group) - request_outputs.append(request_output) + if seq_group.sampling_params.output_kind == ( + RequestOutputKind.CUMULATIVE): + request_output = RequestOutputFactory.create(seq_group) + request_outputs.append(request_output) return request_outputs def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 175f418a1294..2a8ae13401c4 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -16,7 +16,7 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.transformers_utils.tokenizer import get_cached_tokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, deprecate_kwargs @@ -547,14 +547,12 @@ def _validate_and_add_requests( raise ValueError("The lengths of prompts and lora_request " "must be the same.") - if isinstance(params, list): - params = [ - self._add_guided_processor(param, guided_options) - if isinstance(param, SamplingParams) else param - for param in params - ] - elif isinstance(params, SamplingParams): - params = self._add_guided_processor(params, guided_options) + for sp in params if isinstance(params, list) else (params, ): + if isinstance(sp, SamplingParams): + self._add_guided_processor(sp, guided_options) + + # We only care about the final output + sp.output_kind = RequestOutputKind.FINAL_ONLY # Add requests to the engine. for i, request_inputs in enumerate(inputs): diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7da3002b283f..2191f4d311e3 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -12,7 +12,8 @@ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.pooling_params import PoolingParams -from vllm.sampling_params import LogitsProcessor, SamplingParams +from vllm.sampling_params import (LogitsProcessor, RequestOutputKind, + SamplingParams) from vllm.utils import random_uuid # torch is mocked during docs generation, @@ -275,6 +276,8 @@ def to_sampling_params( length_penalty=self.length_penalty, logits_processors=logits_processors, truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, ) @model_validator(mode='before') @@ -461,6 +464,8 @@ def to_sampling_params( length_penalty=self.length_penalty, logits_processors=logits_processors, truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, ) @model_validator(mode="before") diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 2167b967b14b..0f79c9fe3a82 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -213,7 +213,6 @@ async def chat_completion_stream_generator( # Send response for each token for each request.n (index) num_choices = 1 if request.n is None else request.n - previous_texts = [""] * num_choices previous_num_tokens = [0] * num_choices finish_reason_sent = [False] * num_choices @@ -301,25 +300,20 @@ async def chat_completion_stream_generator( if finish_reason_sent[i]: continue - delta_token_ids = output.token_ids[previous_num_tokens[i]:] - out_logprobs = output.logprobs[ - previous_num_tokens[i]:] if output.logprobs else None - if request.logprobs and request.top_logprobs is not None: - assert out_logprobs is not None, ( + assert output.logprobs is not None, ( "Did not output logprobs") logprobs = self._create_chat_logprobs( - token_ids=delta_token_ids, - top_logprobs=out_logprobs, + token_ids=output.token_ids, + top_logprobs=output.logprobs, tokenizer=tokenizer, num_output_top_logprobs=request.top_logprobs, ) else: logprobs = None - delta_text = output.text[len(previous_texts[i]):] - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) + delta_text = output.text + previous_num_tokens[i] += len(output.token_ids) if request.tool_choice and type( request.tool_choice @@ -398,10 +392,11 @@ async def chat_completion_stream_generator( if (request.stream_options and request.stream_options.include_usage): + completion_tokens = previous_num_tokens[i] final_usage = UsageInfo( prompt_tokens=prompt_tokens, - completion_tokens=previous_num_tokens[i], - total_tokens=prompt_tokens + previous_num_tokens[i], + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, ) final_usage_chunk = ChatCompletionStreamResponse( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index f4c91ce04684..35057967ea38 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -221,7 +221,7 @@ async def completion_stream_generator( tokenizer: PreTrainedTokenizer, ) -> AsyncGenerator[str, None]: num_choices = 1 if request.n is None else request.n - previous_texts = [""] * num_choices * num_prompts + previous_text_lens = [0] * num_choices * num_prompts previous_num_tokens = [0] * num_choices * num_prompts has_echoed = [False] * num_choices * num_prompts @@ -251,11 +251,9 @@ async def completion_stream_generator( has_echoed[i] = True else: # return just the delta - delta_text = output.text[len(previous_texts[i]):] - delta_token_ids = output.token_ids[ - previous_num_tokens[i]:] - out_logprobs = output.logprobs[previous_num_tokens[ - i]:] if output.logprobs else None + delta_text = output.text + delta_token_ids = output.token_ids + out_logprobs = output.logprobs if request.logprobs is not None: assert out_logprobs is not None, ( @@ -265,13 +263,13 @@ async def completion_stream_generator( top_logprobs=out_logprobs, num_output_top_logprobs=request.logprobs, tokenizer=tokenizer, - initial_text_offset=len(previous_texts[i]), + initial_text_offset=previous_text_lens[i], ) else: logprobs = None - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) + previous_text_lens[i] += len(output.text) + previous_num_tokens[i] += len(output.token_ids) finish_reason = output.finish_reason stop_reason = output.stop_reason @@ -293,7 +291,7 @@ async def completion_stream_generator( if (request.stream_options.continuous_usage_stats or output.finish_reason is not None): prompt_tokens = len(res.prompt_token_ids) - completion_tokens = len(output.token_ids) + completion_tokens = previous_num_tokens[i] usage = UsageInfo( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, diff --git a/vllm/outputs.py b/vllm/outputs.py index e091b576f597..a521e9b5b13a 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,10 +1,11 @@ import time from dataclasses import dataclass -from typing import List, Optional +from typing import Dict, List, Optional from typing import Sequence as GenericSequence -from typing import Union +from typing import Tuple, Union from vllm.lora.request import LoRARequest +from vllm.sampling_params import RequestOutputKind from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, SequenceGroup, SequenceStatus) @@ -113,19 +114,28 @@ def __init__( self.encoder_prompt_token_ids = encoder_prompt_token_ids @classmethod - def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": - if seq_group.sampling_params is None: + def from_seq_group( + cls, + seq_group: SequenceGroup, + prior_output_lens: Dict[int, Tuple[int, int]], + ) -> Optional["RequestOutput"]: + sampling_params = seq_group.sampling_params + if sampling_params is None: raise ValueError( "Sampling parameters are missing for a CompletionRequest.") + finished = seq_group.is_finished() + if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY: + return None + seqs = seq_group.get_seqs() if len(seqs) == 1: top_n_seqs = seqs else: # Get the top-n sequences. - n = seq_group.sampling_params.n - if seq_group.sampling_params.use_beam_search: + n = sampling_params.n + if sampling_params.use_beam_search: sorting_key = lambda seq: seq.get_beam_search_score( - seq_group.sampling_params.length_penalty) + sampling_params.length_penalty) else: sorting_key = lambda seq: seq.get_cumulative_logprob() sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) @@ -135,18 +145,34 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": # NOTE: We need omit logprobs here explicitly because the sequence # always has the logprobs of the sampled tokens even if the # logprobs are not requested. - include_logprobs = seq_group.sampling_params.logprobs is not None - text_buffer_length = seq_group.sampling_params.output_text_buffer_length - outputs = [ - CompletionOutput( - seqs.index(seq), - seq.get_output_text_to_return(text_buffer_length), - seq.data._output_token_ids, - seq.get_cumulative_logprob() if include_logprobs else None, - seq.output_logprobs if include_logprobs else None, - SequenceStatus.get_finished_reason(seq.status), - seq.stop_reason) for seq in top_n_seqs - ] + include_logprobs = sampling_params.logprobs is not None + text_buffer_length = sampling_params.output_text_buffer_length + + outputs = [] + for seq in top_n_seqs: + output_token_ids = seq.data._output_token_ids + output_logprobs = seq.output_logprobs if include_logprobs else None + output_text = seq.get_output_text_to_return(text_buffer_length) + + # Truncate if only deltas are requested + prior_out_token_len, prior_text_len = prior_output_lens.get( + seq.seq_id, (0, 0)) + if prior_out_token_len: + output_token_ids = output_token_ids[prior_out_token_len:] + if output_logprobs: + output_logprobs = output_logprobs[prior_out_token_len:] + #TODO get deta directly from incremental detokenization to + # avoid re-slicing + if prior_text_len: + output_text = output_text[prior_text_len:] + + outputs.append( + CompletionOutput( + seqs.index(seq), output_text, output_token_ids, + seq.get_cumulative_logprob() if include_logprobs else None, + output_logprobs, + SequenceStatus.get_finished_reason(seq.status), + seq.stop_reason)) # Every sequence in the sequence group should have the same prompt. prompt = seq_group.prompt @@ -154,7 +180,6 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": encoder_prompt = seq_group.encoder_prompt encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids prompt_logprobs = seq_group.prompt_logprobs - finished = seq_group.is_finished() finished_time = time.time() if finished else None seq_group.set_finished_time(finished_time) return cls(seq_group.request_id, @@ -230,10 +255,12 @@ def __repr__(self): class RequestOutputFactory: @staticmethod - def create(seq_group): + def create(seq_group, + previous_output_lens: Dict[int, Tuple[int, int]] = {}): # noqa # Determine the type based on a condition, for example: if hasattr(seq_group, 'embeddings') and seq_group.embeddings is not None: return EmbeddingRequestOutput.from_seq_group(seq_group) else: - return RequestOutput.from_seq_group(seq_group) + return RequestOutput.from_seq_group(seq_group, + previous_output_lens) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 1c1e5f16b517..3de19af13bc0 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,6 +1,6 @@ """Sampling parameters for text generation.""" import copy -from enum import IntEnum +from enum import Enum, IntEnum from functools import cached_property from typing import Any, Callable, Dict, List, Optional, Union @@ -32,6 +32,15 @@ class SamplingType(IntEnum): to sample from.""" +class RequestOutputKind(Enum): + # Return entire output so far in every RequestOutput + CUMULATIVE = 0 + # Return only deltas in each RequestOutput + DELTA = 1 + # Do not return intermediate RequestOuputs + FINAL_ONLY = 2 + + class SamplingParams: """Sampling parameters for text generation. @@ -139,6 +148,7 @@ def __init__( spaces_between_special_tokens: bool = True, logits_processors: Optional[List[LogitsProcessor]] = None, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -180,6 +190,7 @@ def __init__( self.logits_processors = logits_processors self.include_stop_str_in_output = include_stop_str_in_output self.truncate_prompt_tokens = truncate_prompt_tokens + self.output_kind = output_kind # Number of characters to hold back for stop string evaluation # until sequence is finished. if self.stop and not include_stop_str_in_output: @@ -256,6 +267,9 @@ def _verify_args(self) -> None: raise ValueError( "stop strings are only supported when detokenize is True. " "Set detokenize=True to use stop.") + if self.best_of != self.n and self.output_kind == ( + RequestOutputKind.DELTA): + raise ValueError("best_of must equal n to use output_kind=DELTA") def _verify_beam_search(self) -> None: if self.best_of == 1: diff --git a/vllm/sequence.py b/vllm/sequence.py index 7349bc6f13bd..b6ee1dd829d3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -386,12 +386,17 @@ def prompt_adapter_id(self) -> int: return self.prompt_adapter_request.prompt_adapter_id \ if self.prompt_adapter_request else 0 - def get_output_text_to_return(self, buffer_length: int): + def get_output_text_to_return(self, buffer_length: int) -> str: # We return the full output text if the sequence is finished. truncate = buffer_length and not self.is_finished() return self.output_text[:-buffer_length] if truncate else ( self.output_text) + def get_output_text_to_return_len(self, buffer_length: int) -> int: + if not buffer_length or self.is_finished(): + return len(self.output_text) + return max(0, len(self.output_text) - buffer_length) + def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size From 9bc3fdd161e152119c908d9319bb87a7716bccc3 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 9 Aug 2024 18:42:48 -0700 Subject: [PATCH 02/13] Fixes --- vllm/engine/llm_engine.py | 7 ++++--- vllm/outputs.py | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4441138e6115..ec2f01344048 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1191,9 +1191,10 @@ def _process_model_outputs( scheduled_seq_groups, output_by_sequence_group, seq_group_metadata_list): seq_group: SequenceGroup = scheduled_seq_group.seq_group - if seq_group.sampling_params.output_kind == RequestOutputKind.DELTA: - text_buffer_length = ( - seq_group.sampling_params.output_text_buffer_length) + params = seq_group.sampling_params + if params is not None and (params.output_kind + == RequestOutputKind.DELTA): + text_buffer_length = params.output_text_buffer_length for seq in seq_group.seqs: previous_output_lens[seq.seq_id] = ( seq.get_output_len(), diff --git a/vllm/outputs.py b/vllm/outputs.py index a521e9b5b13a..d57b7c88c69d 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -124,7 +124,8 @@ def from_seq_group( raise ValueError( "Sampling parameters are missing for a CompletionRequest.") finished = seq_group.is_finished() - if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY: + if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( + not finished): return None seqs = seq_group.get_seqs() From ef2e59fddd434457f3f9598398cd721251a24979 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sat, 10 Aug 2024 08:44:57 -0700 Subject: [PATCH 03/13] Fix ignored sequence case --- vllm/engine/llm_engine.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ec2f01344048..c1a513cb9fa9 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1241,9 +1241,19 @@ def _process_model_outputs( if request_output: request_outputs.append(request_output) for seq_group in ignored_seq_groups: - if seq_group.sampling_params.output_kind == ( - RequestOutputKind.CUMULATIVE): - request_output = RequestOutputFactory.create(seq_group) + params = seq_group.sampling_params + if params is not None and params.output_kind == ( + RequestOutputKind.DELTA): + if not seq_group.is_finished(): + continue + # Ignored seq groups have no delta, but we must still return + # an "empty" RequestOutput when finished + for seq in seq_group.seqs: + previous_output_lens[seq.seq_id] = (seq.get_output_len(), + seq.output_text) + request_output = RequestOutputFactory.create( + seq_group, previous_output_lens) + if request_output: request_outputs.append(request_output) return request_outputs From dc1f3f236f72e7fca1b47bda8d09aa226a21b417 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 12 Aug 2024 18:24:17 -0700 Subject: [PATCH 04/13] Also exclude prompt details in subsequent outputs in delta mode --- vllm/engine/llm_engine.py | 20 ++++++++++++++++---- vllm/outputs.py | 29 ++++++++++++++++++++--------- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c1a513cb9fa9..97161aa15ddd 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1182,9 +1182,12 @@ def _process_model_outputs( output_by_sequence_group = create_output_by_sequence_group( output, num_seq_groups=len(scheduled_seq_groups)) - # seq_id to (output token count, text len) + # seq_id to (output token count, text len), # only for delta output seq groups previous_output_lens: Dict[int, Tuple[int, int]] = {} + # Seq groups whose outputs should not have prompt details included, + # only applies to delta output seq groups + exclude_prompt_seq_group_ids = set() # Update the scheduled sequence groups with the model outputs. for scheduled_seq_group, outputs, seq_group_meta in zip( @@ -1196,8 +1199,13 @@ def _process_model_outputs( == RequestOutputKind.DELTA): text_buffer_length = params.output_text_buffer_length for seq in seq_group.seqs: + output_len = seq.get_output_len() + if output_len: + # Exclude the prompt if the seq group already has + # completion tokens + exclude_prompt_seq_group_ids.add(seq_group.request_id) previous_output_lens[seq.seq_id] = ( - seq.get_output_len(), + output_len, seq.get_output_text_to_return_len(text_buffer_length)) seq_group.update_num_computed_tokens( @@ -1236,23 +1244,27 @@ def _process_model_outputs( for scheduled_seq_group in scheduled_seq_groups: seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) + include_prompt = seq_group.request_id not in ( + exclude_prompt_seq_group_ids) request_output = RequestOutputFactory.create( - seq_group, previous_output_lens) + seq_group, previous_output_lens, include_prompt) if request_output: request_outputs.append(request_output) for seq_group in ignored_seq_groups: params = seq_group.sampling_params + include_prompt = True if params is not None and params.output_kind == ( RequestOutputKind.DELTA): if not seq_group.is_finished(): continue # Ignored seq groups have no delta, but we must still return # an "empty" RequestOutput when finished + include_prompt = False for seq in seq_group.seqs: previous_output_lens[seq.seq_id] = (seq.get_output_len(), seq.output_text) request_output = RequestOutputFactory.create( - seq_group, previous_output_lens) + seq_group, previous_output_lens, include_prompt) if request_output: request_outputs.append(request_output) return request_outputs diff --git a/vllm/outputs.py b/vllm/outputs.py index d57b7c88c69d..78143ed43329 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -93,7 +93,7 @@ def __init__( self, request_id: str, prompt: Optional[str], - prompt_token_ids: List[int], + prompt_token_ids: Optional[List[int]], prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], finished: bool, @@ -118,6 +118,7 @@ def from_seq_group( cls, seq_group: SequenceGroup, prior_output_lens: Dict[int, Tuple[int, int]], + include_prompt: bool = True, ) -> Optional["RequestOutput"]: sampling_params = seq_group.sampling_params if sampling_params is None: @@ -176,11 +177,18 @@ def from_seq_group( seq.stop_reason)) # Every sequence in the sequence group should have the same prompt. - prompt = seq_group.prompt - prompt_token_ids = seq_group.prompt_token_ids - encoder_prompt = seq_group.encoder_prompt - encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids - prompt_logprobs = seq_group.prompt_logprobs + if include_prompt: + prompt = seq_group.prompt + prompt_token_ids = seq_group.prompt_token_ids + encoder_prompt = seq_group.encoder_prompt + encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids + prompt_logprobs = seq_group.prompt_logprobs + else: + prompt = None + prompt_token_ids = None + encoder_prompt = None + encoder_prompt_token_ids = None + prompt_logprobs = None finished_time = time.time() if finished else None seq_group.set_finished_time(finished_time) return cls(seq_group.request_id, @@ -256,12 +264,15 @@ def __repr__(self): class RequestOutputFactory: @staticmethod - def create(seq_group, - previous_output_lens: Dict[int, Tuple[int, int]] = {}): # noqa + def create( + seq_group, + previous_output_lens: Dict[int, Tuple[int, int]] = {}, # noqa + include_prompt: bool = True): # Determine the type based on a condition, for example: if hasattr(seq_group, 'embeddings') and seq_group.embeddings is not None: return EmbeddingRequestOutput.from_seq_group(seq_group) else: return RequestOutput.from_seq_group(seq_group, - previous_output_lens) + previous_output_lens, + include_prompt) From 9d35a00cd7d395eadc2577b2faf932d26139073c Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 13 Aug 2024 11:27:46 -0700 Subject: [PATCH 05/13] Fix prompt token counts in streaming cases --- vllm/entrypoints/openai/serving_chat.py | 40 +++++++++---------- vllm/entrypoints/openai/serving_completion.py | 5 ++- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0f79c9fe3a82..cce7535e6d1b 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -216,8 +216,13 @@ async def chat_completion_stream_generator( previous_num_tokens = [0] * num_choices finish_reason_sent = [False] * num_choices + num_prompt_tokens = 0 + try: async for res in result_generator: + if res.prompt_token_ids is not None: + num_prompt_tokens = len(res.prompt_token_ids) + # We need to do it here, because if there are exceptions in # the result_generator, it needs to be sent as the FIRST # response (by the try...catch). @@ -239,11 +244,11 @@ async def chat_completion_stream_generator( model=model_name) if (request.stream_options and request.stream_options.include_usage): - if (request.stream_options.continuous_usage_stats): - prompt_tokens = len(res.prompt_token_ids) - usage = UsageInfo(prompt_tokens=prompt_tokens, - completion_tokens=0, - total_tokens=prompt_tokens) + if request.stream_options.continuous_usage_stats: + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) chunk.usage = usage else: chunk.usage = None @@ -279,12 +284,10 @@ async def chat_completion_stream_generator( request.stream_options.include_usage): if (request.stream_options. continuous_usage_stats): - prompt_tokens = len( - res.prompt_token_ids) usage = UsageInfo( - prompt_tokens=prompt_tokens, + prompt_tokens=num_prompt_tokens, completion_tokens=0, - total_tokens=prompt_tokens) + total_tokens=num_prompt_tokens) chunk.usage = usage else: chunk.usage = None @@ -342,13 +345,12 @@ async def chat_completion_stream_generator( model=model_name) if (request.stream_options and request.stream_options.include_usage): - if (request.stream_options.continuous_usage_stats): - prompt_tokens = len(res.prompt_token_ids) + if request.stream_options.continuous_usage_stats: completion_tokens = len(output.token_ids) usage = UsageInfo( - prompt_tokens=prompt_tokens, + prompt_tokens=num_prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + + total_tokens=num_prompt_tokens + completion_tokens, ) chunk.usage = usage @@ -359,7 +361,6 @@ async def chat_completion_stream_generator( yield f"data: {data}\n\n" else: # Send the finish response for each request.n only once - prompt_tokens = len(res.prompt_token_ids) choice_data = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, @@ -374,13 +375,12 @@ async def chat_completion_stream_generator( model=model_name) if (request.stream_options and request.stream_options.include_usage): - if (request.stream_options.continuous_usage_stats): - prompt_tokens = len(res.prompt_token_ids) + if request.stream_options.continuous_usage_stats: completion_tokens = len(output.token_ids) usage = UsageInfo( - prompt_tokens=prompt_tokens, + prompt_tokens=num_prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + + total_tokens=num_prompt_tokens + completion_tokens, ) chunk.usage = usage @@ -394,9 +394,9 @@ async def chat_completion_stream_generator( and request.stream_options.include_usage): completion_tokens = previous_num_tokens[i] final_usage = UsageInfo( - prompt_tokens=prompt_tokens, + prompt_tokens=num_prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, ) final_usage_chunk = ChatCompletionStreamResponse( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 35057967ea38..ef5044789f14 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -224,9 +224,12 @@ async def completion_stream_generator( previous_text_lens = [0] * num_choices * num_prompts previous_num_tokens = [0] * num_choices * num_prompts has_echoed = [False] * num_choices * num_prompts + num_prompt_tokens = [0] * num_prompts try: async for prompt_idx, res in result_generator: + if res.prompt_token_ids is not None: + num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids) for output in res.outputs: i = output.index + prompt_idx * num_choices @@ -290,7 +293,7 @@ async def completion_stream_generator( and request.stream_options.include_usage): if (request.stream_options.continuous_usage_stats or output.finish_reason is not None): - prompt_tokens = len(res.prompt_token_ids) + prompt_tokens = num_prompt_tokens[prompt_idx] completion_tokens = previous_num_tokens[i] usage = UsageInfo( prompt_tokens=prompt_tokens, From b7ff44e82cec206d096e327c62eb25d160f7edd7 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 14 Aug 2024 15:10:40 -0700 Subject: [PATCH 06/13] Simplification suggestion from @joerunde --- vllm/engine/llm_engine.py | 18 +++--------------- vllm/outputs.py | 6 +++--- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 97161aa15ddd..54dfe929a008 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1185,9 +1185,6 @@ def _process_model_outputs( # seq_id to (output token count, text len), # only for delta output seq groups previous_output_lens: Dict[int, Tuple[int, int]] = {} - # Seq groups whose outputs should not have prompt details included, - # only applies to delta output seq groups - exclude_prompt_seq_group_ids = set() # Update the scheduled sequence groups with the model outputs. for scheduled_seq_group, outputs, seq_group_meta in zip( @@ -1199,13 +1196,8 @@ def _process_model_outputs( == RequestOutputKind.DELTA): text_buffer_length = params.output_text_buffer_length for seq in seq_group.seqs: - output_len = seq.get_output_len() - if output_len: - # Exclude the prompt if the seq group already has - # completion tokens - exclude_prompt_seq_group_ids.add(seq_group.request_id) previous_output_lens[seq.seq_id] = ( - output_len, + seq.get_output_len(), seq.get_output_text_to_return_len(text_buffer_length)) seq_group.update_num_computed_tokens( @@ -1244,27 +1236,23 @@ def _process_model_outputs( for scheduled_seq_group in scheduled_seq_groups: seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - include_prompt = seq_group.request_id not in ( - exclude_prompt_seq_group_ids) request_output = RequestOutputFactory.create( - seq_group, previous_output_lens, include_prompt) + seq_group, previous_output_lens) if request_output: request_outputs.append(request_output) for seq_group in ignored_seq_groups: params = seq_group.sampling_params - include_prompt = True if params is not None and params.output_kind == ( RequestOutputKind.DELTA): if not seq_group.is_finished(): continue # Ignored seq groups have no delta, but we must still return # an "empty" RequestOutput when finished - include_prompt = False for seq in seq_group.seqs: previous_output_lens[seq.seq_id] = (seq.get_output_len(), seq.output_text) request_output = RequestOutputFactory.create( - seq_group, previous_output_lens, include_prompt) + seq_group, previous_output_lens) if request_output: request_outputs.append(request_output) return request_outputs diff --git a/vllm/outputs.py b/vllm/outputs.py index 78143ed43329..0f4eae8d691c 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -118,7 +118,6 @@ def from_seq_group( cls, seq_group: SequenceGroup, prior_output_lens: Dict[int, Tuple[int, int]], - include_prompt: bool = True, ) -> Optional["RequestOutput"]: sampling_params = seq_group.sampling_params if sampling_params is None: @@ -151,6 +150,7 @@ def from_seq_group( text_buffer_length = sampling_params.output_text_buffer_length outputs = [] + include_prompt = True for seq in top_n_seqs: output_token_ids = seq.data._output_token_ids output_logprobs = seq.output_logprobs if include_logprobs else None @@ -160,6 +160,7 @@ def from_seq_group( prior_out_token_len, prior_text_len = prior_output_lens.get( seq.seq_id, (0, 0)) if prior_out_token_len: + include_prompt = False output_token_ids = output_token_ids[prior_out_token_len:] if output_logprobs: output_logprobs = output_logprobs[prior_out_token_len:] @@ -274,5 +275,4 @@ def create( return EmbeddingRequestOutput.from_seq_group(seq_group) else: return RequestOutput.from_seq_group(seq_group, - previous_output_lens, - include_prompt) + previous_output_lens) From 34df9bd23ac4cd01c91042e734ca903b063374f1 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 14 Aug 2024 21:07:43 -0700 Subject: [PATCH 07/13] Make tests more robust --- tests/entrypoints/openai/test_chat.py | 8 +++++--- tests/utils.py | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index c96d602b6343..0b5e876ce609 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -6,6 +6,7 @@ import jsonschema import openai # use the official client for correctness check import pytest +import pytest_asyncio import torch from openai import BadRequestError @@ -46,9 +47,10 @@ def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811 yield remote_server -@pytest.fixture(scope="module") -def client(server): - return server.get_async_client() +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as cl: + yield cl @pytest.mark.asyncio diff --git a/tests/utils.py b/tests/utils.py index 697bf7d93c36..1e1863338737 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -133,6 +133,7 @@ def get_async_client(self): return openai.AsyncOpenAI( base_url=self.url_for("v1"), api_key=self.DUMMY_API_KEY, + max_retries=0, ) From 45fd069c3d9f3eba98e0b2b08ca047d831170848 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 27 Aug 2024 11:37:18 -0700 Subject: [PATCH 08/13] Post-merge wip --- vllm/engine/llm_engine.py | 70 +++++++++---------- vllm/entrypoints/llm.py | 1 + vllm/entrypoints/openai/serving_chat.py | 1 + vllm/entrypoints/openai/serving_completion.py | 7 +- vllm/outputs.py | 2 +- 5 files changed, 41 insertions(+), 40 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1a39587ffbc0..603cff84c39d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1253,7 +1253,6 @@ def _process_model_outputs(self, # seq.get_output_len(), # seq.get_output_text_to_return_len(text_buffer_length)) - (outputs, seq_group_metadata_list, scheduler_outputs) = self.output_queue.popleft() @@ -1263,35 +1262,34 @@ def _process_model_outputs(self, # Organize outputs by [step][sequence group] instead of # [sequence group][step]. - if len(outputs) > 1: + + if len(outputs) == 1: + outputs_by_sequence_group = outputs[0] + else: outputs_by_sequence_group = create_output_by_sequence_group( outputs, num_seq_groups=len(seq_group_metadata_list)) - else: - outputs_by_sequence_group = outputs - - finished_before: List[int] = [] - for i, seq_group_meta in enumerate(seq_group_metadata_list): - scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + finished_before: Set[int] = set() + for i, (scheduled_seq_group, seq_group_meta, output) in enumerate( + zip(seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups, + outputs_by_sequence_group)): seq_group = scheduled_seq_group.seq_group if seq_group.is_finished(): - finished_before.append(i) + finished_before.add(i) continue - if len(outputs) > 1: - output = outputs_by_sequence_group[i] - else: - output = [outputs_by_sequence_group[0][i]] + if not isinstance(output, GenericSequence): + output = [output] if not is_async: seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) - if outputs: + if seq_group.metrics is not None: for o in outputs: - if (isinstance(o, SamplerOutput) - and seq_group.metrics is not None): + if isinstance(o, SamplerOutput): if seq_group.metrics.model_forward_time is not None: seq_group.metrics.model_forward_time += ( o.model_forward_time) @@ -1319,18 +1317,18 @@ def _process_model_outputs(self, scheduler.free_finished_seq_groups() # Create the outputs. - for i, _ in enumerate(seq_group_metadata_list): - scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + for i, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + seq_group = scheduled_seq_group.seq_group if i in finished_before: continue # Avoids double processing - seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - request_output = RequestOutputFactory.create( - seq_group, previous_output_lens) - if request_output: - self.request_outputs.append(request_output) + request_output = RequestOutputFactory.create( + seq_group, previous_output_lens) + if request_output: + self.request_outputs.append(request_output) for seq_group in scheduler_outputs.ignored_seq_groups: params = seq_group.sampling_params if params is not None and params.output_kind == ( @@ -1354,10 +1352,9 @@ def _process_model_outputs(self, # Tracing self.do_tracing(scheduler_outputs) - return None - + @staticmethod def _advance_to_next_step( - self, output: List[SamplerOutput], + output: List[SamplerOutput], seq_group_metadata_list: List[SequenceGroupMetadata], scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: """Given model output from a single run, append the tokens to the @@ -1385,7 +1382,6 @@ def _advance_to_next_step( seq = seq_group.seqs[0] seq.append_token_id(sample.output_token, sample.logprobs) - def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. @@ -1528,15 +1524,15 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: self.output_queue.append( (output, seq_group_metadata_list, scheduler_outputs)) - if output and allow_async_output_proc: - assert len(output) == 1, ("Multi step decoding does not work " - "with async output processing.") - - self._advance_to_next_step( - output[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) + if allow_async_output_proc: + if output: + assert len(output) == 1, ("Multi step decoding does not work " + "with async output processing.") - if not allow_async_output_proc: + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) + else: self._process_model_outputs(is_async=False) # Log stats. @@ -1632,7 +1628,7 @@ def remove_logger(self, logger_name: str) -> None: def do_log_stats(self, scheduler_outputs: Optional[SchedulerOutputs] = None, model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[List[int]] = None) -> None: + finished_before: Optional[Set[int]] = None) -> None: """Forced log when no requests active.""" if self.log_stats: stats = self._get_stats(scheduler_outputs, model_output, @@ -1643,7 +1639,7 @@ def do_log_stats(self, def _get_stats(self, scheduler_outputs: Optional[SchedulerOutputs], model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[List[int]] = None) -> Stats: + finished_before: Optional[Set[int]] = None) -> Stats: """Get Stats to be Logged to Prometheus. Args: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 74914c774530..02beaad4ca5a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -693,6 +693,7 @@ def _run_engine( if use_tqdm: if isinstance(output, RequestOutput): # Calculate tokens only for RequestOutput + assert output.prompt_token_ids is not None total_in_toks += len(output.prompt_token_ids) in_spd = total_in_toks / pbar.format_dict["elapsed"] total_out_toks += sum( diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5f5264152627..00feb1e554e7 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -486,6 +486,7 @@ async def chat_completion_full_generator( full_message = last_msg_content + choice.message.content choice.message.content = full_message + assert final_res.prompt_token_ids is not None num_prompt_tokens = len(final_res.prompt_token_ids) num_generated_tokens = sum( len(output.token_ids) for output in final_res.outputs) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 00d4b9f42399..42142efb5f23 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -249,6 +249,7 @@ async def completion_stream_generator( assert request.max_tokens is not None if request.echo and request.max_tokens == 0: + assert prompt_token_ids is not None assert prompt_text is not None # only return the prompt delta_text = prompt_text @@ -257,6 +258,7 @@ async def completion_stream_generator( has_echoed[i] = True elif (request.echo and request.max_tokens > 0 and not has_echoed[i]): + assert prompt_token_ids is not None assert prompt_text is not None assert prompt_logprobs is not None # echo the prompt and first token @@ -359,6 +361,7 @@ def request_output_to_completion_response( for final_res in final_res_batch: prompt_token_ids = final_res.prompt_token_ids + assert prompt_token_ids is not None prompt_logprobs = final_res.prompt_logprobs prompt_text = final_res.prompt @@ -414,9 +417,9 @@ def request_output_to_completion_response( ) choices.append(choice_data) + num_generated_tokens += len(output.token_ids) + num_prompt_tokens += len(prompt_token_ids) - num_generated_tokens += sum( - len(output.token_ids) for output in final_res.outputs) usage = UsageInfo( prompt_tokens=num_prompt_tokens, diff --git a/vllm/outputs.py b/vllm/outputs.py index 0f4eae8d691c..dfd403975ea6 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -268,7 +268,7 @@ class RequestOutputFactory: def create( seq_group, previous_output_lens: Dict[int, Tuple[int, int]] = {}, # noqa - include_prompt: bool = True): + ): # Determine the type based on a condition, for example: if hasattr(seq_group, 'embeddings') and seq_group.embeddings is not None: From 28433659dde3bdb5d7b29ae42b650eedfd19246b Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 10 Sep 2024 09:54:19 -0700 Subject: [PATCH 09/13] Fix delta computation, remove unrelated changes --- vllm/engine/llm_engine.py | 91 +++++++++++++++------------------------ vllm/outputs.py | 44 ++++++++----------- vllm/sequence.py | 6 +++ 3 files changed, 59 insertions(+), 82 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index cae15c119829..a382b3401b95 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -96,7 +96,7 @@ class OutputData(NamedTuple): scheduler_outputs: SchedulerOutputs is_async: bool is_last_step: bool - skip: Set[int] + skip: List[int] class SchedulerContext: @@ -119,7 +119,7 @@ def append_output(self, outputs: List[SamplerOutput], scheduler_outputs=scheduler_outputs, is_async=is_async, is_last_step=is_last_step, - skip=set())) + skip=[])) class LLMEngine: @@ -1279,9 +1279,6 @@ def _process_model_outputs(self, if len(ctx.output_queue) == 0: return None - # seq_id to (output token count, text len), - # only for delta output seq groups - previous_output_lens: Dict[int, Tuple[int, int]] = {} # Get pending async postprocessor if request_id: @@ -1299,19 +1296,19 @@ def _process_model_outputs(self, # Organize outputs by [step][sequence group] instead of # [sequence group][step]. - if len(outputs) == 1: - outputs_by_sequence_group = outputs[0] - else: + if len(outputs) > 1: outputs_by_sequence_group = create_output_by_sequence_group( outputs, num_seq_groups=len(seq_group_metadata_list)) + else: + outputs_by_sequence_group = outputs # Determine the requests we need to operate on if request_id: - indices = None + indices = [] for i, seq_group_meta in enumerate(seq_group_metadata_list): if seq_group_meta.request_id == request_id: assert i not in skip # Cannot be called twice - indices = (i, ) + indices.append(i) break # If the request_id was not found, then it means that @@ -1321,45 +1318,35 @@ def _process_model_outputs(self, return else: indices = range(len(seq_group_metadata_list)) # type: ignore - assert isinstance(indices, Iterable) - finished_before: Set[int] = set() - finished_now: Set[int] = set() + finished_before: List[int] = [] + finished_now: List[int] = [] for i in indices: if i in skip: continue seq_group_meta = seq_group_metadata_list[i] scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - output = outputs_by_sequence_group[i] seq_group = scheduled_seq_group.seq_group if seq_group.is_finished(): - finished_before.add(i) + finished_before.append(i) continue - if not isinstance(output, GenericSequence): - output = [output] - - #TODO I don't think this is the correct place to obtain the - # previous output lens anymore, given async output processing - params = seq_group.sampling_params - if params is not None and (params.output_kind - == RequestOutputKind.DELTA): - text_buffer_length = params.output_text_buffer_length - for seq in seq_group.seqs: - previous_output_lens[seq.seq_id] = ( - seq.get_output_len(), - seq.get_output_text_to_return_len(text_buffer_length)) + if len(outputs) > 1: + output = outputs_by_sequence_group[i] + else: + output = [outputs_by_sequence_group[0][i]] if not is_async: seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) - if seq_group.metrics is not None: + if outputs: for o in outputs: - if isinstance(o, SamplerOutput): + if (isinstance(o, SamplerOutput) + and seq_group.metrics is not None): if seq_group.metrics.model_forward_time is not None: seq_group.metrics.model_forward_time += ( o.model_forward_time) @@ -1382,7 +1369,7 @@ def _process_model_outputs(self, seq_group, output, is_async) if seq_group.is_finished(): - finished_now.add(i) + finished_now.append(i) # Generate outputs for the requests that finished this iteration for i in finished_now: @@ -1390,8 +1377,7 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - request_output = RequestOutputFactory.create( - seq_group, previous_output_lens) + request_output = RequestOutputFactory.create(seq_group) if request_output: ctx.request_outputs.append(request_output) @@ -1399,7 +1385,7 @@ def _process_model_outputs(self, # and invoke the request output callback (if there was final output) if request_id: assert len(indices) == 1 - skip.add(indices[0]) + skip.append(indices[0]) if (finished_now and self.process_request_outputs_callback is not None): @@ -1430,24 +1416,17 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - request_output = RequestOutputFactory.create( - seq_group, previous_output_lens) + request_output = RequestOutputFactory.create(seq_group) if request_output: ctx.request_outputs.append(request_output) for seq_group in scheduler_outputs.ignored_seq_groups: params = seq_group.sampling_params if params is not None and params.output_kind == ( - RequestOutputKind.DELTA): - if not seq_group.is_finished(): - continue - # Ignored seq groups have no delta, but we must still return - # an "empty" RequestOutput when finished - for seq in seq_group.seqs: - previous_output_lens[seq.seq_id] = (seq.get_output_len(), - seq.output_text) - request_output = RequestOutputFactory.create( - seq_group, previous_output_lens) + RequestOutputKind.DELTA) and not seq_group.is_finished(): + continue + + request_output = RequestOutputFactory.create(seq_group) if request_output: ctx.request_outputs.append(request_output) @@ -1467,9 +1446,10 @@ def _process_model_outputs(self, # Tracing self.do_tracing(scheduler_outputs) - @staticmethod + return None + def _advance_to_next_step( - output: List[SamplerOutput], + self, output: List[SamplerOutput], seq_group_metadata_list: List[SequenceGroupMetadata], scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: """Given model output from a single run, append the tokens to the @@ -1477,7 +1457,7 @@ def _advance_to_next_step( required if the worker is to perform async forward pass to next step. """ for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \ - zip(seq_group_metadata_list, output, scheduled_seq_groups): + zip(seq_group_metadata_list, output, scheduled_seq_groups): seq_group = scheduled_seq_group.seq_group if seq_group.is_finished(): @@ -1655,17 +1635,16 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: is_async=allow_async_output_proc, is_last_step=True) - if allow_async_output_proc: - if outputs: - assert len(outputs) == 1, ( - "Async postprocessor expects only a single output set") + if outputs and allow_async_output_proc: + assert len(outputs) == 1, ( + "Async postprocessor expects only a single output set") self._advance_to_next_step( outputs[0], seq_group_metadata_list, scheduler_outputs.scheduled_seq_groups) # Check if need to run the usual non-async path - else: + if not allow_async_output_proc: self._process_model_outputs(ctx=ctx) # Log stats. @@ -1769,7 +1748,7 @@ def remove_logger(self, logger_name: str) -> None: def do_log_stats(self, scheduler_outputs: Optional[SchedulerOutputs] = None, model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[Set[int]] = None) -> None: + finished_before: Optional[List[int]] = None) -> None: """Forced log when no requests active.""" if self.log_stats: stats = self._get_stats(scheduler_outputs, model_output, @@ -1780,7 +1759,7 @@ def do_log_stats(self, def _get_stats(self, scheduler_outputs: Optional[SchedulerOutputs], model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[Set[int]] = None) -> Stats: + finished_before: Optional[List[int]] = None) -> Stats: """Get Stats to be Logged to Prometheus. Args: diff --git a/vllm/outputs.py b/vllm/outputs.py index dfd403975ea6..f1b3a138f6de 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,8 +1,8 @@ import time from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import List, Optional from typing import Sequence as GenericSequence -from typing import Tuple, Union +from typing import Union from vllm.lora.request import LoRARequest from vllm.sampling_params import RequestOutputKind @@ -114,11 +114,8 @@ def __init__( self.encoder_prompt_token_ids = encoder_prompt_token_ids @classmethod - def from_seq_group( - cls, - seq_group: SequenceGroup, - prior_output_lens: Dict[int, Tuple[int, int]], - ) -> Optional["RequestOutput"]: + def from_seq_group(cls, + seq_group: SequenceGroup) -> Optional["RequestOutput"]: sampling_params = seq_group.sampling_params if sampling_params is None: raise ValueError( @@ -148,26 +145,25 @@ def from_seq_group( # logprobs are not requested. include_logprobs = sampling_params.logprobs is not None text_buffer_length = sampling_params.output_text_buffer_length + deltas = sampling_params.output_kind == RequestOutputKind.DELTA outputs = [] include_prompt = True for seq in top_n_seqs: - output_token_ids = seq.data._output_token_ids - output_logprobs = seq.output_logprobs if include_logprobs else None output_text = seq.get_output_text_to_return(text_buffer_length) + output_logprobs = seq.output_logprobs if include_logprobs else None - # Truncate if only deltas are requested - prior_out_token_len, prior_text_len = prior_output_lens.get( - seq.seq_id, (0, 0)) - if prior_out_token_len: - include_prompt = False - output_token_ids = output_token_ids[prior_out_token_len:] + if deltas: + output_tokens_ids = seq.data.last_appended_tokens + seq.data.last_appended_tokens = [] + last_text_offset = seq.data.last_output_text_offset + new_text_len = len(output_text) + output_text = output_text[last_text_offset:] + seq.data.last_output_text_offset = new_text_len if output_logprobs: - output_logprobs = output_logprobs[prior_out_token_len:] - #TODO get deta directly from incremental detokenization to - # avoid re-slicing - if prior_text_len: - output_text = output_text[prior_text_len:] + output_logprobs = output_logprobs[-len(output_tokens_ids):] + else: + output_token_ids = seq.data._output_token_ids outputs.append( CompletionOutput( @@ -265,14 +261,10 @@ def __repr__(self): class RequestOutputFactory: @staticmethod - def create( - seq_group, - previous_output_lens: Dict[int, Tuple[int, int]] = {}, # noqa - ): + def create(seq_group): # Determine the type based on a condition, for example: if hasattr(seq_group, 'embeddings') and seq_group.embeddings is not None: return EmbeddingRequestOutput.from_seq_group(seq_group) else: - return RequestOutput.from_seq_group(seq_group, - previous_output_lens) + return RequestOutput.from_seq_group(seq_group) diff --git a/vllm/sequence.py b/vllm/sequence.py index 0e0302a07180..1c2560f7d04e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -165,6 +165,10 @@ class SequenceData(msgspec.Struct, # is called. _new_appended_tokens: List[int] = msgspec.field(default_factory=list) + # TODO: Add doc + last_appended_tokens: List[int] = [] + last_output_text_offset: int = 0 + def __post_init__(self) -> None: assert self._prompt_token_ids.typecode == "l" assert self._output_token_ids.typecode == "l" @@ -220,6 +224,8 @@ def output_token_ids_array(self) -> array: return self._output_token_ids def append_token_id(self, token_id: int, logprob: float) -> None: + self.last_appended_tokens.append(token_id) + self._output_token_ids.append(token_id) self._new_appended_tokens.append(token_id) self._cached_all_token_ids.append(token_id) From a045dff20adba173831873263b587e1f5fe7515a Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 10 Sep 2024 14:48:01 -0700 Subject: [PATCH 10/13] Address Alex's comments, fix include_prompt logic Also avoid appending delta token ids to sequences in cases they aren't needed. --- vllm/engine/llm_engine.py | 5 ++++- vllm/engine/output_processor/multi_step.py | 5 ++++- vllm/engine/output_processor/single_step.py | 6 ++++-- vllm/entrypoints/openai/serving_chat.py | 3 ++- vllm/outputs.py | 14 +++++++++--- vllm/sequence.py | 24 +++++++++++---------- 6 files changed, 38 insertions(+), 19 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a382b3401b95..cf1dc98e234b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1472,10 +1472,13 @@ def _advance_to_next_step( " (i.e sampling_params.n == 1 and no " "sampling_params.best_of > 1)") sample = sequence_group_outputs.samples[0] + track_delta = seq_group.sampling_params.output_kind == ( + RequestOutputKind.DELTA) assert len(seq_group.seqs) == 1 seq = seq_group.seqs[0] - seq.append_token_id(sample.output_token, sample.logprobs) + seq.append_token_id(sample.output_token, sample.logprobs, + track_delta) def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index c73db765fc3b..ad7b6e42220f 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -8,7 +8,7 @@ single_step_process_prompt_logprob) from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer @@ -161,6 +161,8 @@ def _process_seq_outputs(self, seq: Sequence, valid_samples = valid_samples[:i + 1] break + track_delta = sampling_params.output_kind == RequestOutputKind.DELTA + # Incrementally append tokens to the sequence, as if we had only one new # token. for output_token_id, output_logprob in zip(output_token_ids, @@ -168,6 +170,7 @@ def _process_seq_outputs(self, seq: Sequence, seq.append_token_id( token_id=output_token_id, logprobs=output_logprob, + track_delta=track_delta, ) self._process_decode_and_stop(seq, sampling_params) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index e288aa0c4aaf..2aad6d0e63f7 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -6,7 +6,7 @@ SequenceGroupOutputProcessor) from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer @@ -119,7 +119,9 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # only have one sequence seq = seq_group.seqs[0] if not is_async: - seq.append_token_id(sample.output_token, sample.logprobs) + delta = sampling_params.output_kind == RequestOutputKind.DELTA + seq.append_token_id(sample.output_token, sample.logprobs, + delta) if sampling_params.detokenize and self.detokenizer: new_char_count = self.detokenizer.decode_sequence_inplace( seq, sampling_params) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 91a29ff5c820..9ac4fe41a3b3 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -276,13 +276,14 @@ async def chat_completion_stream_generator( else: tool_choice_function_name = None + # Determine whether tools are in use with "auto" tool choice tool_choice_auto = ( not tool_choice_function_name and self._should_stream_with_auto_tool_parsing(request)) all_previous_token_ids: Optional[List[List[int]]] - # These are only used in "auto" tool choice case if tool_choice_auto: + # These are only required in "auto" tool choice case previous_texts = [""] * num_choices all_previous_token_ids = [[]] * num_choices else: diff --git a/vllm/outputs.py b/vllm/outputs.py index 987f618997f7..3c256d94cfb0 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -145,7 +145,6 @@ def from_seq_group(cls, # logprobs are not requested. include_logprobs = sampling_params.logprobs is not None text_buffer_length = sampling_params.output_text_buffer_length - deltas = sampling_params.output_kind == RequestOutputKind.DELTA outputs = [] include_prompt = True @@ -153,18 +152,27 @@ def from_seq_group(cls, output_text = seq.get_output_text_to_return(text_buffer_length) output_logprobs = seq.output_logprobs if include_logprobs else None + all_output_token_ids = seq.data._output_token_ids output_token_ids: GenericSequence[int] - if deltas: + if sampling_params.output_kind == RequestOutputKind.DELTA: + # Get and reset token id delta output_token_ids = seq.data.last_appended_tokens seq.data.last_appended_tokens = [] + # Use last offset to make text delta, update last offset last_text_offset = seq.data.last_output_text_offset new_text_len = len(output_text) output_text = output_text[last_text_offset:] seq.data.last_output_text_offset = new_text_len + # Slice logprobs delta if applicable if output_logprobs: output_logprobs = output_logprobs[-len(output_token_ids):] + # Don't include prompt if this is after the first output + # containing decode token ids + if include_prompt and len(all_output_token_ids) > len( + output_token_ids): + include_prompt = False else: - output_token_ids = seq.data._output_token_ids + output_token_ids = all_output_token_ids outputs.append( CompletionOutput( diff --git a/vllm/sequence.py b/vllm/sequence.py index 1c2560f7d04e..a392ee20baa1 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -165,7 +165,7 @@ class SequenceData(msgspec.Struct, # is called. _new_appended_tokens: List[int] = msgspec.field(default_factory=list) - # TODO: Add doc + # These are used to keep track of delta outputs last_appended_tokens: List[int] = [] last_output_text_offset: int = 0 @@ -223,8 +223,12 @@ def output_token_ids_array(self) -> array: assert isinstance(self._output_token_ids, array) return self._output_token_ids - def append_token_id(self, token_id: int, logprob: float) -> None: - self.last_appended_tokens.append(token_id) + def append_token_id(self, + token_id: int, + logprob: float, + track_delta: bool = False) -> None: + if track_delta: + self.last_appended_tokens.append(token_id) self._output_token_ids.append(token_id) self._new_appended_tokens.append(token_id) @@ -463,11 +467,6 @@ def get_output_text_to_return(self, buffer_length: int) -> str: return self.output_text[:-buffer_length] if truncate else ( self.output_text) - def get_output_text_to_return_len(self, buffer_length: int) -> int: - if not buffer_length or self.is_finished(): - return len(self.output_text) - return max(0, len(self.output_text) - buffer_length) - def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size @@ -485,11 +484,14 @@ def reset_state_for_recompute(self): """Reset the sequence states for recomputation.""" self.data.reset_state_for_recompute() - def append_token_id(self, token_id: int, logprobs: Dict[int, - Logprob]) -> None: + def append_token_id(self, + token_id: int, + logprobs: Dict[int, Logprob], + track_delta: bool = False) -> None: assert token_id in logprobs self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id].logprob) + self.data.append_token_id(token_id, logprobs[token_id].logprob, + track_delta) def get_len(self) -> int: return self.data.get_len() From e7a2b55c983baf815cc7508c627fe00371bf1937 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 11 Sep 2024 16:10:42 -0700 Subject: [PATCH 11/13] Add tests --- .buildkite/test-pipeline.yaml | 1 + tests/async_engine/test_async_llm_engine.py | 160 ++++++++++++++++++-- 2 files changed, 147 insertions(+), 14 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 5b8d6a8739f1..766d4a66d30e 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -50,6 +50,7 @@ steps: - tests/worker commands: - pytest -v -s async_engine # Async Engine + - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py - pytest -v -s test_inputs.py - pytest -v -s multimodal - pytest -v -s test_utils.py # Utils diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 03494581431d..1afd156ec01b 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -1,8 +1,10 @@ import asyncio import os +import uuid from asyncio import CancelledError +from copy import copy from dataclasses import dataclass -from typing import Optional +from typing import List, Optional import pytest import pytest_asyncio @@ -12,6 +14,7 @@ from vllm.config import ParallelConfig from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine from vllm.outputs import RequestOutput as RealRequestOutput +from vllm.sampling_params import RequestOutputKind from ..conftest import cleanup from ..utils import wait_for_gpu_memory_to_clear @@ -130,8 +133,17 @@ def start_engine(): timeout_s=60, ) + num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1")) + print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}") + return AsyncLLMEngine.from_engine_args( - AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True)) + AsyncEngineArgs(model="facebook/opt-125m", + enforce_eager=True, + num_scheduler_steps=num_scheduler_steps)) + + +def uid() -> str: + return str(uuid.uuid4()) @pytest_asyncio.fixture(scope="module") @@ -156,57 +168,177 @@ def should_do_global_cleanup_after_test(request) -> bool: @pytest.mark.asyncio(scope="module") async def test_asyncio_run(async_engine): + scheduler_config = await async_engine.get_scheduler_config() + num_scheduler_steps = scheduler_config.num_scheduler_steps + async def run(prompt: str): sampling_params = SamplingParams( temperature=0, max_tokens=32, + min_tokens=32, ) + output_count = 0 + final_output = None async for output in async_engine.generate(prompt, sampling_params, - request_id=prompt): + request_id=uid()): + output_count += 1 final_output = output - return final_output + return final_output, output_count results = await asyncio.gather( run("test0"), - run("test1"), + run("test0"), ) assert len(results) == 2 + first, second = results + + # remove nondeterministic fields for comparison + first[0].metrics = None + second[0].metrics = None + first[0].request_id = None + second[0].request_id = None + + assert str(first) == str(second) + + output_count = results[0][1] + if num_scheduler_steps == 1: + assert output_count == 32 + else: + assert 1 < output_count < 32 + + +@pytest.mark.asyncio(scope="module") +async def test_output_kinds(async_engine): + """Test that output_kind works as expected and that + results are equivalent across different kinds.""" + + scheduler_config = await async_engine.get_scheduler_config() + num_scheduler_steps = scheduler_config.num_scheduler_steps + + sampling_params = SamplingParams( + temperature=0, + max_tokens=32, + min_tokens=32, + ) + + async def run(prompt: str, kind: RequestOutputKind): + params = copy(sampling_params) + params.output_kind = kind + + output_count = 0 + final_output = None + async for output in async_engine.generate(prompt, + params, + request_id=uid()): + output_count += 1 + final_output = output + + assert final_output is not None + return (final_output.prompt_token_ids, + final_output.outputs[0].token_ids, + final_output.outputs[0].text, output_count) + + async def run_deltas(prompt: str): + params = copy(sampling_params) + params.output_kind = RequestOutputKind.DELTA + + prompt_tokens = None + output_tokens: List[int] = [] + output_text = "" + output_count = 0 + async for output in async_engine.generate(prompt, + params, + request_id=uid()): + token_ids = output.outputs[0].token_ids + text = output.outputs[0].text + + # Ensure we get prompt ids iff we haven't yet received output tokens + if output_tokens: + assert 1 <= len(token_ids) <= num_scheduler_steps + assert text + assert not output.prompt_token_ids + else: + assert output.prompt_token_ids + prompt_tokens = output.prompt_token_ids + + output_tokens.extend(token_ids) + output_text += text + + output_count += 1 + return prompt_tokens, output_tokens, output_text, output_count + + results = await asyncio.gather( + run("common input prompt", RequestOutputKind.CUMULATIVE), + run("common input prompt", RequestOutputKind.FINAL_ONLY), + run_deltas("common input prompt")) + + # Make sure outputs are the same + prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results) + assert len(prompt_set) == 1 + + text_set = set(text for _, _, text, _ in results) + assert len(text_set) == 1 + + tokens_set = set(tuple(ids) for _, ids, _, _ in results) + assert len(tokens_set) == 1 + + cumulative, final, deltas = results + + # output message counts + assert cumulative[3] == deltas[3] + + if num_scheduler_steps == 1: + assert cumulative[3] == 32 + else: + assert 1 < cumulative[3] < 32 + + assert final[3] == 1 @pytest.mark.asyncio(scope="module") async def test_cancellation(async_engine): + scheduler_config = await async_engine.get_scheduler_config() + num_scheduler_steps = scheduler_config.num_scheduler_steps + sampling_params = SamplingParams( temperature=0, - min_tokens=10, - max_tokens=10, + min_tokens=13, + max_tokens=13, ) + stop_at = 5 if num_scheduler_steps == 1 else 1 + + request_id = uid() + i = 0 with pytest.raises(CancelledError): async for output in async_engine.generate("test2", sampling_params, - request_id="test2"): + request_id=request_id): assert not output.finished i += 1 - if i == 5: - await async_engine.abort("test2") + if i == stop_at: + await async_engine.abort(request_id) - assert i == 5 + assert i == stop_at @pytest.mark.asyncio(scope="module") async def test_delayed_generator(async_engine): + scheduler_config = await async_engine.get_scheduler_config() + + if scheduler_config.num_scheduler_steps != 1: + pytest.skip("no need to test this one with multistep") + sampling_params = SamplingParams( temperature=0, min_tokens=10, max_tokens=10, ) - stream = async_engine.generate("test3", - sampling_params, - request_id="test3") + stream = async_engine.generate("test3", sampling_params, request_id=uid()) i = 0 final_output: Optional[RealRequestOutput] = None async for output in stream: From 6b1f3558517bf6c2ce16ea925d1474b64a32979b Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 11 Sep 2024 17:05:36 -0700 Subject: [PATCH 12/13] Some rework/simplification --- vllm/engine/llm_engine.py | 5 +- vllm/engine/output_processor/multi_step.py | 5 +- vllm/engine/output_processor/single_step.py | 6 +-- vllm/outputs.py | 21 +++----- vllm/sequence.py | 60 +++++++++++++-------- 5 files changed, 48 insertions(+), 49 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index cf1dc98e234b..a382b3401b95 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1472,13 +1472,10 @@ def _advance_to_next_step( " (i.e sampling_params.n == 1 and no " "sampling_params.best_of > 1)") sample = sequence_group_outputs.samples[0] - track_delta = seq_group.sampling_params.output_kind == ( - RequestOutputKind.DELTA) assert len(seq_group.seqs) == 1 seq = seq_group.seqs[0] - seq.append_token_id(sample.output_token, sample.logprobs, - track_delta) + seq.append_token_id(sample.output_token, sample.logprobs) def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index ad7b6e42220f..c73db765fc3b 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -8,7 +8,7 @@ single_step_process_prompt_logprob) from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger -from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.sampling_params import SamplingParams from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer @@ -161,8 +161,6 @@ def _process_seq_outputs(self, seq: Sequence, valid_samples = valid_samples[:i + 1] break - track_delta = sampling_params.output_kind == RequestOutputKind.DELTA - # Incrementally append tokens to the sequence, as if we had only one new # token. for output_token_id, output_logprob in zip(output_token_ids, @@ -170,7 +168,6 @@ def _process_seq_outputs(self, seq: Sequence, seq.append_token_id( token_id=output_token_id, logprobs=output_logprob, - track_delta=track_delta, ) self._process_decode_and_stop(seq, sampling_params) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 2aad6d0e63f7..e288aa0c4aaf 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -6,7 +6,7 @@ SequenceGroupOutputProcessor) from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger -from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.sampling_params import SamplingParams from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer @@ -119,9 +119,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # only have one sequence seq = seq_group.seqs[0] if not is_async: - delta = sampling_params.output_kind == RequestOutputKind.DELTA - seq.append_token_id(sample.output_token, sample.logprobs, - delta) + seq.append_token_id(sample.output_token, sample.logprobs) if sampling_params.detokenize and self.detokenizer: new_char_count = self.detokenizer.decode_sequence_inplace( seq, sampling_params) diff --git a/vllm/outputs.py b/vllm/outputs.py index 3c256d94cfb0..85ea9196b25d 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -145,34 +145,25 @@ def from_seq_group(cls, # logprobs are not requested. include_logprobs = sampling_params.logprobs is not None text_buffer_length = sampling_params.output_text_buffer_length + delta = sampling_params.output_kind == RequestOutputKind.DELTA outputs = [] include_prompt = True for seq in top_n_seqs: - output_text = seq.get_output_text_to_return(text_buffer_length) + output_text = seq.get_output_text_to_return( + text_buffer_length, delta) + output_token_ids = seq.get_output_token_ids_to_return(delta) output_logprobs = seq.output_logprobs if include_logprobs else None - all_output_token_ids = seq.data._output_token_ids - output_token_ids: GenericSequence[int] - if sampling_params.output_kind == RequestOutputKind.DELTA: - # Get and reset token id delta - output_token_ids = seq.data.last_appended_tokens - seq.data.last_appended_tokens = [] - # Use last offset to make text delta, update last offset - last_text_offset = seq.data.last_output_text_offset - new_text_len = len(output_text) - output_text = output_text[last_text_offset:] - seq.data.last_output_text_offset = new_text_len + if delta: # Slice logprobs delta if applicable if output_logprobs: output_logprobs = output_logprobs[-len(output_token_ids):] # Don't include prompt if this is after the first output # containing decode token ids - if include_prompt and len(all_output_token_ids) > len( + if include_prompt and seq.get_output_len() > len( output_token_ids): include_prompt = False - else: - output_token_ids = all_output_token_ids outputs.append( CompletionOutput( diff --git a/vllm/sequence.py b/vllm/sequence.py index 047e2aabc391..98a8b7358606 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,8 +5,9 @@ from array import array from collections import defaultdict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping, - Optional, Set, Tuple, Union, cast) +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional +from typing import Sequence as GenericSequence +from typing import Set, Tuple, Union, cast import msgspec import torch @@ -168,10 +169,6 @@ class SequenceData(msgspec.Struct, # It is used to compute mrope_position_ids. _mrope_position_delta: Optional[int] = None - # These are used to keep track of delta outputs - last_appended_tokens: List[int] = [] - last_output_text_offset: int = 0 - def __post_init__(self) -> None: assert self._prompt_token_ids.typecode == "l" assert self._output_token_ids.typecode == "l" @@ -234,13 +231,7 @@ def mrope_position_delta(self) -> Optional[int]: def mrope_position_delta(self, new_mrope_position_delta): self._mrope_position_delta = new_mrope_position_delta - def append_token_id(self, - token_id: int, - logprob: float, - track_delta: bool = False) -> None: - if track_delta: - self.last_appended_tokens.append(token_id) - + def append_token_id(self, token_id: int, logprob: float) -> None: self._output_token_ids.append(token_id) self._new_appended_tokens.append(token_id) self._cached_all_token_ids.append(token_id) @@ -417,6 +408,10 @@ def __init__( self.status = SequenceStatus.WAITING self.stop_reason: Union[int, str, None] = None + # These are used to keep track of delta outputs + self._last_token_ids_offset: int = 0 + self._last_output_text_offset: int = 0 + # Used for incremental detokenization self.prefix_offset = 0 self.read_offset = 0 @@ -472,11 +467,35 @@ def prompt_adapter_id(self) -> int: return self.prompt_adapter_request.prompt_adapter_id \ if self.prompt_adapter_request else 0 - def get_output_text_to_return(self, buffer_length: int) -> str: + def get_output_text_to_return(self, buffer_length: int, + delta: bool) -> str: + """If delta is True, only new text since the last call to + this method is returned""" + # We return the full output text if the sequence is finished. truncate = buffer_length and not self.is_finished() - return self.output_text[:-buffer_length] if truncate else ( - self.output_text) + if not delta: + return self.output_text[:-buffer_length] if truncate else ( + self.output_text) + length = len(self.output_text) - buffer_length + last_offset = self._last_output_text_offset + if last_offset < length: + self._last_output_text_offset = length + return self.output_text[last_offset:length] + return "" + + def get_output_token_ids_to_return(self, + delta: bool) -> GenericSequence[int]: + """If delta is True, only new tokens since the last call to + this method are returned""" + if not delta: + return self.get_output_token_ids() + length = self.get_output_len() + last_offset = self._last_token_ids_offset + if last_offset < length: + self._last_token_ids_offset = length + return self.data._output_token_ids[last_offset:] + return () def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size @@ -495,14 +514,11 @@ def reset_state_for_recompute(self): """Reset the sequence states for recomputation.""" self.data.reset_state_for_recompute() - def append_token_id(self, - token_id: int, - logprobs: Dict[int, Logprob], - track_delta: bool = False) -> None: + def append_token_id(self, token_id: int, logprobs: Dict[int, + Logprob]) -> None: assert token_id in logprobs self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id].logprob, - track_delta) + self.data.append_token_id(token_id, logprobs[token_id].logprob) def get_len(self) -> int: return self.data.get_len() From 3233a92924d9a58b41a8bfc173769f7e6d5d02c3 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 11 Sep 2024 17:45:27 -0700 Subject: [PATCH 13/13] Remove obsolete engine.step_return_finished_only field --- vllm/engine/llm_engine.py | 4 ---- vllm/entrypoints/llm.py | 6 ------ 2 files changed, 10 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a382b3401b95..f5810d76e993 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -225,9 +225,6 @@ def __init__( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, - # To improve performance, only final requests outputs may be required. - # If this set to true, then no intermediate outputs will be returned. - step_return_finished_only: bool = False, ) -> None: logger.info( "Initializing an LLM engine (v%s) with config: " @@ -295,7 +292,6 @@ def __init__( self.observability_config = observability_config or ObservabilityConfig( ) self.log_stats = log_stats - self.step_return_finished_only = step_return_finished_only if not self.model_config.skip_tokenizer_init: self.tokenizer = self._init_tokenizer() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 63bc9a6a6f8f..c01bffeb4289 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -707,9 +707,6 @@ def _run_engine( f"output: {0:.2f} toks/s"), ) - # In the loop below, only finished outputs are used - self.llm_engine.step_return_finished_only = True - # Run the engine. outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] total_in_toks = 0 @@ -734,9 +731,6 @@ def _run_engine( f"output: {out_spd:.2f} toks/s") pbar.update(1) - # Restore original behavior - self.llm_engine.step_return_finished_only = False - if use_tqdm: pbar.close() # Sort the outputs by request ID.