diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 8d423657630c..f7b84eebc8b6 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -25,7 +25,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install mypy==1.9.0 + pip install mypy==1.11.1 pip install types-setuptools pip install types-PyYAML pip install types-requests diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py index 79b284554f17..0f266d791885 100644 --- a/examples/offline_inference_encoder_decoder.py +++ b/examples/offline_inference_encoder_decoder.py @@ -4,8 +4,8 @@ ''' from vllm import LLM, SamplingParams -from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt -from vllm.utils import zip_enc_dec_prompt_lists +from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, + TokensPrompt, zip_enc_dec_prompts) dtype = "float" @@ -61,9 +61,9 @@ ) # - Finally, here's a useful helper function for zipping encoder and -# decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt +# decoder prompts together into a list of ExplicitEncoderDecoderPrompt # instances -zipped_prompt_list = zip_enc_dec_prompt_lists( +zipped_prompt_list = zip_enc_dec_prompts( ['An encoder prompt', 'Another encoder prompt'], ['A decoder prompt', 'Another decoder prompt']) diff --git a/requirements-common.txt b/requirements-common.txt index d8c95bf77240..ebd0fca51919 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer == 0.10.3 outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 -typing_extensions +typing_extensions >= 4.10 filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 pyzmq gguf == 0.9.1 diff --git a/requirements-lint.txt b/requirements-lint.txt index bd34227d3e82..d0b2fef6deae 100644 --- a/requirements-lint.txt +++ b/requirements-lint.txt @@ -8,7 +8,7 @@ isort==5.13.2 clang-format==18.1.5 # type checking -mypy==1.9.0 +mypy==1.11.1 types-PyYAML types-requests types-setuptools diff --git a/tests/conftest.py b/tests/conftest.py index c0bf9897c97f..d565da5a1019 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ import os import sys from collections import UserList +from enum import Enum from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union import pytest @@ -14,20 +15,19 @@ AutoModelForVision2Seq, AutoTokenizer, BatchEncoding, BatchFeature) -from tests.models.utils import DecoderPromptType from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.config import TokenizerPoolConfig from vllm.connections import global_http_connection from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) -from vllm.inputs import TextPrompt +from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, + to_enc_dec_tuple_list, zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sequence import SampleLogprobs from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, - is_cpu, to_enc_dec_tuple_list, - zip_enc_dec_prompt_lists) + is_cpu) logger = init_logger(__name__) @@ -124,10 +124,16 @@ def example_prompts() -> List[str]: return prompts +class DecoderPromptType(Enum): + """For encoder/decoder models only.""" + CUSTOM = 1 + NONE = 2 + EMPTY_STR = 3 + + @pytest.fixture -def example_encoder_decoder_prompts() \ - -> Dict[DecoderPromptType, - Tuple[List[str], List[Optional[str]]]]: +def example_encoder_decoder_prompts( +) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]: ''' Returns an encoder prompt list and a decoder prompt list, wherein each pair of same-index entries in both lists corresponds to an (encoder prompt, @@ -150,11 +156,11 @@ def example_encoder_decoder_prompts() \ # NONE decoder prompt type return { DecoderPromptType.NONE: - zip_enc_dec_prompt_lists(encoder_prompts, none_decoder_prompts), + zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts), DecoderPromptType.EMPTY_STR: - zip_enc_dec_prompt_lists(encoder_prompts, empty_str_decoder_prompts), + zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts), DecoderPromptType.CUSTOM: - zip_enc_dec_prompt_lists(encoder_prompts, custom_decoder_prompts), + zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts), } @@ -444,7 +450,7 @@ def generate_greedy_logprobs_limit( def generate_encoder_decoder_greedy_logprobs_limit( self, - encoder_decoder_prompts: Tuple[List[str], List[str]], + encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, num_logprobs: int, **kwargs: Any, @@ -608,7 +614,7 @@ def generate_w_logprobs( def generate_encoder_decoder_w_logprobs( self, - encoder_decoder_prompts: Tuple[List[str], List[str]], + encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], sampling_params: SamplingParams, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: ''' @@ -653,7 +659,7 @@ def generate_greedy_logprobs( def generate_encoder_decoder_greedy_logprobs( self, - encoder_decoder_prompts: Tuple[List[str], List[str]], + encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, num_logprobs: int, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: diff --git a/tests/distributed/test_basic_distributed_correctness_enc_dec.py b/tests/distributed/test_basic_distributed_correctness_enc_dec.py index 69eae62ca732..9850c823ff5d 100644 --- a/tests/distributed/test_basic_distributed_correctness_enc_dec.py +++ b/tests/distributed/test_basic_distributed_correctness_enc_dec.py @@ -11,9 +11,9 @@ import pytest -from tests.models.utils import DecoderPromptType from vllm.utils import cuda_device_count_stateless +from ..conftest import DecoderPromptType from ..models.utils import check_logprobs_close from ..utils import fork_new_process_for_each_test diff --git a/tests/entrypoints/openai/test_encoder_decoder.py b/tests/entrypoints/openai/test_encoder_decoder.py new file mode 100644 index 000000000000..85f1c6f18bf3 --- /dev/null +++ b/tests/entrypoints/openai/test_encoder_decoder.py @@ -0,0 +1,50 @@ +import openai +import pytest + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "facebook/bart-base" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--dtype", + "bfloat16", + "--enforce-eager", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def client(server): + return server.get_async_client() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): + completion = await client.completions.create(model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + assert len(choice.text) >= 5 + assert choice.finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, prompt_tokens=2, total_tokens=7) + + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 1 diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py index 9c26b7163ff6..9bca5a86f124 100644 --- a/tests/models/test_bart.py +++ b/tests/models/test_bart.py @@ -2,6 +2,8 @@ Run `pytest tests/models/test_bart.py`. """ +from typing import List, Optional, Tuple + from vllm.utils import is_cpu if not is_cpu(): @@ -11,22 +13,31 @@ import pytest - from tests.models.utils import DecoderPromptType + from vllm.sequence import SampleLogprobs + from ..conftest import DecoderPromptType from .utils import check_logprobs_close MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"] - DECODER_PROMPT_TYPES = ([ - DecoderPromptType.CUSTOM, DecoderPromptType.EMPTY_STR, - DecoderPromptType.NONE - ]) + def vllm_to_hf_output( + vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], + decoder_prompt_type: DecoderPromptType, + ): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "" + if decoder_prompt_type == DecoderPromptType.NONE: + hf_output_str = "" + hf_output_str + + return output_ids, hf_output_str, out_logprobs @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) - @pytest.mark.parametrize("decoder_prompt_type", DECODER_PROMPT_TYPES) + @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) def test_models( hf_runner, vllm_runner, @@ -146,8 +157,13 @@ def test_models( hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE else 0) - check_logprobs_close(outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - num_outputs_0_skip_tokens=hf_skip_tokens) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, decoder_prompt_type) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + num_outputs_0_skip_tokens=hf_skip_tokens, + ) diff --git a/tests/models/utils.py b/tests/models/utils.py index d96301b853c8..ff29a0ae81d6 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,5 +1,4 @@ import warnings -from enum import Enum from typing import Dict, List, Optional, Sequence, Tuple, Union from vllm.sequence import SampleLogprobs @@ -136,13 +135,3 @@ def check_logprobs_close( warnings.simplefilter("always") warnings.warn(fail_msg, stacklevel=2) - - -class DecoderPromptType(Enum): - ''' - For encoder/decoder models only - - - ''' - CUSTOM = 1 - NONE = 2 - EMPTY_STR = 3 diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 887c7101decd..3725d8687f25 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -2,7 +2,7 @@ import pytest -from vllm.inputs import parse_and_batch_prompt +from vllm.inputs.parse import parse_and_batch_prompt STRING_INPUTS = [ '', diff --git a/vllm/config.py b/vllm/config.py index 6fc0045fb93a..59cabbfc965d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -464,6 +464,16 @@ def _get_num_seqlen_agnostic_layers( if t != "attention" ]) + @property + def is_encoder_decoder_model(self) -> bool: + """Extract the HF encoder/decoder model flag.""" + return getattr(self.hf_config, "is_encoder_decoder", False) + + @property + def is_embedding_model(self) -> bool: + """Extract the embedding model flag.""" + return self.embedding_mode + class CacheConfig: """Configuration for the KV cache. diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index ef82c3dfd0b5..6af347def475 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -5,6 +5,7 @@ Optional, Set, Tuple, Type, Union) from transformers import PreTrainedTokenizer +from typing_extensions import assert_never import vllm.envs as envs from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, @@ -12,11 +13,14 @@ from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout -from vllm.engine.llm_engine import LLMEngine +from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine, + PromptComponents) from vllm.engine.metrics import StatLoggerBase from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.ray_utils import initialize_ray_cluster, ray -from vllm.inputs import LLMInputs, PromptInputs +from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, + SingletonPromptInputs) +from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -293,38 +297,138 @@ async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() - async def process_model_inputs_async( + async def _tokenize_prompt_async( + self, + prompt: str, + request_id: str, + lora_request: Optional[LoRARequest], + ) -> List[int]: + """Async version of :meth:`_tokenize_prompt`.""" + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") + + return await tokenizer.encode_async(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + + async def _extract_prompt_components_async( self, + inputs: SingletonPromptInputs, request_id: str, + lora_request: Optional[LoRARequest] = None, + ) -> PromptComponents: + """Async version of :meth:`_extract_prompt_components`.""" + if isinstance(inputs, str): + prompt = inputs + prompt_token_ids = await self._tokenize_prompt_async( + prompt, + request_id=request_id, + lora_request=lora_request, + ) + multi_modal_data = None + elif isinstance(inputs, dict): + if "prompt_token_ids" in inputs: + prompt = None + prompt_token_ids = inputs["prompt_token_ids"] + else: + # NOTE: This extra assignment is required to pass mypy + prompt = parsed_prompt = inputs["prompt"] + prompt_token_ids = await self._tokenize_prompt_async( + parsed_prompt, + request_id=request_id, + lora_request=lora_request, + ) + + multi_modal_data = inputs.get("multi_modal_data") + else: + assert_never(inputs) + + return prompt, prompt_token_ids, multi_modal_data + + async def _process_encoder_decoder_prompt_async( + self, inputs: PromptInputs, + request_id: str, + ) -> EncoderDecoderLLMInputs: + """Async version of :meth:`_process_encoder_decoder_prompt`.""" + encoder_comps: PromptComponents + decoder_comps: DecoderPromptComponents + + if is_explicit_encoder_decoder_prompt(inputs): + encoder_task = self._extract_prompt_components_async( + inputs["encoder_prompt"], + request_id=request_id, + ) + + if (decoder_input := inputs["decoder_prompt"]) is None: + encoder_comps = await encoder_task + decoder_comps = None, None, None + else: + decoder_task = self._extract_prompt_components_async( + decoder_input, + request_id=request_id, + ) + + encoder_comps, decoder_comps = await asyncio.gather( + encoder_task, decoder_task) + else: + encoder_comps = await self._extract_prompt_components_async( + inputs, + request_id=request_id, + ) + + decoder_comps = None, None, None + + return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) + + async def _process_decoder_only_prompt_async( + self, + inputs: SingletonPromptInputs, + request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: - if isinstance(inputs, str): - inputs = {"prompt": inputs} + """Async version of :meth:`_process_decoder_only_prompt`.""" + prompt_comps = await self._extract_prompt_components_async( + inputs, + request_id=request_id, + lora_request=lora_request, + ) - if "prompt_token_ids" not in inputs: - tokenizer = self.get_tokenizer_group("prompts must be None if " - "skip_tokenizer_init is True") + return self._build_decoder_only_llm_inputs( + prompt_comps, + prompt_adapter_request=prompt_adapter_request, + ) - prompt_token_ids = await tokenizer.encode_async( + async def process_model_inputs_async( + self, + inputs: PromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: + """Async version of :meth:`process_model_inputs`.""" + if self.is_encoder_decoder_model(): + # Encoder-decoder model requires special mapping of + # input prompts to encoder & decoder + model_inputs = await self._process_encoder_decoder_prompt_async( + inputs, request_id=request_id, - prompt=inputs["prompt"], - lora_request=lora_request) + ) else: - prompt_token_ids = inputs["prompt_token_ids"] + if is_explicit_encoder_decoder_prompt(inputs): + raise ValueError("Cannot pass encoder-decoder prompt " + "to decoder-only models") - if prompt_adapter_request: - prompt_token_ids = [ - 0 - ] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \ - prompt_token_ids - - llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=inputs.get("prompt"), - multi_modal_data=inputs.get("multi_modal_data")) + # Decoder-only operation + model_inputs = await self._process_decoder_only_prompt_async( + inputs, + request_id=request_id, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) - return self.input_processor(llm_inputs) + return self.input_processor(model_inputs) async def add_request_async( self, @@ -336,6 +440,7 @@ async def add_request_async( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: + """Async version of :meth:`add_request`.""" if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") @@ -343,10 +448,11 @@ async def add_request_async( arrival_time = time.time() processed_inputs = await self.process_model_inputs_async( + inputs, request_id=request_id, - inputs=inputs, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + ) self._add_processed_request( request_id=request_id, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 75c6d7e6c9b2..dcaf375f9b15 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -5,6 +5,8 @@ from typing import Sequence as GenericSequence from typing import Set, Tuple, Type, TypeVar, Union +from typing_extensions import assert_never + import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, @@ -22,10 +24,12 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (INPUT_REGISTRY, LLMInputs, PromptInputs, - get_prompt_type) +from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, LLMInputs, + PromptInputs, SingletonPromptInputs) +from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.multimodal import MultiModalDataDict from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams @@ -43,8 +47,7 @@ AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import (Counter, is_embedding_model_config, - is_encoder_decoder_model_config) +from vllm.utils import Counter from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -66,6 +69,11 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) +PromptComponents = Tuple[Optional[str], List[int], + Optional[MultiModalDataDict]] +DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], + Optional[MultiModalDataDict]] + class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -524,7 +532,7 @@ def _get_eos_token_id(self, return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id - def _get_decoder_start_token_id(self, ) -> Optional[int]: + def _get_decoder_start_token_id(self) -> Optional[int]: ''' Obtain the decoder start token id employed by an encoder/decoder model. Returns None for non-encoder/decoder models or if the @@ -553,7 +561,7 @@ def _get_decoder_start_token_id(self, ) -> Optional[int]: def _add_processed_request( self, request_id: str, - processed_inputs: LLMInputs, + processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs], params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], @@ -613,11 +621,11 @@ def _add_processed_request( def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() - _LLMInputComponentsType = Tuple[str, List[int], ] + _LLMInputComponentsType = Tuple[str, List[int]] def _prepare_decoder_input_ids_for_generation( self, - decoder_input_ids: Optional[List[int]] = None, + decoder_input_ids: Optional[List[int]], ) -> List[int]: """ Prepares `decoder_input_ids` for generation with encoder-decoder models. @@ -639,14 +647,13 @@ def _prepare_decoder_input_ids_for_generation( * Processed token list """ - decoder_start_token_id: Optional[int] = ( - self._get_decoder_start_token_id()) + decoder_start_token_id = self._get_decoder_start_token_id() assert decoder_start_token_id is not None if decoder_input_ids is None: # no decoder prompt input -> # use decoder_start_token_id as decoder_input_ids - (decoder_input_ids) = self._get_default_enc_dec_decoder_prompt() + decoder_input_ids = self._get_default_enc_dec_decoder_prompt() if (len(decoder_input_ids) == 0 or decoder_input_ids[0] != decoder_start_token_id): @@ -657,12 +664,11 @@ def _prepare_decoder_input_ids_for_generation( def _tokenize_prompt( self, prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[str] = None, + request_id: str, + lora_request: Optional[LoRARequest], ) -> List[int]: ''' - Wrapper around application of the model's - tokenizer. + Wrapper around application of the model's tokenizer. Arguments: @@ -678,87 +684,72 @@ def _tokenize_prompt( tokenizer = self.get_tokenizer_group("prompts must be None if " "skip_tokenizer_init is True") - prompt_token_ids = tokenizer.encode(request_id=request_id, - prompt=prompt, - lora_request=lora_request) - - return prompt_token_ids + return tokenizer.encode(request_id=request_id, + prompt=prompt, + lora_request=lora_request) - def _extract_single_prompt_for_enc_dec_input( + def _extract_prompt_components( self, - inputs: Optional[PromptInputs], - request_id: Optional[str] = None, - ptype: Optional[str] = None, - is_encoder_prompt: bool = False, - ) -> Tuple[Optional[str], List[int]]: + inputs: SingletonPromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + ) -> PromptComponents: ''' - Only for encoder/decoder models: - Extract prompt & prompt_token_ids from any single - encoder or decoder input prompt. For encoder input prompts - in particular, also extract multi-modal data. - - This function handles the following scenarios: - 1. The user supplied a singleton encoder prompt - & the prompt/prompt-token-ids must be extracted. - 2. The user supplied an explicit encoder/decoder - prompt & the prompt/prompt-token-ids must be - extracted from either the encoder and decoder prompts. - - For decoder prompts in particular (scenario 2), special - processing is applied to the returned decoder token ids. + Extract the components of any single encoder or decoder input prompt. Arguments: * request_id - * ptype: str representation of the input prompt type. - If `ptype` is `None`, assume that the prompt - type is unknown and must be inferred. This is the - case for ExplicitEncoderDecoder sub-prompts. * inputs: single encoder or decoder input prompt - * is_encoder_prompt: True if encoder input prompt. - If False, decoder prompt tokens - are preprocessed. + * lora_request: this is only valid for decoder prompts Returns: * prompt * prompt_token_ids + * multi_modal_data ''' - prompt_token_ids = None - ptype = (get_prompt_type(inputs) if ptype is None else ptype) - if inputs is None: - prompt = None - elif ptype == 'str': + if isinstance(inputs, str): prompt = inputs prompt_token_ids = self._tokenize_prompt( prompt, request_id=request_id, + lora_request=lora_request, ) - elif ptype == 'TokensPrompt': - prompt = None - prompt_token_ids = inputs['prompt_token_ids'] + multi_modal_data = None + elif isinstance(inputs, dict): + if "prompt_token_ids" in inputs: + prompt = None + prompt_token_ids = inputs["prompt_token_ids"] + else: + # NOTE: This extra assignment is required to pass mypy + prompt = parsed_prompt = inputs["prompt"] + prompt_token_ids = self._tokenize_prompt( + parsed_prompt, + request_id=request_id, + lora_request=lora_request, + ) + + multi_modal_data = inputs.get("multi_modal_data") else: - prompt = inputs['prompt'] - prompt_token_ids = self._tokenize_prompt( - prompt, - request_id=request_id, - ) + assert_never(inputs) - if not is_encoder_prompt: - # Apply special pre-processing to - # decoder prompts - prompt_token_ids = (self._prepare_decoder_input_ids_for_generation( - prompt_token_ids, )) + return prompt, prompt_token_ids, multi_modal_data - assert prompt_token_ids is not None + def _apply_prompt_adapter( + self, + prompt_token_ids: List[int], + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> List[int]: + if prompt_adapter_request: + prompt_token_ids = ( + [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + + prompt_token_ids) - return ( - prompt, - prompt_token_ids, - ) + return prompt_token_ids - def _get_default_enc_dec_decoder_prompt(self, ) -> List[int]: + def _get_default_enc_dec_decoder_prompt(self) -> List[int]: ''' Specifically for encoder/decoder models: generate a default decoder prompt for when @@ -792,18 +783,39 @@ def _get_default_enc_dec_decoder_prompt(self, ) -> List[int]: bos_token_id = self._get_bos_token_id() assert bos_token_id is not None - prompt_token_ids: List[int] = [bos_token_id] - return prompt_token_ids + return [bos_token_id] + + def _build_enc_dec_llm_inputs( + self, + encoder_comps: PromptComponents, + decoder_comps: DecoderPromptComponents, + ) -> EncoderDecoderLLMInputs: + encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps + decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps + + if encoder_mm_data is not None or decoder_mm_data is not None: + raise ValueError("Multi-modal encoder-decoder models are " + "not supported yet") + + decoder_prompt_ids = ( + self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) + + return EncoderDecoderLLMInputs( + prompt_token_ids=decoder_prompt_ids, + prompt=decoder_prompt, + encoder_prompt_token_ids=encoder_prompt_ids, + encoder_prompt=encoder_prompt, + ) def _process_encoder_decoder_prompt( self, inputs: PromptInputs, - request_id: Optional[str] = None, - ) -> LLMInputs: + request_id: str, + ) -> EncoderDecoderLLMInputs: ''' For encoder/decoder models only: - Process an input prompt - into an `LLMInputs` instance. + Process an input prompt into an + :class:`EncoderDecoderLLMInputs` instance. There are two types of input prompts: singleton prompts which carry only the @@ -830,136 +842,103 @@ def _process_encoder_decoder_prompt( Returns: - * `LLMInputs` instance + * :class:`EncoderDecoderLLMInputs` instance ''' - ptype = get_prompt_type(inputs) - - # Obtain encoder and decoder prompt tokens. Note - # that, no matter what, the decoder - # prompt type is unknown. - if ptype == "ExplicitEncoderDecoder": - # If input is explicit encoder/decoder prompt, - # then it remains to be determined what type - # of encoder prompt we have - extracted_encoder_prompt = inputs.get('encoder_prompt') - encoder_ptype = None - # Extract decoder prompt from explicit - # encoder/decoder prompt - extracted_decoder_prompt = inputs.get('decoder_prompt') + encoder_comps: PromptComponents + decoder_comps: DecoderPromptComponents + + if is_explicit_encoder_decoder_prompt(inputs): + encoder_comps = self._extract_prompt_components( + inputs["encoder_prompt"], + request_id=request_id, + ) + + if (decoder_input := inputs["decoder_prompt"]) is None: + decoder_comps = None, None, None + else: + decoder_comps = self._extract_prompt_components( + decoder_input, + request_id=request_id, + ) else: - # If input is singleton encoder prompt, then - # we know the encoder prompt type - extracted_encoder_prompt = inputs - encoder_ptype = ptype - # Decoder prompt is always unknown if - # encoder/decoder prompt is not explicit - extracted_decoder_prompt = None - - # Invoke helper function to obtain encoder - # prompt and prompt token ids, either from - # singleton encoder prompt or from the - # encoder sub-prompt of an explicit - # encoder/decode scenario 2), special - # processing is applied to the returned decoder token ids - ( - encoder_prompt, - encoder_prompt_token_ids, - ) = self._extract_single_prompt_for_enc_dec_input( - extracted_encoder_prompt, - request_id=request_id, - ptype=encoder_ptype, - is_encoder_prompt=True, - ) + encoder_comps = self._extract_prompt_components( + inputs, + request_id=request_id, + ) - # Invoke helper method to obtain - # decoder prompt and prompt token ids. - # - # The helper method will detect the decoder - # prompt type. - # - # Helper method will also apply special - # preprocessing unique to decoder prompts. - ( - decoder_prompt, - decoder_prompt_token_ids, - ) = self._extract_single_prompt_for_enc_dec_input( - extracted_decoder_prompt, - request_id=request_id, - ptype=None, - is_encoder_prompt=False, - ) + decoder_comps = None, None, None - return LLMInputs( - prompt_token_ids=decoder_prompt_token_ids, - prompt=decoder_prompt, - encoder_prompt_token_ids=encoder_prompt_token_ids, - encoder_prompt=encoder_prompt, - ) + return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) + + def _build_decoder_only_llm_inputs( + self, + prompt_comps: PromptComponents, + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> LLMInputs: + prompt, prompt_token_ids, multi_modal_data = prompt_comps + + prompt_token_ids = self._apply_prompt_adapter( + prompt_token_ids, prompt_adapter_request=prompt_adapter_request) + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=prompt, + multi_modal_data=multi_modal_data) def _process_decoder_only_prompt( self, - inputs: PromptInputs, + inputs: SingletonPromptInputs, + request_id: str, lora_request: Optional[LoRARequest] = None, - request_id: Optional[str] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: ''' For decoder-only models: - Process an input prompt - into an `LLMInputs` instance. + Process an input prompt into an :class:`LLMInputs` instance. Arguments: * inputs: input prompt - * lora_request * request_id + * lora_request * prompt_adapter_request Returns: - * `LLMInputs` instance + * :class:`LLMInputs` instance ''' - if isinstance(inputs, str): - inputs = {"prompt": inputs} - prompt = inputs.get("prompt") - - if "prompt_token_ids" not in inputs: - prompt_token_ids = self._tokenize_prompt( - prompt, - request_id=request_id, - lora_request=lora_request, - ) - else: - prompt_token_ids = inputs["prompt_token_ids"] - - if prompt_adapter_request: - prompt_token_ids = ( - [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens - + prompt_token_ids) + prompt_comps = self._extract_prompt_components( + inputs, + request_id=request_id, + lora_request=lora_request, + ) - return LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=prompt, - multi_modal_data=inputs.get("multi_modal_data")) + return self._build_decoder_only_llm_inputs( + prompt_comps, + prompt_adapter_request=prompt_adapter_request, + ) def process_model_inputs( self, - request_id: str, inputs: PromptInputs, + request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> LLMInputs: + ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder - model_inputs = self._process_encoder_decoder_prompt( inputs, request_id=request_id, ) else: + if is_explicit_encoder_decoder_prompt(inputs): + raise ValueError("Cannot pass encoder-decoder prompt " + "to decoder-only models") + # Decoder-only operation model_inputs = self._process_decoder_only_prompt( inputs, @@ -1029,10 +1008,11 @@ def add_request( arrival_time = time.time() processed_inputs = self.process_model_inputs( + inputs, request_id=request_id, - inputs=inputs, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + ) self._add_processed_request( request_id=request_id, @@ -1597,7 +1577,7 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None: seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time) def is_encoder_decoder_model(self): - return is_encoder_decoder_model_config(self.model_config) + return self.model_config.is_encoder_decoder_model def is_embedding_model(self): - return is_embedding_model_config(self.model_config) + return self.model_config.is_embedding_model diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 12634c326185..1197c70d88ae 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -2,8 +2,7 @@ from dataclasses import dataclass from functools import lru_cache from pathlib import Path -from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union, - cast, final) +from typing import Any, Awaitable, Iterable, List, Optional, Tuple, Union, cast # yapf conflicts with isort for this block # yapf: disable @@ -59,7 +58,7 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): CustomChatCompletionMessageParam] -@final # So that it should be compatible with Dict[str, str] +# TODO: Make fields ReadOnly once mypy supports it class ConversationMessage(TypedDict): role: str content: str diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index eaa157209493..175f418a1294 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -6,8 +6,8 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt, - parse_and_batch_prompt) +from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt +from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding import ( diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index 84871fc83ef5..c0cd820e30c0 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -40,9 +40,11 @@ def _get_allowed_token_ids_logits_processor( return AllowedTokenIdsLogitsProcessor(allowed_token_ids) -def logit_bias_logits_processor(logit_bias: Dict[str, - float], token_ids: List[int], - logits: torch.Tensor) -> torch.Tensor: +def logit_bias_logits_processor( + logit_bias: Dict[int, float], + token_ids: List[int], + logits: torch.Tensor, +) -> torch.Tensor: for token_id, bias in logit_bias.items(): logits[token_id] += bias return logits diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index df4932d8fe18..8d8b5ea4bdf5 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -22,7 +22,7 @@ TokenizeCompletionRequest, TokenizeRequest) # yapf: enable -from vllm.inputs import parse_and_batch_prompt +from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding import ( diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index e22b88f2fc38..0b08e9691f91 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,7 +1,7 @@ -from .data import (ExplicitEncoderDecoderPrompt, LLMInputs, ParsedText, - ParsedTokens, PromptInputs, SingletonPromptInputs, - TextPrompt, TokensPrompt, get_prompt_type, - is_valid_encoder_decoder_llm_inputs, parse_and_batch_prompt) +from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, + LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, + TokensPrompt, build_explicit_enc_dec_prompt, + to_enc_dec_tuple_list, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() @@ -14,18 +14,17 @@ """ __all__ = [ - "ParsedText", - "ParsedTokens", - "parse_and_batch_prompt", "TextPrompt", "TokensPrompt", "PromptInputs", + "SingletonPromptInputs", + "ExplicitEncoderDecoderPrompt", "LLMInputs", + "EncoderDecoderLLMInputs", + "build_explicit_enc_dec_prompt", + "to_enc_dec_tuple_list", + "zip_enc_dec_prompts", "INPUT_REGISTRY", "InputContext", "InputRegistry", - "get_prompt_type", - "is_valid_encoder_decoder_llm_inputs", - "ExplicitEncoderDecoderPrompt", - "SingletonPromptInputs", ] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 86c2901dc4c8..75ab0c770155 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,71 +1,12 @@ -from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, - TypedDict, Union, cast, overload) +from typing import (TYPE_CHECKING, Generic, Iterable, List, Optional, Tuple, + Union) -from typing_extensions import NotRequired +from typing_extensions import NotRequired, TypedDict, TypeVar if TYPE_CHECKING: from vllm.multimodal import MultiModalDataDict -class ParsedText(TypedDict): - content: str - is_tokens: Literal[False] - - -class ParsedTokens(TypedDict): - content: List[int] - is_tokens: Literal[True] - - -# https://github.com/vllm-project/vllm/pull/4028 -@overload -def parse_and_batch_prompt( - prompt: Union[str, List[str]]) -> Sequence[ParsedText]: - ... - - -@overload -def parse_and_batch_prompt( - prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]: - ... - - -def parse_and_batch_prompt( - prompt: Union[str, List[str], List[int], List[List[int]]], -) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: - if isinstance(prompt, str): - # case 1: a string - return [ParsedText(content=prompt, is_tokens=False)] - - if isinstance(prompt, list): - if len(prompt) == 0: - raise ValueError("please provide at least one prompt") - - if isinstance(prompt[0], str): - # case 2: array of strings - return [ - ParsedText(content=elem, is_tokens=False) - for elem in cast(List[str], prompt) - ] - if isinstance(prompt[0], int): - # case 3: array of tokens - elem = cast(List[int], prompt) - return [ParsedTokens(content=elem, is_tokens=True)] - if isinstance(prompt[0], list): - if len(prompt[0]) == 0: - raise ValueError("please provide at least one prompt") - - if isinstance(prompt[0][0], int): - # case 4: array of token arrays - return [ - ParsedTokens(content=elem, is_tokens=True) - for elem in cast(List[List[int]], prompt) - ] - - raise ValueError("prompt must be a string, array of strings, " - "array of tokens, or array of token arrays") - - class TextPrompt(TypedDict): """Schema for a text prompt.""" @@ -103,39 +44,49 @@ class TokensPrompt(TypedDict): which encapsulates multiple prompts, i.e. of the sort which may be utilized for encoder/decoder models when the user desires to express both the encoder & decoder -prompts explicitly, i.e. ExplicitEncoderDecoderPrompt +prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt` -A prompt of type SingletonPromptInputs may be employed +A prompt of type :class:`SingletonPromptInputs` may be employed as (1) input to a decoder-only model, (2) input to the encoder of an encoder/decoder model, in the scenario where the decoder-prompt is not specified explicitly, or (3) as a member of a larger data structure encapsulating -more than one prompt, i.e. ExplicitEncoderDecoderPrompt +more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt` """ +_T1_co = TypeVar("_T1_co", + bound=SingletonPromptInputs, + default=SingletonPromptInputs, + covariant=True) +_T2_co = TypeVar("_T2_co", + bound=SingletonPromptInputs, + default=SingletonPromptInputs, + covariant=True) -class ExplicitEncoderDecoderPrompt(TypedDict): + +# TODO: Make fields ReadOnly once mypy supports it +class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): """Represents an encoder/decoder model input prompt, comprising an explicit encoder prompt and a decoder prompt. The encoder and decoder prompts, respectively, may formatted according to any of the - SingletonPromptInputs schemas, and are not + :class:`SingletonPromptInputs` schemas, and are not required to have the same schema. Only the encoder prompt may have multi-modal data. - Note that an ExplicitEncoderDecoderPrompt may not + Note that an :class:`ExplicitEncoderDecoderPrompt` may not be used as an input to a decoder-only model, and that the `encoder_prompt` and `decoder_prompt` - fields of this data structure may not themselves - must be SingletonPromptInputs instances. + fields of this data structure themselves must be + :class:`SingletonPromptInputs` instances. """ - encoder_prompt: SingletonPromptInputs + encoder_prompt: _T1_co - decoder_prompt: SingletonPromptInputs + decoder_prompt: Optional[_T2_co] PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt] @@ -150,60 +101,12 @@ class ExplicitEncoderDecoderPrompt(TypedDict): """ -def _has_required_keys( - d: dict, - required_keys: set, -) -> bool: - return required_keys.issubset(d.keys()) - - -def get_prompt_type(prompt: Optional[PromptInputs]) -> Optional[str]: - """ - Get the type-name of the prompt argument instance, given that - isinstance() cannot apply to TypedDict subclasses directly. - If the prompt is None, return 'None' as the type name. - - Arguments: - - * prompt: LLM input prompt or None - - Returns: - - * String representation of prompt type - """ - - if prompt is None: - return 'None' - - required_keys_dict = { - 'TextPrompt': {'prompt'}, - 'TokensPrompt': {'prompt_token_ids'}, - 'ExplicitEncoderDecoder': {'encoder_prompt', 'decoder_prompt'}, - } - - if isinstance(prompt, dict): - for (ptype, required_keys) in required_keys_dict.items(): - # Ignore type checking in the conditional below because type - # checker does not understand that is_dict(prompt) narrows - # down the possible types - if _has_required_keys( - prompt, # type: ignore - required_keys): - return ptype - - raise ValueError(f"Invalid prompt {prompt}, valid types are " - "required_keys_dict={required_keys_dict}") - - if isinstance(prompt, str): - return "str" - - raise ValueError(f"Invalid prompt {prompt}") - - class LLMInputs(TypedDict): """ The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. + + This specifies the data required for decoder-only models. """ prompt_token_ids: List[int] """The token IDs of the prompt.""" @@ -213,7 +116,21 @@ class LLMInputs(TypedDict): The original prompt text corresponding to the token IDs, if available. """ - encoder_prompt_token_ids: NotRequired[List[int]] + multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +class EncoderDecoderLLMInputs(LLMInputs): + """ + The inputs in :class:`~vllm.LLMEngine` before they are + passed to the model executor. + + This specifies the required data for encoder-decoder models. + """ + encoder_prompt_token_ids: List[int] """The token IDs of the encoder prompt.""" encoder_prompt: NotRequired[Optional[str]] @@ -222,20 +139,40 @@ class LLMInputs(TypedDict): available. """ - multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] - """ - Optional multi-modal data to pass to the model, - if the model supports it. - """ + +_T1 = TypeVar("_T1", + bound=SingletonPromptInputs, + default=SingletonPromptInputs) +_T2 = TypeVar("_T2", + bound=SingletonPromptInputs, + default=SingletonPromptInputs) -def is_valid_encoder_decoder_llm_inputs(inputs: LLMInputs) -> bool: +def build_explicit_enc_dec_prompt( + encoder_prompt: _T1, + decoder_prompt: Optional[_T2], +) -> ExplicitEncoderDecoderPrompt[_T1, _T2]: + return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt, + decoder_prompt=decoder_prompt) + + +def zip_enc_dec_prompts( + enc_prompts: Iterable[_T1], + dec_prompts: Iterable[Optional[_T2]], +) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]: """ - Return True if the LLMInputs instance has the correct configuration - for encoder/decoder. + Zip encoder and decoder prompts together into a list of + :class:`ExplicitEncoderDecoderPrompt` instances. """ + return [ + build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt) + for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts) + ] + - # True if encoder prompt token ids field exists & - # is not None - return ('encoder_prompt_token_ids' in inputs - and inputs['encoder_prompt_token_ids'] is not None) +def to_enc_dec_tuple_list( + enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]], +) -> List[Tuple[_T1, Optional[_T2]]]: + return [(enc_dec_prompt["encoder_prompt"], + enc_dec_prompt["decoder_prompt"]) + for enc_dec_prompt in enc_dec_prompts] diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py new file mode 100644 index 000000000000..b5e8ef786059 --- /dev/null +++ b/vllm/inputs/parse.py @@ -0,0 +1,75 @@ +from typing import List, Literal, Sequence, TypedDict, Union, overload + +from typing_extensions import TypeIs + +from vllm.utils import is_list_of + +from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, + LLMInputs, PromptInputs) + + +class ParsedText(TypedDict): + content: str + is_tokens: Literal[False] + + +class ParsedTokens(TypedDict): + content: List[int] + is_tokens: Literal[True] + + +@overload +def parse_and_batch_prompt( + prompt: Union[str, List[str]]) -> Sequence[ParsedText]: + ... + + +@overload +def parse_and_batch_prompt( + prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]: + ... + + +def parse_and_batch_prompt( + prompt: Union[str, List[str], List[int], List[List[int]]], +) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: + if isinstance(prompt, str): + # case 1: a string + return [ParsedText(content=prompt, is_tokens=False)] + + if isinstance(prompt, list): + if len(prompt) == 0: + raise ValueError("please provide at least one prompt") + + if is_list_of(prompt, str): + # case 2: array of strings + return [ + ParsedText(content=elem, is_tokens=False) for elem in prompt + ] + if is_list_of(prompt, int): + # case 3: array of tokens + return [ParsedTokens(content=prompt, is_tokens=True)] + if is_list_of(prompt, list): + if len(prompt[0]) == 0: + raise ValueError("please provide at least one prompt") + + if is_list_of(prompt[0], int): + # case 4: array of token arrays + return [ + ParsedTokens(content=elem, is_tokens=True) + for elem in prompt + ] + + raise ValueError("prompt must be a string, array of strings, " + "array of tokens, or array of token arrays") + + +def is_explicit_encoder_decoder_prompt( + inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]: + return isinstance(inputs, dict) and "encoder_prompt" in inputs + + +def is_valid_encoder_decoder_llm_inputs( + inputs: Union[LLMInputs, EncoderDecoderLLMInputs], +) -> TypeIs[EncoderDecoderLLMInputs]: + return "encoder_prompt_token_ids" in inputs diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 6fdacd446978..db0d6b429d64 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,7 +1,7 @@ from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type, Union, overload, runtime_checkable) -from typing_extensions import TypeGuard +from typing_extensions import TypeIs from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig from vllm.logger import init_logger @@ -37,18 +37,18 @@ def __call__(self, *, multimodal_config: MultiModalConfig) -> None: @overload -def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]: +def supports_vision(model: Type[object]) -> TypeIs[Type[SupportsVision]]: ... @overload -def supports_vision(model: object) -> TypeGuard[SupportsVision]: +def supports_vision(model: object) -> TypeIs[SupportsVision]: ... def supports_vision( model: Union[Type[object], object], -) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]: +) -> Union[TypeIs[Type[SupportsVision]], TypeIs[SupportsVision]]: if isinstance(model, type): return isinstance(model, _SupportsVisionType) @@ -94,18 +94,18 @@ def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: @overload -def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]: +def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]: ... @overload -def supports_lora(model: object) -> TypeGuard[SupportsLoRA]: +def supports_lora(model: object) -> TypeIs[SupportsLoRA]: ... def supports_lora( model: Union[Type[object], object], -) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]: +) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]: result = _supports_lora(model) if not result: @@ -137,7 +137,7 @@ def supports_lora( def _supports_lora( model: Union[Type[object], object], -) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]: +) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]: if isinstance(model, type): return isinstance(model, _SupportsLoRAType) @@ -172,18 +172,18 @@ def __init__(self, @overload -def has_inner_state(model: object) -> TypeGuard[HasInnerState]: +def has_inner_state(model: object) -> TypeIs[HasInnerState]: ... @overload -def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]: +def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]: ... def has_inner_state( model: Union[Type[object], object] -) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]: +) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]: if isinstance(model, type): return isinstance(model, _HasInnerStateType) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index b6a3909e9563..db50229bda31 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -10,6 +10,7 @@ from vllm.logger import init_logger from vllm.transformers_utils.image_processor import get_image_processor from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.utils import is_list_of from .base import MultiModalInputs, MultiModalPlugin @@ -113,7 +114,8 @@ def _get_hf_image_processor(self, model_config: ModelConfig): def _default_input_mapper(self, ctx: InputContext, data: object) -> MultiModalInputs: model_config = ctx.model_config - if isinstance(data, (Image.Image, list)): + + if isinstance(data, Image.Image) or is_list_of(data, Image.Image): image_processor = self._get_hf_image_processor(model_config) if image_processor is None: raise RuntimeError("No HuggingFace processor is available " @@ -127,7 +129,7 @@ def _default_input_mapper(self, ctx: InputContext, raise return MultiModalInputs(batch_data) - elif isinstance(data, torch.Tensor): + elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor): raise NotImplementedError("Embeddings input is not supported yet") raise TypeError(f"Invalid image type: {type(data)}") diff --git a/vllm/sequence.py b/vllm/sequence.py index 634785533382..fbd148001cc7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -11,7 +11,7 @@ import torch -from vllm.inputs import is_valid_encoder_decoder_llm_inputs +from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest diff --git a/vllm/utils.py b/vllm/utils.py index 4137aaec8a93..f8251284af4a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -17,8 +17,8 @@ from functools import lru_cache, partial, wraps from platform import uname from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, - Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar, - Union, overload) + Hashable, List, Literal, Optional, OrderedDict, Set, Tuple, + Type, TypeVar, Union, overload) from uuid import uuid4 import numpy as np @@ -26,12 +26,10 @@ import psutil import torch import torch.types -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, TypeIs, assert_never import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.inputs import (ExplicitEncoderDecoderPrompt, PromptInputs, - SingletonPromptInputs) from vllm.logger import enable_trace_function_call, init_logger logger = init_logger(__name__) @@ -812,6 +810,24 @@ def get_dtype_size(dtype: torch.dtype) -> int: return torch.tensor([], dtype=dtype).element_size() +# `collections` helpers +def is_list_of( + value: object, + typ: Type[T], + *, + check: Literal["first", "all"] = "first", +) -> TypeIs[List[T]]: + if not isinstance(value, list): + return False + + if check == "first": + return len(value) == 0 or isinstance(value[0], typ) + elif check == "all": + return all(isinstance(v, typ) for v in value) + + assert_never(check) + + def merge_dicts(dict1: Dict[K, List[T]], dict2: Dict[K, List[T]]) -> Dict[K, List[T]]: """Merge 2 dicts that have key -> List of items. @@ -959,6 +975,7 @@ def enable_trace_function_call_for_thread() -> None: enable_trace_function_call(log_path) +# `functools` helpers def identity(value: T) -> T: return value @@ -1080,50 +1097,3 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, """Utility function to run async task in a lock""" async with lock: return await task(*args, **kwargs) - - -def is_encoder_decoder_model_config(model_config) -> bool: - ''' - Extract the HF encoder/decoder model flag from the ModelConfig instance. - Return False if model_config is None. - ''' - return model_config is not None and \ - getattr(model_config.hf_config, - "is_encoder_decoder", - False) - - -def is_embedding_model_config(model_config) -> bool: - ''' - Extract the embedding model flag from the ModelConfig instance. - Return False if model_config is None. - ''' - return model_config is not None and \ - model_config.embedding_mode - - -def build_explicit_enc_dec_prompt( - encoder_prompt: SingletonPromptInputs, - decoder_prompt: SingletonPromptInputs, -) -> ExplicitEncoderDecoderPrompt: - return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt, - decoder_prompt=decoder_prompt) - - -def zip_enc_dec_prompt_lists( - enc_prompt_list: List[SingletonPromptInputs], - dec_prompt_list: List[SingletonPromptInputs], -) -> List[ExplicitEncoderDecoderPrompt]: - return [ - build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt) - for (encoder_prompt, - decoder_prompt) in zip(enc_prompt_list, dec_prompt_list) - ] - - -def to_enc_dec_tuple_list( - enc_dec_prompts: List[ExplicitEncoderDecoderPrompt], -) -> List[Tuple[PromptInputs, PromptInputs]]: - return [(enc_dec_prompt['encoder_prompt'], - enc_dec_prompt['decoder_prompt']) - for enc_dec_prompt in enc_dec_prompts] diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index ad6f6750ff98..45751eceacbc 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -19,8 +19,6 @@ from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest -from vllm.utils import (is_embedding_model_config, - is_encoder_decoder_model_config) from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner @@ -113,10 +111,10 @@ def __init__( self.gpu_cache: Optional[List[List[torch.Tensor]]] = None def _is_encoder_decoder_model(self): - return is_encoder_decoder_model_config(self.model_config) + return self.model_config.is_encoder_decoder_model def _is_embedding_model(self): - return is_embedding_model_config(self.model_config) + return self.model_config.is_embedding_model def init_device(self) -> None: if self.device_config.device.type == "cuda":