From d5e298fcfc192a78cdc2ef742504278f77a0193e Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Wed, 18 Sep 2024 18:57:31 +0000 Subject: [PATCH 01/15] [Core][VLM] Add precise multi-modal placeholder tracking --- examples/offline_inference_audio_language.py | 6 +- tests/models/test_ultravox.py | 37 +++-- tests/multimodal/test_utils.py | 57 +++++-- vllm/attention/backends/abstract.py | 7 + vllm/attention/backends/blocksparse_attn.py | 2 + vllm/attention/backends/flash_attn.py | 19 +++ vllm/attention/backends/flashinfer.py | 18 +++ vllm/attention/backends/rocm_flash_attn.py | 2 + vllm/attention/backends/utils.py | 18 +++ vllm/attention/backends/xformers.py | 2 + vllm/core/scheduler.py | 2 + vllm/inputs/data.py | 8 +- vllm/inputs/registry.py | 19 ++- vllm/model_executor/models/blip.py | 10 +- vllm/model_executor/models/blip2.py | 11 +- vllm/model_executor/models/chameleon.py | 21 ++- vllm/model_executor/models/clip.py | 32 ++-- vllm/model_executor/models/fuyu.py | 11 +- vllm/model_executor/models/internvl.py | 4 +- vllm/model_executor/models/llava.py | 8 +- vllm/model_executor/models/llava_next.py | 8 +- .../model_executor/models/llava_next_video.py | 20 ++- vllm/model_executor/models/minicpmv.py | 2 +- vllm/model_executor/models/paligemma.py | 4 +- vllm/model_executor/models/phi3v.py | 4 +- vllm/model_executor/models/pixtral.py | 2 +- vllm/model_executor/models/qwen.py | 2 +- vllm/model_executor/models/qwen2_vl.py | 2 +- vllm/model_executor/models/siglip.py | 24 ++- vllm/model_executor/models/ultravox.py | 53 ++++--- vllm/model_executor/models/utils.py | 19 ++- vllm/multimodal/__init__.py | 7 +- vllm/multimodal/base.py | 146 +++++++++++++++++- vllm/multimodal/utils.py | 11 +- vllm/sequence.py | 11 +- vllm/worker/cpu_model_runner.py | 24 ++- vllm/worker/model_runner.py | 17 +- vllm/worker/openvino_model_runner.py | 33 +++- vllm/worker/tpu_model_runner.py | 4 + vllm/worker/xpu_model_runner.py | 26 +++- 40 files changed, 569 insertions(+), 144 deletions(-) diff --git a/examples/offline_inference_audio_language.py b/examples/offline_inference_audio_language.py index 1c6ac06123bb..2be2fe2247e7 100644 --- a/examples/offline_inference_audio_language.py +++ b/examples/offline_inference_audio_language.py @@ -33,11 +33,7 @@ def run_ultravox(question, audio_count): tokenize=False, add_generation_prompt=True) - llm = LLM(model=model_name, - enforce_eager=True, - enable_chunked_prefill=False, - max_model_len=8192, - limit_mm_per_prompt={"audio": audio_count}) + llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count}) stop_token_ids = None return llm, prompt, stop_token_ids diff --git a/tests/models/test_ultravox.py b/tests/models/test_ultravox.py index e98db9b65f48..484daa1fa264 100644 --- a/tests/models/test_ultravox.py +++ b/tests/models/test_ultravox.py @@ -19,6 +19,13 @@ VLLM_PLACEHOLDER = "<|reserved_special_token_0|>" HF_PLACEHOLDER = "<|audio|>" +CHUNKED_PREFILL_KWARGS = { + "enable_chunked_prefill": True, + "max_num_seqs": 2, + # Use a very small limit to exercise chunked prefill. + "max_num_batched_tokens": 16 +} + @pytest.fixture(scope="session") def audio_assets(): @@ -70,8 +77,7 @@ def run_test( dtype: str, max_tokens: int, num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, + **kwargs, ): """Inference result should be the same between hf and vllm.""" torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] @@ -81,11 +87,8 @@ def run_test( # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). - with vllm_runner(model, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: + with vllm_runner(model, dtype=dtype, enforce_eager=True, + **kwargs) as vllm_model: vllm_outputs_per_audio = [ vllm_model.generate_greedy_logprobs([vllm_prompt], max_tokens, @@ -137,18 +140,16 @@ def run_multi_audio_test( dtype: str, max_tokens: int, num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, + **kwargs, ): with vllm_runner(model, dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, enforce_eager=True, limit_mm_per_prompt={ "audio": max((len(audio) for _, audio in prompts_and_audios)) - }) as vllm_model: + }, + **kwargs) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( [prompt for prompt, _ in prompts_and_audios], max_tokens, @@ -163,8 +164,9 @@ def run_multi_audio_test( @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS]) def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, - num_logprobs: int) -> None: + num_logprobs: int, vllm_kwargs: dict) -> None: vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER) hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER) @@ -176,16 +178,17 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, - tensor_parallel_size=1, + **vllm_kwargs, ) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS]) def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str, - max_tokens: int, - num_logprobs: int) -> None: + max_tokens: int, num_logprobs: int, + vllm_kwargs: dict) -> None: vllm_prompt = _get_prompt(len(audio_assets), "Describe each of the audios above.", @@ -198,5 +201,5 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, - tensor_parallel_size=1, + **vllm_kwargs, ) diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 38cd48629f90..69f04f0a69c0 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -92,18 +92,50 @@ def test_repeat_and_pad_placeholder_tokens(model): tokenizer = AutoTokenizer.from_pretrained(model) test_cases = [ - ("", 2, "", [32000, 32000]), - ("", 2, "", [32000, 32000, 32000]), - ("", [3, 2], "", - [32000, 32000, 32000, 32000, 32000]), - ("Image:Image:!", [3, 2], - "Image:Image:!", - [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]), - ("", [3, 2], "", [32000, 32000, 32000]), - ] - - for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases: - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + ( + "", + 2, + "", + [32000, 32000], + [{ "offset": 0, "length": 2 }], + ), + ( + "", + 2, + "", + [32000, 32000, 32000], + [{ "offset": 0, "length": 2 }]), + ( + "", + [3, 2], + "", + [32000, 32000, 32000, 32000, 32000], + [{ "offset": 0, "length": 3 }, { "offset": 3, "length": 2 }], + ), + ( + "Image:Image:!", + [3, 2], + "Image:Image:!", + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + [{ "offset": 2, "length": 3 }, { "offset": 7, "length": 2 }], + ), + ( + "", + [3, 2], + "", + [32000, 32000, 32000], + [{ "offset": 0, "length": 3 }], + ), + ] # yapf: disable + + for ( + prompt, + repeat_count, + expected_prompt, + expected_token_ids, + expected_ranges, + ) in test_cases: + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer=tokenizer, prompt=prompt, prompt_token_ids=tokenizer.encode(prompt, @@ -113,3 +145,4 @@ def test_repeat_and_pad_placeholder_tokens(model): ) assert new_prompt == expected_prompt assert new_token_ids == expected_token_ids + assert ranges == expected_ranges diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index adc8390e6f9e..67abba1e2192 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -7,6 +7,8 @@ import torch +from vllm.multimodal import MultiModalPlaceholderMap + if TYPE_CHECKING: from vllm.worker.model_runner_base import (ModelRunnerBase, ModelRunnerInputBase, @@ -105,6 +107,11 @@ class AttentionMetadata: # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor + # The index tensors that relate multi-modal embeddings to the corresponding + # placeholders. + multi_modal_placeholder_maps: Optional[Dict[ + str, MultiModalPlaceholderMap.IndexTensors]] + @property @abstractmethod def prefill_metadata(self) -> Optional["AttentionMetadata"]: diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index d84a40890ebb..c316c6c62003 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -212,6 +212,7 @@ def prefill_metadata( num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + multi_modal_placeholder_maps=self.multi_modal_placeholder_maps, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -240,6 +241,7 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + multi_modal_placeholder_maps=None, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index bf883987bd80..6da2969c8819 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,4 +1,5 @@ """Attention layer with FlashAttention.""" +from collections import defaultdict from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type @@ -13,6 +14,7 @@ compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) +from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: @@ -296,6 +298,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + multi_modal_placeholder_maps=self.multi_modal_placeholder_maps, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -324,6 +327,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + multi_modal_placeholder_maps=None, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, @@ -400,6 +404,9 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 @@ -432,6 +439,12 @@ def _add_seq_group( self.context_lens.append(context_len) if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for key, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[key].extend( + placeholders) + self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) @@ -544,6 +557,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) + placeholder_maps = { + key: placeholder_map.index_tensors(device) + for key, placeholder_map in + self.multimodal_placeholder_maps.items() + } torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, @@ -559,6 +577,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, + multi_modal_placeholder_maps=placeholder_maps, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_prefill_seq_len=max_prefill_seq_len, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 58d62e02e873..bce139954d7f 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,7 +1,10 @@ +from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type +from vllm.multimodal import MultiModalPlaceholderMap + try: from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper @@ -210,6 +213,7 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): attn_metadata = self.runner.attn_backend.make_metadata( num_prefills=0, slot_mapping=self._graph_slot_mapping[:batch_size], + multi_modal_placeholder_maps=None, num_prefill_tokens=0, num_decode_tokens=batch_size, max_prefill_seq_len=0, @@ -455,6 +459,9 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 @@ -506,6 +513,11 @@ def _add_seq_group( inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for key, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[key].extend( + placeholders) self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) @@ -629,6 +641,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) + placeholder_maps = { + key: placeholder_map.index_tensors(device) + for key, placeholder_map in + self.multimodal_placeholder_maps.items() + } torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, @@ -671,6 +688,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], return FlashInferMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_maps=placeholder_maps, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, max_prefill_seq_len=max_prefill_seq_len, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index b0f4d0530b7f..b0cee361616a 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -134,6 +134,7 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + multi_modal_placeholder_maps=self.multi_modal_placeholder_maps, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -162,6 +163,7 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + multi_modal_placeholder_maps=None, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 0375d3488eb1..113e17d9f5c9 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,4 +1,5 @@ """Attention backend utils""" +from collections import defaultdict from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union @@ -7,6 +8,7 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, AttentionState) +from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: @@ -131,6 +133,9 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 @@ -158,6 +163,12 @@ def _add_seq_group( inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for key, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[key].extend( + placeholders) + self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) @@ -249,6 +260,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) + placeholder_maps = { + key: placeholder_map.index_tensors(device) + for key, placeholder_map in + self.multimodal_placeholder_maps.items() + } torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, @@ -261,6 +277,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], return self._metadata_cls( # type: ignore num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_maps=placeholder_maps, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -311,6 +328,7 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): num_prefill_tokens=0, num_decode_tokens=batch_size, slot_mapping=self._graph_slot_mapping[:batch_size], + multi_modal_placeholder_maps=None, seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], max_query_len=None, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index e073d616bf01..b428f9f3f42f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -209,6 +209,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, + multi_modal_placeholder_maps=self.multi_modal_placeholder_maps, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, @@ -252,6 +253,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, + multi_modal_placeholder_maps=None, seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c3fa95f57b73..1975fb16ac79 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1202,6 +1202,8 @@ def schedule( # `multi_modal_data` will be None. multi_modal_data=seq_group.multi_modal_data if scheduler_outputs.num_prefill_groups > 0 else None, + multi_modal_placeholders=seq_group.multi_modal_placeholders + if scheduler_outputs.num_prefill_groups > 0 else None, prompt_adapter_request=seq_group.prompt_adapter_request, ) else: diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 75ab0c770155..f719a0685482 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -4,7 +4,7 @@ from typing_extensions import NotRequired, TypedDict, TypeVar if TYPE_CHECKING: - from vllm.multimodal import MultiModalDataDict + from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict class TextPrompt(TypedDict): @@ -122,6 +122,12 @@ class LLMInputs(TypedDict): if the model supports it. """ + multi_modal_placeholders: NotRequired[ + Optional["MultiModalPlaceholderDict"]] + """ + Placeholder ranges for the multi-modal data. + """ + class EncoderDecoderLLMInputs(LLMInputs): """ diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index ae6c6c05d9f7..f7ec97c26740 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -15,7 +15,8 @@ if TYPE_CHECKING: from vllm.config import ModelConfig - from vllm.multimodal import MultiModalDataDict, MultiModalRegistry + from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict, + MultiModalRegistry) from vllm.sequence import SequenceData logger = init_logger(__name__) @@ -65,6 +66,9 @@ def get_hf_image_processor_config(self) -> Dict[str, Any]: N = TypeVar("N", bound=Type[nn.Module]) +DummyDataTuple = Tuple["SequenceData", Optional["MultiModalDataDict"], + Optional["MultiModalPlaceholderDict"]] + class DummyDataFactory(Protocol): @@ -73,7 +77,7 @@ def __call__( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], - ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: + ) -> DummyDataTuple: """ Create dummy data to be inputted into the model. @@ -119,7 +123,7 @@ def _default_dummy_data_factory( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], - ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: + ) -> DummyDataTuple: """ The default dummy data factory represents the longest possible text that can be inputted to the model. @@ -133,8 +137,9 @@ def _default_dummy_data_factory( dummy_seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len) dummy_multi_modal_data = None + dummy_placeholders = None - return dummy_seq_data, dummy_multi_modal_data + return dummy_seq_data, dummy_multi_modal_data, dummy_placeholders def register_dummy_data(self, factory: DummyDataFactory): """ @@ -163,7 +168,7 @@ def dummy_data_for_profiling( model_config: "ModelConfig", seq_len: int, mm_registry: "MultiModalRegistry", - ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: + ) -> DummyDataTuple: """ Create dummy data for profiling the memory usage of a model. @@ -184,7 +189,7 @@ def dummy_data_for_profiling( .get(model_cls, self._default_dummy_data_factory) mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) - seq_data, mm_data = dummy_factory( + seq_data, mm_data, ranges = dummy_factory( InputContext(model_config), seq_len, _MultiModalCounts(mm_counts), @@ -204,7 +209,7 @@ def dummy_data_for_profiling( f"Expected at least {num_expected} dummy '{k}' instances " f"for profiling, but found {num_items} instances instead.") - return seq_data, mm_data + return seq_data, mm_data, ranges def _default_input_processor(self, ctx: InputContext, inputs: LLMInputs) -> LLMInputs: diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 583d5d217903..fcebf62dfc29 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -98,6 +98,11 @@ def input_processor_for_blip( if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs + if "multi_modal_placeholders" in llm_inputs and "image" in llm_inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return llm_inputs + tokenizer = cached_get_tokenizer(model_config.tokenizer) if image_feature_size_override is None: @@ -105,7 +110,7 @@ def input_processor_for_blip( else: image_feature_size = image_feature_size_override - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], @@ -116,7 +121,8 @@ def input_processor_for_blip( # NOTE: Create a defensive copy of the original inputs return LLMInputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": ranges}) # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 39f2b2d853a6..dea17c6b2cad 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -433,7 +433,12 @@ def dummy_seq_data_for_blip2( [image_token_id]) * image_feature_size * num_images token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - image_feature_size * num_images) - return SequenceData(token_ids) + return SequenceData(token_ids), { + "image": [{ + "offset": i * image_feature_size, + "length": image_feature_size + } for i in range(num_images)] + } def dummy_data_for_blip2(ctx: InputContext, seq_len: int, @@ -442,7 +447,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int, vision_config = hf_config.vision_config num_images = mm_counts["image"] - seq_data = dummy_seq_data_for_blip2( + seq_data, ranges = dummy_seq_data_for_blip2( hf_config, seq_len, num_images, @@ -452,7 +457,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int, if isinstance(vision_config, Blip2VisionConfig): mm_data = dummy_image_for_blip(vision_config, num_images) - return seq_data, mm_data + return seq_data, mm_data, ranges msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 47e020e8ecb7..595f6ac472b8 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -76,7 +76,12 @@ def dummy_seq_data_for_chameleon( [image_token_id]) * image_feature_size * num_images token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - image_feature_size * num_images) - return SequenceData(token_ids) + return SequenceData(token_ids), { + "image": [{ + "offset": i * image_feature_size, + "length": image_feature_size + } for i in range(num_images)] + } def dummy_image_for_chameleon( @@ -100,14 +105,14 @@ def dummy_data_for_chameleon(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): num_images = mm_counts["image"] - seq_data = dummy_seq_data_for_chameleon( + seq_data, ranges = dummy_seq_data_for_chameleon( seq_len, num_images, image_token_id=CHAMELEON_IMAGE_TOKEN_ID, ) mm_data = dummy_image_for_chameleon(num_images) - return seq_data, mm_data + return seq_data, mm_data, ranges def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): @@ -122,9 +127,14 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs + if "multi_modal_placeholders" in llm_inputs and "image" in llm_inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return llm_inputs + model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer) - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], @@ -143,7 +153,8 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): # NOTE: Create a defensive copy of the original inputs return LLMInputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": ranges}) class ChameleonLayerNorm(nn.LayerNorm): diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 078928f281c2..3675500f6616 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -49,14 +49,13 @@ def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int: return get_clip_image_feature_size(hf_config) -def dummy_seq_data_for_clip( - hf_config: CLIPVisionConfig, - seq_len: int, - num_images: int, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, -): +def dummy_seq_data_for_clip(hf_config: CLIPVisionConfig, + seq_len: int, + num_images: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, + mm_key: str = "image"): if image_feature_size_override is None: image_feature_size = get_clip_image_feature_size(hf_config) else: @@ -66,7 +65,12 @@ def dummy_seq_data_for_clip( [image_token_id]) * image_feature_size * num_images token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - image_feature_size * num_images) - return SequenceData(token_ids) + return SequenceData(token_ids), { + mm_key: [{ + "offset": i * image_feature_size, + "length": image_feature_size + } for i in range(num_images)] + } def dummy_image_for_clip( @@ -98,6 +102,11 @@ def input_processor_for_clip( if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs + if "multi_modal_placeholders" in llm_inputs and "image" in llm_inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return llm_inputs + tokenizer = cached_get_tokenizer(model_config.tokenizer) if image_feature_size_override is None: @@ -111,7 +120,7 @@ def input_processor_for_clip( else: image_feature_size = image_feature_size_override - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], @@ -122,7 +131,8 @@ def input_processor_for_clip( # NOTE: Create a defensive copy of the original inputs return LLMInputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": ranges}) # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index beeae1422957..8675356a7dc8 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -106,7 +106,12 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int): token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - image_feature_size * num_images) - return SequenceData(token_ids) + return SequenceData(token_ids), { + "image": [{ + "offset": i * image_feature_size, + "length": image_feature_size + } for i in range(num_images)] + } def dummy_image_for_fuyu( @@ -122,11 +127,11 @@ def dummy_image_for_fuyu( def dummy_data_for_fuyu(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): num_images = mm_counts["image"] - seq_data = dummy_seq_data_for_fuyu(ctx, seq_len, num_images) + seq_data, ranges = dummy_seq_data_for_fuyu(ctx, seq_len, num_images) mm_data = dummy_image_for_fuyu(num_images, image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT) - return seq_data, mm_data + return seq_data, mm_data, ranges def _fuyu_image_preprocess(image_processor: FuyuImageProcessor, diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 507d7014714a..beb3012cad79 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -302,7 +302,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int, tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) - seq_data = dummy_seq_data_for_clip( + seq_data, ranges = dummy_seq_data_for_clip( vision_config, seq_len, num_images, @@ -324,7 +324,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int, image_height_override=max_image_height, ) - return seq_data, mm_data + return seq_data, mm_data, ranges @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_internvl) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 7a6c991fb133..f9369410bc6b 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -100,7 +100,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, image_feature_size = get_max_llava_image_tokens(ctx) if isinstance(vision_config, CLIPVisionConfig): - seq_data = dummy_seq_data_for_clip( + seq_data, ranges = dummy_seq_data_for_clip( vision_config, seq_len, num_images, @@ -109,9 +109,9 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, ) mm_data = dummy_image_for_clip(vision_config, num_images) - return seq_data, mm_data + return seq_data, mm_data, ranges elif isinstance(vision_config, SiglipVisionConfig): - seq_data = dummy_seq_data_for_siglip( + seq_data, ranges = dummy_seq_data_for_siglip( vision_config, seq_len, num_images, @@ -120,7 +120,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, ) mm_data = dummy_image_for_siglip(vision_config, num_images) - return seq_data, mm_data + return seq_data, mm_data, ranges msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index c6bd46dd7eda..77c5f5b616f2 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -170,7 +170,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, image_feature_size = get_max_llava_next_image_tokens(ctx) if isinstance(vision_config, CLIPVisionConfig): - seq_data = dummy_seq_data_for_clip( + seq_data, ranges = dummy_seq_data_for_clip( vision_config, seq_len, num_images, @@ -185,9 +185,9 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) - return seq_data, mm_data + return seq_data, mm_data, ranges elif isinstance(vision_config, SiglipVisionConfig): - seq_data = dummy_seq_data_for_siglip( + seq_data, ranges = dummy_seq_data_for_siglip( vision_config, seq_len, num_images, @@ -202,7 +202,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) - return seq_data, mm_data + return seq_data, mm_data, ranges msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 7fe85e5e4ab3..1e15b1cd1b96 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -111,33 +111,35 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int, video_feature_size = frames_per_video * tokens_per_frame if isinstance(vision_config, CLIPVisionConfig): - seq_data = dummy_seq_data_for_clip( + seq_data, ranges = dummy_seq_data_for_clip( vision_config, seq_len, num_videos, image_token_id=hf_config.video_token_index, image_feature_size_override=video_feature_size, + mm_key="video", ) pil_frame = dummy_image_for_clip(vision_config, num_images=1) np_frame = np.array(pil_frame["image"]) mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0) mm_data = {"video": mm_data_per_video} - return seq_data, mm_data + return seq_data, mm_data, ranges elif isinstance(vision_config, SiglipVisionConfig): - seq_data = dummy_seq_data_for_siglip( + seq_data, ranges = dummy_seq_data_for_siglip( vision_config, seq_len, num_videos, image_token_id=hf_config.video_token_index, image_feature_size_override=video_feature_size, + mm_key="video", ) pil_frame = dummy_image_for_siglip(vision_config, num_images=1) np_frame = np.array(pil_frame["image"]) mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0) mm_data = {"video": mm_data_per_video} - return seq_data, mm_data + return seq_data, mm_data, ranges msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -148,6 +150,11 @@ def input_processor_for_llava_next_video(ctx: InputContext, multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "video" not in multi_modal_data: return llm_inputs + if "multi_modal_placeholders" in llm_inputs and "video" in llm_inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return llm_inputs + video_data = multi_modal_data["video"] model_config = ctx.model_config @@ -163,7 +170,7 @@ def input_processor_for_llava_next_video(ctx: InputContext, tokenizer = cached_get_tokenizer(model_config.tokenizer) - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], @@ -173,7 +180,8 @@ def input_processor_for_llava_next_video(ctx: InputContext, return LLMInputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"video": ranges}) elif is_list_of(video_data, np.ndarray): raise NotImplementedError( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index f8be9490ee55..b35cfb47e0fa 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -277,7 +277,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images) mm_data = dummy_image_for_minicpmv(hf_config, num_images) - return seq_data, mm_data + return seq_data, mm_data, None def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 5fd39b5e35be..a2a00f239fae 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -60,7 +60,7 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int, vision_config = hf_config.vision_config num_images = mm_counts["image"] - seq_data = dummy_seq_data_for_siglip( + seq_data, ranges = dummy_seq_data_for_siglip( vision_config, seq_len, num_images, @@ -68,7 +68,7 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int, ) mm_data = dummy_image_for_siglip(vision_config, num_images) - return seq_data, mm_data + return seq_data, mm_data, ranges def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 6f17f571ccae..313420957423 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -362,7 +362,7 @@ def dummy_data_for_phi3v(ctx: InputContext, seq_len: int, image_feature_size = get_max_phi3v_image_tokens(ctx) - seq_data = dummy_seq_data_for_clip( + seq_data, ranges = dummy_seq_data_for_clip( CLIP_VIT_LARGE_PATCH14_336_CONFIG, seq_len, num_images, @@ -376,7 +376,7 @@ def dummy_data_for_phi3v(ctx: InputContext, seq_len: int, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) - return seq_data, mm_data + return seq_data, mm_data, ranges # Reserve this function to also handle placeholders for additional images diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 010cf85f45e0..66dbd79d482c 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -65,7 +65,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, seq_data = SequenceData(token_ids) mm_data = {"image": max_num_images_per_request * [image]} - return seq_data, mm_data + return seq_data, mm_data, None def input_mapper_for_pixtral(ctx: InputContext, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 18bc6b303f48..f345f1143a1c 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -850,7 +850,7 @@ def dummy_data_for_qwen( # the data will get resized and the # of tokens per image is constant image = Image.new("RGB", (224, 224), color=0) mm_data = {"image": image if num_images == 1 else [image] * num_images} - return SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, toks)), mm_data + return SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, toks)), mm_data, None @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 3f8c590a39b0..dbb2a43911c8 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -695,7 +695,7 @@ def dummy_data_for_qwen2_vl( return dummy_seqdata, { "image": dummy_image if num_images == 1 else [dummy_image] * num_images - } + }, None def _get_llm_num_vision_tokens( diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index f7976eba7420..12ba0661aab6 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -61,6 +61,7 @@ def dummy_seq_data_for_siglip( *, image_token_id: int, image_feature_size_override: Optional[int] = None, + mm_key: str = "image", ): if image_feature_size_override is None: image_feature_size = get_siglip_image_feature_size(hf_config) @@ -71,7 +72,12 @@ def dummy_seq_data_for_siglip( [image_token_id]) * image_feature_size token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - image_feature_size) - return SequenceData(token_ids) + return SequenceData(token_ids), { + mm_key: [{ + "offset": i * image_feature_size, + "length": image_feature_size + } for i in range(num_images)] + } def dummy_image_for_siglip( @@ -103,6 +109,11 @@ def input_processor_for_siglip( if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs + if "multi_modal_placeholders" in llm_inputs and "image" in llm_inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return llm_inputs + tokenizer = cached_get_tokenizer(model_config.tokenizer) if image_feature_size_override is None: @@ -116,7 +127,7 @@ def input_processor_for_siglip( else: image_feature_size = image_feature_size_override - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], @@ -125,11 +136,10 @@ def input_processor_for_siglip( ) # NOTE: Create a defensive copy of the original inputs - return LLMInputs( - prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - ) + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": ranges}) # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 416fabda831a..ff51f5c30730 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -29,12 +29,12 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsMultiModal -from vllm.model_executor.models.utils import (filter_weights, flatten_bn, - init_vllm_registered_model, - merge_multimodal_embeddings) +from vllm.model_executor.models.utils import ( + filter_weights, flatten_bn, init_vllm_registered_model, + merge_multimodal_embeddings_from_map) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs, NestedTensors +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, + NestedTensors) from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData @@ -49,13 +49,13 @@ class UltravoxAudioFeatureInputs(TypedDict): type: Literal["audio_features"] data: NestedTensors - """Shape: `(batch_size, num_audios, 80, M)""" + """Shape: `(batch_size, num_audios, 80, M)`""" class UltravoxAudioEmbeddingInputs(TypedDict): type: Literal["audio_embeds"] data: NestedTensors - """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" + """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`""" UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, @@ -85,21 +85,26 @@ def dummy_data_for_ultravox( feature_extractor = whisper_feature_extractor(ctx) audio_count = mm_counts["audio"] + audio_length = min(get_ultravox_max_audio_tokens(ctx), + seq_len // audio_count) - audio_placeholder = array( - VLLM_TOKEN_ID_ARRAY_TYPE, - [_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx) + audio_placeholder = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [_AUDIO_PLACEHOLDER_TOKEN]) * audio_length - # Add a separator between each chunk. - audio_token_ids = (audio_placeholder + - array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count + audio_token_ids = audio_placeholder * audio_count other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - len(audio_token_ids)) - audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1) - mm_dict = {"audio": [audio_and_sr] * audio_count} + mm_dict = {"audio": [audio_and_sr for _ in range(audio_count)]} + mm_placeholders = { + "audio": [{ + "offset": i * audio_length, + "length": audio_length + } for i in range(audio_count)] + } - return (SequenceData(audio_token_ids + other_token_ids), mm_dict) + return (SequenceData(audio_token_ids + other_token_ids), mm_dict, + mm_placeholders) def input_mapper_for_ultravox(ctx: InputContext, data: object): @@ -146,6 +151,11 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): if multi_modal_data is None or "audio" not in multi_modal_data: return llm_inputs + if "multi_modal_placeholders" in llm_inputs and "audio" in llm_inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return llm_inputs + feature_extractor = whisper_feature_extractor(ctx) audios = multi_modal_data["audio"] if not isinstance(audios, list): @@ -174,7 +184,7 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], @@ -185,7 +195,8 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): # NOTE: Create a defensive copy of the original inputs return LLMInputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"audio": ranges}) class StackAudioFrames(nn.Module): @@ -423,9 +434,9 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, inputs_embeds = self.language_model.model.get_input_embeddings( input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, audio_embeddings, - _AUDIO_PLACEHOLDER_TOKEN) + merge_multimodal_embeddings_from_map( + inputs_embeds, audio_embeddings, + attn_metadata.multi_modal_placeholder_maps["audio"]) input_ids = None else: inputs_embeds = None diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 8b80dda96db4..d341e52bbf61 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.models import ModelRegistry -from vllm.multimodal.base import NestedTensors +from vllm.multimodal.base import MultiModalPlaceholderMap, NestedTensors from vllm.sequence import IntermediateTensors from vllm.utils import is_pin_memory_available @@ -120,6 +120,23 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str: _embedding_count_expression(inner) for inner in embeddings) +def merge_multimodal_embeddings_from_map( + inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, + placeholder_map: MultiModalPlaceholderMap.IndexTensors +) -> torch.Tensor: + """ + Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided + placeholder map . + + Note: + This updates ``inputs_embeds`` in place. + """ + flattened_embeddings = _flatten_embeddings(multimodal_embeddings) + inputs_embeds[placeholder_map.dest] = flattened_embeddings[ + placeholder_map.src] + return inputs_embeds + + def merge_multimodal_embeddings(input_ids: torch.Tensor, inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 489e1e51f05c..53da2badb9b9 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,6 +1,7 @@ from .base import (BatchedTensorInputs, MultiModalDataBuiltins, - MultiModalDataDict, MultiModalInputs, MultiModalPlugin, - NestedTensors) + MultiModalDataDict, MultiModalInputs, + MultiModalPlaceholderDict, MultiModalPlaceholderMap, + MultiModalPlugin, NestedTensors) from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() @@ -17,6 +18,8 @@ "MultiModalDataBuiltins", "MultiModalDataDict", "MultiModalInputs", + "MultiModalPlaceholderDict", + "MultiModalPlaceholderMap", "MultiModalPlugin", "NestedTensors", "MULTIMODAL_REGISTRY", diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 032964fe0ac4..12c603165e22 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,8 +1,8 @@ import sys from abc import ABC, abstractmethod from collections import UserDict, defaultdict -from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type, - TypedDict, TypeVar, Union, cast, final) +from typing import (Callable, Dict, List, Mapping, NamedTuple, Optional, Tuple, + Type, TypedDict, TypeVar, Union, cast, final) import numpy as np import torch @@ -14,6 +14,7 @@ from vllm.config import ModelConfig from vllm.inputs import InputContext from vllm.logger import init_logger +from vllm.sequence import SequenceGroupMetadata from vllm.utils import JSONTree, is_list_of, json_map_leaves logger = init_logger(__name__) @@ -144,6 +145,22 @@ class MultiModalDataBuiltins(TypedDict, total=False): Read more on that :ref:`here `. """ + +class PlaceholderRange(TypedDict): + """A placeholder for multi-modal data.""" + + offset: int + """The start index of the placeholder.""" + + length: int + """The length of the placeholder.""" + + +MultiModalPlaceholderDict = Mapping[str, List[PlaceholderRange]] +""" +A dictionary containing placeholder ranges. +""" + MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]], MultiModalInputs] """ @@ -338,3 +355,128 @@ def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: self._validate_max_multimodal_tokens(max_mm_tokens) return max_mm_tokens + + +class MultiModalPlaceholderMap: + """ + Relates multi-modal embeddings to their corresponding placeholders. + """ + + class IndexTensors(NamedTuple): + src: torch.Tensor + dest: torch.Tensor + + src_ranges: List[range] + """ + The indices of the multi-modal embeddings that will replace the + corresponding placeholder embeddings pointed to be ``dest_ranges``. + """ + + src_len: int + """ + The total number of flattened multi-modal embeddings. + """ + + dest_ranges: List[range] + """ + The indices of the placeholder embeddings that will be replaced by the + multimodal embeddings. + """ + + dest_len: int + """ + The total number of embeddings in the destination tensor. + """ + + def __init__(self): + self.src_ranges = [] + self.src_len = 0 + self.dest_ranges = [] + self.dest_len = 0 + + @classmethod + def from_seq_group( + cls, seq_group: SequenceGroupMetadata, positions: range + ) -> Tuple[Optional[MultiModalDataDict], Dict[str, + "MultiModalPlaceholderMap"]]: + if (not seq_group.multi_modal_data + or not seq_group.multi_modal_placeholders): + return seq_group.multi_modal_data, {} + + mm_data = {**seq_group.multi_modal_data} + placeholder_maps: Dict[str, MultiModalPlaceholderMap] = defaultdict( + MultiModalPlaceholderMap) + + for key, placeholders in seq_group.multi_modal_placeholders.items(): + mm_items = mm_data.pop(key) + if not isinstance(mm_items, list): + mm_items = [mm_items] + + if positions: + intersecting_items = placeholder_maps[ + key].append_items_from_seq_group(positions, mm_items, + placeholders) + + if intersecting_items: + mm_data[key] = intersecting_items + + return mm_data, placeholder_maps + + def append_items_from_seq_group( + self, positions: range, multi_modal_items: List[_T], + multi_modal_placeholders: List[PlaceholderRange]) -> List[_T]: + intersecting_items = [] + + if len(multi_modal_items) != len(multi_modal_placeholders): + raise ValueError( + "Multi-modal placeholders and items must have the same length." + ) + for placeholder_dict, mm_item in zip(multi_modal_placeholders, + multi_modal_items): + placeholder = range( + placeholder_dict["offset"], + placeholder_dict["offset"] + placeholder_dict["length"]) + intersection = range(max(positions.start, placeholder.start), + min(positions.stop, placeholder.stop)) + + if not intersection: + # Skip this multi-modal item. + continue + + token_embedding_range = range(intersection.start - positions.start, + intersection.stop - positions.start) + + multimodal_embedding_range = range( + intersection.start - placeholder.start + self.src_len, + intersection.stop - placeholder.start + self.src_len) + + intersecting_items.append(mm_item) + self.dest_ranges.append(token_embedding_range) + self.src_ranges.append(multimodal_embedding_range) + self.src_len += len(placeholder) + + self.dest_len += len(positions) + return intersecting_items + + def extend(self, other: "MultiModalPlaceholderMap"): + self.src_ranges.extend( + range(self.src_len + r.start, self.src_len + r.stop) + for r in other.src_ranges) + self.src_len += other.src_len + self.dest_ranges.extend( + range(self.dest_len + r.start, self.dest_len + r.stop) + for r in other.dest_ranges) + self.dest_len += other.dest_len + + def index_tensors(self, device: str) -> "IndexTensors": + src_indices = [i for r in self.src_ranges for i in r] + dest_indices = [i for r in self.dest_ranges for i in r] + + if len(src_indices) != len(dest_indices): + raise ValueError( + f"The number of source ({len(src_indices)}) and destination " + f"indices ({len(dest_indices)}) must be the same.") + + return MultiModalPlaceholderMap.IndexTensors( + src=torch.tensor(src_indices, dtype=torch.int64, device=device), + dest=torch.tensor(dest_indices, dtype=torch.int64, device=device)) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 3c801464383a..35c8965f5cdb 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -10,7 +10,7 @@ from vllm.connections import global_http_connection from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT from vllm.logger import init_logger -from vllm.multimodal.base import MultiModalDataDict +from vllm.multimodal.base import MultiModalDataDict, PlaceholderRange from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer logger = init_logger(__name__) @@ -258,7 +258,7 @@ def repeat_and_pad_placeholder_tokens( repeat_count: Union[int, List[int]], pad_token_left: Optional[int] = None, pad_token_right: Optional[int] = None, -) -> Tuple[Optional[str], List[int]]: +) -> Tuple[Optional[str], List[int], List[PlaceholderRange]]: if isinstance(repeat_count, int): repeat_count = [repeat_count] @@ -301,6 +301,7 @@ def repeat_and_pad_placeholder_tokens( new_prompt += prompt_parts[-1] new_token_ids: List[int] = [] + placeholder_ranges: List[PlaceholderRange] = [] placeholder_token_idx = 0 for i, token in enumerate(prompt_token_ids): if token == placeholder_token_id: @@ -310,6 +311,10 @@ def repeat_and_pad_placeholder_tokens( pad_token_left=pad_token_left, pad_token_right=pad_token_right, ) + placeholder_ranges.append({ + "offset": len(new_token_ids), + "length": len(replacement_ids) + }) new_token_ids.extend(replacement_ids) placeholder_token_idx += 1 @@ -320,4 +325,4 @@ def repeat_and_pad_placeholder_tokens( else: new_token_ids.append(token) - return new_prompt, new_token_ids + return new_prompt, new_token_ids, placeholder_ranges diff --git a/vllm/sequence.py b/vllm/sequence.py index 98a8b7358606..3480e5ab312e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from vllm.inputs import LLMInputs - from vllm.multimodal.base import MultiModalDataDict + from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -458,6 +458,10 @@ def prompt_token_ids(self) -> List[int]: def multi_modal_data(self) -> "MultiModalDataDict": return self.inputs.get("multi_modal_data") or {} + @property + def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": + return self.inputs.get("multi_modal_placeholders") or {} + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -690,6 +694,10 @@ def multi_modal_data(self) -> "MultiModalDataDict": # We use the multi-modal data of an arbitrary sequence. return self.seqs[0].multi_modal_data + @property + def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": + return self.seqs[0].multi_modal_placeholders + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -937,6 +945,7 @@ class SequenceGroupMetadata( # "MultiModalDataDict" types. We have to use Any due to msgspec # doesn't allow to have union of 2 different dicts. multi_modal_data: Optional[Any] = None + multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None encoder_seq_data: Optional[SequenceData] = None cross_block_table: Optional[List[int]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 7b2caf497358..b478a8b0cdae 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,3 +1,4 @@ +from collections import defaultdict from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union @@ -13,7 +14,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs) + MultiModalInputs, MultiModalPlaceholderMap) from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad from vllm.worker.model_runner_base import ( @@ -145,6 +146,9 @@ def _prepare_prompt( slot_mapping: List[int] = [] seq_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] + multi_modal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -163,13 +167,19 @@ def _prepare_prompt( # Token position ids # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, seq_len))) + positions_range = range(computed_len, seq_len) + input_positions.extend(list(positions_range)) + + if seq_group_metadata.multi_modal_data: + mm_data, placeholder_maps = MultiModalPlaceholderMap \ + .from_seq_group(seq_group_metadata, positions_range) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: mm_kwargs = self.multi_modal_input_mapper(mm_data) multi_modal_inputs_list.append(mm_kwargs) + for key, placeholder_map in placeholder_maps.items(): + multi_modal_placeholder_maps[key].extend(placeholder_map) + # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, @@ -203,6 +213,10 @@ def _prepare_prompt( slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) # type: ignore + placeholder_maps = { + key: placeholder_map.index_tensors(self.device) + for key, placeholder_map in multi_modal_placeholder_maps.items() + } attn_metadata = self.attn_backend.make_metadata( is_prompt=True, @@ -214,6 +228,7 @@ def _prepare_prompt( num_decode_tokens=0, block_tables=torch.tensor([]), slot_mapping=slot_mapping, + multi_modal_placeholder_maps=placeholder_maps, ) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) @@ -288,6 +303,7 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, + multi_modal_placeholder_maps=None, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_decode_seq_len=max_decode_seq_len, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index acb7bafefc20..f4c4d33fcc0e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -38,7 +38,8 @@ supports_multimodal) from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs, MultiModalRegistry) + MultiModalInputs, MultiModalPlaceholderMap, + MultiModalRegistry) from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.worker_manager import ( @@ -238,6 +239,8 @@ def __init__( # Multi-modal inputs. multi_modal_inputs: Optional[MultiModalInputs] = None, + multi_modal_placeholder_maps: Optional[Dict[ + str, MultiModalPlaceholderMap]] = None, # Whether the prefix cache is hit (prefill only). prefix_cache_hit: bool = False, @@ -355,6 +358,7 @@ def __init__( self.prompt_adapter_request = prompt_adapter_request self.multi_modal_inputs = multi_modal_inputs + self.multi_modal_placeholder_maps = multi_modal_placeholder_maps self.prefix_cache_hit = prefix_cache_hit self.n_seqs = len(self.seq_ids) @@ -651,12 +655,16 @@ def _compute_prompt_adapter_input( def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, seq_group_metadata: SequenceGroupMetadata): """If multi-modal data is given, add it to the input.""" - mm_data = seq_group_metadata.multi_modal_data + positions = inter_data.input_positions[0] + mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( + seq_group_metadata, + range(positions[0], positions[0] + len(positions))) if not mm_data: return mm_kwargs = self.multi_modal_input_mapper(mm_data) inter_data.multi_modal_inputs = mm_kwargs + inter_data.multi_modal_placeholder_maps = placeholder_maps # special processing for mrope position deltas. if self.runner.model_is_mrope: @@ -1184,7 +1192,9 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - seq_data, dummy_multi_modal_data = self.input_registry \ + (seq_data, + dummy_multi_modal_data, + dummy_placeholders) = self.input_registry \ .dummy_data_for_profiling(self.model_config, seq_len, self.mm_registry) @@ -1198,6 +1208,7 @@ def profile_run(self) -> None: lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None, multi_modal_data=dummy_multi_modal_data, + multi_modal_placeholders=dummy_placeholders, ) seqs.append(seq) diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index f335e4e32efd..7403c1617611 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -1,4 +1,5 @@ -from typing import List, NamedTuple, Optional, Tuple +from collections import defaultdict +from typing import Dict, List, NamedTuple, Optional, Tuple import openvino as ov import torch @@ -14,7 +15,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.openvino import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs) + MultiModalInputs, MultiModalPlaceholderMap) from vllm.sequence import SequenceGroupMetadata logger = init_logger(__name__) @@ -116,6 +117,9 @@ def _prepare_model_input( past_lens: List[int] = [] query_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] + multi_modal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) subsequence_begins: List[int] = [] block_indices: List[int] = [] @@ -169,11 +173,6 @@ def _prepare_model_input( and self.sliding_window is None and is_prompt) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - mm_kwargs = self.multi_modal_input_mapper(mm_data) - multi_modal_inputs_list.append(mm_kwargs) - block_table = seq_group_metadata.block_tables[seq_id] # TODO(sang): Combine chunked prefill and prefix caching by # only allowing multiple of block_size chunk size. @@ -217,7 +216,8 @@ def _prepare_model_input( query_lens.append(query_len) input_tokens.extend(tokens) - input_positions.extend(list(range(computed_len, seq_len))) + positions_range = range(computed_len, seq_len) + input_positions.extend(list(positions_range)) past_lens.append(computed_len) subsequence_begins.append(subsequence_begins[-1] + query_len) @@ -230,6 +230,17 @@ def _prepare_model_input( ), "seq_len: {}, computed_len: {}, query_len: {}".format( seq_len, computed_len, query_len) + if seq_group_metadata.multi_modal_data: + mm_data, placeholder_maps = MultiModalPlaceholderMap \ + .from_seq_group(seq_group_metadata, positions_range) + + mm_kwargs = self.multi_modal_input_mapper(mm_data) + multi_modal_inputs_list.append(mm_kwargs) + + for key, placeholder_map in placeholder_maps.items(): + multi_modal_placeholder_maps[key].extend( + placeholder_map) + max_query_len = max(query_lens) assert max_query_len > 0, "query_lens: {}".format(query_lens) @@ -258,12 +269,18 @@ def _prepare_model_input( max_context_len, dtype=torch.int32, device=self.device) # type: ignore + placeholder_maps = { + key: placeholder_map.index_tensors(self.device) + for key, placeholder_map in multi_modal_placeholder_maps.items() + } + attn_metadata = self.attn_backend.make_openvino_metadata( past_lens=past_lens_tensor, subsequence_begins=subsequence_begins_tensor, block_indices=block_indices_tensor, block_indices_begins=block_indices_begins_tensor, max_context_len=max_context_len_tensor, + multi_modal_placeholder_maps=placeholder_maps, ) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index db306bc743d3..a8c020d5f24e 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -172,6 +172,7 @@ def _dummy_run( num_prefill_tokens=batch_size * seq_len, num_decode_tokens=0, slot_mapping=slot_mapping, + multi_modal_placeholder_maps=None, block_tables=None, context_lens=None, ) @@ -204,6 +205,7 @@ def _dummy_run( num_prefill_tokens=0, num_decode_tokens=batch_size * seq_len, slot_mapping=slot_mapping, + multi_modal_placeholder_maps=None, block_tables=block_tables, context_lens=context_lens, ) @@ -348,6 +350,7 @@ def _prepare_prompt( num_prefill_tokens=0, # NOTE: This is not used. num_decode_tokens=0, slot_mapping=slot_mapping, + multi_modal_placeholder_maps=None, block_tables=None, context_lens=None, ) @@ -417,6 +420,7 @@ def _prepare_decode( num_prefill_tokens=0, num_decode_tokens=batch_size, slot_mapping=slot_mapping, + multi_modal_placeholder_maps=None, block_tables=block_tables, context_lens=context_lens, ) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index f9037625d4af..28a6b9965eae 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -1,6 +1,7 @@ import dataclasses import time import weakref +from collections import defaultdict from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, TypeVar) @@ -18,7 +19,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs, MultiModalRegistry) + MultiModalInputs, MultiModalPlaceholderMap, + MultiModalRegistry) from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad @@ -159,6 +161,9 @@ def _prepare_prompt( slot_mapping: List[int] = [] seq_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] + multi_modal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -177,7 +182,18 @@ def _prepare_prompt( # Token position ids # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, seq_len))) + positions_range = range(computed_len, seq_len) + input_positions.extend(list(positions_range)) + + if seq_group_metadata.multi_modal_data: + mm_data, placeholder_maps = MultiModalPlaceholderMap \ + .from_seq_group(seq_group_metadata, positions_range) + + mm_kwargs = self.runner.multi_modal_input_mapper(mm_data) + multi_modal_inputs_list.append(mm_kwargs) + + for key, placeholder_map in placeholder_maps.items(): + multi_modal_placeholder_maps[key].extend(placeholder_map) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized @@ -218,6 +234,10 @@ def _prepare_prompt( slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) # type: ignore + placeholder_maps = { + key: placeholder_map.index_tensors(self.device) + for key, placeholder_map in multi_modal_placeholder_maps.items() + } max_seqlen = max(seq_lens) tmp = [0] @@ -228,6 +248,7 @@ def _prepare_prompt( attn_metadata = self.attn_backend.make_metadata( is_prompt=True, slot_mapping=slot_mapping, + multi_modal_placeholder_maps=placeholder_maps, seq_lens=seq_lens, seqlen_q=seqlen_q, max_seqlen=max_seqlen, @@ -311,6 +332,7 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, + multi_modal_placeholder_maps=None, seq_lens=seq_lens, seqlen_q=torch.tensor([]), max_seqlen=0, From 46ee2d55f14dfeac9ff223c8f6207002b990799b Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Wed, 18 Sep 2024 22:13:45 +0000 Subject: [PATCH 02/15] Fix msgpack failures --- vllm/multimodal/base.py | 17 ++++++++++------- vllm/multimodal/image.py | 7 +++++-- vllm/multimodal/registry.py | 18 ++++++++++-------- vllm/multimodal/video.py | 8 +++++--- vllm/sequence.py | 18 ++++++++---------- 5 files changed, 38 insertions(+), 30 deletions(-) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 12c603165e22..7896da4f6d85 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,8 +1,9 @@ import sys from abc import ABC, abstractmethod from collections import UserDict, defaultdict -from typing import (Callable, Dict, List, Mapping, NamedTuple, Optional, Tuple, - Type, TypedDict, TypeVar, Union, cast, final) +from typing import (TYPE_CHECKING, Callable, Dict, List, Mapping, NamedTuple, + Optional, Tuple, Type, TypedDict, TypeVar, Union, cast, + final) import numpy as np import torch @@ -11,12 +12,14 @@ from torch import nn from typing_extensions import TypeAlias -from vllm.config import ModelConfig from vllm.inputs import InputContext from vllm.logger import init_logger -from vllm.sequence import SequenceGroupMetadata from vllm.utils import JSONTree, is_list_of, json_map_leaves +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.sequence import SequenceGroupMetadata + logger = init_logger(__name__) NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] @@ -252,7 +255,7 @@ def wrapper(model_cls: N) -> N: return wrapper - def map_input(self, model_config: ModelConfig, + def map_input(self, model_config: "ModelConfig", data: MultiModalData[object]) -> MultiModalInputs: """ Transform the data into a dictionary of model inputs using the @@ -324,7 +327,7 @@ def wrapper(model_cls: N) -> N: return wrapper - def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: + def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ Get the maximum number of multi-modal tokens for profiling the memory usage of a model. @@ -396,7 +399,7 @@ def __init__(self): @classmethod def from_seq_group( - cls, seq_group: SequenceGroupMetadata, positions: range + cls, seq_group: "SequenceGroupMetadata", positions: range ) -> Tuple[Optional[MultiModalDataDict], Dict[str, "MultiModalPlaceholderMap"]]: if (not seq_group.multi_modal_data diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 6cdde949bc2b..437e17046c39 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,9 +1,9 @@ from functools import lru_cache +from typing import TYPE_CHECKING import torch from PIL import Image -from vllm.config import ModelConfig from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.transformers_utils.image_processor import get_image_processor @@ -11,6 +11,9 @@ from .base import MultiModalData, MultiModalInputs, MultiModalPlugin +if TYPE_CHECKING: + from vllm.config import ModelConfig + logger = init_logger(__name__) cached_get_image_processor = lru_cache(get_image_processor) @@ -22,7 +25,7 @@ class ImagePlugin(MultiModalPlugin): def get_data_key(self) -> str: return "image" - def _get_hf_image_processor(self, model_config: ModelConfig): + def _get_hf_image_processor(self, model_config: "ModelConfig"): return cached_get_image_processor( model_config.model, trust_remote_code=model_config.trust_remote_code) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 745fc715caf4..1a4d0fe61d36 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,8 +1,7 @@ import functools from collections import UserDict -from typing import Dict, Mapping, Optional, Sequence +from typing import TYPE_CHECKING, Dict, Mapping, Optional, Sequence -from vllm.config import ModelConfig from vllm.logger import init_logger from .audio import AudioPlugin @@ -11,6 +10,9 @@ from .image import ImagePlugin from .video import VideoPlugin +if TYPE_CHECKING: + from vllm.config import ModelConfig + logger = init_logger(__name__) @@ -20,7 +22,7 @@ class _MultiModalLimits(UserDict): when attempting to access a model that does not exist. """ - def __getitem__(self, key: ModelConfig) -> Dict[str, int]: + def __getitem__(self, key: "ModelConfig") -> Dict[str, int]: try: return super().__getitem__(key) except KeyError as exc: @@ -96,7 +98,7 @@ def register_image_input_mapper( """ return self.register_input_mapper("image", mapper) - def map_input(self, model_config: ModelConfig, + def map_input(self, model_config: "ModelConfig", data: MultiModalDataDict) -> MultiModalInputs: """ Apply an input mapper to the data passed to the model. @@ -134,7 +136,7 @@ def map_input(self, model_config: ModelConfig, return MultiModalInputs(merged_dict) - def create_input_mapper(self, model_config: ModelConfig): + def create_input_mapper(self, model_config: "ModelConfig"): """ Create an input mapper (see :meth:`map_input`) for a specific model. """ @@ -163,7 +165,7 @@ def register_max_image_tokens( """ return self.register_max_multimodal_tokens("image", max_mm_tokens) - def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: + def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ Get the maximum number of multi-modal tokens for profiling the memory usage of a model. @@ -181,7 +183,7 @@ def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: def init_mm_limits_per_prompt( self, - model_config: ModelConfig, + model_config: "ModelConfig", ) -> None: """ Initialize the maximum number of multi-modal input instances for each @@ -217,7 +219,7 @@ def init_mm_limits_per_prompt( def get_mm_limits_per_prompt( self, - model_config: ModelConfig, + model_config: "ModelConfig", ) -> Mapping[str, int]: """ Get the maximum number of multi-modal input instances for each modality diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 4401d1315792..be63ec0bdf4e 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -1,9 +1,8 @@ from functools import lru_cache -from typing import List, Union +from typing import TYPE_CHECKING, List, Union import numpy as np -from vllm.config import ModelConfig from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.transformers_utils.image_processor import get_video_processor @@ -13,6 +12,9 @@ from .base import MultiModalData, MultiModalInputs from .image import ImagePlugin +if TYPE_CHECKING: + from vllm.config import ModelConfig + logger = init_logger(__name__) cached_get_video_processor = lru_cache(get_video_processor) @@ -36,7 +38,7 @@ class VideoPlugin(ImagePlugin): def get_data_key(self) -> str: return "video" - def _get_hf_video_processor(self, model_config: ModelConfig): + def _get_hf_video_processor(self, model_config: "ModelConfig"): return cached_get_video_processor( model_config.model, trust_remote_code=model_config.trust_remote_code) diff --git a/vllm/sequence.py b/vllm/sequence.py index a15395745dfe..fdeb45446e76 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,24 +5,22 @@ from array import array from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional +from typing import Any, Callable, Dict, List, Mapping, Optional from typing import Sequence as GenericSequence from typing import Set, Tuple, Union, cast import msgspec import torch +from vllm.inputs import LLMInputs from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs from vllm.lora.request import LoRARequest +from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics -if TYPE_CHECKING: - from vllm.inputs import LLMInputs - from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict - VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -455,11 +453,11 @@ def prompt_token_ids(self) -> List[int]: return self._prompt_token_ids @property - def multi_modal_data(self) -> "MultiModalDataDict": + def multi_modal_data(self) -> MultiModalDataDict: return self.inputs.get("multi_modal_data") or {} @property - def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": + def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: return self.inputs.get("multi_modal_placeholders") or {} @property @@ -691,13 +689,13 @@ def encoder_prompt_token_ids(self) -> Optional[List[int]]: if self.encoder_seq is not None else None) @property - def multi_modal_data(self) -> "MultiModalDataDict": + def multi_modal_data(self) -> MultiModalDataDict: # All sequences in the group should have the same multi-modal data. # We use the multi-modal data of an arbitrary sequence. return self.seqs[0].multi_modal_data @property - def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": + def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: return self.seqs[0].multi_modal_placeholders @property @@ -947,7 +945,7 @@ class SequenceGroupMetadata( # "MultiModalDataDict" types. We have to use Any due to msgspec # doesn't allow to have union of 2 different dicts. multi_modal_data: Optional[Any] = None - multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None + multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None encoder_seq_data: Optional[SequenceData] = None cross_block_table: Optional[List[int]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None From 6c7830e9b8179c7ea5bf2ca825316dae99f6d5f5 Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Wed, 18 Sep 2024 23:12:58 +0000 Subject: [PATCH 03/15] Fix test --- tests/worker/test_model_input.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 1e7f560fc68c..a095e8f0fd50 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -73,6 +73,7 @@ def test_model_runner_input(): num_prefill_tokens=2, num_decode_tokens=3, slot_mapping=torch.zeros(1), + multi_modal_placeholder_maps=None, ) model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), @@ -124,6 +125,7 @@ def test_embedding_model_runner_input(): num_prefill_tokens=2, num_decode_tokens=3, slot_mapping=torch.zeros(1), + multi_modal_placeholder_maps=None, ) model_input = ModelInputForGPUWithPoolingMetadata( input_tokens=torch.ones(10), From bf7d874ee175d57adec46cee0dcb2890cccee7d9 Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Thu, 19 Sep 2024 17:18:04 +0000 Subject: [PATCH 04/15] Change DummyData to NamedTuple --- vllm/inputs/registry.py | 37 ++++++++++--------- vllm/model_executor/models/blip2.py | 3 +- vllm/model_executor/models/chameleon.py | 3 +- vllm/model_executor/models/fuyu.py | 3 +- vllm/model_executor/models/internvl.py | 3 +- vllm/model_executor/models/llava.py | 5 ++- vllm/model_executor/models/llava_next.py | 5 ++- .../model_executor/models/llava_next_video.py | 5 ++- vllm/model_executor/models/minicpmv.py | 3 +- vllm/model_executor/models/paligemma.py | 3 +- vllm/model_executor/models/phi3v.py | 3 +- vllm/model_executor/models/pixtral.py | 9 ++++- vllm/model_executor/models/qwen.py | 4 +- vllm/model_executor/models/qwen2_vl.py | 8 ++-- vllm/model_executor/models/ultravox.py | 6 +-- vllm/worker/enc_dec_model_runner.py | 13 ++++--- vllm/worker/model_runner.py | 10 ++--- vllm/worker/xpu_model_runner.py | 8 ++-- 18 files changed, 76 insertions(+), 55 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index f7ec97c26740..f12d86d9f504 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -2,8 +2,8 @@ from array import array from collections import UserDict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, - Protocol, Tuple, Type) +from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple, + Optional, Protocol, Type) from torch import nn from transformers import PretrainedConfig @@ -66,8 +66,13 @@ def get_hf_image_processor_config(self) -> Dict[str, Any]: N = TypeVar("N", bound=Type[nn.Module]) -DummyDataTuple = Tuple["SequenceData", Optional["MultiModalDataDict"], - Optional["MultiModalPlaceholderDict"]] + +class DummyData(NamedTuple): + """Dummy data used for profiling.""" + + seq_data: "SequenceData" + multi_modal_data: Optional["MultiModalDataDict"] = None + multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None class DummyDataFactory(Protocol): @@ -77,7 +82,7 @@ def __call__( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], - ) -> DummyDataTuple: + ) -> DummyData: """ Create dummy data to be inputted into the model. @@ -123,7 +128,7 @@ def _default_dummy_data_factory( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], - ) -> DummyDataTuple: + ) -> DummyData: """ The default dummy data factory represents the longest possible text that can be inputted to the model. @@ -134,12 +139,8 @@ def _default_dummy_data_factory( # Avoid circular import from vllm.sequence import SequenceData - dummy_seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len) - dummy_multi_modal_data = None - dummy_placeholders = None - - return dummy_seq_data, dummy_multi_modal_data, dummy_placeholders + return DummyData( + SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len)) def register_dummy_data(self, factory: DummyDataFactory): """ @@ -168,7 +169,7 @@ def dummy_data_for_profiling( model_config: "ModelConfig", seq_len: int, mm_registry: "MultiModalRegistry", - ) -> DummyDataTuple: + ) -> DummyData: """ Create dummy data for profiling the memory usage of a model. @@ -189,27 +190,27 @@ def dummy_data_for_profiling( .get(model_cls, self._default_dummy_data_factory) mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) - seq_data, mm_data, ranges = dummy_factory( + dummy_data = dummy_factory( InputContext(model_config), seq_len, _MultiModalCounts(mm_counts), ) # Having more tokens is over-conservative but otherwise fine - num_tokens = seq_data.prompt_token_ids + num_tokens = dummy_data.seq_data.prompt_token_ids assert len(num_tokens) >= seq_len, ( f"Expected at least {seq_len} dummy tokens for profiling, " f"but found {len(num_tokens)} tokens instead.") - if mm_data is not None: - for k, v in mm_data.items(): + if dummy_data.multi_modal_data is not None: + for k, v in dummy_data.multi_modal_data.items(): num_items = len(v) if isinstance(v, list) else 1 num_expected = mm_counts[k] assert num_items >= num_expected, ( f"Expected at least {num_expected} dummy '{k}' instances " f"for profiling, but found {num_items} instances instead.") - return seq_data, mm_data, ranges + return dummy_data def _default_input_processor(self, ctx: InputContext, inputs: LLMInputs) -> LLMInputs: diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index dea17c6b2cad..7aa7a03ace90 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -10,6 +10,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs.registry import DummyData from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -457,7 +458,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int, if isinstance(vision_config, Blip2VisionConfig): mm_data = dummy_image_for_blip(vision_config, num_images) - return seq_data, mm_data, ranges + return DummyData(seq_data, mm_data, ranges) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 595f6ac472b8..58d0f002d3a4 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -13,6 +13,7 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs.registry import DummyData from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -112,7 +113,7 @@ def dummy_data_for_chameleon(ctx: InputContext, seq_len: int, ) mm_data = dummy_image_for_chameleon(num_images) - return seq_data, mm_data, ranges + return DummyData(seq_data, mm_data, ranges) def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 8675356a7dc8..dddf509fddc8 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -28,6 +28,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs.registry import DummyData from vllm.logger import init_logger from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig @@ -131,7 +132,7 @@ def dummy_data_for_fuyu(ctx: InputContext, seq_len: int, mm_data = dummy_image_for_fuyu(num_images, image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT) - return seq_data, mm_data, ranges + return DummyData(seq_data, mm_data, ranges) def _fuyu_image_preprocess(image_processor: FuyuImageProcessor, diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index beb3012cad79..e7272980f0b5 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -19,6 +19,7 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs.registry import DummyData from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -324,7 +325,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int, image_height_override=max_image_height, ) - return seq_data, mm_data, ranges + return DummyData(seq_data, mm_data, ranges) @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_internvl) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index f9369410bc6b..7e1e1c88c6b4 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -10,6 +10,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs.registry import DummyData from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -109,7 +110,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, ) mm_data = dummy_image_for_clip(vision_config, num_images) - return seq_data, mm_data, ranges + return DummyData(seq_data, mm_data, ranges) elif isinstance(vision_config, SiglipVisionConfig): seq_data, ranges = dummy_seq_data_for_siglip( vision_config, @@ -120,7 +121,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, ) mm_data = dummy_image_for_siglip(vision_config, num_images) - return seq_data, mm_data, ranges + return DummyData(seq_data, mm_data, ranges) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 77c5f5b616f2..236630bfa791 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -13,6 +13,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs.registry import DummyData from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -185,7 +186,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) - return seq_data, mm_data, ranges + return DummyData(seq_data, mm_data, ranges) elif isinstance(vision_config, SiglipVisionConfig): seq_data, ranges = dummy_seq_data_for_siglip( vision_config, @@ -202,7 +203,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) - return seq_data, mm_data, ranges + return DummyData(seq_data, mm_data, ranges) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 1e15b1cd1b96..99bdf626fadd 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -12,6 +12,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs.registry import DummyData from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization.base_config import ( @@ -124,7 +125,7 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int, np_frame = np.array(pil_frame["image"]) mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0) mm_data = {"video": mm_data_per_video} - return seq_data, mm_data, ranges + return DummyData(seq_data, mm_data, ranges) elif isinstance(vision_config, SiglipVisionConfig): seq_data, ranges = dummy_seq_data_for_siglip( vision_config, @@ -139,7 +140,7 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int, np_frame = np.array(pil_frame["image"]) mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0) mm_data = {"video": mm_data_per_video} - return seq_data, mm_data, ranges + return DummyData(seq_data, mm_data, ranges) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 89ef5fc96efe..8bd7f3074ce3 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -38,6 +38,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs.registry import DummyData from vllm.logger import init_logger from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -277,7 +278,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images) mm_data = dummy_image_for_minicpmv(hf_config, num_images) - return seq_data, mm_data, None + return DummyData(seq_data, mm_data) def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index a2a00f239fae..cc80c3219f6c 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -9,6 +9,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs.registry import DummyData from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -68,7 +69,7 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int, ) mm_data = dummy_image_for_siglip(vision_config, num_images) - return seq_data, mm_data, ranges + return DummyData(seq_data, mm_data, ranges) def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 313420957423..c438ff6cb099 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -28,6 +28,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs.registry import DummyData from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -376,7 +377,7 @@ def dummy_data_for_phi3v(ctx: InputContext, seq_len: int, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) - return seq_data, mm_data, ranges + return DummyData(seq_data, mm_data, ranges) # Reserve this function to also handle placeholders for additional images diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 1b8d7aab0bfd..dcf1919c1818 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -15,6 +15,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs.registry import DummyData from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -71,7 +72,13 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, seq_data = SequenceData(token_ids) mm_data = {"image": num_images * [image]} - return seq_data, mm_data, None + mm_placeholders = { + "image": [{ + "offset": i, + "length": image_feature_size + } for i in range(num_images)] + } + return DummyData(seq_data, mm_data, mm_placeholders) def input_mapper_for_pixtral(ctx: InputContext, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index f345f1143a1c..662b4f47f03b 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -24,6 +24,7 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs.registry import DummyData from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm @@ -850,7 +851,8 @@ def dummy_data_for_qwen( # the data will get resized and the # of tokens per image is constant image = Image.new("RGB", (224, 224), color=0) mm_data = {"image": image if num_images == 1 else [image] * num_images} - return SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, toks)), mm_data, None + return DummyData(SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, toks)), + mm_data) @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 8418273cb291..ea0e46373e55 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -49,6 +49,7 @@ from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs.registry import DummyData from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU @@ -693,9 +694,10 @@ def dummy_data_for_qwen2_vl( dummy_image = Image.new("RGB", (max_resized_width, max_resized_height), color=0) - return dummy_seqdata, { - "image": dummy_image if num_images == 1 else [dummy_image] * num_images - }, None + return DummyData(dummy_seqdata, { + "image": + dummy_image if num_images == 1 else [dummy_image] * num_images + }) def _get_llm_num_vision_tokens( diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index ff51f5c30730..b3c8a992ee77 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -20,7 +20,7 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY from vllm.inputs.data import LLMInputs -from vllm.inputs.registry import InputContext +from vllm.inputs.registry import DummyData, InputContext from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm @@ -103,8 +103,8 @@ def dummy_data_for_ultravox( } for i in range(audio_count)] } - return (SequenceData(audio_token_ids + other_token_ids), mm_dict, - mm_placeholders) + return DummyData(SequenceData(audio_token_ids + other_token_ids), mm_dict, + mm_placeholders) def input_mapper_for_ultravox(ctx: InputContext, data: object): diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 09dab0135f39..79b73ac26e22 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -297,25 +297,26 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - seq_data, _ = self.input_registry \ + dummy_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, seq_len, self.mm_registry) # Having more tokens is over-conservative but otherwise fine - assert len(seq_data.prompt_token_ids) >= seq_len, ( + assert len(dummy_data.seq_data.prompt_token_ids) >= seq_len, ( f"Expected at least {seq_len} dummy tokens for profiling, " - f"but got: {len(seq_data.prompt_token_ids)}") + f"but got: {len(dummy_data.seq_data.prompt_token_ids)}") seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, - seq_data={group_id: seq_data}, + seq_data={group_id: dummy_data.seq_data}, sampling_params=sampling_params, block_tables=None, - encoder_seq_data=seq_data, + encoder_seq_data=dummy_data.seq_data, cross_block_table=None, - ) + multi_modal_data=dummy_data.multi_modal_data, + multi_modal_placeholders=dummy_data.multi_modal_placeholders) seqs.append(seq) # Run the model with the dummy inputs. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f971ab457d3a..73902657fe21 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1212,9 +1212,7 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - (seq_data, - dummy_multi_modal_data, - dummy_placeholders) = self.input_registry \ + dummy_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, seq_len, self.mm_registry) @@ -1222,13 +1220,13 @@ def profile_run(self) -> None: seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, - seq_data={group_id: seq_data}, + seq_data={group_id: dummy_data.seq_data}, sampling_params=sampling_params, block_tables=None, lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None, - multi_modal_data=dummy_multi_modal_data, - multi_modal_placeholders=dummy_placeholders, + multi_modal_data=dummy_data.multi_modal_data, + multi_modal_placeholders=dummy_data.multi_modal_placeholders, ) seqs.append(seq) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 28a6b9965eae..db864b9a575f 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -468,7 +468,7 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - seq_data, dummy_multi_modal_data = self.input_registry \ + dummy_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, seq_len, self.mm_registry) @@ -476,12 +476,12 @@ def profile_run(self) -> None: seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, - seq_data={group_id: seq_data}, + seq_data={group_id: dummy_data.seq_data}, sampling_params=sampling_params, block_tables=None, lora_request=None, - multi_modal_data=dummy_multi_modal_data, - ) + multi_modal_data=dummy_data.multi_modal_data, + multi_modal_placeholders=dummy_data.multi_modal_placeholders) seqs.append(seq) # Run the model with the dummy inputs. From 5da94505ac7b26e741f5fc6396c772835654bf63 Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Mon, 23 Sep 2024 17:41:49 +0000 Subject: [PATCH 05/15] Add online + chunked prefill test for ultravox --- .../audio_language/test_ultravox.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index 4aafd0cc517a..2530921c1e68 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -2,8 +2,10 @@ import numpy as np import pytest +import pytest_asyncio from transformers import AutoModel, AutoTokenizer, BatchEncoding +from tests.utils import RemoteOpenAIServer from vllm.sequence import SampleLogprobs from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE @@ -37,6 +39,26 @@ def audio(request): return AudioAsset(request.param) +@pytest.fixture(params=({}, CHUNKED_PREFILL_KWARGS)) +def server(request, audio_assets): + args = [ + "--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager", + f"--limit-mm-per-prompt=audio={len(audio_assets)}" + ] + [ + f"--{key.replace('_','-')}={value}" + for key, value in request.param.items() + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + def _get_prompt(audio_count, question, placeholder): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) placeholder = f"{placeholder}\n" * audio_count @@ -201,3 +223,33 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str, num_logprobs=num_logprobs, **vllm_kwargs, ) + + +@pytest.mark.asyncio +async def test_online_inference(client, audio_assets): + messages = [{ + "role": + "user", + "content": [ + *[{ + "type": "audio_url", + "audio_url": { + "url": audio.url + } + } for audio in audio_assets], + { + "type": + "text", + "text": + f"What's happening in these {len(audio_assets)} audio clips?" + }, + ], + }] + + chat_completion = await client.chat.completions.create(model=MODEL_NAME, + messages=messages, + max_tokens=10) + + assert len(chat_completion.choices) == 1 + choice = chat_completion.choices[0] + assert choice.finish_reason == "length" From c5d0d7f3ac3b427650de9c81bcbed1c3860478dd Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Mon, 23 Sep 2024 17:56:34 +0000 Subject: [PATCH 06/15] Use SequenceData.from_token_counts in Ultravox --- vllm/model_executor/models/ultravox.py | 29 ++++++++++---------------- 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index edf63113b012..12100dfe5449 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -2,7 +2,6 @@ """PyTorch Ultravox model.""" import math -from array import array from functools import lru_cache from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union, cast) @@ -35,7 +34,7 @@ NestedTensors) from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from vllm.sequence import SequenceData from vllm.transformers_utils.configs.ultravox import UltravoxConfig _AUDIO_PLACEHOLDER_TOKEN = 128002 @@ -81,19 +80,14 @@ def dummy_seq_data_for_ultravox( audio_length = min(get_ultravox_max_audio_tokens(ctx), seq_len // audio_count) - audio_placeholder = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [_AUDIO_PLACEHOLDER_TOKEN]) * audio_length - - audio_token_ids = audio_placeholder * audio_count - other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - len(audio_token_ids)) - - return SequenceData(audio_token_ids + other_token_ids), { - "audio": [{ - "offset": i * audio_length, - "length": audio_length - } for i in range(audio_count)] - } + return SequenceData.from_token_counts( + (_AUDIO_PLACEHOLDER_TOKEN, audio_length * audio_count), + (0, seq_len - audio_length * audio_count)), { + "audio": [{ + "offset": i * audio_length, + "length": audio_length + } for i in range(audio_count)] + } def dummy_audio_for_ultravox( @@ -111,11 +105,10 @@ def dummy_data_for_ultravox( mm_counts: Mapping[str, int], ): audio_count = mm_counts["audio"] - seq_data, placeholders = dummy_seq_data_for_ultravox( - ctx, seq_len, audio_count) + seq_data, ranges = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count) mm_dict = dummy_audio_for_ultravox(ctx, audio_count) - return DummyData(seq_data, mm_dict, placeholders) + return DummyData(seq_data, mm_dict, ranges) def input_mapper_for_ultravox(ctx: InputContext, data: object): From 7a6cbe9c3b50226db3298bbf1a8930b21464b32f Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Mon, 23 Sep 2024 21:47:57 +0000 Subject: [PATCH 07/15] Update test mock --- tests/multimodal/test_processor_kwargs.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index 5529ccd4fa57..4282a7976ca5 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -6,7 +6,7 @@ import torch from vllm.inputs import InputContext, LLMInputs -from vllm.inputs.registry import InputRegistry +from vllm.inputs.registry import DummyData, InputRegistry from vllm.multimodal import MultiModalRegistry from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData @@ -56,7 +56,7 @@ def custom_dummy_data_factory(self, num_crops=DEFAULT_NUM_CROPS): seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)) - return seq_data, None + return DummyData(seq_data, None) with patch( "vllm.inputs.registry.InputRegistry._default_dummy_data_factory", @@ -149,9 +149,9 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): # NOTE: seq_len is thrown away here since this will leverage the # default dummy data factory that we have patched in, whose seq # len is solely dependent on the value of the mm_processor_kwargs. - seq_data, _ = dummy_registry.dummy_data_for_profiling( + dummy_data = dummy_registry.dummy_data_for_profiling( ctx.model_config, seq_len=-1, mm_registry=mm_registry) - assert len(seq_data.prompt_token_ids) == expected_seq_count + assert len(dummy_data.seq_data.prompt_token_ids) == expected_seq_count @pytest.mark.parametrize( @@ -178,9 +178,9 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, # NOTE: seq_len is thrown away here since this will leverage the # default dummy data factory that we have patched in, whose seq # len is solely dependent on the value of the mm_processor_kwargs. - seq_data, _ = dummy_registry.dummy_data_for_profiling( + dummy_data = dummy_registry.dummy_data_for_profiling( ctx.model_config, seq_len=-1, mm_registry=mm_registry) - assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS + assert len(dummy_data.seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS ### Test overrides for the max token count per multimodal instance From 1a617f2e5aae4bdf455c52898129fd769085cac5 Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Tue, 24 Sep 2024 16:30:13 +0000 Subject: [PATCH 08/15] Fix test failures --- tests/worker/test_model_input.py | 1 + vllm/model_executor/models/fuyu.py | 14 +++++++++----- vllm/model_executor/models/qwen.py | 4 ++-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index a095e8f0fd50..32ccfc8a3208 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -176,6 +176,7 @@ def test_multi_step_model_runner_input(): num_prefill_tokens=2, num_decode_tokens=3, slot_mapping=torch.zeros(1), + multi_modal_placeholder_maps=None, ) frozen_model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 8f08495346de..b3dd9faf6637 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -41,6 +41,7 @@ from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) +from vllm.utils import is_list_of from .interfaces import SupportsMultiModal from .utils import merge_multimodal_embeddings @@ -133,7 +134,7 @@ def dummy_data_for_fuyu(ctx: InputContext, seq_len: int, def _fuyu_image_preprocess(image_processor: FuyuImageProcessor, - data: Image.Image): + data: List[Image.Image]): image_encoding = image_processor.preprocess(data, return_tensors="pt") batch_images = torch.stack([img[0] for img in image_encoding["images"] ]).unsqueeze(1) @@ -164,8 +165,10 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): model_config = ctx.model_config image_data = multi_modal_data["image"] new_multi_modal_data = {} + image_list = image_data if isinstance(image_data, list) else [image_data] + # process image data - if isinstance(image_data, Image.Image): + if is_list_of(image_list, Image.Image): # Fuyu's image_processor can also finish token padding image_processor: FuyuImageProcessor = cached_get_image_processor( model_config.model) @@ -177,7 +180,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): ]) new_multi_modal_data["image"] = image_patches - elif isinstance(image_data, torch.Tensor): + elif is_list_of(image_list, torch.Tensor): raise NotImplementedError("Embeddings input is not supported yet") else: raise TypeError(f"Invalid image type: {type(image_data)}") @@ -204,12 +207,13 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): def input_mapper_for_fuyu(ctx: InputContext, data: object): model_config = ctx.model_config - if isinstance(data, Image.Image): + data_list = data if isinstance(data, list) else [data] + if is_list_of(data_list, Image.Image): # Fuyu's image_processor can also finish token padding image_processor: FuyuImageProcessor = cached_get_image_processor( model_config.model) - model_image_input = _fuyu_image_preprocess(image_processor, data) + model_image_input = _fuyu_image_preprocess(image_processor, data_list) data = torch.stack([ image_patch[0] for image_patch in model_image_input["image_patches"] diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 1044f6d89b45..d0ab2113f4c5 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -801,7 +801,7 @@ def dummy_data_for_qwen( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], -) -> Tuple[SequenceData, Optional[Dict]]: +) -> DummyData: """Build dummy data for warming up Qwen models; this will only contain text matching the defaults for VLLM unless the model has a visual config. @@ -820,7 +820,7 @@ def dummy_data_for_qwen( if not hasattr(hf_config, "visual"): seq_data = SequenceData.from_token_counts((0, seq_len)) mm_data = None - return seq_data, mm_data + return DummyData(seq_data, mm_data) # We have a visual component - use images to warm up num_images = mm_counts["image"] From 509cbaccfffdccf037335512fe0b82b9c492064a Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Tue, 24 Sep 2024 16:33:50 +0000 Subject: [PATCH 09/15] Fix llava-onevision dummy data --- vllm/model_executor/models/llava_onevision.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 9099d4f88222..fa00f4719994 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -15,6 +15,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.inputs.registry import DummyData from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization.base_config import ( @@ -226,7 +227,7 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int, video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames) if isinstance(vision_config, CLIPVisionConfig): - seq_data = dummy_seq_data_for_clip( + seq_data, ranges = dummy_seq_data_for_clip( vision_config, seq_len, num_videos, @@ -235,9 +236,9 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int, ) mm_data = dummy_video_for_clip(vision_config, num_frames=num_frames) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) elif isinstance(vision_config, SiglipVisionConfig): - seq_data = dummy_seq_data_for_siglip( + seq_data, ranges = dummy_seq_data_for_siglip( vision_config, seq_len, num_videos, @@ -246,7 +247,7 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int, ) mm_data = dummy_video_for_siglip(vision_config, num_frames=num_frames) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) From cb2dc49ce876ef6dde293678bb11c0ec17bb13c4 Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Wed, 25 Sep 2024 16:05:12 +0000 Subject: [PATCH 10/15] Allow None values in AttentionMetadata --- vllm/worker/model_runner_base.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 975b88c0e79a..778fdf50c174 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -46,9 +46,8 @@ def _init_attn_metadata_from_tensor_dict( # Extract the fields used to create AttentionMetadata. valid_attn_kwargs = {} for field in dataclasses.fields(attn_backend.get_metadata_cls()): - val = tensor_dict.pop(field.name, None) - if val is not None: - valid_attn_kwargs[field.name] = val + if field.name in tensor_dict: + valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) tensor_dict["attn_metadata"] = attn_metadata From defb59cc9753a3798849d43fa10a0e59bafe4ed1 Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Wed, 25 Sep 2024 17:19:10 +0000 Subject: [PATCH 11/15] Fix phi3v test --- tests/models/decoder_only/vision_language/test_phi3v.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py index eba0a1a1bce4..ff586be1e2b2 100644 --- a/tests/models/decoder_only/vision_language/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/test_phi3v.py @@ -356,14 +356,14 @@ def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str, mm_processor_kwargs=None, ) - sequence_data, _, = dummy_data_for_phi3v( + dummy_data = dummy_data_for_phi3v( ctx=ctx, seq_len=8192, # Should be bigger than num_imgs * toks_per_img mm_counts={"image": num_imgs}, num_crops=num_crops, ) # Ensure we have the right number of placeholders per num_crops size - img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID) + img_tok_count = dummy_data.seq_data.get_token_ids().count(_IMAGE_TOKEN_ID) assert img_tok_count == toks_per_img * num_imgs From a8025f936b4e264fc8bf214ff68468a1b2edcb31 Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Thu, 26 Sep 2024 23:11:14 +0000 Subject: [PATCH 12/15] Replace index tensors with lists --- vllm/attention/backends/abstract.py | 4 ++-- vllm/attention/backends/flash_attn.py | 2 +- vllm/attention/backends/flashinfer.py | 2 +- vllm/attention/backends/utils.py | 2 +- vllm/model_executor/models/utils.py | 3 +-- vllm/multimodal/base.py | 13 ++++++------- vllm/worker/cpu_model_runner.py | 2 +- vllm/worker/openvino_model_runner.py | 2 +- vllm/worker/xpu_model_runner.py | 2 +- 9 files changed, 15 insertions(+), 17 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index f7431068bb89..ecacfb8e230e 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -107,10 +107,10 @@ class AttentionMetadata: # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor - # The index tensors that relate multi-modal embeddings to the corresponding + # The index maps that relate multi-modal embeddings to the corresponding # placeholders. multi_modal_placeholder_maps: Optional[Dict[ - str, MultiModalPlaceholderMap.IndexTensors]] + str, MultiModalPlaceholderMap.IndexMap]] @property @abstractmethod diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 96d042f561b3..fba7efbb6b3d 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -563,7 +563,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=torch.int32, device=device) placeholder_maps = { - key: placeholder_map.index_tensors(device) + key: placeholder_map.index_map() for key, placeholder_map in self.multimodal_placeholder_maps.items() } diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 47cded69df37..b2b63346f08f 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -658,7 +658,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=torch.int32, device=device) placeholder_maps = { - key: placeholder_map.index_tensors(device) + key: placeholder_map.index_map() for key, placeholder_map in self.multimodal_placeholder_maps.items() } diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 4b13af8ca53b..5b2a0da13796 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -259,7 +259,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=torch.int32, device=device) placeholder_maps = { - key: placeholder_map.index_tensors(device) + key: placeholder_map.index_map() for key, placeholder_map in self.multimodal_placeholder_maps.items() } diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index ed64d83f07c2..818aceb59749 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -156,8 +156,7 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str: def merge_multimodal_embeddings_from_map( inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, - placeholder_map: MultiModalPlaceholderMap.IndexTensors -) -> torch.Tensor: + placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor: """ Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided placeholder map . diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 02118fb57a91..5da4678d7f70 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -384,9 +384,9 @@ class MultiModalPlaceholderMap: Relates multi-modal embeddings to their corresponding placeholders. """ - class IndexTensors(NamedTuple): - src: torch.Tensor - dest: torch.Tensor + class IndexMap(NamedTuple): + src: List[int] + dest: List[int] src_ranges: List[range] """ @@ -490,7 +490,7 @@ def extend(self, other: "MultiModalPlaceholderMap"): for r in other.dest_ranges) self.dest_len += other.dest_len - def index_tensors(self, device: str) -> "IndexTensors": + def index_map(self) -> "IndexMap": src_indices = [i for r in self.src_ranges for i in r] dest_indices = [i for r in self.dest_ranges for i in r] @@ -499,6 +499,5 @@ def index_tensors(self, device: str) -> "IndexTensors": f"The number of source ({len(src_indices)}) and destination " f"indices ({len(dest_indices)}) must be the same.") - return MultiModalPlaceholderMap.IndexTensors( - src=torch.tensor(src_indices, dtype=torch.int64, device=device), - dest=torch.tensor(dest_indices, dtype=torch.int64, device=device)) + return MultiModalPlaceholderMap.IndexMap(src=src_indices, + dest=dest_indices) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 28617b5d886b..8dc790ef21dd 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -276,7 +276,7 @@ def _prepare_prompt( dtype=torch.long, device=self.device) # type: ignore placeholder_maps = { - key: placeholder_map.index_tensors(self.device) + key: placeholder_map.index_map() for key, placeholder_map in multi_modal_placeholder_maps.items() } diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 7403c1617611..a739f43c5abf 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -270,7 +270,7 @@ def _prepare_model_input( device=self.device) # type: ignore placeholder_maps = { - key: placeholder_map.index_tensors(self.device) + key: placeholder_map.index_map() for key, placeholder_map in multi_modal_placeholder_maps.items() } diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 11f5e2acce97..f0e88a642b00 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -235,7 +235,7 @@ def _prepare_prompt( dtype=torch.long, device=self.device) # type: ignore placeholder_maps = { - key: placeholder_map.index_tensors(self.device) + key: placeholder_map.index_map() for key, placeholder_map in multi_modal_placeholder_maps.items() } From 862c46c8934e5257c36c04895f70e8c2c15601a8 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 30 Sep 2024 16:23:25 +0000 Subject: [PATCH 13/15] Fix test failures --- vllm/model_executor/models/llava_onevision.py | 9 +++++---- vllm/multimodal/video.py | 3 ++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index fa00f4719994..118eb99b5963 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -233,7 +233,7 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int, num_videos, image_token_id=hf_config.video_token_index, image_feature_size_override=video_feature_size, - ) + mm_key="video") mm_data = dummy_video_for_clip(vision_config, num_frames=num_frames) return DummyData(seq_data, mm_data, ranges) @@ -244,7 +244,7 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int, num_videos, image_token_id=hf_config.video_token_index, image_feature_size_override=video_feature_size, - ) + mm_key="video") mm_data = dummy_video_for_siglip(vision_config, num_frames=num_frames) return DummyData(seq_data, mm_data, ranges) @@ -326,7 +326,7 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames) tokenizer = cached_get_tokenizer(model_config.tokenizer) - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], @@ -336,7 +336,8 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, return LLMInputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"video": ranges}) elif is_list_of(video_data, np.ndarray): raise NotImplementedError( diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 526815394fad..b8d075254163 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -56,7 +56,8 @@ def _default_input_mapper( model_config = ctx.model_config # single video input as np.ndarray - if isinstance(data, np.ndarray): + if isinstance(data, np.ndarray) or (is_list_of(data, np.ndarray) + and len(data) == 1): video_processor = self._get_hf_video_processor(model_config) if video_processor is None: raise RuntimeError("No HuggingFace processor is available " From d0bc54db2ba430d6787edb0d96f530d71e890c7c Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Thu, 10 Oct 2024 22:48:25 +0000 Subject: [PATCH 14/15] Add docstrings --- vllm/multimodal/base.py | 51 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index b92371f284ef..3e17a9bba6b1 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -430,6 +430,42 @@ def from_seq_group( cls, seq_group: "SequenceGroupMetadata", positions: range ) -> Tuple[Optional[MultiModalDataDict], Dict[str, "MultiModalPlaceholderMap"]]: + """ + Returns the multi-modal items that intersect with the portion of a + prompt (``seq_group``) represented by ``positions``, as well as a + ``MultiModalPlaceholderMap`` that relates the multi-modal embedding + vectors to their corresponding placeholders. + + Consider the following scenarios: + + Prompt: |AAAA BBBB What's in these images?| + Positions: |.................................| + + images = [A, B] + src_ranges = [(0, 4), (4, 8)] + dest_ranges = [(0, 4), (5, 9)] + + Prompt: |AAAA BBBB What's in these images?| + Positions: | ..... | + + images = [A, B] + src_ranges = [(2, 4), (4, 8)] + dest_ranges = [(0, 2), (3, 7)] + + Prompt: |AAAA BBBB What's in these images?| + Positions: | ......... | + + images = [B] + src_ranges = [(0, 4)] + dest_ranges = [(0, 4)] + + Prompt: |AAAA BBBB What's in these images?| + Positions: | .......................| + + images = [] + src_ranges = [] + dest_ranges = [] + """ if (not seq_group.multi_modal_data or not seq_group.multi_modal_placeholders): return seq_group.multi_modal_data, {} @@ -456,6 +492,10 @@ def from_seq_group( def append_items_from_seq_group( self, positions: range, multi_modal_items: List[_T], multi_modal_placeholders: List[PlaceholderRange]) -> List[_T]: + """ + Adds the multi-modal items that intersect ```positions`` to this + placeholder map and returns the intersecting items. + """ intersecting_items = [] if len(multi_modal_items) != len(multi_modal_placeholders): @@ -490,6 +530,12 @@ def append_items_from_seq_group( return intersecting_items def extend(self, other: "MultiModalPlaceholderMap"): + """ + Adds the placeholders from another ``MultiModalPlaceholderMap`` to this + instance based on the source and destination tensors being + concatenated. + """ + self.src_ranges.extend( range(self.src_len + r.start, self.src_len + r.stop) for r in other.src_ranges) @@ -500,6 +546,11 @@ def extend(self, other: "MultiModalPlaceholderMap"): self.dest_len += other.dest_len def index_map(self) -> "IndexMap": + """ + Finalizes the placeholder map into lists of indices that can be used to + index the source and destination tensors. + """ + src_indices = [i for r in self.src_ranges for i in r] dest_indices = [i for r in self.dest_ranges for i in r] From 5171297d26cd2579e2d6d9a0c7d44517e15dec8a Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Fri, 11 Oct 2024 16:30:40 +0000 Subject: [PATCH 15/15] Update docstrings --- .../decoder_only/audio_language/test_ultravox.py | 2 ++ vllm/multimodal/base.py | 14 +++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index 2530921c1e68..8c081d5cba04 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -227,6 +227,8 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str, @pytest.mark.asyncio async def test_online_inference(client, audio_assets): + """Exercises online inference with/without chunked prefill enabled.""" + messages = [{ "role": "user", diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 3e17a9bba6b1..caf207e12916 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -157,10 +157,18 @@ class MultiModalDataBuiltins(TypedDict, total=False): class PlaceholderRange(TypedDict): - """A placeholder for multi-modal data.""" + """ + A placeholder for multi-modal data. + + For example: + Prompt: AAAA BBBB What is in these images? + Images A and B will have: + A: { "offset": 0, "length": 4 } + B: { "offset": 5, "length": 4 } + """ offset: int - """The start index of the placeholder.""" + """The start index of the placeholder in the prompt.""" length: int """The length of the placeholder.""" @@ -400,7 +408,7 @@ class IndexMap(NamedTuple): src_ranges: List[range] """ The indices of the multi-modal embeddings that will replace the - corresponding placeholder embeddings pointed to be ``dest_ranges``. + corresponding placeholder embeddings pointed to by ``dest_ranges``. """ src_len: int