Skip to content

Commit

Permalink
make sure EPFLs cuda code can work with amp autocast, by setting amp_…
Browse files Browse the repository at this point in the history
…autocast to True
  • Loading branch information
lucidrains committed Dec 9, 2020
1 parent 53d4b8f commit 607db4e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
2 changes: 2 additions & 0 deletions performer_pytorch/performer_enc_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
ignore_index = 0,
pad_value = 0,
tie_token_embeds = False,
amp_autocast = False,
**kwargs
):
super().__init__()
Expand All @@ -51,6 +52,7 @@ def __init__(
assert 'dim' not in dec_kwargs and 'dim' not in enc_kwargs, 'you must set the dim for both encoder and decoder'

enc_kwargs['dim'] = dec_kwargs['dim'] = dim
dec_kwargs['amp_autocast'] = amp_autocast
dec_kwargs['causal'] = True
dec_kwargs['cross_attend'] = True

Expand Down
22 changes: 12 additions & 10 deletions performer_pytorch/performer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import autocast
from einops import rearrange, repeat
from functools import partial

Expand Down Expand Up @@ -122,9 +123,10 @@ def linear_attention(q, k, v):

# efficient causal linear attention, created by EPFL
# TODO: rewrite EPFL's CUDA kernel to do mixed precision and remove half to float conversion and back
def causal_linear_attention(q, k, v):
def causal_linear_attention(q, k, v, amp_autocast = False):
from fast_transformers.causal_product import CausalDotProduct
is_half = isinstance(q, torch.cuda.HalfTensor)
is_half = isinstance(q, torch.cuda.HalfTensor) or amp_autocast

if is_half:
q, k, v = map(lambda t: t.float(), (q, k, v))

Expand All @@ -146,7 +148,7 @@ def causal_linear_attention_noncuda(q, k, v):
return out

class FastAttention(nn.Module):
def __init__(self, dim_heads, nb_features = None, feature_redraw_interval = 0, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False):
def __init__(self, dim_heads, nb_features = None, feature_redraw_interval = 0, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, amp_autocast = False):
super().__init__()
nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))

Expand All @@ -168,7 +170,7 @@ def __init__(self, dim_heads, nb_features = None, feature_redraw_interval = 0, o
if causal:
try:
import fast_transformers.causal_product.causal_product_cuda
self.causal_linear_fn = causal_linear_attention
self.causal_linear_fn = partial(causal_linear_attention, amp_autocast = amp_autocast)
except ImportError:
print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version')
self.causal_linear_fn = causal_linear_attention_noncuda
Expand Down Expand Up @@ -264,11 +266,11 @@ def forward(self, x, **kwargs):
return x

class SelfAttention(nn.Module):
def __init__(self, dim, causal = False, heads = 8, local_heads = 0, local_window_size = 256, nb_features = None, feature_redraw_interval = 1000, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, dropout = 0.):
def __init__(self, dim, causal = False, heads = 8, local_heads = 0, local_window_size = 256, nb_features = None, feature_redraw_interval = 1000, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, dropout = 0., amp_autocast = False):
super().__init__()
assert dim % heads == 0, 'dimension must be divisible by number of heads'
dim_head = dim // heads
self.fast_attention = FastAttention(dim_head, nb_features, feature_redraw_interval, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q)
self.fast_attention = FastAttention(dim_head, nb_features, feature_redraw_interval, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q, amp_autocast = amp_autocast)

self.heads = heads
self.global_heads = heads - local_heads
Expand Down Expand Up @@ -313,7 +315,7 @@ def forward(self, x, context = None, mask = None, context_mask = None):
return self.dropout(out)

class Performer(nn.Module):
def __init__(self, dim, depth, heads, local_attn_heads = 0, local_window_size = 256, causal = False, ff_mult = 4, nb_features = None, reversible = False, ff_chunks = 1, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, use_scalenorm = False, use_rezero = False, ff_glu = False, ff_dropout = 0., attn_dropout = 0., cross_attend = False):
def __init__(self, dim, depth, heads, local_attn_heads = 0, local_window_size = 256, causal = False, ff_mult = 4, nb_features = None, reversible = False, ff_chunks = 1, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, use_scalenorm = False, use_rezero = False, ff_glu = False, ff_dropout = 0., attn_dropout = 0., cross_attend = False, amp_autocast = False):
super().__init__()
layers = nn.ModuleList([])
local_attn_heads = cast_tuple(local_attn_heads)
Expand All @@ -330,7 +332,7 @@ def __init__(self, dim, depth, heads, local_attn_heads = 0, local_window_size =

for _, local_heads in zip(range(depth), local_attn_heads):
layers.append(nn.ModuleList([
wrapper_fn(SelfAttention(dim, causal = causal, heads = heads, local_heads = local_heads, local_window_size = local_window_size, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q, dropout = attn_dropout)),
wrapper_fn(SelfAttention(dim, causal = causal, heads = heads, local_heads = local_heads, local_window_size = local_window_size, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q, dropout = attn_dropout, amp_autocast = amp_autocast)),
wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1))
]))

Expand All @@ -354,7 +356,7 @@ def forward(self, x, **kwargs):
return self.net(x, **kwargs)

class PerformerLM(nn.Module):
def __init__(self, *, num_tokens, max_seq_len, dim, depth, heads, local_attn_heads = 0, local_window_size = 256, causal = False, ff_mult = 4, nb_features = None, reversible = False, ff_chunks = 1, ff_glu = False, emb_dropout = 0., ff_dropout = 0., attn_dropout = 0., generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, use_scalenorm = False, use_rezero = False, cross_attend = False):
def __init__(self, *, num_tokens, max_seq_len, dim, depth, heads, local_attn_heads = 0, local_window_size = 256, causal = False, ff_mult = 4, nb_features = None, reversible = False, ff_chunks = 1, ff_glu = False, emb_dropout = 0., ff_dropout = 0., attn_dropout = 0., generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, use_scalenorm = False, use_rezero = False, cross_attend = False, amp_autocast = False):
super().__init__()
local_attn_heads = cast_tuple(local_attn_heads)

Expand All @@ -366,7 +368,7 @@ def __init__(self, *, num_tokens, max_seq_len, dim, depth, heads, local_attn_hea
nn.init.normal_(self.token_emb.weight, std = 0.02)
nn.init.normal_(self.pos_emb.weight, std = 0.02)

self.performer = Performer(dim, depth, heads, local_attn_heads, local_window_size, causal, ff_mult, nb_features, reversible, ff_chunks, generalized_attention, kernel_fn, qr_uniform_q, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend)
self.performer = Performer(dim, depth, heads, local_attn_heads, local_window_size, causal, ff_mult, nb_features, reversible, ff_chunks, generalized_attention, kernel_fn, qr_uniform_q, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, amp_autocast)
self.norm = nn.LayerNorm(dim)

def fix_projection_matrices_(self):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'performer-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.12.5',
version = '0.12.6',
license='MIT',
description = 'Performer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 607db4e

Please sign in to comment.