diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 0fbc4cca83bd2..3af0032fd2fb0 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -433,18 +433,28 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, model=model_name, messages=messages, max_tokens=10, + extra_body=dict(min_tokens=10), temperature=0.0, stream=True, stream_options={ "include_usage": True, - "continuous_usage_stats": True + "continuous_usage_stats": True, }, ) + last_completion_tokens = 0 async for chunk in stream: assert chunk.usage.prompt_tokens >= 0 - assert chunk.usage.completion_tokens >= 0 + assert last_completion_tokens == 0 or \ + chunk.usage.completion_tokens > last_completion_tokens or \ + ( + not chunk.choices and + chunk.usage.completion_tokens == last_completion_tokens + ) assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + chunk.usage.completion_tokens) + last_completion_tokens = chunk.usage.completion_tokens + + assert last_completion_tokens == 10 # NOTE: Not sure why, but when I place this after `test_guided_regex_chat` diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 9470b6ea03ef6..acb56e4a886e1 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -330,6 +330,14 @@ async def chat_completion_stream_generator( yield "data: [DONE]\n\n" return + stream_options = request.stream_options + if stream_options: + include_usage = stream_options.include_usage + include_continuous_usage = include_usage and \ + stream_options.continuous_usage_stats + else: + include_usage, include_continuous_usage = False, False + try: async for res in result_generator: if res.prompt_token_ids is not None: @@ -348,7 +356,6 @@ async def chat_completion_stream_generator( # NOTE num_choices defaults to 1 so this usually executes # once per request for i in range(num_choices): - tool_parser = tool_parsers[i] choice_data = ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage( @@ -364,19 +371,12 @@ async def chat_completion_stream_generator( choices=[choice_data], model=model_name) - # if usage should be included - if (request.stream_options - and request.stream_options.include_usage): - # if continuous usage stats are requested, add it - if request.stream_options.continuous_usage_stats: - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=0, - total_tokens=num_prompt_tokens) - chunk.usage = usage - # otherwise don't - else: - chunk.usage = None + # if continuous usage stats are requested, add it + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" @@ -404,17 +404,11 @@ async def chat_completion_stream_generator( created=created_time, choices=[choice_data], model=model_name) - if (request.stream_options and - request.stream_options.include_usage): - if (request.stream_options. - continuous_usage_stats): - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=0, - total_tokens=num_prompt_tokens) - chunk.usage = usage - else: - chunk.usage = None + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) data = chunk.model_dump_json( exclude_unset=True) @@ -494,36 +488,11 @@ async def chat_completion_stream_generator( if output.finish_reason is None: # Send token-by-token response for each request.n - choice_data = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, logprobs=logprobs, finish_reason=None) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - - # handle usage stats if requested & if continuous - if (request.stream_options - and request.stream_options.include_usage): - if request.stream_options.continuous_usage_stats: - completion_tokens = len(output.token_ids) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + - completion_tokens, - ) - chunk.usage = usage - else: - chunk.usage = None - - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" # if the model is finished generating else: @@ -573,34 +542,32 @@ async def chat_completion_stream_generator( finish_reason=output.finish_reason if not auto_tools_called else "tool_calls", stop_reason=output.stop_reason) - chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) - if (request.stream_options - and request.stream_options.include_usage): - if request.stream_options.continuous_usage_stats: - completion_tokens = len(output.token_ids) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + - completion_tokens, - ) - chunk.usage = usage - else: - chunk.usage = None - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" + finish_reason_sent[i] = True + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # handle usage stats if requested & if continuous + if include_continuous_usage: + completion_tokens = previous_num_tokens[i] + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + # once the final token is handled, if stream_options.include_usage # is sent, send the usage - if (request.stream_options - and request.stream_options.include_usage): - completion_tokens = previous_num_tokens[i] + if include_usage: + completion_tokens = sum(previous_num_tokens) final_usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=completion_tokens,