From 2f707fcb35c5bc4b9164cf2bbce0254a72f7348b Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 7 Sep 2024 10:57:24 +0800 Subject: [PATCH] [Model] Multi-input support for LLaVA (#8238) --- docs/source/models/supported_models.rst | 16 +- tests/conftest.py | 12 +- .../distributed/test_multimodal_broadcast.py | 6 +- tests/models/test_llava.py | 141 ++++++++++++++++-- vllm/model_executor/models/clip.py | 2 +- vllm/model_executor/models/internvl.py | 2 +- vllm/model_executor/models/llava.py | 32 ++-- vllm/model_executor/models/llava_next.py | 4 +- vllm/model_executor/models/phi3v.py | 4 +- vllm/model_executor/models/siglip.py | 2 +- 10 files changed, 176 insertions(+), 45 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 0c0a54281e3f..fe01e1681353 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -219,7 +219,7 @@ Multimodal Language Models - * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - - Image\ :sup:`E` + - Image\ :sup:`E+` - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. - * - :code:`LlavaNextForConditionalGeneration` @@ -227,6 +227,11 @@ Multimodal Language Models - Image\ :sup:`E+` - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - + * - :code:`MiniCPMV` + - MiniCPM-V + - Image\ :sup:`+` + - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. + - * - :code:`PaliGemmaForConditionalGeneration` - PaliGemma - Image\ :sup:`E` @@ -237,14 +242,9 @@ Multimodal Language Models - Image\ :sup:`E+` - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - - * - :code:`MiniCPMV` - - MiniCPM-V - - Image\ :sup:`+` - - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - - * - :code:`QWenLMHeadModel` - - Qwen - - Image + - Qwen-VL + - Image\ :sup:`E` - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - * - :code:`UltravoxModel` diff --git a/tests/conftest.py b/tests/conftest.py index e66a14598c34..cd0091b7cba6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -278,7 +278,7 @@ def __init__( def generate( self, prompts: List[str], - images: Optional[List[Image.Image]] = None, + images: Optional[PromptImageInput] = None, **kwargs: Any, ) -> List[Tuple[List[List[int]], List[str]]]: if images: @@ -314,7 +314,7 @@ def generate_greedy( self, prompts: List[str], max_tokens: int, - images: Optional[List[Image.Image]] = None, + images: Optional[PromptImageInput] = None, **kwargs: Any, ) -> List[Tuple[List[int], str]]: outputs = self.generate(prompts, @@ -351,7 +351,7 @@ def generate_greedy_logprobs( self, prompts: List[str], max_tokens: int, - images: Optional[List[Image.Image]] = None, + images: Optional[PromptImageInput] = None, **kwargs: Any, ) -> List[List[torch.Tensor]]: all_logprobs: List[List[torch.Tensor]] = [] @@ -433,8 +433,8 @@ def generate_greedy_logprobs_limit( prompts: List[str], max_tokens: int, num_logprobs: int, - images: Optional[List[Image.Image]] = None, - audios: Optional[List[Tuple[np.ndarray, int]]] = None, + images: Optional[PromptImageInput] = None, + audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: all_logprobs: List[List[Dict[int, float]]] = [] @@ -671,7 +671,7 @@ def generate_greedy( self, prompts: List[str], max_tokens: int, - images: Optional[List[Image.Image]] = None, + images: Optional[PromptImageInput] = None, ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) outputs = self.generate(prompts, greedy_params, images=images) diff --git a/tests/distributed/test_multimodal_broadcast.py b/tests/distributed/test_multimodal_broadcast.py index e7723a7ae248..73ef863c2f19 100644 --- a/tests/distributed/test_multimodal_broadcast.py +++ b/tests/distributed/test_multimodal_broadcast.py @@ -35,9 +35,11 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str, if model.startswith("llava-hf/llava-1.5"): from ..models.test_llava import models, run_test elif model.startswith("llava-hf/llava-v1.6"): - from ..models.test_llava_next import models, run_test + from ..models.test_llava_next import run_test # type: ignore[no-redef] + from ..models.test_llava_next import models elif model.startswith("facebook/chameleon"): - from ..models.test_chameleon import models, run_test + from ..models.test_chameleon import run_test # type: ignore[no-redef] + from ..models.test_chameleon import models else: raise NotImplementedError(f"Unsupported model: {model}") diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index 9d7da5f803ea..84ca23f6222a 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type, overload import pytest from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, @@ -8,11 +8,14 @@ from vllm.sequence import SampleLogprobs from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) from .utils import check_logprobs_close pytestmark = pytest.mark.vlm +_LIMIT_IMAGE_PER_PROMPT = 4 + HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": "USER: \nWhat's the content of the image?\nASSISTANT:", @@ -52,6 +55,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, return hf_output_ids, hf_output_str, out_logprobs +@overload def run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], @@ -64,6 +68,78 @@ def run_test( num_logprobs: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, +): + ... + + +@overload +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + sizes: List[Tuple[int, int]], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + ... + + +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: Optional[List[float]] = None, + sizes: Optional[List[Tuple[int, int]]] = None, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + images = [asset.pil_image for asset in image_assets] + + if size_factors is not None: + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + elif sizes is not None: + inputs_per_image = [( + [prompt for _ in sizes], + [image.resize(size) for size in sizes], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + else: + raise ValueError("You must provide either `size_factors` or `sizes`") + + _run_test(hf_runner, + vllm_runner, + inputs_per_image, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend) + + +def _run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + inputs: List[Tuple[List[str], PromptImageInput]], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, ): """Inference result should be the same between hf and vllm. @@ -85,13 +161,6 @@ def run_test( else: mantis_processor = None - images = [asset.pil_image for asset in image_assets] - - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it @@ -100,15 +169,18 @@ def run_test( # max_model_len should be greater than image_feature_size with vllm_runner(model, dtype=dtype, + max_model_len=4096, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: + enforce_eager=True, + limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT + }) as vllm_model: vllm_outputs_per_image = [ vllm_model.generate_greedy_logprobs(prompts, max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] if mantis_processor is not None: @@ -131,7 +203,7 @@ def process(hf_inputs: BatchEncoding): max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, @@ -181,6 +253,51 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, ) +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets, + model, dtype, max_tokens, + num_logprobs) -> None: + stop_sign = image_assets[0].pil_image + cherry_blossom = image_assets[1].pil_image + + inputs = [( + [ + "USER: \nDescribe 2 images.\nASSISTANT:", + "USER: \nDescribe 2 images.\nASSISTANT:", + "USER: \nDescribe 4 images.\nASSISTANT:", # noqa: E501 + "USER: \nWhat is the season?\nASSISTANT:", + ], + [ + [stop_sign, cherry_blossom], + # Images with different sizes and aspect-ratios + [ + rescale_image_size(stop_sign, 0.1), + stop_sign, + ], + [ + stop_sign, + rescale_image_size(stop_sign, 0.25), + cherry_blossom.resize((183, 488)), + cherry_blossom.resize((488, 183)) + ], + cherry_blossom, + ])] + + _run_test( + hf_runner, + vllm_runner, + inputs, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + + @pytest.mark.parametrize("model", models) def test_context_length_too_short(vllm_runner, image_assets, model): images = [asset.pil_image for asset in image_assets] diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index b581a501e333..70f1522ae252 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -105,7 +105,7 @@ def input_processor_for_clip( if isinstance(image_data, Image.Image): image_feature_size = get_clip_image_feature_size(hf_config) elif isinstance(image_data, torch.Tensor): - image_feature_size = image_data.shape[0] + num_images, image_feature_size, hidden_size = image_data.shape else: raise TypeError(f"Invalid image type: {type(image_data)}") else: diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index d317fdce3ba6..10fbb5663d27 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -209,7 +209,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): image_feature_size = num_blocks * num_patches elif isinstance(image_data, torch.Tensor): - image_feature_size = image_data.shape[0] + num_images, image_feature_size, hidden_size = image_data.shape else: raise TypeError(f"Invalid image type: {type(image_data)}") diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 43c485bdf366..7a6c991fb133 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +from PIL import Image from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig from vllm.attention import AttentionMetadata @@ -16,6 +17,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of from .clip import (CLIPVisionModel, dummy_image_for_clip, dummy_seq_data_for_clip, get_max_clip_image_tokens, @@ -24,7 +26,7 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens, input_processor_for_siglip) -from .utils import (filter_weights, init_vllm_registered_model, +from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, merge_multimodal_embeddings) @@ -133,7 +135,18 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): hf_config = ctx.get_hf_config(LlavaConfig) vision_config = hf_config.vision_config - image_feature_size = get_max_llava_image_tokens(ctx) + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + image_feature_size = get_max_llava_image_tokens(ctx) + elif is_list_of(image_data, Image.Image): + image_feature_size = [get_max_llava_image_tokens(ctx) + ] * len(image_data) + elif isinstance(image_data, torch.Tensor): + num_images, image_feature_size, hidden_size = image_data.shape + elif is_list_of(image_data, torch.Tensor): + image_feature_size = [item.shape[1] for item in image_data] + else: + raise TypeError(f"Invalid image type: {type(image_data)}") if isinstance(vision_config, CLIPVisionConfig): return input_processor_for_clip( @@ -230,29 +243,24 @@ def _parse_and_validate_image_input( return None if pixel_values is not None: - if not isinstance(pixel_values, torch.Tensor): + if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - # Remove the N dimension until multiple images are supported. - pixel_values = pixel_values.squeeze(1) - return LlavaImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values(pixel_values), + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), ) if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): + if not isinstance(image_embeds, (torch.Tensor, list)): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") - # Remove the N dimension until multiple images are supported. - image_embeds = image_embeds.squeeze(1) - return LlavaImageEmbeddingInputs( type="image_embeds", - data=image_embeds, + data=flatten_bn(image_embeds, concat=True), ) raise AssertionError("This line should be unreachable.") diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 5a179e960371..c6bd46dd7eda 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -234,7 +234,9 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): for img in image_data ] elif isinstance(image_data, torch.Tensor): - image_feature_size = image_data.shape[0] + num_images, image_feature_size, hidden_size = image_data.shape + elif is_list_of(image_data, torch.Tensor): + image_feature_size = [item.shape[1] for item in image_data] else: raise TypeError(f"Invalid image type: {type(image_data)}") diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index c449e0fc759a..6f17f571ccae 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -424,7 +424,9 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): input_width=w, input_height=h)) elif isinstance(image_data, torch.Tensor): - image_feature_size = image_data.shape[0] + num_images, image_feature_size, hidden_size = image_data.shape + elif is_list_of(image_data, torch.Tensor): + image_feature_size = [item.shape[1] for item in image_data] else: raise TypeError(f"Invalid image type: {type(image_data)}") diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 0bee75e2f0cb..fb4c30c1a13f 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -110,7 +110,7 @@ def input_processor_for_siglip( if isinstance(image_data, Image.Image): image_feature_size = get_siglip_image_feature_size(hf_config) elif isinstance(image_data, torch.Tensor): - image_feature_size = image_data.shape[0] + num_images, image_feature_size, hidden_size = image_data.shape else: raise TypeError(f"Invalid image type: {type(image_data)}") else: