From 200c0babc32bb992437b2e9560af9afbd25c8fa5 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 26 Sep 2024 00:36:47 +0800 Subject: [PATCH] rename PromptInputs and inputs with backward compatibility (#8760) --- benchmarks/benchmark_latency.py | 8 +- .../dev/multimodal/multimodal_index.rst | 2 +- .../dev/offline_inference/llm_inputs.rst | 2 +- docs/source/models/vlm.rst | 2 +- tests/async_engine/test_async_llm_engine.py | 8 +- tests/entrypoints/llm/test_encode.py | 34 ------ tests/entrypoints/llm/test_generate.py | 37 ------ tests/mq_llm_engine/test_error_handling.py | 12 +- tests/mq_llm_engine/utils.py | 2 +- vllm/__init__.py | 4 +- vllm/engine/async_llm_engine.py | 110 +++++++++++++++--- vllm/engine/llm_engine.py | 52 +++++++-- vllm/engine/multiprocessing/__init__.py | 61 +++++++++- vllm/engine/multiprocessing/client.py | 95 ++++++++++++--- vllm/engine/multiprocessing/engine.py | 2 +- vllm/engine/protocol.py | 8 +- vllm/entrypoints/llm.py | 68 +++++------ vllm/inputs/__init__.py | 20 +++- vllm/inputs/data.py | 48 +++++--- vllm/inputs/parse.py | 22 ++-- vllm/inputs/preprocess.py | 86 +++++++------- 21 files changed, 438 insertions(+), 245 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index a39d1cf842f06..eadf994cacd34 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -11,7 +11,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -61,7 +61,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_inputs: List[PromptInputs] = [{ + dummy_prompts: List[PromptType] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] @@ -74,13 +74,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(dummy_inputs, + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(dummy_inputs, + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index 241b2ccd0991e..e112b43aade5e 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -8,7 +8,7 @@ Multi-Modality vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models ` -via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`. +via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`. Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities by following :ref:`this guide `. diff --git a/docs/source/dev/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst index 9adf82d43f3e0..0d47281db485e 100644 --- a/docs/source/dev/offline_inference/llm_inputs.rst +++ b/docs/source/dev/offline_inference/llm_inputs.rst @@ -1,7 +1,7 @@ LLM Inputs ========== -.. autodata:: vllm.inputs.PromptInputs +.. autodata:: vllm.inputs.PromptType .. autoclass:: vllm.inputs.TextPrompt :show-inheritance: diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 08db891665044..ca5b125369c85 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model. -To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`: +To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`: * ``prompt``: The prompt should follow the format that is documented on HuggingFace. * ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 6cae76f74603d..1903a7582dc89 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -86,17 +86,19 @@ class MockAsyncLLMEngine(AsyncLLMEngine): @pytest.mark.asyncio async def test_new_requests_event(): + params = SamplingParams() + engine = MockAsyncLLMEngine() engine.start_background_loop() await asyncio.sleep(0.01) assert engine.engine.step_calls == 0 - await engine.add_request("1", "", None) + await engine.add_request("1", "", params) await asyncio.sleep(0.01) assert engine.engine.add_request_calls == 1 assert engine.engine.step_calls == 1 - await engine.add_request("2", "", None) + await engine.add_request("2", "", params) engine.engine.generate("2") await asyncio.sleep(0) await asyncio.sleep(0) @@ -111,7 +113,7 @@ async def test_new_requests_event(): await asyncio.sleep(0.001) assert engine.engine.step_calls == old_step_calls - await engine.add_request("3", "", None) + await engine.add_request("3", "", params) await asyncio.sleep(0.01) assert engine.engine.add_request_calls == 3 assert engine.engine.step_calls == old_step_calls + 1 diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index d1056a0490509..1885f2e168d80 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -49,21 +49,6 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput], assert [o.outputs for o in o1] == [o.outputs for o in o2] -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt', PROMPTS) -def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): - pooling_params = PoolingParams() - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params) - - v2_output = llm.encode(prompt, pooling_params=pooling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup @pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, @@ -79,25 +64,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, assert_outputs_equal(v1_output, v2_output) -@pytest.mark.skip_global_cleanup -def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): - pooling_params = PoolingParams() - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params) - - v2_output = llm.encode(PROMPTS, pooling_params=pooling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.encode( - [{ - "prompt": p - } for p in PROMPTS], - pooling_params=pooling_params, - ) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): pooling_params = PoolingParams() diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index cd989225e2483..6543c4bb1b58e 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -47,23 +47,6 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): assert [o.outputs for o in o1] == [o.outputs for o in o2] -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt', PROMPTS) -def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.generate(prompts=prompt, - sampling_params=sampling_params) - - v2_output = llm.generate(prompt, sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.generate({"prompt": prompt}, - sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup @pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, @@ -79,26 +62,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, assert_outputs_equal(v1_output, v2_output) -@pytest.mark.skip_global_cleanup -def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.generate(prompts=PROMPTS, - sampling_params=sampling_params) - - v2_output = llm.generate(PROMPTS, sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.generate( - [{ - "prompt": p - } for p in PROMPTS], - sampling_params=sampling_params, - ) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): sampling_params = SamplingParams(temperature=0.0, top_p=1.0) diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 76b2f494d5b25..616a15a1328de 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket): # Throws an error in first forward pass. with pytest.raises(RAISED_ERROR): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket): # Engine is errored, should get ENGINE_DEAD_ERROR. with pytest.raises(MQEngineDeadError): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket): # Generate call should throw ENGINE_DEAD_ERROR with pytest.raises(MQEngineDeadError): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -160,7 +160,7 @@ async def test_failed_abort(tmp_socket): # with reference to the original KeyError("foo") with pytest.raises(MQEngineDeadError) as execinfo: async for _ in client.generate( - inputs="Hello my name is", + prompt="Hello my name is", sampling_params=SamplingParams(max_tokens=10), request_id=uuid.uuid4()): pass @@ -183,7 +183,7 @@ async def test_bad_request(tmp_socket): # Invalid request should fail, but not crash the server. with pytest.raises(ValueError): - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-1", lora_request=LoRARequest( @@ -192,7 +192,7 @@ async def test_bad_request(tmp_socket): pass # This request should be okay. - async for _ in client.generate(inputs="Hello my name is", + async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-2"): pass diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py index e27fd77923412..3ffa126070ca0 100644 --- a/tests/mq_llm_engine/utils.py +++ b/tests/mq_llm_engine/utils.py @@ -20,7 +20,7 @@ async def generate( count = 0 async for out in client.generate( request_id=request_id, - inputs="Hello my name is Robert and", + prompt="Hello my name is Robert and", sampling_params=SamplingParams(max_tokens=num_tokens, temperature=0)): diff --git a/vllm/__init__.py b/vllm/__init__.py index 90363b3e49b73..8f477ea84756d 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,7 +5,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt +from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import (CompletionOutput, EmbeddingOutput, EmbeddingRequestOutput, RequestOutput) @@ -19,7 +19,7 @@ "__version_tuple__", "LLM", "ModelRegistry", - "PromptInputs", + "PromptType", "TextPrompt", "TokensPrompt", "SamplingParams", diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 34e7e05341f02..54c5af2fe3665 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -2,8 +2,8 @@ import time import weakref from functools import partial -from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, - Mapping, Optional, Set, Tuple, Type, Union) +from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable, + List, Mapping, Optional, Set, Tuple, Type, Union, overload) from weakref import ReferenceType import vllm.envs as envs @@ -17,7 +17,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -28,7 +28,7 @@ from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils import weak_bind +from vllm.utils import deprecate_kwargs, weak_bind logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S @@ -402,17 +402,54 @@ async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() + @overload # DEPRECATED async def add_request_async( self, request_id: str, - inputs: PromptInputs, + *, + inputs: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + ... + + @overload + async def add_request_async( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + async def add_request_async( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + *, + inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: """Async version of :meth:`add_request`.""" + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") @@ -420,7 +457,7 @@ async def add_request_async( arrival_time = time.time() preprocessed_inputs = await self.input_preprocessor.preprocess_async( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -774,16 +811,55 @@ async def run_engine_loop(engine_ref: ReferenceType): # This method does not need to be async, but kept that way # for backwards compatibility. - async def add_request( + @overload # DEPRECATED + def add_request( self, request_id: str, - inputs: PromptInputs, + *, + inputs: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Coroutine[None, None, AsyncGenerator[Union[ + RequestOutput, EmbeddingRequestOutput], None]]: + ... + + @overload + def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Coroutine[None, None, AsyncGenerator[Union[ + RequestOutput, EmbeddingRequestOutput], None]]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + async def add_request( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + *, + inputs: Optional[PromptType] = None, # DEPRECATED ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + if not self.is_running: if self.start_engine_loop: self.start_background_loop() @@ -797,7 +873,7 @@ async def add_request( stream = self._request_tracker.add_request( request_id, verbose=self.log_requests, - inputs=inputs, + prompt=prompt, params=params, arrival_time=arrival_time or time.time(), lora_request=lora_request, @@ -808,7 +884,7 @@ async def add_request( async def generate( self, - inputs: PromptInputs, + prompt: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -822,8 +898,7 @@ async def generate( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -881,7 +956,7 @@ async def generate( """ async for output in await self.add_request( request_id, - inputs, + prompt, sampling_params, lora_request=lora_request, trace_headers=trace_headers, @@ -891,7 +966,7 @@ async def generate( async def encode( self, - inputs: PromptInputs, + prompt: PromptType, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -904,8 +979,7 @@ async def encode( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -959,7 +1033,7 @@ async def encode( """ async for output in await self.add_request( request_id, - inputs, + prompt, pooling_params, lora_request=lora_request, trace_headers=trace_headers, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c341b236003a3..7266d8e18a8ab 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,7 +6,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, Iterable, List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence -from typing import Set, Type, Union +from typing import Set, Type, Union, overload import torch from typing_extensions import TypeVar @@ -29,7 +29,7 @@ from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, - InputRegistry, LLMInputs, PromptInputs) + InputRegistry, LLMInputs, PromptType) from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -51,7 +51,7 @@ BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter, Device, weak_bind +from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -689,16 +689,51 @@ def _add_processed_request( def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() + @overload # DEPRECATED def add_request( self, request_id: str, - inputs: PromptInputs, + *, + inputs: PromptType, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + ) -> None: + ... + + @overload + def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def add_request( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + *, + inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: """Add a request to the engine's request pool. @@ -708,8 +743,7 @@ def add_request( Args: request_id: The unique ID of the request. - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. params: Parameters for sampling or pooling. :class:`~vllm.SamplingParams` for text generation. @@ -744,6 +778,10 @@ def add_request( >>> # continue the request processing >>> ... """ + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") @@ -756,7 +794,7 @@ def add_request( arrival_time = time.time() preprocessed_inputs = self.input_preprocessor.preprocess( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 165e6cc2146c3..05067a6a192d5 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -1,13 +1,14 @@ from dataclasses import dataclass from enum import Enum -from typing import List, Mapping, Optional, Union +from typing import List, Mapping, Optional, Union, overload from vllm import PoolingParams -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams +from vllm.utils import deprecate_kwargs VLLM_RPC_SUCCESS_STR = "SUCCESS" @@ -23,13 +24,67 @@ class MQEngineDeadError(RuntimeError): @dataclass class RPCProcessRequest: - inputs: PromptInputs + prompt: PromptType params: Union[SamplingParams, PoolingParams] request_id: str lora_request: Optional[LoRARequest] = None trace_headers: Optional[Mapping[str, str]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None + @overload # DEPRECATED + def __init__( + self, + *, + inputs: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + ... + + @overload + def __init__( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def __init__( + self, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + *, + inputs: Optional[PromptType] = None, # DEPRECATED + ) -> None: + if inputs is not None: + prompt = inputs + assert (prompt is not None and params is not None + and request_id is not None) + + super().__init__() + + self.prompt = prompt + self.params = params + self.request_id = request_id + self.lora_request = lora_request + self.trace_headers = trace_headers + self.prompt_adapter_request = prompt_adapter_request + @dataclass class RPCError: diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 7e397cf408fba..239ca52ef13e2 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -3,7 +3,7 @@ import pickle from contextlib import contextmanager, suppress from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, - Union) + Union, overload) import cloudpickle import zmq @@ -24,13 +24,14 @@ RPCStartupRequest, RPCStartupResponse) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.utils import deprecate_kwargs logger = init_logger(__name__) @@ -366,14 +367,45 @@ def errored(self) -> bool: def dead_error(self) -> BaseException: return ENGINE_DEAD_ERROR(self._errored_with) + @overload # DEPRECATED def generate( self, - inputs: PromptInputs, + *, + inputs: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> AsyncGenerator[RequestOutput, None]: + ... + + @overload + def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> AsyncGenerator[RequestOutput, None]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def generate( + self, + prompt: Optional[PromptType] = None, + sampling_params: Optional[SamplingParams] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + *, + inputs: Optional[PromptType] = None # DEPRECATED ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request. @@ -382,8 +414,7 @@ def generate( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -392,17 +423,51 @@ def generate( prompt_adapter_request: Prompt Adapter request to use for generation, if any. """ - return self._process_request(inputs, sampling_params, request_id, + if inputs is not None: + prompt = inputs + assert (prompt is not None and sampling_params is not None + and request_id is not None) + + return self._process_request(prompt, sampling_params, request_id, lora_request, trace_headers, prompt_adapter_request) + @overload # DEPRECATED def encode( self, - inputs: PromptInputs, + *, + inputs: PromptType, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ... + + @overload + def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def encode( + self, + prompt: Optional[PromptType] = None, + pooling_params: Optional[PoolingParams] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + *, + inputs: Optional[PromptType] = None # DEPRECATED ) -> AsyncGenerator[EmbeddingRequestOutput, None]: """Generate outputs for a request from an embedding model. @@ -411,8 +476,7 @@ def encode( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. See - :class:`~vllm.inputs.PromptInputs` + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -423,12 +487,17 @@ def encode( The output `EmbeddingRequestOutput` objects from the LLMEngine for the request. """ - return self._process_request(inputs, pooling_params, request_id, + if inputs is not None: + prompt = inputs + assert (prompt is not None and pooling_params is not None + and request_id is not None) + + return self._process_request(prompt, pooling_params, request_id, lora_request, trace_headers) async def _process_request( self, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], request_id: str, lora_request: Optional[LoRARequest] = None, @@ -461,7 +530,7 @@ async def _process_request( request_bytes = pickle.dumps( RPCProcessRequest( - inputs=inputs, + prompt=prompt, params=params, request_id=request_id, lora_request=lora_request, diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index b1dd9915cbbf5..b406d4a759667 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -271,7 +271,7 @@ def _handle_process_request(self, request: RPCProcessRequest): try: self.engine.add_request( request_id=request_id, - inputs=request.inputs, + prompt=request.prompt, params=request.params, lora_request=request.lora_request, trace_headers=request.trace_headers, diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 70444faa670a2..d0bbeb357b506 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -3,7 +3,7 @@ from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.inputs.data import PromptInputs +from vllm.inputs.data import PromptType from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -35,19 +35,19 @@ def dead_error(self) -> BaseException: def generate( self, - inputs: PromptInputs, + prompt: PromptType, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[RequestOutput, None]: - """Generates outputs for a request""" + """Generate outputs for a request.""" ... def encode( self, - inputs: PromptInputs, + prompt: PromptType, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 77ae7b088398a..f4943cb38da44 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -12,7 +12,7 @@ apply_hf_chat_template, apply_mistral_chat_template, parse_chat_messages) -from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt +from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -293,8 +293,8 @@ def generate( @overload def generate( self, - inputs: Union[PromptInputs, Sequence[PromptInputs]], - /, # We may enable `inputs` keyword after removing the old API + prompts: Union[PromptType, Sequence[PromptType]], + /, *, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -304,14 +304,13 @@ def generate( ... @deprecate_kwargs( - "prompts", "prompt_token_ids", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'inputs' parameter instead.", + additional_message="Please use the 'prompts' parameter instead.", ) def generate( self, - prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[Union[PromptType, Sequence[PromptType]], Optional[Union[str, List[str]]]] = None, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -330,7 +329,9 @@ def generate( into a single list and pass it to this method. Args: - inputs: A list of inputs to generate completions for. + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. @@ -358,12 +359,13 @@ def generate( "models (XForCausalLM, XForConditionalGeneration).") if prompt_token_ids is not None: - inputs = self._convert_v1_inputs( + parsed_prompts = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) if isinstance(guided_options_request, dict): if len(guided_options_request) > 1: @@ -378,7 +380,7 @@ def generate( sampling_params = SamplingParams() self._validate_and_add_requests( - inputs=inputs, + prompts=parsed_prompts, params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -648,8 +650,8 @@ def encode( @overload def encode( self, - inputs: Union[PromptInputs, Sequence[PromptInputs]], - /, # We may enable `inputs` keyword after removing the old API + prompts: Union[PromptType, Sequence[PromptType]], + /, *, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -659,14 +661,13 @@ def encode( ... @deprecate_kwargs( - "prompts", "prompt_token_ids", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'inputs' parameter instead.", + additional_message="Please use the 'prompts' parameter instead.", ) def encode( self, - prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[Union[PromptType, Sequence[PromptType]], Optional[Union[str, List[str]]]] = None, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -682,9 +683,9 @@ def encode( into a single list and pass it to this method. Args: - inputs: The inputs to the LLM. You may pass a sequence of inputs for - batch inference. See :class:`~vllm.inputs.PromptInputs` - for more details about the format of each input. + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. @@ -707,19 +708,20 @@ def encode( ) if prompt_token_ids is not None: - inputs = self._convert_v1_inputs( + parsed_prompts = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() self._validate_and_add_requests( - inputs=inputs, + prompts=parsed_prompts, params=pooling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -763,9 +765,9 @@ def _convert_v1_inputs( raise ValueError("Either prompts or prompt_token_ids must be " "provided.") - inputs: List[PromptInputs] = [] + parsed_prompts: List[PromptType] = [] for i in range(num_requests): - item: PromptInputs + item: PromptType if prompts is not None: item = TextPrompt(prompt=prompts[i]) @@ -774,13 +776,13 @@ def _convert_v1_inputs( else: raise AssertionError - inputs.append(item) + parsed_prompts.append(item) - return inputs + return parsed_prompts def _validate_and_add_requests( self, - inputs: Union[PromptInputs, Sequence[PromptInputs]], + prompts: Union[PromptType, Sequence[PromptType]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], @@ -788,11 +790,11 @@ def _validate_and_add_requests( guided_options: Optional[GuidedDecodingRequest] = None, priority: Optional[List[int]] = None, ) -> None: - if isinstance(inputs, (str, dict)): + if isinstance(prompts, (str, dict)): # Convert a single prompt to a list. - inputs = [inputs] + prompts = [prompts] - num_requests = len(inputs) + num_requests = len(prompts) if isinstance(params, list) and len(params) != num_requests: raise ValueError("The lengths of prompts and params " "must be the same.") @@ -809,9 +811,9 @@ def _validate_and_add_requests( sp.output_kind = RequestOutputKind.FINAL_ONLY # Add requests to the engine. - for i, request_inputs in enumerate(inputs): + for i, prompt in enumerate(prompts): self._add_request( - request_inputs, + prompt, params[i] if isinstance(params, Sequence) else params, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, @@ -821,7 +823,7 @@ def _validate_and_add_requests( def _add_request( self, - inputs: PromptInputs, + prompt: PromptType, params: Union[SamplingParams, PoolingParams], lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -830,7 +832,7 @@ def _add_request( request_id = str(next(self.request_counter)) self.llm_engine.add_request( request_id, - inputs, + prompt, params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 0b08e9691f915..a8c8672cb5fe7 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,5 +1,5 @@ from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, + LLMInputs, PromptType, SingletonPrompt, TextPrompt, TokensPrompt, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry @@ -16,8 +16,8 @@ __all__ = [ "TextPrompt", "TokensPrompt", - "PromptInputs", - "SingletonPromptInputs", + "PromptType", + "SingletonPrompt", "ExplicitEncoderDecoderPrompt", "LLMInputs", "EncoderDecoderLLMInputs", @@ -28,3 +28,17 @@ "InputContext", "InputRegistry", ] + + +def __getattr__(name: str): + if name == "PromptInput": + import warnings + + msg = ("PromptInput has been renamed to PromptType. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PromptType + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 75ab0c770155b..9e6238cb85ac0 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -33,7 +33,7 @@ class TokensPrompt(TypedDict): """ -SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt] +SingletonPrompt = Union[str, TextPrompt, TokensPrompt] """ Set of possible schemas for a single LLM input: @@ -46,7 +46,7 @@ class TokensPrompt(TypedDict): the user desires to express both the encoder & decoder prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt` -A prompt of type :class:`SingletonPromptInputs` may be employed +A prompt of type :class:`SingletonPrompt` may be employed as (1) input to a decoder-only model, (2) input to the encoder of an encoder/decoder model, in the scenario where the decoder-prompt is not specified explicitly, or @@ -55,33 +55,33 @@ class TokensPrompt(TypedDict): """ _T1_co = TypeVar("_T1_co", - bound=SingletonPromptInputs, - default=SingletonPromptInputs, + bound=SingletonPrompt, + default=SingletonPrompt, covariant=True) _T2_co = TypeVar("_T2_co", - bound=SingletonPromptInputs, - default=SingletonPromptInputs, + bound=SingletonPrompt, + default=SingletonPrompt, covariant=True) # TODO: Make fields ReadOnly once mypy supports it class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): - """Represents an encoder/decoder model input prompt, - comprising an explicit encoder prompt and a - decoder prompt. + """ + Represents an encoder/decoder model input prompt, + comprising an explicit encoder prompt and a decoder prompt. The encoder and decoder prompts, respectively, may formatted according to any of the - :class:`SingletonPromptInputs` schemas, and are not + :class:`SingletonPrompt` schemas, and are not required to have the same schema. Only the encoder prompt may have multi-modal data. Note that an :class:`ExplicitEncoderDecoderPrompt` may not be used as an input to a decoder-only model, - and that the `encoder_prompt` and `decoder_prompt` + and that the :code:`encoder_prompt` and :code:`decoder_prompt` fields of this data structure themselves must be - :class:`SingletonPromptInputs` instances. + :class:`SingletonPrompt` instances. """ encoder_prompt: _T1_co @@ -89,7 +89,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): decoder_prompt: Optional[_T2_co] -PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt] +PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] """ Set of possible schemas for an LLM input, including both decoder-only and encoder/decoder input types: @@ -140,12 +140,8 @@ class EncoderDecoderLLMInputs(LLMInputs): """ -_T1 = TypeVar("_T1", - bound=SingletonPromptInputs, - default=SingletonPromptInputs) -_T2 = TypeVar("_T2", - bound=SingletonPromptInputs, - default=SingletonPromptInputs) +_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) +_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) def build_explicit_enc_dec_prompt( @@ -176,3 +172,17 @@ def to_enc_dec_tuple_list( return [(enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"]) for enc_dec_prompt in enc_dec_prompts] + + +def __getattr__(name: str): + if name == "PromptInput": + import warnings + + msg = ("PromptInput has been renamed to PromptType. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PromptType + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index ac9d355c64c80..e5fa1e4184277 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -5,7 +5,7 @@ from vllm.utils import is_list_of from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, + LLMInputs, PromptType, SingletonPrompt, TextPrompt, TokensPrompt) @@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict): def parse_singleton_prompt( - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, ) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: - if isinstance(inputs, str): - return ParsedStrPrompt(type="str", content=inputs) - elif isinstance(inputs, dict): - if "prompt_token_ids" in inputs: + if isinstance(prompt, str): + return ParsedStrPrompt(type="str", content=prompt) + elif isinstance(prompt, dict): + if "prompt_token_ids" in prompt: return ParsedTokensPrompt(type="tokens", - content=inputs) # type: ignore - elif "prompt" in inputs: - return ParsedTextPrompt(type="text", content=inputs) + content=prompt) # type: ignore + elif "prompt" in prompt: + return ParsedTextPrompt(type="text", content=prompt) raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") def is_explicit_encoder_decoder_prompt( - inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]: - return isinstance(inputs, dict) and "encoder_prompt" in inputs + prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: + return isinstance(prompt, dict) and "encoder_prompt" in prompt def is_valid_encoder_decoder_llm_inputs( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index be2aa5f8cb7d0..1f1b048d37e9b 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -9,8 +9,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, - SingletonPromptInputs) +from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType, + SingletonPrompt) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt if TYPE_CHECKING: @@ -206,7 +206,7 @@ async def _tokenize_prompt_async( def _extract_prompt_components( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: @@ -216,7 +216,7 @@ def _extract_prompt_components( Arguments: * request_id - * inputs: single encoder or decoder input prompt + * prompt: single encoder or decoder input prompt * lora_request: this is only valid for decoder prompts Returns: @@ -226,24 +226,24 @@ def _extract_prompt_components( * multi_modal_data ''' - parsed = parse_singleton_prompt(inputs) + parsed = parse_singleton_prompt(prompt) if parsed["type"] == "str": - prompt = parsed["content"] + prompt_text = parsed["content"] prompt_token_ids = self._tokenize_prompt( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) multi_modal_data = None elif parsed["type"] == "tokens": - prompt = None + prompt_text = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") elif parsed["type"] == "text": - prompt = parsed["content"]["prompt"] + prompt_text = parsed["content"]["prompt"] prompt_token_ids = self._tokenize_prompt( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) @@ -251,33 +251,33 @@ def _extract_prompt_components( else: assert_never(parsed) - return prompt, prompt_token_ids, multi_modal_data + return prompt_text, prompt_token_ids, multi_modal_data async def _extract_prompt_components_async( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: """Async version of :meth:`_extract_prompt_components`.""" - parsed = parse_singleton_prompt(inputs) + parsed = parse_singleton_prompt(prompt) if parsed["type"] == "str": - prompt = parsed["content"] + prompt_text = parsed["content"] prompt_token_ids = await self._tokenize_prompt_async( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) multi_modal_data = None elif parsed["type"] == "tokens": - prompt = None + prompt_text = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") elif parsed["type"] == "text": - prompt = parsed["content"]["prompt"] + prompt_text = parsed["content"]["prompt"] prompt_token_ids = await self._tokenize_prompt_async( - prompt, + prompt_text, request_id=request_id, lora_request=lora_request, ) @@ -285,7 +285,7 @@ async def _extract_prompt_components_async( else: assert_never(parsed) - return prompt, prompt_token_ids, multi_modal_data + return prompt_text, prompt_token_ids, multi_modal_data def _build_enc_dec_llm_inputs( self, @@ -311,7 +311,7 @@ def _build_enc_dec_llm_inputs( def _process_encoder_decoder_prompt( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, ) -> EncoderDecoderLLMInputs: ''' @@ -339,7 +339,7 @@ def _process_encoder_decoder_prompt( Arguments: - * inputs: an input prompt + * prompt: an input prompt * request_id Returns: @@ -350,13 +350,13 @@ def _process_encoder_decoder_prompt( encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): encoder_comps = self._extract_prompt_components( - inputs["encoder_prompt"], + prompt["encoder_prompt"], request_id=request_id, ) - if (decoder_input := inputs["decoder_prompt"]) is None: + if (decoder_input := prompt["decoder_prompt"]) is None: decoder_comps = None, None, None else: decoder_comps = self._extract_prompt_components( @@ -365,7 +365,7 @@ def _process_encoder_decoder_prompt( ) else: encoder_comps = self._extract_prompt_components( - inputs, + prompt, request_id=request_id, ) @@ -375,20 +375,20 @@ def _process_encoder_decoder_prompt( async def _process_encoder_decoder_prompt_async( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, ) -> EncoderDecoderLLMInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): encoder_task = self._extract_prompt_components_async( - inputs["encoder_prompt"], + prompt["encoder_prompt"], request_id=request_id, ) - if (decoder_input := inputs["decoder_prompt"]) is None: + if (decoder_input := prompt["decoder_prompt"]) is None: encoder_comps = await encoder_task decoder_comps = None, None, None else: @@ -401,7 +401,7 @@ async def _process_encoder_decoder_prompt_async( encoder_task, decoder_task) else: encoder_comps = await self._extract_prompt_components_async( - inputs, + prompt, request_id=request_id, ) @@ -425,7 +425,7 @@ def _build_decoder_only_llm_inputs( def _process_decoder_only_prompt( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -436,7 +436,7 @@ def _process_decoder_only_prompt( Arguments: - * inputs: input prompt + * prompt: input prompt * request_id * lora_request * prompt_adapter_request @@ -447,7 +447,7 @@ def _process_decoder_only_prompt( ''' prompt_comps = self._extract_prompt_components( - inputs, + prompt, request_id=request_id, lora_request=lora_request, ) @@ -459,14 +459,14 @@ def _process_decoder_only_prompt( async def _process_decoder_only_prompt_async( self, - inputs: SingletonPromptInputs, + prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: """Async version of :meth:`_process_decoder_only_prompt`.""" prompt_comps = await self._extract_prompt_components_async( - inputs, + prompt, request_id=request_id, lora_request=lora_request, ) @@ -478,7 +478,7 @@ async def _process_decoder_only_prompt_async( def preprocess( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -488,17 +488,17 @@ def preprocess( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return self._process_encoder_decoder_prompt( - inputs, + prompt, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return self._process_decoder_only_prompt( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -506,7 +506,7 @@ def preprocess( async def preprocess_async( self, - inputs: PromptInputs, + prompt: PromptType, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -516,17 +516,17 @@ async def preprocess_async( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return await self._process_encoder_decoder_prompt_async( - inputs, + prompt, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(inputs): + if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return await self._process_decoder_only_prompt_async( - inputs, + prompt, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request,