Skip to content

Commit

Permalink
[Misc] Pass attention to impl backend
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
  • Loading branch information
wangxiyuan committed Jan 20, 2025
1 parent 5c89a29 commit 78dedb9
Show file tree
Hide file tree
Showing 11 changed files with 55 additions and 71 deletions.
3 changes: 1 addition & 2 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,12 @@ def __init__(
@abstractmethod
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
11 changes: 5 additions & 6 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,12 @@ def __init__(

def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: BlocksparseFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand Down Expand Up @@ -401,8 +400,8 @@ def forward(
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

if prefill_meta := attn_metadata.prefill_metadata:
Expand Down Expand Up @@ -439,8 +438,8 @@ def forward(
self.num_kv_heads,
self.scale,
self.alibi_slopes,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
tp_rank=self.tp_rank,
blocksparse_local_blocks=self.local_blocks,
blocksparse_vert_stride=self.vert_stride,
Expand Down
9 changes: 4 additions & 5 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,13 +634,12 @@ def __init__(

def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Expand All @@ -657,7 +656,7 @@ def forward(
NOTE: It in-place updates the output tensor.
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, (
assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")

assert output is not None, "Output tensor must be provided."
Expand Down Expand Up @@ -709,8 +708,8 @@ def forward(
kv_cache[1],
updated_slot_mapping.flatten(), # type: ignore[union-attr]
kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

(num_prefill_query_tokens, num_prefill_kv_tokens,
Expand Down
15 changes: 7 additions & 8 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,13 +792,12 @@ def __init__(

def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:

Expand Down Expand Up @@ -826,8 +825,8 @@ def forward(
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
Expand Down Expand Up @@ -886,8 +885,8 @@ def forward(
kv_cache,
logits_soft_cap=logits_soft_cap,
causal=True,
k_scale=k_scale,
v_scale=v_scale,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
window_left=window_left)
if decode_meta := attn_metadata.decode_metadata:
assert decode_meta is not None
Expand All @@ -897,8 +896,8 @@ def forward(
kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
window_left=window_left)

if prefill_output is None and decode_output is not None:
Expand Down
3 changes: 1 addition & 2 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,12 @@ def __init__(

def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand Down
17 changes: 8 additions & 9 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,12 @@ def split_kv_cache(

def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: IpexAttnMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
Expand All @@ -193,7 +192,7 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
Expand All @@ -210,8 +209,8 @@ def forward(
value_cache,
attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

if attn_metadata.is_prompt:
Expand Down Expand Up @@ -296,8 +295,8 @@ def forward(
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
else:
# Run PagedAttention V2.
Expand Down Expand Up @@ -329,8 +328,8 @@ def forward(
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

# Reshape the output tensor.
Expand Down
5 changes: 2 additions & 3 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,12 @@ def __init__(

def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Expand All @@ -173,7 +172,7 @@ def forward(
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
Expand Down
19 changes: 9 additions & 10 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,13 +414,12 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:

def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand Down Expand Up @@ -458,8 +457,8 @@ def forward(
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

num_prefill_tokens = attn_metadata.num_prefill_tokens
Expand Down Expand Up @@ -567,8 +566,8 @@ def forward(
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

if decode_meta := attn_metadata.decode_metadata:
Expand Down Expand Up @@ -613,8 +612,8 @@ def forward(
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
else:
output[num_prefill_tokens:] = PagedAttention.forward_decode(
Expand All @@ -628,8 +627,8 @@ def forward(
self.num_kv_heads,
self.scale,
self.alibi_slopes,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

# Reshape the output tensor.
Expand Down
17 changes: 7 additions & 10 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,13 +429,12 @@ def __init__(

def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: TorchSDPAMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Expand All @@ -451,7 +450,7 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
Expand Down Expand Up @@ -493,11 +492,9 @@ def forward(
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping

PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype,
k_scale, v_scale)
PagedAttention.write_to_paged_cache(
key, value, key_cache, value_cache, updated_slot_mapping,
self.kv_cache_dtype, layer._k_scale, layer._v_scale)

if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
Expand Down Expand Up @@ -571,8 +568,8 @@ def forward(
self.num_kv_heads,
self.scale,
self.alibi_slopes,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

# Reshape the output tensor.
Expand Down
19 changes: 8 additions & 11 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,13 +412,12 @@ def __init__(

def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: Optional[torch.Tensor],
value: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: "XFormersMetadata",
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand Down Expand Up @@ -524,11 +523,9 @@ def forward(
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
# profiling run.
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype,
k_scale, v_scale)
PagedAttention.write_to_paged_cache(
key, value, key_cache, value_cache, updated_slot_mapping,
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
(num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
Expand Down Expand Up @@ -580,8 +577,8 @@ def forward(
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
assert output[:num_prefill_query_tokens].shape == out.shape
output[:num_prefill_query_tokens] = out
Expand All @@ -607,8 +604,8 @@ def forward(
self.num_kv_heads,
self.scale,
self.alibi_slopes,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)

# Reshape the output tensor.
Expand Down
Loading

0 comments on commit 78dedb9

Please sign in to comment.