diff --git a/model.py b/model.py index bfbb11d303..887ca56ab4 100644 --- a/model.py +++ b/model.py @@ -149,6 +149,91 @@ def forward(self, inputs): return numerator / denominator +class RotaryEmbedding(nn.Module): + + def __init__(self, config): + super().__init__() + self.dim = config.n_embd + + # Register frequencies directly as buffers + self.register_buffer('freq_left', (10000 ** (torch.arange(0, self.dim//2).float() / self.dim//2))) + self.register_buffer('freq_right',(10000 ** (torch.arange(0, self.dim//2).float() / self.dim//2))) + + def forward(self, x): + seq_len = x.shape[-2] + device = x.device + + t = torch.arange(seq_len, device=device) + + # Get separate frequencies for left and right + freqs_left = torch.einsum('i,j->ij', t, self.freq_left) + freqs_right = torch.einsum('i,j->ij', t, self.freq_right) + + # Apply frequencies + x_left, x_right = x[..., :self.dim//2], x[..., self.dim//2:] + x_left = x_left * freqs_left.cos() - x_right * freqs_left.sin() + x_right = x_left * freqs_right.sin() + x_right * freqs_right.cos() + + # Combine the left and right parts back + x = torch.cat([x_left, x_right], dim=-1) + + return x + +class ShortRope(nn.Module): + + def __init__(self, config): + super().__init__() + self.n = config.shortrope_length + self.dim = config.n_embd + + # Generate freqs of size n rather than full dim + self.register_buffer('freq_left', (10000 ** (torch.arange(0, self.n//2).float() / self.n//2))) + self.register_buffer('freq_right', (10000 ** (torch.arange(0, self.n//2).float() / self.n//2))) + + def forward(self, x): + # Step 1: Get the input tensor shape + batch_size, seq_len, _ = x.shape + + # Step 2: Split the input tensor into unrotated and rotated sections + x_unrotated = x[..., :-self.n] # All but the last n dimensions + x_rotated = x[..., -self.n:] # Only the last n dimensions + + # Step 3: Generate rotation frequencies + t = torch.arange(self.n, device=x.device) + freqs_left = torch.einsum('i,j->ij', t, self.freq_left) + freqs_right = torch.einsum('i,j->ij', t, self.freq_right) + + # Calculate how many times to repeat freqs along the sequence length + num_repeats = seq_len // self.n + int(seq_len % self.n != 0) + + # Repeat the frequency tensors to match the sequence length + freqs_left = freqs_left.repeat(batch_size, num_repeats, 1) + freqs_right = freqs_right.repeat(batch_size, num_repeats, 1) + + # Trim the excess elements so the freqs tensors match the sequence length + freqs_left = freqs_left[:, :seq_len, :] + freqs_right = freqs_right[:, :seq_len, :] + + # Step 4: Process the x_rotated section + x_left = x_rotated[..., :self.n//2] + x_right = x_rotated[..., self.n//2:] + + # Apply the cosine and sine rotations + x_left = x_left * freqs_left.cos() - x_right * freqs_left.sin() + x_right = x_left * freqs_right.sin() + x_right * freqs_right.cos() + + # Invert the order of the right tensor's last dimension and negate it + x_right = torch.flip(x_right, dims=[-1]) * -1 + + # Combine the left and right rotated sections + x_rotated = torch.cat([x_left, x_right], dim=-1) + + # Step 5: Combine the rotated and unrotated sections + x = torch.cat([x_unrotated, x_rotated], dim=-1) + + return x + + class CausalSelfAttention(nn.Module): def __init__(self, config): @@ -167,6 +252,14 @@ def __init__(self, config): # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') + # Rotary Positional Embeddings + self.rotary_emb = None + if config.use_rotary_embeddings: + if config.rope_variant == "rope": + self.rotary_emb = RotaryEmbedding(config) + if config.rope_variant == "shortrope": + self.rotary_emb = ShortRope(config) + # Softmax Variant Selection self.use_softmax_variant = config.use_softmax_variant self.softmax_variant = config.softmax_variant @@ -208,6 +301,8 @@ def __init__(self, config): .view(1, 1, config.block_size, config.block_size)) def forward(self, x): + if self.rotary_emb is not None: + x = self.rotary_emb(x) B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim @@ -303,6 +398,12 @@ class GPTConfig: constantmax_constant: int = 1000 # denominator to utilize for Constantmax strongermax_strength: int = 2 # Softermax Option active is softermax selected - True: uses (x - x_max) normalization; False: removes normalization (potential overflow) + # Positional Embeddings Variations + use_rotary_embeddings: bool = True # If True, uses rotary embeddings, else use conventional absolute position encoding + rope_variant: str = "shortrope" # options: "shortrope", "rope + shortrope_length: int = 8 # number of embeddings to use in shortrope + use_abs_pos_embeddings: bool = True # If True, uses rotary embeddings, else use conventional absolute position encoding + # Layernorm Alternatives and Options use_rmsnorm: bool = True # Add option for RMSNorm in place of LayerNorm: https://arxiv.org/abs/1910.07467 use_relu: bool = False #True: relu squared, False: do not utilize @@ -399,8 +500,12 @@ def forward(self, idx, targets=None): # forward the GPT model itself tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) - x = self.transformer.drop(tok_emb + pos_emb) + x = None + if self.config.use_abs_pos_embeddings: + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) + x = self.transformer.drop(tok_emb + pos_emb) + else: + x = self.transformer.drop(tok_emb) for block in self.transformer.h: x = block(x) x = self.transformer.ln_f(x) diff --git a/train.py b/train.py index 84f9745772..59564213eb 100644 --- a/train.py +++ b/train.py @@ -54,6 +54,12 @@ def parse_args(): # Norm variations model_group.add_argument('--use_rmsnorm', default=True, action=argparse.BooleanOptionalAction) + # Positional Embedding variations + model_group.add_argument('--use_rotary_embeddings', default=True, action=argparse.BooleanOptionalAction) + model_group.add_argument("--rope_variant", type=str, default="rope", choices=["shortrope", "rope"]) + model_group.add_argument("--shortrope_length", type=int, default="16", help="number of embeddings to use with rope, must be <= length, and be even") + model_group.add_argument('--use_abs_pos_embeddings', default=False, action=argparse.BooleanOptionalAction) + # Softmax variations model_group.add_argument('--use_softmax_variant', default=False, action=argparse.BooleanOptionalAction) model_group.add_argument("--softmax_variant", type=str, default="softermax", choices=["polymax", "strongermax", "softermax", "sigsoftmax", "sigsoftmax_base2"])