Skip to content

Commit

Permalink
[Platform] Add output for Attention Backend (#11981)
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
  • Loading branch information
wangxiyuan authored Jan 14, 2025
1 parent 1f18adb commit 2e0e017
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 5 deletions.
4 changes: 4 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class AttentionType:

class AttentionBackend(ABC):
"""Abstract class for attention backends."""
# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False

@staticmethod
@abstractmethod
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

class FlashAttentionBackend(AttentionBackend):

accept_output_buffer: bool = True

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
Expand Down
6 changes: 1 addition & 5 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,7 @@ def __init__(
self.use_direct_call = not current_platform.is_cuda_alike(
) and not current_platform.is_cpu()

# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
self.use_output = self.backend == _Backend.FLASH_ATTN or \
self.backend == _Backend.FLASH_ATTN_VLLM_V1
self.use_output = attn_backend.accept_output_buffer
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

class FlashAttentionBackend(AttentionBackend):

accept_output_buffer: bool = True

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
Expand Down

0 comments on commit 2e0e017

Please sign in to comment.