Skip to content

Commit

Permalink
make performer work as encoder / decoder with cross attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 11, 2020
1 parent 1d2cdbb commit d86231b
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 12 deletions.
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,39 @@ x = torch.randn(1, 2048, 512)
model(x) # (1, 2048, 512)
```

Full encoder / decoder

```python
import torch
from performer_pytorch import PerformerLM

enc = PerformerLM(
num_tokens = 20000,
max_seq_len = 2048,
dim = 512,
depth = 6,
heads = 8
).cuda()

dec = PerformerLM(
num_tokens = 20000,
max_seq_len = 2048,
dim = 512,
depth = 6,
heads = 8,
causal = True,
cross_attend = True
).cuda()

src = torch.randint(0, 20000, (1, 2048)).cuda()
tgt = torch.randint(0, 20000, (1, 2048)).cuda()
src_mask = torch.ones_like(src).bool()
tgt_mask = torch.ones_like(src).bool()

encodings = enc(src, mask = src_mask, return_encodings = True)
logits = dec(tgt, context = encodings, mask = tgt_mask, context_mask = src_mask) # (1, 2048, 20000)
```

Standalone self-attention layer with linear complexity in respect to sequence length, for replacing trained full-attention transformer self-attention layers.

```python
Expand Down
43 changes: 32 additions & 11 deletions performer_pytorch/performer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,24 +269,30 @@ def __init__(self, dim, causal = False, heads = 8, local_heads = 0, local_window
self.to_out = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)

def forward(self, x, mask = None):
def forward(self, x, context = None, mask = None, context_mask = None):
b, n, _, h, gh = *x.shape, self.heads, self.global_heads
qkv = map(lambda fn: fn(x), (self.to_q, self.to_k, self.to_v))

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
cross_attend = exists(context)
context = default(context, x)
context_mask = default(context_mask, mask)

q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
(q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))

attn_outs = []

if not empty(q):
if exists(mask):
global_mask = mask[:, None, :, None]
if exists(context_mask):
global_mask = context_mask[:, None, :, None]
k.masked_fill_(~global_mask, 0)

out = self.fast_attention(q, k, v)
attn_outs.append(out)

if not empty(lq):
assert 'local attention is not compatible with cross attention'
out = self.local_attn(lq, lk, lv, input_mask = mask)
attn_outs.append(out)

Expand All @@ -296,7 +302,7 @@ def forward(self, x, mask = None):
return self.dropout(out)

class Performer(nn.Module):
def __init__(self, dim, depth, heads, local_attn_heads = 0, local_window_size = 256, causal = False, ff_mult = 4, nb_features = None, reversible = False, ff_chunks = 1, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, use_scalenorm = False, use_rezero = False, ff_glu = False, ff_dropout = 0., attn_dropout = 0.):
def __init__(self, dim, depth, heads, local_attn_heads = 0, local_window_size = 256, causal = False, ff_mult = 4, nb_features = None, reversible = False, ff_chunks = 1, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, use_scalenorm = False, use_rezero = False, ff_glu = False, ff_dropout = 0., attn_dropout = 0., cross_attend = False):
super().__init__()
layers = nn.ModuleList([])
local_attn_heads = cast_tuple(local_attn_heads)
Expand All @@ -317,16 +323,27 @@ def __init__(self, dim, depth, heads, local_attn_heads = 0, local_window_size =
wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1))
]))

if not cross_attend:
continue

layers.append(nn.ModuleList([
wrapper_fn(SelfAttention(dim, heads = heads, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q, dropout = attn_dropout)),
wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1))
]))

execute_type = ReversibleSequence if reversible else SequentialSequence
route_attn = ((True, False),) * depth

route_attn = ((True, False),) * depth * (2 if cross_attend else 1)
route_context = ((False, False), (True, False)) * depth
attn_route_map = {'mask': route_attn}
self.net = execute_type(layers, args_route = {**attn_route_map})
context_route_map = {'context': route_context, 'context_mask': route_context} if cross_attend else {}
self.net = execute_type(layers, args_route = {**attn_route_map, **context_route_map})

def forward(self, x, **kwargs):
return self.net(x, **kwargs)

class PerformerLM(nn.Module):
def __init__(self, *, num_tokens, max_seq_len, dim, depth, heads, local_attn_heads = 0, local_window_size = 256, causal = False, ff_mult = 4, nb_features = None, reversible = False, ff_chunks = 1, ff_glu = False, emb_dropout = 0., ff_dropout = 0., attn_dropout = 0., generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, use_scalenorm = False, use_rezero = False):
def __init__(self, *, num_tokens, max_seq_len, dim, depth, heads, local_attn_heads = 0, local_window_size = 256, causal = False, ff_mult = 4, nb_features = None, reversible = False, ff_chunks = 1, ff_glu = False, emb_dropout = 0., ff_dropout = 0., attn_dropout = 0., generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, use_scalenorm = False, use_rezero = False, cross_attend = False):
super().__init__()
local_attn_heads = cast_tuple(local_attn_heads)

Expand All @@ -338,7 +355,7 @@ def __init__(self, *, num_tokens, max_seq_len, dim, depth, heads, local_attn_hea
nn.init.normal_(self.token_emb.weight, std = 0.02)
nn.init.normal_(self.pos_emb.weight, std = 0.02)

self.performer = Performer(dim, depth, heads, local_attn_heads, local_window_size, causal, ff_mult, nb_features, reversible, ff_chunks, generalized_attention, kernel_fn, qr_uniform_q, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout)
self.performer = Performer(dim, depth, heads, local_attn_heads, local_window_size, causal, ff_mult, nb_features, reversible, ff_chunks, generalized_attention, kernel_fn, qr_uniform_q, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend)
self.norm = nn.LayerNorm(dim)

def fix_projection_matrices_(self):
Expand All @@ -347,7 +364,7 @@ def fix_projection_matrices_(self):
for fast_attention in fast_attentions:
fast_attention.set_projection_matrix(device)

def forward(self, x, **kwargs):
def forward(self, x, return_encodings = False, **kwargs):
b, n, device = *x.shape, x.device
# token and positional embeddings
x = self.token_emb(x)
Expand All @@ -359,4 +376,8 @@ def forward(self, x, **kwargs):

# norm and to logits
x = self.norm(x)

if return_encodings:
return x

return x @ self.token_emb.weight.t()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'performer-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.8.1',
version = '0.9.0',
license='MIT',
description = 'Performer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit d86231b

Please sign in to comment.