diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 1e08cd9712bc..56e35950410a 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -258,6 +258,14 @@ async def completion_stream_generator( has_echoed = [False] * num_choices * num_prompts num_prompt_tokens = [0] * num_prompts + stream_options = request.stream_options + if stream_options: + include_usage = stream_options.include_usage + include_continuous_usage = include_usage and \ + stream_options.continuous_usage_stats + else: + include_usage, include_continuous_usage = False, False + try: async for prompt_idx, res in result_generator: prompt_token_ids = res.prompt_token_ids @@ -276,28 +284,25 @@ async def completion_stream_generator( i = output.index + prompt_idx * num_choices assert request.max_tokens is not None - if request.echo and request.max_tokens == 0: + if request.echo and not has_echoed[i]: assert prompt_token_ids is not None assert prompt_text is not None - # only return the prompt - delta_text = prompt_text - delta_token_ids = prompt_token_ids - out_logprobs = prompt_logprobs - has_echoed[i] = True - elif (request.echo and request.max_tokens > 0 - and not has_echoed[i]): - assert prompt_token_ids is not None - assert prompt_text is not None - assert prompt_logprobs is not None - # echo the prompt and first token - delta_text = prompt_text + output.text - delta_token_ids = [ - *prompt_token_ids, *output.token_ids - ] - out_logprobs = [ - *prompt_logprobs, - *(output.logprobs or []), - ] + if request.max_tokens == 0: + # only return the prompt + delta_text = prompt_text + delta_token_ids = prompt_token_ids + out_logprobs = prompt_logprobs + else: + assert prompt_logprobs is not None + # echo the prompt and first token + delta_text = prompt_text + output.text + delta_token_ids = [ + *prompt_token_ids, *output.token_ids + ] + out_logprobs = [ + *prompt_logprobs, + *(output.logprobs or []), + ] has_echoed[i] = True else: # return just the delta @@ -341,45 +346,39 @@ async def completion_stream_generator( stop_reason=stop_reason, ) ]) - if (request.stream_options - and request.stream_options.include_usage): - if (request.stream_options.continuous_usage_stats - or output.finish_reason is not None): - prompt_tokens = num_prompt_tokens[prompt_idx] - completion_tokens = previous_num_tokens[i] - usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - if request.stream_options.continuous_usage_stats: - chunk.usage = usage - else: - chunk.usage = None + if include_continuous_usage: + prompt_tokens = num_prompt_tokens[prompt_idx] + completion_tokens = previous_num_tokens[i] + chunk.usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) response_json = chunk.model_dump_json(exclude_unset=False) yield f"data: {response_json}\n\n" - if (request.stream_options - and request.stream_options.include_usage): + total_prompt_tokens = sum(num_prompt_tokens) + total_completion_tokens = sum(previous_num_tokens) + final_usage_info = UsageInfo( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens) + + if include_usage: final_usage_chunk = CompletionStreamResponse( id=request_id, created=created_time, model=model_name, choices=[], - usage=usage, + usage=final_usage_info, ) final_usage_data = (final_usage_chunk.model_dump_json( exclude_unset=False, exclude_none=True)) yield f"data: {final_usage_data}\n\n" # report to FastAPI middleware aggregate usage across all choices - total_prompt_tokens = sum(num_prompt_tokens) - total_completion_tokens = sum(previous_num_tokens) - request_metadata.final_usage_info = UsageInfo( - prompt_tokens=total_prompt_tokens, - completion_tokens=total_completion_tokens, - total_tokens=total_prompt_tokens + total_completion_tokens) + request_metadata.final_usage_info = final_usage_info except ValueError as e: # TODO: Use a vllm-specific Validation Error @@ -413,26 +412,26 @@ def request_output_to_completion_response( for output in final_res.outputs: assert request.max_tokens is not None - if request.echo and request.max_tokens == 0: - assert prompt_text is not None - token_ids = prompt_token_ids - out_logprobs = prompt_logprobs - output_text = prompt_text - elif request.echo and request.max_tokens > 0: + if request.echo: assert prompt_text is not None - token_ids = [*prompt_token_ids, *output.token_ids] - - if request.logprobs is None: - out_logprobs = None + if request.max_tokens == 0: + token_ids = prompt_token_ids + out_logprobs = prompt_logprobs + output_text = prompt_text else: - assert prompt_logprobs is not None - assert output.logprobs is not None - out_logprobs = [ - *prompt_logprobs, - *output.logprobs, - ] - - output_text = prompt_text + output.text + token_ids = [*prompt_token_ids, *output.token_ids] + + if request.logprobs is None: + out_logprobs = None + else: + assert prompt_logprobs is not None + assert output.logprobs is not None + out_logprobs = [ + *prompt_logprobs, + *output.logprobs, + ] + + output_text = prompt_text + output.text else: token_ids = output.token_ids out_logprobs = output.logprobs