Skip to content

Commit

Permalink
[Misc] Take user preference in attention selector (vllm-project#4960)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored May 22, 2024
1 parent a36de68 commit ee3eea0
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 61 deletions.
84 changes: 84 additions & 0 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
from unittest.mock import patch

import pytest
import torch

from vllm.attention.selector import which_attn_to_use


@pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
def test_env(name: str, device: str):
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
"""
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
os.environ["VLLM_ATTENTION_BACKEND"] = name

if device == "cpu":
with patch("vllm.attention.selector.is_cpu", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.is_hip", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
assert backend.name == "ROCM_FLASH"
else:
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
assert backend.name == name

if name_backup is not None:
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup


def test_flash_attn():
"""Test FlashAttn validation."""
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN"

# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=[7, 5]):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN"

# Unsupported data type
backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16)
assert backend.name != "FLASH_ATTN"

# Unsupported kv cache data type
backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16)
assert backend.name != "FLASH_ATTN"

# Unsupported block size
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8)
assert backend.name != "FLASH_ATTN"

# Unsupported sliding window
backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN"

# flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN"

# Unsupported head size
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN"

if name_backup is not None:
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup


def test_invalid_env():
"""Throw an exception if the backend name is invalid."""
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID"
with pytest.raises(ValueError):
which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup
1 change: 1 addition & 0 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def forward(
)

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.block_tables is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
output = flash_attn_varlen_func(
Expand Down
145 changes: 84 additions & 61 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,16 @@ def get_attn_backend(
kv_cache_dtype: Optional[str],
block_size: int,
) -> Type[AttentionBackend]:
backend = _which_attn_to_use(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size)
"""Determine which attention backend to use and only import
the selected backend module.
"""
backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size)
if backend == _Backend.FLASH_ATTN:
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)

# We check it here not in _which_attn_to_use because we cannot know
# the head size until we import FlashAttentionBackend.
supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size in supported_head_sizes:
logger.info("Using FlashAttention-2 backend.")
return FlashAttentionBackend
logger.info(
"Cannot use FlashAttention-2 backend for head size %d. "
"Using XFormers backend instead.", head_size)
backend = _Backend.XFORMERS

return FlashAttentionBackend
if backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
Expand All @@ -64,14 +56,15 @@ def get_attn_backend(
return TorchSDPABackend
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is enforced for the Flashinfer backend.")
logger.warning("Eager mode is required for the Flashinfer backend. "
"Please make sure --enforce-eager is set.")
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
else:
raise ValueError("Invalid attention backend.")


def _which_attn_to_use(
def which_attn_to_use(
num_heads: int,
head_size: int,
num_kv_heads: int,
Expand All @@ -81,54 +74,84 @@ def _which_attn_to_use(
block_size: int,
) -> _Backend:
"""Returns which flash attention backend to use."""

# Default case.
selected_backend = _Backend.FLASH_ATTN

# Check the environment variable and override if specified
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
backend_members = _Backend.__members__
if backend_by_env_var not in backend_members:
raise ValueError(
f"Invalid attention backend '{backend_by_env_var}'. "
f"Available backends: {', '.join(backend_members)} "
"(case-sensitive).")
selected_backend = _Backend[backend_by_env_var]

if is_cpu():
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA

if is_hip():
# AMD GPUs.
if torch.cuda.get_device_capability()[0] != 9:
# not Instinct series GPUs.
logger.info("flash_atten is not supported on NAVI GPUs.")
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
if torch.cuda.get_device_capability()[0] != 9:
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH

# NVIDIA GPUs.
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
return _Backend.XFORMERS

if dtype not in (torch.float16, torch.bfloat16):
logger.info("Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
return _Backend.XFORMERS

if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.")
return _Backend.XFORMERS

if block_size % 16 != 0:
logger.info("Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
return _Backend.XFORMERS

if sliding_window is not None:
logger.info(
"Cannot use FlashAttention-2 backend due to sliding window.")
return _Backend.XFORMERS

try:
import vllm_flash_attn # noqa: F401
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the vllm_flash_attn "
"package is not found. `pip install vllm-flash-attn` for better "
"performance.")
return _Backend.XFORMERS

backend_by_env_var = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
return _Backend[backend_by_env_var]

# Default case.
return _Backend.FLASH_ATTN
# FlashAttn in NVIDIA GPUs.
if selected_backend == _Backend.FLASH_ATTN:
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
selected_backend = _Backend.XFORMERS
elif dtype not in (torch.float16, torch.bfloat16):
logger.info(
"Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
selected_backend = _Backend.XFORMERS
elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
logger.info(
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
selected_backend = _Backend.XFORMERS
elif block_size % 16 != 0:
logger.info(
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
selected_backend = _Backend.XFORMERS
elif sliding_window is not None:
logger.info(
"Cannot use FlashAttention-2 backend due to sliding window.")
selected_backend = _Backend.XFORMERS

# FlashAttn is valid for the model, checking if the package is installed.
if selected_backend == _Backend.FLASH_ATTN:
try:
import vllm_flash_attn # noqa: F401

from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)

supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in supported_sizes:
logger.info(
"Cannot use FlashAttention-2 backend for head size %d.",
head_size)
selected_backend = _Backend.XFORMERS
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the "
"vllm_flash_attn package is not found. "
"`pip install vllm-flash-attn` for better performance.")
selected_backend = _Backend.XFORMERS

return selected_backend

0 comments on commit ee3eea0

Please sign in to comment.