Skip to content

Commit

Permalink
Revert "[Misc] Add FA2 support to ViT MHA layer (#12355)"
Browse files Browse the repository at this point in the history
This reverts commit f1fc051.
  • Loading branch information
WoosukKwon committed Jan 26, 2025
1 parent 324960a commit 1be4877
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 146 deletions.
126 changes: 0 additions & 126 deletions tests/kernels/test_mha_attn.py

This file was deleted.

25 changes: 5 additions & 20 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,22 +210,18 @@ def __init__(
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype=None,
block_size=16,
is_attention_free=False)
backend = backend_name_to_enum(attn_backend.get_name())
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
backend = _Backend.XFORMERS

self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1,
_Backend.TORCH_SDPA, _Backend.XFORMERS
} else _Backend.TORCH_SDPA

def forward(
Expand All @@ -235,26 +231,15 @@ def forward(
value: torch.Tensor,
) -> torch.Tensor:
"""Input shape: batch_size x seq_len x hidden_size"""
# TODO(Isotr0py): Use existing backend implementations and support FA2
bsz, q_len, _ = query.size()
kv_len = key.size(1)

query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)

if self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1,
}:
from vllm.vllm_flash_attn import flash_attn_func

out = flash_attn_func(query, key, value, softmax_scale=self.scale)
elif self.attn_backend == _Backend.XFORMERS:
if self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops

out = xops.memory_efficient_attention_forward(query,
Expand Down

0 comments on commit 1be4877

Please sign in to comment.