From 19154e0f9c1cb43bf1bc09ab6609b64d8435ee08 Mon Sep 17 00:00:00 2001 From: mrsalehi Date: Tue, 1 Oct 2024 20:03:10 -0700 Subject: [PATCH 01/12] molmo vllm integration --- examples/offline_inference_molmo.py | 130 +++ vllm/entrypoints/chat_utils.py | 2 + vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/molmo.py | 1130 ++++++++++++++++++++++++ 4 files changed, 1263 insertions(+) create mode 100644 examples/offline_inference_molmo.py create mode 100644 vllm/model_executor/models/molmo.py diff --git a/examples/offline_inference_molmo.py b/examples/offline_inference_molmo.py new file mode 100644 index 000000000000..3d7434ea8357 --- /dev/null +++ b/examples/offline_inference_molmo.py @@ -0,0 +1,130 @@ +import argparse +import numpy as np +import requests +from io import BytesIO +import base64 +from PIL import Image, ImageFile, ImageOps +import torch +from typing import Optional + +from vllm import LLM +from vllm.sampling_params import SamplingParams +from vllm_molmo.molmo import MolmoForCausalLM + + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +def download_image_to_numpy(url): + # Send a GET request to the URL + response = requests.get(url) + + # Check if the request was successful + if response.status_code == 200: + # Open the image from the response content + image = Image.open(BytesIO(response.content)).convert("RGB") + + image = ImageOps.exif_transpose(image) + + # Convert the image to a NumPy array + image_array = np.array(image).astype(np.uint8) + + return image_array + else: + raise Exception(f"Failed to download image. Status code: {response.status_code}") + + +def vllm_generate(): + inputs = [ + { + "prompt": "Describe this image.", + "multi_modal_data": {"image": download_image_to_numpy("https://picsum.photos/id/9/1080/720")} + }, + { + "prompt": "Describe what you see in this image.", + "multi_modal_data": {"image": download_image_to_numpy("https://picsum.photos/id/23/1080/720")} + }, + ] + + outputs = llm.generate( + inputs, + sampling_params=sampling_params + ) + + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + + +def vllm_chat(): + url_1 = "https://picsum.photos/id/9/1080/720" + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe the image." + }, + { + "type": "image_url", + "image_url": { + "url": url_1 + } + }, + ], + }, + { + "role": "assistant", + "content": "The image shows some objects.", + }, + { + "role": "user", + "content": "What objects do you exactly see in the image?", + }, + ] + + outputs = llm.chat( + messages, + sampling_params=sampling_params + ) + + # Error: Invalid message role {{ message['role'] }} at index {{ loop.index }} + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + + +def set_tf_memory_growth(): + import tensorflow as tf + gpus = tf.config.experimental.list_physical_devices('GPU') + + if gpus: + try: + # Set memory growth for each GPU + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + print("Memory growth set for GPUs") + except RuntimeError as e: + print(e) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="vllm example for Molmo-7B-O-0924, Molmo-7B-D-0924, Molmo-72B-0924" + ) + parser.add_argument("--model_path", type=str, default="allenai/Molmo-7B-D-0924") + args = parser.parse_args() + llm = LLM( + model=args.model_path, + trust_remote_code=True, + gpu_memory_utilization=0.95, + dtype="bfloat16", + ) + sampling_params = SamplingParams( + max_tokens=768, + temperature=0, + ) + set_tf_memory_growth() + vllm_generate() + vllm_chat() \ No newline at end of file diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 130f3ba49f3e..21bf1515a446 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -163,6 +163,8 @@ def _placeholder_str(self, modality: ModalityStr, return "<|image|>" if model_type == "qwen2_vl": return "<|vision_start|><|image_pad|><|vision_end|>" + if model_type == "molmo": + return "" raise TypeError(f"Unknown model type: {model_type}") elif modality == "audio": diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index ad6cf659c3e6..1aafaf754dc9 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -105,6 +105,7 @@ "UltravoxModel": ("ultravox", "UltravoxModel"), "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), + "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), } _CONDITIONAL_GENERATION_MODELS = { "BartModel": ("bart", "BartForConditionalGeneration"), diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py new file mode 100644 index 000000000000..3b7a003c5dab --- /dev/null +++ b/vllm/model_executor/models/molmo.py @@ -0,0 +1,1130 @@ +from array import array +from functools import lru_cache +import re +import logging +from dataclasses import dataclass +from typing import Iterable, List, Optional, Tuple, Union, Any, Mapping +from PIL import Image +import math + +import torch +from torch import nn +from torch.nn import functional as F +from einops import rearrange +from transformers import PretrainedConfig +import vllm.envs as envs +from vllm.attention import Attention, AttentionMetadata +from vllm.attention.selector import (_Backend, backend_name_to_enum, + get_global_forced_attn_backend) +from vllm.config import CacheConfig, MultiModalConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.multimodal import ( + MULTIMODAL_REGISTRY, + MultiModalInputs, +) +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.utils import make_layers +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.sequence import ( + VLLM_TOKEN_ID_ARRAY_TYPE, + SequenceData, +) +from vllm.model_executor import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.processor import get_processor +from vllm.platforms import current_platform + +log = logging.getLogger(__name__) + + +@dataclass +class VisionBackboneConfig: + image_default_input_size: Tuple[int, int] = (336, 336) + image_patch_size: int = 14 + image_pos_patch_size: int = 14 + image_emb_dim: int = 1024 + image_num_heads: int = 16 + image_num_key_value_heads: int = 16 + image_num_layers: int = 23 + image_mlp_dim: int = 4096 + image_mlp_activations: str = "quick_gelu" + image_num_pos: int = 577 + image_norm_eps: float = 1e-5 + + def __post_init__(self): + self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment] + + @property + def image_num_patch(self): + h, w = self.image_default_input_size + return h // self.image_patch_size, w // self.image_patch_size + + +class ViTMLP(nn.Module): + """MLP used in Vision Transformer.""" + def __init__( + self, + config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.w1 = ColumnParallelLinear( + config.image_emb_dim, + config.image_mlp_dim, + bias=True, + quant_config=quant_config, + ) + # Activation function. + assert config.image_mlp_activations == "quick_gelu" + self.act = QuickGELU() + self.w2 = RowParallelLinear( + config.image_mlp_dim, + config.image_emb_dim, + bias=True, + quant_config=quant_config, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.w1(x) + x = self.act(x) + x, _ = self.w2(x) + return x + + +class MultiHeadDotProductAttention(nn.Module): + """Multi-head attention used in Vision Transformer.""" + def __init__( + self, + config: VisionBackboneConfig, + use_bias: bool = True, + nlayers: int = 1, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.hidden_size = config.image_emb_dim + self.total_num_heads = config.image_num_heads + tp_size = get_tensor_model_parallel_world_size() + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % tp_size == 0 + + self.num_heads = self.total_num_heads // tp_size + self.head_dim = self.hidden_size // self.total_num_heads + + self.total_num_kv_heads = config.image_num_key_value_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.wq = ColumnParallelLinear( + nlayers * self.hidden_size, + self.total_num_heads * self.head_dim, + bias=use_bias, + quant_config=quant_config, + ) + self.wk = ColumnParallelLinear( + nlayers * self.hidden_size, + self.total_num_kv_heads * self.head_dim, + bias=use_bias, + quant_config=quant_config, + ) + self.wv = ColumnParallelLinear( + nlayers * self.hidden_size, + self.total_num_kv_heads * self.head_dim, + bias=use_bias, + quant_config=quant_config, + ) + self.wo = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=use_bias, + quant_config=quant_config, + ) + + # Detect attention implementation. + selected_backend: Optional[_Backend] = get_global_forced_attn_backend() + if selected_backend is None: + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = backend_name_to_enum(backend_by_env_var) + if selected_backend is None: + # For Volta and Turing GPUs, use xformers instead. + device_available = current_platform.get_device_capability()[0] >= 8 + if device_available: + from transformers.utils import is_flash_attn_2_available + if is_flash_attn_2_available(): + self._use_flash_attn = True + else: + log.warning( + "Current Molmo implementation has a bug with " + "`vllm-flash-attn` inside vision module, so we use " + "xformers backend instead. You can run `pip install " + "flash-attn to use flash-attention backend.") + self._use_flash_attn = False + else: + self._use_flash_attn = False + else: + if selected_backend == _Backend.FLASH_ATTN: + self._use_flash_attn = True + elif selected_backend == _Backend.XFORMERS: + self._use_flash_attn = False + else: + raise RuntimeError( + f"Molmo does not support {selected_backend} backend now." + ) + + def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor: + + if inputs_kv is not None: + inputs_k = inputs_kv + inputs_v = inputs_kv + else: + inputs_k = inputs_q + inputs_v = inputs_q + + xq, _ = self.wq(inputs_q) + xk, _ = self.wk(inputs_k) + xv, _ = self.wv(inputs_v) + q_shape = xq.size()[:-1] + (self.num_heads, self.head_dim) + kv_shape = xk.size()[:-1] + (self.num_kv_heads, self.head_dim) + xq = xq.view(*q_shape) + xk = xk.view(*kv_shape) + xv = xv.view(*kv_shape) + + if self._use_flash_attn: + from flash_attn import flash_attn_func + output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False) + else: + from xformers import ops as xops + output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0) + + output = rearrange(output, "b s h d -> b s (h d)").contiguous() + output, _ = self.wo(output) + + return output + + +class ResidualAttentionBlock(nn.Module): + """Residual attention block used in Vision Transformer.""" + def __init__( + self, + config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.attention = MultiHeadDotProductAttention(config, quant_config=quant_config) + self.feed_forward = ViTMLP(config, quant_config) + self.attention_norm = nn.LayerNorm( + config.image_emb_dim, + eps=config.image_norm_eps, + ) + self.ffn_norm = nn.LayerNorm( + config.image_emb_dim, + eps=config.image_norm_eps, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attention(self.attention_norm(x)) + x = x + self.feed_forward(self.ffn_norm(x)) + return x + + +class BlockCollection(nn.Module): + """Collection of residual attention blocks used in Vision Transformer.""" + def __init__( + self, + config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.resblocks = nn.ModuleList( + [ResidualAttentionBlock(config, quant_config) for _ in range(config.image_num_layers)] + ) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + hidden_states = [] + for r in self.resblocks: + x = r(x) + hidden_states.append(x) + return hidden_states + + +def _expand_token(token: torch.Tensor, batch_size: int) -> torch.Tensor: + return token.view(1, 1, -1).expand(batch_size, -1, -1) + + +class VisionTransformer(nn.Module): + """Vision Transformer used in Vision Backbone.""" + def __init__( + self, + config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + scale = config.image_emb_dim ** -0.5 + self.patch_num = config.image_num_patch + self.class_embedding = nn.Parameter(torch.randn(config.image_emb_dim) * scale) + self.num_prefix_tokens: int = 1 + self.positional_embedding = nn.Parameter( + torch.randn(config.image_num_pos, config.image_emb_dim) * scale + ) + image_patch_size = config.image_patch_size + self.patch_embedding = nn.Linear( + image_patch_size * image_patch_size * 3, + config.image_emb_dim, + bias=False, + ) + self.pre_ln = nn.LayerNorm(config.image_emb_dim, eps=config.image_norm_eps) + self.transformer = BlockCollection(config, quant_config) + + def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: + cls_emb = self.positional_embedding[0:1] + pos_emb = self.positional_embedding[1:] + + pos_emb = pos_emb.reshape( + (int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1]) + ) + + (patch_num_0, patch_num_1) = patch_num + + if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: + # from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2) + pos_emb = F.interpolate( + pos_emb, size=(patch_num_0, patch_num_1), mode="bicubic", align_corners=False, antialias=True, + ) + pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0) + + pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1]) + x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype) + return x + + def forward(self, x: torch.Tensor, patch_num: int = None) -> List[torch.Tensor]: + """ + : param x: (batch_size, num_patch, n_pixels) + """ + if patch_num is None: + patch_num = self.patch_num + B, N, D = x.shape + + x = self.patch_embedding(x) + + # class embeddings and positional embeddings + x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) + x = self.add_pos_emb(x, patch_num) + + x = self.pre_ln(x) + + hidden_states = self.transformer(x) + return hidden_states + + +class MolmoAttention(nn.Module): + """Molmo's LLM attention.""" + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % tp_size == 0 + + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads or self.total_num_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + # Attention input projection. Projects x -> (q, k, v) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.qkv_bias, + quant_config=quant_config, + ) + + self.k_norm: Optional[nn.Module] = None + self.q_norm: Optional[nn.Module] = None + if config.attention_layer_norm: + self.k_norm = RMSNorm(self.kv_size, eps=config.layer_norm_eps) + self.q_norm = RMSNorm(self.q_size, eps=config.layer_norm_eps) + + # Rotary embeddings. + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + # Attention output projection. + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.q_norm is not None and self.k_norm is not None: + q = self.q_norm.forward_native(q) + k = self.k_norm.forward_native(k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class MolmoMLP(nn.Module): + """Molmo's LLM mlp.""" + + def __init__( + self, + config: PretrainedConfig, + input_dim: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size // 2 + + # Feed-forward input projection. + self.gate_up_proj = MergedColumnParallelLinear( + input_dim or self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + + # Activation function. + self.act_fn = SiluAndMul() + + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class MolmoDecoderLayer(nn.Module): + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + # Attention block. + self.self_attn = MolmoAttention(config, cache_config, quant_config) + + # MLP block. + self.mlp = MolmoMLP(config, quant_config=quant_config) + + # LayerNorm + assert config.layer_norm_type == "rms" + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.norm_after = config.norm_after + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Self Attention + if self.norm_after: + residual = hidden_states + elif residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + if self.norm_after: + hidden_states = self.input_layernorm(hidden_states) + hidden_states = hidden_states + residual + residual = hidden_states + + if not self.norm_after: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + if self.norm_after: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + residual + residual = None + return hidden_states, residual + + +class MolmoVisionBackbone(nn.Module): + def __init__( + self, + config: PretrainedConfig, + vision_config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + # TODO: vit_layers hard-coded for now. Consider making it configurable. + self.vit_layers = [-2, -9] + self.image_num_patch = vision_config.image_num_patch + self.llm_patches_per_crop = ( + (self.image_num_patch[0] + 1) // 2, + (self.image_num_patch[1] + 1) // 2, + ) + self.image_vit = VisionTransformer(vision_config, quant_config=quant_config) + self.num_prefix_tokens = self.image_vit.num_prefix_tokens + assert self.num_prefix_tokens in {0, 1}, "Only 0 or 1 prefix tokens are supported" + self.image_pooling_2d = MultiHeadDotProductAttention( + vision_config, nlayers=len(self.vit_layers), quant_config=quant_config + ) + self.image_projector = MolmoMLP( + config, + input_dim=vision_config.image_emb_dim, + quant_config=quant_config, + ) + + image_dim = vision_config.image_emb_dim * len(self.vit_layers) + self.pad_embed = nn.Parameter(torch.zeros((2, image_dim))) + + def encode_image(self, images: torch.Tensor) -> torch.Tensor: + """ + : param images: (batch_size, num_crops, num_patch, n_pixels) + """ + B, T, N, D = images.shape + + mask = ~torch.all(images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True) + + images = images.view(B * T, N, D) + image_features = self.image_vit(images) + + if self.vit_layers is not None: + features = [] + for layer in self.vit_layers: + features.append(image_features[layer]) + image_features = torch.cat(features, dim=-1) + else: + image_features = image_features[-1] + + cls_embed: torch.Tensor = None + if self.num_prefix_tokens > 0: + cls_embed = image_features[:, 0] + image_features = image_features[:, 1:] + + image_features = image_features * mask + image_features = image_features.view(B, T, N, -1) + + cls_embed = cls_embed.view(B, T, -1) if cls_embed is not None else None + + return image_features, cls_embed + + def forward(self, images: torch.Tensor, image_masks: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + + # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) + batch_size, num_image = images.shape[:2] + image_features, cls_embed = self.encode_image(images) + + og_dtype = image_features.dtype + assert image_masks is not None + pad_embed = self.pad_embed[:, None, None, None, :] + all_pad = image_masks == 0 + partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=torch.float32) + all_pad = all_pad.to(dtype=torch.float32) + image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1) + image_features = image_features + pad_embed[1] * torch.unsqueeze(partial_pad, -1) + + image_features = image_features.to(og_dtype) + + image_features = image_features.reshape( + (batch_size, num_image) + self.image_num_patch + (-1,), + ) + + if self.image_num_patch[0] % 2 == 1: + # Pad so we can still pool 2x2 patches + image_features = F.pad( + image_features, + (0, 0, 0, 1, 0, 1, 0, 0, 0, 0), + ) + + # image pooling + image_features = rearrange( + image_features, + 'b n (h dh) (w dw) c -> (b n h w) (dh dw) c', + dh=2, + dw=2, + ) + + query = image_features.mean(-2, keepdim=True) + image_features = self.image_pooling_2d(query, image_features) + + h, w = self.llm_patches_per_crop + image_features = image_features.view(batch_size, num_image, h * w, -1) + + image_features = self.image_projector(image_features) + + # image_features: (batch_size, num_image, num_patch, d_model) + # cls_embed: (batch_size, num_image, d_model) + return image_features, cls_embed + + +class MolmoModel(nn.Module): + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + self.embedding_size = config.embedding_size or config.vocab_size + # TODO: extra embedding_size hard-coded for now. Consider making it configurable. + self.embedding_size += 128 + self.embed_tokens = VocabParallelEmbedding( + self.embedding_size, + config.hidden_size, + quant_config=quant_config, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MolmoDecoderLayer(config, cache_config, quant_config), + prefix=f"{prefix}.layers", + ) + + assert config.layer_norm_type == "rms" + self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + # Apply blocks one-by-one. + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + if residual is not None: + hidden_states, _ = self.norm(hidden_states, residual) + else: + hidden_states = self.norm(hidden_states) + return hidden_states + + +cached_get_processor = lru_cache(get_processor) + + +def get_num_patches(num_tiles, crop_patches, left_margin, right_margin, pooling_size): + crop_window_patches = crop_patches - (left_margin + right_margin) + if num_tiles > 1: + left_crop_window_patches = (crop_window_patches + left_margin + pooling_size - 1) // pooling_size * pooling_size + middle_crop_window_patches = (crop_window_patches + pooling_size - 1) // pooling_size * pooling_size + right_crop_window_patches = (crop_window_patches + right_margin + pooling_size - 1) // pooling_size * pooling_size + return left_crop_window_patches + (num_tiles - 2) * middle_crop_window_patches + right_crop_window_patches + else: + single_crop_window_patches = (crop_patches + pooling_size - 1) // pooling_size * pooling_size + return single_crop_window_patches + + +def get_tokens(tiling_h, tiling_w, crop_patches, left_margin, right_margin, pooling_size): + h = get_num_patches(tiling_h, crop_patches, left_margin, right_margin, pooling_size) + w = get_num_patches(tiling_w, crop_patches, left_margin, right_margin, pooling_size) + per_row = w // pooling_size + 1 + joint = per_row * (h // pooling_size) + 2 + image_token_length = (crop_patches + pooling_size - 1) // pooling_size + resize = (image_token_length + 1) * image_token_length + 2 + return resize + joint + + +def get_max_tokens(max_crops, crop_patches, left_margin, right_margin, pooling_size): + tilings = [] + for i in range(1, max_crops+1): + for j in range(1, max_crops + 1): + if i * j <= max_crops: + tilings.append((i, j)) + tokens = [get_tokens(tilings[i][0], tilings[i][1], crop_patches, left_margin, right_margin, pooling_size) for i in range(len(tilings))] + return max(tokens) + + +def get_max_molmo_image_tokens(ctx: InputContext) -> int: + processor = cached_get_processor(ctx.model_config.model, trust_remote_code=True) + image_processor = processor.image_processor + max_llm_image_tokens = get_max_tokens( + image_processor.max_crops, + image_processor.base_image_input_size[0] // image_processor.image_patch_size, + image_processor.overlap_margins[0], + image_processor.overlap_margins[1], + 2, + ) + return max_llm_image_tokens + + +def image_input_mapper_for_molmo( + ctx: InputContext, + data: Union[Image.Image, List[Image.Image]], +): + return MultiModalInputs(data) + + +def dummy_data_for_molmo( + ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] +): + processor = cached_get_processor(ctx.model_config.model, trust_remote_code=True) + image_processor = processor.image_processor + + base_image_input_d = image_processor.image_patch_size + left_margin, right_margin = image_processor.overlap_margins + max_crops = image_processor.max_crops + + # Assume: prompt_token_ids always starts with bos_token_id followed image tokens + max_llm_image_tokens = get_max_molmo_image_tokens(ctx) + if seq_len - max_llm_image_tokens - 1 < 0: + raise RuntimeError( + f"Molmo cannot process {max_crops} crops in a prompt, " + "please increase max_model_len or reduce number of crops" + ) + + # The vertical image has the maximum number of image tokens due to column tokens. + tiling = (max_crops, 1) + total_margin_pixels = base_image_input_d * (right_margin + left_margin) + crop_patches = image_processor.base_image_input_size[0] // base_image_input_d + crop_window_patches = crop_patches - (right_margin + left_margin) + crop_window_size = crop_window_patches * base_image_input_d + + h = crop_window_size * tiling[0] + total_margin_pixels + w = crop_window_size * tiling[1] + total_margin_pixels + + dummy_image = Image.new( + "RGB", (w, h), color="red" + ) + + out = processor.process("dummy prompt", dummy_image) + + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + out["input_ids"][:1 + max_llm_image_tokens]) + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - max_llm_image_tokens - 1) + dummy_seqdata = SequenceData(token_ids) + dummy_imgdata = { + "images": out["images"], + "image_input_idx": out["image_input_idx"], + } + if "image_masks" in out: + dummy_imgdata["image_masks"] = out["image_masks"] + dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long) + return dummy_seqdata, {"image": dummy_imgdata} + + +def pad_images( + max_total_crops: int, + images: torch.Tensor, + image_input_idx: torch.Tensor, + image_masks: Optional[torch.Tensor] = None, +): + n = max_total_crops - images.shape[0] + images = F.pad( + images, (0, 0, 0, 0, 0, n), value=-1 + ) + image_input_idx = F.pad( + image_input_idx, (0, 0, 0, n), value=-1 + ) + if image_masks is not None: + image_masks = F.pad( + image_masks, (0, 0, 0, n), value=-1 + ) + return images, image_input_idx, image_masks + + +def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs): + prompt = llm_inputs["prompt"] + multi_modal_data = llm_inputs.get("multi_modal_data") + image = multi_modal_data.get("image") + processor = cached_get_processor(ctx.model_config.model, trust_remote_code=True) + + if re.match(r"^User:[\s\S]*?(Assistant:)*$", prompt): + out = processor.process(prompt, image, message_format="none") + else: + out = processor.process(prompt, image) + + image_processor = processor.image_processor + max_total_crops = 1 + image_processor.max_crops + if image is not None: + images, image_input_idx, image_masks = pad_images( + max_total_crops, + out["images"], + out["image_input_idx"], + out.get("image_masks"), + ) + else: + base_image_input_size = image_processor.base_image_input_size + image_patch_size = image_processor.image_patch_size + image_num_patch = ( + base_image_input_size[0] // image_patch_size, + base_image_input_size[1] // image_patch_size, + ) + n_pixels = image_patch_size * image_patch_size * 3 + n_patches = image_num_patch[0] * image_num_patch[1] + tokens_per_image = image_processor.image_token_length_w * image_processor.image_token_length_h + images = torch.full( + (max_total_crops, n_patches, n_pixels), -1, dtype=torch.float32, + ) + image_input_idx = torch.full( + (max_total_crops, tokens_per_image), -1, dtype=torch.int32, + ) + if image_processor.image_padding_mask: + image_masks = torch.full( + (max_total_crops, n_patches), -1, dtype=torch.float32, + ) + + image_data = dict( + images=images, + image_input_idx=image_input_idx, + ) + if image_masks is not None: + image_data["image_masks"] = image_masks + + image_data["seq_len"] = torch.tensor(len(out["input_ids"]), dtype=torch.long) + + multi_modal_data = dict(image=image_data) + + return LLMInputs( + prompt_token_ids=out["input_ids"], + prompt=llm_inputs["prompt"], + multi_modal_data=multi_modal_data, + ) + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(image_input_mapper_for_molmo) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo) +@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) +class MolmoForCausalLM(nn.Module, SupportsMultiModal): + + def __init__( + self, + config: PretrainedConfig, + multimodal_config: Optional[MultiModalConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[Mapping[str, Any]] = None, + ): + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + + vision_config = VisionBackboneConfig() + self.vision_backbone = MolmoVisionBackbone(config, vision_config, quant_config) + self.model = MolmoModel(config, cache_config, quant_config) + + if self.config.weight_tying: + self.lm_head = self.model.transformer.wte + else: + self.lm_head = ParallelLMHead( + config.embedding_size or config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + + self.logits_processor = LogitsProcessor(config.embedding_size or config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.LongTensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + images: Optional[torch.Tensor] = None, + image_masks: Optional[torch.Tensor] = None, + image_input_idx: Optional[torch.Tensor] = None, + seq_len: Optional[torch.Tensor] = None, + **kwargs, + ) -> SamplerOutput: + + has_image = images is not None + num_image: Optional[int] = None + + if has_image: + x = self.model.embed_tokens(input_ids) + batch_size = images.size(0) + if not isinstance(seq_len, torch.Tensor): + seq_len = torch.tensor(seq_len, device=x.device) + images = images.to(device=x.device, dtype=x.dtype) + image_features: torch.Tensor + image_features, cls_embed = self.vision_backbone(images=images, image_masks=image_masks) + num_image, num_patch = image_features.shape[1:3] + assert image_input_idx.shape == (batch_size, num_image, num_patch) + + image_features = image_features.to(x.device) + + # insert the image feature into the embedding. + image_features = image_features.view(batch_size, num_image * num_patch, -1) + image_input_idx = image_input_idx.view(batch_size, num_image * num_patch) + + valid = image_input_idx >= 0 + image_features = image_features * valid[:, :, None].to(image_features.dtype) + image_features = image_features.view(batch_size * num_image * num_patch, -1).contiguous() + + image_input_idx = image_input_idx * valid.to(image_input_idx.dtype) + offset = torch.cat([seq_len.new_zeros((1)), seq_len.cumsum(dim=0)[:-1]], dim=0)[:, None] + image_input_idx = image_input_idx + offset.to(image_input_idx.dtype) + image_input_idx = image_input_idx.flatten()[:, None] + mat = image_input_idx == torch.arange(seq_len.sum().item(), device=x.device)[None, :] + mat = mat.to(image_features.dtype) + x = x + torch.einsum('nd,nm->md', image_features, mat) + + inputs_embeds = x + + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + + embedding_weight = dict() + projector_weight = dict() + for name, loaded_weight in weights: + log.info(f"Original name: {name}") + if "rotary_emb.inv_freq" in name: + log.info(f"Skipping {name}") + continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + + if "wte.embedding" in name: + embedding_weight["embedding"] = loaded_weight + continue + + if "wte.new_embedding" in name: + embedding_weight["new_embedding"] = loaded_weight + continue + + if "vision_backbone" in name: + if name.startswith("model"): + name = name[len("model."):] + if 'image_projector' in name: + if 'w1' in name: + projector_weight['gate_proj'] = loaded_weight + elif 'w3' in name: + projector_weight['up_proj'] = loaded_weight + elif 'w2' in name: + projector_weight['down_proj'] = loaded_weight + else: + raise ValueError(f"Unexpected projector weight: {name}") + continue + else: + if "ln_f.weight" in name: + name = "model.norm.weight" + + if "transformer.blocks" in name: + name = name.replace("transformer.blocks", "layers") + + if "attn_out" in name: + name = name.replace("attn_out", "self_attn.o_proj") + + if "att_proj" in name: + name = name.replace("att_proj", "self_attn.qkv_proj") + + if 'q_norm' in name: + name = name.replace("q_norm", "self_attn.q_norm") + + if 'k_norm' in name: + name = name.replace("k_norm", "self_attn.k_norm") + + if "ff_proj" in name: + name = name.replace("ff_proj", "mlp.gate_up_proj") + assert 'weight' in name + up_weight, gate_weight = loaded_weight.chunk(2, dim=0) + loaded_weight = torch.cat([gate_weight, up_weight], dim=0) + + if "ff_out" in name: + if "layers" in name: + name = name.replace("ff_out", "mlp.down_proj") + else: + # lm head + name = name.replace("model.transformer.ff_out", "lm_head") + + if "attn_norm" in name: + name = name.replace("attn_norm", "input_layernorm") + + if "ff_norm" in name: + name = name.replace("ff_norm", "post_attention_layernorm") + try: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + except KeyError: + print(params_dict.keys()) + raise + weight_loader = getattr(param, "weight_loader", default_weight_loader) + try: + weight_loader(param, loaded_weight) + except: + raise + + gate_up_proj_weight = torch.cat( + [projector_weight["gate_proj"], projector_weight["up_proj"]], dim=0 + ) + name = "vision_backbone.image_projector.gate_up_proj.weight" + try: + param = params_dict[name] + except KeyError: + print(f"no {name}") + raise + weight_loader = getattr(param, "weight_loader", default_weight_loader) + try: + weight_loader(param, gate_up_proj_weight) + except: + raise + down_proj_weight = projector_weight["down_proj"] + name = "vision_backbone.image_projector.down_proj.weight" + try: + param = params_dict[name] + except KeyError: + print(f"no {name}") + raise + weight_loader = getattr(param, "weight_loader", default_weight_loader) + try: + weight_loader(param, down_proj_weight) + except: + raise + + embedding_weight = torch.cat( + [embedding_weight["embedding"], embedding_weight["new_embedding"]], dim=0 + ) + name = "model.embed_tokens.weight" + try: + param = params_dict[name] + except KeyError: + print(f"no {name}") + raise + weight_loader = getattr(param, "weight_loader", default_weight_loader) + try: + weight_loader(param, embedding_weight) + except: + raise \ No newline at end of file From 4af9e3f032f0878f7874bee03fcd5ca0b10541ea Mon Sep 17 00:00:00 2001 From: mrsalehi Date: Tue, 1 Oct 2024 20:08:58 -0700 Subject: [PATCH 02/12] rmvd import --- examples/offline_inference_molmo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/offline_inference_molmo.py b/examples/offline_inference_molmo.py index 3d7434ea8357..c6cb6e531d8f 100644 --- a/examples/offline_inference_molmo.py +++ b/examples/offline_inference_molmo.py @@ -9,7 +9,6 @@ from vllm import LLM from vllm.sampling_params import SamplingParams -from vllm_molmo.molmo import MolmoForCausalLM ImageFile.LOAD_TRUNCATED_IMAGES = True From 72c365e6e3235dbf020a7cfb2643a6506d3cd649 Mon Sep 17 00:00:00 2001 From: sanghol Date: Thu, 10 Oct 2024 23:21:58 +0000 Subject: [PATCH 03/12] code cleanup --- examples/offline_inference_molmo.py | 35 +- examples/offline_inference_vision_language.py | 18 + vllm/model_executor/models/molmo.py | 393 +++++++++++------- 3 files changed, 277 insertions(+), 169 deletions(-) diff --git a/examples/offline_inference_molmo.py b/examples/offline_inference_molmo.py index c6cb6e531d8f..e7d4bea3e2fa 100644 --- a/examples/offline_inference_molmo.py +++ b/examples/offline_inference_molmo.py @@ -1,11 +1,7 @@ import argparse -import numpy as np import requests from io import BytesIO -import base64 -from PIL import Image, ImageFile, ImageOps -import torch -from typing import Optional +from PIL import Image, ImageFile from vllm import LLM from vllm.sampling_params import SamplingParams @@ -14,7 +10,7 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True -def download_image_to_numpy(url): +def download_image(url: str): # Send a GET request to the URL response = requests.get(url) @@ -22,13 +18,8 @@ def download_image_to_numpy(url): if response.status_code == 200: # Open the image from the response content image = Image.open(BytesIO(response.content)).convert("RGB") - - image = ImageOps.exif_transpose(image) - - # Convert the image to a NumPy array - image_array = np.array(image).astype(np.uint8) - return image_array + return image else: raise Exception(f"Failed to download image. Status code: {response.status_code}") @@ -37,11 +28,11 @@ def vllm_generate(): inputs = [ { "prompt": "Describe this image.", - "multi_modal_data": {"image": download_image_to_numpy("https://picsum.photos/id/9/1080/720")} + "multi_modal_data": {"image": download_image("https://picsum.photos/id/9/1080/720")} }, { "prompt": "Describe what you see in this image.", - "multi_modal_data": {"image": download_image_to_numpy("https://picsum.photos/id/23/1080/720")} + "multi_modal_data": {"image": download_image("https://picsum.photos/id/23/1080/720")} }, ] @@ -88,26 +79,11 @@ def vllm_chat(): sampling_params=sampling_params ) - # Error: Invalid message role {{ message['role'] }} at index {{ loop.index }} for o in outputs: generated_text = o.outputs[0].text print(generated_text) -def set_tf_memory_growth(): - import tensorflow as tf - gpus = tf.config.experimental.list_physical_devices('GPU') - - if gpus: - try: - # Set memory growth for each GPU - for gpu in gpus: - tf.config.experimental.set_memory_growth(gpu, True) - print("Memory growth set for GPUs") - except RuntimeError as e: - print(e) - - if __name__ == "__main__": parser = argparse.ArgumentParser( description="vllm example for Molmo-7B-O-0924, Molmo-7B-D-0924, Molmo-72B-0924" @@ -124,6 +100,5 @@ def set_tf_memory_growth(): max_tokens=768, temperature=0, ) - set_tf_memory_growth() vllm_generate() vllm_chat() \ No newline at end of file diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index b94ef537d783..2c15e4509d34 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -275,6 +275,23 @@ def run_mllama(question, modality): return llm, prompt, stop_token_ids +# Molmo +def run_molmo(question, modality): + assert modality == "image" + + model_name = "allenai/Molmo-7B-D-0924" + + llm = LLM( + model=model_name, + trust_remote_code=True, + dtype="bfloat16", + ) + + prompt = question + stop_token_ids = None + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -290,6 +307,7 @@ def run_mllama(question, modality): "qwen_vl": run_qwen_vl, "qwen2_vl": run_qwen2_vl, "mllama": run_mllama, + "molmo": run_molmo, } diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 3b7a003c5dab..822ebd9d0672 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1,9 +1,9 @@ from array import array -from functools import lru_cache +from functools import lru_cache, partial import re import logging from dataclasses import dataclass -from typing import Iterable, List, Optional, Tuple, Union, Any, Mapping +from typing import Iterable, List, Optional, Tuple, Union, Any, Mapping, TypedDict from PIL import Image import math @@ -17,7 +17,13 @@ from vllm.attention.selector import (_Backend, backend_name_to_enum, get_global_forced_attn_backend) from vllm.config import CacheConfig, MultiModalConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_rank, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, +) from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ColumnParallelLinear, @@ -51,6 +57,34 @@ log = logging.getLogger(__name__) +# TODO: hard-coded for now. Consider making it configurable. +VIT_LAYERS = [-2, -9] +NUM_PREFIX_TOKENS = 1 +ADDITIONAL_VOCAB_SIZE = 128 + + +class MolmoImageInputs(TypedDict): + images: torch.Tensor + """Shape: + `(batch_size, num_crops, num_patch, patch_dim)` + """ + + image_input_idx: torch.Tensor + """Shape: + `(batch_size, num_crops, num_patch)` + """ + + seq_len: torch.Tensor + """Shape: + `(batch_size, )` + """ + + image_masks: Optional[torch.Tensor] + """Shape: + `(batch_size, num_crops, num_patch)` + """ + + @dataclass class VisionBackboneConfig: image_default_input_size: Tuple[int, int] = (336, 336) @@ -282,7 +316,7 @@ def __init__( scale = config.image_emb_dim ** -0.5 self.patch_num = config.image_num_patch self.class_embedding = nn.Parameter(torch.randn(config.image_emb_dim) * scale) - self.num_prefix_tokens: int = 1 + self.num_prefix_tokens: int = NUM_PREFIX_TOKENS self.positional_embedding = nn.Parameter( torch.randn(config.image_num_pos, config.image_emb_dim) * scale ) @@ -347,20 +381,20 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - tp_size = get_tensor_model_parallel_world_size() + self.tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads assert self.hidden_size % self.total_num_heads == 0 - assert self.total_num_heads % tp_size == 0 + assert self.total_num_heads % self.tp_size == 0 - self.num_heads = self.total_num_heads // tp_size + self.num_heads = self.total_num_heads // self.tp_size self.total_num_kv_heads = config.num_key_value_heads or self.total_num_heads - if self.total_num_kv_heads >= tp_size: - assert self.total_num_kv_heads % tp_size == 0 + if self.total_num_kv_heads >= self.tp_size: + assert self.total_num_kv_heads % self.tp_size == 0 else: - assert tp_size % self.total_num_kv_heads == 0 + assert self.tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -377,11 +411,13 @@ def __init__( quant_config=quant_config, ) + self.tp_rank: Optional[int] = None self.k_norm: Optional[nn.Module] = None self.q_norm: Optional[nn.Module] = None if config.attention_layer_norm: - self.k_norm = RMSNorm(self.kv_size, eps=config.layer_norm_eps) - self.q_norm = RMSNorm(self.q_size, eps=config.layer_norm_eps) + self.tp_rank = get_tensor_model_parallel_rank() + self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, eps=config.layer_norm_eps) + self.q_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) # Rotary embeddings. self.rotary_emb = get_rope( @@ -405,6 +441,25 @@ def __init__( bias=False, quant_config=quant_config, ) + + def _apply_qk_norm( + self, + q: torch.Tensor, + k: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.tp_size > 1: + q = tensor_model_parallel_all_gather(q.contiguous()) + k = tensor_model_parallel_all_gather(k.contiguous()) + q = self.q_norm.forward_native(q) + k = self.k_norm.forward_native(k) + if self.tp_size > 1: + splitter = partial( + split_tensor_along_last_dim, + num_partitions=self.tp_size + ) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + return q, k def forward( self, @@ -416,8 +471,7 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.q_norm is not None and self.k_norm is not None: - q = self.q_norm.forward_native(q) - k = self.k_norm.forward_native(k) + q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) @@ -486,7 +540,6 @@ def __init__( eps=config.layer_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) - self.norm_after = config.norm_after def forward( self, @@ -497,9 +550,7 @@ def forward( residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: # Self Attention - if self.norm_after: - residual = hidden_states - elif residual is None: + if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: @@ -512,19 +563,39 @@ def forward( attn_metadata=attn_metadata, ) - if self.norm_after: - hidden_states = self.input_layernorm(hidden_states) - hidden_states = hidden_states + residual - residual = hidden_states + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class MolmoDecoderNormAfterLayer(MolmoDecoderLayer): + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Self Attention + residual = hidden_states + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + hidden_states = self.input_layernorm(hidden_states) + hidden_states = hidden_states + residual + residual = hidden_states - if not self.norm_after: - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) hidden_states = self.mlp(hidden_states) - if self.norm_after: - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = hidden_states + residual - residual = None + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + residual + residual = None return hidden_states, residual @@ -536,8 +607,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - # TODO: vit_layers hard-coded for now. Consider making it configurable. - self.vit_layers = [-2, -9] + self.vit_layers = VIT_LAYERS self.image_num_patch = vision_config.image_num_patch self.llm_patches_per_crop = ( (self.image_num_patch[0] + 1) // 2, @@ -558,6 +628,14 @@ def __init__( image_dim = vision_config.image_emb_dim * len(self.vit_layers) self.pad_embed = nn.Parameter(torch.zeros((2, image_dim))) + @property + def dtype(self) -> torch.dtype: + return self.image_vit.patch_embedding.weight.dtype + + @property + def device(self) -> torch.device: + return self.image_vit.patch_embedding.weight.device + def encode_image(self, images: torch.Tensor) -> torch.Tensor: """ : param images: (batch_size, num_crops, num_patch, n_pixels) @@ -577,23 +655,20 @@ def encode_image(self, images: torch.Tensor) -> torch.Tensor: else: image_features = image_features[-1] - cls_embed: torch.Tensor = None if self.num_prefix_tokens > 0: - cls_embed = image_features[:, 0] image_features = image_features[:, 1:] image_features = image_features * mask image_features = image_features.view(B, T, N, -1) - cls_embed = cls_embed.view(B, T, -1) if cls_embed is not None else None - - return image_features, cls_embed + return image_features def forward(self, images: torch.Tensor, image_masks: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) batch_size, num_image = images.shape[:2] - image_features, cls_embed = self.encode_image(images) + images = images.to(device=self.device, dtype=self.dtype) + image_features = self.encode_image(images) og_dtype = image_features.dtype assert image_masks is not None @@ -634,8 +709,7 @@ def forward(self, images: torch.Tensor, image_masks: torch.Tensor) -> Tuple[torc image_features = self.image_projector(image_features) # image_features: (batch_size, num_image, num_patch, d_model) - # cls_embed: (batch_size, num_image, d_model) - return image_features, cls_embed + return image_features class MolmoModel(nn.Module): @@ -650,17 +724,17 @@ def __init__( self.config = config self.embedding_size = config.embedding_size or config.vocab_size - # TODO: extra embedding_size hard-coded for now. Consider making it configurable. - self.embedding_size += 128 + self.embedding_size += ADDITIONAL_VOCAB_SIZE self.embed_tokens = VocabParallelEmbedding( self.embedding_size, config.hidden_size, quant_config=quant_config, ) + decoder_layer = MolmoDecoderNormAfterLayer if config.norm_after else MolmoDecoderLayer self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: MolmoDecoderLayer(config, cache_config, quant_config), + lambda prefix: decoder_layer(config, cache_config, quant_config), prefix=f"{prefix}.layers", ) @@ -712,7 +786,13 @@ def forward( cached_get_processor = lru_cache(get_processor) -def get_num_patches(num_tiles, crop_patches, left_margin, right_margin, pooling_size): +def get_num_patches( + num_tiles: int, + crop_patches: int, + left_margin: int, + right_margin: int, + pooling_size: int +) -> int: crop_window_patches = crop_patches - (left_margin + right_margin) if num_tiles > 1: left_crop_window_patches = (crop_window_patches + left_margin + pooling_size - 1) // pooling_size * pooling_size @@ -724,7 +804,14 @@ def get_num_patches(num_tiles, crop_patches, left_margin, right_margin, pooling_ return single_crop_window_patches -def get_tokens(tiling_h, tiling_w, crop_patches, left_margin, right_margin, pooling_size): +def get_tokens( + tiling_h: int, + tiling_w: int, + crop_patches: int, + left_margin: int, + right_margin: int, + pooling_size: int +) -> int: h = get_num_patches(tiling_h, crop_patches, left_margin, right_margin, pooling_size) w = get_num_patches(tiling_w, crop_patches, left_margin, right_margin, pooling_size) per_row = w // pooling_size + 1 @@ -734,7 +821,13 @@ def get_tokens(tiling_h, tiling_w, crop_patches, left_margin, right_margin, pool return resize + joint -def get_max_tokens(max_crops, crop_patches, left_margin, right_margin, pooling_size): +def get_max_tokens( + max_crops: int, + crop_patches: int, + left_margin: int, + right_margin: int, + pooling_size: int +) -> int: tilings = [] for i in range(1, max_crops+1): for j in range(1, max_crops + 1): @@ -745,7 +838,7 @@ def get_max_tokens(max_crops, crop_patches, left_margin, right_margin, pooling_s def get_max_molmo_image_tokens(ctx: InputContext) -> int: - processor = cached_get_processor(ctx.model_config.model, trust_remote_code=True) + processor = cached_get_processor(ctx.model_config.model, trust_remote_code=True, revision=ctx.model_config.code_revision) image_processor = processor.image_processor max_llm_image_tokens = get_max_tokens( image_processor.max_crops, @@ -759,7 +852,7 @@ def get_max_molmo_image_tokens(ctx: InputContext) -> int: def image_input_mapper_for_molmo( ctx: InputContext, - data: Union[Image.Image, List[Image.Image]], + data: object, ): return MultiModalInputs(data) @@ -767,7 +860,7 @@ def image_input_mapper_for_molmo( def dummy_data_for_molmo( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] ): - processor = cached_get_processor(ctx.model_config.model, trust_remote_code=True) + processor = cached_get_processor(ctx.model_config.model, trust_remote_code=True, revision=ctx.model_config.code_revision) image_processor = processor.image_processor base_image_input_d = image_processor.image_patch_size @@ -837,12 +930,14 @@ def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs): prompt = llm_inputs["prompt"] multi_modal_data = llm_inputs.get("multi_modal_data") image = multi_modal_data.get("image") - processor = cached_get_processor(ctx.model_config.model, trust_remote_code=True) + processor = cached_get_processor(ctx.model_config.model, trust_remote_code=True, revision=ctx.model_config.code_revision) - if re.match(r"^User:[\s\S]*?(Assistant:)*$", prompt): + if prompt is not None and re.match(r"^User:[\s\S]*?(Assistant:)*$", prompt): out = processor.process(prompt, image, message_format="none") - else: + elif prompt is not None: out = processor.process(prompt, image) + else: + out = processor.process(None, image, tokens=llm_inputs["prompt_token_ids"]) image_processor = processor.image_processor max_total_crops = 1 + image_processor.max_crops @@ -904,7 +999,7 @@ def __init__( multimodal_config: Optional[MultiModalConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[Mapping[str, Any]] = None, - ): + ) -> None: super().__init__() self.config = config @@ -925,6 +1020,76 @@ def __init__( self.logits_processor = LogitsProcessor(config.embedding_size or config.vocab_size) self.sampler = Sampler() + + def _parse_and_validate_image_input( + self, + **kwargs: object, + ) -> Optional[MolmoImageInputs]: + images = kwargs.pop("images", None) + image_masks = kwargs.pop("image_masks", None) + if images is None: + return None + + image_input_idx = kwargs.pop("image_input_idx", None) + seq_len = kwargs.pop("seq_len", None) + if image_input_idx is None: + raise ValueError("image_input_idx is required for Molmo model.") + if seq_len is None: + raise ValueError("seq_len is required for Molmo model.") + if not isinstance(seq_len, torch.Tensor): + seq_len = torch.tensor(seq_len) + + return MolmoImageInputs( + images=images, + image_input_idx=image_input_idx, + seq_len=seq_len, + image_masks=image_masks, + ) + + def _process_image_input( + self, + image_input: MolmoImageInputs, + ) -> torch.Tensor: + + image_features = self.vision_backbone( + images=image_input["images"], + image_masks=image_input["image_masks"], + ) + + return image_features + + def _merge_multimodal_embeddings( + self, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + image_input_idx: torch.Tensor, + seq_len: Union[torch.Tensor, List[torch.Tensor]], + ) -> torch.Tensor: + batch_size, num_image, num_patch = image_features.shape[:3] + assert image_input_idx.shape == (batch_size, num_image, num_patch) + + image_features = image_features.to(inputs_embeds.device) + seq_len = seq_len.to(inputs_embeds.device) + + # insert the image feature into the embedding. + image_features = image_features.view(batch_size, num_image * num_patch, -1) + image_input_idx = image_input_idx.view(batch_size, num_image * num_patch) + + valid = image_input_idx >= 0 + image_features = image_features * valid[:, :, None].to(image_features.dtype) + image_features = image_features.view(batch_size * num_image * num_patch, -1).contiguous() + + image_input_idx = image_input_idx * valid.to(image_input_idx.dtype) + offset = torch.cat([seq_len.new_zeros((1)), seq_len.cumsum(dim=0)[:-1]], dim=0)[:, None] + image_input_idx = image_input_idx + offset.to(image_input_idx.dtype) + image_input_idx = image_input_idx.flatten()[:, None] + mat = image_input_idx == torch.arange(seq_len.sum().item(), device=inputs_embeds.device)[None, :] + mat = mat.to(image_features.dtype) + + inputs_embeds = inputs_embeds + torch.einsum('nd,nm->md', image_features, mat) + + return inputs_embeds + def forward( self, @@ -932,46 +1097,21 @@ def forward( positions: torch.LongTensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - images: Optional[torch.Tensor] = None, - image_masks: Optional[torch.Tensor] = None, - image_input_idx: Optional[torch.Tensor] = None, - seq_len: Optional[torch.Tensor] = None, - **kwargs, + **kwargs: object, ) -> SamplerOutput: - has_image = images is not None - num_image: Optional[int] = None - - if has_image: - x = self.model.embed_tokens(input_ids) - batch_size = images.size(0) - if not isinstance(seq_len, torch.Tensor): - seq_len = torch.tensor(seq_len, device=x.device) - images = images.to(device=x.device, dtype=x.dtype) - image_features: torch.Tensor - image_features, cls_embed = self.vision_backbone(images=images, image_masks=image_masks) - num_image, num_patch = image_features.shape[1:3] - assert image_input_idx.shape == (batch_size, num_image, num_patch) - - image_features = image_features.to(x.device) - - # insert the image feature into the embedding. - image_features = image_features.view(batch_size, num_image * num_patch, -1) - image_input_idx = image_input_idx.view(batch_size, num_image * num_patch) - - valid = image_input_idx >= 0 - image_features = image_features * valid[:, :, None].to(image_features.dtype) - image_features = image_features.view(batch_size * num_image * num_patch, -1).contiguous() - - image_input_idx = image_input_idx * valid.to(image_input_idx.dtype) - offset = torch.cat([seq_len.new_zeros((1)), seq_len.cumsum(dim=0)[:-1]], dim=0)[:, None] - image_input_idx = image_input_idx + offset.to(image_input_idx.dtype) - image_input_idx = image_input_idx.flatten()[:, None] - mat = image_input_idx == torch.arange(seq_len.sum().item(), device=x.device)[None, :] - mat = mat.to(image_features.dtype) - x = x + torch.einsum('nd,nm->md', image_features, mat) + image_input = self._parse_and_validate_image_input(**kwargs) - inputs_embeds = x + if image_input is not None: + inputs_embeds = self.model.embed_tokens(input_ids) + image_features = self._process_image_input(image_input) + + inputs_embeds = self._merge_multimodal_embeddings( + inputs_embeds, + image_features, + image_input["image_input_idx"], + image_input["seq_len"], + ) input_ids = None else: @@ -1003,12 +1143,21 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_mapping = [ + ("model.transformer.ln_f.weight", "model.norm.weight"), + ("attn_out", "self_attn.o_proj"), + ("att_proj", "self_attn.qkv_proj"), + ("q_norm", "self_attn.q_norm"), + ("k_norm", "self_attn.k_norm"), + ("attn_norm", "input_layernorm"), + ("ff_norm", "post_attention_layernorm"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) embedding_weight = dict() projector_weight = dict() for name, loaded_weight in weights: - log.info(f"Original name: {name}") if "rotary_emb.inv_freq" in name: log.info(f"Skipping {name}") continue @@ -1037,23 +1186,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): raise ValueError(f"Unexpected projector weight: {name}") continue else: - if "ln_f.weight" in name: - name = "model.norm.weight" - if "transformer.blocks" in name: name = name.replace("transformer.blocks", "layers") - - if "attn_out" in name: - name = name.replace("attn_out", "self_attn.o_proj") - - if "att_proj" in name: - name = name.replace("att_proj", "self_attn.qkv_proj") - - if 'q_norm' in name: - name = name.replace("q_norm", "self_attn.q_norm") - - if 'k_norm' in name: - name = name.replace("k_norm", "self_attn.k_norm") if "ff_proj" in name: name = name.replace("ff_proj", "mlp.gate_up_proj") @@ -1061,18 +1195,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): up_weight, gate_weight = loaded_weight.chunk(2, dim=0) loaded_weight = torch.cat([gate_weight, up_weight], dim=0) - if "ff_out" in name: + elif "ff_out" in name: if "layers" in name: name = name.replace("ff_out", "mlp.down_proj") else: # lm head name = name.replace("model.transformer.ff_out", "lm_head") - - if "attn_norm" in name: - name = name.replace("attn_norm", "input_layernorm") - if "ff_norm" in name: - name = name.replace("ff_norm", "post_attention_layernorm") + else: + for (param_name, weight_name) in params_mapping: + if param_name in name: + name = name.replace(param_name, weight_name) + break + try: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: @@ -1091,40 +1226,20 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): [projector_weight["gate_proj"], projector_weight["up_proj"]], dim=0 ) name = "vision_backbone.image_projector.gate_up_proj.weight" - try: - param = params_dict[name] - except KeyError: - print(f"no {name}") - raise + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - try: - weight_loader(param, gate_up_proj_weight) - except: - raise + weight_loader(param, gate_up_proj_weight) + down_proj_weight = projector_weight["down_proj"] name = "vision_backbone.image_projector.down_proj.weight" - try: - param = params_dict[name] - except KeyError: - print(f"no {name}") - raise + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - try: - weight_loader(param, down_proj_weight) - except: - raise + weight_loader(param, down_proj_weight) embedding_weight = torch.cat( [embedding_weight["embedding"], embedding_weight["new_embedding"]], dim=0 ) name = "model.embed_tokens.weight" - try: - param = params_dict[name] - except KeyError: - print(f"no {name}") - raise + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - try: - weight_loader(param, embedding_weight) - except: - raise \ No newline at end of file + weight_loader(param, embedding_weight) \ No newline at end of file From 0a8946a2d783b9d2f264fd50177d412319903eb4 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 13 Oct 2024 22:07:25 -0700 Subject: [PATCH 04/12] format --- examples/offline_inference_vision_language.py | 2 + vllm/model_executor/models/molmo.py | 338 ++++++++++-------- 2 files changed, 194 insertions(+), 146 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index af44ce03aa5f..8f35ca78a3b6 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -316,6 +316,8 @@ def run_molmo(question, modality): stop_token_ids = None return llm, prompt, stop_token_ids + +# GLM4V def run_glm4v(question: str, modality: str): assert modality == "image" model_name = "THUDM/glm-4v-9b" diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 822ebd9d0672..e4fac9900a08 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -56,7 +56,6 @@ log = logging.getLogger(__name__) - # TODO: hard-coded for now. Consider making it configurable. VIT_LAYERS = [-2, -9] NUM_PREFIX_TOKENS = 1 @@ -100,7 +99,8 @@ class VisionBackboneConfig: image_norm_eps: float = 1e-5 def __post_init__(self): - self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment] + self.image_default_input_size = tuple( + self.image_default_input_size) # type: ignore[assignment] @property def image_num_patch(self): @@ -110,6 +110,7 @@ def image_num_patch(self): class ViTMLP(nn.Module): """MLP used in Vision Transformer.""" + def __init__( self, config: VisionBackboneConfig, @@ -141,6 +142,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MultiHeadDotProductAttention(nn.Module): """Multi-head attention used in Vision Transformer.""" + def __init__( self, config: VisionBackboneConfig, @@ -167,7 +169,7 @@ def __init__( assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - + self.wq = ColumnParallelLinear( nlayers * self.hidden_size, self.total_num_heads * self.head_dim, @@ -222,10 +224,11 @@ def __init__( self._use_flash_attn = False else: raise RuntimeError( - f"Molmo does not support {selected_backend} backend now." - ) + f"Molmo does not support {selected_backend} backend now.") - def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, + inputs_q: torch.Tensor, + inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor: if inputs_kv is not None: inputs_k = inputs_kv @@ -249,7 +252,7 @@ def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = No else: from xformers import ops as xops output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0) - + output = rearrange(output, "b s h d -> b s (h d)").contiguous() output, _ = self.wo(output) @@ -258,13 +261,15 @@ def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = No class ResidualAttentionBlock(nn.Module): """Residual attention block used in Vision Transformer.""" + def __init__( self, config: VisionBackboneConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.attention = MultiHeadDotProductAttention(config, quant_config=quant_config) + self.attention = MultiHeadDotProductAttention( + config, quant_config=quant_config) self.feed_forward = ViTMLP(config, quant_config) self.attention_norm = nn.LayerNorm( config.image_emb_dim, @@ -283,15 +288,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class BlockCollection(nn.Module): """Collection of residual attention blocks used in Vision Transformer.""" + def __init__( self, config: VisionBackboneConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.resblocks = nn.ModuleList( - [ResidualAttentionBlock(config, quant_config) for _ in range(config.image_num_layers)] - ) + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock(config, quant_config) + for _ in range(config.image_num_layers) + ]) def forward(self, x: torch.Tensor) -> List[torch.Tensor]: hidden_states = [] @@ -307,26 +314,28 @@ def _expand_token(token: torch.Tensor, batch_size: int) -> torch.Tensor: class VisionTransformer(nn.Module): """Vision Transformer used in Vision Backbone.""" + def __init__( self, config: VisionBackboneConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - scale = config.image_emb_dim ** -0.5 + scale = config.image_emb_dim**-0.5 self.patch_num = config.image_num_patch - self.class_embedding = nn.Parameter(torch.randn(config.image_emb_dim) * scale) + self.class_embedding = nn.Parameter( + torch.randn(config.image_emb_dim) * scale) self.num_prefix_tokens: int = NUM_PREFIX_TOKENS self.positional_embedding = nn.Parameter( - torch.randn(config.image_num_pos, config.image_emb_dim) * scale - ) + torch.randn(config.image_num_pos, config.image_emb_dim) * scale) image_patch_size = config.image_patch_size self.patch_embedding = nn.Linear( image_patch_size * image_patch_size * 3, config.image_emb_dim, bias=False, ) - self.pre_ln = nn.LayerNorm(config.image_emb_dim, eps=config.image_norm_eps) + self.pre_ln = nn.LayerNorm(config.image_emb_dim, + eps=config.image_norm_eps) self.transformer = BlockCollection(config, quant_config) def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: @@ -334,8 +343,8 @@ def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: pos_emb = self.positional_embedding[1:] pos_emb = pos_emb.reshape( - (int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1]) - ) + (int(math.sqrt(pos_emb.shape[0])), + int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])) (patch_num_0, patch_num_1) = patch_num @@ -343,15 +352,22 @@ def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: # from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2) pos_emb = F.interpolate( - pos_emb, size=(patch_num_0, patch_num_1), mode="bicubic", align_corners=False, antialias=True, + pos_emb, + size=(patch_num_0, patch_num_1), + mode="bicubic", + align_corners=False, + antialias=True, ) pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0) pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1]) - x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype) + x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], + dim=1).to(x.dtype) return x - def forward(self, x: torch.Tensor, patch_num: int = None) -> List[torch.Tensor]: + def forward(self, + x: torch.Tensor, + patch_num: int = None) -> List[torch.Tensor]: """ : param x: (batch_size, num_patch, n_pixels) """ @@ -362,7 +378,9 @@ def forward(self, x: torch.Tensor, patch_num: int = None) -> List[torch.Tensor]: x = self.patch_embedding(x) # class embeddings and positional embeddings - x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) + x = torch.cat( + [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], + dim=1) x = self.add_pos_emb(x, patch_num) x = self.pre_ln(x) @@ -373,6 +391,7 @@ def forward(self, x: torch.Tensor, patch_num: int = None) -> List[torch.Tensor]: class MolmoAttention(nn.Module): """Molmo's LLM attention.""" + def __init__( self, config: PretrainedConfig, @@ -393,7 +412,7 @@ def __init__( assert self.total_num_kv_heads % self.tp_size == 0 else: assert self.tp_size % self.total_num_kv_heads == 0 - + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim @@ -416,8 +435,10 @@ def __init__( self.q_norm: Optional[nn.Module] = None if config.attention_layer_norm: self.tp_rank = get_tensor_model_parallel_rank() - self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, eps=config.layer_norm_eps) - self.q_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) + self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, + eps=config.layer_norm_eps) + self.q_norm = RMSNorm(config.hidden_size, + eps=config.layer_norm_eps) # Rotary embeddings. self.rotary_emb = get_rope( @@ -441,22 +462,17 @@ def __init__( bias=False, quant_config=quant_config, ) - - def _apply_qk_norm( - self, - q: torch.Tensor, - k: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + + def _apply_qk_norm(self, q: torch.Tensor, + k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous()) q = self.q_norm.forward_native(q) k = self.k_norm.forward_native(k) if self.tp_size > 1: - splitter = partial( - split_tensor_along_last_dim, - num_partitions=self.tp_size - ) + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] return q, k @@ -521,6 +537,7 @@ def forward( class MolmoDecoderLayer(nn.Module): + def __init__( self, config: PretrainedConfig, @@ -600,6 +617,7 @@ def forward( class MolmoVisionBackbone(nn.Module): + def __init__( self, config: PretrainedConfig, @@ -613,12 +631,16 @@ def __init__( (self.image_num_patch[0] + 1) // 2, (self.image_num_patch[1] + 1) // 2, ) - self.image_vit = VisionTransformer(vision_config, quant_config=quant_config) + self.image_vit = VisionTransformer(vision_config, + quant_config=quant_config) self.num_prefix_tokens = self.image_vit.num_prefix_tokens - assert self.num_prefix_tokens in {0, 1}, "Only 0 or 1 prefix tokens are supported" + assert self.num_prefix_tokens in { + 0, 1 + }, "Only 0 or 1 prefix tokens are supported" self.image_pooling_2d = MultiHeadDotProductAttention( - vision_config, nlayers=len(self.vit_layers), quant_config=quant_config - ) + vision_config, + nlayers=len(self.vit_layers), + quant_config=quant_config) self.image_projector = MolmoMLP( config, input_dim=vision_config.image_emb_dim, @@ -642,7 +664,8 @@ def encode_image(self, images: torch.Tensor) -> torch.Tensor: """ B, T, N, D = images.shape - mask = ~torch.all(images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True) + mask = ~torch.all( + images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True) images = images.view(B * T, N, D) image_features = self.image_vit(images) @@ -663,27 +686,32 @@ def encode_image(self, images: torch.Tensor) -> torch.Tensor: return image_features - def forward(self, images: torch.Tensor, image_masks: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + def forward( + self, images: torch.Tensor, image_masks: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) batch_size, num_image = images.shape[:2] images = images.to(device=self.device, dtype=self.dtype) image_features = self.encode_image(images) - + og_dtype = image_features.dtype assert image_masks is not None pad_embed = self.pad_embed[:, None, None, None, :] all_pad = image_masks == 0 - partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=torch.float32) + partial_pad = torch.logical_and( + image_masks < 1, + torch.logical_not(all_pad)).to(dtype=torch.float32) all_pad = all_pad.to(dtype=torch.float32) - image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1) - image_features = image_features + pad_embed[1] * torch.unsqueeze(partial_pad, -1) + image_features = image_features + pad_embed[0] * torch.unsqueeze( + all_pad, -1) + image_features = image_features + pad_embed[1] * torch.unsqueeze( + partial_pad, -1) image_features = image_features.to(og_dtype) image_features = image_features.reshape( - (batch_size, num_image) + self.image_num_patch + (-1,), - ) + (batch_size, num_image) + self.image_num_patch + (-1, ), ) if self.image_num_patch[0] % 2 == 1: # Pad so we can still pool 2x2 patches @@ -713,6 +741,7 @@ def forward(self, images: torch.Tensor, image_masks: torch.Tensor) -> Tuple[torc class MolmoModel(nn.Module): + def __init__( self, config: PretrainedConfig, @@ -786,34 +815,33 @@ def forward( cached_get_processor = lru_cache(get_processor) -def get_num_patches( - num_tiles: int, - crop_patches: int, - left_margin: int, - right_margin: int, - pooling_size: int -) -> int: +def get_num_patches(num_tiles: int, crop_patches: int, left_margin: int, + right_margin: int, pooling_size: int) -> int: crop_window_patches = crop_patches - (left_margin + right_margin) if num_tiles > 1: - left_crop_window_patches = (crop_window_patches + left_margin + pooling_size - 1) // pooling_size * pooling_size - middle_crop_window_patches = (crop_window_patches + pooling_size - 1) // pooling_size * pooling_size - right_crop_window_patches = (crop_window_patches + right_margin + pooling_size - 1) // pooling_size * pooling_size - return left_crop_window_patches + (num_tiles - 2) * middle_crop_window_patches + right_crop_window_patches + left_crop_window_patches = (crop_window_patches + left_margin + + pooling_size - + 1) // pooling_size * pooling_size + middle_crop_window_patches = (crop_window_patches + pooling_size - + 1) // pooling_size * pooling_size + right_crop_window_patches = (crop_window_patches + right_margin + + pooling_size - + 1) // pooling_size * pooling_size + return left_crop_window_patches + ( + num_tiles - + 2) * middle_crop_window_patches + right_crop_window_patches else: - single_crop_window_patches = (crop_patches + pooling_size - 1) // pooling_size * pooling_size + single_crop_window_patches = (crop_patches + pooling_size - + 1) // pooling_size * pooling_size return single_crop_window_patches -def get_tokens( - tiling_h: int, - tiling_w: int, - crop_patches: int, - left_margin: int, - right_margin: int, - pooling_size: int -) -> int: - h = get_num_patches(tiling_h, crop_patches, left_margin, right_margin, pooling_size) - w = get_num_patches(tiling_w, crop_patches, left_margin, right_margin, pooling_size) +def get_tokens(tiling_h: int, tiling_w: int, crop_patches: int, + left_margin: int, right_margin: int, pooling_size: int) -> int: + h = get_num_patches(tiling_h, crop_patches, left_margin, right_margin, + pooling_size) + w = get_num_patches(tiling_w, crop_patches, left_margin, right_margin, + pooling_size) per_row = w // pooling_size + 1 joint = per_row * (h // pooling_size) + 2 image_token_length = (crop_patches + pooling_size - 1) // pooling_size @@ -821,28 +849,29 @@ def get_tokens( return resize + joint -def get_max_tokens( - max_crops: int, - crop_patches: int, - left_margin: int, - right_margin: int, - pooling_size: int -) -> int: +def get_max_tokens(max_crops: int, crop_patches: int, left_margin: int, + right_margin: int, pooling_size: int) -> int: tilings = [] - for i in range(1, max_crops+1): + for i in range(1, max_crops + 1): for j in range(1, max_crops + 1): if i * j <= max_crops: tilings.append((i, j)) - tokens = [get_tokens(tilings[i][0], tilings[i][1], crop_patches, left_margin, right_margin, pooling_size) for i in range(len(tilings))] + tokens = [ + get_tokens(tilings[i][0], tilings[i][1], crop_patches, left_margin, + right_margin, pooling_size) for i in range(len(tilings)) + ] return max(tokens) def get_max_molmo_image_tokens(ctx: InputContext) -> int: - processor = cached_get_processor(ctx.model_config.model, trust_remote_code=True, revision=ctx.model_config.code_revision) + processor = cached_get_processor(ctx.model_config.model, + trust_remote_code=True, + revision=ctx.model_config.code_revision) image_processor = processor.image_processor max_llm_image_tokens = get_max_tokens( image_processor.max_crops, - image_processor.base_image_input_size[0] // image_processor.image_patch_size, + image_processor.base_image_input_size[0] // + image_processor.image_patch_size, image_processor.overlap_margins[0], image_processor.overlap_margins[1], 2, @@ -857,10 +886,11 @@ def image_input_mapper_for_molmo( return MultiModalInputs(data) -def dummy_data_for_molmo( - ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] -): - processor = cached_get_processor(ctx.model_config.model, trust_remote_code=True, revision=ctx.model_config.code_revision) +def dummy_data_for_molmo(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + processor = cached_get_processor(ctx.model_config.model, + trust_remote_code=True, + revision=ctx.model_config.code_revision) image_processor = processor.image_processor base_image_input_d = image_processor.image_patch_size @@ -872,29 +902,27 @@ def dummy_data_for_molmo( if seq_len - max_llm_image_tokens - 1 < 0: raise RuntimeError( f"Molmo cannot process {max_crops} crops in a prompt, " - "please increase max_model_len or reduce number of crops" - ) - + "please increase max_model_len or reduce number of crops") + # The vertical image has the maximum number of image tokens due to column tokens. tiling = (max_crops, 1) total_margin_pixels = base_image_input_d * (right_margin + left_margin) - crop_patches = image_processor.base_image_input_size[0] // base_image_input_d + crop_patches = image_processor.base_image_input_size[ + 0] // base_image_input_d crop_window_patches = crop_patches - (right_margin + left_margin) crop_window_size = crop_window_patches * base_image_input_d h = crop_window_size * tiling[0] + total_margin_pixels w = crop_window_size * tiling[1] + total_margin_pixels - dummy_image = Image.new( - "RGB", (w, h), color="red" - ) + dummy_image = Image.new("RGB", (w, h), color="red") out = processor.process("dummy prompt", dummy_image) token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - out["input_ids"][:1 + max_llm_image_tokens]) + out["input_ids"][:1 + max_llm_image_tokens]) token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - max_llm_image_tokens - 1) + [0]) * (seq_len - max_llm_image_tokens - 1) dummy_seqdata = SequenceData(token_ids) dummy_imgdata = { "images": out["images"], @@ -913,16 +941,10 @@ def pad_images( image_masks: Optional[torch.Tensor] = None, ): n = max_total_crops - images.shape[0] - images = F.pad( - images, (0, 0, 0, 0, 0, n), value=-1 - ) - image_input_idx = F.pad( - image_input_idx, (0, 0, 0, n), value=-1 - ) + images = F.pad(images, (0, 0, 0, 0, 0, n), value=-1) + image_input_idx = F.pad(image_input_idx, (0, 0, 0, n), value=-1) if image_masks is not None: - image_masks = F.pad( - image_masks, (0, 0, 0, n), value=-1 - ) + image_masks = F.pad(image_masks, (0, 0, 0, n), value=-1) return images, image_input_idx, image_masks @@ -930,14 +952,19 @@ def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs): prompt = llm_inputs["prompt"] multi_modal_data = llm_inputs.get("multi_modal_data") image = multi_modal_data.get("image") - processor = cached_get_processor(ctx.model_config.model, trust_remote_code=True, revision=ctx.model_config.code_revision) + processor = cached_get_processor(ctx.model_config.model, + trust_remote_code=True, + revision=ctx.model_config.code_revision) - if prompt is not None and re.match(r"^User:[\s\S]*?(Assistant:)*$", prompt): + if prompt is not None and re.match(r"^User:[\s\S]*?(Assistant:)*$", + prompt): out = processor.process(prompt, image, message_format="none") elif prompt is not None: out = processor.process(prompt, image) else: - out = processor.process(None, image, tokens=llm_inputs["prompt_token_ids"]) + out = processor.process(None, + image, + tokens=llm_inputs["prompt_token_ids"]) image_processor = processor.image_processor max_total_crops = 1 + image_processor.max_crops @@ -959,14 +986,20 @@ def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs): n_patches = image_num_patch[0] * image_num_patch[1] tokens_per_image = image_processor.image_token_length_w * image_processor.image_token_length_h images = torch.full( - (max_total_crops, n_patches, n_pixels), -1, dtype=torch.float32, + (max_total_crops, n_patches, n_pixels), + -1, + dtype=torch.float32, ) image_input_idx = torch.full( - (max_total_crops, tokens_per_image), -1, dtype=torch.int32, + (max_total_crops, tokens_per_image), + -1, + dtype=torch.int32, ) if image_processor.image_padding_mask: image_masks = torch.full( - (max_total_crops, n_patches), -1, dtype=torch.float32, + (max_total_crops, n_patches), + -1, + dtype=torch.float32, ) image_data = dict( @@ -975,9 +1008,10 @@ def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs): ) if image_masks is not None: image_data["image_masks"] = image_masks - - image_data["seq_len"] = torch.tensor(len(out["input_ids"]), dtype=torch.long) - + + image_data["seq_len"] = torch.tensor(len(out["input_ids"]), + dtype=torch.long) + multi_modal_data = dict(image=image_data) return LLMInputs( @@ -994,19 +1028,20 @@ def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs): class MolmoForCausalLM(nn.Module, SupportsMultiModal): def __init__( - self, + self, config: PretrainedConfig, multimodal_config: Optional[MultiModalConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[Mapping[str, Any]] = None, ) -> None: super().__init__() - + self.config = config self.multimodal_config = multimodal_config vision_config = VisionBackboneConfig() - self.vision_backbone = MolmoVisionBackbone(config, vision_config, quant_config) + self.vision_backbone = MolmoVisionBackbone(config, vision_config, + quant_config) self.model = MolmoModel(config, cache_config, quant_config) if self.config.weight_tying: @@ -1018,7 +1053,8 @@ def __init__( quant_config=quant_config, ) - self.logits_processor = LogitsProcessor(config.embedding_size or config.vocab_size) + self.logits_processor = LogitsProcessor(config.embedding_size + or config.vocab_size) self.sampler = Sampler() def _parse_and_validate_image_input( @@ -1029,7 +1065,7 @@ def _parse_and_validate_image_input( image_masks = kwargs.pop("image_masks", None) if images is None: return None - + image_input_idx = kwargs.pop("image_input_idx", None) seq_len = kwargs.pop("seq_len", None) if image_input_idx is None: @@ -1038,26 +1074,26 @@ def _parse_and_validate_image_input( raise ValueError("seq_len is required for Molmo model.") if not isinstance(seq_len, torch.Tensor): seq_len = torch.tensor(seq_len) - + return MolmoImageInputs( images=images, image_input_idx=image_input_idx, seq_len=seq_len, image_masks=image_masks, ) - + def _process_image_input( self, image_input: MolmoImageInputs, ) -> torch.Tensor: - + image_features = self.vision_backbone( images=image_input["images"], image_masks=image_input["image_masks"], ) return image_features - + def _merge_multimodal_embeddings( self, inputs_embeds: torch.Tensor, @@ -1070,27 +1106,34 @@ def _merge_multimodal_embeddings( image_features = image_features.to(inputs_embeds.device) seq_len = seq_len.to(inputs_embeds.device) - + # insert the image feature into the embedding. - image_features = image_features.view(batch_size, num_image * num_patch, -1) - image_input_idx = image_input_idx.view(batch_size, num_image * num_patch) + image_features = image_features.view(batch_size, num_image * num_patch, + -1) + image_input_idx = image_input_idx.view(batch_size, + num_image * num_patch) valid = image_input_idx >= 0 - image_features = image_features * valid[:, :, None].to(image_features.dtype) - image_features = image_features.view(batch_size * num_image * num_patch, -1).contiguous() + image_features = image_features * valid[:, :, None].to( + image_features.dtype) + image_features = image_features.view( + batch_size * num_image * num_patch, -1).contiguous() image_input_idx = image_input_idx * valid.to(image_input_idx.dtype) - offset = torch.cat([seq_len.new_zeros((1)), seq_len.cumsum(dim=0)[:-1]], dim=0)[:, None] + offset = torch.cat( + [seq_len.new_zeros( + (1)), seq_len.cumsum(dim=0)[:-1]], dim=0)[:, None] image_input_idx = image_input_idx + offset.to(image_input_idx.dtype) image_input_idx = image_input_idx.flatten()[:, None] - mat = image_input_idx == torch.arange(seq_len.sum().item(), device=inputs_embeds.device)[None, :] + mat = image_input_idx == torch.arange( + seq_len.sum().item(), device=inputs_embeds.device)[None, :] mat = mat.to(image_features.dtype) - - inputs_embeds = inputs_embeds + torch.einsum('nd,nm->md', image_features, mat) + + inputs_embeds = inputs_embeds + torch.einsum('nd,nm->md', + image_features, mat) return inputs_embeds - def forward( self, input_ids: torch.LongTensor, @@ -1105,7 +1148,7 @@ def forward( if image_input is not None: inputs_embeds = self.model.embed_tokens(input_ids) image_features = self._process_image_input(image_input) - + inputs_embeds = self._merge_multimodal_embeddings( inputs_embeds, image_features, @@ -1116,7 +1159,7 @@ def forward( input_ids = None else: inputs_embeds = None - + hidden_states = self.model( input_ids=input_ids, positions=positions, @@ -1126,7 +1169,7 @@ def forward( ) return hidden_states - + def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, @@ -1136,7 +1179,7 @@ def compute_logits(self, hidden_states: torch.Tensor, def sample( self, logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens @@ -1167,11 +1210,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if "wte.embedding" in name: embedding_weight["embedding"] = loaded_weight continue - + if "wte.new_embedding" in name: embedding_weight["new_embedding"] = loaded_weight continue - + if "vision_backbone" in name: if name.startswith("model"): name = name[len("model."):] @@ -1183,7 +1226,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): elif 'w2' in name: projector_weight['down_proj'] = loaded_weight else: - raise ValueError(f"Unexpected projector weight: {name}") + raise ValueError( + f"Unexpected projector weight: {name}") continue else: if "transformer.blocks" in name: @@ -1200,8 +1244,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = name.replace("ff_out", "mlp.down_proj") else: # lm head - name = name.replace("model.transformer.ff_out", "lm_head") - + name = name.replace("model.transformer.ff_out", + "lm_head") + else: for (param_name, weight_name) in params_mapping: if param_name in name: @@ -1216,15 +1261,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): except KeyError: print(params_dict.keys()) raise - weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) try: weight_loader(param, loaded_weight) except: raise - + gate_up_proj_weight = torch.cat( - [projector_weight["gate_proj"], projector_weight["up_proj"]], dim=0 - ) + [projector_weight["gate_proj"], projector_weight["up_proj"]], + dim=0) name = "vision_backbone.image_projector.gate_up_proj.weight" param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -1237,9 +1283,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, down_proj_weight) embedding_weight = torch.cat( - [embedding_weight["embedding"], embedding_weight["new_embedding"]], dim=0 - ) + [embedding_weight["embedding"], embedding_weight["new_embedding"]], + dim=0) name = "model.embed_tokens.weight" param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, embedding_weight) \ No newline at end of file + weight_loader(param, embedding_weight) From 81c759d29c50990c025ca73c3af233ff49ddc498 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 13 Oct 2024 22:07:46 -0700 Subject: [PATCH 05/12] delete old example --- examples/offline_inference_molmo.py | 104 ---------------------------- 1 file changed, 104 deletions(-) delete mode 100644 examples/offline_inference_molmo.py diff --git a/examples/offline_inference_molmo.py b/examples/offline_inference_molmo.py deleted file mode 100644 index e7d4bea3e2fa..000000000000 --- a/examples/offline_inference_molmo.py +++ /dev/null @@ -1,104 +0,0 @@ -import argparse -import requests -from io import BytesIO -from PIL import Image, ImageFile - -from vllm import LLM -from vllm.sampling_params import SamplingParams - - -ImageFile.LOAD_TRUNCATED_IMAGES = True - - -def download_image(url: str): - # Send a GET request to the URL - response = requests.get(url) - - # Check if the request was successful - if response.status_code == 200: - # Open the image from the response content - image = Image.open(BytesIO(response.content)).convert("RGB") - - return image - else: - raise Exception(f"Failed to download image. Status code: {response.status_code}") - - -def vllm_generate(): - inputs = [ - { - "prompt": "Describe this image.", - "multi_modal_data": {"image": download_image("https://picsum.photos/id/9/1080/720")} - }, - { - "prompt": "Describe what you see in this image.", - "multi_modal_data": {"image": download_image("https://picsum.photos/id/23/1080/720")} - }, - ] - - outputs = llm.generate( - inputs, - sampling_params=sampling_params - ) - - for o in outputs: - generated_text = o.outputs[0].text - print(generated_text) - - -def vllm_chat(): - url_1 = "https://picsum.photos/id/9/1080/720" - messages = [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "Describe the image." - }, - { - "type": "image_url", - "image_url": { - "url": url_1 - } - }, - ], - }, - { - "role": "assistant", - "content": "The image shows some objects.", - }, - { - "role": "user", - "content": "What objects do you exactly see in the image?", - }, - ] - - outputs = llm.chat( - messages, - sampling_params=sampling_params - ) - - for o in outputs: - generated_text = o.outputs[0].text - print(generated_text) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="vllm example for Molmo-7B-O-0924, Molmo-7B-D-0924, Molmo-72B-0924" - ) - parser.add_argument("--model_path", type=str, default="allenai/Molmo-7B-D-0924") - args = parser.parse_args() - llm = LLM( - model=args.model_path, - trust_remote_code=True, - gpu_memory_utilization=0.95, - dtype="bfloat16", - ) - sampling_params = SamplingParams( - max_tokens=768, - temperature=0, - ) - vllm_generate() - vllm_chat() \ No newline at end of file From beecc01f9ffa2f40793734e1564c6e4d4de24f72 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 13 Oct 2024 22:09:04 -0700 Subject: [PATCH 06/12] add molmo to correct registry --- vllm/model_executor/models/registry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 8caaab997466..b06d3d612dbc 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -104,6 +104,7 @@ "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501 "MiniCPMV": ("minicpmv", "MiniCPMV"), + "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), From 877e71fb160e292db6db18d7915977af44961864 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 13 Oct 2024 22:22:35 -0700 Subject: [PATCH 07/12] a lot of format cleanup --- vllm/model_executor/models/molmo.py | 75 ++++++++++++++--------------- 1 file changed, 36 insertions(+), 39 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index e4fac9900a08..503eaad54279 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1,58 +1,51 @@ -from array import array -from functools import lru_cache, partial -import re import logging -from dataclasses import dataclass -from typing import Iterable, List, Optional, Tuple, Union, Any, Mapping, TypedDict -from PIL import Image import math +import re +from array import array +from dataclasses import dataclass +from functools import lru_cache, partial +from typing import (Any, Iterable, List, Mapping, Optional, Tuple, TypedDict, + Union) import torch +from einops import rearrange +from PIL import Image from torch import nn from torch.nn import functional as F -from einops import rearrange from transformers import PretrainedConfig + import vllm.envs as envs from vllm.attention import Attention, AttentionMetadata from vllm.attention.selector import (_Backend, backend_name_to_enum, get_global_forced_attn_backend) from vllm.config import CacheConfig, MultiModalConfig -from vllm.distributed import ( - get_pp_group, - get_tensor_model_parallel_world_size, - get_tensor_model_parallel_rank, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather, -) +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - ColumnParallelLinear, +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs -from vllm.model_executor.models.interfaces import SupportsMultiModal -from vllm.multimodal import ( - MULTIMODAL_REGISTRY, - MultiModalInputs, -) -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.models.utils import make_layers from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.sequence import ( - VLLM_TOKEN_ID_ARRAY_TYPE, - SequenceData, -) -from vllm.model_executor import SamplingMetadata -from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.processor import get_processor +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.utils import make_layers +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs from vllm.platforms import current_platform +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SequenceData) +from vllm.transformers_utils.processor import get_processor log = logging.getLogger(__name__) @@ -407,7 +400,8 @@ def __init__( assert self.total_num_heads % self.tp_size == 0 self.num_heads = self.total_num_heads // self.tp_size - self.total_num_kv_heads = config.num_key_value_heads or self.total_num_heads + self.total_num_kv_heads = config.num_key_value_heads \ + or self.total_num_heads if self.total_num_kv_heads >= self.tp_size: assert self.total_num_kv_heads % self.tp_size == 0 else: @@ -690,7 +684,7 @@ def forward( self, images: torch.Tensor, image_masks: torch.Tensor ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) + # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501 batch_size, num_image = images.shape[:2] images = images.to(device=self.device, dtype=self.dtype) image_features = self.encode_image(images) @@ -760,7 +754,8 @@ def __init__( quant_config=quant_config, ) - decoder_layer = MolmoDecoderNormAfterLayer if config.norm_after else MolmoDecoderLayer + decoder_layer = MolmoDecoderNormAfterLayer if config.norm_after \ + else MolmoDecoderLayer self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: decoder_layer(config, cache_config, quant_config), @@ -897,14 +892,14 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int, left_margin, right_margin = image_processor.overlap_margins max_crops = image_processor.max_crops - # Assume: prompt_token_ids always starts with bos_token_id followed image tokens + # Assume: prompt_token_ids always starts with bos_token_id followed image tokens # noqa: E501 max_llm_image_tokens = get_max_molmo_image_tokens(ctx) if seq_len - max_llm_image_tokens - 1 < 0: raise RuntimeError( f"Molmo cannot process {max_crops} crops in a prompt, " "please increase max_model_len or reduce number of crops") - # The vertical image has the maximum number of image tokens due to column tokens. + # The vertical image has the maximum number of image tokens due to column tokens. # noqa: E501 tiling = (max_crops, 1) total_margin_pixels = base_image_input_d * (right_margin + left_margin) crop_patches = image_processor.base_image_input_size[ @@ -984,7 +979,10 @@ def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs): ) n_pixels = image_patch_size * image_patch_size * 3 n_patches = image_num_patch[0] * image_num_patch[1] - tokens_per_image = image_processor.image_token_length_w * image_processor.image_token_length_h + + image_length_w = image_processor.image_token_length_w + image_length_h = image_processor.image_token_length_h + tokens_per_image = image_length_w * image_length_h images = torch.full( (max_total_crops, n_patches, n_pixels), -1, @@ -1202,7 +1200,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): projector_weight = dict() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: - log.info(f"Skipping {name}") continue if self.config.tie_word_embeddings and "lm_head.weight" in name: continue From 46bad99485606a7b4c5a77d72b3ef9125a5d4e3b Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 13 Oct 2024 22:24:07 -0700 Subject: [PATCH 08/12] fix comment change --- examples/offline_inference_vision_language.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 8f35ca78a3b6..4c88dcc2f087 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -317,7 +317,7 @@ def run_molmo(question, modality): return llm, prompt, stop_token_ids -# GLM4V +# GLM-4v def run_glm4v(question: str, modality: str): assert modality == "image" model_name = "THUDM/glm-4v-9b" From ed41e0b5c69750bae248b77422852fb988ea651e Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Mon, 14 Oct 2024 00:48:02 -0700 Subject: [PATCH 09/12] add NOTE --- vllm/model_executor/models/molmo.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 503eaad54279..12588f8426b1 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -874,6 +874,8 @@ def get_max_molmo_image_tokens(ctx: InputContext) -> int: return max_llm_image_tokens +# NOTE: preprocessing for the image data has been included in the +# 'input_processor_for_molmo' function def image_input_mapper_for_molmo( ctx: InputContext, data: object, @@ -951,6 +953,9 @@ def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs): trust_remote_code=True, revision=ctx.model_config.code_revision) + # NOTE: message formatting for raw text prompt is only applied for + # offline inference; for online inference, the prompt is always in + # instruction format and tokenized. if prompt is not None and re.match(r"^User:[\s\S]*?(Assistant:)*$", prompt): out = processor.process(prompt, image, message_format="none") From aff4040ccc2937865ee5c6bdbbe42e6ea4684458 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Mon, 14 Oct 2024 00:58:53 -0700 Subject: [PATCH 10/12] cleanup weight loading and error handling --- vllm/model_executor/models/molmo.py | 9 +++------ vllm/model_executor/models/qwen2_vl.py | 3 +-- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 12588f8426b1..ccfee165368e 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1261,14 +1261,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue param = params_dict[name] except KeyError: - print(params_dict.keys()) - raise + raise ValueError(f"Unexpected weight: {name}") from None + weight_loader = getattr(param, "weight_loader", default_weight_loader) - try: - weight_loader(param, loaded_weight) - except: - raise + weight_loader(param, loaded_weight) gate_up_proj_weight = torch.cat( [projector_weight["gate_proj"], projector_weight["up_proj"]], diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 24fd5152ecd0..4a39b3fbe5a4 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1167,8 +1167,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue param = params_dict[name] except KeyError: - print(params_dict.keys()) - raise + raise ValueError(f"Unexpected weight: {name}") from None weight_loader = getattr(param, "weight_loader", default_weight_loader) From d48a28f5a238a493b2c57a8e1c2c549245d96531 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Mon, 14 Oct 2024 01:28:52 -0700 Subject: [PATCH 11/12] add to documentation --- docs/source/models/supported_models.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index bf86a72e20b5..926ffab6d928 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -399,6 +399,12 @@ Text Generation - :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc. - - + * - :code:`MolmoForCausalLM` + - Molmo + - Image + - :code:`allenai/Molmo-7B-D-0924`, :code:`allenai/Molmo-72B-0924`, etc. + - + - ✅︎ * - :code:`NVLM_D_Model` - NVLM-D 1.0 - Image\ :sup:`E+` From e4c74b355d47fd5966a0aaef6fde2229e9a62db4 Mon Sep 17 00:00:00 2001 From: sanghol Date: Wed, 16 Oct 2024 00:22:53 +0000 Subject: [PATCH 12/12] bug fix for text-only examples --- vllm/model_executor/models/molmo.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index ccfee165368e..b04916f17088 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -946,9 +946,12 @@ def pad_images( def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs): - prompt = llm_inputs["prompt"] - multi_modal_data = llm_inputs.get("multi_modal_data") - image = multi_modal_data.get("image") + prompt = llm_inputs.get("prompt", None) + multi_modal_data = llm_inputs.get("multi_modal_data", None) + if multi_modal_data is not None: + image = multi_modal_data.get("image", None) + else: + image = None processor = cached_get_processor(ctx.model_config.model, trust_remote_code=True, revision=ctx.model_config.code_revision)