Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 1690706

Browse files
DarkLight1337ywang96
authored andcommitted
[Core] Consolidate prompt arguments to LLM engines (vllm-project#4328)
Co-authored-by: Roger Wang <ywang@roblox.com>
1 parent 1c49a67 commit 1690706

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1404
-439
lines changed

.buildkite/test-pipeline.yaml

+6-3
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ steps:
6363
mirror_hardwares: [amd]
6464

6565
commands:
66-
# these tests have to be separated, because each one will allocate all posible GPU memory
67-
- pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py
68-
- pytest -v -s entrypoints/test_server_oot_registration.py
66+
- pytest -v -s test_inputs.py
67+
- pytest -v -s entrypoints -m llm
68+
- pytest -v -s entrypoints -m openai
6969

7070
- label: Examples Test
7171
working_dir: "/vllm-workspace/examples"
@@ -110,6 +110,9 @@ steps:
110110
mirror_hardwares: [amd]
111111
command: pytest -v -s test_logits_processor.py
112112

113+
- label: Utils Test
114+
command: pytest -v -s test_utils.py
115+
113116
- label: Worker Test
114117
mirror_hardwares: [amd]
115118
command: pytest -v -s worker

benchmarks/benchmark_latency.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import json
44
import time
55
from pathlib import Path
6-
from typing import Optional
6+
from typing import List, Optional
77

88
import numpy as np
99
import torch
1010
from tqdm import tqdm
1111

1212
from vllm import LLM, SamplingParams
13+
from vllm.inputs import PromptStrictInputs
1314
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
1415

1516

@@ -48,7 +49,9 @@ def main(args: argparse.Namespace):
4849
dummy_prompt_token_ids = np.random.randint(10000,
4950
size=(args.batch_size,
5051
args.input_len))
51-
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
52+
dummy_inputs: List[PromptStrictInputs] = [{
53+
"prompt_token_ids": batch
54+
} for batch in dummy_prompt_token_ids.tolist()]
5255

5356
def run_to_completion(profile_dir: Optional[str] = None):
5457
if profile_dir:
@@ -59,13 +62,13 @@ def run_to_completion(profile_dir: Optional[str] = None):
5962
],
6063
on_trace_ready=torch.profiler.tensorboard_trace_handler(
6164
str(profile_dir))) as p:
62-
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
65+
llm.generate(dummy_inputs,
6366
sampling_params=sampling_params,
6467
use_tqdm=False)
6568
print(p.key_averages())
6669
else:
6770
start_time = time.perf_counter()
68-
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
71+
llm.generate(dummy_inputs,
6972
sampling_params=sampling_params,
7073
use_tqdm=False)
7174
end_time = time.perf_counter()

docs/source/offline_inference/llm.rst docs/source/dev/offline_inference/llm.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
LLM Class
2-
==========
2+
=========
33

44
.. autoclass:: vllm.LLM
55
:members:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
LLM Inputs
2+
==========
3+
4+
.. autodata:: vllm.inputs.PromptStrictInputs
5+
6+
.. autoclass:: vllm.inputs.TextPrompt
7+
:show-inheritance:
8+
:members:
9+
:member-order: bysource
10+
11+
.. autoclass:: vllm.inputs.TokensPrompt
12+
:show-inheritance:
13+
:members:
14+
:member-order: bysource
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Offline Inference
2+
=================================
3+
4+
.. toctree::
5+
:maxdepth: 1
6+
7+
llm
8+
llm_inputs

docs/source/index.rst

+3-8
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,6 @@ Documentation
6868
getting_started/quickstart
6969
getting_started/examples/examples_index
7070

71-
.. toctree::
72-
:maxdepth: 1
73-
:caption: Offline Inference
74-
75-
offline_inference/llm
76-
offline_inference/sampling_params
77-
7871
.. toctree::
7972
:maxdepth: 1
8073
:caption: Serving
@@ -108,7 +101,9 @@ Documentation
108101
.. toctree::
109102
:maxdepth: 2
110103
:caption: Developer Documentation
111-
104+
105+
dev/sampling_params
106+
dev/offline_inference/offline_index
112107
dev/engine/engine_index
113108
dev/kernel/paged_attention
114109
dev/dockerfile/dockerfile

docs/source/serving/openai_compatible_server.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ completion = client.chat.completions.create(
4848
```
4949

5050
### Extra Parameters for Chat API
51-
The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported.
51+
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
5252

5353
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
5454
:language: python
@@ -65,7 +65,7 @@ The following extra parameters are supported:
6565
```
6666

6767
### Extra Parameters for Completions API
68-
The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported.
68+
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
6969

7070
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
7171
:language: python

examples/llava_example.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,15 @@ def run_llava_pixel_values():
2323
"\nUSER: What is the content of this image?\nASSISTANT:")
2424

2525
# This should be provided by another online or offline component.
26-
images = torch.load("images/stop_sign_pixel_values.pt")
26+
image = torch.load("images/stop_sign_pixel_values.pt")
27+
28+
outputs = llm.generate({
29+
"prompt":
30+
prompt,
31+
"multi_modal_data":
32+
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
33+
})
2734

28-
outputs = llm.generate(prompt,
29-
multi_modal_data=MultiModalData(
30-
type=MultiModalData.Type.IMAGE, data=images))
3135
for o in outputs:
3236
generated_text = o.outputs[0].text
3337
print(generated_text)
@@ -46,11 +50,14 @@ def run_llava_image_features():
4650
"\nUSER: What is the content of this image?\nASSISTANT:")
4751

4852
# This should be provided by another online or offline component.
49-
images = torch.load("images/stop_sign_image_features.pt")
50-
51-
outputs = llm.generate(prompt,
52-
multi_modal_data=MultiModalData(
53-
type=MultiModalData.Type.IMAGE, data=images))
53+
image = torch.load("images/stop_sign_image_features.pt")
54+
55+
outputs = llm.generate({
56+
"prompt":
57+
prompt,
58+
"multi_modal_data":
59+
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
60+
})
5461
for o in outputs:
5562
generated_text = o.outputs[0].text
5663
print(generated_text)

pyproject.toml

+7
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,10 @@ skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
6565
[tool.isort]
6666
use_parentheses = true
6767
skip_gitignore = true
68+
69+
[tool.pytest.ini_options]
70+
markers = [
71+
"skip_global_cleanup",
72+
"llm: run tests for vLLM API only",
73+
"openai: run tests for OpenAI API only",
74+
]

tests/async_engine/test_async_llm_engine.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ async def step_async(self):
2525
return [RequestOutput(
2626
request_id=self.request_id)] if self.request_id else []
2727

28-
async def encode_request_async(self, *args, **kwargs):
28+
async def process_model_inputs_async(self, *args, **kwargs):
2929
pass
3030

3131
def generate(self, request_id):

tests/async_engine/test_openapi_server_ray.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def server():
2929
ray.shutdown()
3030

3131

32-
@pytest.fixture(scope="session")
32+
@pytest.fixture(scope="module")
3333
def client():
3434
client = openai.AsyncOpenAI(
3535
base_url="http://localhost:8000/v1",

tests/conftest.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from vllm import LLM, SamplingParams
1515
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
1616
from vllm.distributed import destroy_model_parallel
17+
from vllm.inputs import PromptInputs
1718
from vllm.logger import init_logger
1819
from vllm.sequence import MultiModalData
1920

@@ -587,12 +588,22 @@ def generate(
587588
) -> List[Tuple[List[int], str]]:
588589
if images is not None:
589590
assert len(prompts) == images.shape[0]
590-
req_outputs = self.model.generate(
591-
prompts,
592-
sampling_params=sampling_params,
593-
multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE,
594-
data=images)
595-
if images is not None else None)
591+
592+
prompt_inputs: List[PromptInputs] = []
593+
for i, prompt in enumerate(prompts):
594+
image = None if images is None else images[i:i + 1]
595+
mm_data = None if image is None else MultiModalData(
596+
type=MultiModalData.Type.IMAGE,
597+
data=image,
598+
)
599+
600+
prompt_inputs.append({
601+
"prompt": prompt,
602+
"multi_modal_data": mm_data,
603+
})
604+
605+
req_outputs = self.model.generate(prompt_inputs,
606+
sampling_params=sampling_params)
596607
outputs = []
597608
for req_output in req_outputs:
598609
prompt_str = req_output.prompt

tests/core/test_block_manager.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,11 @@ def test_append_slot_cow():
133133

134134
# Allocate prompt to gpu block. There is one slot left in the block.
135135
prompt = Sequence(seq_id=1,
136-
prompt="one two three",
137-
prompt_token_ids=[1, 2, 3],
136+
inputs={
137+
"prompt": "one two three",
138+
"prompt_token_ids": [1, 2, 3],
139+
"multi_modal_data": None
140+
},
138141
block_size=block_size)
139142

140143
# Fork the sequence, such that a COW will be required when we append a new
@@ -304,7 +307,13 @@ def test_sliding_window_multi_seq():
304307

305308
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
306309

307-
parent = Sequence(1, "one two three", [0, 1, 2], block_size)
310+
parent = Sequence(seq_id=1,
311+
inputs={
312+
"prompt": "one two three",
313+
"prompt_token_ids": [0, 1, 2],
314+
"multi_modal_data": None
315+
},
316+
block_size=block_size)
308317
seq_group = SequenceGroup(request_id="1",
309318
seqs=[parent],
310319
arrival_time=time.time(),

tests/core/utils.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@ def create_dummy_prompt(
2121
# and prompt "0 ... block_size".
2222
prompt_tokens = list(range(prompt_length))
2323
prompt_str = " ".join([str(t) for t in prompt_tokens])
24-
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
24+
prompt = Sequence(int(request_id),
25+
inputs={
26+
"prompt": prompt_str,
27+
"prompt_token_ids": prompt_tokens,
28+
"multi_modal_data": None,
29+
},
30+
block_size=block_size)
2531
seq_group = SequenceGroup(request_id=request_id,
2632
seqs=[prompt],
2733
arrival_time=time.time(),
@@ -51,8 +57,11 @@ def create_seq_group(
5157
for seq_id_offset, output_len in enumerate(seq_output_lens):
5258
seq = Sequence(
5359
seq_id=seq_id_start + seq_id_offset,
54-
prompt="",
55-
prompt_token_ids=prompt_token_ids,
60+
inputs={
61+
"prompt": "",
62+
"prompt_token_ids": prompt_token_ids,
63+
"multi_modal_data": None,
64+
},
5665
block_size=16,
5766
)
5867

tests/engine/test_skip_tokenizer_init.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def test_skip_tokenizer_initialization(model: str):
1414
with pytest.raises(ValueError) as err:
1515
llm.generate("abc", sampling_params)
1616
assert "prompts must be None if" in str(err.value)
17-
outputs = llm.generate(prompt_token_ids=[[1, 2, 3]],
17+
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
1818
sampling_params=sampling_params)
1919
assert len(outputs) > 0
2020
completions = outputs[0].outputs

tests/entrypoints/openai/test_serving_chat.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import asyncio
22
from dataclasses import dataclass
33

4+
import pytest
5+
46
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
57

68
MODEL_NAME = "openai-community/gpt2"
79
CHAT_TEMPLATE = "Dummy chat template for testing {}"
810

11+
pytestmark = pytest.mark.openai
12+
913

1014
@dataclass
1115
class MockModelConfig:

tests/entrypoints/test_guided_processors.py

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
5353
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
5454

55+
pytestmark = pytest.mark.openai
56+
5557

5658
def test_guided_logits_processors():
5759
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""

0 commit comments

Comments
 (0)