Skip to content

Commit

Permalink
[BugFix] Fix and simplify completion API usage streaming
Browse files Browse the repository at this point in the history
Similar to what was done for the chat API in #9357. Ensure that the final chunk with usage data contains aggregate counts across all choices.

Also simplify some of the prompt-handling logic in the API implementation.
  • Loading branch information
njhill committed Oct 17, 2024
1 parent eca2c5f commit 82f5103
Showing 1 changed file with 61 additions and 62 deletions.
123 changes: 61 additions & 62 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,14 @@ async def completion_stream_generator(
has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts

stream_options = request.stream_options
if stream_options:
include_usage = stream_options.include_usage
include_continuous_usage = include_usage and \
stream_options.continuous_usage_stats
else:
include_usage, include_continuous_usage = False, False

try:
async for prompt_idx, res in result_generator:
prompt_token_ids = res.prompt_token_ids
Expand All @@ -276,28 +284,25 @@ async def completion_stream_generator(
i = output.index + prompt_idx * num_choices

assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
if request.echo and not has_echoed[i]:
assert prompt_token_ids is not None
assert prompt_text is not None
# only return the prompt
delta_text = prompt_text
delta_token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
has_echoed[i] = True
elif (request.echo and request.max_tokens > 0
and not has_echoed[i]):
assert prompt_token_ids is not None
assert prompt_text is not None
assert prompt_logprobs is not None
# echo the prompt and first token
delta_text = prompt_text + output.text
delta_token_ids = [
*prompt_token_ids, *output.token_ids
]
out_logprobs = [
*prompt_logprobs,
*(output.logprobs or []),
]
if request.max_tokens == 0:
# only return the prompt
delta_text = prompt_text
delta_token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
else:
assert prompt_logprobs is not None
# echo the prompt and first token
delta_text = prompt_text + output.text
delta_token_ids = [
*prompt_token_ids, *output.token_ids
]
out_logprobs = [
*prompt_logprobs,
*(output.logprobs or []),
]
has_echoed[i] = True
else:
# return just the delta
Expand Down Expand Up @@ -341,45 +346,39 @@ async def completion_stream_generator(
stop_reason=stop_reason,
)
])
if (request.stream_options
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats
or output.finish_reason is not None):
prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = previous_num_tokens[i]
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
if request.stream_options.continuous_usage_stats:
chunk.usage = usage
else:
chunk.usage = None
if include_continuous_usage:
prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = previous_num_tokens[i]
chunk.usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)

response_json = chunk.model_dump_json(exclude_unset=False)
yield f"data: {response_json}\n\n"

if (request.stream_options
and request.stream_options.include_usage):
total_prompt_tokens = sum(num_prompt_tokens)
total_completion_tokens = sum(previous_num_tokens)
final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)

if include_usage:
final_usage_chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[],
usage=usage,
usage=final_usage_info,
)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=False, exclude_none=True))
yield f"data: {final_usage_data}\n\n"

# report to FastAPI middleware aggregate usage across all choices
total_prompt_tokens = sum(num_prompt_tokens)
total_completion_tokens = sum(previous_num_tokens)
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)
request_metadata.final_usage_info = final_usage_info

except ValueError as e:
# TODO: Use a vllm-specific Validation Error
Expand Down Expand Up @@ -413,26 +412,26 @@ def request_output_to_completion_response(

for output in final_res.outputs:
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
assert prompt_text is not None
token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
if request.echo:
assert prompt_text is not None
token_ids = [*prompt_token_ids, *output.token_ids]

if request.logprobs is None:
out_logprobs = None
if request.max_tokens == 0:
token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
output_text = prompt_text
else:
assert prompt_logprobs is not None
assert output.logprobs is not None
out_logprobs = [
*prompt_logprobs,
*output.logprobs,
]

output_text = prompt_text + output.text
token_ids = [*prompt_token_ids, *output.token_ids]

if request.logprobs is None:
out_logprobs = None
else:
assert prompt_logprobs is not None
assert output.logprobs is not None
out_logprobs = [
*prompt_logprobs,
*output.logprobs,
]

output_text = prompt_text + output.text
else:
token_ids = output.token_ids
out_logprobs = output.logprobs
Expand Down

0 comments on commit 82f5103

Please sign in to comment.