Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LogProbs for Chat Completions in OpenAI #2918

Merged
merged 18 commits into from
Feb 26, 2024
6 changes: 6 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class ChatCompletionRequest(BaseModel):
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
Expand Down Expand Up @@ -96,6 +98,8 @@ def to_sampling_params(self) -> SamplingParams:
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
best_of=self.best_of,
top_k=self.top_k,
ignore_eos=self.ignore_eos,
Expand Down Expand Up @@ -216,6 +220,7 @@ class ChatMessage(BaseModel):
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None


Expand All @@ -236,6 +241,7 @@ class DeltaMessage(BaseModel):
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None


Expand Down
68 changes: 57 additions & 11 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import time
import codecs
from fastapi import Request
from typing import AsyncGenerator, AsyncIterator, Optional, List, Union
from typing import AsyncGenerator, AsyncIterator, Optional, List, Union, Dict, Callable
from vllm.logger import init_logger
from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo)
LogProbs, UsageInfo)
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA

logger = init_logger(__name__)

TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]


class OpenAIServingChat(OpenAIServing):

Expand Down Expand Up @@ -77,10 +82,11 @@ async def create_chat_completion(
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id)
request, result_generator, request_id, self._create_logprobs)
else:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id)
request, raw_request, result_generator, request_id,
self._create_logprobs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert this since we can directly call that method.


def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
Expand All @@ -89,8 +95,9 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
return request.messages[-1].role

async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], request_id: str
self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], request_id: str,
create_logprobs_fn: TypeCreateLogProbsFn
) -> Union[ErrorResponse, AsyncGenerator[str, None]]:

model_name = request.model
Expand All @@ -101,7 +108,10 @@ async def chat_completion_stream_generator(
role = self.get_chat_request_role(request)
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i, delta=DeltaMessage(role=role), finish_reason=None)
index=i,
delta=DeltaMessage(role=role),
logprobs=None,
finish_reason=None)
chunk = ChatCompletionStreamResponse(id=request_id,
object=chunk_object_type,
created=created_time,
Expand All @@ -118,6 +128,7 @@ async def chat_completion_stream_generator(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]

if last_msg_content:
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
Expand All @@ -129,6 +140,7 @@ async def chat_completion_stream_generator(
object=chunk_object_type,
created=created_time,
choices=[choice_data],
logprobs=None,
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
Expand All @@ -145,15 +157,31 @@ async def chat_completion_stream_generator(
if finish_reason_sent[i]:
continue

delta_token_ids = output.token_ids[previous_num_tokens[i]:]
top_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None

if request.logprobs is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if request.logprobs:

assert(top_logprobs is not None),\
"top_logprobs must be provided when logprobs is requested"
logprobs = create_logprobs_fn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not invoke self._create_logprobs directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what serving_completion.py did. I can only speculate on their design decision but perhaps it was to allow for extensibility with other log probability formats in the future--not just the open ai ones?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am happy to change it if you think that is best.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this function is already defined in a class, so there's no need to delegate a callable function here. Finally we can delete this and relevant things and invoke parent method _create_logprobs directly.

token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None

delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just moved this line down.


if output.finish_reason is None:
# Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
Expand All @@ -174,6 +202,7 @@ async def chat_completion_stream_generator(
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=output.finish_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
Expand All @@ -191,9 +220,10 @@ async def chat_completion_stream_generator(
yield "data: [DONE]\n\n"

async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Request,
result_generator: AsyncIterator[RequestOutput],
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
self, request: ChatCompletionRequest, raw_request: Request,
result_generator: AsyncIterator[RequestOutput], request_id: str,
create_logprobs_fn: TypeCreateLogProbsFn
) -> Union[ErrorResponse, ChatCompletionResponse]:

model_name = request.model
created_time = int(time.monotonic())
Expand All @@ -208,11 +238,27 @@ async def chat_completion_full_generator(
assert final_res is not None

choices = []

role = self.get_chat_request_role(request)
for output in final_res.outputs:
token_ids = output.token_ids
top_logprobs = output.logprobs

if request.logprobs is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

assert(top_logprobs is not None),\
Copy link
Collaborator

@esmeetu esmeetu Feb 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to move this assertion to protocol.py using validator or something else. Return 400 when this condition isn't met.

"top_logprobs must be provided when logprobs is requested"
logprobs = create_logprobs_fn(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None

choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role=role, content=output.text),
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
Expand Down
Loading