diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 5e26eac77c85..dc1933fd2e5d 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,8 +1,7 @@ import json from dataclasses import dataclass from http import HTTPStatus -from typing import (Dict, Iterable, Iterator, List, Literal, Optional, Tuple, - TypedDict, Union, cast) +from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union from pydantic import Field from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -15,6 +14,7 @@ EmbeddingRequest, ErrorResponse, LogProbs, ModelCard, ModelList, ModelPermission) +from vllm.inputs import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import Logprob @@ -23,16 +23,6 @@ logger = init_logger(__name__) -class InputString(TypedDict): - text: str - is_tokens: Literal[False] - - -class InputTokens(TypedDict): - text: List[int] - is_tokens: Literal[True] - - @dataclass class LoRAModulePath: name: str @@ -304,38 +294,6 @@ def _tokenize_prompt_inputs( truncate_prompt_tokens=truncate_prompt_tokens, ) - def _parse_prompt_input_or_inputs( - self, - input_or_inputs: Union[str, List[str], List[int], List[List[int]]], - ) -> List[Union[InputString, InputTokens]]: - if isinstance(input_or_inputs, str): - # case 1: a string - return [InputString(text=input_or_inputs, is_tokens=False)] - - if isinstance(input_or_inputs, list): - if len(input_or_inputs) == 0: - raise ValueError("please provide at least one prompt") - if isinstance(input_or_inputs[0], str): - # case 2: array of strings - return [ - InputString(text=elem, is_tokens=False) - for elem in cast(List[str], input_or_inputs) - ] - if isinstance(input_or_inputs[0], int): - # case 3: array of tokens - elem = cast(List[int], input_or_inputs) - return [InputTokens(text=elem, is_tokens=True)] - if isinstance(input_or_inputs[0], list) and isinstance( - input_or_inputs[0][0], int): - # case 4: array of token arrays - return [ - InputTokens(text=elem, is_tokens=True) - for elem in cast(List[List[int]], input_or_inputs) - ] - - raise ValueError("prompt must be a string, array of strings, " - "array of tokens, or array of token arrays") - def _tokenize_prompt_input_or_inputs( self, request: Union[ChatCompletionRequest, CompletionRequest, @@ -352,8 +310,7 @@ def _tokenize_prompt_input_or_inputs( """ tokenizer = self.tokenizer - for prompt_input in self._parse_prompt_input_or_inputs( - input_or_inputs): + for prompt_input in parse_and_batch_prompt(input_or_inputs): # Although our type checking is based on mypy, # VSCode Pyright extension should still work properly # "is True" is required for Pyright to perform type narrowing @@ -361,7 +318,7 @@ def _tokenize_prompt_input_or_inputs( if prompt_input["is_tokens"] is False: yield self._normalize_prompt_text_to_input( request, - prompt=prompt_input["text"], + prompt=prompt_input["content"], tokenizer=tokenizer, truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, @@ -369,7 +326,7 @@ def _tokenize_prompt_input_or_inputs( else: yield self._normalize_prompt_tokens_to_input( request, - prompt_ids=prompt_input["text"], + prompt_ids=prompt_input["content"], tokenizer=tokenizer, truncate_prompt_tokens=truncate_prompt_tokens, )