From e729f1d6147f1fb700aa0e10cdf7592919282c94 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 24 Sep 2024 04:10:36 +0000 Subject: [PATCH] Add backwards compatibility for #8673 --- tests/async_engine/test_async_llm_engine.py | 8 +- tests/entrypoints/llm/test_encode.py | 34 -------- tests/entrypoints/llm/test_generate.py | 37 --------- vllm/engine/async_llm_engine.py | 86 +++++++++++++++++++-- vllm/engine/llm_engine.py | 41 +++++++++- vllm/engine/multiprocessing/__init__.py | 57 +++++++++++++- vllm/engine/multiprocessing/client.py | 75 +++++++++++++++++- vllm/entrypoints/llm.py | 6 +- 8 files changed, 256 insertions(+), 88 deletions(-) diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 6cae76f74603d..1903a7582dc89 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -86,17 +86,19 @@ class MockAsyncLLMEngine(AsyncLLMEngine): @pytest.mark.asyncio async def test_new_requests_event(): + params = SamplingParams() + engine = MockAsyncLLMEngine() engine.start_background_loop() await asyncio.sleep(0.01) assert engine.engine.step_calls == 0 - await engine.add_request("1", "", None) + await engine.add_request("1", "", params) await asyncio.sleep(0.01) assert engine.engine.add_request_calls == 1 assert engine.engine.step_calls == 1 - await engine.add_request("2", "", None) + await engine.add_request("2", "", params) engine.engine.generate("2") await asyncio.sleep(0) await asyncio.sleep(0) @@ -111,7 +113,7 @@ async def test_new_requests_event(): await asyncio.sleep(0.001) assert engine.engine.step_calls == old_step_calls - await engine.add_request("3", "", None) + await engine.add_request("3", "", params) await asyncio.sleep(0.01) assert engine.engine.add_request_calls == 3 assert engine.engine.step_calls == old_step_calls + 1 diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index d1056a0490509..1885f2e168d80 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -49,21 +49,6 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput], assert [o.outputs for o in o1] == [o.outputs for o in o2] -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt', PROMPTS) -def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): - pooling_params = PoolingParams() - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params) - - v2_output = llm.encode(prompt, pooling_params=pooling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup @pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, @@ -79,25 +64,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, assert_outputs_equal(v1_output, v2_output) -@pytest.mark.skip_global_cleanup -def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): - pooling_params = PoolingParams() - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params) - - v2_output = llm.encode(PROMPTS, pooling_params=pooling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.encode( - [{ - "prompt": p - } for p in PROMPTS], - pooling_params=pooling_params, - ) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): pooling_params = PoolingParams() diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index cd989225e2483..6543c4bb1b58e 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -47,23 +47,6 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): assert [o.outputs for o in o1] == [o.outputs for o in o2] -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt', PROMPTS) -def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.generate(prompts=prompt, - sampling_params=sampling_params) - - v2_output = llm.generate(prompt, sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.generate({"prompt": prompt}, - sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup @pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, @@ -79,26 +62,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, assert_outputs_equal(v1_output, v2_output) -@pytest.mark.skip_global_cleanup -def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.generate(prompts=PROMPTS, - sampling_params=sampling_params) - - v2_output = llm.generate(PROMPTS, sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.generate( - [{ - "prompt": p - } for p in PROMPTS], - sampling_params=sampling_params, - ) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): sampling_params = SamplingParams(temperature=0.0, top_p=1.0) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f108751056ab5..54c5af2fe3665 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -2,8 +2,8 @@ import time import weakref from functools import partial -from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, - Mapping, Optional, Set, Tuple, Type, Union) +from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable, + List, Mapping, Optional, Set, Tuple, Type, Union, overload) from weakref import ReferenceType import vllm.envs as envs @@ -28,7 +28,7 @@ from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils import weak_bind +from vllm.utils import deprecate_kwargs, weak_bind logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S @@ -402,6 +402,21 @@ async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() + @overload # DEPRECATED + async def add_request_async( + self, + request_id: str, + *, + inputs: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + ... + + @overload async def add_request_async( self, request_id: str, @@ -411,8 +426,30 @@ async def add_request_async( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + async def add_request_async( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + *, + inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: """Async version of :meth:`add_request`.""" + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") @@ -774,7 +811,23 @@ async def run_engine_loop(engine_ref: ReferenceType): # This method does not need to be async, but kept that way # for backwards compatibility. - async def add_request( + @overload # DEPRECATED + def add_request( + self, + request_id: str, + *, + inputs: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Coroutine[None, None, AsyncGenerator[Union[ + RequestOutput, EmbeddingRequestOutput], None]]: + ... + + @overload + def add_request( self, request_id: str, prompt: PromptType, @@ -782,8 +835,31 @@ async def add_request( arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Coroutine[None, None, AsyncGenerator[Union[ + RequestOutput, EmbeddingRequestOutput], None]]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + async def add_request( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + *, + inputs: Optional[PromptType] = None, # DEPRECATED ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + if not self.is_running: if self.start_engine_loop: self.start_background_loop() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1e77a01bfa9d9..3c4f6a4ab272e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,7 +6,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, Iterable, List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence -from typing import Set, Type, Union +from typing import Set, Type, Union, overload import torch from typing_extensions import TypeVar @@ -51,7 +51,7 @@ BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter, Device, weak_bind +from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -686,6 +686,21 @@ def _add_processed_request( def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() + @overload # DEPRECATED + def add_request( + self, + request_id: str, + *, + inputs: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + ... + + @overload def add_request( self, request_id: str, @@ -695,6 +710,24 @@ def add_request( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def add_request( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + *, + inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: """Add a request to the engine's request pool. @@ -737,6 +770,10 @@ def add_request( >>> # continue the request processing >>> ... """ + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 09aa279f1e22c..48f2c46b7114d 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from typing import List, Mapping, Optional, Union +from typing import List, Mapping, Optional, Union, overload from vllm import PoolingParams from vllm.inputs import PromptType @@ -8,6 +8,7 @@ from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams +from vllm.utils import deprecate_kwargs VLLM_RPC_SUCCESS_STR = "SUCCESS" @@ -30,6 +31,60 @@ class RPCProcessRequest: trace_headers: Optional[Mapping[str, str]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None + @overload # DEPRECATED + def __init__( + self, + *, + inputs: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + ... + + @overload + def __init__( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def __init__( + self, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + *, + inputs: Optional[PromptType] = None, # DEPRECATED + ) -> None: + if inputs is not None: + prompt = inputs + assert (prompt is not None and params is not None + and request_id is not None) + + super().__init__() + + self.prompt = prompt + self.params = params + self.request_id = request_id + self.lora_request = lora_request + self.trace_headers = trace_headers + self.prompt_adapter_request = prompt_adapter_request + @dataclass class RPCError: diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 71099115ea125..57cad168704a3 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -3,7 +3,7 @@ import pickle from contextlib import contextmanager, suppress from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, - Union) + Union, overload) import cloudpickle import zmq @@ -32,6 +32,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.utils import deprecate_kwargs logger = init_logger(__name__) @@ -373,6 +374,20 @@ def errored(self) -> bool: def dead_error(self) -> BaseException: return ENGINE_DEAD_ERROR(self._errored_with) + @overload # DEPRECATED + def generate( + self, + *, + inputs: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> AsyncGenerator[RequestOutput, None]: + ... + + @overload def generate( self, prompt: PromptType, @@ -380,7 +395,24 @@ def generate( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> AsyncGenerator[RequestOutput, None]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def generate( + self, + prompt: Optional[PromptType] = None, + sampling_params: Optional[SamplingParams] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + *, + inputs: Optional[PromptType] = None # DEPRECATED ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request. @@ -398,10 +430,28 @@ def generate( prompt_adapter_request: Prompt Adapter request to use for generation, if any. """ + if inputs is not None: + prompt = inputs + assert (prompt is not None and sampling_params is not None + and request_id is not None) + return self._process_request(prompt, sampling_params, request_id, lora_request, trace_headers, prompt_adapter_request) + @overload # DEPRECATED + def encode( + self, + *, + inputs: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ... + + @overload def encode( self, prompt: PromptType, @@ -409,6 +459,22 @@ def encode( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def encode( + self, + prompt: Optional[PromptType] = None, + pooling_params: Optional[PoolingParams] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + *, + inputs: Optional[PromptType] = None # DEPRECATED ) -> AsyncGenerator[EmbeddingRequestOutput, None]: """Generate outputs for a request from an embedding model. @@ -428,6 +494,11 @@ def encode( The output `EmbeddingRequestOutput` objects from the LLMEngine for the request. """ + if inputs is not None: + prompt = inputs + assert (prompt is not None and pooling_params is not None + and request_id is not None) + return self._process_request(prompt, pooling_params, request_id, lora_request, trace_headers) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 63ffde06768cb..766b03fb7953d 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -304,10 +304,9 @@ def generate( ... @deprecate_kwargs( - "prompts", "prompt_token_ids", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'inputs' parameter instead.", + additional_message="Please use the 'prompts' parameter instead.", ) def generate( self, @@ -658,10 +657,9 @@ def encode( ... @deprecate_kwargs( - "prompts", "prompt_token_ids", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'inputs' parameter instead.", + additional_message="Please use the 'prompts' parameter instead.", ) def encode( self,