diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c1a513cb9fa9..97161aa15ddd 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1182,9 +1182,12 @@ def _process_model_outputs( output_by_sequence_group = create_output_by_sequence_group( output, num_seq_groups=len(scheduled_seq_groups)) - # seq_id to (output token count, text len) + # seq_id to (output token count, text len), # only for delta output seq groups previous_output_lens: Dict[int, Tuple[int, int]] = {} + # Seq groups whose outputs should not have prompt details included, + # only applies to delta output seq groups + exclude_prompt_seq_group_ids = set() # Update the scheduled sequence groups with the model outputs. for scheduled_seq_group, outputs, seq_group_meta in zip( @@ -1196,8 +1199,13 @@ def _process_model_outputs( == RequestOutputKind.DELTA): text_buffer_length = params.output_text_buffer_length for seq in seq_group.seqs: + output_len = seq.get_output_len() + if output_len: + # Exclude the prompt if the seq group already has + # completion tokens + exclude_prompt_seq_group_ids.add(seq_group.request_id) previous_output_lens[seq.seq_id] = ( - seq.get_output_len(), + output_len, seq.get_output_text_to_return_len(text_buffer_length)) seq_group.update_num_computed_tokens( @@ -1236,23 +1244,27 @@ def _process_model_outputs( for scheduled_seq_group in scheduled_seq_groups: seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) + include_prompt = seq_group.request_id not in ( + exclude_prompt_seq_group_ids) request_output = RequestOutputFactory.create( - seq_group, previous_output_lens) + seq_group, previous_output_lens, include_prompt) if request_output: request_outputs.append(request_output) for seq_group in ignored_seq_groups: params = seq_group.sampling_params + include_prompt = True if params is not None and params.output_kind == ( RequestOutputKind.DELTA): if not seq_group.is_finished(): continue # Ignored seq groups have no delta, but we must still return # an "empty" RequestOutput when finished + include_prompt = False for seq in seq_group.seqs: previous_output_lens[seq.seq_id] = (seq.get_output_len(), seq.output_text) request_output = RequestOutputFactory.create( - seq_group, previous_output_lens) + seq_group, previous_output_lens, include_prompt) if request_output: request_outputs.append(request_output) return request_outputs diff --git a/vllm/outputs.py b/vllm/outputs.py index d57b7c88c69d..78143ed43329 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -93,7 +93,7 @@ def __init__( self, request_id: str, prompt: Optional[str], - prompt_token_ids: List[int], + prompt_token_ids: Optional[List[int]], prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], finished: bool, @@ -118,6 +118,7 @@ def from_seq_group( cls, seq_group: SequenceGroup, prior_output_lens: Dict[int, Tuple[int, int]], + include_prompt: bool = True, ) -> Optional["RequestOutput"]: sampling_params = seq_group.sampling_params if sampling_params is None: @@ -176,11 +177,18 @@ def from_seq_group( seq.stop_reason)) # Every sequence in the sequence group should have the same prompt. - prompt = seq_group.prompt - prompt_token_ids = seq_group.prompt_token_ids - encoder_prompt = seq_group.encoder_prompt - encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids - prompt_logprobs = seq_group.prompt_logprobs + if include_prompt: + prompt = seq_group.prompt + prompt_token_ids = seq_group.prompt_token_ids + encoder_prompt = seq_group.encoder_prompt + encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids + prompt_logprobs = seq_group.prompt_logprobs + else: + prompt = None + prompt_token_ids = None + encoder_prompt = None + encoder_prompt_token_ids = None + prompt_logprobs = None finished_time = time.time() if finished else None seq_group.set_finished_time(finished_time) return cls(seq_group.request_id, @@ -256,12 +264,15 @@ def __repr__(self): class RequestOutputFactory: @staticmethod - def create(seq_group, - previous_output_lens: Dict[int, Tuple[int, int]] = {}): # noqa + def create( + seq_group, + previous_output_lens: Dict[int, Tuple[int, int]] = {}, # noqa + include_prompt: bool = True): # Determine the type based on a condition, for example: if hasattr(seq_group, 'embeddings') and seq_group.embeddings is not None: return EmbeddingRequestOutput.from_seq_group(seq_group) else: return RequestOutput.from_seq_group(seq_group, - previous_output_lens) + previous_output_lens, + include_prompt)