Skip to content

Commit

Permalink
[BugFix] Fix chat API continuous usage stats (vllm-project#9357)
Browse files Browse the repository at this point in the history
Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com>
  • Loading branch information
njhill authored and mfournioux committed Nov 20, 2024
1 parent 76359e1 commit a754e52
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 76 deletions.
14 changes: 12 additions & 2 deletions tests/entrypoints/openai/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,18 +433,28 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
model=model_name,
messages=messages,
max_tokens=10,
extra_body=dict(min_tokens=10),
temperature=0.0,
stream=True,
stream_options={
"include_usage": True,
"continuous_usage_stats": True
"continuous_usage_stats": True,
},
)
last_completion_tokens = 0
async for chunk in stream:
assert chunk.usage.prompt_tokens >= 0
assert chunk.usage.completion_tokens >= 0
assert last_completion_tokens == 0 or \
chunk.usage.completion_tokens > last_completion_tokens or \
(
not chunk.choices and
chunk.usage.completion_tokens == last_completion_tokens
)
assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
chunk.usage.completion_tokens)
last_completion_tokens = chunk.usage.completion_tokens

assert last_completion_tokens == 10


# NOTE: Not sure why, but when I place this after `test_guided_regex_chat`
Expand Down
115 changes: 41 additions & 74 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,14 @@ async def chat_completion_stream_generator(
yield "data: [DONE]\n\n"
return

stream_options = request.stream_options
if stream_options:
include_usage = stream_options.include_usage
include_continuous_usage = include_usage and \
stream_options.continuous_usage_stats
else:
include_usage, include_continuous_usage = False, False

try:
async for res in result_generator:
if res.prompt_token_ids is not None:
Expand All @@ -348,7 +356,6 @@ async def chat_completion_stream_generator(
# NOTE num_choices defaults to 1 so this usually executes
# once per request
for i in range(num_choices):
tool_parser = tool_parsers[i]
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
Expand All @@ -364,19 +371,12 @@ async def chat_completion_stream_generator(
choices=[choice_data],
model=model_name)

# if usage should be included
if (request.stream_options
and request.stream_options.include_usage):
# if continuous usage stats are requested, add it
if request.stream_options.continuous_usage_stats:
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)
chunk.usage = usage
# otherwise don't
else:
chunk.usage = None
# if continuous usage stats are requested, add it
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)

data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
Expand Down Expand Up @@ -404,17 +404,11 @@ async def chat_completion_stream_generator(
created=created_time,
choices=[choice_data],
model=model_name)
if (request.stream_options and
request.stream_options.include_usage):
if (request.stream_options.
continuous_usage_stats):
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)
chunk.usage = usage
else:
chunk.usage = None
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)

data = chunk.model_dump_json(
exclude_unset=True)
Expand Down Expand Up @@ -494,36 +488,11 @@ async def chat_completion_stream_generator(

if output.finish_reason is None:
# Send token-by-token response for each request.n

choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
logprobs=logprobs,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)

# handle usage stats if requested & if continuous
if (request.stream_options
and request.stream_options.include_usage):
if request.stream_options.continuous_usage_stats:
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens +
completion_tokens,
)
chunk.usage = usage
else:
chunk.usage = None

data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"

# if the model is finished generating
else:
Expand Down Expand Up @@ -573,34 +542,32 @@ async def chat_completion_stream_generator(
finish_reason=output.finish_reason
if not auto_tools_called else "tool_calls",
stop_reason=output.stop_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
if request.stream_options.continuous_usage_stats:
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens +
completion_tokens,
)
chunk.usage = usage
else:
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"

finish_reason_sent[i] = True

chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)

# handle usage stats if requested & if continuous
if include_continuous_usage:
completion_tokens = previous_num_tokens[i]
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)

data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"

# once the final token is handled, if stream_options.include_usage
# is sent, send the usage
if (request.stream_options
and request.stream_options.include_usage):
completion_tokens = previous_num_tokens[i]
if include_usage:
completion_tokens = sum(previous_num_tokens)
final_usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
Expand Down

0 comments on commit a754e52

Please sign in to comment.