Skip to content

Commit

Permalink
[Feature][Frontend]: Add support for stream_options in `ChatComplet…
Browse files Browse the repository at this point in the history
…ionRequest` (vllm-project#5135)
  • Loading branch information
Etelis authored and jimpang committed Jul 8, 2024
1 parent a23a014 commit 6de6718
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 10 deletions.
101 changes: 101 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,5 +1343,106 @@ async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI,
assert embeddings.usage.total_tokens == 17


@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_stream_options(server, client: openai.AsyncOpenAI,
model_name: str):
prompt = "What is the capital of France?"

# Test stream=True, stream_options=None
stream = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options=None,
)
chunks = []
async for chunk in stream:
chunks.append(chunk.choices[0].text)
assert len(chunks) > 0
assert "usage" not in chunk

# Test stream=True, stream_options={"include_usage": False}
stream = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={"include_usage": False},
)
chunks = []
async for chunk in stream:
chunks.append(chunk.choices[0].text)
assert len(chunks) > 0
assert "usage" not in chunk

# Test stream=True, stream_options={"include_usage": True}
stream = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={"include_usage": True},
)
chunks = []
finish_reason_count = 0
async for chunk in stream:
if chunk.choices[0].finish_reason is None:
assert chunk.usage is None
chunks.append(chunk.choices[0].text)
else:
assert chunk.usage is None
finish_reason_count += 1

# The last message should have usage and no choices
last_message = await stream.__anext__()
assert last_message.usage is not None
assert last_message.usage.prompt_tokens > 0
assert last_message.usage.completion_tokens > 0
assert last_message.usage.total_tokens == (
last_message.usage.prompt_tokens +
last_message.usage.completion_tokens)
assert last_message.choices == []

# Test stream=False, stream_options={"include_usage": None}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": None},
)

# Test stream=False, stream_options={"include_usage": False}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": False},
)

# Test stream=False, stream_options={"include_usage": True}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": True},
)


if __name__ == "__main__":
pytest.main([__file__])
14 changes: 14 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ class ResponseFormat(OpenAIBaseModel):
type: Literal["text", "json_object"]


class StreamOptions(OpenAIBaseModel):
include_usage: Optional[bool]


class FunctionDefinition(OpenAIBaseModel):
name: str
description: Optional[str] = None
Expand Down Expand Up @@ -140,6 +144,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
le=torch.iinfo(torch.long).max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
tools: Optional[List[ChatCompletionToolsParam]] = None
Expand Down Expand Up @@ -269,6 +274,15 @@ def logit_bias_logits_processor(
logits_processors=logits_processors,
)

@model_validator(mode='before')
@classmethod
def validate_stream_options(cls, values):
if (values.get('stream_options') is not None
and not values.get('stream')):
raise ValueError(
"stream_options can only be set if stream is true")
return values

@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
Expand Down
44 changes: 34 additions & 10 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ async def chat_completion_stream_generator(
created=created_time,
choices=[choice_data],
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"

Expand Down Expand Up @@ -274,6 +277,9 @@ async def chat_completion_stream_generator(
choices=[choice_data],
logprobs=None,
model=model_name)
if (request.stream_options and
request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json(
exclude_unset=True)
yield f"data: {data}\n\n"
Expand Down Expand Up @@ -327,17 +333,14 @@ async def chat_completion_stream_generator(
created=created_time,
choices=[choice_data],
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
else:
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens +
previous_num_tokens[i],
)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
Expand All @@ -350,12 +353,33 @@ async def chat_completion_stream_generator(
created=created_time,
choices=[choice_data],
model=model_name)
if final_usage is not None:
chunk.usage = final_usage
data = chunk.model_dump_json(exclude_unset=True,
exclude_none=True)
if (request.stream_options
and request.stream_options.include_usage):
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
finish_reason_sent[i] = True

if (request.stream_options
and request.stream_options.include_usage):
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens +
previous_num_tokens[i],
)

final_usage_chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[],
model=model_name,
usage=final_usage)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=True, exclude_none=True))
yield f"data: {final_usage_data}\n\n"

except ValueError as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
Expand Down

0 comments on commit 6de6718

Please sign in to comment.