Skip to content

Commit

Permalink
Documentation for fMHA + blockdiag (fairinternal/xformers#459)
Browse files Browse the repository at this point in the history
Co-authored-by: danthe3rd <danthe3rd>

__original_commit__ = fairinternal/xformers@7475b63
  • Loading branch information
danthe3rd authored and xFormers Bot committed Feb 6, 2023
1 parent 7f4fdce commit 00afc12
Show file tree
Hide file tree
Showing 13 changed files with 173 additions and 23 deletions.
9 changes: 6 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## TBD
### Fixed
- fMHA: Fixed BW pass on Sm86/Sm89 GPUs when `K > 64` (RTX 3090, RTX 4090, A6000, ..) [facebookresearch/xformers#631]

### Added
- Added tensor attn bias support to CUTLASS FlashAttention
- Added tensor attn bias grad support to CUTLASS FlashAttention
- Added dropout support to CUTLASS FlashAttention
- fMHA/CUTLASS: Added tensor attn bias support [facebookresearch/xformers#587] - contribution from [@jfc4050](https://github.com/jfc4050)
- fMHA/CUTLASS: Added tensor attn bias grad support [facebookresearch/xformers#587] - contribution from [@jfc4050](https://github.com/jfc4050)
- fMHA/CUTLASS: Added dropout support [facebookresearch/xformers#587] - contribution from [@jfc4050](https://github.com/jfc4050)
- fMHA: Added support for varying sequence lengths [facebookresearch/xformers#500]


## [0.0.16] - 2023-01-31
Expand Down
Binary file added docs/source/_static/block_diag_bias.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/block_diag_cat_split.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
43 changes: 41 additions & 2 deletions docs/source/components/ops.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,47 @@
Operators
======================
xFormers optimized operators
============================================================

Memory-efficient attention
---------------------------

.. automodule:: xformers.ops
:members: memory_efficient_attention, AttentionOpBase, AttentionBias
:show-inheritance:
:imported-members:


Available implementations
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. automodule:: xformers.ops.fmha.cutlass
:members: FwOp, BwOp
:member-order: bysource

.. automodule:: xformers.ops.fmha.flash
:members: FwOp, BwOp
:member-order: bysource

.. automodule:: xformers.ops.fmha.triton
:members: FwOp, BwOp
:member-order: bysource

.. automodule:: xformers.ops.fmha.small_k
:members: FwOp, BwOp
:member-order: bysource

Attention biases
~~~~~~~~~~~~~~~~~~~~

.. automodule:: xformers.ops.fmha.attn_bias
:members:
:show-inheritance:
:member-order: bysource


Non-autograd implementations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: xformers.ops.fmha
:members: memory_efficient_attention_forward, memory_efficient_attention_forward_requires_grad, memory_efficient_attention_backward
:show-inheritance:
:imported-members:
:member-order: bysource
29 changes: 29 additions & 0 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,3 +1311,32 @@ def test_attn_bias_from_seqlens() -> None:
out = bias.split(torch.randn([1, 3 + 5 + 1, 16]))
assert len(out) == 3
assert tuple(out[0].shape) == (1, 3, 16)


@cuda_only
def test_attn_bias_blockdiag_doc() -> None:
"""IMPORTANT:
This is the example in the doc for `BlockDiagonalMask`.
If this example needs to be updated, please also update the doc
"""
import torch

from xformers.ops import fmha

K = 16
dtype = torch.float16
device = "cuda"
list_x = [
torch.randn([1, 3, 1, K], dtype=dtype, device=device),
torch.randn([1, 6, 1, K], dtype=dtype, device=device),
torch.randn([1, 2, 1, K], dtype=dtype, device=device),
]
attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x)

linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype) # type: ignore

q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2)
out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias)
list_out = attn_bias.split(out)
print(list_out[0].shape) # [1, 3, 1, K]
assert tuple(list_out[0].shape) == (1, 3, 1, K)
2 changes: 1 addition & 1 deletion xformers/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def masked_matmul(a, b, mask=None):


__all__ = [
"memory_efficient_attention",
"AttentionBias",
"AttentionMask",
"AttentionOp",
Expand All @@ -71,7 +72,6 @@ def masked_matmul(a, b, mask=None):
"MemoryEfficientAttentionFlashAttentionOp",
"MemoryEfficientAttentionOp",
"MemoryEfficientAttentionTritonFwdFlashBwOp",
"memory_efficient_attention",
"memory_efficient_attention_backward",
"memory_efficient_attention_forward",
"memory_efficient_attention_forward_requires_grad",
Expand Down
9 changes: 5 additions & 4 deletions xformers/ops/fmha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,18 @@ def memory_efficient_attention(
Raises:
NotImplementedError: if there is no operator available to compute the MHA
ValueError: if inputs are invalid
:parameter query: Tensor of shape ``[B, Mq, H, K]``
:parameter key: Tensor of shape ``[B, Mkv, H, K]``
:parameter value: Tensor of shape ``[B, Mkv, H, Kv]``
:parameter attn_bias: Bias to apply to the attention matrix - defaults to no masking. \
For causal attention, use :attr:`xformers.ops.LowerTriangularMask`. \
This can also be a :attr:`torch.Tensor` for an arbitrary mask.
For common biases implemented efficiently in xFormers, see :attr:`xformers.ops.fmha.attn_bias.AttentionBias`. \
This can also be a :attr:`torch.Tensor` for an arbitrary mask (slower).
:parameter p: Dropout probability. Disabled if set to ``0.0``
:parameter scale: The scale to query_state weights. If set to ``None``, the default \
:parameter scale: Scaling factor for ``Q @ K.transpose()``. If set to ``None``, the default \
scale (q.shape[-1]**-0.5) will be used.
:parameter op: The operator to use - see :attr:`xformers.ops.AttentionOpBase`. \
:parameter op: The operators to use - see :attr:`xformers.ops.AttentionOpBase`. \
If set to ``None`` (recommended), xFormers \
will dispatch to the best available operator, depending on the inputs \
and options.
Expand Down
70 changes: 69 additions & 1 deletion xformers/ops/fmha/attn_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,40 @@ def split(

@dataclass
class BlockDiagonalMask(AttentionBias):
"""A block-diagonal mask - can be used to handle batch elements with different sequence length"""
"""
A block-diagonal mask that can be passed as ``attn_bias``
argument to :attr:`xformers.ops.memory_efficient_attention`.
.. figure:: /_static/block_diag_bias.png
This bias can be used to handle a batch of sequences of
different lengths, via :attr:`BlockDiagonalMask.from_tensor_list`
:Example:
.. code-block:: python
import torch
from xformers.ops import fmha
K = 16
dtype = torch.float16
device = "cuda"
list_x = [
torch.randn([1, 3, 1, K], dtype=dtype, device=device),
torch.randn([1, 6, 1, K], dtype=dtype, device=device),
torch.randn([1, 2, 1, K], dtype=dtype, device=device),
]
attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x)
linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype)
q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2)
out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias)
list_out = attn_bias.split(out)
print(list_out[0].shape) # [1, 3, 1, K]
assert tuple(list_out[0].shape) == (1, 3, 1, K)
"""

q_seqinfo: _SeqLenInfo
k_seqinfo: _SeqLenInfo
Expand All @@ -153,6 +186,7 @@ def materialize(
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
"""Materialize the attention bias - for debugging & testing"""
assert shape[-1] == self.k_seqinfo.cu_seqlen_py[-1]
assert shape[-2] == self.q_seqinfo.cu_seqlen_py[-1]
mask = torch.empty(shape[-2:], dtype=dtype, device=device)
Expand All @@ -178,6 +212,15 @@ def from_seqlens(
q_seqlen: Sequence[int],
kv_seqlen: Optional[Sequence[int]] = None,
) -> "BlockDiagonalMask":
"""Creates a :attr:`BlockDiagonalMask` from a list of tensors lengths for query and key/value.
Args:
q_seqlen (Sequence[int]): List of sequence lengths for query tensors
kv_seqlen (Sequence[int], optional): List of sequence lengths for key/value. Defaults to ``q_seqlen``.
Returns:
BlockDiagonalMask
"""
assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen)
q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen)
if kv_seqlen is None or q_seqlen == kv_seqlen:
Expand All @@ -191,6 +234,23 @@ def from_tensor_list(
cls,
tensors: Sequence[torch.Tensor],
) -> Tuple["BlockDiagonalMask", torch.Tensor]:
"""Creates a :attr:`BlockDiagonalMask` from a list of tensors, and returns the tensors
concatenated on the sequence length dimension
.. figure:: /_static/block_diag_cat_split.png
See also :attr:`BlockDiagonalMask.split` to split the returned
:attr:`torch.Tensor` back to a list of tensors of varying sequence length
Args:
tensors (Sequence[torch.Tensor]): A list of tensors of shape ``[B, M_i, *]``.
All tensors should have the same dimension and the same batch size ``B``, but
they can have different sequence length ``M``.
Returns:
Tuple[BlockDiagonalMask, torch.Tensor]: The corresponding bias for the attention
along with `tensors` concatenated on the sequence length dimension, with shape ``[1, sum_i{M_i}, *]``
"""
batch_sizes = [tensor.shape[0] for tensor in tensors]
seqlens = []
for x in tensors:
Expand Down Expand Up @@ -236,6 +296,14 @@ def split_kv(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
return self.k_seqinfo.split(tensor, self._batch_sizes)

def split(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
"""The inverse operation of :attr:`BlockDiagonalCausalMask.from_tensor_list`
Args:
tensor (torch.Tensor): Tensor of tokens of shape ``[1, sum_i{M_i}, *]``
Returns:
Sequence[torch.Tensor]: A list of tokens with possibly different sequence lengths
"""
assert self.q_seqinfo is self.k_seqinfo
return self.q_seqinfo.split(tensor, self._batch_sizes)

Expand Down
15 changes: 8 additions & 7 deletions xformers/ops/fmha/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,14 @@ class AttentionOpBase(BaseOperator):
See:
- :attr:`xformers.ops.MemoryEfficientAttentionOp`
- :attr:`xformers.ops.MemoryEfficientAttentionCutlassOp`
- :attr:`xformers.ops.MemoryEfficientAttentionFlashAttentionOp`
- :attr:`xformers.ops.MemoryEfficientAttentionCutlassFwdFlashBwOp`
- :attr:`xformers.ops.fmha.cutlass.FwOp`
- :attr:`xformers.ops.fmha.cutlass.BwOp`
- :attr:`xformers.ops.fmha.flash.FwOp`
- :attr:`xformers.ops.fmha.flash.BwOp`
- :attr:`xformers.ops.fmha.triton.FwOp`
- :attr:`xformers.ops.fmha.triton.BwOp`
- :attr:`xformers.ops.fmha.small_k.FwOp`
- :attr:`xformers.ops.fmha.small_k.BwOp`
"""

OPERATOR: Any
Expand Down
2 changes: 2 additions & 0 deletions xformers/ops/fmha/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:

@register_operator
class BwOp(AttentionBwOpBase):
__doc__ = FwOp.__doc__

OPERATOR = get_xformers_operator("efficient_attention_backward_cutlass")
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
Expand Down
7 changes: 2 additions & 5 deletions xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,6 @@ class FwOp(AttentionFwOpBase):
"""Operator that computes memory-efficient attention using \
`Flash-Attention <https://github.com/HazyResearch/flash-attention>`_ \
implementation.
This is a wrapper to make FlashAttention compatible with xformers's API
Most of this code was taken from:
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_interface.py
"""

OPERATOR = get_operator("xformers_flash", "flash_fwd")
Expand Down Expand Up @@ -261,6 +256,8 @@ def apply(

@register_operator
class BwOp(AttentionBwOpBase):
__doc__ = FwOp.__doc__

OPERATOR = get_operator("xformers_flash", "flash_bwd")
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY
Expand Down
2 changes: 2 additions & 0 deletions xformers/ops/fmha/small_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def apply(

@register_operator
class BwOp(AttentionBwOpBase):
__doc__ = FwOp.__doc__

OPERATOR = get_xformers_operator("efficient_attention_backward_small_k")
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
Expand Down
8 changes: 8 additions & 0 deletions xformers/ops/fmha/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ def _prepare_inputs(inp: Inputs) -> Inputs:

@register_operator
class FwOp(AttentionFwOpBase):
"""Operator that computes memory-efficient attention using \
`Tri Dao's <https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py>`_ \
implementation, based on
`Phil Tillet's code <https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py>`_
"""

OPERATOR = triton_flash_forward
SUPPORTED_DEVICES = {"cuda"}
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
Expand Down Expand Up @@ -101,6 +107,8 @@ def apply(

@register_operator
class BwOp(AttentionBwOpBase):
__doc__ = FwOp.__doc__

OPERATOR = triton_flash_backward
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY
Expand Down

0 comments on commit 00afc12

Please sign in to comment.