From ec266536b7c4d4d308566ac928a69fcb9ef94462 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 3 Sep 2024 21:37:52 +0800 Subject: [PATCH] [Bugfix][VLM] Add fallback to SDPA for ViT model running on CPU backend (#8061) --- vllm/model_executor/models/blip.py | 25 ++++++-- vllm/model_executor/models/clip.py | 28 +++++++-- vllm/model_executor/models/intern_vit.py | 79 +++++++++++++++++++++--- vllm/model_executor/models/paligemma.py | 42 +++++++------ vllm/model_executor/models/siglip.py | 27 ++++++-- 5 files changed, 157 insertions(+), 44 deletions(-) diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index e6acf8cd5d5b..583d5d217903 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -7,7 +7,7 @@ import torch.nn as nn from PIL import Image from transformers import Blip2VisionConfig, BlipVisionConfig -from xformers import ops as xops +from transformers.models.blip.modeling_blip import BlipAttention from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -21,6 +21,12 @@ repeat_and_pad_placeholder_tokens) from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +try: + from xformers import ops as xops + USE_XFORMERS_OPS = True +except ImportError: + USE_XFORMERS_OPS = False + def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: assert image_size % patch_size == 0 @@ -156,7 +162,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return embeddings -class BlipAttention(nn.Module): +class BlipParallelAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( @@ -224,7 +230,7 @@ def forward( out = out.view(bsz, tgt_len, -1) attn_output, _ = self.projection(out) - return attn_output + return attn_output, None class BlipMLP(nn.Module): @@ -261,7 +267,16 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.self_attn = BlipAttention(config, quant_config=quant_config) + # fallback to sdpa attention if tp unavailable + num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + if USE_XFORMERS_OPS and num_heads % tp_size == 0: + self.self_attn = BlipParallelAttention(config, + quant_config=quant_config) + else: + # Blip doesn't have SDPA attention implemented in transformers + # use eager attention instead for cpu backend + self.self_attn = BlipAttention(config) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = BlipMLP(config, quant_config=quant_config) @@ -272,7 +287,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states = self.self_attn(hidden_states=hidden_states) + hidden_states, _ = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states residual = hidden_states diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index ddfec91d6cab..b581a501e333 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -7,7 +7,7 @@ import torch.nn as nn from PIL import Image from transformers import CLIPVisionConfig -from xformers import ops as xops +from transformers.models.clip.modeling_clip import CLIPSdpaAttention from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -22,6 +22,12 @@ repeat_and_pad_placeholder_tokens) from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +try: + from xformers import ops as xops + USE_XFORMERS_OPS = True +except ImportError: + USE_XFORMERS_OPS = False + def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: assert image_size % patch_size == 0 @@ -162,7 +168,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return embeddings -class CLIPAttention(nn.Module): +class CLIPParallelAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( @@ -231,7 +237,7 @@ def forward( out = out.view(bsz, tgt_len, -1) attn_output, _ = self.out_proj(out) - return attn_output + return attn_output, None class CLIPMLP(nn.Module): @@ -266,7 +272,13 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.self_attn = CLIPAttention(config, quant_config=quant_config) + num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + if USE_XFORMERS_OPS and num_heads % tp_size == 0: + self.self_attn = CLIPParallelAttention(config, + quant_config=quant_config) + else: + self.self_attn = CLIPSdpaAttention(config) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = CLIPMLP(config, quant_config=quant_config) @@ -278,7 +290,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states = self.self_attn(hidden_states=hidden_states) + hidden_states, _ = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states residual = hidden_states @@ -365,6 +377,10 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None, num_hidden_layers_override: Optional[int] = None): super().__init__() + tp_size = get_tensor_model_parallel_world_size() + num_heads = config.num_attention_heads + self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 + self.vision_model = CLIPVisionTransformer( config=config, quant_config=quant_config, @@ -386,7 +402,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), - ] + ] if self.shard_weight else [] params_dict = dict(self.named_parameters()) layer_count = len(self.vision_model.encoder.layers) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index ad5919150cad..33b4a3acaa55 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -10,7 +10,6 @@ import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig -from xformers import ops as xops from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -21,6 +20,12 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +try: + from xformers import ops as xops + USE_XFORMERS_OPS = True +except ImportError: + USE_XFORMERS_OPS = False + NORM2FN = { 'rms_norm': RMSNorm, 'layer_norm': nn.LayerNorm, @@ -81,7 +86,7 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: return embeddings -class InternAttention(nn.Module): +class InternParallelAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( @@ -140,18 +145,67 @@ def forward(self, x): k = self.k_norm.forward_native(k.flatten(-2, -1)).view(B_, N_, H_, D_) - x = xops.memory_efficient_attention_forward( - q, - k, - v, - scale=self.scale, - ) + x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale) x = x.view(B, N, -1) x, _ = self.proj(x) return x +class InternSdpaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f'embed_dim must be divisible by num_heads ' + f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' + f' {self.num_heads}).') + + self.scale = self.head_dim**-0.5 + self.qkv = nn.Linear(self.embed_dim, + 3 * self.embed_dim, + bias=config.qkv_bias) + + self.qk_normalization = config.qk_normalization + + if self.qk_normalization: + self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps) + self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps) + + self.proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x) + q, k, v = qkv.chunk(3, dim=-1) + + q = q.view(B, N, self.num_heads, self.head_dim) + k = k.view(B, N, self.num_heads, self.head_dim) + v = v.view(B, N, self.num_heads, self.head_dim) + + if self.qk_normalization: + B_, N_, H_, D_ = q.shape + q = self.q_norm.forward_native(q.flatten(-2, + -1)).view(B_, N_, H_, D_) + k = self.k_norm.forward_native(k.flatten(-2, + -1)).view(B_, N_, H_, D_) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) + x = x.transpose(1, 2).view(B, N, -1) + + x = self.proj(x) + return x + + class InternMLP(nn.Module): def __init__(self, @@ -187,7 +241,14 @@ def __init__(self, self.intermediate_size = config.intermediate_size self.norm_type = config.norm_type - self.attn = InternAttention(config, quant_config=quant_config) + # fallback to sdpa attention if tp unavailable + tp_size = get_tensor_model_parallel_world_size() + num_heads = config.num_attention_heads + if USE_XFORMERS_OPS and num_heads % tp_size == 0: + self.attn = InternParallelAttention(config, + quant_config=quant_config) + else: + self.attn = InternSdpaAttention(config) self.mlp = InternMLP(config, quant_config=quant_config) self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 9b29ff69808a..b6f4275fbc94 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -307,26 +307,30 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if key_to_modify in name: name = name.replace(key_to_modify, new_key) use_default_weight_loading = False - for (param_name, shard_name, shard_id) in stacked_params_mapping: - if shard_name not in name: - continue - name = name.replace(shard_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break + if "vision" not in name or self.vision_tower.shard_weight: + for (param_name, shard_name, + shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # lm_head is not used in vllm as it is tied with + # embed_token. To prevent errors, skip loading + # lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + use_default_weight_loading = True else: - # lm_head is not used in vllm as it is tied with - # embed_token. To prevent errors, skip loading - # lm_head.weight. - if "lm_head.weight" in name: - continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue use_default_weight_loading = True if use_default_weight_loading: diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index e6f95af0ff49..114dbf09b0c5 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -9,7 +9,7 @@ from PIL import Image from torch import nn from transformers import SiglipVisionConfig -from xformers import ops as xops +from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -26,6 +26,12 @@ repeat_and_pad_placeholder_tokens) from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +try: + from xformers import ops as xops + USE_XFORMERS_OPS = True +except ImportError: + USE_XFORMERS_OPS = False + def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: # Since interpolation is applied, the image size need not be divisible @@ -219,7 +225,7 @@ def forward(self, return embeddings -class SiglipAttention(nn.Module): +class SiglipParallelAttention(nn.Module): def __init__( self, @@ -282,7 +288,7 @@ def forward( out = out.view(batch_size, q_len, -1) attn_output, _ = self.out_proj(out) - return attn_output + return attn_output, None class SiglipMLP(nn.Module): @@ -327,7 +333,14 @@ def __init__( super().__init__() self.embed_dim = config.hidden_size - self.self_attn = SiglipAttention(config, quant_config=quant_config) + num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + if USE_XFORMERS_OPS and num_heads % tp_size == 0: + self.self_attn = SiglipParallelAttention(config, + quant_config=quant_config) + else: + self.self_attn = SiglipSdpaAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( @@ -344,7 +357,7 @@ def forward( residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states = self.self_attn(hidden_states=hidden_states) + hidden_states, _ = self.self_attn(hidden_states=hidden_states) hidden_states = residual + hidden_states residual = hidden_states @@ -476,6 +489,10 @@ def __init__( num_hidden_layers_override: Optional[int] = None, ): super().__init__() + num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 + self.vision_model = SiglipVisionTransformer( config, quant_config,