diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 0896e337b5d2..5942462861d5 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -537,6 +537,28 @@ async def test_logits_bias(client: openai.AsyncOpenAI): assert first_response != completion.choices[0].text +@pytest.mark.asyncio +async def test_allowed_token_ids(client: openai.AsyncOpenAI): + prompt = "Hello, my name is" + max_tokens = 1 + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + # Test exclusive selection + allowed_ids = [21555, 21557, 21558] + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + seed=42, + extra_body=dict(allowed_token_ids=allowed_ids), + logprobs=1, + ) + response_tokens = completion.choices[0].logprobs.tokens + assert len(response_tokens) == 1 + assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids + + @pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 464465494b71..168ba7ba888e 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -1,7 +1,12 @@ import asyncio +from contextlib import suppress from dataclasses import dataclass +from unittest.mock import MagicMock +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.transformers_utils.tokenizer import get_tokenizer MODEL_NAME = "openai-community/gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" @@ -42,3 +47,37 @@ async def _async_serving_chat_init(): def test_async_serving_chat_init(): serving_completion = asyncio.run(_async_serving_chat_init()) assert serving_completion.chat_template == CHAT_TEMPLATE + + +def test_serving_chat_should_set_correct_max_tokens(): + mock_engine = MagicMock(spec=AsyncLLMEngine) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + + serving_chat = OpenAIServingChat(mock_engine, + MockModelConfig(), + served_model_names=[MODEL_NAME], + response_role="assistant", + chat_template=CHAT_TEMPLATE, + lora_modules=None, + prompt_adapters=None, + request_logger=None) + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?" + }], + guided_decoding_backend="outlines", + ) + + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + # AsyncLLMEngine.generate(inputs, sampling_params, ...) + assert mock_engine.generate.call_args.args[1].max_tokens == 93 + + req.max_tokens = 10 + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + assert mock_engine.generate.call_args.args[1].max_tokens == 10 diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py new file mode 100644 index 000000000000..31eb5aa628c5 --- /dev/null +++ b/vllm/entrypoints/openai/logits_processors.py @@ -0,0 +1,74 @@ +from functools import lru_cache +from typing import Dict, FrozenSet, Iterable, List, Optional, Union + +import torch +from transformers import PreTrainedTokenizer + +from vllm.sampling_params import LogitsProcessor + + +class AllowedTokenIdsLogitsProcessor: + """Logits processor for constraining generated tokens to a + specific set of token ids.""" + + def __init__(self, allowed_ids: Iterable[int]): + self.allowed_ids: Optional[List[int]] = list(allowed_ids) + self.mask: Optional[torch.Tensor] = None + + def __call__(self, token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: + if self.mask is None: + self.mask = torch.ones((logits.shape[-1], ), + dtype=torch.bool, + device=logits.device) + self.mask[self.allowed_ids] = False + self.allowed_ids = None + logits.masked_fill_(self.mask, float("-inf")) + return logits + + +@lru_cache(maxsize=32) +def _get_allowed_token_ids_logits_processor( + allowed_token_ids: FrozenSet[int], + vocab_size: int, +) -> LogitsProcessor: + if not allowed_token_ids: + raise ValueError("Empty allowed_token_ids provided") + if not all(0 <= tid < vocab_size for tid in allowed_token_ids): + raise ValueError("allowed_token_ids contains " + "out-of-vocab token id") + return AllowedTokenIdsLogitsProcessor(allowed_token_ids) + + +def get_logits_processors( + logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]], + allowed_token_ids: Optional[List[int]], + tokenizer: PreTrainedTokenizer) -> List[LogitsProcessor]: + logits_processors = [] + if logit_bias: + try: + # Convert token_id to integer + # Clamp the bias between -100 and 100 per OpenAI API spec + clamped_logit_bias: Dict[int, float] = { + int(token_id): min(100.0, max(-100.0, bias)) + for token_id, bias in logit_bias.items() + } + except ValueError as exc: + raise ValueError( + "Found token_id in logit_bias that is not " + "an integer or string representing an integer") from exc + + def logit_bias_logits_processor(token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: + for token_id, bias in clamped_logit_bias.items(): + logits[token_id] += bias + return logits + + logits_processors.append(logit_bias_logits_processor) + + if allowed_token_ids is not None: + logits_processors.append( + _get_allowed_token_ids_logits_processor( + frozenset(allowed_token_ids), tokenizer.vocab_size)) + + return logits_processors diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c024bbc07c06..3b35ae1ebd70 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -5,11 +5,13 @@ import torch from pydantic import BaseModel, ConfigDict, Field, model_validator +from transformers import PreTrainedTokenizer from typing_extensions import Annotated from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.utils import random_uuid @@ -213,30 +215,22 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params - def to_sampling_params(self) -> SamplingParams: - # We now allow logprobs being true without top_logrobs. + def to_sampling_params( + self, tokenizer: PreTrainedTokenizer, + guided_decode_logits_processor: Optional[LogitsProcessor], + default_max_tokens: int) -> SamplingParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens - logits_processors = None - if self.logit_bias: - logit_bias: Dict[int, float] = {} - try: - for token_id, bias in self.logit_bias.items(): - # Convert token_id to integer before we add to LLMEngine - # Clamp the bias between -100 and 100 per OpenAI API spec - logit_bias[int(token_id)] = min(100, max(-100, bias)) - except ValueError as exc: - raise ValueError(f"Found token_id `{token_id}` in logit_bias " - f"but token_id must be an integer or string " - f"representing an integer") from exc - - def logit_bias_logits_processor( - token_ids: List[int], - logits: torch.Tensor) -> torch.Tensor: - for token_id, bias in logit_bias.items(): - logits[token_id] += bias - return logits - - logits_processors = [logit_bias_logits_processor] + # We now allow logprobs being true without top_logrobs. + logits_processors = get_logits_processors( + logit_bias=self.logit_bias, + allowed_token_ids=None, + tokenizer=tokenizer, + ) + if guided_decode_logits_processor: + logits_processors.append(guided_decode_logits_processor) return SamplingParams( n=self.n, @@ -254,7 +248,7 @@ def logit_bias_logits_processor( logprobs=self.top_logprobs if self.logprobs else None, prompt_logprobs=self.top_logprobs if self.echo else None, ignore_eos=self.ignore_eos, - max_tokens=self.max_tokens, + max_tokens=max_tokens, min_tokens=self.min_tokens, use_beam_search=self.use_beam_search, early_stopping=self.early_stopping, @@ -358,6 +352,7 @@ class CompletionRequest(OpenAIBaseModel): skip_special_tokens: bool = True spaces_between_special_tokens: bool = True truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + allowed_token_ids: Optional[List[int]] = None # doc: end-completion-sampling-params # doc: begin-completion-extra-params @@ -407,30 +402,23 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-extra-params - def to_sampling_params(self): + def to_sampling_params( + self, tokenizer: PreTrainedTokenizer, + guided_decode_logits_processor: Optional[LogitsProcessor], + default_max_tokens: int) -> SamplingParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + echo_without_generation = self.echo and self.max_tokens == 0 - logits_processors = None - if self.logit_bias: - logit_bias: Dict[int, float] = {} - try: - for token_id, bias in self.logit_bias.items(): - # Convert token_id to integer - # Clamp the bias between -100 and 100 per OpenAI API spec - logit_bias[int(token_id)] = min(100, max(-100, bias)) - except ValueError as exc: - raise ValueError(f"Found token_id `{token_id}` in logit_bias " - f"but token_id must be an integer or string " - f"representing an integer") from exc - - def logit_bias_logits_processor( - token_ids: List[int], - logits: torch.Tensor) -> torch.Tensor: - for token_id, bias in logit_bias.items(): - logits[token_id] += bias - return logits - - logits_processors = [logit_bias_logits_processor] + logits_processors = get_logits_processors( + logit_bias=self.logit_bias, + allowed_token_ids=self.allowed_token_ids, + tokenizer=tokenizer, + ) + if guided_decode_logits_processor: + logits_processors.append(guided_decode_logits_processor) return SamplingParams( n=self.n, @@ -447,7 +435,7 @@ def logit_bias_logits_processor( stop_token_ids=self.stop_token_ids, logprobs=self.logprobs, ignore_eos=self.ignore_eos, - max_tokens=self.max_tokens if not echo_without_generation else 1, + max_tokens=max_tokens if not echo_without_generation else 1, min_tokens=self.min_tokens, use_beam_search=self.use_beam_search, early_stopping=self.early_stopping, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index b21c2bc51318..21bfee700a6b 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -25,8 +25,6 @@ PromptAdapterPath) from vllm.inputs import PromptInputs from vllm.logger import init_logger -from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) from vllm.multimodal import MultiModalDataDict from vllm.outputs import RequestOutput from vllm.sequence import Logprob @@ -132,28 +130,23 @@ async def create_chat_completion( request_id = f"chat-{random_uuid()}" try: - sampling_params = request.to_sampling_params() - decoding_config = await self.engine.get_decoding_config() - guided_decoding_backend = request.guided_decoding_backend \ - or decoding_config.guided_decoding_backend guided_decode_logits_processor = ( - await - get_guided_decoding_logits_processor(guided_decoding_backend, - request, tokenizer)) - if guided_decode_logits_processor: - if sampling_params.logits_processors is None: - sampling_params.logits_processors = [] - sampling_params.logits_processors.append( - guided_decode_logits_processor) + await self._guided_decode_logits_processor(request, tokenizer)) prompt_inputs = self._tokenize_prompt_input( request, tokenizer, prompt, - truncate_prompt_tokens=sampling_params.truncate_prompt_tokens, + truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) + sampling_params = request.to_sampling_params( + tokenizer, + guided_decode_logits_processor, + default_max_tokens=self.max_model_len - + len(prompt_inputs["prompt_token_ids"])) + self._log_inputs(request_id, prompt_inputs, params=sampling_params, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 6aef4c9f9615..5356cd2f7cbc 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -24,8 +24,6 @@ OpenAIServing, PromptAdapterPath) from vllm.logger import init_logger -from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, @@ -93,31 +91,24 @@ async def create_completion(self, request: CompletionRequest, tokenizer = await self.engine.get_tokenizer(lora_request) - sampling_params = request.to_sampling_params() - decoding_config = await self.engine.get_decoding_config() - guided_decoding_backend = request.guided_decoding_backend \ - or decoding_config.guided_decoding_backend - guided_decode_logit_processor = ( - await - get_guided_decoding_logits_processor(guided_decoding_backend, - request, tokenizer)) - if guided_decode_logit_processor is not None: - if sampling_params.logits_processors is None: - sampling_params.logits_processors = [] - sampling_params.logits_processors.append( - guided_decode_logit_processor) - + guided_decode_logits_processor = ( + await self._guided_decode_logits_processor(request, tokenizer)) prompts = list( self._tokenize_prompt_input_or_inputs( request, tokenizer, request.prompt, - truncate_prompt_tokens=sampling_params. - truncate_prompt_tokens, + truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, )) for i, prompt_inputs in enumerate(prompts): + sampling_params = request.to_sampling_params( + tokenizer, + guided_decode_logits_processor, + default_max_tokens=self.max_model_len - + len(prompt_inputs["prompt_token_ids"])) + request_id_item = f"{request_id}-{i}" self._log_inputs(request_id_item, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8c6bd10b9b4d..044bd5e72f07 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -26,9 +26,11 @@ from vllm.inputs import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sequence import Logprob logger = init_logger(__name__) @@ -150,6 +152,15 @@ def create_streaming_error_response( }) return json_str + async def _guided_decode_logits_processor( + self, request: Union[ChatCompletionRequest, CompletionRequest], + tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]: + decoding_config = await self.engine.get_decoding_config() + guided_decoding_backend = request.guided_decoding_backend \ + or decoding_config.guided_decoding_backend + return await get_guided_decoding_logits_processor( + guided_decoding_backend, request, tokenizer) + async def _check_model( self, request: AnyRequest, @@ -254,9 +265,7 @@ def _validate_input( f"{self.max_model_len} tokens. However, you requested " f"{token_num} tokens in the messages, " f"Please reduce the length of the messages.") - request.max_tokens = self.max_model_len - token_num - - if token_num + request.max_tokens > self.max_model_len: + elif token_num + request.max_tokens > self.max_model_len: raise ValueError( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested "