Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sub-quadratic attention #1

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c810c32
initial commit of sub-quadratic attention source from https://github.…
Birch-san Dec 26, 2022
c9b3b9f
invoke efficient_dot_product_attention(). not currently giving correc…
Birch-san Dec 26, 2022
70dc50d
provide a way to skip checkpointing
Birch-san Dec 26, 2022
c794f0b
MPS fixes; now working
Birch-san Dec 26, 2022
04a5cbe
eliminate all einsums. assume 3D tensor [batch * num_heads, tokens, c…
Birch-san Dec 26, 2022
b44fa12
remove the bits that I broke in the pursuit of speed (mask, bias, wei…
Birch-san Dec 26, 2022
8694703
clarify comment; verified that upcast_attention is indeed still helpf…
Birch-san Dec 26, 2022
5bfe96d
add TODO about softmax
Birch-san Dec 26, 2022
da8901b
typings
Birch-san Dec 26, 2022
0c4d82f
simplify protocols
Birch-san Dec 26, 2022
c5e8e31
remove unused
Birch-san Dec 26, 2022
b16edc9
simplify protocol
Birch-san Dec 26, 2022
b7fc3a8
fix tensor shape destructuring
Birch-san Dec 26, 2022
8f003c2
simplify dynamic_slice
Birch-san Dec 26, 2022
1334670
simplify chunk scanning
Birch-san Dec 26, 2022
0676c13
inline sole use of map_pt function
Birch-san Dec 26, 2022
264dfb7
simplify
Birch-san Dec 26, 2022
205f55b
no longer using original utilities from memory-efficient-attention re…
Birch-san Dec 26, 2022
1880c0e
fix query slicing
Birch-san Dec 26, 2022
8603c30
fix kv chunking
Birch-san Dec 26, 2022
96e0d8c
simplify dynamic slicing
Birch-san Dec 26, 2022
63ca66d
removed bias, mask, weights, calc_fn, and the conditions controlling …
Birch-san Dec 26, 2022
f4c0bf4
device arg fix no longer included
Birch-san Dec 26, 2022
624123f
simplify
Birch-san Dec 26, 2022
5b92dab
clarify attributions now that algorithm has been substantially rewritten
Birch-san Dec 26, 2022
60f0a5e
add chunk_threshold_bytes to let you specify your safe memory limit, …
Birch-san Dec 28, 2022
48db711
fast path for when we're just attention-slicing (i.e. chunking query …
Birch-san Dec 28, 2022
ef20fb9
default kv_chunk_size was meant to be sqrt() of global key size, not …
Birch-san Dec 28, 2022
69a8d2e
remove debug notes
Birch-san Dec 28, 2022
db25934
explain kv fast-path
Birch-san Dec 28, 2022
7aa8bac
add fast-path for "1 query chunk"
Birch-san Dec 28, 2022
59002c3
move kv_chunk_size_min concern to callsite, since if caller knows fin…
Birch-san Dec 28, 2022
a3152d8
Revert "move kv_chunk_size_min concern to callsite (1c4f10748e31d1851…
Birch-san Dec 28, 2022
0eafb95
de-duplicate fast-path for "matmul < quota". we can just ask for ever…
Birch-san Dec 28, 2022
9dc6822
pre-transpose key, rather than transposing it then undoing the transp…
Birch-san Dec 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 113 additions & 1 deletion src/diffusers/models/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from .sub_quadratic_attention import efficient_dot_product_attention

import torch
import torch.nn.functional as F
from torch import nn
from torch import nn, Tensor

from ..utils.import_utils import is_xformers_available

Expand Down Expand Up @@ -145,6 +146,29 @@ def set_attention_slice(self, slice_size):
processor = CrossAttnProcessor()

self.set_processor(processor)

def set_subquadratic_attention(
self,
query_chunk_size = 1024,
kv_chunk_size: Optional[int] = None,
kv_chunk_size_min: Optional[int] = None,
chunk_threshold_bytes: Optional[int] = None,
):
r"""
Args:
query_chunk_size (`int`, *optional*, defaults to `1024`)
kv_chunk_size (`int`, *optional*, defaults to `None`): if None, sqrt(key_tokens) is used.
kv_chunk_size_min (`int`, *optional*, defaults to `None`): only considered when `kv_chunk_size is None`. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
chunk_threshold_bytes (`int`, *optional*, defaults to `None`): if defined: only bother chunking if the self-attn matmul would allocate more bytes than this. whenever we can fit traditional attention into memory: we should prefer to do so, as the unchunked algorithm is faster.
"""
processor = SubQuadraticCrossAttnProcessor(
query_chunk_size=query_chunk_size,
kv_chunk_size=kv_chunk_size,
kv_chunk_size_min=kv_chunk_size_min,
chunk_threshold_bytes=chunk_threshold_bytes,
)

self.set_processor(processor)

def set_processor(self, processor: "AttnProcessor"):
self.processor = processor
Expand Down Expand Up @@ -236,6 +260,94 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No

return hidden_states

class SubQuadraticCrossAttnProcessor:
query_chunk_size: int
kv_chunk_size: Optional[int]
kv_chunk_size_min: Optional[int]
chunk_threshold_bytes: Optional[int]
def __init__(
self,
query_chunk_size = 1024,
kv_chunk_size: Optional[int] = None,
kv_chunk_size_min: Optional[int] = None,
chunk_threshold_bytes: Optional[int] = None,
):
r"""
Args:
query_chunk_size (`int`, *optional*, defaults to `1024`)
kv_chunk_size (`int`, *optional*, defaults to `None`): if None, sqrt(key_tokens) is used.
kv_chunk_size_min (`int`, *optional*, defaults to `None`): only considered when `kv_chunk_size is None`. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
chunk_threshold_bytes (`int`, *optional*, defaults to `None`): if defined: only bother chunking if the self-attn matmul would allocate more bytes than this. whenever we can fit traditional attention into memory: we should prefer to do so, as the unchunked algorithm is faster.
"""
self.query_chunk_size = query_chunk_size
self.kv_chunk_size = kv_chunk_size
self.kv_chunk_size_min = kv_chunk_size_min
self.chunk_threshold_bytes = chunk_threshold_bytes

def __call__(
self,
attn: CrossAttention,
hidden_states: Tensor,
encoder_hidden_states: Optional[Tensor]=None,
attention_mask: Optional[Tensor]=None,
):
encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states

assert attention_mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
# I don't know what test case can be used to determine whether softmax is computed at sufficient bit-width,
# but sub-quadratic attention has a pretty bespoke softmax (defers computation of the denominator) so this needs some thought.
assert not attn.upcast_softmax or torch.finfo(hidden_states.dtype).bits >= 32, "upcast_softmax was requested, but is not implemented"

query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

query = query.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1)
key_t = key.transpose(1,2).unflatten(1, (attn.heads, -1)).flatten(end_dim=1)
del key
value = value.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1)

dtype = query.dtype
# TODO: do we still need to do *everything* in float32, given how we delay the division?
# TODO: do we need to support upcast_softmax too? SD 2.1 seems to work without it
if attn.upcast_attention:
query = query.float()
key_t = key_t.float()

bytes_per_token = torch.finfo(query.dtype).bits//8
batch_x_heads, q_tokens, _ = query.shape
_, _, k_tokens = key_t.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens

query_chunk_size = self.query_chunk_size
kv_chunk_size = self.kv_chunk_size

if self.chunk_threshold_bytes is not None and qk_matmul_size_bytes <= self.chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
query_chunk_size = q_tokens
kv_chunk_size = k_tokens

hidden_states = efficient_dot_product_attention(
query,
key_t,
value,
query_chunk_size=query_chunk_size,
kv_chunk_size=kv_chunk_size,
kv_chunk_size_min=self.kv_chunk_size_min,
use_checkpoint=attn.training,
)

hidden_states = hidden_states.to(dtype)

hidden_states = hidden_states.unflatten(0, (-1, attn.heads)).transpose(1,2).flatten(start_dim=2)

out_proj, dropout = attn.to_out
hidden_states = out_proj(hidden_states)
hidden_states = dropout(hidden_states)

return hidden_states


class CrossAttnAddedKVProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
Expand Down
194 changes: 194 additions & 0 deletions src/diffusers/models/sub_quadratic_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# original source:
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
# license:
# unspecified
# credit:
# Amin Rezaei (original author)
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
# implementation of:
# Self-attention Does Not Need O(n2) Memory":
# https://arxiv.org/abs/2112.05682v2

from functools import partial
import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
from typing import Optional, NamedTuple, Protocol, List
from ..utils.dynamic_slice import dynamic_slice

class AttnChunk(NamedTuple):
exp_values: Tensor
exp_weights_sum: Tensor
max_score: Tensor

class SummarizeChunk(Protocol):
@staticmethod
def __call__(
query: Tensor,
key_t: Tensor,
value: Tensor,
) -> AttnChunk: ...

class ComputeQueryChunkAttn(Protocol):
@staticmethod
def __call__(
query: Tensor,
key_t: Tensor,
value: Tensor,
) -> Tensor: ...

def _summarize_chunk(
query: Tensor,
key_t: Tensor,
value: Tensor,
scale: float,
) -> AttnChunk:
attn_weights = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't torch.zeros() be used here instead of torch.empty()?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope; it's actually an unused tensor (because beta=0), so we want whatever's the cheapest thing that passes the parameter validation. unfortunately PyTorch complains if you pass None. bad API design.

query,
key_t,
alpha=scale,
beta=0,
)
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach()
exp_weights = torch.exp(attn_weights - max_score)
exp_values = torch.bmm(exp_weights, value)
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)

def _query_chunk_attention(
query: Tensor,
key_t: Tensor,
value: Tensor,
summarize_chunk: SummarizeChunk,
kv_chunk_size: int,
) -> Tensor:
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
_, _, v_channels_per_head = value.shape

def chunk_scanner(chunk_idx: int) -> AttnChunk:
key_chunk = dynamic_slice(
key_t,
(0, 0, chunk_idx),
(batch_x_heads, k_channels_per_head, kv_chunk_size)
)
value_chunk = dynamic_slice(
value,
(0, chunk_idx, 0),
(batch_x_heads, kv_chunk_size, v_channels_per_head)
)
return summarize_chunk(query, key_chunk, value_chunk)

chunks: List[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
]
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
chunk_values, chunk_weights, chunk_max = acc_chunk

global_max, _ = torch.max(chunk_max, 0, keepdim=True)
max_diffs = torch.exp(chunk_max - global_max)
chunk_values *= torch.unsqueeze(max_diffs, -1)
chunk_weights *= max_diffs

all_values = chunk_values.sum(dim=0)
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
return all_values / all_weights

# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking(
query: Tensor,
key_t: Tensor,
value: Tensor,
scale: float,
) -> Tensor:
attn_scores = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key_t,
alpha=scale,
beta=0,
)
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
hidden_states_slice = torch.bmm(attn_probs, value)
return hidden_states_slice

class ScannedChunk(NamedTuple):
chunk_idx: int
attn_chunk: AttnChunk

def efficient_dot_product_attention(
query: Tensor,
key_t: Tensor,
value: Tensor,
query_chunk_size=1024,
kv_chunk_size: Optional[int] = None,
kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True,
):
"""Computes efficient dot-product attention given query, transposed key, and value.
This is efficient version of attention presented in
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
Args:
query: queries for calculating attention with shape of
`[batch * num_heads, tokens, channels_per_head]`.
key_t: keys for calculating attention with shape of
`[batch * num_heads, channels_per_head, tokens]`.
value: values to be used in attention with shape of
`[batch * num_heads, tokens, channels_per_head]`.
query_chunk_size: int: query chunks size
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
Returns:
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
"""
batch_x_heads, q_tokens, q_channels_per_head = query.shape
_, _, k_tokens = key_t.shape
scale = q_channels_per_head ** -0.5

kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
if kv_chunk_size_min is not None:
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)

def get_query_chunk(chunk_idx: int) -> Tensor:
return dynamic_slice(
query,
(0, chunk_idx, 0),
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
)

summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
_get_attention_scores_no_kv_chunking,
scale=scale
) if k_tokens <= kv_chunk_size else (
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
partial(
_query_chunk_attention,
kv_chunk_size=kv_chunk_size,
summarize_chunk=summarize_chunk,
)
)

if q_tokens <= query_chunk_size:
# fast-path for when there's just 1 query chunk
return compute_query_chunk_attn(
query=query,
key_t=key_t,
value=value,
)

# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
res = torch.cat([
compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size),
key_t=key_t,
value=value,
) for i in range(math.ceil(q_tokens / query_chunk_size))
], dim=1)
return res
10 changes: 10 additions & 0 deletions src/diffusers/utils/dynamic_slice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from torch import Tensor
from typing import List

def dynamic_slice(
x: Tensor,
starts: List[int],
sizes: List[int],
) -> Tensor:
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this attempts to implement jax.lax.dynamic_slice(), but hey is this literally just torch.narrow()?

Copy link

@brkirch brkirch Jan 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that works also:
brkirch/stable-diffusion-webui@b119815

No notable performance difference that I observed, but it's probably slightly more efficient nonetheless.

return x[slicing]