Skip to content

Commit

Permalink
Merge pull request #117 from Xaenalt/bugfix/SamplingParams
Browse files Browse the repository at this point in the history
[Cherry-Pick] [Bugfix] Set SamplingParams.max_tokens for OpenAI requests if not provided by user (vllm-project#6954)
  • Loading branch information
Xaenalt authored Aug 2, 2024
2 parents 7a21f52 + e2652b2 commit 69f02b7
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 85 deletions.
22 changes: 22 additions & 0 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,28 @@ async def test_logits_bias(client: openai.AsyncOpenAI):
assert first_response != completion.choices[0].text


@pytest.mark.asyncio
async def test_allowed_token_ids(client: openai.AsyncOpenAI):
prompt = "Hello, my name is"
max_tokens = 1
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)

# Test exclusive selection
allowed_ids = [21555, 21557, 21558]
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
seed=42,
extra_body=dict(allowed_token_ids=allowed_ids),
logprobs=1,
)
response_tokens = completion.choices[0].logprobs.tokens
assert len(response_tokens) == 1
assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
Expand Down
39 changes: 39 additions & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import asyncio
from contextlib import suppress
from dataclasses import dataclass
from unittest.mock import MagicMock

from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.transformers_utils.tokenizer import get_tokenizer

MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"
Expand Down Expand Up @@ -42,3 +47,37 @@ async def _async_serving_chat_init():
def test_async_serving_chat_init():
serving_completion = asyncio.run(_async_serving_chat_init())
assert serving_completion.chat_template == CHAT_TEMPLATE


def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=AsyncLLMEngine)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)

serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
served_model_names=[MODEL_NAME],
response_role="assistant",
chat_template=CHAT_TEMPLATE,
lora_modules=None,
prompt_adapters=None,
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 1+1?"
}],
guided_decoding_backend="outlines",
)

with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

# AsyncLLMEngine.generate(inputs, sampling_params, ...)
assert mock_engine.generate.call_args.args[1].max_tokens == 93

req.max_tokens = 10
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))

assert mock_engine.generate.call_args.args[1].max_tokens == 10
74 changes: 74 additions & 0 deletions vllm/entrypoints/openai/logits_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from functools import lru_cache
from typing import Dict, FrozenSet, Iterable, List, Optional, Union

import torch
from transformers import PreTrainedTokenizer

from vllm.sampling_params import LogitsProcessor


class AllowedTokenIdsLogitsProcessor:
"""Logits processor for constraining generated tokens to a
specific set of token ids."""

def __init__(self, allowed_ids: Iterable[int]):
self.allowed_ids: Optional[List[int]] = list(allowed_ids)
self.mask: Optional[torch.Tensor] = None

def __call__(self, token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
if self.mask is None:
self.mask = torch.ones((logits.shape[-1], ),
dtype=torch.bool,
device=logits.device)
self.mask[self.allowed_ids] = False
self.allowed_ids = None
logits.masked_fill_(self.mask, float("-inf"))
return logits


@lru_cache(maxsize=32)
def _get_allowed_token_ids_logits_processor(
allowed_token_ids: FrozenSet[int],
vocab_size: int,
) -> LogitsProcessor:
if not allowed_token_ids:
raise ValueError("Empty allowed_token_ids provided")
if not all(0 <= tid < vocab_size for tid in allowed_token_ids):
raise ValueError("allowed_token_ids contains "
"out-of-vocab token id")
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)


def get_logits_processors(
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
allowed_token_ids: Optional[List[int]],
tokenizer: PreTrainedTokenizer) -> List[LogitsProcessor]:
logits_processors = []
if logit_bias:
try:
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
clamped_logit_bias: Dict[int, float] = {
int(token_id): min(100.0, max(-100.0, bias))
for token_id, bias in logit_bias.items()
}
except ValueError as exc:
raise ValueError(
"Found token_id in logit_bias that is not "
"an integer or string representing an integer") from exc

def logit_bias_logits_processor(token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in clamped_logit_bias.items():
logits[token_id] += bias
return logits

logits_processors.append(logit_bias_logits_processor)

if allowed_token_ids is not None:
logits_processors.append(
_get_allowed_token_ids_logits_processor(
frozenset(allowed_token_ids), tokenizer.vocab_size))

return logits_processors
84 changes: 36 additions & 48 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator
from transformers import PreTrainedTokenizer
from typing_extensions import Annotated

from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.utils import random_uuid


Expand Down Expand Up @@ -213,30 +215,22 @@ class ChatCompletionRequest(OpenAIBaseModel):

# doc: end-chat-completion-extra-params

def to_sampling_params(self) -> SamplingParams:
# We now allow logprobs being true without top_logrobs.
def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens

logits_processors = None
if self.logit_bias:
logit_bias: Dict[int, float] = {}
try:
for token_id, bias in self.logit_bias.items():
# Convert token_id to integer before we add to LLMEngine
# Clamp the bias between -100 and 100 per OpenAI API spec
logit_bias[int(token_id)] = min(100, max(-100, bias))
except ValueError as exc:
raise ValueError(f"Found token_id `{token_id}` in logit_bias "
f"but token_id must be an integer or string "
f"representing an integer") from exc

def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in logit_bias.items():
logits[token_id] += bias
return logits

logits_processors = [logit_bias_logits_processor]
# We now allow logprobs being true without top_logrobs.
logits_processors = get_logits_processors(
logit_bias=self.logit_bias,
allowed_token_ids=None,
tokenizer=tokenizer,
)
if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor)

return SamplingParams(
n=self.n,
Expand All @@ -254,7 +248,7 @@ def logit_bias_logits_processor(
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens,
max_tokens=max_tokens,
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
Expand Down Expand Up @@ -358,6 +352,7 @@ class CompletionRequest(OpenAIBaseModel):
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
allowed_token_ids: Optional[List[int]] = None
# doc: end-completion-sampling-params

# doc: begin-completion-extra-params
Expand Down Expand Up @@ -407,30 +402,23 @@ class CompletionRequest(OpenAIBaseModel):

# doc: end-completion-extra-params

def to_sampling_params(self):
def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens

echo_without_generation = self.echo and self.max_tokens == 0

logits_processors = None
if self.logit_bias:
logit_bias: Dict[int, float] = {}
try:
for token_id, bias in self.logit_bias.items():
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
logit_bias[int(token_id)] = min(100, max(-100, bias))
except ValueError as exc:
raise ValueError(f"Found token_id `{token_id}` in logit_bias "
f"but token_id must be an integer or string "
f"representing an integer") from exc

def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in logit_bias.items():
logits[token_id] += bias
return logits

logits_processors = [logit_bias_logits_processor]
logits_processors = get_logits_processors(
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids,
tokenizer=tokenizer,
)
if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor)

return SamplingParams(
n=self.n,
Expand All @@ -447,7 +435,7 @@ def logit_bias_logits_processor(
stop_token_ids=self.stop_token_ids,
logprobs=self.logprobs,
ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens if not echo_without_generation else 1,
max_tokens=max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
Expand Down
23 changes: 8 additions & 15 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
PromptAdapterPath)
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
Expand Down Expand Up @@ -132,28 +130,23 @@ async def create_chat_completion(

request_id = f"chat-{random_uuid()}"
try:
sampling_params = request.to_sampling_params()
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logits_processor = (
await
get_guided_decoding_logits_processor(guided_decoding_backend,
request, tokenizer))
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logits_processor)
await self._guided_decode_logits_processor(request, tokenizer))

prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
truncate_prompt_tokens=sampling_params.truncate_prompt_tokens,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)

sampling_params = request.to_sampling_params(
tokenizer,
guided_decode_logits_processor,
default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"]))

self._log_inputs(request_id,
prompt_inputs,
params=sampling_params,
Expand Down
27 changes: 9 additions & 18 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
OpenAIServing,
PromptAdapterPath)
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
Expand Down Expand Up @@ -93,31 +91,24 @@ async def create_completion(self, request: CompletionRequest,

tokenizer = await self.engine.get_tokenizer(lora_request)

sampling_params = request.to_sampling_params()
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logit_processor = (
await
get_guided_decoding_logits_processor(guided_decoding_backend,
request, tokenizer))
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logit_processor)

guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
prompts = list(
self._tokenize_prompt_input_or_inputs(
request,
tokenizer,
request.prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
))

for i, prompt_inputs in enumerate(prompts):
sampling_params = request.to_sampling_params(
tokenizer,
guided_decode_logits_processor,
default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"]))

request_id_item = f"{request_id}-{i}"

self._log_inputs(request_id_item,
Expand Down
Loading

0 comments on commit 69f02b7

Please sign in to comment.