Skip to content

Commit 56577b2

Browse files
jlcmoorejimpang
authored and
jimpang
committed
Add LogProbs for Chat Completions in OpenAI (vllm-project#2918)
1 parent aaa1428 commit 56577b2

File tree

3 files changed

+57
-14
lines changed

3 files changed

+57
-14
lines changed

tests/entrypoints/test_openai_server.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -155,15 +155,18 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
155155
}]
156156

157157
# test single completion
158-
chat_completion = await client.chat.completions.create(
159-
model=model_name,
160-
messages=messages,
161-
max_tokens=10,
162-
)
158+
chat_completion = await client.chat.completions.create(model=model_name,
159+
messages=messages,
160+
max_tokens=10,
161+
logprobs=True,
162+
top_logprobs=10)
163163
assert chat_completion.id is not None
164164
assert chat_completion.choices is not None and len(
165165
chat_completion.choices) == 1
166166
assert chat_completion.choices[0].message is not None
167+
assert chat_completion.choices[0].logprobs is not None
168+
assert chat_completion.choices[0].logprobs.top_logprobs is not None
169+
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 10
167170
message = chat_completion.choices[0].message
168171
assert message.content is not None and len(message.content) >= 10
169172
assert message.role == "assistant"
@@ -198,13 +201,11 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
198201
single_output = single_completion.choices[0].text
199202
single_usage = single_completion.usage
200203

201-
stream = await client.completions.create(
202-
model=model_name,
203-
prompt=prompt,
204-
max_tokens=5,
205-
temperature=0.0,
206-
stream=True,
207-
)
204+
stream = await client.completions.create(model=model_name,
205+
prompt=prompt,
206+
max_tokens=5,
207+
temperature=0.0,
208+
stream=True)
208209
chunks = []
209210
async for chunk in stream:
210211
chunks.append(chunk.choices[0].text)

vllm/entrypoints/openai/protocol.py

+8
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ class ChatCompletionRequest(BaseModel):
6363
seed: Optional[int] = None
6464
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
6565
stream: Optional[bool] = False
66+
logprobs: Optional[bool] = False
67+
top_logprobs: Optional[int] = None
6668
presence_penalty: Optional[float] = 0.0
6769
frequency_penalty: Optional[float] = 0.0
6870
logit_bias: Optional[Dict[str, float]] = None
@@ -84,6 +86,8 @@ class ChatCompletionRequest(BaseModel):
8486
length_penalty: Optional[float] = 1.0
8587

8688
def to_sampling_params(self) -> SamplingParams:
89+
if self.logprobs and not self.top_logprobs:
90+
raise ValueError("Top logprobs must be set when logprobs is.")
8791
return SamplingParams(
8892
n=self.n,
8993
presence_penalty=self.presence_penalty,
@@ -96,6 +100,8 @@ def to_sampling_params(self) -> SamplingParams:
96100
stop=self.stop,
97101
stop_token_ids=self.stop_token_ids,
98102
max_tokens=self.max_tokens,
103+
logprobs=self.top_logprobs if self.logprobs else None,
104+
prompt_logprobs=self.top_logprobs if self.echo else None,
99105
best_of=self.best_of,
100106
top_k=self.top_k,
101107
ignore_eos=self.ignore_eos,
@@ -216,6 +222,7 @@ class ChatMessage(BaseModel):
216222
class ChatCompletionResponseChoice(BaseModel):
217223
index: int
218224
message: ChatMessage
225+
logprobs: Optional[LogProbs] = None
219226
finish_reason: Optional[Literal["stop", "length"]] = None
220227

221228

@@ -236,6 +243,7 @@ class DeltaMessage(BaseModel):
236243
class ChatCompletionResponseStreamChoice(BaseModel):
237244
index: int
238245
delta: DeltaMessage
246+
logprobs: Optional[LogProbs] = None
239247
finish_reason: Optional[Literal["stop", "length"]] = None
240248

241249

vllm/entrypoints/openai/serving_chat.py

+36-2
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,10 @@ async def chat_completion_stream_generator(
101101
role = self.get_chat_request_role(request)
102102
for i in range(request.n):
103103
choice_data = ChatCompletionResponseStreamChoice(
104-
index=i, delta=DeltaMessage(role=role), finish_reason=None)
104+
index=i,
105+
delta=DeltaMessage(role=role),
106+
logprobs=None,
107+
finish_reason=None)
105108
chunk = ChatCompletionStreamResponse(id=request_id,
106109
object=chunk_object_type,
107110
created=created_time,
@@ -118,6 +121,7 @@ async def chat_completion_stream_generator(
118121
"content") and request.messages[-1].get(
119122
"role") == role:
120123
last_msg_content = request.messages[-1]["content"]
124+
121125
if last_msg_content:
122126
for i in range(request.n):
123127
choice_data = ChatCompletionResponseStreamChoice(
@@ -129,6 +133,7 @@ async def chat_completion_stream_generator(
129133
object=chunk_object_type,
130134
created=created_time,
131135
choices=[choice_data],
136+
logprobs=None,
132137
model=model_name)
133138
data = chunk.model_dump_json(exclude_unset=True)
134139
yield f"data: {data}\n\n"
@@ -145,15 +150,29 @@ async def chat_completion_stream_generator(
145150
if finish_reason_sent[i]:
146151
continue
147152

153+
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
154+
top_logprobs = output.logprobs[
155+
previous_num_tokens[i]:] if output.logprobs else None
156+
157+
if request.logprobs:
158+
logprobs = self._create_logprobs(
159+
token_ids=delta_token_ids,
160+
top_logprobs=top_logprobs,
161+
num_output_top_logprobs=request.logprobs,
162+
initial_text_offset=len(previous_texts[i]),
163+
)
164+
else:
165+
logprobs = None
166+
148167
delta_text = output.text[len(previous_texts[i]):]
149168
previous_texts[i] = output.text
150169
previous_num_tokens[i] = len(output.token_ids)
151-
152170
if output.finish_reason is None:
153171
# Send token-by-token response for each request.n
154172
choice_data = ChatCompletionResponseStreamChoice(
155173
index=i,
156174
delta=DeltaMessage(content=delta_text),
175+
logprobs=logprobs,
157176
finish_reason=None)
158177
chunk = ChatCompletionStreamResponse(
159178
id=request_id,
@@ -174,6 +193,7 @@ async def chat_completion_stream_generator(
174193
choice_data = ChatCompletionResponseStreamChoice(
175194
index=i,
176195
delta=DeltaMessage(content=delta_text),
196+
logprobs=logprobs,
177197
finish_reason=output.finish_reason)
178198
chunk = ChatCompletionStreamResponse(
179199
id=request_id,
@@ -208,11 +228,25 @@ async def chat_completion_full_generator(
208228
assert final_res is not None
209229

210230
choices = []
231+
211232
role = self.get_chat_request_role(request)
212233
for output in final_res.outputs:
234+
token_ids = output.token_ids
235+
top_logprobs = output.logprobs
236+
237+
if request.logprobs:
238+
logprobs = self._create_logprobs(
239+
token_ids=token_ids,
240+
top_logprobs=top_logprobs,
241+
num_output_top_logprobs=request.logprobs,
242+
)
243+
else:
244+
logprobs = None
245+
213246
choice_data = ChatCompletionResponseChoice(
214247
index=output.index,
215248
message=ChatMessage(role=role, content=output.text),
249+
logprobs=logprobs,
216250
finish_reason=output.finish_reason,
217251
)
218252
choices.append(choice_data)

0 commit comments

Comments
 (0)