Skip to content

Commit

Permalink
RFC: AttentionBias v2
Browse files Browse the repository at this point in the history
See also #640

Adds support for combinations of different sorts of biases:
- Causal
- Bias (coming with #587)
- Block-diagonal (used for different seqlen per batch element)

We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal:

```
# A (block-diagonal)
0 0 0 * *
0 0 0 * *
* * * 0 0
* * * 0 0
# B (lower triangular)
0 * * * *
0 0 * * *
0 0 0 * *
0 0 0 0 *
# A + B
0 * * * *
0 0 * * *
* * * * *
* * * 0 *
# A + causal (what most ppl want)
0 * * * *
0 0 * * *
* * * 0 *
* * * 0 0
```
  • Loading branch information
danthe3rd committed Jan 13, 2023
1 parent 6a988dd commit f6f8b71
Showing 1 changed file with 119 additions and 0 deletions.
119 changes: 119 additions & 0 deletions xformers/ops/fmha/attn_bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


from dataclasses import dataclass, replace
from typing import List, Optional, Tuple

import torch


@dataclass
class TensorCreateInfo:
shape: Tuple[int, ...]
dtype: torch.dtype
device: torch.device

def __post_init__(self) -> None:
assert len(self.shape) in [2, 3]


@dataclass
class SeqLenInfo:
max_seqlen: int
cu_seqlen: torch.Tensor
cu_seqlen_py: List[int]


@dataclass
class BlockDiagonal:
q_seqinfo: SeqLenInfo
k_seqinfo: SeqLenInfo

def __post_init__(self) -> None:
assert len(self.q_seqinfo.cu_seqlen_py) == len(self.k_seqinfo.cu_seqlen_py)

def materialize(self, create_info: TensorCreateInfo) -> torch.Tensor:
tensors: List[torch.Tensor] = []
for q_start, q_end, k_start, k_end in zip(
self.q_seqinfo.cu_seqlen_py,
self.q_seqinfo.cu_seqlen_py[1:],
self.k_seqinfo.cu_seqlen_py,
self.k_seqinfo.cu_seqlen_py[1:],
):
tensors.append(
torch.ones(
[q_end - q_start, k_end - k_start],
dtype=create_info.dtype,
device=create_info.device,
)
)
mask = torch.block_diag(tensors)
for _ in range(len(create_info.shape) - 2):
mask = mask.unsqueeze(0)
return mask


def _create_causal_mask(create_info: TensorCreateInfo) -> torch.Tensor:
dtype = create_info.dtype
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
tensor = torch.full( # type: ignore
create_info.shape,
dtype=create_as,
fill_value=float("-inf"),
device=create_info.device,
)
return torch.triu(tensor, diagonal=1).to(dtype) # type: ignore


@dataclass
class AttentionBias:
causal: bool = False
block_diag: Optional[BlockDiagonal] = None
bias: Optional[torch.Tensor] = None
create_info: Optional[TensorCreateInfo] = None

def make_causal(self, causal: bool) -> "AttentionBias":
return replace(self, causal=causal)

def add_bias(self, bias: torch.Tensor) -> "AttentionBias":
res = replace(self, bias=self.bias + bias if self.bias is not None else bias)
assert res.bias is not None
if res.create_info is None:
res.create_info = TensorCreateInfo(
res.bias.shape, dtype=res.bias.dtype, device=res.bias.device
)
return res

def materialize(self) -> torch.Tensor:
"""
Materializes the bias as a `torch.Tensor`. This is very slow
and we don't attempt to make it fast. Only use for debugging/testing.
Returned tensor has shape [..., Mq, Mk]
"""
if self.create_info is None:
raise ValueError(
"Can't create a causal mask if no dimension/shape/device provided"
)
tensor = torch.zeros(
self.create_info.shape,
dtype=self.create_info.dtype,
device=self.create_info.device,
)
if self.bias is not None:
tensor += self.bias
if self.causal:
tensor = tensor + _create_causal_mask(self.create_info)
if self.block_diag is not None:
tensor = tensor + self.block_diag.materialize(self.create_info)
if tensor is None:
raise ValueError("This mask is empty")
return tensor


def LowerTriangularMask(*args, **kwargs):
if len(args) + len(kwargs):
return AttentionBias(causal=True, create_info=TensorCreateInfo(*args, **kwargs))
return AttentionBias(causal=True)

0 comments on commit f6f8b71

Please sign in to comment.