Skip to content

Commit

Permalink
pre-transpose key, rather than transposing it then undoing the transp…
Browse files Browse the repository at this point in the history
…ose during the matmul
  • Loading branch information
Birch-san committed Dec 30, 2022
1 parent 0eafb95 commit 3c92600
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/diffusers/models/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,19 +303,20 @@ def __call__(
value = attn.to_v(encoder_hidden_states)

query = query.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1)
key = key.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1)
key_t = key.transpose(1,2).unflatten(1, (attn.heads, -1)).flatten(end_dim=1)
del key
value = value.unflatten(-1, (attn.heads, -1)).transpose(1,2).flatten(end_dim=1)

dtype = query.dtype
# TODO: do we still need to do *everything* in float32, given how we delay the division?
# TODO: do we need to support upcast_softmax too? SD 2.1 seems to work without it
if attn.upcast_attention:
query = query.float()
key = key.float()
key_t = key_t.float()

bytes_per_token = torch.finfo(query.dtype).bits//8
batch_x_heads, q_tokens, _ = query.shape
_, k_tokens, _ = key.shape
_, _, k_tokens = key_t.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens

query_chunk_size = self.query_chunk_size
Expand All @@ -329,7 +330,7 @@ def __call__(

hidden_states = efficient_dot_product_attention(
query,
key,
key_t,
value,
query_chunk_size=query_chunk_size,
kv_chunk_size=kv_chunk_size,
Expand Down

0 comments on commit 3c92600

Please sign in to comment.