Skip to content

Commit e0ade06

Browse files
authored
Support logit bias for OpenAI API (#3027)
1 parent 4bd18ec commit e0ade06

File tree

4 files changed

+83
-12
lines changed

4 files changed

+83
-12
lines changed

tests/entrypoints/test_openai_server.py

+48
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import openai # use the official client for correctness check
1010
from huggingface_hub import snapshot_download # downloading lora to test lora requests
1111

12+
from vllm.transformers_utils.tokenizer import get_tokenizer
13+
1214
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
1315
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
1416
LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
@@ -310,5 +312,51 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI,
310312
assert texts[0] == texts[1]
311313

312314

315+
async def test_logits_bias(server, client: openai.AsyncOpenAI):
316+
prompt = "Hello, my name is"
317+
max_tokens = 5
318+
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
319+
320+
# Test exclusive selection
321+
token_id = 1000
322+
completion = await client.completions.create(
323+
model=MODEL_NAME,
324+
prompt=prompt,
325+
max_tokens=max_tokens,
326+
temperature=0.0,
327+
logit_bias={str(token_id): 100},
328+
)
329+
assert completion.choices[0].text is not None and len(
330+
completion.choices[0].text) >= 5
331+
response_tokens = tokenizer(completion.choices[0].text,
332+
add_special_tokens=False)["input_ids"]
333+
expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
334+
add_special_tokens=False)["input_ids"]
335+
assert all([
336+
response == expected
337+
for response, expected in zip(response_tokens, expected_tokens)
338+
])
339+
340+
# Test ban
341+
completion = await client.completions.create(
342+
model=MODEL_NAME,
343+
prompt=prompt,
344+
max_tokens=max_tokens,
345+
temperature=0.0,
346+
)
347+
response_tokens = tokenizer(completion.choices[0].text,
348+
add_special_tokens=False)["input_ids"]
349+
first_response = completion.choices[0].text
350+
completion = await client.completions.create(
351+
model=MODEL_NAME,
352+
prompt=prompt,
353+
max_tokens=max_tokens,
354+
temperature=0.0,
355+
logit_bias={str(token): -100
356+
for token in response_tokens},
357+
)
358+
assert first_response != completion.choices[0].text
359+
360+
313361
if __name__ == "__main__":
314362
pytest.main([__file__])

vllm/entrypoints/openai/protocol.py

+33
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from vllm.utils import random_uuid
99
from vllm.sampling_params import SamplingParams
1010

11+
import torch
12+
1113

1214
class ErrorResponse(BaseModel):
1315
object: str = "error"
@@ -88,6 +90,21 @@ class ChatCompletionRequest(BaseModel):
8890
def to_sampling_params(self) -> SamplingParams:
8991
if self.logprobs and not self.top_logprobs:
9092
raise ValueError("Top logprobs must be set when logprobs is.")
93+
94+
logits_processors = None
95+
if self.logit_bias:
96+
97+
def logit_bias_logits_processor(
98+
token_ids: List[int],
99+
logits: torch.Tensor) -> torch.Tensor:
100+
for token_id, bias in self.logit_bias.items():
101+
# Clamp the bias between -100 and 100 per OpenAI API spec
102+
bias = min(100, max(-100, bias))
103+
logits[int(token_id)] += bias
104+
return logits
105+
106+
logits_processors = [logit_bias_logits_processor]
107+
91108
return SamplingParams(
92109
n=self.n,
93110
presence_penalty=self.presence_penalty,
@@ -111,6 +128,7 @@ def to_sampling_params(self) -> SamplingParams:
111128
spaces_between_special_tokens=self.spaces_between_special_tokens,
112129
include_stop_str_in_output=self.include_stop_str_in_output,
113130
length_penalty=self.length_penalty,
131+
logits_processors=logits_processors,
114132
)
115133

116134

@@ -149,6 +167,20 @@ class CompletionRequest(BaseModel):
149167
def to_sampling_params(self):
150168
echo_without_generation = self.echo and self.max_tokens == 0
151169

170+
logits_processors = None
171+
if self.logit_bias:
172+
173+
def logit_bias_logits_processor(
174+
token_ids: List[int],
175+
logits: torch.Tensor) -> torch.Tensor:
176+
for token_id, bias in self.logit_bias.items():
177+
# Clamp the bias between -100 and 100 per OpenAI API spec
178+
bias = min(100, max(-100, bias))
179+
logits[int(token_id)] += bias
180+
return logits
181+
182+
logits_processors = [logit_bias_logits_processor]
183+
152184
return SamplingParams(
153185
n=self.n,
154186
best_of=self.best_of,
@@ -172,6 +204,7 @@ def to_sampling_params(self):
172204
spaces_between_special_tokens=(self.spaces_between_special_tokens),
173205
include_stop_str_in_output=self.include_stop_str_in_output,
174206
length_penalty=self.length_penalty,
207+
logits_processors=logits_processors,
175208
)
176209

177210

vllm/entrypoints/openai/serving_chat.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,13 @@ async def create_chat_completion(
3939
See https://platform.openai.com/docs/api-reference/chat/create
4040
for the API specification. This API mimics the OpenAI ChatCompletion API.
4141
42-
NOTE: Currently we do not support the following features:
42+
NOTE: Currently we do not support the following feature:
4343
- function_call (Users should implement this by themselves)
44-
- logit_bias (to be supported by vLLM engine)
4544
"""
4645
error_check_ret = await self._check_model(request)
4746
if error_check_ret is not None:
4847
return error_check_ret
4948

50-
if request.logit_bias is not None and len(request.logit_bias) > 0:
51-
# TODO: support logit_bias in vLLM engine.
52-
return self.create_error_response(
53-
"logit_bias is not currently supported")
54-
5549
try:
5650
prompt = self.tokenizer.apply_chat_template(
5751
conversation=request.messages,

vllm/entrypoints/openai/serving_completion.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,9 @@ async def create_completion(self, request: CompletionRequest,
264264
See https://platform.openai.com/docs/api-reference/completions/create
265265
for the API specification. This API mimics the OpenAI Completion API.
266266
267-
NOTE: Currently we do not support the following features:
267+
NOTE: Currently we do not support the following feature:
268268
- suffix (the language models we currently support do not support
269269
suffix)
270-
- logit_bias (to be supported by vLLM engine)
271270
"""
272271
error_check_ret = await self._check_model(request)
273272
if error_check_ret is not None:
@@ -277,9 +276,6 @@ async def create_completion(self, request: CompletionRequest,
277276
if request.suffix is not None:
278277
return self.create_error_response(
279278
"suffix is not currently supported")
280-
if request.logit_bias is not None and len(request.logit_bias) > 0:
281-
return self.create_error_response(
282-
"logit_bias is not currently supported")
283279

284280
model_name = request.model
285281
request_id = f"cmpl-{random_uuid()}"

0 commit comments

Comments
 (0)