Skip to content

Commit

Permalink
rename PromptInputs and inputs with backward compatibility (vllm-proj…
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored and siddharth9820 committed Sep 30, 2024
1 parent b326519 commit 200c0ba
Show file tree
Hide file tree
Showing 21 changed files with 438 additions and 245 deletions.
8 changes: 4 additions & 4 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
from vllm.inputs import PromptInputs
from vllm.inputs import PromptType
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser

Expand Down Expand Up @@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_inputs: List[PromptInputs] = [{
dummy_prompts: List[PromptType] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]

Expand All @@ -74,13 +74,13 @@ def run_to_completion(profile_dir: Optional[str] = None):
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p:
llm.generate(dummy_inputs,
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
print(p.key_averages())
else:
start_time = time.perf_counter()
llm.generate(dummy_inputs,
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dev/multimodal/multimodal_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.

Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`.

Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
by following :ref:`this guide <adding_multimodal_plugin>`.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dev/offline_inference/llm_inputs.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
LLM Inputs
==========

.. autodata:: vllm.inputs.PromptInputs
.. autodata:: vllm.inputs.PromptType

.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.

To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`:

* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
Expand Down
8 changes: 5 additions & 3 deletions tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,19 @@ class MockAsyncLLMEngine(AsyncLLMEngine):

@pytest.mark.asyncio
async def test_new_requests_event():
params = SamplingParams()

engine = MockAsyncLLMEngine()
engine.start_background_loop()
await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0

await engine.add_request("1", "", None)
await engine.add_request("1", "", params)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 1
assert engine.engine.step_calls == 1

await engine.add_request("2", "", None)
await engine.add_request("2", "", params)
engine.engine.generate("2")
await asyncio.sleep(0)
await asyncio.sleep(0)
Expand All @@ -111,7 +113,7 @@ async def test_new_requests_event():
await asyncio.sleep(0.001)
assert engine.engine.step_calls == old_step_calls

await engine.add_request("3", "", None)
await engine.add_request("3", "", params)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == old_step_calls + 1
Expand Down
34 changes: 0 additions & 34 deletions tests/entrypoints/llm/test_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,6 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
assert [o.outputs for o in o1] == [o.outputs for o in o2]


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
pooling_params = PoolingParams()

with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params)

v2_output = llm.encode(prompt, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)

v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
Expand All @@ -79,25 +64,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
pooling_params = PoolingParams()

with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params)

v2_output = llm.encode(PROMPTS, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)

v2_output = llm.encode(
[{
"prompt": p
} for p in PROMPTS],
pooling_params=pooling_params,
)
assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
pooling_params = PoolingParams()
Expand Down
37 changes: 0 additions & 37 deletions tests/entrypoints/llm/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,6 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2]


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.generate(prompts=prompt,
sampling_params=sampling_params)

v2_output = llm.generate(prompt, sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)

v2_output = llm.generate({"prompt": prompt},
sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
Expand All @@ -79,26 +62,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.generate(prompts=PROMPTS,
sampling_params=sampling_params)

v2_output = llm.generate(PROMPTS, sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)

v2_output = llm.generate(
[{
"prompt": p
} for p in PROMPTS],
sampling_params=sampling_params,
)
assert_outputs_equal(v1_output, v2_output)


@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
Expand Down
12 changes: 6 additions & 6 deletions tests/mq_llm_engine/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ async def test_evil_forward(tmp_socket):

# Throws an error in first forward pass.
with pytest.raises(RAISED_ERROR):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
assert client.errored

# Engine is errored, should get ENGINE_DEAD_ERROR.
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
Expand Down Expand Up @@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):

# Generate call should throw ENGINE_DEAD_ERROR
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
Expand Down Expand Up @@ -160,7 +160,7 @@ async def test_failed_abort(tmp_socket):
# with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate(
inputs="Hello my name is",
prompt="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
request_id=uuid.uuid4()):
pass
Expand All @@ -183,7 +183,7 @@ async def test_bad_request(tmp_socket):

# Invalid request should fail, but not crash the server.
with pytest.raises(ValueError):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-1",
lora_request=LoRARequest(
Expand All @@ -192,7 +192,7 @@ async def test_bad_request(tmp_socket):
pass

# This request should be okay.
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-2"):
pass
Expand Down
2 changes: 1 addition & 1 deletion tests/mq_llm_engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def generate(
count = 0
async for out in client.generate(
request_id=request_id,
inputs="Hello my name is Robert and",
prompt="Hello my name is Robert and",
sampling_params=SamplingParams(max_tokens=num_tokens,
temperature=0)):

Expand Down
4 changes: 2 additions & 2 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput)
Expand All @@ -19,7 +19,7 @@
"__version_tuple__",
"LLM",
"ModelRegistry",
"PromptInputs",
"PromptType",
"TextPrompt",
"TokensPrompt",
"SamplingParams",
Expand Down
Loading

0 comments on commit 200c0ba

Please sign in to comment.