Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support logit bias for OpenAI API #3027

Merged
merged 8 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import ray # using Ray for overall ease of process management, parallel requests, and debugging.
import openai # use the official client for correctness check

from vllm.transformers_utils.tokenizer import get_tokenizer

MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here

Expand Down Expand Up @@ -250,5 +252,51 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI):
assert texts[0] == texts[1]


async def test_logits_bias(server, client: openai.AsyncOpenAI):
prompt = "Hello, my name is"
max_tokens = 5
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)

# Test exclusive selection
token_id = 1000
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
logit_bias={str(token_id): 100},
)
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
response_tokens = tokenizer(completion.choices[0].text,
add_special_tokens=False)["input_ids"]
expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
add_special_tokens=False)["input_ids"]
assert all([
response == expected
for response, expected in zip(response_tokens, expected_tokens)
])

# Test ban
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
)
response_tokens = tokenizer(completion.choices[0].text,
add_special_tokens=False)["input_ids"]
first_response = completion.choices[0].text
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
logit_bias={str(token): -100
for token in response_tokens},
)
assert first_response != completion.choices[0].text


if __name__ == "__main__":
pytest.main([__file__])
18 changes: 18 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams

import torch


class ErrorResponse(BaseModel):
object: str = "error"
Expand Down Expand Up @@ -137,6 +139,21 @@ class CompletionRequest(BaseModel):
def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0

logits_processors = None

if self.logit_bias:

def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
logits[int(token_id)] += bias
return logits

logits_processors = [logit_bias_logits_processor]

return SamplingParams(
n=self.n,
best_of=self.best_of,
Expand All @@ -158,6 +175,7 @@ def to_sampling_params(self):
spaces_between_special_tokens=(self.spaces_between_special_tokens),
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
)


Expand Down
8 changes: 1 addition & 7 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,13 @@ async def create_chat_completion(
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.

NOTE: Currently we do not support the following features:
NOTE: Currently we do not support the following feature:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret

if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in vLLM engine.
return self.create_error_response(
"logit_bias is not currently supported")

try:
prompt = self.tokenizer.apply_chat_template(
conversation=request.messages,
Expand Down
6 changes: 1 addition & 5 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,9 @@ async def create_completion(self, request: CompletionRequest,
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.

NOTE: Currently we do not support the following features:
NOTE: Currently we do not support the following feature:
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
Expand All @@ -272,9 +271,6 @@ async def create_completion(self, request: CompletionRequest,
if request.suffix is not None:
return self.create_error_response(
"suffix is not currently supported")
if request.logit_bias is not None and len(request.logit_bias) > 0:
return self.create_error_response(
"logit_bias is not currently supported")

model_name = request.model
request_id = f"cmpl-{random_uuid()}"
Expand Down
Loading