Skip to content

Commit

Permalink
turn on Shaws relative positional encoding for local attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 8, 2020
1 parent 9fbfe59 commit 4cd1755
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 4 additions & 3 deletions performer_pytorch/performer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,12 @@ class SelfAttention(nn.Module):
def __init__(self, dim, causal = False, heads = 8, local_heads = 0, local_window_size = 256, nb_features = None, redraw_projection = True, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, dropout = 0.):
super().__init__()
assert dim % heads == 0, 'dimension must be divisible by number of heads'
self.fast_attention = FastAttention(dim // heads, nb_features, redraw_projection, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q)
dim_head = dim // heads
self.fast_attention = FastAttention(dim_head, nb_features, redraw_projection, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q)

self.heads = heads
self.global_heads = heads - local_heads
self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal)) if local_heads > 0 else None
self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None

self.to_q = nn.Linear(dim, dim)
self.to_k = nn.Linear(dim, dim)
Expand Down Expand Up @@ -348,7 +349,7 @@ def fix_projection_matrices_(self):

def forward(self, x, **kwargs):
b, n, device = *x.shape, x.device
# token and positoinal embeddings
# token and positional embeddings
x = self.token_emb(x)
x += self.pos_emb(torch.arange(n, device = device))
x = self.dropout(x)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'performer-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.7.5',
version = '0.8.0',
license='MIT',
description = 'Performer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 4cd1755

Please sign in to comment.