Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM][Model] TP support for ViTs #7186

Merged
merged 33 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
0919bd1
feat: replace siglipattention with tp'ed one
ChristopherCho Aug 6, 2024
7cfc98c
feat: tp blip attention
ChristopherCho Aug 6, 2024
079c53f
feat: clip attention replaced
ChristopherCho Aug 6, 2024
9a9af50
fix: style
ChristopherCho Aug 6, 2024
8176c8e
fix: provide qunatization config
ChristopherCho Aug 6, 2024
b3bdbef
fix: return value of attention
ChristopherCho Aug 6, 2024
8927414
fix: add tp config in clip attention
ChristopherCho Aug 6, 2024
22a0a84
feat: tp attention in intern vit
ChristopherCho Aug 6, 2024
f6063b1
fix: style fix
ChristopherCho Aug 6, 2024
8e54ef6
feat: weight loading for clip based models
ChristopherCho Aug 6, 2024
87043e4
feat: weight loading for siglip based models
ChristopherCho Aug 6, 2024
2c18f3a
feat: weight loading for blip based models
ChristopherCho Aug 6, 2024
c6015f5
fix: bug in clip weight loading
ChristopherCho Aug 6, 2024
be0e190
fix: bug in clip attention
ChristopherCho Aug 6, 2024
0121445
fix: bug in blip attention
ChristopherCho Aug 6, 2024
414040f
fix: blip does not require sharding
ChristopherCho Aug 6, 2024
f28aec3
fix: style
ChristopherCho Aug 7, 2024
734fcb1
fix: phi3v weight loading logic fixed
ChristopherCho Aug 7, 2024
f1329c9
fix: make intern vit working
ChristopherCho Aug 7, 2024
c776662
fix: fix for tp input
ChristopherCho Aug 7, 2024
5ad0d22
fix: minor refactoring
ChristopherCho Aug 8, 2024
0a9968e
Merge branch 'main' into tp-support-for-vit
ywang96 Aug 26, 2024
70580a6
format
ywang96 Aug 27, 2024
0593206
cleanup TODO
ywang96 Aug 28, 2024
ccc49dc
Merge branch 'main' into tp-support-for-vit
ChristopherCho Aug 30, 2024
31af673
doc: Todo for adding prefix in clip load weights
ChristopherCho Aug 30, 2024
3652cb9
Fix: use view rather than reshape and contiguous
ChristopherCho Aug 30, 2024
21a146d
Fix: use view rather than reshape and contiguous in clip
ChristopherCho Aug 30, 2024
2d1f639
Fix: use view rather than reshape and contiguous in siglip
ChristopherCho Aug 30, 2024
300e8a9
Merge branch 'tp-support-for-vit' of https://github.com/ChristopherCh…
ChristopherCho Aug 30, 2024
659adc5
feat: option for disabling bias in blip
ChristopherCho Aug 30, 2024
ebf3503
Merge branch 'vllm-project:main' into tp-support-for-vit
ChristopherCho Aug 30, 2024
ffb176b
patch internvl
ywang96 Aug 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/models/test_intern_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModel, CLIPImageProcessor

from vllm.model_executor.models.intern_vit import InternVisionModel

from ..conftest import _ImageAssets, cleanup

pytestmark = pytest.mark.vlm
Expand Down Expand Up @@ -49,6 +47,7 @@ def run_intern_vit_test(
for pixel_value in pixel_values
]

from vllm.model_executor.models.intern_vit import InternVisionModel
vllm_model = InternVisionModel(config)
vllm_model.load_weights(hf_model.state_dict().items())

Expand Down
63 changes: 31 additions & 32 deletions tests/models/test_internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from PIL.Image import Image
from transformers import AutoConfig

from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END,
IMG_START,
image_to_pixel_values)
from vllm.multimodal.utils import rescale_image_size
from vllm.utils import is_cpu

Expand All @@ -33,35 +30,6 @@
]


class InternVLProcessor:
"""A simple processor for InternVL2 HF model which misses a processor."""

def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
self.dtype = hf_runner.model.dtype

self.config = AutoConfig.from_pretrained(hf_runner.model_name)
self.vision_config = self.config.vision_config
self.use_thumbnail = self.config.use_thumbnail
self.min_num = self.config.min_dynamic_patch
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size

def __call__(self, text: str, images: Image, **kwargs):
pixel_values = image_to_pixel_values(images, self.image_size,
self.min_num, self.max_num,
self.use_thumbnail).to(self.dtype)
num_patches_list = [pixel_values.shape[0]]
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token * num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)
prompt = self.tokenizer(text, return_tensors="pt")
prompt.update({"pixel_values": pixel_values})
return prompt


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py
def generate(
self,
Expand Down Expand Up @@ -127,6 +95,37 @@ def run_test(
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).

class InternVLProcessor:
"""A simple processor for InternVL2 which misses a processor."""

def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
self.dtype = hf_runner.model.dtype

self.config = AutoConfig.from_pretrained(hf_runner.model_name)
self.vision_config = self.config.vision_config
self.use_thumbnail = self.config.use_thumbnail
self.min_num = self.config.min_dynamic_patch
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size

def __call__(self, text: str, images: Image, **kwargs):
from vllm.model_executor.models.internvl import (
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
pixel_values = image_to_pixel_values(
images, self.image_size, self.min_num, self.max_num,
self.use_thumbnail).to(self.dtype)
num_patches_list = [pixel_values.shape[0]]
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token \
* num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)
prompt = self.tokenizer(text, return_tensors="pt")
prompt.update({"pixel_values": pixel_values})
return prompt

# max_model_len should be greater than image_feature_size
with vllm_runner(model,
max_model_len=4096,
Expand Down
79 changes: 76 additions & 3 deletions vllm/model_executor/models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import torch.nn as nn
from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig
from transformers.models.blip.modeling_blip import BlipAttention
from xformers import ops as xops

from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.utils import (cached_get_tokenizer,
Expand Down Expand Up @@ -154,6 +156,77 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return embeddings


class BlipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout

self.qkv = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
)
self.projection = RowParallelLinear(
self.embed_dim,
self.embed_dim,
quant_config=quant_config,
)

self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()

qkv_states, _ = self.qkv(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)

out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.projection(out)

return attn_output


class BlipMLP(nn.Module):

def __init__(self,
Expand Down Expand Up @@ -188,7 +261,7 @@ def __init__(self,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()

self.self_attn = BlipAttention(config)
self.self_attn = BlipAttention(config, quant_config=quant_config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = BlipMLP(config, quant_config=quant_config)
Expand All @@ -199,7 +272,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states

residual = hidden_states
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
use_default_weight_loading = False
if "vision" in name:
if self.vision_model is not None:
# We only do sharding for language model and
# not vision model for now.
# BlipVisionModel does not need sharding
ywang96 marked this conversation as resolved.
Show resolved Hide resolved
use_default_weight_loading = True
else:
for (param_name, weight_name,
Expand Down
105 changes: 98 additions & 7 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig
from transformers.models.clip.modeling_clip import CLIPAttention
from xformers import ops as xops

from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -160,6 +162,78 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return embeddings


class CLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout

self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
)

self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
)

self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()

def forward(
self,
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()

qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)

query_states = query_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(bsz, tgt_len,
self.num_heads_per_partition,
self.head_dim)

out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
out = out.view(bsz, tgt_len, -1)
attn_output, _ = self.out_proj(out)

return attn_output


class CLIPMLP(nn.Module):

def __init__(self,
Expand Down Expand Up @@ -192,7 +266,7 @@ def __init__(self,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()

self.self_attn = CLIPAttention(config)
self.self_attn = CLIPAttention(config, quant_config=quant_config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config)
Expand All @@ -204,7 +278,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states

hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states

residual = hidden_states
Expand Down Expand Up @@ -304,7 +378,15 @@ def forward(self, pixel_values: Optional[torch.Tensor] = None):
def device(self):
return next(self.parameters()).device

# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers)

Expand All @@ -318,7 +400,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if layer_idx >= layer_count:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue

param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
Loading
Loading