Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model][VLM] Support multi-images inputs for InternVL2 models #8201

Merged
merged 10 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ Multimodal Language Models
-
* - :code:`InternVLChatModel`
- InternVL2
- Image\ :sup:`E`
- Image\ :sup:`E+`
- :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
-
* - :code:`LlavaForConditionalGeneration`
Expand Down
95 changes: 77 additions & 18 deletions examples/offline_inference_vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from argparse import Namespace
from typing import List

from vllm import LLM
from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.multimodal.utils import fetch_image
from vllm.utils import FlexibleArgumentParser

Expand All @@ -17,36 +19,85 @@
]


def _load_phi3v(image_urls: List[str]):
return LLM(
def load_phi3v(question, image_urls: List[str]):
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True,
max_model_len=4096,
limit_mm_per_prompt={"image": len(image_urls)},
)


def run_phi3v_generate(question: str, image_urls: List[str]):
llm = _load_phi3v(image_urls)

placeholders = "\n".join(f"<|image_{i}|>"
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
stop_token_ids = None
return llm, prompt, stop_token_ids

outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {
"image": [fetch_image(url) for url in image_urls]

def load_internvl(question, image_urls: List[str]):
# model_name = "OpenGVLab/InternVL2-2B"
model_name = "/data/LLM-model/InternVL2-2B"
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved

llm = LLM(
model=model_name,
trust_remote_code=True,
max_num_seqs=5,
max_model_len=4096,
limit_mm_per_prompt={"image": len(image_urls)},
)

placeholders = "\n".join(f"Image-{i}: <image>\n"
for i, _ in enumerate(image_urls, start=1))
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]

tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

# Stop tokens for InternVL
# models variants may have different stop tokens
# please refer to the model card for the correct "stop words":
# https://huggingface.co/OpenGVLab/InternVL2-2B#service
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return llm, prompt, stop_token_ids


model_example_map = {
"phi3_v": load_phi3v,
"internvl_chat": load_internvl,
}


def run_generate(model, question: str, image_urls: List[str]):
llm, prompt, stop_token_ids = model_example_map[model](question,
image_urls)

sampling_params = SamplingParams(temperature=0.0,
max_tokens=128,
stop_token_ids=stop_token_ids)

outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": {
"image": [fetch_image(url) for url in image_urls]
},
},
})
sampling_params=sampling_params)

for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)


def run_phi3v_chat(question: str, image_urls: List[str]):
llm = _load_phi3v(image_urls)
def run_chat(model: str, question: str, image_urls: List[str]):
llm, _, stop_token_ids = model_example_map[model](question, image_urls)

sampling_params = SamplingParams(temperature=0.0,
max_tokens=128,
stop_token_ids=stop_token_ids)

outputs = llm.chat([{
"role":
Expand All @@ -63,20 +114,22 @@ def run_phi3v_chat(question: str, image_urls: List[str]):
},
} for image_url in image_urls),
],
}])
}],
sampling_params=sampling_params)

for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)


def main(args: Namespace):
model = args.model_type
method = args.method

if method == "generate":
run_phi3v_generate(QUESTION, IMAGE_URLS)
run_generate(model, QUESTION, IMAGE_URLS)
elif method == "chat":
run_phi3v_chat(QUESTION, IMAGE_URLS)
run_chat(model, QUESTION, IMAGE_URLS)
else:
raise ValueError(f"Invalid method: {method}")

Expand All @@ -85,6 +138,12 @@ def main(args: Namespace):
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'vision language models that support multi-image input')
parser.add_argument('--model-type',
'-m',
type=str,
default="phi3_v",
choices=model_example_map.keys(),
help='Huggingface "model_type".')
parser.add_argument("--method",
type=str,
default="generate",
Expand Down
89 changes: 71 additions & 18 deletions tests/models/test_internvl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import types
from typing import List, Optional, Tuple, Type
from typing import List, Optional, Tuple, Type, Union

import pytest
import torch
Expand All @@ -20,6 +20,7 @@
"cherry_blossom":
"<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
})
HF_MULTIIMAGE_IMAGE_PROMPT = "Image-1: <image>\nImage-2: <image>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501

models = [
"OpenGVLab/InternVL2-1B",
Expand Down Expand Up @@ -64,13 +65,13 @@ def generate(
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
inputs: List[Tuple[List[str], Union[List[Image], List[List[Image]]]]],
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
mm_limit: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
Expand All @@ -83,12 +84,6 @@ def run_test(
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
images = [asset.pil_image for asset in image_assets]

inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]

# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
Expand All @@ -110,13 +105,21 @@ def __init__(self, hf_runner: HfRunner):
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size

def __call__(self, text: str, images: Image, **kwargs):
def __call__(self, text: str, images: Union[Image, List[Image]],
**kwargs):
from vllm.model_executor.models.internvl import (
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
pixel_values = image_to_pixel_values(
images, self.image_size, self.min_num, self.max_num,
self.use_thumbnail).to(self.dtype)
num_patches_list = [pixel_values.shape[0]]
images = [images] if isinstance(images, Image) else images
pixel_values = [
image_to_pixel_values(image, self.image_size, self.min_num,
self.max_num,
self.use_thumbnail).to(self.dtype)
for image in images
]
num_patches_list = [
pixel_value.shape[0] for pixel_value in pixel_values
]
pixel_values = torch.cat(pixel_values, dim=0)
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token \
* num_patches
Expand All @@ -130,6 +133,7 @@ def __call__(self, text: str, images: Image, **kwargs):
with vllm_runner(model,
max_model_len=4096,
dtype=dtype,
limit_mm_per_prompt={"image": mm_limit},
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
Expand All @@ -138,7 +142,7 @@ def __call__(self, text: str, images: Image, **kwargs):
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
for prompts, images in inputs
]

with hf_runner(model, dtype=dtype) as hf_model:
Expand All @@ -156,7 +160,7 @@ def __call__(self, text: str, images: Image, **kwargs):
num_logprobs=num_logprobs,
images=hf_images,
eos_token_id=eos_token_id)
for prompts, hf_images in inputs_per_image
for prompts, hf_images in inputs
]

for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
Expand Down Expand Up @@ -264,15 +268,64 @@ def run_awq_test(
@torch.inference_mode()
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype: str, max_tokens: int, num_logprobs: int) -> None:
images = [asset.pil_image for asset in image_assets]

inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]

run_test(
hf_runner,
vllm_runner,
image_assets,
inputs_per_image,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
mm_limit=1,
tensor_parallel_size=1,
)


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@torch.inference_mode()
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
size_factors, dtype: str, max_tokens: int,
num_logprobs: int) -> None:
images = [asset.pil_image for asset in image_assets]

inputs_per_case = [
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
[[rescale_image_size(image, factor) for image in images]
for factor in size_factors])
]

run_test(
hf_runner,
vllm_runner,
inputs_per_case,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
mm_limit=2,
tensor_parallel_size=1,
)

Expand Down
Loading
Loading