Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Add engine option to return only deltas or final output #7381

Merged
merged 22 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1eb9991
[Core] Add engine option to return only deltas or final output
njhill Aug 9, 2024
9bc3fdd
Fixes
njhill Aug 10, 2024
ef2e59f
Fix ignored sequence case
njhill Aug 10, 2024
dc1f3f2
Also exclude prompt details in subsequent outputs in delta mode
njhill Aug 13, 2024
9d35a00
Fix prompt token counts in streaming cases
njhill Aug 13, 2024
b7ff44e
Simplification suggestion from @joerunde
njhill Aug 14, 2024
34df9bd
Make tests more robust
njhill Aug 15, 2024
a68506f
Merge remote-tracking branch 'origin/main' into reduce-output
njhill Aug 15, 2024
cfe7118
Merge remote-tracking branch 'origin/main' into reduce-output
njhill Aug 27, 2024
45fd069
Post-merge wip
njhill Aug 27, 2024
3f21ad6
Merge remote-tracking branch 'origin/main' into reduce-output
njhill Sep 6, 2024
d59ffd1
Merge remote-tracking branch 'origin/main' into reduce-output
njhill Sep 8, 2024
d2f36dd
Merge remote-tracking branch 'origin/main' into reduce-output
njhill Sep 10, 2024
2843365
Fix delta computation, remove unrelated changes
njhill Sep 10, 2024
2736ab1
Merge remote-tracking branch 'origin/main' into reduce-output
njhill Sep 10, 2024
a045dff
Address Alex's comments, fix include_prompt logic
njhill Sep 10, 2024
58f6112
Merge remote-tracking branch 'origin/main' into reduce-output
njhill Sep 11, 2024
e7a2b55
Add tests
njhill Sep 11, 2024
6b1f355
Some rework/simplification
njhill Sep 12, 2024
3233a92
Remove obsolete engine.step_return_finished_only field
njhill Sep 12, 2024
f351ed2
Merge remote-tracking branch 'origin/main' into reduce-output
njhill Sep 12, 2024
75814bd
Merge remote-tracking branch 'origin/main' into reduce-output
njhill Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 45 additions & 6 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
PoolerOutput, SamplerOutput, Sequence,
SequenceGroup, SequenceGroupMetadata,
Expand Down Expand Up @@ -1182,11 +1182,32 @@ def _process_model_outputs(
output_by_sequence_group = create_output_by_sequence_group(
output, num_seq_groups=len(scheduled_seq_groups))

# seq_id to (output token count, text len),
# only for delta output seq groups
previous_output_lens: Dict[int, Tuple[int, int]] = {}
# Seq groups whose outputs should not have prompt details included,
# only applies to delta output seq groups
exclude_prompt_seq_group_ids = set()

# Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs, seq_group_meta in zip(
scheduled_seq_groups, output_by_sequence_group,
seq_group_metadata_list):
seq_group = scheduled_seq_group.seq_group
seq_group: SequenceGroup = scheduled_seq_group.seq_group
params = seq_group.sampling_params
if params is not None and (params.output_kind
== RequestOutputKind.DELTA):
text_buffer_length = params.output_text_buffer_length
for seq in seq_group.seqs:
output_len = seq.get_output_len()
if output_len:
# Exclude the prompt if the seq group already has
# completion tokens
exclude_prompt_seq_group_ids.add(seq_group.request_id)
previous_output_lens[seq.seq_id] = (
output_len,
seq.get_output_text_to_return_len(text_buffer_length))

seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
if output is not None and len(output) > 0:
Expand Down Expand Up @@ -1223,11 +1244,29 @@ def _process_model_outputs(
for scheduled_seq_group in scheduled_seq_groups:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
include_prompt = seq_group.request_id not in (
exclude_prompt_seq_group_ids)
request_output = RequestOutputFactory.create(
seq_group, previous_output_lens, include_prompt)
if request_output:
request_outputs.append(request_output)
for seq_group in ignored_seq_groups:
request_output = RequestOutputFactory.create(seq_group)
request_outputs.append(request_output)
params = seq_group.sampling_params
include_prompt = True
if params is not None and params.output_kind == (
RequestOutputKind.DELTA):
if not seq_group.is_finished():
continue
# Ignored seq groups have no delta, but we must still return
# an "empty" RequestOutput when finished
include_prompt = False
for seq in seq_group.seqs:
previous_output_lens[seq.seq_id] = (seq.get_output_len(),
seq.output_text)
request_output = RequestOutputFactory.create(
seq_group, previous_output_lens, include_prompt)
if request_output:
request_outputs.append(request_output)
return request_outputs

def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Expand Down
16 changes: 7 additions & 9 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs
Expand Down Expand Up @@ -547,14 +547,12 @@ def _validate_and_add_requests(
raise ValueError("The lengths of prompts and lora_request "
"must be the same.")

if isinstance(params, list):
params = [
self._add_guided_processor(param, guided_options)
if isinstance(param, SamplingParams) else param
for param in params
]
elif isinstance(params, SamplingParams):
params = self._add_guided_processor(params, guided_options)
for sp in params if isinstance(params, list) else (params, ):
if isinstance(sp, SamplingParams):
self._add_guided_processor(sp, guided_options)

# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY

# Add requests to the engine.
for i, request_inputs in enumerate(inputs):
Expand Down
7 changes: 6 additions & 1 deletion vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sampling_params import (LogitsProcessor, RequestOutputKind,
SamplingParams)
from vllm.utils import random_uuid

# torch is mocked during docs generation,
Expand Down Expand Up @@ -275,6 +276,8 @@ def to_sampling_params(
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
)

@model_validator(mode='before')
Expand Down Expand Up @@ -461,6 +464,8 @@ def to_sampling_params(
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
)

@model_validator(mode="before")
Expand Down
59 changes: 27 additions & 32 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,16 @@ async def chat_completion_stream_generator(

# Send response for each token for each request.n (index)
num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices

num_prompt_tokens = 0

try:
async for res in result_generator:
if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids)

# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
Expand All @@ -240,11 +244,11 @@ async def chat_completion_stream_generator(
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats):
prompt_tokens = len(res.prompt_token_ids)
usage = UsageInfo(prompt_tokens=prompt_tokens,
completion_tokens=0,
total_tokens=prompt_tokens)
if request.stream_options.continuous_usage_stats:
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)
chunk.usage = usage
else:
chunk.usage = None
Expand Down Expand Up @@ -280,12 +284,10 @@ async def chat_completion_stream_generator(
request.stream_options.include_usage):
if (request.stream_options.
continuous_usage_stats):
prompt_tokens = len(
res.prompt_token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=prompt_tokens)
total_tokens=num_prompt_tokens)
chunk.usage = usage
else:
chunk.usage = None
Expand All @@ -301,25 +303,20 @@ async def chat_completion_stream_generator(
if finish_reason_sent[i]:
continue

delta_token_ids = output.token_ids[previous_num_tokens[i]:]
out_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None

if request.logprobs and request.top_logprobs is not None:
assert out_logprobs is not None, (
assert output.logprobs is not None, (
"Did not output logprobs")
logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
token_ids=output.token_ids,
top_logprobs=output.logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs,
)
else:
logprobs = None

delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
delta_text = output.text
previous_num_tokens[i] += len(output.token_ids)

if request.tool_choice and type(
request.tool_choice
Expand Down Expand Up @@ -348,13 +345,12 @@ async def chat_completion_stream_generator(
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats):
prompt_tokens = len(res.prompt_token_ids)
if request.stream_options.continuous_usage_stats:
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens +
total_tokens=num_prompt_tokens +
completion_tokens,
)
chunk.usage = usage
Expand All @@ -365,7 +361,6 @@ async def chat_completion_stream_generator(
yield f"data: {data}\n\n"
else:
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
Expand All @@ -380,13 +375,12 @@ async def chat_completion_stream_generator(
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats):
prompt_tokens = len(res.prompt_token_ids)
if request.stream_options.continuous_usage_stats:
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens +
total_tokens=num_prompt_tokens +
completion_tokens,
)
chunk.usage = usage
Expand All @@ -398,10 +392,11 @@ async def chat_completion_stream_generator(

if (request.stream_options
and request.stream_options.include_usage):
completion_tokens = previous_num_tokens[i]
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens + previous_num_tokens[i],
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)

final_usage_chunk = ChatCompletionStreamResponse(
Expand Down
23 changes: 12 additions & 11 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,15 @@ async def completion_stream_generator(
tokenizer: PreTrainedTokenizer,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices * num_prompts
previous_text_lens = [0] * num_choices * num_prompts
previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts

try:
async for prompt_idx, res in result_generator:
if res.prompt_token_ids is not None:
num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)

for output in res.outputs:
i = output.index + prompt_idx * num_choices
Expand All @@ -251,11 +254,9 @@ async def completion_stream_generator(
has_echoed[i] = True
else:
# return just the delta
delta_text = output.text[len(previous_texts[i]):]
delta_token_ids = output.token_ids[
previous_num_tokens[i]:]
out_logprobs = output.logprobs[previous_num_tokens[
i]:] if output.logprobs else None
delta_text = output.text
delta_token_ids = output.token_ids
out_logprobs = output.logprobs

if request.logprobs is not None:
assert out_logprobs is not None, (
Expand All @@ -265,13 +266,13 @@ async def completion_stream_generator(
top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs,
tokenizer=tokenizer,
initial_text_offset=len(previous_texts[i]),
initial_text_offset=previous_text_lens[i],
)
else:
logprobs = None

previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
previous_text_lens[i] += len(output.text)
previous_num_tokens[i] += len(output.token_ids)
finish_reason = output.finish_reason
stop_reason = output.stop_reason

Expand All @@ -292,8 +293,8 @@ async def completion_stream_generator(
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats
or output.finish_reason is not None):
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = previous_num_tokens[i]
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
Expand Down
Loading
Loading