Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support SDPA attention for Molmo vision backbone #9410

Merged
merged 6 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 19 additions & 33 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
from torch.nn import functional as F
from transformers import PretrainedConfig

import vllm.envs as envs
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.selector import (_Backend, backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.attention.selector import _Backend
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand All @@ -42,11 +40,12 @@
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import make_layers
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
from vllm.platforms import current_platform
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.transformers_utils.processor import get_processor

from .utils import get_vit_attn_backend

log = logging.getLogger(__name__)

# TODO: hard-coded for now. Consider making it configurable.
Expand Down Expand Up @@ -189,35 +188,17 @@ def __init__(
)

# Detect attention implementation.
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None:
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.get_device_capability()[0] >= 8
if device_available:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
self._use_flash_attn = True
else:
log.warning(
"Current Molmo implementation has a bug with "
"`vllm-flash-attn` inside vision module, so we use "
"xformers backend instead. You can run `pip install "
"flash-attn to use flash-attention backend.")
self._use_flash_attn = False
else:
self._use_flash_attn = False
self._use_flash_attn = self._use_sdpa = self._use_xformers = False
selected_backend: _Backend = get_vit_attn_backend()
if selected_backend == _Backend.FLASH_ATTN:
self._use_flash_attn = True
elif selected_backend == _Backend.XFORMERS:
self._use_xformers = True
elif selected_backend == _Backend.TORCH_SDPA:
self._use_sdpa = True
else:
if selected_backend == _Backend.FLASH_ATTN:
self._use_flash_attn = True
elif selected_backend == _Backend.XFORMERS:
self._use_flash_attn = False
else:
raise RuntimeError(
f"Molmo does not support {selected_backend} backend now.")
raise RuntimeError(
f"Molmo does not support {selected_backend} backend now.")

def forward(self,
inputs_q: torch.Tensor,
Expand All @@ -242,7 +223,12 @@ def forward(self,
if self._use_flash_attn:
from flash_attn import flash_attn_func
output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False)
else:
elif self._use_sdpa:
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
xq, xk, xv = (rearrange(x, "b s h d -> b h s d")
for x in (xq, xk, xv))
output = F.scaled_dot_product_attention(xq, xk, xv)
output = rearrange(output, "b h s d -> b s h d ")
elif self._use_xformers:
from xformers import ops as xops
output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0)

Expand Down
53 changes: 15 additions & 38 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
make_batched_images, make_batched_videos, smart_resize)

import vllm.envs as envs
from vllm.attention import AttentionMetadata
from vllm.attention.selector import (_Backend, backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.attention.selector import _Backend
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group, parallel_state
from vllm.distributed import utils as dist_utils
Expand All @@ -60,15 +58,14 @@
MultiModalInputs)
from vllm.multimodal.base import MultiModalData
from vllm.multimodal.image import cached_get_image_processor
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
Qwen2VLVisionConfig)
from vllm.transformers_utils.processor import get_processor
from vllm.utils import is_cpu

from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
from .utils import (PPMissingLayer, get_vit_attn_backend,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory)

logger = init_logger(__name__)
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -213,37 +210,17 @@ def __init__(
quant_config=quant_config)

# Detect attention implementation.
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None:
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.has_device_capability(80)
if device_available:
from transformers.utils import is_flash_attn_2_available

if is_flash_attn_2_available():
self._use_flash_attn = True
else:
logger.warning(
"Current Qwen2-VL implementation has a bug with "
"`vllm-flash-attn` inside vision module, so we use "
"xformers backend instead. You can run `pip install "
"flash-attn to use flash-attention backend.")
self._use_flash_attn = False
else:
self._use_flash_attn = False
self._use_flash_attn = self._use_sdpa = self._use_xformers = False
selected_backend: _Backend = get_vit_attn_backend()
if selected_backend == _Backend.FLASH_ATTN:
self._use_flash_attn = True
elif selected_backend == _Backend.XFORMERS:
self._use_xformers = True
elif selected_backend == _Backend.TORCH_SDPA:
self._use_sdpa = True
else:
if selected_backend == _Backend.FLASH_ATTN:
self._use_flash_attn = True
elif selected_backend == _Backend.XFORMERS:
self._use_flash_attn = False
else:
raise RuntimeError(
f"Qwen2-VL does not support {selected_backend} backend now."
)
raise RuntimeError(
f"Qwen2-VL does not support {selected_backend} backend now.")

def forward(
self,
Expand Down Expand Up @@ -293,7 +270,7 @@ def forward(
context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
elif is_cpu():
elif self._use_sdpa:
seq_length = q.size(1)
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]]
attention_mask = torch.zeros([1, seq_length, seq_length],
Expand All @@ -308,7 +285,7 @@ def forward(
attention_mask,
dropout_p=0.0)
context_layer = rearrange(output, "b h s d -> b s h d ")
else:
elif self._use_xformers:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask

Expand Down
35 changes: 34 additions & 1 deletion vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,22 @@
from torch.func import functional_call
from transformers import PretrainedConfig

import vllm.envs as envs
from vllm.attention.selector import (_Backend, backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import NestedTensors
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available
from vllm.utils import is_cpu, is_pin_memory_available

logger = init_logger(__name__)

WeightsMapping = Mapping[str, Optional[str]]
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
Expand Down Expand Up @@ -482,3 +489,29 @@ def __getattr__(self, key: str):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
llm = super().__getattr__(self.model_name)
return llm(*args, **kwargs)


def get_vit_attn_backend() -> _Backend:
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None:
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.has_device_capability(80)
if device_available:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
selected_backend = _Backend.FLASH_ATTN
else:
logger.warning(
"Current `vllm-flash-attn` has a bug inside vision module, "
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend.")
selected_backend = _Backend.XFORMERS
elif is_cpu():
selected_backend = _Backend.TORCH_SDPA
else:
selected_backend = _Backend.XFORMERS
return selected_backend