Skip to content

Commit

Permalink
Merge branch 'add_rotary_position_embeddings' of github.com:gkielian/…
Browse files Browse the repository at this point in the history
…nanoGPT into gkielian-add_rotary_position_embeddings
  • Loading branch information
gkielian committed Nov 17, 2023
2 parents 027980a + 7e86704 commit bba2048
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 2 deletions.
109 changes: 107 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,91 @@ def forward(self, inputs):

return numerator / denominator

class RotaryEmbedding(nn.Module):

def __init__(self, config):
super().__init__()
self.dim = config.n_embd

# Register frequencies directly as buffers
self.register_buffer('freq_left', (10000 ** (torch.arange(0, self.dim//2).float() / self.dim//2)))
self.register_buffer('freq_right',(10000 ** (torch.arange(0, self.dim//2).float() / self.dim//2)))

def forward(self, x):
seq_len = x.shape[-2]
device = x.device

t = torch.arange(seq_len, device=device)

# Get separate frequencies for left and right
freqs_left = torch.einsum('i,j->ij', t, self.freq_left)
freqs_right = torch.einsum('i,j->ij', t, self.freq_right)

# Apply frequencies
x_left, x_right = x[..., :self.dim//2], x[..., self.dim//2:]
x_left = x_left * freqs_left.cos() - x_right * freqs_left.sin()
x_right = x_left * freqs_right.sin() + x_right * freqs_right.cos()

# Combine the left and right parts back
x = torch.cat([x_left, x_right], dim=-1)

return x

class ShortRope(nn.Module):

def __init__(self, config):
super().__init__()
self.n = config.shortrope_length
self.dim = config.n_embd

# Generate freqs of size n rather than full dim
self.register_buffer('freq_left', (10000 ** (torch.arange(0, self.n//2).float() / self.n//2)))
self.register_buffer('freq_right', (10000 ** (torch.arange(0, self.n//2).float() / self.n//2)))

def forward(self, x):
# Step 1: Get the input tensor shape
batch_size, seq_len, _ = x.shape

# Step 2: Split the input tensor into unrotated and rotated sections
x_unrotated = x[..., :-self.n] # All but the last n dimensions
x_rotated = x[..., -self.n:] # Only the last n dimensions

# Step 3: Generate rotation frequencies
t = torch.arange(self.n, device=x.device)
freqs_left = torch.einsum('i,j->ij', t, self.freq_left)
freqs_right = torch.einsum('i,j->ij', t, self.freq_right)

# Calculate how many times to repeat freqs along the sequence length
num_repeats = seq_len // self.n + int(seq_len % self.n != 0)

# Repeat the frequency tensors to match the sequence length
freqs_left = freqs_left.repeat(batch_size, num_repeats, 1)
freqs_right = freqs_right.repeat(batch_size, num_repeats, 1)

# Trim the excess elements so the freqs tensors match the sequence length
freqs_left = freqs_left[:, :seq_len, :]
freqs_right = freqs_right[:, :seq_len, :]

# Step 4: Process the x_rotated section
x_left = x_rotated[..., :self.n//2]
x_right = x_rotated[..., self.n//2:]

# Apply the cosine and sine rotations
x_left = x_left * freqs_left.cos() - x_right * freqs_left.sin()
x_right = x_left * freqs_right.sin() + x_right * freqs_right.cos()

# Invert the order of the right tensor's last dimension and negate it
x_right = torch.flip(x_right, dims=[-1]) * -1

# Combine the left and right rotated sections
x_rotated = torch.cat([x_left, x_right], dim=-1)

# Step 5: Combine the rotated and unrotated sections
x = torch.cat([x_unrotated, x_rotated], dim=-1)

return x


class CausalSelfAttention(nn.Module):

def __init__(self, config):
Expand All @@ -167,6 +252,14 @@ def __init__(self, config):
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')

# Rotary Positional Embeddings
self.rotary_emb = None
if config.use_rotary_embeddings:
if config.rope_variant == "rope":
self.rotary_emb = RotaryEmbedding(config)
if config.rope_variant == "shortrope":
self.rotary_emb = ShortRope(config)

# Softmax Variant Selection
self.use_softmax_variant = config.use_softmax_variant
self.softmax_variant = config.softmax_variant
Expand Down Expand Up @@ -208,6 +301,8 @@ def __init__(self, config):
.view(1, 1, config.block_size, config.block_size))

def forward(self, x):
if self.rotary_emb is not None:
x = self.rotary_emb(x)
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

# calculate query, key, values for all heads in batch and move head forward to be the batch dim
Expand Down Expand Up @@ -303,6 +398,12 @@ class GPTConfig:
constantmax_constant: int = 1000 # denominator to utilize for Constantmax
strongermax_strength: int = 2 # Softermax Option active is softermax selected - True: uses (x - x_max) normalization; False: removes normalization (potential overflow)

# Positional Embeddings Variations
use_rotary_embeddings: bool = True # If True, uses rotary embeddings, else use conventional absolute position encoding
rope_variant: str = "shortrope" # options: "shortrope", "rope
shortrope_length: int = 8 # number of embeddings to use in shortrope
use_abs_pos_embeddings: bool = True # If True, uses rotary embeddings, else use conventional absolute position encoding

# Layernorm Alternatives and Options
use_rmsnorm: bool = True # Add option for RMSNorm in place of LayerNorm: https://arxiv.org/abs/1910.07467
use_relu: bool = False #True: relu squared, False: do not utilize
Expand Down Expand Up @@ -399,8 +500,12 @@ def forward(self, idx, targets=None):

# forward the GPT model itself
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
x = None
if self.config.use_abs_pos_embeddings:
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
else:
x = self.transformer.drop(tok_emb)
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
Expand Down
6 changes: 6 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ def parse_args():
# Norm variations
model_group.add_argument('--use_rmsnorm', default=True, action=argparse.BooleanOptionalAction)

# Positional Embedding variations
model_group.add_argument('--use_rotary_embeddings', default=True, action=argparse.BooleanOptionalAction)
model_group.add_argument("--rope_variant", type=str, default="rope", choices=["shortrope", "rope"])
model_group.add_argument("--shortrope_length", type=int, default="16", help="number of embeddings to use with rope, must be <= length, and be even")
model_group.add_argument('--use_abs_pos_embeddings', default=False, action=argparse.BooleanOptionalAction)

# Softmax variations
model_group.add_argument('--use_softmax_variant', default=False, action=argparse.BooleanOptionalAction)
model_group.add_argument("--softmax_variant", type=str, default="softermax", choices=["polymax", "strongermax", "softermax", "sigsoftmax", "sigsoftmax_base2"])
Expand Down

0 comments on commit bba2048

Please sign in to comment.