Skip to content

Commit

Permalink
Parse and batch the prompt using vllm-project#4328
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed May 26, 2024
1 parent 18d5bcd commit fab7f92
Showing 1 changed file with 5 additions and 48 deletions.
53 changes: 5 additions & 48 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json
from dataclasses import dataclass
from http import HTTPStatus
from typing import (Dict, Iterable, Iterator, List, Literal, Optional, Tuple,
TypedDict, Union, cast)
from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union

from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
Expand All @@ -15,6 +14,7 @@
EmbeddingRequest, ErrorResponse,
LogProbs, ModelCard, ModelList,
ModelPermission)
from vllm.inputs import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob
Expand All @@ -23,16 +23,6 @@
logger = init_logger(__name__)


class InputString(TypedDict):
text: str
is_tokens: Literal[False]


class InputTokens(TypedDict):
text: List[int]
is_tokens: Literal[True]


@dataclass
class LoRAModulePath:
name: str
Expand Down Expand Up @@ -304,38 +294,6 @@ def _tokenize_prompt_inputs(
truncate_prompt_tokens=truncate_prompt_tokens,
)

def _parse_prompt_input_or_inputs(
self,
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
) -> List[Union[InputString, InputTokens]]:
if isinstance(input_or_inputs, str):
# case 1: a string
return [InputString(text=input_or_inputs, is_tokens=False)]

if isinstance(input_or_inputs, list):
if len(input_or_inputs) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(input_or_inputs[0], str):
# case 2: array of strings
return [
InputString(text=elem, is_tokens=False)
for elem in cast(List[str], input_or_inputs)
]
if isinstance(input_or_inputs[0], int):
# case 3: array of tokens
elem = cast(List[int], input_or_inputs)
return [InputTokens(text=elem, is_tokens=True)]
if isinstance(input_or_inputs[0], list) and isinstance(
input_or_inputs[0][0], int):
# case 4: array of token arrays
return [
InputTokens(text=elem, is_tokens=True)
for elem in cast(List[List[int]], input_or_inputs)
]

raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")

def _tokenize_prompt_input_or_inputs(
self,
request: Union[ChatCompletionRequest, CompletionRequest,
Expand All @@ -352,24 +310,23 @@ def _tokenize_prompt_input_or_inputs(
"""
tokenizer = self.tokenizer

for prompt_input in self._parse_prompt_input_or_inputs(
input_or_inputs):
for prompt_input in parse_and_batch_prompt(input_or_inputs):
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is True" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
if prompt_input["is_tokens"] is False:
yield self._normalize_prompt_text_to_input(
request,
prompt=prompt_input["text"],
prompt=prompt_input["content"],
tokenizer=tokenizer,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
yield self._normalize_prompt_tokens_to_input(
request,
prompt_ids=prompt_input["text"],
prompt_ids=prompt_input["content"],
tokenizer=tokenizer,
truncate_prompt_tokens=truncate_prompt_tokens,
)

0 comments on commit fab7f92

Please sign in to comment.