From 179a6a36f2a585df49ce9c26701b1b9d894bd00e Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sun, 4 Aug 2024 16:12:41 +0800 Subject: [PATCH] [Model]Refactor MiniCPMV (#7020) Co-authored-by: Cyrus Leung --- docs/source/models/supported_models.rst | 2 +- .../models/idefics2_vision_model.py | 296 +++++ vllm/model_executor/models/minicpmv.py | 1023 ++++++++++------- vllm/model_executor/models/na_vit.py | 2 +- 4 files changed, 937 insertions(+), 386 deletions(-) create mode 100644 vllm/model_executor/models/idefics2_vision_model.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index a1ea366b82b0..fd5d154006ae 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -220,7 +220,7 @@ Vision Language Models - Phi-3-Vision - :code:`microsoft/Phi-3-vision-128k-instruct`, etc. - - * - :code:`MiniCPM-V` + * - :code:`MiniCPMV` - MiniCPM-V - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc. - diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py new file mode 100644 index 000000000000..cc448ed28d2d --- /dev/null +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -0,0 +1,296 @@ +# coding=utf-8 + +# adapted from https://github.com/huggingface/transformers/blob/v4.43.2/src/transformers/models/idefics2/modeling_idefics2.py +# Copyright 2024 The vLLM team. +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Idefics2 model.""" + +from typing import Optional + +import torch +from torch import nn +from transformers.models.idefics2.configuration_idefics2 import ( + Idefics2Config, Idefics2VisionConfig) +from xformers import ops as xops + +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig + + +class Idefics2VisionEmbeddings(nn.Module): + """ + This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings + ` to enable images of variable + resolution. + + The modifications are adapted from [Patch n' Pack: NaViT, a Vision + Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) + which allows treating images in their native aspect ratio and without the + need to resize them to the same fixed size. In particular, we start from the + original pre-trained SigLIP model(which uses images of fixed-size square + images) and adapt it by training on images of variable resolutions. + """ + + def __init__(self, config: Idefics2VisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, + self.embed_dim) + + def forward( + self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + ) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + max_nb_patches_h, max_nb_patches_w = ( + max_im_h // self.patch_size, + max_im_w // self.patch_size, + ) + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, + 1 / self.num_patches_per_side) + position_ids = torch.full(size=(batch_size, + max_nb_patches_h * max_nb_patches_w), + fill_value=0) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + bucket_coords_h = torch.bucketize(fractional_coords_h, + boundaries, + right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, + boundaries, + right=True) + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Idefics2VisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Idefics2Config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" # noqa: E501 + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + quant_config=quant_config, + ) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + ) + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + batch_size, q_len, _ = hidden_states.size() + qkv, _ = self.qkv_proj( + hidden_states + ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim + query_states, key_states, value_states = qkv.chunk(3, dim=-1) + query_states = query_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + # see: https://facebookresearch.github.io/xformers/components/ops.html + out = xops.memory_efficient_attention_forward( + query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale, + ) + out = out.view(batch_size, q_len, -1) + attn_output, _ = self.out_proj(out) + return attn_output + + +class Idefics2VisionMLP(nn.Module): + + def __init__( + self, + config: Idefics2Config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class Idefics2EncoderLayer(nn.Module): + + def __init__(self, config: Idefics2Config): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Idefics2VisionAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = Idefics2VisionMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + + """ + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Idefics2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention + layers. Each layer is a + [`Idefics2EncoderLayer`]. + + Args: + config: Idefics2Config + """ + + def __init__(self, config: Idefics2Config): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + Idefics2EncoderLayer(config) + for _ in range(config.num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds: torch.Tensor, + ) -> torch.Tensor: + r""" + Args: + inputs_embeds (torch.Tensor): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. + This is useful if you want more control over how to convert + `input_ids` indices into associated vectorsthan the model's + internal embedding lookup matrix. + """ + hidden_states = inputs_embeds + for encoder_layer in self.layers: + layer_outputs = encoder_layer(hidden_states) + hidden_states = layer_outputs + return hidden_states + + +class Idefics2VisionTransformer(nn.Module): + + def __init__(self, config: Idefics2VisionConfig): + super().__init__() + embed_dim = config.hidden_size + self.config = config + self.embeddings = Idefics2VisionEmbeddings(config) + self.encoder = Idefics2Encoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + + def get_input_embeddings(self): + return self.embeddings + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + ) -> torch.tensor: + hidden_states = self.embeddings( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask) + encoder_outputs = self.encoder(hidden_states) + last_hidden_state = self.post_layernorm(encoder_outputs) + return last_hidden_state diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 2a7fe7ba0eba..095bb49f6ba7 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -24,7 +24,8 @@ import math import re from functools import partial -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import (Any, Callable, Iterable, List, Optional, Tuple, TypedDict, + Union) import numpy as np import torch @@ -38,11 +39,14 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsVision from vllm.model_executor.models.llama import LlamaModel @@ -54,12 +58,45 @@ cached_get_tokenizer) from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData +from .idefics2_vision_model import Idefics2VisionTransformer + +logger = init_logger(__name__) + _KEYS_TO_MODIFY_MAPPING = { "llm.lm_head": "lm_head", "llm.model": "llm", } +class MiniCPMVImagePixelInputs(TypedDict): + pixel_values: List[torch.Tensor] + """ + Shape: `(batch_size * num_images, num_channels, height, width)` + + Note that the image size may vary, so we pass it as a list + instead of a batched tensor. + """ + + image_bounds: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(start, stop)` format. + """ + + tgt_sizes: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(height, width)` format. + """ + + +MiniCPMVImageInputs = MiniCPMVImagePixelInputs + +DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) + + def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor): # abs_pos: L, C # tgt_size: (H, W) @@ -68,23 +105,25 @@ def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor): # tgt_size = int(math.sqrt(tgt_size)) dtype = abs_pos.dtype - return F.interpolate( + return (F.interpolate( abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), size=(tgt_size[0], tgt_size[1]), mode="bicubic", align_corners=False, - ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) + ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)) # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 -def get_2d_sincos_pos_embed(embed_dim: int, - grid_size: Union[int, Tuple[int, int]], - cls_token: bool = False, - version: Tuple[int, int] = (2, 0)): +def get_2d_sincos_pos_embed( + embed_dim: int, + grid_size: Union[int, Tuple[int, int]], + cls_token: bool = False, + version: Tuple[int, int] = (2, 0), +): """ grid_size: int of the grid height and width return: - pos_embed: [grid_size*grid_size, embed_dim] or + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ if isinstance(grid_size, int): @@ -109,7 +148,7 @@ def get_2d_sincos_pos_embed(embed_dim: int, def get_2d_sincos_pos_embed_from_grid(embed_dim: int, - grid: Union[int, Tuple[int, int]], + grid: np.ndarray, version: Tuple[int, int] = (2, 0)): assert embed_dim % 2 == 0 @@ -127,7 +166,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, def get_1d_sincos_pos_embed_from_grid(embed_dim: int, - pos: int, + pos: np.ndarray, version: Tuple[int, int] = (2, 0)): """ embed_dim: output dimension for each position @@ -136,24 +175,24 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) - omega /= embed_dim / 2. - omega = 1. / 10000**omega # (D/2,) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) if version == (2, 0): pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) else: - out = np.einsum('hw,d->hwd', pos, omega) # (H, W, D/2), outer product + out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product emb_sin = np.sin(out) # (H, W, D/2) emb_cos = np.cos(out) # (H, W, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) return emb -class Resampler(nn.Module): +class BaseResampler(nn.Module): """ A 2D perceiver-resampler network with one cross attention layers by (grid_size**2) learnable queries and 2d sincos pos_emb @@ -161,89 +200,151 @@ class Resampler(nn.Module): A tensor with the shape of (grid_size**2, embed_dim) """ - default_norm_layer = partial(nn.LayerNorm, eps=1e-6) - - def __init__(self, - num_queries: int, - grid_size: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: nn.Module = default_norm_layer, - adaptive: bool = False, - max_size: Tuple[int, int] = (70, 70), - version: Tuple[int, int] = (2, 0)): + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + ) -> None: super().__init__() - self.version = version - if self.version == (2, 0): - self.num_queries = grid_size**2 - else: - self.num_queries = num_queries - self.max_size = max_size + self.num_queries = num_queries self.embed_dim = embed_dim self.num_heads = num_heads - self.adaptive = adaptive self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) - trunc_normal_(self.query, std=.02) + trunc_normal_(self.query, std=0.02) if kv_dim is not None and kv_dim != embed_dim: - self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) + self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False) else: - self.kv_proj = nn.Identity() + # Maintain the same return value with ReplicatedLinear.forward + self.kv_proj = lambda *args, **kwargs: ( + nn.Identity()(*args, **kwargs), + None, + ) self.attn = nn.MultiheadAttention(embed_dim, num_heads) self.ln_q = norm_layer(embed_dim) self.ln_kv = norm_layer(embed_dim) - self.ln_post = norm_layer(embed_dim) self.proj = nn.Parameter( (embed_dim**-0.5) * torch.randn(embed_dim, embed_dim)) - if self.version == (2, 0): - self.pos_embed = nn.Parameter( - torch.from_numpy( - get_2d_sincos_pos_embed( - embed_dim, grid_size, - version=self.version)).float()).requires_grad_(False) + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class Resampler2(BaseResampler): + + def __init__( + self, + grid_size: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + adaptive: bool = False, + ) -> None: + super().__init__(grid_size**2, embed_dim, num_heads, kv_dim, + norm_layer) + + self.adaptive = adaptive + + pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, + grid_size, + version=(2, 0)) + self.pos_embed = nn.Parameter( + torch.from_numpy(pos_embed_arr).float()).requires_grad_(False) + + self.apply(self._init_weights) + + def forward( + self, + x: torch.Tensor, + tgt_sizes: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ): + if self.adaptive: + pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, + tgt_sizes, + version=(2, 0)) + pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device, + dtype=x.dtype) else: - self._set_2d_pos_cache(self.max_size) + pos_embed = get_abs_pos(self.pos_embed, tgt_sizes) + + x, _ = self.kv_proj(x) + x = self.ln_kv(x).permute(1, 0, 2) + + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn( + self._repeat(q, N) + self.pos_embed.unsqueeze(1), + x + pos_embed.unsqueeze(1), + x, + attn_mask=attn_mask, + )[0] + x = out.permute(1, 0, 2) + + x = self.ln_post(x) + x = x @ self.proj + return x + + +class Resampler2_5(BaseResampler): + + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: Tuple[int, int] = (70, 70), + ) -> None: + super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer) + + self.max_size = max_size + self._set_2d_pos_cache(self.max_size) self.apply(self._init_weights) def _set_2d_pos_cache(self, max_size: Tuple[int, int], - device: torch.types.Device = 'cpu'): - pos_embed = torch.from_numpy( - get_2d_sincos_pos_embed(self.embed_dim, - max_size, - version=self.version)).float().to(device) + device: torch.types.Device = "cpu") -> None: + pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, + max_size, + version=(2, 5)) + pos_embed = torch.from_numpy(pos_embed_arr).float().to(device) self.register_buffer("pos_embed", pos_embed, persistent=False) def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, - device: torch.types.Device): - max_h = torch.max(tgt_sizes[:, 0]) - max_w = torch.max(tgt_sizes[:, 1]) + device: torch.types.Device) -> None: + max_h = tgt_sizes[:, 0].max().item() + max_w = tgt_sizes[:, 1].max().item() + assert isinstance(max_h, int) and isinstance(max_w, int) + if max_h > self.max_size[0] or max_w > self.max_size[1]: - self.max_size = [ + self.max_size = ( max(max_h, self.max_size[0]), - max(max_w, self.max_size[1]) - ] + max(max_w, self.max_size[1]), + ) self._set_2d_pos_cache(self.max_size, device) - def _init_weights(self, m: nn.Module): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward_2_5(self, - x: torch.Tensor, - tgt_sizes: Optional[torch.Tensor] = None): + def forward(self, x: torch.Tensor, + tgt_sizes: torch.Tensor) -> torch.Tensor: assert x.shape[0] == tgt_sizes.shape[0] bs = x.shape[0] @@ -254,25 +355,25 @@ def forward_2_5(self, self._adjust_pos_cache(tgt_sizes, device=device) - max_patch_len = torch.max(patch_len) + max_patch_len = patch_len.max().item() + assert isinstance(max_patch_len, int) + key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool, device=device) pos_embed = [] for i in range(bs): - tgt_h, tgt_w = tgt_sizes[i] + tgt_h, tgt_w = tgt_sizes[i].tolist() pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape( (tgt_h * tgt_w, -1)).to(dtype)) # patches * D key_padding_mask[i, patch_len[i]:] = True - pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, batch_first=True, padding_value=0.0).permute( 1, 0, 2) # BLD => L * B * D - - x = self.kv_proj(x) # B * L * D + x, _ = self.kv_proj(x) # B * L * D x = self.ln_kv(x).permute(1, 0, 2) # L * B * D q = self.ln_q(self.query) # Q * D @@ -281,7 +382,8 @@ def forward_2_5(self, self._repeat(q, bs), # Q * B * D x + pos_embed, # L * B * D + L * B * D x, - key_padding_mask=key_padding_mask)[0] + key_padding_mask=key_padding_mask, + )[0] # out: Q * B * D x = out.permute(1, 0, 2) # B * Q * D @@ -289,45 +391,6 @@ def forward_2_5(self, x = x @ self.proj return x - def forward_2(self, - x: torch.Tensor, - tgt_sizes: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None): - if self.adaptive: - pos_embed = torch.Tensor( - get_2d_sincos_pos_embed(self.embed_dim, - tgt_sizes)).float().to(device=x.device, - dtype=x.dtype) - else: - pos_embed = get_abs_pos(self.pos_embed, tgt_sizes) - - x = self.kv_proj(x) - x = self.ln_kv(x).permute(1, 0, 2) - - N = x.shape[1] - q = self.ln_q(self.query) - out = self.attn(self._repeat(q, N) + self.pos_embed.unsqueeze(1), - x + pos_embed.unsqueeze(1), - x, - attn_mask=attn_mask)[0] - x = out.permute(1, 0, 2) - - x = self.ln_post(x) - x = x @ self.proj - return x - - def forward(self, - x: torch.Tensor, - tgt_sizes: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None): - if self.version == (2, 0): - return self.forward_2(x, tgt_sizes=tgt_sizes, attn_mask=attn_mask) - else: - return self.forward_2_5(x, tgt_sizes=tgt_sizes) - - def _repeat(self, query, N: int): - return query.unsqueeze(1).repeat(1, N, 1) - def get_max_minicpmv_image_tokens(ctx: InputContext): hf_config = ctx.get_hf_config(PretrainedConfig) @@ -348,10 +411,7 @@ def dummy_image_for_minicpmv(hf_config: PretrainedConfig): def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int): hf_config = ctx.get_hf_config(PretrainedConfig) - # image_feature_size = get_max_minicpmv_image_tokens(ctx) - seq_data = dummy_seq_data_for_minicpmv(seq_len) - mm_data = dummy_image_for_minicpmv(hf_config) return seq_data, mm_data @@ -376,25 +436,36 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): pattern = "(./)" image = multi_modal_data["image"] image_tags = re.findall(pattern, prompt) - assert len(image_tags) <= 1 - text_chunks = prompt.split(pattern) - new_prompt = text_chunks[0] \ - + image_processor.get_slice_image_placeholder(image.size) \ - + text_chunks[1] - new_token_ids = tokenizer.encode(new_prompt) - - llm_inputs = LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + if len(image_tags) == 0: + new_token_ids = token_ids + new_prompt = prompt + else: + if len(image_tags) > 1: + logger.warning("Multiple image input is not supported yet, " + "so any extra image tokens will be treated " + "as plain text.") + + text_chunks = prompt.split(pattern) + new_prompt = (text_chunks[0] + + image_processor.get_slice_image_placeholder(image.size) + + "".join(text_chunks[1:])) + + new_token_ids = tokenizer.encode(new_prompt) + + llm_inputs = LLMInputs( + prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data, + ) return llm_inputs -@MULTIMODAL_REGISTRY.register_image_input_mapper() -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv) -@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv) -class MiniCPMV(nn.Module, SupportsVision): +class MiniCPMVBaseModel(nn.Module, SupportsVision): + """ + The abstract class of MiniCPMV can only be inherited, but cannot be + instantiated. + """ def __init__( self, @@ -419,8 +490,8 @@ def __init__( self.vpm = self.init_vision_module() param_dtype = torch.get_default_dtype() self.vpm.to(dtype=param_dtype) - self.vision_dim = self.vpm.embed_dim if self.version == (2, 0) \ - else self.vpm.embeddings.embed_dim + self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else + self.vpm.embeddings.embed_dim) self.embed_dim = self.config.hidden_size self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) self.resampler.to(device="cuda", dtype=param_dtype) @@ -430,248 +501,100 @@ def __init__( self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() - def init_llm(self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): - if self.version == (2, 0): - return MiniCPMModel(config, - cache_config=cache_config, - quant_config=quant_config) - elif self.version == (2, 5): - return LlamaModel(config, - cache_config=cache_config, - quant_config=quant_config) - else: - return Qwen2Model(config, - cache_config=cache_config, - quant_config=quant_config) - - def init_vision_module(self): - if self.version == (2, 0): - try: - import timm - except ImportError: - raise ImportError( - 'Please install timm==0.9.10') from ImportError - default_dtype = torch.get_default_dtype() - torch.set_default_dtype(torch.float16) - model = timm.create_model('vit_so400m_patch14_siglip_384.webli', - pretrained=False, - num_classes=0, - dynamic_img_size=True, - dynamic_img_pad=True) - torch.set_default_dtype(default_dtype) - if isinstance(model, timm.models.VisionTransformer - ) and model.attn_pool is not None: - model.attn_pool = torch.nn.Identity() - - if self.config.drop_vision_last_layer: - model.blocks = model.blocks[:-1] - elif self.version == (2, 5): - from transformers.models.idefics2.modeling_idefics2 import ( - Idefics2VisionTransformer) - model = Idefics2VisionTransformer(self.config.vision_config) - if self.config.drop_vision_last_layer: - model.encoder.layers = model.encoder.layers[:-1] - else: - from vllm.model_executor.models.na_vit import ( - SiglipVisionTransformer) - if self.config._attn_implementation == 'flash_attention_2': - self.config.vision_config._attn_implementation \ - = 'flash_attention_2' - else: - # not support sdpa - self.config.vision_config._attn_implementation = 'eager' - model = SiglipVisionTransformer(self.config.vision_config) - if self.config.drop_vision_last_layer: - model.encoder.layers = model.encoder.layers[:-1] - return model - - def init_resampler(self, embed_dim: int, vision_dim: int): - default_dtype = torch.get_default_dtype() - torch.set_default_dtype(torch.float16) - if self.version == (2, 0): - resampler = Resampler(grid_size=int( - math.sqrt(self.config.query_num)), - num_queries=None, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - adaptive=True, - version=self.version) + def get_embedding( + self, + input_ids: torch.Tensor, + image_inputs: Optional[MiniCPMVImageInputs], + ) -> Tuple[torch.Tensor, torch.Tensor]: + vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids) + if hasattr(self.config, "scale_emb"): + vlm_embedding *= self.config.scale_emb + + if image_inputs is None: # No image + vision_hidden_states = torch.tensor([], device=input_ids.device) else: - resampler = Resampler(num_queries=self.config.query_num, - grid_size=None, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - adaptive=True, - version=self.version) - torch.set_default_dtype(default_dtype) - return resampler + vision_hidden_states = self.get_vision_hidden_states(image_inputs) + + # See NOTE in _parse_and_validate_inputs + image_bounds = image_inputs["image_bounds"] + if len(image_bounds) > 0: + image_indices = torch.stack([ + torch.arange(start, end, dtype=torch.long) + for start, end in image_bounds.tolist() + ]).to(vlm_embedding.device) + vlm_embedding.scatter_( + 0, + image_indices.view(-1, 1).repeat(1, + vlm_embedding.shape[-1]), + vision_hidden_states.view(-1, + vision_hidden_states.shape[-1]), + ) - def get_vision_embedding(self, - pixel_values: List[List[torch.Tensor]], - patch_attn_mask: Optional[torch.Tensor] = None, - tgt_sizes: Optional[torch.Tensor] = None, - version: Tuple[int, int] = (2, 0)): - if version == (2, 0): - res = [] - dtype = self.vpm.pos_embed.data.dtype - for pixel_value in pixel_values: - # V2.0 start - H, W = pixel_value[0].shape[-2:] - tgt_size = (math.ceil(H / self.vpm.patch_embed.patch_size[0]), - math.ceil(W / self.vpm.patch_embed.patch_size[0])) - # V2.0 end - vision_embedding = self.vpm.forward_features( - pixel_value.unsqueeze(0).type(dtype)) - if hasattr(self.vpm, 'num_prefix_tokens' - ) and self.vpm.num_prefix_tokens > 0: - vision_embedding = vision_embedding[:, self.vpm. - num_prefix_tokens:] - res.append(self.resampler(vision_embedding, tgt_size)) - return torch.vstack(res) - elif version == (2, 5): - vision_embedding = self.vpm( - pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask).last_hidden_state - vision_embedding = self.resampler(vision_embedding, tgt_sizes) - else: - vision_embedding = self.vpm(pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask, - tgt_sizes=tgt_sizes).last_hidden_state + return vlm_embedding, vision_hidden_states - def get_image_bounds(self, input_ids: torch.Tensor): + def _get_image_bounds(self, input_ids: torch.Tensor) -> torch.Tensor: tokenizer = cached_get_tokenizer(self.config._name_or_path, trust_remote_code=True) - if not hasattr(tokenizer, "slice_start_id"): - start_cond = input_ids == tokenizer.im_start_id - end_cond = input_ids == tokenizer.im_end_id - else: - start_cond = (input_ids == tokenizer.im_start_id) | ( - input_ids == tokenizer.slice_start_id) - end_cond = (input_ids == tokenizer.im_end_id) | ( - input_ids == tokenizer.slice_end_id) + start_cond = input_ids == tokenizer.im_start_id + end_cond = input_ids == tokenizer.im_end_id + if hasattr(tokenizer, "slice_start_id"): + start_cond |= (input_ids == tokenizer.slice_start_id) + end_cond |= (input_ids == tokenizer.slice_end_id) - image_start_tokens = torch.where(start_cond)[0] + image_start_tokens, = torch.where(start_cond) image_start_tokens += 1 - image_end_tokens = torch.where(end_cond)[0] + image_end_tokens, = torch.where(end_cond) valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) + if valid_image_nums == 0: - return [] - image_bound = torch.hstack([ + return torch.zeros((0, 2), device=input_ids.device) + + return torch.hstack([ image_start_tokens[:valid_image_nums].unsqueeze(-1), image_end_tokens[:valid_image_nums].unsqueeze(-1), ]) - return image_bound - - def get_vision_hidden_states(self, data: Dict[str, - Union[List[torch.Tensor], - torch.Tensor]]): - if "vision_hidden_states" not in data: - pixel_values = data["pixel_values"] - tgt_sizes = data["tgt_sizes"] - vision_hidden_states = [] - if self.version == (2, 0): - if pixel_values is not None and len(pixel_values) > 0: - vision_hidden_states = self.get_vision_embedding( - pixel_values) - else: - vision_hidden_states = torch.tensor([]).to( - data["input_ids"].device) - else: - device = self.vpm.embeddings.position_embedding.weight.device - dtype = self.vpm.embeddings.position_embedding.weight.dtype - all_pixel_values = [ - i.flatten(end_dim=1).permute(1, 0) for i in pixel_values - ] - if all_pixel_values: - tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) - max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) - all_pixel_values = torch.nn.utils.rnn.pad_sequence( - all_pixel_values, batch_first=True, padding_value=0.0) - B, L, _ = all_pixel_values.shape - all_pixel_values = all_pixel_values.permute( - 0, 2, 1).reshape(B, 3, -1, L) - patch_attn_mask = torch.zeros((B, 1, max_patches), - dtype=torch.bool, - device=device) - if self.version == (2, 5): - for i in range(B): - patch_attn_mask[i, :tgt_sizes[i][0] * - tgt_sizes[i][1]] = True - vision_embedding = self.vpm( - all_pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask - ).last_hidden_state - else: - for i in range(B): - patch_attn_mask[i, 0, :tgt_sizes[i][0] * - tgt_sizes[i][1]] = True - vision_embedding = self.vpm( - all_pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask, - tgt_sizes=tgt_sizes).last_hidden_state - - vision_hidden_states = self.resampler( - vision_embedding, tgt_sizes) - - else: # no image - dummy_feature = [] - vision_hidden_states = dummy_feature - else: - vision_hidden_states = data["vision_hidden_states"] - - return vision_hidden_states - - def get_embedding(self, data: Dict[str, Union[List[torch.Tensor], - torch.Tensor]]): - input_ids = data["input_ids"] - - vision_hidden_states = self.get_vision_hidden_states(data) - if vision_hidden_states is not None and len(vision_hidden_states) > 0: - image_bounds = self.get_image_bounds(input_ids) - else: - image_bounds = [] - - if hasattr(self.config, 'scale_emb'): - vlm_embedding = self.llm.embed_tokens( - input_ids) * self.config.scale_emb - else: - vlm_embedding = self.llm.embed_tokens(input_ids) - vision_hidden_states = [ - i.type(vlm_embedding.dtype) if isinstance(i, torch.Tensor) else i - for i in vision_hidden_states - ] - - if len(vision_hidden_states) > 0 and len(image_bounds) > 0: - vision_hidden_states = torch.cat(vision_hidden_states, dim=0) - image_indices = torch.stack([ - torch.arange(r[0], r[1], dtype=torch.long) - for r in image_bounds - ]).to(vlm_embedding.device) - vlm_embedding.scatter_( - 0, - image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]), - vision_hidden_states.view(-1, vision_hidden_states.shape[-1])) - return vlm_embedding, vision_hidden_states - - def process_multimodal_inputs(self, inputs: Dict[str, - Union[List[torch.Tensor], - torch.Tensor]]): - pixel_values = [] - tgt_sizes = [] - for b in range(len(inputs["pixel_values"])): - pixel_values += inputs["pixel_values"][b] - tgt_sizes += inputs["tgt_sizes"][b] - return { - "pixel_values": pixel_values, - "input_ids": inputs["input_ids"], - "tgt_sizes": tgt_sizes - } + def _parse_and_validate_inputs( + self, + input_ids: torch.Tensor, + **kwargs: object, + ) -> Optional[MiniCPMVImageInputs]: + pixel_values = kwargs.pop("pixel_values", []) + tgt_sizes = kwargs.pop("tgt_sizes", []) + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(tgt_sizes, (torch.Tensor, list)): + raise ValueError("Incorrect type of target sizes. " + f"Got type: {type(tgt_sizes)}") + + if len(pixel_values) != len(tgt_sizes): + raise ValueError("Inconsistent batch lengths, found: " + f"{len(pixel_values)} vs. {len(tgt_sizes)}") + + pixel_values_flat: List[torch.Tensor] = [] + tgt_sizes_flat: List[torch.Tensor] = [] + for b in range(len(pixel_values)): + pixel_values_flat += pixel_values[b] + tgt_sizes_flat += tgt_sizes[b] + + # NOTE: Input IDs does not contain image tokens during memory profiling, + # so we allow it to be empty + if len(pixel_values_flat) != len(tgt_sizes_flat): + raise ValueError("Inconsistent flattened lengths, found: " + f"{len(pixel_values_flat)} vs. " + f"{len(tgt_sizes_flat)}") + + if len(pixel_values_flat) == 0: + return None + + return MiniCPMVImageInputs( + image_bounds=self._get_image_bounds(input_ids), + pixel_values=pixel_values_flat, + tgt_sizes=torch.stack(tgt_sizes_flat), + ) def forward( self, @@ -680,23 +603,20 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - **kwargs: object, - ): - inputs = { - "pixel_values": kwargs.pop("pixel_values", []), - "input_ids": input_ids, - "tgt_sizes": kwargs.pop("tgt_sizes", None), - } - inputs = self.process_multimodal_inputs(inputs) - - vlm_embeddings, vision_hidden_states = self.get_embedding(inputs) - - output = self.llm(input_ids=None, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors, - inputs_embeds=vlm_embeddings) + **kwargs: Any, + ) -> torch.Tensor: + image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs) + + vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) + + output = self.llm( + input_ids=None, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=vlm_embeddings, + ) return output def compute_logits(self, hidden_states: torch.Tensor, @@ -735,13 +655,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # the checkpoint. Skip them. continue use_default_weight_loading = False - if "vpm" in name or 'resampler' in name: - # We only do sharding for language model and - # not vision model for now. + if self.is_default_weight_loading(name): use_default_weight_loading = True else: - for (param_name, weight_name, - shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue param = params_dict[name.replace(weight_name, param_name)] @@ -755,3 +672,341 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def init_llm( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + raise NotImplementedError + + def init_vision_module(self) -> nn.Module: + raise NotImplementedError + + def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + raise NotImplementedError + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + raise NotImplementedError + + def is_default_weight_loading(self, name: str) -> bool: + raise NotImplementedError + + +class MiniCPMV2(MiniCPMVBaseModel): + + def __init__( + self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(config, multimodal_config, cache_config, quant_config) + assert self.version == (2, 0) + + def init_llm( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + return MiniCPMModel(config, + cache_config=cache_config, + quant_config=quant_config) + + def init_vision_module(self) -> nn.Module: + # TODO :refactor this vision model + try: + import timm + except ImportError: + raise ImportError("Please install timm==0.9.10") from ImportError + with set_default_torch_dtype(torch.float16): + model = timm.create_model( + "vit_so400m_patch14_siglip_384.webli", + pretrained=False, + num_classes=0, + dynamic_img_size=True, + dynamic_img_pad=True, + ) + + if (isinstance(model, timm.models.VisionTransformer) + and model.attn_pool is not None): + model.attn_pool = torch.nn.Identity() + + if self.config.drop_vision_last_layer: + model.blocks = model.blocks[:-1] + + return model + + def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + with set_default_torch_dtype(torch.float16): + resampler = Resampler2( + embed_dim=embed_dim, + num_heads=embed_dim // 128, + grid_size=int(math.sqrt(self.config.query_num)), + kv_dim=vision_dim, + adaptive=True, + ) + + return resampler + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + res = [] + dtype = self.vpm.pos_embed.data.dtype + for pixel_value in pixel_values: + H, W = pixel_value[0].shape[-2:] + tgt_size = ( + math.ceil(H / self.vpm.patch_embed.patch_size[0]), + math.ceil(W / self.vpm.patch_embed.patch_size[0]), + ) + vision_embedding = self.vpm.forward_features( + pixel_value.unsqueeze(0).type(dtype)) + if (hasattr(self.vpm, "num_prefix_tokens") + and self.vpm.num_prefix_tokens > 0): + vision_embedding = vision_embedding[:, self.vpm. + num_prefix_tokens:] + res.append(self.resampler(vision_embedding, tgt_size)) + return torch.vstack(res) + + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + + return self.get_vision_embedding(pixel_values) + + def is_default_weight_loading(self, name: str) -> bool: + return "resampler" in name or "vpm" in name + + +class MiniCPMV2_5(MiniCPMVBaseModel): + + def __init__( + self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(config, multimodal_config, cache_config, quant_config) + assert self.version == (2, 5) + + def init_llm( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + return LlamaModel(config, + cache_config=cache_config, + quant_config=quant_config) + + def init_vision_module(self) -> nn.Module: + model = Idefics2VisionTransformer(self.config.vision_config) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + return model + + def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + with set_default_torch_dtype(torch.float16): + resampler = Resampler2_5( + num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + ) + return resampler + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + vision_embedding = self.vpm(pixel_values, + patch_attention_mask=patch_attn_mask) + vision_embedding = self.resampler(vision_embedding, tgt_sizes) + return vision_embedding + + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + tgt_sizes = data["tgt_sizes"] + + device = self.vpm.embeddings.position_embedding.weight.device + dtype = self.vpm.embeddings.position_embedding.weight.dtype + all_pixel_values_lst = [ + i.flatten(end_dim=1).permute(1, 0) for i in pixel_values + ] + + max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() + assert isinstance(max_patches, int) + + all_pixel_values = torch.nn.utils.rnn.pad_sequence( + all_pixel_values_lst, batch_first=True, padding_value=0.0) + B, L, _ = all_pixel_values.shape + all_pixel_values = all_pixel_values.permute(0, 2, + 1).reshape(B, 3, -1, L) + + patch_attn_mask = torch.zeros((B, 1, max_patches), + dtype=torch.bool, + device=device) + for i in range(B): + patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True + + return self.get_vision_embedding(all_pixel_values.type(dtype), + patch_attn_mask, tgt_sizes) + + def is_default_weight_loading(self, name: str) -> bool: + return "resampler" in name + + +# NOTE: Currently, information about this model is unavailable. We are +# temporarily using `MiniCPMVQwen2` as it's name. The name may need +# to be modified in the future. +class MiniCPMVQwen2(MiniCPMVBaseModel): + + def __init__( + self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(config, multimodal_config, cache_config, quant_config) + + def init_llm( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + return Qwen2Model(config, + cache_config=cache_config, + quant_config=quant_config) + + def init_vision_module(self) -> nn.Module: + # A custom version of SiglipVisionTransformer, won't work with TP + from vllm.model_executor.models.na_vit import SiglipVisionTransformer + + if self.config._attn_implementation == "flash_attention_2": + self.config.vision_config._attn_implementation = "flash_attention_2" + else: + # not support sdpa + self.config.vision_config._attn_implementation = "eager" + model = SiglipVisionTransformer(self.config.vision_config) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + return model + + def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + with set_default_torch_dtype(torch.float16): + resampler = Resampler2_5( + num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + ) + + return resampler + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + vision_embedding = self.vpm( + pixel_values, + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes, + ).last_hidden_state + return vision_embedding + + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + tgt_sizes = data["tgt_sizes"] + + device = self.vpm.embeddings.position_embedding.weight.device + dtype = self.vpm.embeddings.position_embedding.weight.dtype + all_pixel_values_lst = [ + i.flatten(end_dim=1).permute(1, 0) for i in pixel_values + ] + + max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() + assert isinstance(max_patches, int) + + all_pixel_values = torch.nn.utils.rnn.pad_sequence( + all_pixel_values_lst, batch_first=True, padding_value=0.0) + B, L, _ = all_pixel_values.shape + all_pixel_values = all_pixel_values.permute(0, 2, + 1).reshape(B, 3, -1, L) + + patch_attn_mask = torch.zeros((B, 1, max_patches), + dtype=torch.bool, + device=device) + for i in range(B): + patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True + vision_embedding = self.vpm( + all_pixel_values.type(dtype), + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes, + ).last_hidden_state + + return self.resampler(vision_embedding, tgt_sizes) + + def is_default_weight_loading(self, name: str) -> bool: + return "resampler" in name or "vpm" in name + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv) +@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv) +class MiniCPMV(MiniCPMVBaseModel): + """ + Different versions of MiniCPMV use different visual encoders and LLMs, + which is not conducive to the current integration logic of LoRA and + bitsandbytes in vLLM. Therefore, it is necessary to separate them. + """ + + def __new__( + cls, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + if not hasattr(config, "version"): + if config.hidden_size == 2304 and config.query_num == 64: + version = (2, 0) + else: + version = (2, 5) + else: + version = str(config.version).split(".") + version = tuple([int(x) for x in version]) + # Dispatch class based on version + if version == (2, 0): + instance_class = MiniCPMV2 + elif version == (2, 5): + instance_class = MiniCPMV2_5 + else: + instance_class = MiniCPMVQwen2 + return instance_class(config, multimodal_config, cache_config, + quant_config) diff --git a/vllm/model_executor/models/na_vit.py b/vllm/model_executor/models/na_vit.py index 871e4128b66e..1d6f26f0d4fb 100644 --- a/vllm/model_executor/models/na_vit.py +++ b/vllm/model_executor/models/na_vit.py @@ -100,7 +100,7 @@ def _get_unpad_data(attention_mask): indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens,