Skip to content

Commit

Permalink
AttentionBias v2
Browse files Browse the repository at this point in the history
See also facebookresearch#640

Adds support for combinations of different sorts of biases:
- Causal
- Bias (coming with facebookresearch#587)
- Block-diagonal (used for different seqlen per batch element)

We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal:

```
# A (block-diagonal)
0 0 0 * *
0 0 0 * *
* * * 0 0
* * * 0 0
# B (lower triangular)
0 * * * *
0 0 * * *
0 0 0 * *
0 0 0 0 *
# A + B
0 * * * *
0 0 * * *
* * * * *
* * * 0 *
# A + causal (what most ppl want)
0 * * * *
0 0 * * *
* * * 0 *
* * * 0 0
```

ghstack-source-id: 44740f71132fa76226fd4c559cc3f09732ff139b
Pull Request resolved: https://github.com/fairinternal/xformers/pull/435

__original_commit__ = fairinternal/xformers@be55fcd21c5dd621831245c5995e1c6fb49d9b77
  • Loading branch information
danthe3rd authored and xFormers Bot committed Jan 19, 2023
1 parent 8dab253 commit d23d04e
Show file tree
Hide file tree
Showing 11 changed files with 651 additions and 562 deletions.
455 changes: 221 additions & 234 deletions tests/test_mem_eff_attention.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def assert_allclose(
atol: float = 1e-8,
rtol: float = 1e-5,
) -> None:
assert out.shape == ref.shape
assert out.shape == ref.shape, f"Shape: {out.shape} (expected: {ref.shape})"
flatten_diff = ((out - ref).abs() - atol - ref.abs() * rtol).flatten()
max_pos = flatten_diff.argmax()
max_diff = flatten_diff[max_pos]
Expand Down
6 changes: 5 additions & 1 deletion xformers/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from .fmha import (
AttentionMask,
AttentionBias,
AttentionOp,
AttentionOpBase,
AttentionOpDispatch,
Expand All @@ -33,6 +33,9 @@
)
from .unbind import get_stack_strides, stack_or_none, unbind

# BW compatibility
AttentionMask = AttentionBias


def masked_matmul(a, b, mask=None):
if torch.overrides.has_torch_function((a, b, mask)):
Expand All @@ -57,6 +60,7 @@ def masked_matmul(a, b, mask=None):


__all__ = [
"AttentionBias",
"AttentionMask",
"AttentionOp",
"AttentionOpBase",
Expand Down
22 changes: 10 additions & 12 deletions xformers/ops/fmha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,19 @@
import torch

from . import cutlass, flash, small_k, triton
from .attn_bias import AttentionBias, BlockDiagonalMask, LowerTriangularMask
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
AttentionMask,
AttentionOp,
AttentionOpBase,
AttentionOpDispatch,
Context,
Gradients,
Inputs,
LowerTriangularMask,
bmk2bmhk,
)
from .dispatch import _dispatch_bw, _dispatch_fw, _ensure_op_supports_or_raise
from .tensor_with_seqlen import TensorWithSeqLen # noqa

MemoryEfficientAttentionCutlassOp = (cutlass.FwOp, cutlass.BwOp)
MemoryEfficientAttentionCutlassFwdFlashBwOp = (cutlass.FwOp, flash.BwOp)
Expand Down Expand Up @@ -118,7 +116,7 @@ def memory_efficient_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionMask]] = None,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
*,
Expand Down Expand Up @@ -206,7 +204,7 @@ def memory_efficient_attention_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionMask]] = None,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
*,
Expand All @@ -225,7 +223,7 @@ def memory_efficient_attention_forward_requires_grad(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionMask]] = None,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
*,
Expand Down Expand Up @@ -257,7 +255,7 @@ def memory_efficient_attention_backward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionMask]] = None,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
*,
Expand Down Expand Up @@ -349,18 +347,18 @@ def _memory_efficient_attention_backward(
ctx.lse.ndim != 3
# Dim 0
or (
not isinstance(inp.query, TensorWithSeqLen)
not isinstance(inp.attn_bias, BlockDiagonalMask)
and ctx.lse.shape[0] != inp.query.shape[0]
)
or (
isinstance(inp.query, TensorWithSeqLen)
and ctx.lse.shape[0] != inp.query.cu_seqlen.shape[0] - 1
isinstance(inp.attn_bias, BlockDiagonalMask)
and ctx.lse.shape[0] != inp.attn_bias.q_seqinfo.cu_seqlen.shape[0] - 1
)
# Dim 1
or ctx.lse.shape[1] != inp.query.shape[2]
# Dim 2
or (
not isinstance(inp.query, TensorWithSeqLen)
not isinstance(inp.attn_bias, BlockDiagonalMask)
and ctx.lse.shape[2] < inp.query.shape[1]
)
):
Expand All @@ -387,7 +385,7 @@ def _memory_efficient_attention_backward(


__all__ = [
"AttentionMask",
"AttentionBias",
"AttentionOp",
"AttentionOpBase",
"AttentionOpDispatch",
Expand Down
265 changes: 265 additions & 0 deletions xformers/ops/fmha/attn_bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import math
from dataclasses import dataclass
from typing import Iterable, List, Optional, Sequence, Tuple, Union

import torch


class AttentionBias:
"""Base class for a custom bias that can be applied \
in :attr:`xformers.ops.memory_efficient_attention`.
When using an :attr:`xformers.ops.AttentionBias`
instead of a :attr:`torch.Tensor`, the mask matrix does
not need to be materialized, and can be
hardcoded into some kernels for better performance.
See:
- :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMask`
- :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias`
- :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`
- :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`
"""

def materialize(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
"""
Materializes the bias as a `torch.Tensor`. This is very slow
and we don't attempt to make it fast. Only use for debugging/testing.
Shape should be like `[*, q_seqlen, k_seqlen]`
"""
raise NotImplementedError()


class LowerTriangularMask(AttentionBias):
"""A lower-triangular (aka causal) mask"""

def materialize(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
tensor = torch.full( # type: ignore
shape,
dtype=create_as,
fill_value=float("-inf"),
device=device,
)
return torch.triu(tensor, diagonal=1).to(dtype) # type: ignore

def add_bias(self, bias: torch.Tensor) -> "LowerTriangularMaskWithTensorBias":
return LowerTriangularMaskWithTensorBias(bias)


class LowerTriangularMaskWithTensorBias(LowerTriangularMask):
"""A lower-triangular (aka causal) mask with an additive bias"""

def __init__(self, bias: torch.Tensor) -> None:
self._bias = bias

def materialize(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
return super().materialize(shape, dtype=dtype, device=device) + self._bias


@dataclass
class _SeqLenInfo:
max_seqlen: int
cu_seqlen: torch.Tensor
cu_seqlen_py: List[int]

@classmethod
def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
"""
Input tensors are assumed to be in shape [B, M, *]
"""
cu_seqlen_py = [0]
max_seqlen = -1
for seqlen in seqlens:
max_seqlen = max(max_seqlen, seqlen)
cu_seqlen_py.append(cu_seqlen_py[len(cu_seqlen_py) - 1] + seqlen)
cu_seqlen = torch.tensor(cu_seqlen_py, dtype=torch.int32)
return cls(
max_seqlen=max_seqlen, cu_seqlen=cu_seqlen, cu_seqlen_py=cu_seqlen_py
)

def split(
self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None
) -> List[torch.Tensor]:
if self.cu_seqlen_py[-1] != x.shape[1] or x.shape[0] != 1:
raise ValueError(
f"Invalid `torch.Tensor` of shape {x.shape}, expected format "
f"(B, M, *) with B=1 and M={self.cu_seqlen_py[-1]}\n"
f" cu_seqlen: {self.cu_seqlen_py}"
)
if batch_sizes is None:
batch_sizes = [1] * (len(self.cu_seqlen_py) - 1)
split_chunks = []
it = 0
for batch_size in batch_sizes:
split_chunks.append(
self.cu_seqlen_py[it + batch_size] - self.cu_seqlen_py[it]
)
it += batch_size
return [
tensor.reshape([bs, -1, *tensor.shape[2:]])
for bs, tensor in zip(batch_sizes, x.split(split_chunks, dim=1))
]


@dataclass
class BlockDiagonalMask(AttentionBias):
"""A block-diagonal mask - can be used to handle batch elements with different sequence length"""

q_seqinfo: _SeqLenInfo
k_seqinfo: _SeqLenInfo
_batch_sizes: Optional[Sequence[int]] = None

def _create_block_mask(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
return torch.zeros(
shape,
dtype=dtype,
device=device,
)

def materialize(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
assert shape[-1] == self.k_seqinfo.cu_seqlen_py[-1]
assert shape[-2] == self.q_seqinfo.cu_seqlen_py[-1]
mask = torch.empty(shape[-2:], dtype=dtype, device=device)
mask.fill_(-math.inf)
for q_start, q_end, k_start, k_end in zip(
self.q_seqinfo.cu_seqlen_py,
self.q_seqinfo.cu_seqlen_py[1:],
self.k_seqinfo.cu_seqlen_py,
self.k_seqinfo.cu_seqlen_py[1:],
):
mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
(q_end - q_start, k_end - k_start),
dtype=dtype,
device=device,
)
for _ in range(len(shape) - 2):
mask = mask.unsqueeze(0)
return mask.expand(shape)

@classmethod
def from_seqlens(
cls,
q_seqlen: Sequence[int],
kv_seqlen: Optional[Sequence[int]] = None,
) -> "BlockDiagonalMask":
assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen)
q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen)
if kv_seqlen is None or q_seqlen == kv_seqlen:
k_seqinfo = q_seqinfo
else:
k_seqinfo = _SeqLenInfo.from_seqlens(kv_seqlen)
return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)

@classmethod
def from_tensor_list(
cls,
tensors: Sequence[torch.Tensor],
) -> Tuple["BlockDiagonalMask", torch.Tensor]:
batch_sizes = [tensor.shape[0] for tensor in tensors]
seqlens = []
for x in tensors:
for _ in range(x.shape[0]):
seqlens.append(x.shape[1])
block_diag = cls.from_seqlens(seqlens)
block_diag._batch_sizes = batch_sizes
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in tensors)
concat_tensors = torch.cat(tensors_bs1, dim=1)
return block_diag, concat_tensors

@classmethod
def from_tensor_lists_qkv(
cls,
tensors_q: Sequence[torch.Tensor],
tensors_k: Sequence[torch.Tensor],
tensors_v: Optional[Sequence[torch.Tensor]] = None,
) -> Tuple["BlockDiagonalMask", torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert len(tensors_q) == len(tensors_k)
assert tensors_v is None or len(tensors_v) == len(tensors_q)
batch_sizes = [tensor.shape[0] for tensor in tensors_q]
q_seqlens, kv_seqlens = [], []
for i, (q, k) in enumerate(zip(tensors_q, tensors_k)):
assert q.shape[0] == k.shape[0]
q_seqlens += [q.shape[1]] * q.shape[0]
kv_seqlens += [k.shape[1]] * k.shape[0]
assert tensors_v is None or tensors_v[i].shape[:2] == k.shape[:2]
block_diag = cls.from_seqlens(q_seqlens, kv_seqlens)
block_diag._batch_sizes = batch_sizes
return (
block_diag,
torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_q], dim=1),
torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_k], dim=1),
torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_v], dim=1)
if tensors_v is not None
else None,
)

def split_queries(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
return self.q_seqinfo.split(tensor, self._batch_sizes)

def split_kv(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
return self.k_seqinfo.split(tensor, self._batch_sizes)

def split(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
assert self.q_seqinfo is self.k_seqinfo
return self.q_seqinfo.split(tensor, self._batch_sizes)

def make_causal(self) -> "BlockDiagonalCausalMask":
"""Makes each block causal"""
return BlockDiagonalCausalMask(
q_seqinfo=self.q_seqinfo,
k_seqinfo=self.k_seqinfo,
_batch_sizes=self._batch_sizes,
)


@dataclass
class BlockDiagonalCausalMask(BlockDiagonalMask):
"""Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal"""

def _create_block_mask(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
return LowerTriangularMask().materialize(
shape,
dtype=dtype,
device=device,
)
Loading

0 comments on commit d23d04e

Please sign in to comment.