Skip to content

Commit

Permalink
[Bugfix][Kernel] Add head size check for attention backend selection (v…
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored May 21, 2024
1 parent 14772ee commit 99eff67
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
12 changes: 8 additions & 4 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)

_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256]


class FlashAttentionBackend(AttentionBackend):

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]

@staticmethod
def get_name() -> str:
return "flash-attn"
Expand Down Expand Up @@ -237,10 +239,12 @@ def __init__(
# paged KV cache.
raise ValueError(
"Sliding window is not supported in FlashAttention.")
if head_size not in _SUPPORTED_HEAD_SIZES:

support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.")
f"Supported head sizes are: {support_head_sizes}.")

def forward(
self,
Expand Down
16 changes: 13 additions & 3 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,21 @@ def get_attn_backend(
sliding_window, dtype, kv_cache_dtype,
block_size)
if backend == _Backend.FLASH_ATTN:
logger.info("Using FlashAttention-2 backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
elif backend == _Backend.XFORMERS:

# We check it here not in _which_attn_to_use because we cannot know
# the head size until we import FlashAttentionBackend.
supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size in supported_head_sizes:
logger.info("Using FlashAttention-2 backend.")
return FlashAttentionBackend
logger.info(
"Cannot use FlashAttention-2 backend for head size %d. "
"Using XFormers backend instead.", head_size)
backend = _Backend.XFORMERS

if backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
XFormersBackend)
Expand Down

0 comments on commit 99eff67

Please sign in to comment.