Skip to content

Commit

Permalink
Also exclude prompt details in subsequent outputs in delta mode
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill committed Aug 13, 2024
1 parent ef2e59f commit dc1f3f2
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 13 deletions.
20 changes: 16 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,9 +1182,12 @@ def _process_model_outputs(
output_by_sequence_group = create_output_by_sequence_group(
output, num_seq_groups=len(scheduled_seq_groups))

# seq_id to (output token count, text len)
# seq_id to (output token count, text len),
# only for delta output seq groups
previous_output_lens: Dict[int, Tuple[int, int]] = {}
# Seq groups whose outputs should not have prompt details included,
# only applies to delta output seq groups
exclude_prompt_seq_group_ids = set()

# Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs, seq_group_meta in zip(
Expand All @@ -1196,8 +1199,13 @@ def _process_model_outputs(
== RequestOutputKind.DELTA):
text_buffer_length = params.output_text_buffer_length
for seq in seq_group.seqs:
output_len = seq.get_output_len()
if output_len:
# Exclude the prompt if the seq group already has
# completion tokens
exclude_prompt_seq_group_ids.add(seq_group.request_id)
previous_output_lens[seq.seq_id] = (
seq.get_output_len(),
output_len,
seq.get_output_text_to_return_len(text_buffer_length))

seq_group.update_num_computed_tokens(
Expand Down Expand Up @@ -1236,23 +1244,27 @@ def _process_model_outputs(
for scheduled_seq_group in scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
include_prompt = seq_group.request_id not in (
exclude_prompt_seq_group_ids)
request_output = RequestOutputFactory.create(
seq_group, previous_output_lens)
seq_group, previous_output_lens, include_prompt)
if request_output:
request_outputs.append(request_output)
for seq_group in ignored_seq_groups:
params = seq_group.sampling_params
include_prompt = True
if params is not None and params.output_kind == (
RequestOutputKind.DELTA):
if not seq_group.is_finished():
continue
# Ignored seq groups have no delta, but we must still return
# an "empty" RequestOutput when finished
include_prompt = False
for seq in seq_group.seqs:
previous_output_lens[seq.seq_id] = (seq.get_output_len(),
seq.output_text)
request_output = RequestOutputFactory.create(
seq_group, previous_output_lens)
seq_group, previous_output_lens, include_prompt)
if request_output:
request_outputs.append(request_output)
return request_outputs
Expand Down
29 changes: 20 additions & 9 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
self,
request_id: str,
prompt: Optional[str],
prompt_token_ids: List[int],
prompt_token_ids: Optional[List[int]],
prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput],
finished: bool,
Expand All @@ -118,6 +118,7 @@ def from_seq_group(
cls,
seq_group: SequenceGroup,
prior_output_lens: Dict[int, Tuple[int, int]],
include_prompt: bool = True,
) -> Optional["RequestOutput"]:
sampling_params = seq_group.sampling_params
if sampling_params is None:
Expand Down Expand Up @@ -176,11 +177,18 @@ def from_seq_group(
seq.stop_reason))

# Every sequence in the sequence group should have the same prompt.
prompt = seq_group.prompt
prompt_token_ids = seq_group.prompt_token_ids
encoder_prompt = seq_group.encoder_prompt
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs
if include_prompt:
prompt = seq_group.prompt
prompt_token_ids = seq_group.prompt_token_ids
encoder_prompt = seq_group.encoder_prompt
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs
else:
prompt = None
prompt_token_ids = None
encoder_prompt = None
encoder_prompt_token_ids = None
prompt_logprobs = None
finished_time = time.time() if finished else None
seq_group.set_finished_time(finished_time)
return cls(seq_group.request_id,
Expand Down Expand Up @@ -256,12 +264,15 @@ def __repr__(self):
class RequestOutputFactory:

@staticmethod
def create(seq_group,
previous_output_lens: Dict[int, Tuple[int, int]] = {}): # noqa
def create(
seq_group,
previous_output_lens: Dict[int, Tuple[int, int]] = {}, # noqa
include_prompt: bool = True):
# Determine the type based on a condition, for example:
if hasattr(seq_group,
'embeddings') and seq_group.embeddings is not None:
return EmbeddingRequestOutput.from_seq_group(seq_group)
else:
return RequestOutput.from_seq_group(seq_group,
previous_output_lens)
previous_output_lens,
include_prompt)

0 comments on commit dc1f3f2

Please sign in to comment.