Skip to content

Commit

Permalink
[VLM][Model] TP support for ViTs (vllm-project#7186)
Browse files Browse the repository at this point in the history
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
  • Loading branch information
3 people authored and siddharth9820 committed Sep 30, 2024
1 parent 293fce0 commit 8e7e982
Show file tree
Hide file tree
Showing 9 changed files with 340 additions and 289 deletions.
3 changes: 1 addition & 2 deletions tests/models/test_intern_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModel, CLIPImageProcessor

from vllm.model_executor.models.intern_vit import InternVisionModel

from ..conftest import _ImageAssets, cleanup

pytestmark = pytest.mark.vlm
Expand Down Expand Up @@ -49,6 +47,7 @@ def run_intern_vit_test(
for pixel_value in pixel_values
]

from vllm.model_executor.models.intern_vit import InternVisionModel
vllm_model = InternVisionModel(config)
vllm_model.load_weights(hf_model.state_dict().items())

Expand Down
63 changes: 31 additions & 32 deletions tests/models/test_internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from PIL.Image import Image
from transformers import AutoConfig

from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END,
IMG_START,
image_to_pixel_values)
from vllm.multimodal.utils import rescale_image_size
from vllm.utils import is_cpu

Expand All @@ -33,35 +30,6 @@
]


class InternVLProcessor:
"""A simple processor for InternVL2 HF model which misses a processor."""

def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
self.dtype = hf_runner.model.dtype

self.config = AutoConfig.from_pretrained(hf_runner.model_name)
self.vision_config = self.config.vision_config
self.use_thumbnail = self.config.use_thumbnail
self.min_num = self.config.min_dynamic_patch
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size

def __call__(self, text: str, images: Image, **kwargs):
pixel_values = image_to_pixel_values(images, self.image_size,
self.min_num, self.max_num,
self.use_thumbnail).to(self.dtype)
num_patches_list = [pixel_values.shape[0]]
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token * num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)
prompt = self.tokenizer(text, return_tensors="pt")
prompt.update({"pixel_values": pixel_values})
return prompt


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py
def generate(
self,
Expand Down Expand Up @@ -127,6 +95,37 @@ def run_test(
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).

class InternVLProcessor:
"""A simple processor for InternVL2 which misses a processor."""

def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
self.dtype = hf_runner.model.dtype

self.config = AutoConfig.from_pretrained(hf_runner.model_name)
self.vision_config = self.config.vision_config
self.use_thumbnail = self.config.use_thumbnail
self.min_num = self.config.min_dynamic_patch
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size

def __call__(self, text: str, images: Image, **kwargs):
from vllm.model_executor.models.internvl import (
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
pixel_values = image_to_pixel_values(
images, self.image_size, self.min_num, self.max_num,
self.use_thumbnail).to(self.dtype)
num_patches_list = [pixel_values.shape[0]]
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token \
* num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)
prompt = self.tokenizer(text, return_tensors="pt")
prompt.update({"pixel_values": pixel_values})
return prompt

# max_model_len should be greater than image_feature_size
with vllm_runner(model,
max_model_len=4096,
Expand Down
79 changes: 76 additions & 3 deletions vllm/model_executor/models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import torch.nn as nn
from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig
from transformers.models.blip.modeling_blip import BlipAttention
from xformers import ops as xops

from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.utils import (cached_get_tokenizer,
Expand Down Expand Up @@ -154,6 +156,77 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return embeddings


class BlipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout

self.qkv = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
)
self.projection = RowParallelLinear(
self.embed_dim,
self.embed_dim,
quant_config=quant_config,
)

self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()

qkv_states, _ = self.qkv(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)

out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.projection(out)

return attn_output


class BlipMLP(nn.Module):

def __init__(self,
Expand Down Expand Up @@ -188,7 +261,7 @@ def __init__(self,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()

self.self_attn = BlipAttention(config)
self.self_attn = BlipAttention(config, quant_config=quant_config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = BlipMLP(config, quant_config=quant_config)
Expand All @@ -199,7 +272,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states

residual = hidden_states
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
use_default_weight_loading = False
if "vision" in name:
if self.vision_model is not None:
# We only do sharding for language model and
# not vision model for now.
# BlipVisionModel does not need sharding
use_default_weight_loading = True
else:
for (param_name, weight_name,
Expand Down
105 changes: 98 additions & 7 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig
from transformers.models.clip.modeling_clip import CLIPAttention
from xformers import ops as xops

from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -160,6 +162,78 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return embeddings


class CLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout

self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
)

self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
)

self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()

qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)

query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)

out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.out_proj(out)

return attn_output


class CLIPMLP(nn.Module):

def __init__(self,
Expand Down Expand Up @@ -192,7 +266,7 @@ def __init__(self,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()

self.self_attn = CLIPAttention(config)
self.self_attn = CLIPAttention(config, quant_config=quant_config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config)
Expand All @@ -204,7 +278,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states

residual = hidden_states
Expand Down Expand Up @@ -304,7 +378,15 @@ def forward(self, pixel_values: Optional[torch.Tensor] = None):
def device(self):
return next(self.parameters()).device

# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers)

Expand All @@ -318,7 +400,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if layer_idx >= layer_count:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue

param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
Loading

0 comments on commit 8e7e982

Please sign in to comment.