Skip to content

Commit

Permalink
[Model] Support SDPA attention for Molmo vision backbone (#9410)
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored Oct 16, 2024
1 parent 59230ef commit cf1d62a
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 78 deletions.
52 changes: 15 additions & 37 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import math
import re
from array import array
Expand All @@ -14,10 +13,8 @@
from torch.nn import functional as F
from transformers import PretrainedConfig

import vllm.envs as envs
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.selector import (_Backend, backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.attention.selector import _Backend
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand All @@ -43,12 +40,11 @@
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import make_layers
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
from vllm.platforms import current_platform
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.transformers_utils.processor import get_processor

log = logging.getLogger(__name__)
from .utils import get_vit_attn_backend

# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9]
Expand Down Expand Up @@ -190,35 +186,12 @@ def __init__(
)

# Detect attention implementation.
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None:
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.get_device_capability()[0] >= 8
if device_available:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
self._use_flash_attn = True
else:
log.warning(
"Current Molmo implementation has a bug with "
"`vllm-flash-attn` inside vision module, so we use "
"xformers backend instead. You can run `pip install "
"flash-attn to use flash-attention backend.")
self._use_flash_attn = False
else:
self._use_flash_attn = False
else:
if selected_backend == _Backend.FLASH_ATTN:
self._use_flash_attn = True
elif selected_backend == _Backend.XFORMERS:
self._use_flash_attn = False
else:
raise RuntimeError(
f"Molmo does not support {selected_backend} backend now.")
self.attn_backend: _Backend = get_vit_attn_backend()
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
}:
raise RuntimeError(
f"Molmo does not support {self.attn_backend} backend now.")

def forward(self,
inputs_q: torch.Tensor,
Expand All @@ -240,10 +213,15 @@ def forward(self,
xk = xk.view(*kv_shape)
xv = xv.view(*kv_shape)

if self._use_flash_attn:
if self.attn_backend == _Backend.FLASH_ATTN:
from flash_attn import flash_attn_func
output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False)
else:
elif self.attn_backend == _Backend.TORCH_SDPA:
xq, xk, xv = (rearrange(x, "b s h d -> b h s d")
for x in (xq, xk, xv))
output = F.scaled_dot_product_attention(xq, xk, xv)
output = rearrange(output, "b h s d -> b s h d ")
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0)

Expand Down
52 changes: 12 additions & 40 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,8 @@
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
make_batched_images, make_batched_videos, smart_resize)

import vllm.envs as envs
from vllm.attention import AttentionMetadata
from vllm.attention.selector import (_Backend, backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.attention.selector import _Backend
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group, parallel_state
from vllm.distributed import utils as dist_utils
Expand All @@ -63,14 +61,13 @@
MultiModalInputs)
from vllm.multimodal.base import MultiModalData
from vllm.multimodal.image import cached_get_image_processor
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import get_processor
from vllm.utils import is_cpu

from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
from .utils import (PPMissingLayer, get_vit_attn_backend,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory)

logger = init_logger(__name__)
Expand Down Expand Up @@ -215,37 +212,12 @@ def __init__(
quant_config=quant_config)

# Detect attention implementation.
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None:
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.has_device_capability(80)
if device_available:
from transformers.utils import is_flash_attn_2_available

if is_flash_attn_2_available():
self._use_flash_attn = True
else:
logger.warning(
"Current Qwen2-VL implementation has a bug with "
"`vllm-flash-attn` inside vision module, so we use "
"xformers backend instead. You can run `pip install "
"flash-attn to use flash-attention backend.")
self._use_flash_attn = False
else:
self._use_flash_attn = False
else:
if selected_backend == _Backend.FLASH_ATTN:
self._use_flash_attn = True
elif selected_backend == _Backend.XFORMERS:
self._use_flash_attn = False
else:
raise RuntimeError(
f"Qwen2-VL does not support {selected_backend} backend now."
)
self.attn_backend: _Backend = get_vit_attn_backend()
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
}:
raise RuntimeError(
f"Qwen2-VL does not support {self.attn_backend} backend now.")

def forward(
self,
Expand Down Expand Up @@ -274,7 +246,7 @@ def forward(
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)

if self._use_flash_attn:
if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
from flash_attn import flash_attn_varlen_func
Expand All @@ -295,7 +267,7 @@ def forward(
context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
elif is_cpu():
elif self.attn_backend == _Backend.TORCH_SDPA:
seq_length = q.size(1)
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]]
attention_mask = torch.zeros([1, seq_length, seq_length],
Expand All @@ -310,7 +282,7 @@ def forward(
attention_mask,
dropout_p=0.0)
context_layer = rearrange(output, "b h s d -> b s h d ")
else:
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask

Expand Down
35 changes: 34 additions & 1 deletion vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,22 @@
from torch.func import functional_call
from transformers import PretrainedConfig

import vllm.envs as envs
from vllm.attention.selector import (_Backend, backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import NestedTensors
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available
from vllm.utils import is_cpu, is_pin_memory_available

logger = init_logger(__name__)

WeightsMapping = Mapping[str, Optional[str]]
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
Expand Down Expand Up @@ -487,3 +494,29 @@ def __getattr__(self, key: str):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
llm = super().__getattr__(self.model_name)
return llm(*args, **kwargs)


def get_vit_attn_backend() -> _Backend:
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None:
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.has_device_capability(80)
if device_available:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
selected_backend = _Backend.FLASH_ATTN
else:
logger.warning(
"Current `vllm-flash-attn` has a bug inside vision module, "
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend.")
selected_backend = _Backend.XFORMERS
elif is_cpu():
selected_backend = _Backend.TORCH_SDPA
else:
selected_backend = _Backend.XFORMERS
return selected_backend

0 comments on commit cf1d62a

Please sign in to comment.