diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 0361dd3bd4ea..0f4568070cfc 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -9,11 +9,13 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) -_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256] - class FlashAttentionBackend(AttentionBackend): + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + @staticmethod def get_name() -> str: return "flash-attn" @@ -237,10 +239,12 @@ def __init__( # paged KV cache. raise ValueError( "Sliding window is not supported in FlashAttention.") - if head_size not in _SUPPORTED_HEAD_SIZES: + + support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: raise ValueError( f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.") + f"Supported head sizes are: {support_head_sizes}.") def forward( self, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 5140c3cc86a3..51c25a81b413 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -34,11 +34,21 @@ def get_attn_backend( sliding_window, dtype, kv_cache_dtype, block_size) if backend == _Backend.FLASH_ATTN: - logger.info("Using FlashAttention-2 backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) - return FlashAttentionBackend - elif backend == _Backend.XFORMERS: + + # We check it here not in _which_attn_to_use because we cannot know + # the head size until we import FlashAttentionBackend. + supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size in supported_head_sizes: + logger.info("Using FlashAttention-2 backend.") + return FlashAttentionBackend + logger.info( + "Cannot use FlashAttention-2 backend for head size %d. " + "Using XFormers backend instead.", head_size) + backend = _Backend.XFORMERS + + if backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 XFormersBackend)