Skip to content

Commit

Permalink
attn refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
NouamaneTazi committed Sep 14, 2022
1 parent d30f968 commit 0c70c0e
Showing 1 changed file with 151 additions and 62 deletions.
213 changes: 151 additions & 62 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from typing import Optional

import torch
import torch.nn.functional as F
Expand All @@ -10,16 +11,24 @@ class AttentionBlock(nn.Module):
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention
Uses three q, k, v linear layers to compute attention.
Parameters:
channels (:obj:`int`): The number of channels in the input and output.
num_head_channels (:obj:`int`, *optional*):
The number of channels in each head. If None, then `num_heads` = 1.
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
"""

def __init__(
self,
channels,
num_head_channels=None,
num_groups=32,
rescale_output_factor=1.0,
eps=1e-5,
channels: int,
num_head_channels: Optional[int] = None,
num_groups: int = 32,
rescale_output_factor: float = 1.0,
eps: float = 1e-5,
):
super().__init__()
self.channels = channels
Expand Down Expand Up @@ -86,10 +95,26 @@ def forward(self, hidden_states):
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image
standard transformer action. Finally, reshape to image.
Parameters:
in_channels (:obj:`int`): The number of channels in the input and output.
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
d_head (:obj:`int`): The number of channels in each head.
depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
"""

def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
def __init__(
self,
in_channels: int,
n_heads: int,
d_head: int,
depth: int = 1,
dropout: float = 0.0,
context_dim: Optional[int] = None,
):
super().__init__()
self.n_heads = n_heads
self.d_head = d_head
Expand All @@ -112,22 +137,44 @@ def _set_attention_slice(self, slice_size):
for block in self.transformer_blocks:
block._set_attention_slice(slice_size)

def forward(self, x, context=None):
def forward(self, hidden_states, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
for block in self.transformer_blocks:
x = block(x, context=context)
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
x = self.proj_out(x)
return x + x_in
hidden_states = block(hidden_states, context=context)
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2)
hidden_states = self.proj_out(hidden_states)
return hidden_states + residual


class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
r"""
A basic Transformer block.
Parameters:
dim (:obj:`int`): The number of channels in the input and output.
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
d_head (:obj:`int`): The number of channels in each head.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
"""

def __init__(
self,
dim: int,
n_heads: int,
d_head: int,
dropout=0.0,
context_dim: Optional[int] = None,
gated_ff: bool = True,
checkpoint: bool = True,
):
super().__init__()
self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
Expand All @@ -145,15 +192,30 @@ def _set_attention_slice(self, slice_size):
self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size

def forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
def forward(self, hidden_states, context=None):
hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
return hidden_states


class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
r"""
A cross attention layer.
Parameters:
query_dim (:obj:`int`): The number of channels in the query.
context_dim (:obj:`int`, *optional*):
The number of channels in the context. If not given, defaults to `query_dim`.
heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""

def __init__(
self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
):
super().__init__()
inner_dim = dim_head * heads
context_dim = context_dim if context_dim is not None else query_dim
Expand All @@ -174,77 +236,104 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
tensor2 = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor3 = tensor2.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor3

def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
tensor2 = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor3 = tensor2.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor3

def forward(self, x, context=None, mask=None):
batch_size, sequence_length, dim = x.shape
def forward(self, hidden_states, context=None, mask=None):
batch_size, sequence_length, dim = hidden_states.shape

q = self.to_q(x)
context = context if context is not None else x
k = self.to_k(context)
v = self.to_v(context)
query = self.to_q(hidden_states)
context = context if context is not None else hidden_states
key = self.to_k(context)
value = self.to_v(context)

q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v)
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)

# TODO(PVP) - mask is currently never used. Remember to re-implement when used

# attention, what we cannot get enough of
hidden_states = self._attention(q, k, v, sequence_length, dim)
hidden_states = self._attention(query, key, value, sequence_length, dim)

return self.to_out(hidden_states)

def _attention(self, query, key, value, sequence_length, dim):
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = (
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
)
attn_slice = attn_slice.softmax(dim=-1)
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])

hidden_states[start_idx:end_idx] = attn_slice
# hidden_states = torch.zeros(
# (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
# )
slice_size = self._slice_size if self._slice_size is not None else batch_size_attention
# for i in range(hidden_states.shape[0] // slice_size):
# start_idx = i * slice_size
# end_idx = (i + 1) * slice_size
# qslice = query[start_idx:end_idx]
qslice = query
# kslice = key[start_idx:end_idx].transpose(1, 2)
kslice = key.transpose(1, 2)
attn_slice = torch.matmul(qslice, kslice) * self.scale
attn_slice = attn_slice.softmax(dim=-1)
# vslice = value[start_idx:end_idx]
vslice = value
hidden_states = torch.matmul(attn_slice, vslice)


# hidden_states = torch.cat(attn_slices, dim=0)


# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states


class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
r"""
A feed-forward layer.
Parameters:
dim (:obj:`int`): The number of channels in the input.
dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""

def __init__(
self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
project_in = GEGLU(dim, inner_dim)

self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))

def forward(self, x):
return self.net(x)
def forward(self, hidden_states):
return self.net(hidden_states)


# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
r"""
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
Parameters:
dim_in (:obj:`int`): The number of channels in the input.
dim_out (:obj:`int`): The number of channels in the output.
"""

def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)

def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * F.gelu(gate)

0 comments on commit 0c70c0e

Please sign in to comment.