Skip to content

Commit

Permalink
add rotary positional embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 25, 2021
1 parent 2c651fc commit 1072a19
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 18 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ model = FastTransformer(
num_tokens = 20000,
dim = 512,
depth = 2,
max_seq_len = 4096
max_seq_len = 4096,
absolute_pos_emb = True # default uses relative positional encoding, but if that isn't working, then turn on absolute positional embedding by setting this to True
)

x = torch.randint(0, 20000, (1, 4096))
Expand Down
86 changes: 70 additions & 16 deletions fast_transformer_pytorch/fast_transformer_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange

from einops import rearrange, reduce
from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding

# helper functions

Expand Down Expand Up @@ -38,7 +40,9 @@ def __init__(
dim,
*,
heads = 8,
dim_head = 64
dim_head = 64,
max_seq_len = None,
pos_emb = None
):
super().__init__()
inner_dim = heads * dim_head
Expand All @@ -47,22 +51,44 @@ def __init__(

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

# rotary positional embedding

assert not (exists(pos_emb) and not exists(max_seq_len)), 'max_seq_len must be passed in if to use rotary positional embeddings'

self.pos_emb = pos_emb
self.max_seq_len = max_seq_len

# if using relative positional encoding, make sure to reduce pairs of consecutive feature dimension before doing projection to attention logits

kv_attn_proj_divisor = 1 if not exists(pos_emb) else 2

self.to_q_attn_logits = nn.Linear(dim_head, 1, bias = False) # for projecting queries to query attention logits
self.to_k_attn_logits = nn.Linear(dim_head, 1, bias = False) # for projecting keys to key attention logits
self.to_k_attn_logits = nn.Linear(dim_head // kv_attn_proj_divisor, 1, bias = False) # for projecting keys to key attention logits

# final transformation of values to "r" as in the paper

self.to_r = nn.Linear(dim_head, dim_head)
self.to_r = nn.Linear(dim_head // kv_attn_proj_divisor, dim_head)

self.to_out = nn.Linear(inner_dim, dim)

def forward(self, x, mask = None):
h = self.heads
n, device, h, use_rotary_emb = x.shape[1], x.device, self.heads, exists(self.pos_emb)

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

mask_value = -torch.finfo(x.dtype).max
mask = rearrange(mask, 'b n -> b () n')

# if relative positional encoding is needed

if use_rotary_emb:
freqs = self.pos_emb(torch.arange(self.max_seq_len, device = device), cache_key = self.max_seq_len)
freqs = rearrange(freqs[:n], 'n d -> () () n d')
q_aggr, k_aggr, v_aggr = map(lambda t: apply_rotary_emb(freqs, t), (q, k, v))
else:
q_aggr, k_aggr, v_aggr = q, k, v

# calculate query attention logits

q_attn_logits = rearrange(self.to_q_attn_logits(q), 'b h n () -> b h n') * self.scale
Expand All @@ -71,13 +97,18 @@ def forward(self, x, mask = None):

# calculate global query token

global_q = einsum('b h n, b h n d -> b h d', q_attn, q)
global_q = einsum('b h n, b h n d -> b h d', q_attn, q_aggr)
global_q = rearrange(global_q, 'b h d -> b h () d')

# bias keys with global query token

k = k * global_q

# if using rotary embeddings, do an inner product between adjacent pairs in the feature dimension

if use_rotary_emb:
k = reduce(k, 'b h n (d r) -> b h n d', 'sum', r = 2)

# now calculate key attention logits

k_attn_logits = rearrange(self.to_k_attn_logits(k), 'b h n () -> b h n') * self.scale
Expand All @@ -86,17 +117,27 @@ def forward(self, x, mask = None):

# calculate global key token

global_k = einsum('b h n, b h n d -> b h d', k_attn, k)
global_k = einsum('b h n, b h n d -> b h d', k_attn, k_aggr)
global_k = rearrange(global_k, 'b h d -> b h () d')

# bias the values

v = v * global_k
r = self.to_r(v)
u = v_aggr * global_k

# if using rotary embeddings, do an inner product between adjacent pairs in the feature dimension

if use_rotary_emb:
u = reduce(u, 'b h n (d r) -> b h n d', 'sum', r = 2)

# transformation step

r = self.to_r(u)

# paper then says to add the queries as a residual

r = r + q # paper says to add the queries as a residual
r = r + q

# aggregate
# combine heads

r = rearrange(r, 'b h n d -> b n (h d)')
return self.to_out(r)
Expand All @@ -113,16 +154,27 @@ def __init__(
max_seq_len,
heads = 8,
dim_head = 64,
ff_mult = 4
ff_mult = 4,
absolute_pos_emb = False
):
super().__init__()
self.token_emb = nn.Embedding(num_tokens, dim)
self.pos_emb = nn.Embedding(max_seq_len, dim)

# positional embeddings

self.abs_pos_emb = nn.Embedding(max_seq_len, dim) if absolute_pos_emb else None

layer_pos_emb = None
if not absolute_pos_emb:
assert (dim_head % 4) == 0, 'dimension of the head must be divisible by 4 to use rotary embeddings'
layer_pos_emb = RotaryEmbedding(dim_head // 2)

# layers

self.layers = nn.ModuleList([])

for _ in range(depth):
attn = FastAttention(dim, dim_head = dim_head, heads = heads)
attn = FastAttention(dim, dim_head = dim_head, heads = heads, pos_emb = layer_pos_emb, max_seq_len = max_seq_len)
ff = FeedForward(dim, mult = ff_mult)

self.layers.append(nn.ModuleList([
Expand Down Expand Up @@ -151,8 +203,10 @@ def forward(
):
n, device = x.shape[1], x.device
x = self.token_emb(x)
pos_emb = self.pos_emb(torch.arange(n, device = device))
x = x + rearrange(pos_emb, 'n d -> () n d')

if exists(self.abs_pos_emb):
pos_emb = self.abs_pos_emb(torch.arange(n, device = device))
x = x + rearrange(pos_emb, 'n d -> () n d')

for attn, ff in self.layers:
x = attn(x, mask = mask) + x
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'fast-transformer-pytorch',
packages = find_packages(),
version = '0.0.3',
version = '0.0.4',
license='MIT',
description = 'Fast Transformer - Pytorch',
author = 'Phil Wang',
Expand All @@ -16,6 +16,7 @@
],
install_requires=[
'einops>=0.3',
'rotary-embedding-torch',
'torch>=1.6'
],
classifiers=[
Expand Down

0 comments on commit 1072a19

Please sign in to comment.