Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][VLM] Add precise multi-modal placeholder tracking #8346

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d5e298f
[Core][VLM] Add precise multi-modal placeholder tracking
petersalas Sep 18, 2024
aa756bf
Merge remote-tracking branch 'upstream/main' into psalas/placeholder-…
petersalas Sep 18, 2024
46ee2d5
Fix msgpack failures
petersalas Sep 18, 2024
6c7830e
Fix test
petersalas Sep 18, 2024
bf7d874
Change DummyData to NamedTuple
petersalas Sep 19, 2024
33e3023
Merge remote-tracking branch 'upstream/main' into psalas/placeholder-…
petersalas Sep 23, 2024
8ebd665
Merge remote-tracking branch 'upstream/main' into psalas/placeholder-…
petersalas Sep 23, 2024
5da9450
Add online + chunked prefill test for ultravox
petersalas Sep 23, 2024
c5d0d7f
Use SequenceData.from_token_counts in Ultravox
petersalas Sep 23, 2024
7a6cbe9
Update test mock
petersalas Sep 23, 2024
1a617f2
Fix test failures
petersalas Sep 24, 2024
509cbac
Fix llava-onevision dummy data
petersalas Sep 24, 2024
cb2dc49
Allow None values in AttentionMetadata
petersalas Sep 25, 2024
b60ec8c
Merge remote-tracking branch 'upstream/main' into psalas/placeholder-…
petersalas Sep 25, 2024
defb59c
Fix phi3v test
petersalas Sep 25, 2024
5a48342
Merge remote-tracking branch 'upstream/main' into psalas/placeholder-…
petersalas Sep 25, 2024
a8025f9
Replace index tensors with lists
petersalas Sep 26, 2024
c10dff9
Merge remote-tracking branch 'upstream/main' into psalas/placeholder-…
petersalas Sep 27, 2024
862c46c
Fix test failures
petersalas Sep 30, 2024
f14ab68
Merge remote-tracking branch 'upstream/main' into psalas/placeholder-…
petersalas Sep 30, 2024
938b857
Merge remote-tracking branch 'upstream/main' into psalas/placeholder-…
petersalas Oct 7, 2024
6611261
Merge remote-tracking branch 'upstream/main' into psalas/placeholder-…
petersalas Oct 9, 2024
d0d1ea2
Merge remote-tracking branch 'upstream/main' into psalas/placeholder-…
petersalas Oct 9, 2024
d0bc54d
Add docstrings
petersalas Oct 10, 2024
f0b7a3f
Merge remote-tracking branch 'upstream/main' into psalas/placeholder-…
petersalas Oct 10, 2024
5171297
Update docstrings
petersalas Oct 11, 2024
fc0c190
Merge remote-tracking branch 'upstream/main' into psalas/placeholder-…
petersalas Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions examples/offline_inference_audio_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@ def run_ultravox(question, audio_count):
tokenize=False,
add_generation_prompt=True)

llm = LLM(model=model_name,
enforce_eager=True,
enable_chunked_prefill=False,
max_model_len=8192,
limit_mm_per_prompt={"audio": audio_count})
llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand Down
89 changes: 72 additions & 17 deletions tests/models/decoder_only/audio_language/test_ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import numpy as np
import pytest
import pytest_asyncio
from transformers import AutoModel, AutoTokenizer, BatchEncoding

from tests.utils import RemoteOpenAIServer
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE

Expand All @@ -17,6 +19,13 @@
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
HF_PLACEHOLDER = "<|audio|>"

CHUNKED_PREFILL_KWARGS = {
"enable_chunked_prefill": True,
"max_num_seqs": 2,
# Use a very small limit to exercise chunked prefill.
"max_num_batched_tokens": 16
}


@pytest.fixture(scope="session")
def audio_assets():
Expand All @@ -30,6 +39,26 @@ def audio(request):
return AudioAsset(request.param)


@pytest.fixture(params=({}, CHUNKED_PREFILL_KWARGS))
def server(request, audio_assets):
args = [
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
f"--limit-mm-per-prompt=audio={len(audio_assets)}"
] + [
f"--{key.replace('_','-')}={value}"
for key, value in request.param.items()
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client


def _get_prompt(audio_count, question, placeholder):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
placeholder = f"{placeholder}\n" * audio_count
Expand Down Expand Up @@ -68,8 +97,7 @@ def run_test(
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
**kwargs,
):
"""Inference result should be the same between hf and vllm."""
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
Expand All @@ -79,11 +107,8 @@ def run_test(
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).

with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
with vllm_runner(model, dtype=dtype, enforce_eager=True,
**kwargs) as vllm_model:
vllm_outputs_per_audio = [
vllm_model.generate_greedy_logprobs([vllm_prompt],
max_tokens,
Expand Down Expand Up @@ -135,18 +160,16 @@ def run_multi_audio_test(
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
**kwargs,
):
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
limit_mm_per_prompt={
"audio":
max((len(audio) for _, audio in prompts_and_audios))
}) as vllm_model:
},
**kwargs) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
[prompt for prompt, _ in prompts_and_audios],
max_tokens,
Expand All @@ -161,8 +184,9 @@ def run_multi_audio_test(
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS])
def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
num_logprobs: int, vllm_kwargs: dict) -> None:

vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER)
hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER)
Expand All @@ -174,16 +198,17 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
**vllm_kwargs,
)


@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS])
def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
max_tokens: int,
num_logprobs: int) -> None:
max_tokens: int, num_logprobs: int,
vllm_kwargs: dict) -> None:

vllm_prompt = _get_prompt(len(audio_assets),
"Describe each of the audios above.",
Expand All @@ -196,5 +221,35 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
**vllm_kwargs,
)


@pytest.mark.asyncio
async def test_online_inference(client, audio_assets):
petersalas marked this conversation as resolved.
Show resolved Hide resolved
messages = [{
"role":
"user",
"content": [
*[{
"type": "audio_url",
"audio_url": {
"url": audio.url
}
} for audio in audio_assets],
{
"type":
"text",
"text":
f"What's happening in these {len(audio_assets)} audio clips?"
},
],
}]

chat_completion = await client.chat.completions.create(model=MODEL_NAME,
messages=messages,
max_tokens=10)

assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
4 changes: 2 additions & 2 deletions tests/models/decoder_only/vision_language/test_phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,14 +356,14 @@ def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str,
mm_processor_kwargs=None,
)

sequence_data, _, = dummy_data_for_phi3v(
dummy_data = dummy_data_for_phi3v(
ctx=ctx,
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
mm_counts={"image": num_imgs},
num_crops=num_crops,
)
# Ensure we have the right number of placeholders per num_crops size
img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID)
img_tok_count = dummy_data.seq_data.get_token_ids().count(_IMAGE_TOKEN_ID)
assert img_tok_count == toks_per_img * num_imgs


Expand Down
12 changes: 6 additions & 6 deletions tests/multimodal/test_processor_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from vllm.inputs import InputContext, LLMInputs
from vllm.inputs.registry import InputRegistry
from vllm.inputs.registry import DummyData, InputRegistry
from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData

Expand Down Expand Up @@ -56,7 +56,7 @@ def custom_dummy_data_factory(self,
num_crops=DEFAULT_NUM_CROPS):
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops))
return seq_data, None
return DummyData(seq_data, None)

with patch(
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory",
Expand Down Expand Up @@ -177,9 +177,9 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling(
dummy_data = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == expected_seq_count
assert len(dummy_data.seq_data.prompt_token_ids) == expected_seq_count


@pytest.mark.parametrize(
Expand All @@ -206,9 +206,9 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling(
dummy_data = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
assert len(dummy_data.seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS


### Test overrides for the max token count per multimodal instance
Expand Down
57 changes: 45 additions & 12 deletions tests/multimodal/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,50 @@ def test_repeat_and_pad_placeholder_tokens(model):
tokenizer = AutoTokenizer.from_pretrained(model)

test_cases = [
("<image>", 2, "<image><image>", [32000, 32000]),
("<image><image>", 2, "<image><image><image>", [32000, 32000, 32000]),
("<image><image>", [3, 2], "<image><image><image><image><image>",
[32000, 32000, 32000, 32000, 32000]),
("Image:<image>Image:<image>!", [3, 2],
"Image:<image><image><image>Image:<image><image>!",
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]),
("<image>", [3, 2], "<image><image><image>", [32000, 32000, 32000]),
]

for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases:
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
(
"<image>",
2,
"<image><image>",
[32000, 32000],
[{ "offset": 0, "length": 2 }],
),
(
"<image><image>",
2,
"<image><image><image>",
[32000, 32000, 32000],
[{ "offset": 0, "length": 2 }]),
(
"<image><image>",
[3, 2],
"<image><image><image><image><image>",
[32000, 32000, 32000, 32000, 32000],
[{ "offset": 0, "length": 3 }, { "offset": 3, "length": 2 }],
),
(
"Image:<image>Image:<image>!",
[3, 2],
"Image:<image><image><image>Image:<image><image>!",
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
[{ "offset": 2, "length": 3 }, { "offset": 7, "length": 2 }],
),
(
"<image>",
[3, 2],
"<image><image><image>",
[32000, 32000, 32000],
[{ "offset": 0, "length": 3 }],
),
] # yapf: disable

for (
prompt,
repeat_count,
expected_prompt,
expected_token_ids,
expected_ranges,
) in test_cases:
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer=tokenizer,
prompt=prompt,
prompt_token_ids=tokenizer.encode(prompt,
Expand All @@ -113,3 +145,4 @@ def test_repeat_and_pad_placeholder_tokens(model):
)
assert new_prompt == expected_prompt
assert new_token_ids == expected_token_ids
assert ranges == expected_ranges
3 changes: 3 additions & 0 deletions tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def test_model_runner_input():
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_maps=None,
)
model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
Expand Down Expand Up @@ -124,6 +125,7 @@ def test_embedding_model_runner_input():
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_maps=None,
)
model_input = ModelInputForGPUWithPoolingMetadata(
input_tokens=torch.ones(10),
Expand Down Expand Up @@ -174,6 +176,7 @@ def test_multi_step_model_runner_input():
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
multi_modal_placeholder_maps=None,
)
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
Expand Down
7 changes: 7 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import torch

from vllm.multimodal import MultiModalPlaceholderMap

if TYPE_CHECKING:
from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase,
Expand Down Expand Up @@ -105,6 +107,11 @@ class AttentionMetadata:
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor

# The index maps that relate multi-modal embeddings to the corresponding
# placeholders.
multi_modal_placeholder_maps: Optional[Dict[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to give multi_modal_placeholder_maps a default value, so that in non-multi-modal scenarios multi_modal_placeholder_maps need not be specified

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish I could -- unfortunately because this is the base type for the other AttentionMetadata types, doing so would require that all fields in all derived types also have default values.

str, MultiModalPlaceholderMap.IndexMap]]

@property
@abstractmethod
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def prefill_metadata(
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_maps=self.multi_modal_placeholder_maps,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
Expand Down Expand Up @@ -246,6 +247,7 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_maps=None,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
Expand Down
Loading
Loading