diff --git a/CHANGELOG.md b/CHANGELOG.md index e25540ad00..dc63c5d573 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,10 +6,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## TBD ### Fixed +- fMHA: Fixed BW pass on Sm86/Sm89 GPUs when `K > 64` (RTX 3090, RTX 4090, A6000, ..) [facebookresearch/xformers#631] + ### Added -- Added tensor attn bias support to CUTLASS FlashAttention -- Added tensor attn bias grad support to CUTLASS FlashAttention -- Added dropout support to CUTLASS FlashAttention +- fMHA/CUTLASS: Added tensor attn bias support [facebookresearch/xformers#587] - contribution from [@jfc4050](https://github.com/jfc4050) +- fMHA/CUTLASS: Added tensor attn bias grad support [facebookresearch/xformers#587] - contribution from [@jfc4050](https://github.com/jfc4050) +- fMHA/CUTLASS: Added dropout support [facebookresearch/xformers#587] - contribution from [@jfc4050](https://github.com/jfc4050) +- fMHA: Added support for varying sequence lengths [facebookresearch/xformers#500] ## [0.0.16] - 2023-01-31 diff --git a/docs/source/_static/block_diag_bias.png b/docs/source/_static/block_diag_bias.png new file mode 100644 index 0000000000..d18054a308 Binary files /dev/null and b/docs/source/_static/block_diag_bias.png differ diff --git a/docs/source/_static/block_diag_cat_split.png b/docs/source/_static/block_diag_cat_split.png new file mode 100644 index 0000000000..ddec7fb147 Binary files /dev/null and b/docs/source/_static/block_diag_cat_split.png differ diff --git a/docs/source/components/ops.rst b/docs/source/components/ops.rst index af51f41e76..5f98fdcb52 100644 --- a/docs/source/components/ops.rst +++ b/docs/source/components/ops.rst @@ -1,8 +1,47 @@ -Operators -====================== +xFormers optimized operators +============================================================ + +Memory-efficient attention +--------------------------- .. automodule:: xformers.ops + :members: memory_efficient_attention, AttentionOpBase, AttentionBias + :show-inheritance: + :imported-members: + + +Available implementations +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: xformers.ops.fmha.cutlass + :members: FwOp, BwOp + :member-order: bysource + +.. automodule:: xformers.ops.fmha.flash + :members: FwOp, BwOp + :member-order: bysource + +.. automodule:: xformers.ops.fmha.triton + :members: FwOp, BwOp + :member-order: bysource + +.. automodule:: xformers.ops.fmha.small_k + :members: FwOp, BwOp + :member-order: bysource + +Attention biases +~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: xformers.ops.fmha.attn_bias :members: :show-inheritance: + :member-order: bysource + + +Non-autograd implementations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: xformers.ops.fmha + :members: memory_efficient_attention_forward, memory_efficient_attention_forward_requires_grad, memory_efficient_attention_backward + :show-inheritance: :imported-members: :member-order: bysource diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index b1057dd1c6..3a1be8e694 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -1311,3 +1311,32 @@ def test_attn_bias_from_seqlens() -> None: out = bias.split(torch.randn([1, 3 + 5 + 1, 16])) assert len(out) == 3 assert tuple(out[0].shape) == (1, 3, 16) + + +@cuda_only +def test_attn_bias_blockdiag_doc() -> None: + """IMPORTANT: + This is the example in the doc for `BlockDiagonalMask`. + If this example needs to be updated, please also update the doc + """ + import torch + + from xformers.ops import fmha + + K = 16 + dtype = torch.float16 + device = "cuda" + list_x = [ + torch.randn([1, 3, 1, K], dtype=dtype, device=device), + torch.randn([1, 6, 1, K], dtype=dtype, device=device), + torch.randn([1, 2, 1, K], dtype=dtype, device=device), + ] + attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x) + + linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype) # type: ignore + + q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2) + out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias) + list_out = attn_bias.split(out) + print(list_out[0].shape) # [1, 3, 1, K] + assert tuple(list_out[0].shape) == (1, 3, 1, K) diff --git a/xformers/ops/__init__.py b/xformers/ops/__init__.py index a1943cb72e..f0169774d4 100644 --- a/xformers/ops/__init__.py +++ b/xformers/ops/__init__.py @@ -60,6 +60,7 @@ def masked_matmul(a, b, mask=None): __all__ = [ + "memory_efficient_attention", "AttentionBias", "AttentionMask", "AttentionOp", @@ -71,7 +72,6 @@ def masked_matmul(a, b, mask=None): "MemoryEfficientAttentionFlashAttentionOp", "MemoryEfficientAttentionOp", "MemoryEfficientAttentionTritonFwdFlashBwOp", - "memory_efficient_attention", "memory_efficient_attention_backward", "memory_efficient_attention_forward", "memory_efficient_attention_forward_requires_grad", diff --git a/xformers/ops/fmha/__init__.py b/xformers/ops/fmha/__init__.py index 762b14ea42..743fcfff82 100644 --- a/xformers/ops/fmha/__init__.py +++ b/xformers/ops/fmha/__init__.py @@ -176,17 +176,18 @@ def memory_efficient_attention( Raises: NotImplementedError: if there is no operator available to compute the MHA + ValueError: if inputs are invalid :parameter query: Tensor of shape ``[B, Mq, H, K]`` :parameter key: Tensor of shape ``[B, Mkv, H, K]`` :parameter value: Tensor of shape ``[B, Mkv, H, Kv]`` :parameter attn_bias: Bias to apply to the attention matrix - defaults to no masking. \ - For causal attention, use :attr:`xformers.ops.LowerTriangularMask`. \ - This can also be a :attr:`torch.Tensor` for an arbitrary mask. + For common biases implemented efficiently in xFormers, see :attr:`xformers.ops.fmha.attn_bias.AttentionBias`. \ + This can also be a :attr:`torch.Tensor` for an arbitrary mask (slower). :parameter p: Dropout probability. Disabled if set to ``0.0`` - :parameter scale: The scale to query_state weights. If set to ``None``, the default \ + :parameter scale: Scaling factor for ``Q @ K.transpose()``. If set to ``None``, the default \ scale (q.shape[-1]**-0.5) will be used. - :parameter op: The operator to use - see :attr:`xformers.ops.AttentionOpBase`. \ + :parameter op: The operators to use - see :attr:`xformers.ops.AttentionOpBase`. \ If set to ``None`` (recommended), xFormers \ will dispatch to the best available operator, depending on the inputs \ and options. diff --git a/xformers/ops/fmha/attn_bias.py b/xformers/ops/fmha/attn_bias.py index 4458185538..5e5939083f 100644 --- a/xformers/ops/fmha/attn_bias.py +++ b/xformers/ops/fmha/attn_bias.py @@ -129,7 +129,40 @@ def split( @dataclass class BlockDiagonalMask(AttentionBias): - """A block-diagonal mask - can be used to handle batch elements with different sequence length""" + """ + A block-diagonal mask that can be passed as ``attn_bias`` + argument to :attr:`xformers.ops.memory_efficient_attention`. + + .. figure:: /_static/block_diag_bias.png + + This bias can be used to handle a batch of sequences of + different lengths, via :attr:`BlockDiagonalMask.from_tensor_list` + + :Example: + + .. code-block:: python + + import torch + from xformers.ops import fmha + + K = 16 + dtype = torch.float16 + device = "cuda" + list_x = [ + torch.randn([1, 3, 1, K], dtype=dtype, device=device), + torch.randn([1, 6, 1, K], dtype=dtype, device=device), + torch.randn([1, 2, 1, K], dtype=dtype, device=device), + ] + attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x) + linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype) + + q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2) + out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias) + list_out = attn_bias.split(out) + print(list_out[0].shape) # [1, 3, 1, K] + assert tuple(list_out[0].shape) == (1, 3, 1, K) + + """ q_seqinfo: _SeqLenInfo k_seqinfo: _SeqLenInfo @@ -153,6 +186,7 @@ def materialize( dtype: torch.dtype = torch.float32, device: Union[str, torch.device] = "cpu", ) -> torch.Tensor: + """Materialize the attention bias - for debugging & testing""" assert shape[-1] == self.k_seqinfo.cu_seqlen_py[-1] assert shape[-2] == self.q_seqinfo.cu_seqlen_py[-1] mask = torch.empty(shape[-2:], dtype=dtype, device=device) @@ -178,6 +212,15 @@ def from_seqlens( q_seqlen: Sequence[int], kv_seqlen: Optional[Sequence[int]] = None, ) -> "BlockDiagonalMask": + """Creates a :attr:`BlockDiagonalMask` from a list of tensors lengths for query and key/value. + + Args: + q_seqlen (Sequence[int]): List of sequence lengths for query tensors + kv_seqlen (Sequence[int], optional): List of sequence lengths for key/value. Defaults to ``q_seqlen``. + + Returns: + BlockDiagonalMask + """ assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen) q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) if kv_seqlen is None or q_seqlen == kv_seqlen: @@ -191,6 +234,23 @@ def from_tensor_list( cls, tensors: Sequence[torch.Tensor], ) -> Tuple["BlockDiagonalMask", torch.Tensor]: + """Creates a :attr:`BlockDiagonalMask` from a list of tensors, and returns the tensors + concatenated on the sequence length dimension + + .. figure:: /_static/block_diag_cat_split.png + + See also :attr:`BlockDiagonalMask.split` to split the returned + :attr:`torch.Tensor` back to a list of tensors of varying sequence length + + Args: + tensors (Sequence[torch.Tensor]): A list of tensors of shape ``[B, M_i, *]``. + All tensors should have the same dimension and the same batch size ``B``, but + they can have different sequence length ``M``. + + Returns: + Tuple[BlockDiagonalMask, torch.Tensor]: The corresponding bias for the attention + along with `tensors` concatenated on the sequence length dimension, with shape ``[1, sum_i{M_i}, *]`` + """ batch_sizes = [tensor.shape[0] for tensor in tensors] seqlens = [] for x in tensors: @@ -236,6 +296,14 @@ def split_kv(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: return self.k_seqinfo.split(tensor, self._batch_sizes) def split(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + """The inverse operation of :attr:`BlockDiagonalCausalMask.from_tensor_list` + + Args: + tensor (torch.Tensor): Tensor of tokens of shape ``[1, sum_i{M_i}, *]`` + + Returns: + Sequence[torch.Tensor]: A list of tokens with possibly different sequence lengths + """ assert self.q_seqinfo is self.k_seqinfo return self.q_seqinfo.split(tensor, self._batch_sizes) diff --git a/xformers/ops/fmha/common.py b/xformers/ops/fmha/common.py index 3b172e2211..1df234f5d2 100644 --- a/xformers/ops/fmha/common.py +++ b/xformers/ops/fmha/common.py @@ -132,13 +132,14 @@ class AttentionOpBase(BaseOperator): See: - - :attr:`xformers.ops.MemoryEfficientAttentionOp` - - - :attr:`xformers.ops.MemoryEfficientAttentionCutlassOp` - - - :attr:`xformers.ops.MemoryEfficientAttentionFlashAttentionOp` - - - :attr:`xformers.ops.MemoryEfficientAttentionCutlassFwdFlashBwOp` + - :attr:`xformers.ops.fmha.cutlass.FwOp` + - :attr:`xformers.ops.fmha.cutlass.BwOp` + - :attr:`xformers.ops.fmha.flash.FwOp` + - :attr:`xformers.ops.fmha.flash.BwOp` + - :attr:`xformers.ops.fmha.triton.FwOp` + - :attr:`xformers.ops.fmha.triton.BwOp` + - :attr:`xformers.ops.fmha.small_k.FwOp` + - :attr:`xformers.ops.fmha.small_k.BwOp` """ OPERATOR: Any diff --git a/xformers/ops/fmha/cutlass.py b/xformers/ops/fmha/cutlass.py index d5181ac60c..3d37f6d830 100644 --- a/xformers/ops/fmha/cutlass.py +++ b/xformers/ops/fmha/cutlass.py @@ -175,6 +175,8 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: @register_operator class BwOp(AttentionBwOpBase): + __doc__ = FwOp.__doc__ + OPERATOR = get_xformers_operator("efficient_attention_backward_cutlass") SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index acb8585ed4..c2c32e9945 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -184,11 +184,6 @@ class FwOp(AttentionFwOpBase): """Operator that computes memory-efficient attention using \ `Flash-Attention `_ \ implementation. - - - This is a wrapper to make FlashAttention compatible with xformers's API - Most of this code was taken from: - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_interface.py """ OPERATOR = get_operator("xformers_flash", "flash_fwd") @@ -261,6 +256,8 @@ def apply( @register_operator class BwOp(AttentionBwOpBase): + __doc__ = FwOp.__doc__ + OPERATOR = get_operator("xformers_flash", "flash_bwd") SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY diff --git a/xformers/ops/fmha/small_k.py b/xformers/ops/fmha/small_k.py index 218ae85de8..a2192f1ba5 100644 --- a/xformers/ops/fmha/small_k.py +++ b/xformers/ops/fmha/small_k.py @@ -113,6 +113,8 @@ def apply( @register_operator class BwOp(AttentionBwOpBase): + __doc__ = FwOp.__doc__ + OPERATOR = get_xformers_operator("efficient_attention_backward_small_k") SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES diff --git a/xformers/ops/fmha/triton.py b/xformers/ops/fmha/triton.py index e50a489eec..3b7ab54bcd 100644 --- a/xformers/ops/fmha/triton.py +++ b/xformers/ops/fmha/triton.py @@ -52,6 +52,12 @@ def _prepare_inputs(inp: Inputs) -> Inputs: @register_operator class FwOp(AttentionFwOpBase): + """Operator that computes memory-efficient attention using \ + `Tri Dao's `_ \ + implementation, based on + `Phil Tillet's code `_ + """ + OPERATOR = triton_flash_forward SUPPORTED_DEVICES = {"cuda"} CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) @@ -101,6 +107,8 @@ def apply( @register_operator class BwOp(AttentionBwOpBase): + __doc__ = FwOp.__doc__ + OPERATOR = triton_flash_backward SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY