diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index def8a460e84a..08e132d0c68b 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -63,9 +63,9 @@ steps: mirror_hardwares: [amd] commands: - # these tests have to be separated, because each one will allocate all posible GPU memory - - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py - - pytest -v -s entrypoints/test_server_oot_registration.py + - pytest -v -s test_inputs.py + - pytest -v -s entrypoints -m llm + - pytest -v -s entrypoints -m openai - label: Examples Test working_dir: "/vllm-workspace/examples" @@ -110,6 +110,9 @@ steps: mirror_hardwares: [amd] command: pytest -v -s test_logits_processor.py +- label: Utils Test + command: pytest -v -s test_utils.py + - label: Worker Test mirror_hardwares: [amd] command: pytest -v -s worker diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index a9657f785975..3146fb33cc27 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -3,13 +3,14 @@ import json import time from pathlib import Path -from typing import Optional +from typing import List, Optional import numpy as np import torch from tqdm import tqdm from vllm import LLM, SamplingParams +from vllm.inputs import PromptStrictInputs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -48,7 +49,9 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_prompt_token_ids = dummy_prompt_token_ids.tolist() + dummy_inputs: List[PromptStrictInputs] = [{ + "prompt_token_ids": batch + } for batch in dummy_prompt_token_ids.tolist()] def run_to_completion(profile_dir: Optional[str] = None): if profile_dir: @@ -59,13 +62,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(prompt_token_ids=dummy_prompt_token_ids, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(prompt_token_ids=dummy_prompt_token_ids, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() diff --git a/docs/source/offline_inference/llm.rst b/docs/source/dev/offline_inference/llm.rst similarity index 86% rename from docs/source/offline_inference/llm.rst rename to docs/source/dev/offline_inference/llm.rst index 1a443ea40699..83ba1b6987c6 100644 --- a/docs/source/offline_inference/llm.rst +++ b/docs/source/dev/offline_inference/llm.rst @@ -1,5 +1,5 @@ LLM Class -========== +========= .. autoclass:: vllm.LLM :members: diff --git a/docs/source/dev/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst new file mode 100644 index 000000000000..31c3d16a3c8e --- /dev/null +++ b/docs/source/dev/offline_inference/llm_inputs.rst @@ -0,0 +1,14 @@ +LLM Inputs +========== + +.. autodata:: vllm.inputs.PromptStrictInputs + +.. autoclass:: vllm.inputs.TextPrompt + :show-inheritance: + :members: + :member-order: bysource + +.. autoclass:: vllm.inputs.TokensPrompt + :show-inheritance: + :members: + :member-order: bysource diff --git a/docs/source/dev/offline_inference/offline_index.rst b/docs/source/dev/offline_inference/offline_index.rst new file mode 100644 index 000000000000..27dfb0e9df90 --- /dev/null +++ b/docs/source/dev/offline_inference/offline_index.rst @@ -0,0 +1,8 @@ +Offline Inference +================================= + +.. toctree:: + :maxdepth: 1 + + llm + llm_inputs diff --git a/docs/source/offline_inference/sampling_params.rst b/docs/source/dev/sampling_params.rst similarity index 100% rename from docs/source/offline_inference/sampling_params.rst rename to docs/source/dev/sampling_params.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 5db1c9346c45..5f18fe9ae0a7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -68,13 +68,6 @@ Documentation getting_started/quickstart getting_started/examples/examples_index -.. toctree:: - :maxdepth: 1 - :caption: Offline Inference - - offline_inference/llm - offline_inference/sampling_params - .. toctree:: :maxdepth: 1 :caption: Serving @@ -108,7 +101,9 @@ Documentation .. toctree:: :maxdepth: 2 :caption: Developer Documentation - + + dev/sampling_params + dev/offline_inference/offline_index dev/engine/engine_index dev/kernel/paged_attention dev/dockerfile/dockerfile diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index a775c6addf1d..15a8761eb573 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -48,7 +48,7 @@ completion = client.chat.completions.create( ``` ### Extra Parameters for Chat API -The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported. +The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py :language: python @@ -65,7 +65,7 @@ The following extra parameters are supported: ``` ### Extra Parameters for Completions API -The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported. +The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py :language: python diff --git a/examples/llava_example.py b/examples/llava_example.py index 3d22b492654b..60250c4303fb 100644 --- a/examples/llava_example.py +++ b/examples/llava_example.py @@ -23,11 +23,15 @@ def run_llava_pixel_values(): "\nUSER: What is the content of this image?\nASSISTANT:") # This should be provided by another online or offline component. - images = torch.load("images/stop_sign_pixel_values.pt") + image = torch.load("images/stop_sign_pixel_values.pt") + + outputs = llm.generate({ + "prompt": + prompt, + "multi_modal_data": + MultiModalData(type=MultiModalData.Type.IMAGE, data=image), + }) - outputs = llm.generate(prompt, - multi_modal_data=MultiModalData( - type=MultiModalData.Type.IMAGE, data=images)) for o in outputs: generated_text = o.outputs[0].text print(generated_text) @@ -46,11 +50,14 @@ def run_llava_image_features(): "\nUSER: What is the content of this image?\nASSISTANT:") # This should be provided by another online or offline component. - images = torch.load("images/stop_sign_image_features.pt") - - outputs = llm.generate(prompt, - multi_modal_data=MultiModalData( - type=MultiModalData.Type.IMAGE, data=images)) + image = torch.load("images/stop_sign_image_features.pt") + + outputs = llm.generate({ + "prompt": + prompt, + "multi_modal_data": + MultiModalData(type=MultiModalData.Type.IMAGE, data=image), + }) for o in outputs: generated_text = o.outputs[0].text print(generated_text) diff --git a/pyproject.toml b/pyproject.toml index 96f78c37cfef..0e9096fb4c03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,3 +65,10 @@ skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build" [tool.isort] use_parentheses = true skip_gitignore = true + +[tool.pytest.ini_options] +markers = [ + "skip_global_cleanup", + "llm: run tests for vLLM API only", + "openai: run tests for OpenAI API only", +] diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index b69cdc0a2140..10a46422887e 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -25,7 +25,7 @@ async def step_async(self): return [RequestOutput( request_id=self.request_id)] if self.request_id else [] - async def encode_request_async(self, *args, **kwargs): + async def process_model_inputs_async(self, *args, **kwargs): pass def generate(self, request_id): diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index ace4c53916c7..7a8d4b391561 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -29,7 +29,7 @@ def server(): ray.shutdown() -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def client(): client = openai.AsyncOpenAI( base_url="http://localhost:8000/v1", diff --git a/tests/conftest.py b/tests/conftest.py index c1a44a606e1b..af04cfbbb990 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,7 @@ from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.distributed import destroy_model_parallel +from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.sequence import MultiModalData @@ -402,12 +403,22 @@ def generate( ) -> List[Tuple[List[int], str]]: if images is not None: assert len(prompts) == images.shape[0] - req_outputs = self.model.generate( - prompts, - sampling_params=sampling_params, - multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE, - data=images) - if images is not None else None) + + prompt_inputs: List[PromptInputs] = [] + for i, prompt in enumerate(prompts): + image = None if images is None else images[i:i + 1] + mm_data = None if image is None else MultiModalData( + type=MultiModalData.Type.IMAGE, + data=image, + ) + + prompt_inputs.append({ + "prompt": prompt, + "multi_modal_data": mm_data, + }) + + req_outputs = self.model.generate(prompt_inputs, + sampling_params=sampling_params) outputs = [] for req_output in req_outputs: prompt_str = req_output.prompt diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 22a9f0cf47d3..88cd4f98091f 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -133,8 +133,11 @@ def test_append_slot_cow(): # Allocate prompt to gpu block. There is one slot left in the block. prompt = Sequence(seq_id=1, - prompt="one two three", - prompt_token_ids=[1, 2, 3], + inputs={ + "prompt": "one two three", + "prompt_token_ids": [1, 2, 3], + "multi_modal_data": None + }, block_size=block_size) # Fork the sequence, such that a COW will be required when we append a new @@ -304,7 +307,13 @@ def test_sliding_window_multi_seq(): assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - parent = Sequence(1, "one two three", [0, 1, 2], block_size) + parent = Sequence(seq_id=1, + inputs={ + "prompt": "one two three", + "prompt_token_ids": [0, 1, 2], + "multi_modal_data": None + }, + block_size=block_size) seq_group = SequenceGroup(request_id="1", seqs=[parent], arrival_time=time.time(), diff --git a/tests/core/utils.py b/tests/core/utils.py index 8fb13177a2d6..1c5724090b69 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -21,7 +21,13 @@ def create_dummy_prompt( # and prompt "0 ... block_size". prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) - prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) + prompt = Sequence(int(request_id), + inputs={ + "prompt": prompt_str, + "prompt_token_ids": prompt_tokens, + "multi_modal_data": None, + }, + block_size=block_size) seq_group = SequenceGroup(request_id=request_id, seqs=[prompt], arrival_time=time.time(), @@ -51,8 +57,11 @@ def create_seq_group( for seq_id_offset, output_len in enumerate(seq_output_lens): seq = Sequence( seq_id=seq_id_start + seq_id_offset, - prompt="", - prompt_token_ids=prompt_token_ids, + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, block_size=16, ) diff --git a/tests/engine/test_skip_tokenizer_init.py b/tests/engine/test_skip_tokenizer_init.py index baa463a31690..338b208723ba 100644 --- a/tests/engine/test_skip_tokenizer_init.py +++ b/tests/engine/test_skip_tokenizer_init.py @@ -14,7 +14,7 @@ def test_skip_tokenizer_initialization(model: str): with pytest.raises(ValueError) as err: llm.generate("abc", sampling_params) assert "prompts must be None if" in str(err.value) - outputs = llm.generate(prompt_token_ids=[[1, 2, 3]], + outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, sampling_params=sampling_params) assert len(outputs) > 0 completions = outputs[0].outputs diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 74b49726734b..c45f02fe564a 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -1,11 +1,15 @@ import asyncio from dataclasses import dataclass +import pytest + from vllm.entrypoints.openai.serving_chat import OpenAIServingChat MODEL_NAME = "openai-community/gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" +pytestmark = pytest.mark.openai + @dataclass class MockModelConfig: diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py index 41c871ca40bc..5d4163e96fd8 100644 --- a/tests/entrypoints/test_guided_processors.py +++ b/tests/entrypoints/test_guided_processors.py @@ -52,6 +52,8 @@ TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") +pytestmark = pytest.mark.openai + def test_guided_logits_processors(): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py new file mode 100644 index 000000000000..7c3fbe43a838 --- /dev/null +++ b/tests/entrypoints/test_llm_encode.py @@ -0,0 +1,144 @@ +import weakref +from typing import List + +import pytest + +from vllm import LLM, EmbeddingRequestOutput, PoolingParams + +from ..conftest import cleanup + +MODEL_NAME = "intfloat/e5-mistral-7b-instruct" + +PROMPTS = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +TOKEN_IDS = [ + # Using ID={0, 1, 2, 3} results in NaN values, + # so we add this offset of 1000 + [1000], + [1000, 1001], + [1000, 1002, 1001], + [1000, 1003, 1001, 1002], +] + +pytestmark = pytest.mark.llm + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup() + + +def assert_outputs_equal(o1: List[EmbeddingRequestOutput], + o2: List[EmbeddingRequestOutput]): + assert [o.outputs for o in o1] == [o.outputs for o in o2] + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt', PROMPTS) +def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params) + + v2_output = llm.encode(prompt, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) +def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, + prompt_token_ids): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.encode(prompt_token_ids=prompt_token_ids, + pooling_params=pooling_params) + + v2_output = llm.encode({"prompt_token_ids": prompt_token_ids}, + pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params) + + v2_output = llm.encode(PROMPTS, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.encode( + [{ + "prompt": p + } for p in PROMPTS], + pooling_params=pooling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.encode(prompt_token_ids=TOKEN_IDS, + pooling_params=pooling_params) + + v2_output = llm.encode( + [{ + "prompt_token_ids": p + } for p in TOKEN_IDS], + pooling_params=pooling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_multiple_pooling_params(llm: LLM): + pooling_params = [ + PoolingParams(), + PoolingParams(), + PoolingParams(), + PoolingParams(), + ] + + # Multiple PoolingParams should be matched with each prompt + outputs = llm.encode(PROMPTS, pooling_params=pooling_params) + assert len(PROMPTS) == len(outputs) + + # Exception raised, if the size of params does not match the size of prompts + with pytest.raises(ValueError): + outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3]) + + # Single PoolingParams should be applied to every prompt + single_pooling_params = PoolingParams() + outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params) + assert len(PROMPTS) == len(outputs) + + # pooling_params is None, default params should be applied + outputs = llm.encode(PROMPTS, pooling_params=None) + assert len(PROMPTS) == len(outputs) diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index 5e8b7ca4d997..a00fff91a310 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -1,21 +1,124 @@ +import weakref +from typing import List + import pytest -from vllm import LLM, SamplingParams +from vllm import LLM, RequestOutput, SamplingParams + +from ..conftest import cleanup + +MODEL_NAME = "facebook/opt-125m" + +PROMPTS = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +TOKEN_IDS = [ + [0], + [0, 1], + [0, 2, 1], + [0, 3, 1, 2], +] -def test_multiple_sampling_params(): +pytestmark = pytest.mark.llm - llm = LLM(model="facebook/opt-125m", + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, max_num_batched_tokens=4096, - tensor_parallel_size=1) + tensor_parallel_size=1, + gpu_memory_utilization=0.10, + enforce_eager=True) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup() + + +def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): + assert [o.outputs for o in o1] == [o.outputs for o in o2] + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt', PROMPTS) +def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.generate(prompts=prompt, + sampling_params=sampling_params) + + v2_output = llm.generate(prompt, sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.generate({"prompt": prompt}, + sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) +def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, + prompt_token_ids): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.generate(prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params) + + v2_output = llm.generate({"prompt_token_ids": prompt_token_ids}, + sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.generate(prompts=PROMPTS, + sampling_params=sampling_params) + + v2_output = llm.generate(PROMPTS, sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.generate( + [{ + "prompt": p + } for p in PROMPTS], + sampling_params=sampling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.generate(prompt_token_ids=TOKEN_IDS, + sampling_params=sampling_params) + + v2_output = llm.generate( + [{ + "prompt_token_ids": p + } for p in TOKEN_IDS], + sampling_params=sampling_params, + ) + assert_outputs_equal(v1_output, v2_output) - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] +@pytest.mark.skip_global_cleanup +def test_multiple_sampling_params(llm: LLM): sampling_params = [ SamplingParams(temperature=0.01, top_p=0.95), SamplingParams(temperature=0.3, top_p=0.95), @@ -24,18 +127,18 @@ def test_multiple_sampling_params(): ] # Multiple SamplingParams should be matched with each prompt - outputs = llm.generate(prompts, sampling_params=sampling_params) - assert len(prompts) == len(outputs) + outputs = llm.generate(PROMPTS, sampling_params=sampling_params) + assert len(PROMPTS) == len(outputs) # Exception raised, if the size of params does not match the size of prompts with pytest.raises(ValueError): - outputs = llm.generate(prompts, sampling_params=sampling_params[:3]) + outputs = llm.generate(PROMPTS, sampling_params=sampling_params[:3]) # Single SamplingParams should be applied to every prompt single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95) - outputs = llm.generate(prompts, sampling_params=single_sampling_params) - assert len(prompts) == len(outputs) + outputs = llm.generate(PROMPTS, sampling_params=single_sampling_params) + assert len(PROMPTS) == len(outputs) # sampling_params is None, default params should be applied - outputs = llm.generate(prompts, sampling_params=None) - assert len(prompts) == len(outputs) \ No newline at end of file + outputs = llm.generate(PROMPTS, sampling_params=None) + assert len(PROMPTS) == len(outputs) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 1b04e3205c4b..2463ccde2bc8 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -71,7 +71,7 @@ "Swift", "Kotlin" ] -pytestmark = pytest.mark.asyncio +pytestmark = pytest.mark.openai @pytest.fixture(scope="session") @@ -91,6 +91,8 @@ def server(zephyr_lora_files): "--max-model-len", "8192", "--enforce-eager", + "--gpu-memory-utilization", + "0.75", # lora config below "--enable-lora", "--lora-modules", @@ -118,9 +120,11 @@ def embedding_server(zephyr_lora_files): # use half precision for speed and memory savings in CI environment "--dtype", "bfloat16", + "--enforce-eager", + "--gpu-memory-utilization", + "0.75", "--max-model-len", "8192", - "--enforce-eager", ]) ray.get(server_runner.ready.remote()) yield server_runner @@ -136,6 +140,7 @@ def client(): yield client +@pytest.mark.asyncio async def test_check_models(server, client: openai.AsyncOpenAI): models = await client.models.list() models = models.data @@ -147,6 +152,7 @@ async def test_check_models(server, client: openai.AsyncOpenAI): assert lora_models[1].id == "zephyr-lora2" +@pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras "model_name", @@ -178,6 +184,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI, completion.choices[0].text) >= 5 +@pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras "model_name", @@ -199,6 +206,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI, assert choice.logprobs.top_logprobs is None +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -243,6 +251,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, model_name: str): @@ -298,6 +307,7 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -335,6 +345,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI, assert "".join(chunks) == single_output +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -385,6 +396,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI, assert "".join(chunks) == output +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -438,6 +450,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI, assert texts[0] == texts[1] +@pytest.mark.asyncio async def test_logits_bias(server, client: openai.AsyncOpenAI): prompt = "Hello, my name is" max_tokens = 5 @@ -485,6 +498,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI): assert first_response != completion.choices[0].text +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_json_completion(server, client: openai.AsyncOpenAI, @@ -507,6 +521,7 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI, jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_json_chat(server, client: openai.AsyncOpenAI, @@ -553,6 +568,7 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI, assert json1["age"] != json2["age"] +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, @@ -573,6 +589,7 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, @@ -610,6 +627,7 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, assert ip1 != ip2 +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, @@ -629,6 +647,7 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, assert completion.choices[i].text in TEST_CHOICE +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, @@ -667,6 +686,7 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, assert choice1 != choice2 +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, @@ -702,6 +722,7 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA)) +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, @@ -732,6 +753,7 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, for token, logprob in token_dict.items()) +@pytest.mark.asyncio async def test_response_format_json_object(server, client: openai.AsyncOpenAI): for _ in range(2): resp = await client.chat.completions.create( @@ -749,6 +771,7 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI): assert loaded == {"result": 2}, loaded +@pytest.mark.asyncio async def test_extra_fields(server, client: openai.AsyncOpenAI): with pytest.raises(BadRequestError) as exc_info: await client.chat.completions.create( @@ -764,6 +787,7 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI): assert "extra_forbidden" in exc_info.value.message +@pytest.mark.asyncio async def test_complex_message_content(server, client: openai.AsyncOpenAI): resp = await client.chat.completions.create( model=MODEL_NAME, @@ -783,6 +807,7 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI): assert content == "2" +@pytest.mark.asyncio async def test_custom_role(server, client: openai.AsyncOpenAI): # Not sure how the model handles custom roles so we just check that # both string and complex message content are handled in the same way @@ -813,6 +838,7 @@ async def test_custom_role(server, client: openai.AsyncOpenAI): assert content1 == content2 +@pytest.mark.asyncio async def test_guided_grammar(server, client: openai.AsyncOpenAI): simple_sql_grammar = """ start: select_statement @@ -847,6 +873,7 @@ async def test_guided_grammar(server, client: openai.AsyncOpenAI): assert content.strip() == ground_truth +@pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras "model_name", @@ -878,6 +905,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, assert len(logprobs.tokens) > 5 +@pytest.mark.asyncio async def test_long_seed(server, client: openai.AsyncOpenAI): for seed in [ torch.iinfo(torch.long).min - 1, @@ -897,6 +925,7 @@ async def test_long_seed(server, client: openai.AsyncOpenAI): or "less_than_equal" in exc_info.value.message) +@pytest.mark.asyncio @pytest.mark.parametrize( "model_name", [EMBEDDING_MODEL_NAME], @@ -935,6 +964,7 @@ async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI, assert embeddings.usage.total_tokens == 5 +@pytest.mark.asyncio @pytest.mark.parametrize( "model_name", [EMBEDDING_MODEL_NAME], diff --git a/tests/entrypoints/test_server_oot_registration.py b/tests/entrypoints/test_server_oot_registration.py index 22e65bf7e7da..3e55d7f4297f 100644 --- a/tests/entrypoints/test_server_oot_registration.py +++ b/tests/entrypoints/test_server_oot_registration.py @@ -1,7 +1,7 @@ -import multiprocessing import sys import time +import pytest import torch from openai import OpenAI, OpenAIError @@ -10,6 +10,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.utils import get_open_port +pytestmark = pytest.mark.openai + class MyOPTForCausalLM(OPTForCausalLM): @@ -26,15 +28,16 @@ def server_function(port): # register our dummy model ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) sys.argv = ["placeholder.py"] + \ - ("--model facebook/opt-125m --dtype" - f" float32 --api-key token-abc123 --port {port}").split() + ("--model facebook/opt-125m --gpu-memory-utilization 0.10 " + f"--dtype float32 --api-key token-abc123 --port {port}").split() import runpy runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') def test_oot_registration_for_api_server(): port = get_open_port() - server = multiprocessing.Process(target=server_function, args=(port, )) + ctx = torch.multiprocessing.get_context() + server = ctx.Process(target=server_function, args=(port, )) server.start() client = OpenAI( base_url=f"http://localhost:{port}/v1", diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index 15189f421a53..4361e5452cdf 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -86,20 +86,18 @@ def generate( def batched_generate( - llm, + llm: vllm.LLM, inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]], ): for input in inputs: prompt, sampling_param, lora_req = input - requests_data = llm._validate_and_prepare_requests( + # Add requests to the engine and run the engine + llm._validate_and_add_requests( prompt, sampling_param, lora_request=lora_req, ) - # Add requests to the engine and run the engine - for request_data in requests_data: - llm._add_request(**request_data) outputs = llm._run_engine(use_tqdm=True) return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py index be4c2ea1b781..0ccbabfff640 100644 --- a/tests/samplers/test_logits_processor.py +++ b/tests/samplers/test_logits_processor.py @@ -35,28 +35,25 @@ def pick_vllm(token_ids, logits): # test logits_processors when prompt_logprobs is not None vllm_model.model._add_request( - prompt=example_prompts[0], + example_prompts[0], params=params_with_logprobs, - prompt_token_ids=None, ) # test prompt_logprobs is not None vllm_model.model._add_request( - prompt=example_prompts[1], + example_prompts[1], params=SamplingParams( prompt_logprobs=3, max_tokens=max_tokens, ), - prompt_token_ids=None, ) # test grouped requests vllm_model.model._add_request( - prompt=example_prompts[2], + example_prompts[2], params=SamplingParams(max_tokens=max_tokens), - prompt_token_ids=None, ) - outputs = vllm_model.model._run_engine(False) + outputs = vllm_model.model._run_engine(use_tqdm=False) assert outputs[0].outputs[0].text == enforced_answers * repeat_times diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py index ce4501bbf71e..fef5ff3fb9e8 100644 --- a/tests/samplers/test_seeded_generate.py +++ b/tests/samplers/test_seeded_generate.py @@ -57,11 +57,7 @@ def test_random_sample_with_seed( sampling_params_seed_1, sampling_params_seed_2, ): - llm._add_request( - prompt=prompt, - prompt_token_ids=None, - params=params, - ) + llm._add_request(prompt, params=params) results = llm._run_engine(use_tqdm=False) all_outputs = [[out.token_ids for out in output.outputs] diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py index 3b257ac062f5..97864af88e40 100644 --- a/tests/test_cache_block_hashing.py +++ b/tests/test_cache_block_hashing.py @@ -70,8 +70,15 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, for prompt in prompts: hashes[-1].append([]) prompt_token_ids = tokenizer.encode(prompt) - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - tokenizer.tokenizer.eos_token_id, lora_request) + seq = Sequence(seq_id, + inputs={ + "prompt": prompt, + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, + block_size=block_size, + eos_token_id=tokenizer.tokenizer.eos_token_id, + lora_request=lora_request) num_blocks = len(prompt_token_ids) // block_size for idx in range(num_blocks): diff --git a/tests/test_inputs.py b/tests/test_inputs.py new file mode 100644 index 000000000000..887c7101decd --- /dev/null +++ b/tests/test_inputs.py @@ -0,0 +1,53 @@ +from typing import List + +import pytest + +from vllm.inputs import parse_and_batch_prompt + +STRING_INPUTS = [ + '', + 'foo', + 'foo bar', + 'foo baz bar', + 'foo bar qux baz', +] + +TOKEN_INPUTS = [ + [-1], + [1], + [1, 2], + [1, 3, 4], + [1, 2, 4, 3], +] + +INPUTS_SLICES = [ + slice(None, None, -1), + slice(None, None, 2), + slice(None, None, -2), +] + + +def test_parse_single_batch_empty(): + with pytest.raises(ValueError, match="at least one prompt"): + parse_and_batch_prompt([]) + + with pytest.raises(ValueError, match="at least one prompt"): + parse_and_batch_prompt([[]]) + + +@pytest.mark.parametrize('string_input', STRING_INPUTS) +def test_parse_single_batch_string_consistent(string_input: str): + assert parse_and_batch_prompt(string_input) \ + == parse_and_batch_prompt([string_input]) + + +@pytest.mark.parametrize('token_input', TOKEN_INPUTS) +def test_parse_single_batch_token_consistent(token_input: List[int]): + assert parse_and_batch_prompt(token_input) \ + == parse_and_batch_prompt([token_input]) + + +@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES) +def test_parse_single_batch_string_slice(inputs_slice: slice): + assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \ + == parse_and_batch_prompt(STRING_INPUTS[inputs_slice]) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000000..54dc5c6f5bfb --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,63 @@ +import pytest + +from vllm.utils import deprecate_kwargs + +from .utils import error_on_warning + + +def test_deprecate_kwargs_always(): + + @deprecate_kwargs("old_arg", is_deprecated=True) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="'old_arg'"): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + +def test_deprecate_kwargs_never(): + + @deprecate_kwargs("old_arg", is_deprecated=False) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with error_on_warning(): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + +def test_deprecate_kwargs_dynamic(): + is_deprecated = True + + @deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="'old_arg'"): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + is_deprecated = False + + with error_on_warning(): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + +def test_deprecate_kwargs_additional_message(): + + @deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd") + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="abcd"): + dummy(old_arg=1) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 9bc9becb2a6f..1d4c74d6bd8d 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -123,8 +123,11 @@ def create_sequence(prompt_token_ids=None): prompt_token_ids = prompt_token_ids or [1] return Sequence( seq_id=0, - prompt="", - prompt_token_ids=prompt_token_ids, + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, block_size=16, ) diff --git a/tests/utils.py b/tests/utils.py index 689d8c8c5ba8..329842911e15 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,6 +2,8 @@ import subprocess import sys import time +import warnings +from contextlib import contextmanager import ray import requests @@ -87,3 +89,15 @@ def multi_process_tensor_parallel( ray.get(refs) ray.shutdown() + + +@contextmanager +def error_on_warning(): + """ + Within the scope of this context manager, tests will fail if any warning + is emitted. + """ + with warnings.catch_warnings(): + warnings.simplefilter("error") + + yield diff --git a/vllm/__init__.py b/vllm/__init__.py index 74674ca0d12a..a0e154d24087 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,6 +5,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import (CompletionOutput, EmbeddingOutput, EmbeddingRequestOutput, RequestOutput) @@ -16,6 +17,9 @@ __all__ = [ "LLM", "ModelRegistry", + "PromptStrictInputs", + "TextPrompt", + "TokensPrompt", "SamplingParams", "RequestOutput", "CompletionOutput", diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 5a15ed67e332..d4289c715d9e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -12,12 +12,13 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.executor.ray_utils import initialize_ray_cluster, ray +from vllm.inputs import LLMInputs, PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) @@ -244,64 +245,69 @@ async def step_async( return request_outputs - async def encode_request_async( + async def process_model_inputs_async( self, - request_id: str, # pylint: disable=unused-argument - prompt: Optional[str], - prompt_token_ids: Optional[List[int]] = None, + request_id: str, + inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, - ): - if prompt_token_ids is None: - assert prompt is not None - prompt_token_ids = await self.tokenizer.encode_async( + ) -> LLMInputs: + if isinstance(inputs, str): + inputs = {"prompt": inputs} + + if "prompt_token_ids" not in inputs: + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") + + prompt_token_ids = await tokenizer.encode_async( request_id=request_id, - prompt=prompt, + prompt=inputs["prompt"], lora_request=lora_request) - return prompt_token_ids + else: + prompt_token_ids = inputs["prompt_token_ids"] + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) async def add_request_async( self, request_id: str, - prompt: Optional[str], + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") if arrival_time is None: arrival_time = time.time() - prompt_token_ids = await self.encode_request_async( + + processed_inputs = await self.process_model_inputs_async( + request_id=request_id, inputs=inputs, lora_request=lora_request) + + self._add_processed_request( request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request) - - return self.add_request(request_id, - prompt=prompt, - params=params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time, - lora_request=lora_request, - multi_modal_data=multi_modal_data) + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + ) async def check_health_async(self) -> None: self.model_executor.check_health() class AsyncLLMEngine: - """An asynchronous wrapper for LLMEngine. + """An asynchronous wrapper for :class:`LLMEngine`. - This class is used to wrap the LLMEngine class to make it asynchronous. It - uses asyncio to create a background loop that keeps processing incoming - requests. The LLMEngine is kicked by the generate method when there - are requests in the waiting queue. The generate method yields the outputs - from the LLMEngine to the caller. + This class is used to wrap the :class:`LLMEngine` class to make it + asynchronous. It uses asyncio to create a background loop that keeps + processing incoming requests. The :class:`LLMEngine` is kicked by the + generate method when there are requests in the waiting queue. The generate + method yields the outputs from the :class:`LLMEngine` to the caller. - NOTE: For the comprehensive list of arguments, see `LLMEngine`. + NOTE: For the comprehensive list of arguments, see :class:`LLMEngine`. Args: worker_use_ray: Whether to use Ray for model workers. Required for @@ -315,8 +321,8 @@ class AsyncLLMEngine: being printed in log. start_engine_loop: If True, the background task to run the engine will be automatically started in the generate call. - *args: Arguments for LLMEngine. - *kwargs: Arguments for LLMEngine. + *args: Arguments for :class:`LLMEngine`. + **kwargs: Arguments for :class:`LLMEngine`. """ _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine @@ -526,22 +532,26 @@ async def run_engine_loop(self): async def add_request( self, request_id: str, - prompt: Optional[str], + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> AsyncStream: if self.log_requests: - shortened_prompt = prompt - shortened_token_ids = prompt_token_ids - if self.max_log_len is not None: + if isinstance(inputs, str): + shortened_prompt = inputs + shortened_token_ids = None + else: + shortened_prompt = inputs.get("prompt") + shortened_token_ids = inputs.get("prompt_token_ids") + + max_log_len = self.max_log_len + if max_log_len is not None: if shortened_prompt is not None: - shortened_prompt = shortened_prompt[:self.max_log_len] + shortened_prompt = shortened_prompt[:max_log_len] if shortened_token_ids is not None: - shortened_token_ids = shortened_token_ids[:self. - max_log_len] + shortened_token_ids = shortened_token_ids[:max_log_len] + logger.info( "Received request %s: prompt: %r, " "params: %s, prompt_token_ids: %s, " @@ -562,39 +572,33 @@ async def add_request( arrival_time = time.time() if self.engine_use_ray: - prompt_token_ids = await ( - self.engine.encode_request_async.remote( # type: ignore + processed_inputs = await self.engine.process_model_inputs_async \ + .remote( # type: ignore request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request)) + inputs=inputs, + lora_request=lora_request) else: - prompt_token_ids = await self.engine.encode_request_async( + processed_inputs = await self.engine.process_model_inputs_async( request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, + inputs=inputs, lora_request=lora_request) stream = self._request_tracker.add_request( request_id, - prompt=prompt, + inputs=processed_inputs, params=params, - prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, lora_request=lora_request, - multi_modal_data=multi_modal_data, ) return stream async def generate( self, - prompt: Optional[str], + inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, - prompt_token_ids: Optional[List[int]] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -603,14 +607,12 @@ async def generate( from the LLMEngine to the caller. Args: - prompt: The prompt string. Can be None if prompt_token_ids is - provided. + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. - prompt_token_ids: The token IDs of the prompt. If None, we - use the tokenizer to convert the prompts to token IDs. lora_request: LoRA request to use for generation, if any. - multi_modal_data: Multi modal data per request. Yields: The output `RequestOutput` objects from the LLMEngine @@ -659,24 +661,20 @@ async def generate( >>> # Process and return the final output >>> ... """ - async for output in self.process_request( + async for output in self._process_request( request_id, - prompt, + inputs, sampling_params, - prompt_token_ids, - lora_request, - multi_modal_data, + lora_request=lora_request, ): - yield output + yield LLMEngine.validate_output(output, RequestOutput) async def encode( self, - prompt: Optional[str], + inputs: PromptInputs, pooling_params: PoolingParams, request_id: str, - prompt_token_ids: Optional[List[int]] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None ) -> AsyncIterator[EmbeddingRequestOutput]: """Generate outputs for a request from an embedding model. @@ -685,14 +683,12 @@ async def encode( from the LLMEngine to the caller. Args: - prompt: The prompt string. Can be None if prompt_token_ids is - provided. + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. - prompt_token_ids: The token IDs of the prompt. If None, we - use the tokenizer to convert the prompts to token IDs. lora_request: LoRA request to use for generation, if any. - multi_modal_data: Multi modal data per request. Yields: The output `EmbeddingRequestOutput` objects from the LLMEngine @@ -739,24 +735,21 @@ async def encode( >>> # Process and return the final output >>> ... """ - async for output in self.process_request( + async for output in self._process_request( request_id, - prompt, + inputs, pooling_params, - prompt_token_ids, - lora_request, - multi_modal_data, + lora_request=lora_request, ): - yield output + yield LLMEngine.validate_output(output, EmbeddingRequestOutput) - async def process_request( + async def _process_request( self, request_id: str, - prompt: Optional[str], + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - prompt_token_ids: Optional[List[int]] = None, + *, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: """Common logic to process requests with SamplingParams or PoolingParams.""" @@ -764,12 +757,10 @@ async def process_request( stream = await self.add_request( request_id, - prompt, + inputs, params, - prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, lora_request=lora_request, - multi_modal_data=multi_modal_data, ) try: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0631c0de7682..08bccf209b7c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,5 +1,8 @@ import time -from typing import Iterable, List, Optional, Type, Union +from contextlib import contextmanager +from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional +from typing import Sequence as GenericSequence +from typing import Type, TypeVar, Union from transformers import GenerationConfig, PreTrainedTokenizer @@ -18,6 +21,7 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.inputs import LLMInputs, PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, @@ -25,8 +29,8 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - MultiModalData, PoolerOutput, SamplerOutput, - Sequence, SequenceGroup, SequenceGroupMetadata, + PoolerOutput, SamplerOutput, Sequence, + SequenceGroup, SequenceGroupMetadata, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, @@ -50,6 +54,9 @@ def _load_generation_config_dict(model_config: ModelConfig): return {} +_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) + + class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -60,11 +67,11 @@ class LLMEngine: iteration-level scheduling and efficient memory management to maximize the serving throughput. - The `LLM` class wraps this class for offline batched inference and the - `AsyncLLMEngine` class wraps this class for online serving. + The :class:`~vllm.LLM` class wraps this class for offline batched inference + and the :class:`AsyncLLMEngine` class wraps this class for online serving. - NOTE: The config arguments are derived from the `EngineArgs` class. For the - comprehensive list of arguments, see `EngineArgs`. + NOTE: The config arguments are derived from the :class:`~vllm.EngineArgs` + class. For the comprehensive list of arguments, see :ref:`engine_args`. Args: model_config: The configuration related to the LLM model. @@ -81,9 +88,60 @@ class LLMEngine: executor_class: The model executor class for managing distributed execution. log_stats: Whether to log statistics. - usage_context: Specified entry point, used for usage info collection + usage_context: Specified entry point, used for usage info collection. """ + DO_VALIDATE_OUTPUT: ClassVar[bool] = False + """A flag to toggle whether to validate the type of request output.""" + + @classmethod + @contextmanager + def enable_output_validation(cls): + cls.DO_VALIDATE_OUTPUT = True + + yield + + cls.DO_VALIDATE_OUTPUT = False + + @classmethod + def validate_output( + cls, + output: object, + output_type: Type[_O], + ) -> _O: + do_validate = cls.DO_VALIDATE_OUTPUT + + if ((TYPE_CHECKING or do_validate) + and not isinstance(output, output_type)): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + return output + + @classmethod + def validate_outputs( + cls, + outputs: GenericSequence[object], + output_type: Type[_O], + ) -> List[_O]: + do_validate = cls.DO_VALIDATE_OUTPUT + + outputs_: List[_O] + if TYPE_CHECKING or do_validate: + outputs_ = [] + for output in outputs: + if not isinstance(output, output_type): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + outputs_.append(output) + else: + outputs_ = outputs + + return outputs_ + + tokenizer: Optional[BaseTokenizerGroup] + def __init__( self, model_config: ModelConfig, @@ -151,12 +209,11 @@ def __init__( self.log_stats = log_stats if not self.model_config.skip_tokenizer_init: - self.tokenizer: BaseTokenizerGroup - self._init_tokenizer() + self.tokenizer = self._init_tokenizer() self.detokenizer = Detokenizer(self.tokenizer) else: - self.detokenizer = None self.tokenizer = None + self.detokenizer = None self.seq_counter = Counter() self.generation_config_fields = _load_generation_config_dict( @@ -318,14 +375,26 @@ def __del__(self): if model_executor := getattr(self, "model_executor", None): model_executor.shutdown() + MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because " + "skip_tokenizer_init is True") + + def get_tokenizer_group( + self, + fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup: + if self.tokenizer is None: + raise ValueError(fail_msg) + + return self.tokenizer + def get_tokenizer(self) -> "PreTrainedTokenizer": - return self.tokenizer.get_lora_tokenizer(None) + return self.get_tokenizer_group().get_lora_tokenizer(None) def get_tokenizer_for_seq(self, sequence: Sequence) -> "PreTrainedTokenizer": - return self.tokenizer.get_lora_tokenizer(sequence.lora_request) + return self.get_tokenizer_group().get_lora_tokenizer( + sequence.lora_request) - def _init_tokenizer(self, **tokenizer_init_kwargs): + def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: init_kwargs = dict( tokenizer_id=self.model_config.tokenizer, enable_lora=bool(self.lora_config), @@ -335,8 +404,9 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): trust_remote_code=self.model_config.trust_remote_code, revision=self.model_config.tokenizer_revision) init_kwargs.update(tokenizer_init_kwargs) - self.tokenizer = get_tokenizer_group( - self.parallel_config.tokenizer_pool_config, **init_kwargs) + + return get_tokenizer_group(self.parallel_config.tokenizer_pool_config, + **init_kwargs) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) @@ -346,29 +416,85 @@ def _verify_args(self) -> None: self.lora_config.verify_with_scheduler_config( self.scheduler_config) - def encode_request( + def _get_eos_token_id( + self, lora_request: Optional[LoRARequest]) -> Optional[int]: + if self.tokenizer is None: + logger.warning("Using None for EOS token id because tokenizer " + "is not initialized") + return None + + return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id + + def _add_processed_request( self, - request_id: str, # pylint: disable=unused-argument - prompt: Optional[str], - prompt_token_ids: Optional[List[int]] = None, + request_id: str, + processed_inputs: LLMInputs, + params: Union[SamplingParams, PoolingParams], + arrival_time: float, + lora_request: Optional[LoRARequest], + ) -> None: + # Create the sequences. + block_size = self.cache_config.block_size + seq_id = next(self.seq_counter) + eos_token_id = self._get_eos_token_id(lora_request) + + seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, + lora_request) + + # Create a SequenceGroup based on SamplingParams or PoolingParams + if isinstance(params, SamplingParams): + seq_group = self._create_sequence_group_with_sampling( + request_id, + seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + ) + elif isinstance(params, PoolingParams): + seq_group = self._create_sequence_group_with_pooling( + request_id, + seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + ) + else: + raise ValueError( + "Either SamplingParams or PoolingParams must be provided.") + + # Add the sequence group to the scheduler. + self.scheduler.add_seq_group(seq_group) + + def process_model_inputs( + self, + request_id: str, + inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, - ): - if prompt_token_ids is None: - assert prompt is not None - prompt_token_ids = self.tokenizer.encode(request_id=request_id, - prompt=prompt, - lora_request=lora_request) - return prompt_token_ids + ) -> LLMInputs: + if isinstance(inputs, str): + inputs = {"prompt": inputs} + + if "prompt_token_ids" not in inputs: + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") + + prompt_token_ids = tokenizer.encode(request_id=request_id, + prompt=inputs["prompt"], + lora_request=lora_request) + else: + prompt_token_ids = inputs["prompt_token_ids"] + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) def add_request( self, request_id: str, - prompt: Optional[str], + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: """Add a request to the engine's request pool. @@ -378,15 +504,14 @@ def add_request( Args: request_id: The unique ID of the request. - prompt: The prompt string. Can be None if prompt_token_ids is - provided. - params: Parameters for sampling or pooling. SamplingParams - for text generation. PoolingParams for pooling. - prompt_token_ids: The token IDs of the prompt. If None, we - use the tokenizer to convert the prompts to token IDs. + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. + params: Parameters for sampling or pooling. + :class:`~vllm.SamplingParams` for text generation. + :class:`~vllm.PoolingParams` for pooling. arrival_time: The arrival time of the request. If None, we use the current monotonic time. - multi_modal_data: Multi modal data per request. Details: - Set arrival_time to the current time if it is None. @@ -417,59 +542,26 @@ def add_request( "not enabled!") if arrival_time is None: arrival_time = time.time() - prompt_token_ids = self.encode_request( - request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request) - - # Create the sequences. - block_size = self.cache_config.block_size - seq_id = next(self.seq_counter) - eos_token_id = None - if self.tokenizer: - eos_token_id = self.tokenizer.get_lora_tokenizer( - lora_request).eos_token_id - else: - logger.warning("Use None for EOS token id because tokenizer is " - "not initialized") - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - eos_token_id, lora_request) - # Create a SequenceGroup based on SamplingParams or PoolingParams - if isinstance(params, SamplingParams): - seq_group = self._create_sequence_group_with_sampling( - request_id, - seq, - params, - arrival_time, - lora_request, - multi_modal_data, - ) - elif isinstance(params, PoolingParams): - seq_group = self._create_sequence_group_with_pooling( - request_id, - seq, - params, - arrival_time, - lora_request, - multi_modal_data, - ) - else: - raise ValueError( - "Either SamplingParams or PoolingParams must be provided.") + processed_inputs = self.process_model_inputs(request_id=request_id, + inputs=inputs, + lora_request=lora_request) - # Add the sequence group to the scheduler. - self.scheduler.add_seq_group(seq_group) + self._add_processed_request( + request_id=request_id, + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + ) def _create_sequence_group_with_sampling( self, request_id: str, seq: Sequence, sampling_params: SamplingParams, - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, + arrival_time: float, + lora_request: Optional[LoRARequest], ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -495,8 +587,7 @@ def _create_sequence_group_with_sampling( seqs=[seq], arrival_time=arrival_time, sampling_params=sampling_params, - lora_request=lora_request, - multi_modal_data=multi_modal_data) + lora_request=lora_request) return seq_group @@ -505,9 +596,8 @@ def _create_sequence_group_with_pooling( request_id: str, seq: Sequence, pooling_params: PoolingParams, - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, + arrival_time: float, + lora_request: Optional[LoRARequest], ) -> SequenceGroup: """Creates a SequenceGroup with PoolingParams.""" # Defensive copy of PoolingParams, which are used by the pooler @@ -517,7 +607,6 @@ def _create_sequence_group_with_pooling( seqs=[seq], arrival_time=arrival_time, lora_request=lora_request, - multi_modal_data=multi_modal_data, pooling_params=pooling_params) return seq_group @@ -570,7 +659,7 @@ def _process_sequence_group_outputs( def _process_model_outputs( self, - output: List[Union[SamplerOutput, PoolerOutput]], + output: GenericSequence[Union[SamplerOutput, PoolerOutput]], scheduled_seq_groups: List[ScheduledSequenceGroup], ignored_seq_groups: List[SequenceGroup], seq_group_metadata_list: List[SequenceGroupMetadata], @@ -585,7 +674,7 @@ def _process_model_outputs( # Organize outputs by [sequence group][step] instead of # [step][sequence group]. output_by_sequence_group = create_output_by_sequence_group( - sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups)) + output, num_seq_groups=len(scheduled_seq_groups)) # Update the scheduled sequence groups with the model outputs. for scheduled_seq_group, outputs, seq_group_meta in zip( diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py index 9816e966c1e3..57cc33d91118 100644 --- a/vllm/engine/output_processor/util.py +++ b/vllm/engine/output_processor/util.py @@ -1,18 +1,20 @@ from typing import List +from typing import Sequence as GenericSequence +from typing import Union -from vllm.sequence import SamplerOutput, SequenceGroupOutput +from vllm.sequence import PoolerOutput, SamplerOutput, SequenceGroupOutput def create_output_by_sequence_group( - sampler_outputs: List[SamplerOutput], + outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]], num_seq_groups: int) -> List[List[SequenceGroupOutput]]: """Helper method which transforms a 2d list organized by [step][sequence group] into [sequence group][step]. """ - output_by_sequence_group: List[List[SamplerOutput]] = [ + output_by_sequence_group: List[List[SequenceGroupOutput]] = [ [] for _ in range(num_seq_groups) ] - for step in sampler_outputs: + for step in outputs: for i, sequence_group_output in enumerate(step): output_by_sequence_group[i].append(sequence_group_output) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 25f4428100b2..9759d0557779 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,11 +1,14 @@ -from typing import List, Optional, Union +from contextlib import contextmanager +from typing import ClassVar, List, Optional, Sequence, Union, cast, overload -import torch from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine +from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt, + TextTokensPrompt, TokensPrompt, + parse_and_batch_prompt) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -13,7 +16,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import MultiModalData from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter +from vllm.utils import Counter, deprecate_kwargs logger = init_logger(__name__) @@ -28,8 +31,10 @@ class LLM: mechanism and efficient memory management. NOTE: This class is intended to be used for offline inference. For online - serving, use the `AsyncLLMEngine` class instead. - NOTE: For the comprehensive list of arguments, see `EngineArgs`. + serving, use the :class:`~vllm.AsyncLLMEngine` class instead. + + NOTE: For the comprehensive list of arguments, see + :class:`~vllm.EngineArgs`. Args: model: The name or path of a HuggingFace Transformers model. @@ -81,6 +86,18 @@ class LLM: disable_custom_all_reduce: See ParallelConfig """ + DEPRECATE_LEGACY: ClassVar[bool] = False + """A flag to toggle whether to deprecate the legacy generate/encode API.""" + + @classmethod + @contextmanager + def deprecate_legacy_api(cls): + cls.DEPRECATE_LEGACY = True + + yield + + cls.DEPRECATE_LEGACY = False + def __init__( self, model: str, @@ -138,15 +155,101 @@ def set_tokenizer( ) -> None: self.llm_engine.tokenizer.tokenizer = tokenizer + @overload # LEGACY: single (prompt + optional token ids) + def generate( + self, + prompts: str, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + prompt_token_ids: Optional[List[int]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: multi (prompt + optional token ids) def generate( self, - prompts: Optional[Union[str, List[str]]] = None, + prompts: List[str], sampling_params: Optional[Union[SamplingParams, List[SamplingParams]]] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: single (token ids + optional prompt) + def generate( + self, + prompts: Optional[str] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + *, + prompt_token_ids: List[int], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: multi (token ids + optional prompt) + def generate( + self, + prompts: Optional[List[str]] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + *, + prompt_token_ids: List[List[int]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: single or multi token ids [pos-only] + def generate( + self, + prompts: None, + sampling_params: None, + prompt_token_ids: Union[List[int], List[List[int]]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload + def generate( + self, + inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + /, # We may enable `inputs` keyword after removing the old API + *, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + ) -> List[RequestOutput]: + ... + + @deprecate_kwargs("prompts", + "prompt_token_ids", + "multi_modal_data", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'inputs' parameter " + "instead.") + def generate( + self, + prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + Optional[Union[str, List[str]]]] = None, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -155,49 +258,138 @@ def generate( into a single list and pass it to this method. Args: - prompts: A list of prompts to generate completions for. + inputs: A list of inputs to generate completions for. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. When it is a list, the list must have the same length as the prompts and it is paired one by one with the prompt. - prompt_token_ids: A list of token IDs for the prompts. If None, we - use the tokenizer to convert the prompts to token IDs. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. - multi_modal_data: Multi modal data. Returns: A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ + if prompt_token_ids is not None or multi_modal_data is not None: + inputs = self._convert_v1_inputs( + prompts=cast(Optional[Union[str, List[str]]], prompts), + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + ) + else: + inputs = cast( + Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + prompts) + if sampling_params is None: # Use default sampling params. sampling_params = SamplingParams() - requests_data = self._validate_and_prepare_requests( - prompts, - sampling_params, - prompt_token_ids, - lora_request, - multi_modal_data, + self._validate_and_add_requests( + inputs=inputs, + params=sampling_params, + lora_request=lora_request, ) - # Add requests to the engine and run the engine - for request_data in requests_data: - self._add_request(**request_data) + outputs = self._run_engine(use_tqdm=use_tqdm) + return LLMEngine.validate_outputs(outputs, RequestOutput) - return self._run_engine(use_tqdm) + @overload # LEGACY: single (prompt + optional token ids) + def encode( + self, + prompts: str, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[List[int]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + @overload # LEGACY: multi (prompt + optional token ids) def encode( self, - prompts: Optional[Union[str, List[str]]] = None, + prompts: List[str], pooling_params: Optional[Union[PoolingParams, - List[PoolingParams]]] = None, + Sequence[PoolingParams]]] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload # LEGACY: single (token ids + optional prompt) + def encode( + self, + prompts: Optional[str] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + prompt_token_ids: List[int], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload # LEGACY: multi (token ids + optional prompt) + def encode( + self, + prompts: Optional[List[str]] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + prompt_token_ids: List[List[int]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload # LEGACY: single or multi token ids [pos-only] + def encode( + self, + prompts: None, + pooling_params: None, + prompt_token_ids: Union[List[int], List[List[int]]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload + def encode( + self, + inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + /, # We may enable `inputs` keyword after removing the old API + *, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @deprecate_kwargs("prompts", + "prompt_token_ids", + "multi_modal_data", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'inputs' parameter " + "instead.") + def encode( + self, + prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + Optional[Union[str, List[str]]]] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> List[EmbeddingRequestOutput]: """Generates the completions for the input prompts. @@ -206,124 +398,133 @@ def encode( into a single list and pass it to this method. Args: - prompts: A list of prompts to generate completions for. + inputs: The inputs to the LLM. You may pass a sequence of inputs for + batch inference. See :class:`~vllm.inputs.PromptStrictInputs` + for more details about the format of each input. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. - prompt_token_ids: A list of token IDs for the prompts. If None, we - use the tokenizer to convert the prompts to token IDs. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. - multi_modal_data: Multi modal data. Returns: A list of `EmbeddingRequestOutput` objects containing the generated embeddings in the same order as the input prompts. """ + if prompt_token_ids is not None or multi_modal_data is not None: + inputs = self._convert_v1_inputs( + prompts=cast(Optional[Union[str, List[str]]], prompts), + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + ) + else: + inputs = cast( + Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + prompts) + if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() - requests_data = self._validate_and_prepare_requests( - prompts, - pooling_params, - prompt_token_ids, - lora_request, - multi_modal_data, + self._validate_and_add_requests( + inputs=inputs, + params=pooling_params, + lora_request=lora_request, ) - # Add requests to the engine and run the engine - for request_data in requests_data: - self._add_request(**request_data) + outputs = self._run_engine(use_tqdm=use_tqdm) + return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput) - return self._run_engine(use_tqdm) - - def _validate_and_prepare_requests( + # LEGACY + def _convert_v1_inputs( self, prompts: Optional[Union[str, List[str]]], - params: Union[Union[SamplingParams, PoolingParams], - List[Union[SamplingParams, - PoolingParams]]], # Unified parameter - prompt_token_ids: Optional[List[List[int]]] = None, - lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, - ) -> List[dict]: - """Validates and prepares request data for adding to the engine. + prompt_token_ids: Optional[Union[List[int], List[List[int]]]], + multi_modal_data: Optional[MultiModalData], + ): + # skip_tokenizer_init is now checked in engine - Ensures prompts and token IDs are consistent, and returns a list of - dictionaries with request data for further processing. - """ - if prompts is None and prompt_token_ids is None: - raise ValueError("Either prompts or prompt_token_ids must be " - "provided.") - if self.llm_engine.model_config.skip_tokenizer_init \ - and prompts is not None: - raise ValueError("prompts must be None if skip_tokenizer_init " - "is True") - if isinstance(prompts, str): - # Convert a single prompt to a list. - prompts = [prompts] - if (prompts is not None and prompt_token_ids is not None - and len(prompts) != len(prompt_token_ids)): - raise ValueError("The lengths of prompts and prompt_token_ids " - "must be the same.") + if prompts is not None: + prompts = [p["content"] for p in parse_and_batch_prompt(prompts)] + if prompt_token_ids is not None: + prompt_token_ids = [ + p["content"] for p in parse_and_batch_prompt(prompt_token_ids) + ] + num_requests = None if prompts is not None: num_requests = len(prompts) - else: - assert prompt_token_ids is not None + if prompt_token_ids is not None: + if (num_requests is not None + and num_requests != len(prompt_token_ids)): + raise ValueError("The lengths of prompts and prompt_token_ids " + "must be the same.") + num_requests = len(prompt_token_ids) + if num_requests is None: + raise ValueError("Either prompts or prompt_token_ids must be " + "provided.") + + inputs: List[PromptInputs] = [] + for i in range(num_requests): + if prompts is not None: + if prompt_token_ids is not None: + item = TextTokensPrompt( + prompt=prompts[i], + prompt_token_ids=prompt_token_ids[i]) + else: + item = TextPrompt(prompt=prompts[i]) + else: + if prompt_token_ids is not None: + item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) + else: + raise AssertionError + + if multi_modal_data is not None: + item["multi_modal_data"] = multi_modal_data + + inputs.append(item) + + return inputs + + def _validate_and_add_requests( + self, + inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, + Sequence[PoolingParams]], + lora_request: Optional[LoRARequest], + ) -> None: + if isinstance(inputs, (str, dict)): + # Convert a single prompt to a list. + inputs = [inputs] + + num_requests = len(inputs) if isinstance(params, list) and len(params) != num_requests: raise ValueError("The lengths of prompts and params " "must be the same.") - if multi_modal_data: - multi_modal_data.data = multi_modal_data.data.to(torch.float16) # Add requests to the engine. - requests_data = [] - for i in range(num_requests): - prompt = prompts[i] if prompts is not None else None - token_ids = None if prompt_token_ids is None else prompt_token_ids[ - i] - - multi_modal_item = MultiModalData( - type=multi_modal_data.type, - data=multi_modal_data.data[i].unsqueeze(0), - ) if multi_modal_data else None - - requests_data.append({ - "prompt": - prompt, - "params": - params[i] if isinstance(params, list) else params, - "prompt_token_ids": - token_ids, - "lora_request": - lora_request, - "multi_modal_data": - multi_modal_item, - }) - - return requests_data + for i, request_inputs in enumerate(inputs): + self._add_request( + request_inputs, + params[i] if isinstance(params, Sequence) else params, + lora_request=lora_request, + ) def _add_request( self, - prompt: Optional[str], + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - prompt_token_ids: Optional[List[int]], lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request(request_id, - prompt, + inputs, params, - prompt_token_ids, - lora_request=lora_request, - multi_modal_data=multi_modal_data) + lora_request=lora_request) def _run_engine( - self, use_tqdm: bool + self, *, use_tqdm: bool ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # Initialize tqdm. if use_tqdm: @@ -355,5 +556,4 @@ def _run_engine( # Sort the outputs by request ID. # This is necessary because some requests may be finished earlier than # its previous requests. - outputs = sorted(outputs, key=lambda x: int(x.request_id)) - return outputs + return sorted(outputs, key=lambda x: int(x.request_id)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 7e179362eef8..33daabd881df 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -176,9 +176,15 @@ async def create_chat_completion( except ValueError as e: return self.create_error_response(str(e)) - result_generator = self.engine.generate(prompt_text, sampling_params, - request_id, prompt_ids, - lora_request) + result_generator = self.engine.generate( + { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids + }, + sampling_params, + request_id, + lora_request, + ) # Streaming response if request.stream: return self.chat_completion_stream_generator( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 158d8ed7fbbf..d1812c8f44f4 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -119,12 +119,17 @@ async def create_completion(self, request: CompletionRequest, truncate_prompt_tokens) prompt_ids, prompt_text = prompt_formats - generators.append( - self.engine.generate(prompt_text, - sampling_params, - f"{request_id}-{i}", - prompt_token_ids=prompt_ids, - lora_request=lora_request)) + generator = self.engine.generate( + { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids + }, + sampling_params, + f"{request_id}-{i}", + lora_request=lora_request, + ) + + generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 7a57be0c8891..5a3448de3d7a 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,5 +1,5 @@ import time -from typing import AsyncIterator, List, Tuple +from typing import AsyncIterator, List, Optional, Tuple from fastapi import Request @@ -100,11 +100,16 @@ async def create_embedding(self, request: EmbeddingRequest, prompt_ids, prompt_text = prompt_formats - generators.append( - self.engine.generate(prompt_text, - pooling_params, - f"{request_id}-{i}", - prompt_token_ids=prompt_ids)) + generator = self.engine.encode( + { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids + }, + pooling_params, + f"{request_id}-{i}", + ) + + generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -113,16 +118,21 @@ async def create_embedding(self, request: EmbeddingRequest, int, EmbeddingRequestOutput]] = merge_async_iterators(*generators) # Non-streaming response - final_res_batch: EmbeddingRequestOutput = [None] * len(prompts) - async for i, res in result_generator: - if await raw_request.is_disconnected(): - # Abort the request if the client disconnects. - await self.engine.abort(f"{request_id}-{i}") - # TODO: Use a vllm-specific Validation Error - return self.create_error_response("Client disconnected") - final_res_batch[i] = res - response = request_output_to_embedding_response( - final_res_batch, request_id, created_time, model_name) + final_res_batch: List[Optional[EmbeddingRequestOutput]] + final_res_batch = [None] * len(prompts) + try: + async for i, res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(f"{request_id}-{i}") + # TODO: Use a vllm-specific Validation Error + return self.create_error_response("Client disconnected") + final_res_batch[i] = res + response = request_output_to_embedding_response( + final_res_batch, request_id, created_time, model_name) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) return response diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 0df0223b9dbb..708b0dad102c 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -143,7 +143,8 @@ def create_streaming_error_response( return json_str async def _check_model( - self, request: Union[CompletionRequest, ChatCompletionRequest] + self, request: Union[CompletionRequest, ChatCompletionRequest, + EmbeddingRequest] ) -> Optional[ErrorResponse]: if request.model in self.served_model_names: return None @@ -155,7 +156,8 @@ async def _check_model( status_code=HTTPStatus.NOT_FOUND) def _maybe_get_lora( - self, request: Union[CompletionRequest, ChatCompletionRequest] + self, request: Union[CompletionRequest, ChatCompletionRequest, + EmbeddingRequest] ) -> Optional[LoRARequest]: if request.model in self.served_model_names: return None diff --git a/vllm/inputs.py b/vllm/inputs.py new file mode 100644 index 000000000000..f5d99b1b66b7 --- /dev/null +++ b/vllm/inputs.py @@ -0,0 +1,130 @@ +from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, + TypedDict, Union, cast, overload) + +from typing_extensions import NotRequired + +if TYPE_CHECKING: + from vllm.sequence import MultiModalData + + +class ParsedText(TypedDict): + content: str + is_tokens: Literal[False] + + +class ParsedTokens(TypedDict): + content: List[int] + is_tokens: Literal[True] + + +# https://github.com/vllm-project/vllm/pull/4028 +@overload +def parse_and_batch_prompt( + prompt: Union[str, List[str]]) -> Sequence[ParsedText]: + ... + + +@overload +def parse_and_batch_prompt( + prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]: + ... + + +def parse_and_batch_prompt( + prompt: Union[str, List[str], List[int], List[List[int]]], +) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: + if isinstance(prompt, str): + # case 1: a string + return [ParsedText(content=prompt, is_tokens=False)] + + if isinstance(prompt, list): + if len(prompt) == 0: + raise ValueError("please provide at least one prompt") + + if isinstance(prompt[0], str): + # case 2: array of strings + return [ + ParsedText(content=elem, is_tokens=False) + for elem in cast(List[str], prompt) + ] + if isinstance(prompt[0], int): + # case 3: array of tokens + elem = cast(List[int], prompt) + return [ParsedTokens(content=elem, is_tokens=True)] + if isinstance(prompt[0], list): + if len(prompt[0]) == 0: + raise ValueError("please provide at least one prompt") + + if isinstance(prompt[0][0], int): + # case 4: array of token arrays + return [ + ParsedTokens(content=elem, is_tokens=True) + for elem in cast(List[List[int]], prompt) + ] + + raise ValueError("prompt must be a string, array of strings, " + "array of tokens, or array of token arrays") + + +class TextPrompt(TypedDict): + """Schema for a text prompt.""" + + prompt: str + """The input text to be tokenized before passing to the model.""" + + multi_modal_data: NotRequired["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +class TokensPrompt(TypedDict): + """Schema for a tokenized prompt.""" + + prompt_token_ids: List[int] + """A list of token IDs to pass to the model.""" + + multi_modal_data: NotRequired["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +class TextTokensPrompt(TypedDict): + """It is assumed that :attr:`prompt` is consistent with + :attr:`prompt_token_ids`. This is currently used in + :class:`AsyncLLMEngine` for logging both the text and token IDs.""" + + prompt: str + """The prompt text.""" + + prompt_token_ids: List[int] + """The token IDs of the prompt. If None, we use the + tokenizer to convert the prompts to token IDs.""" + + multi_modal_data: NotRequired["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +PromptStrictInputs = Union[str, TextPrompt, TokensPrompt] +""" +The inputs to the LLM, which can take one of the following forms: + +- A text prompt (:class:`str` or :class:`TextPrompt`) +- A tokenized prompt (:class:`TokensPrompt`) +""" + +PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt] +"""Same as :const:`PromptStrictInputs` but additionally accepts +:class:`TextTokensPrompt`.""" + + +class LLMInputs(TypedDict): + prompt_token_ids: List[int] + prompt: Optional[str] + multi_modal_data: Optional["MultiModalData"] diff --git a/vllm/outputs.py b/vllm/outputs.py index f9bce9e683f2..49f526b5f930 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,4 +1,5 @@ import time +from dataclasses import dataclass from typing import List, Optional, Union from vllm.lora.request import LoRARequest @@ -6,6 +7,7 @@ SequenceGroup, SequenceStatus) +@dataclass class CompletionOutput: """The output data of one completion output of a request. @@ -24,25 +26,14 @@ class CompletionOutput: lora_request: The LoRA request that was used to generate the output. """ - def __init__( - self, - index: int, - text: str, - token_ids: List[int], - cumulative_logprob: float, - logprobs: Optional[SampleLogprobs], - finish_reason: Optional[str] = None, - stop_reason: Union[int, str, None] = None, - lora_request: Optional[LoRARequest] = None, - ) -> None: - self.index = index - self.text = text - self.token_ids = token_ids - self.cumulative_logprob = cumulative_logprob - self.logprobs = logprobs - self.finish_reason = finish_reason - self.stop_reason = stop_reason - self.lora_request = lora_request + index: int + text: str + token_ids: List[int] + cumulative_logprob: float + logprobs: Optional[SampleLogprobs] + finish_reason: Optional[str] = None + stop_reason: Union[int, str, None] = None + lora_request: Optional[LoRARequest] = None def finished(self) -> bool: return self.finish_reason is not None @@ -57,6 +48,7 @@ def __repr__(self) -> str: f"stop_reason={self.stop_reason})") +@dataclass class EmbeddingOutput: """The output data of one completion output of a request. @@ -65,15 +57,11 @@ class EmbeddingOutput: length of vector depends on the model as listed in the embedding guide. """ - def __init__( - self, - embedding: List[float], - ) -> None: - self.embedding = embedding + embedding: List[float] def __repr__(self) -> str: return (f"EmbeddingOutput(" - f"embedding={len(self.embedding)}") + f"embedding={len(self.embedding)})") class RequestOutput: @@ -93,7 +81,7 @@ class RequestOutput: def __init__( self, request_id: str, - prompt: str, + prompt: Optional[str], prompt_token_ids: List[int], prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], @@ -183,7 +171,7 @@ class EmbeddingRequestOutput: finished (bool): A flag indicating whether the embedding is completed. """ - def __init__(self, request_id: str, outputs: 'EmbeddingOutput', + def __init__(self, request_id: str, outputs: "EmbeddingOutput", prompt_token_ids: List[int], finished: bool): self.request_id = request_id self.prompt_token_ids = prompt_token_ids diff --git a/vllm/sequence.py b/vllm/sequence.py index aa759448d82b..f8e9da6c7965 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from vllm.block import LogicalTokenBlock +from vllm.inputs import LLMInputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -210,8 +211,7 @@ class Sequence: Args: seq_id: The ID of the sequence. - prompt: The prompt of the sequence. - prompt_token_ids: The token IDs of the prompt. + inputs: The inputs of the sequence. block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. lora_request: LoRA request. @@ -220,25 +220,24 @@ class Sequence: def __init__( self, seq_id: int, - prompt: str, - prompt_token_ids: List[int], + inputs: LLMInputs, block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, ) -> None: self.seq_id = seq_id - self.prompt = prompt + self.inputs = inputs self.block_size = block_size self.eos_token_id = eos_token_id self.lora_request = lora_request - self.data: SequenceData = SequenceData(prompt_token_ids) + self.data = SequenceData(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] self.output_text = "" self.logical_token_blocks: List[LogicalTokenBlock] = [] # Initialize the logical token blocks with the prompt token ids. - self._append_tokens_to_blocks(prompt_token_ids) + self._append_tokens_to_blocks(self.prompt_token_ids) self.status = SequenceStatus.WAITING self.stop_reason: Union[int, str, None] = None @@ -248,6 +247,18 @@ def __init__( # Input + output tokens self.tokens: Optional[List[str]] = None + @property + def prompt(self) -> Optional[str]: + return self.inputs["prompt"] + + @property + def prompt_token_ids(self) -> List[int]: + return self.inputs["prompt_token_ids"] + + @property + def multi_modal_data(self) -> Optional["MultiModalData"]: + return self.inputs["multi_modal_data"] + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -415,7 +426,6 @@ class SequenceGroup: sampling_params: The sampling parameters used to generate the outputs. arrival_time: The arrival time of the request. lora_request: LoRA request. - multi_modal_data: Multi modal data associated with the request. embeddings: The embeddings vectors of the prompt of the sequence group for an embedding model. pooling_params: The pooling parameters used to generate the pooling @@ -429,7 +439,6 @@ def __init__( arrival_time: float, sampling_params: Optional[SamplingParams] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, embeddings: Optional[List[float]] = None, pooling_params: Optional[PoolingParams] = None, ) -> None: @@ -444,12 +453,11 @@ def __init__( self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None self.state = SequenceGroupState() - self.multi_modal_data = multi_modal_data self.embeddings = embeddings self.pooling_params = pooling_params @property - def prompt(self) -> str: + def prompt(self) -> Optional[str]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. return next(iter(self.seqs_dict.values())).prompt @@ -458,7 +466,13 @@ def prompt(self) -> str: def prompt_token_ids(self) -> List[int]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. - return next(iter(self.seqs_dict.values())).data.prompt_token_ids + return next(iter(self.seqs_dict.values())).prompt_token_ids + + @property + def multi_modal_data(self) -> Optional[MultiModalData]: + # All sequences in the group should have the same multi-modal data. + # We use the multi-modal data of an arbitrary sequence. + return next(iter(self.seqs_dict.values())).multi_modal_data @property def lora_int_id(self) -> int: diff --git a/vllm/utils.py b/vllm/utils.py index 4cb9d905097b..c8bc54dab41b 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -11,7 +11,7 @@ import uuid import warnings from collections import defaultdict -from functools import lru_cache, partial +from functools import lru_cache, partial, wraps from platform import uname from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Hashable, List, Optional, OrderedDict, Tuple, TypeVar, @@ -658,3 +658,44 @@ def enable_trace_function_call_for_thread() -> None: filename) os.makedirs(os.path.dirname(log_path), exist_ok=True) enable_trace_function_call(log_path) + + +def identity(value: T) -> T: + return value + + +F = TypeVar('F', bound=Callable[..., Any]) + + +def deprecate_kwargs( + *kws: str, + is_deprecated: Union[bool, Callable[[], bool]] = True, + additional_message: Optional[str] = None) -> Callable[[F], F]: + deprecated_kws = set(kws) + + if not callable(is_deprecated): + is_deprecated = partial(identity, is_deprecated) + + def wrapper(fn: F) -> F: + + @wraps(fn) + def inner(*args, **kwargs): + if is_deprecated(): + deprecated_kwargs = kwargs.keys() & deprecated_kws + if deprecated_kwargs: + msg = ( + f"The keyword arguments {deprecated_kwargs} are " + "deprecated and will be removed in a future update.") + if additional_message is not None: + msg += f" {additional_message}" + + warnings.warn( + DeprecationWarning(msg), + stacklevel=3, # The inner function takes up one level + ) + + return fn(*args, **kwargs) + + return inner # type: ignore + + return wrapper