-
Notifications
You must be signed in to change notification settings - Fork 636
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
See also #640 Adds support for combinations of different sorts of biases: - Causal - Bias (coming with #587) - Block-diagonal (used for different seqlen per batch element) We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal: ``` # A (block-diagonal) 0 0 0 * * 0 0 0 * * * * * 0 0 * * * 0 0 # B (lower triangular) 0 * * * * 0 0 * * * 0 0 0 * * 0 0 0 0 * # A + B 0 * * * * 0 0 * * * * * * * * * * * 0 * # A + causal (what most ppl want) 0 * * * * 0 0 * * * * * * 0 * * * * 0 0 ```
- Loading branch information
danthe3rd
committed
Jan 13, 2023
1 parent
6a988dd
commit f6f8b71
Showing
1 changed file
with
119 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
from dataclasses import dataclass, replace | ||
from typing import List, Optional, Tuple | ||
|
||
import torch | ||
|
||
|
||
@dataclass | ||
class TensorCreateInfo: | ||
shape: Tuple[int, ...] | ||
dtype: torch.dtype | ||
device: torch.device | ||
|
||
def __post_init__(self) -> None: | ||
assert len(self.shape) in [2, 3] | ||
|
||
|
||
@dataclass | ||
class SeqLenInfo: | ||
max_seqlen: int | ||
cu_seqlen: torch.Tensor | ||
cu_seqlen_py: List[int] | ||
|
||
|
||
@dataclass | ||
class BlockDiagonal: | ||
q_seqinfo: SeqLenInfo | ||
k_seqinfo: SeqLenInfo | ||
|
||
def __post_init__(self) -> None: | ||
assert len(self.q_seqinfo.cu_seqlen_py) == len(self.k_seqinfo.cu_seqlen_py) | ||
|
||
def materialize(self, create_info: TensorCreateInfo) -> torch.Tensor: | ||
tensors: List[torch.Tensor] = [] | ||
for q_start, q_end, k_start, k_end in zip( | ||
self.q_seqinfo.cu_seqlen_py, | ||
self.q_seqinfo.cu_seqlen_py[1:], | ||
self.k_seqinfo.cu_seqlen_py, | ||
self.k_seqinfo.cu_seqlen_py[1:], | ||
): | ||
tensors.append( | ||
torch.ones( | ||
[q_end - q_start, k_end - k_start], | ||
dtype=create_info.dtype, | ||
device=create_info.device, | ||
) | ||
) | ||
mask = torch.block_diag(tensors) | ||
for _ in range(len(create_info.shape) - 2): | ||
mask = mask.unsqueeze(0) | ||
return mask | ||
|
||
|
||
def _create_causal_mask(create_info: TensorCreateInfo) -> torch.Tensor: | ||
dtype = create_info.dtype | ||
create_as = dtype if dtype is not torch.bfloat16 else torch.float32 | ||
tensor = torch.full( # type: ignore | ||
create_info.shape, | ||
dtype=create_as, | ||
fill_value=float("-inf"), | ||
device=create_info.device, | ||
) | ||
return torch.triu(tensor, diagonal=1).to(dtype) # type: ignore | ||
|
||
|
||
@dataclass | ||
class AttentionBias: | ||
causal: bool = False | ||
block_diag: Optional[BlockDiagonal] = None | ||
bias: Optional[torch.Tensor] = None | ||
create_info: Optional[TensorCreateInfo] = None | ||
|
||
def make_causal(self, causal: bool) -> "AttentionBias": | ||
return replace(self, causal=causal) | ||
|
||
def add_bias(self, bias: torch.Tensor) -> "AttentionBias": | ||
res = replace(self, bias=self.bias + bias if self.bias is not None else bias) | ||
assert res.bias is not None | ||
if res.create_info is None: | ||
res.create_info = TensorCreateInfo( | ||
res.bias.shape, dtype=res.bias.dtype, device=res.bias.device | ||
) | ||
return res | ||
|
||
def materialize(self) -> torch.Tensor: | ||
""" | ||
Materializes the bias as a `torch.Tensor`. This is very slow | ||
and we don't attempt to make it fast. Only use for debugging/testing. | ||
Returned tensor has shape [..., Mq, Mk] | ||
""" | ||
if self.create_info is None: | ||
raise ValueError( | ||
"Can't create a causal mask if no dimension/shape/device provided" | ||
) | ||
tensor = torch.zeros( | ||
self.create_info.shape, | ||
dtype=self.create_info.dtype, | ||
device=self.create_info.device, | ||
) | ||
if self.bias is not None: | ||
tensor += self.bias | ||
if self.causal: | ||
tensor = tensor + _create_causal_mask(self.create_info) | ||
if self.block_diag is not None: | ||
tensor = tensor + self.block_diag.materialize(self.create_info) | ||
if tensor is None: | ||
raise ValueError("This mask is empty") | ||
return tensor | ||
|
||
|
||
def LowerTriangularMask(*args, **kwargs): | ||
if len(args) + len(kwargs): | ||
return AttentionBias(causal=True, create_info=TensorCreateInfo(*args, **kwargs)) | ||
return AttentionBias(causal=True) |