diff --git a/model.py b/model.py index b59c60053b..14408b2457 100644 --- a/model.py +++ b/model.py @@ -18,7 +18,7 @@ # Variations from variations.softmax_variations import Softermax, Constantmax, Constantmax_quan, Strongermax, Polymax, SigSoftmax from variations.normalization_variations import LayerNorm, RMSNorm -from variations.position_encoding_variations import RotaryEmbedding, ShortRope +from variations.position_encoding_variations import RotaryEmbedding, ShortRope, FIRE from variations.activation_variations import SquaredReLU, activation_dictionary @@ -39,6 +39,7 @@ def __init__(self, config): self.dropout = config.dropout self.window_size = config.window_size self.gate = config.gate + self.fire_pos_enc = FIRE(num_heads=config.n_head) # Rotary Positional Embeddings self.rotary_emb = None @@ -86,8 +87,7 @@ def __init__(self, config): .view(1, 1, config.block_size, config.block_size)) def forward(self, x): - if self.rotary_emb is not None: - x = self.rotary_emb(x) + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) if self.window_size is not None: window_mask = torch.ones((1, 1, T, T), device=x.device) @@ -124,9 +124,21 @@ def forward(self, x): # regular lower triangle attention att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + if self.use_fire_emb: + fire_bias = self.fire_pos_enc(x) # Add this line + att = att + fire_bias # Add this line + + # if torch.sum(att.isnan()==True): + # print("fire bias + att nan") + # input() + + # softmax variation if self.softmax_variant_attn != 'softmax': att = self.softmax_layer(att) + # if torch.sum(att.isnan()==True): + # print("softmax nan") + # input() else: att = F.softmax(att, dim=-1) @@ -345,7 +357,8 @@ def _init_weights(self, module): def forward(self, idx, targets=None): device = idx.device b, t = idx.size() - assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + # assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + print(t) pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) # forward the GPT model itself @@ -483,7 +496,8 @@ def estimate_mfu(self, fwdbwd_per_iter, dt): return mfu @torch.no_grad() - def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): + def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, + block_size=None): """ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete the sequence max_new_tokens times, feeding the predictions back into the model each time. @@ -491,7 +505,10 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): """ for _ in range(max_new_tokens): # if the sequence context is growing too long we must crop it at block_size - idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] + # idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] + block_size = 800 + idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:] + # forward the model to get the logits for the index in the sequence logits, _ = self(idx_cond) # pluck the logits at the final step and scale by desired temperature @@ -512,3 +529,8 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): idx = torch.cat((idx, idx_next), dim=1) return idx + + +# Enable anomaly detection +torch.autograd.set_detect_anomaly(True) + diff --git a/train.py b/train.py index bddd926910..efa62b314a 100644 --- a/train.py +++ b/train.py @@ -105,8 +105,9 @@ def parse_args(): default="softmax", choices=["constantmax_quan", "constantmax", "polymax", "strongermax", "softermax", "sigsoftmax", "softmax"]) ## Custom Softmax Variation Options - model_group.add_argument("--constantmax_initial_beta", type=float, default=0.0) - model_group.add_argument("--constantmax_initial_gamma", type=float, default=100.0) + model_group.add_argument("--constantmax_initial_beta", type=float, default=6.5) + model_group.add_argument("--constantmax_initial_gamma", type=float, + default=100.0) model_group.add_argument('--constantmax_use_euler_base', default=True, action=argparse.BooleanOptionalAction) model_group.add_argument("--constantmax_base", type=float, default=2.0) diff --git a/variations/position_encoding_variations.py b/variations/position_encoding_variations.py index 5cd5dee999..befaee4ada 100644 --- a/variations/position_encoding_variations.py +++ b/variations/position_encoding_variations.py @@ -1,14 +1,254 @@ import torch import torch.nn as nn +def safe_log(x, eps=1e-6): + """ + Computes a safe logarithm of x by clamping the minimum value to eps. + This prevents log(0) and log(negative) scenarios. + """ + return torch.log(torch.clamp(x, min=eps)) + +class FIRE(nn.Module): + def __init__(self, num_heads=12, mlp_width=32, init_c=0.1, init_L=512.0, eps=1e-6): + super(FIRE, self).__init__() + self.mlp = nn.Sequential( + nn.Linear(1, mlp_width), nn.ReLU(), nn.Linear(mlp_width, num_heads) + ) + self.c = nn.Parameter(torch.tensor(init_c, dtype=torch.float)) + self.init_L = nn.Parameter(torch.tensor(init_L, dtype=torch.float), requires_grad=False) + self.L_multiplier = nn.Parameter(torch.tensor(1.0, dtype=torch.float)) + self.eps = eps + + def forward(self, x: torch.Tensor): + seq_length = x.size(1) + positions = torch.arange(seq_length, dtype=torch.float, device=x.device) + rel_distance = positions[:, None] - positions[None, :] + + # Apply absolute value and ensure positive before log + abs_rel_distance = torch.abs(rel_distance) + self.eps + + threshold = torch.abs(self.L_multiplier * self.init_L) + pos_normalizer = torch.max(positions, threshold) + pos_normalizer = pos_normalizer[:, None] + self.eps # Ensure pos_normalizer is never zero + + # Use safe log operation + log_rel_distance = torch.log(abs_rel_distance * self.c + self.eps) + log_pos_normalizer = torch.log(torch.abs(self.c * pos_normalizer) + self.eps) + + normalized_distance = log_rel_distance - log_pos_normalizer # Subtraction instead of division + + fire_bias = self.mlp(normalized_distance.unsqueeze(-1)) + fire_bias = fire_bias.unsqueeze(0).permute(0, 3, 1, 2) + + return fire_bias + +class FIRE_safe(nn.Module): + def __init__(self, num_heads=12, mlp_width=32, init_c=0.1, init_L=512.0, eps=1e-3): + super(FIRE, self).__init__() + self.mlp = nn.Sequential( + nn.Linear(1, mlp_width), nn.ReLU(), nn.Linear(mlp_width, num_heads) + ) + self.c = nn.Parameter(torch.tensor(init_c)) + self.init_L = nn.Parameter(torch.tensor(init_L), requires_grad=False) + self.L_multiplier = nn.Parameter(torch.tensor(1.0)) + self.eps = eps + + def forward(self, x: torch.Tensor): + seq_length = x.size(1) + positions = torch.arange(seq_length, dtype=torch.float, device=x.device) + rel_distance = positions[:, None] - positions[None, :] + + threshold = torch.abs(self.L_multiplier * self.init_L) + pos_normalizer = torch.max(positions, threshold) + pos_normalizer = pos_normalizer[:, None] + + # Using safe_log for logging operations + rel_distance = safe_log(torch.abs(self.c * rel_distance) + 1, self.eps) + pos_normalizer = safe_log(torch.abs(self.c * pos_normalizer) + 1, self.eps) + + normalized_distance = rel_distance / pos_normalizer + + fire_bias = self.mlp(normalized_distance.unsqueeze(-1)) + fire_bias = fire_bias.unsqueeze(0).permute(0, 3, 1, 2) + + return fire_bias + +class FIRE_(nn.Module): + def __init__(self, num_heads=12, mlp_width=32, init_c=0.1, init_L=512.0, eps=1e-3): + super(FIRE, self).__init__() + self.mlp = nn.Sequential( + nn.Linear(1, mlp_width), nn.ReLU(), nn.Linear(mlp_width, num_heads) + ) + self.c = nn.Parameter(torch.tensor(init_c)) + self.init_L = nn.Parameter(torch.tensor(init_L), requires_grad=False) + self.L_multiplier = nn.Parameter(torch.tensor(1.0)) + self.eps = eps + + def forward(self, x: torch.Tensor): + seq_length = x.size(1) # Adjusted to correct dimension + positions = torch.arange(seq_length, dtype=torch.float, device=x.device) + rel_distance = positions[:, None] - positions[None, :] + + threshold = torch.abs(self.L_multiplier * self.init_L) + pos_normalizer = torch.max(positions, threshold) + pos_normalizer = pos_normalizer[:, None] + + # Check for nan + if torch.isnan(rel_distance).any(): + print("Nan found in rel_distance") + + rel_distance = torch.log(torch.abs(self.c * rel_distance) + 1) + pos_normalizer = torch.log(torch.abs(self.c * pos_normalizer) + 1) + self.eps + + # Check for nan after transformations + if torch.isnan(rel_distance).any(): + print("Nan found in log-transformed rel_distance") + if torch.isnan(pos_normalizer).any(): + print("Nan found in log-transformed pos_normalizer") + + normalized_distance = rel_distance / pos_normalizer + + if torch.isnan(normalized_distance).any(): + print("Nan found in normalized_distance") + + fire_bias = self.mlp(normalized_distance.unsqueeze(-1)) + fire_bias = fire_bias.unsqueeze(0).permute(0, 3, 1, 2) + + if torch.isnan(fire_bias).any(): + print("Nan found in fire_bias") + + return fire_bias + + +class FIRE2(nn.Module): + def __init__(self, num_heads=12, mlp_width=32, init_c=0.1, init_L=512.0, + eps=1e-3): + """ + FIRE attention bias module. + + Args: + num_heads: number of attention heads. + mlp_width: Width of MLP. + init_c: initial value of log transformation parameter + init_L: initial value of thresholding parameter + eps: small constant for numerical stability + """ + super(FIRE, self).__init__() + + # Define the MLP layers + self.mlp = nn.Sequential( + nn.Linear(1, mlp_width), nn.ReLU(), nn.Linear(mlp_width, num_heads) + ) + + # Initialize c (log transformation parameter) + self.c = nn.Parameter(torch.tensor(init_c)) + + # Initialize L (threshold) + self.init_L = nn.Parameter(torch.tensor(init_L), requires_grad=False) + # Learn a multiplier to L + self.L_multiplier = nn.Parameter(torch.tensor(1.0)) + + self.eps = eps + + def forward(self, x: torch.Tensor): + """ + Compute FIRE attention bias. + + Args: + x: input sequence, + shape [bsz, num_heads, seq_len, hidden_dim] + + Returns: + attention bias, + shape [1, num_heads, seq_len, seq_len] + """ + seq_length = x.size(1) + positions = torch.arange(seq_length, dtype=torch.float, device=x.device) + rel_distance = positions[:, None] - positions[None, :] + + # Thresholding the normalizer + threshold = torch.abs(self.L_multiplier * self.init_L) + pos_normalizer = torch.max(positions, threshold) + pos_normalizer = pos_normalizer[:, None] + + # Amplifying differences among local positions + # with log transform + rel_distance = torch.log(torch.abs(self.c * rel_distance) + 1) + pos_normalizer = torch.log(torch.abs(self.c * pos_normalizer) + 1) + self.eps + + # Progressive interpolation + normalized_distance = rel_distance / pos_normalizer + fire_bias = self.mlp(normalized_distance.unsqueeze(-1)) + fire_bias = fire_bias.unsqueeze(0).permute(0, 3, 1, 2) + return fire_bias + + +class FIRE2kOu(nn.Module): + def __init__(self, num_heads=6, mlp_width=32, init_c=0.1, init_L=256.0, eps=1e-6): + super(FIRE, self).__init__() + + self.mlp = nn.Sequential( + nn.Linear(1, mlp_width), nn.ReLU(), nn.Linear(mlp_width, num_heads) + ) + + self.c = nn.Parameter(torch.tensor(init_c)) + self.init_L = nn.Parameter(torch.tensor(init_L), requires_grad=False) + self.L_multiplier = nn.Parameter(torch.tensor(1.0)) + self.eps = eps + + def forward(self, x: torch.Tensor): + """ + Compute FIRE attention bias. + + Args: + x: Input tensor with shape [batch_size, seq_len, embedding_dim] + + Returns: + Attention bias of shape [batch_size, num_heads, seq_len, seq_len] + """ + + batch_size, seq_len, _ = x.shape + + # Generate relative positions + positions = torch.arange(seq_len, dtype=torch.float, device=x.device) + rel_distance = ( + positions[None, :, None] - positions[None, None, :] + ) # [1, seq_len, seq_len] + + # Thresholding + threshold = torch.abs(self.L_multiplier * self.init_L) + pos_normalizer = torch.max(rel_distance, threshold.expand_as(rel_distance)) + + # Amplification with log transform + rel_distance = torch.log(torch.abs(self.c * rel_distance) + 1) + pos_normalizer = torch.log(torch.abs(self.c * pos_normalizer) + 1) + self.eps + + # Normalization and MLP + normalized_distance = rel_distance / pos_normalizer + fire_bias = self.mlp( + normalized_distance.unsqueeze(-1) + ) # [1, seq_len, seq_len, num_heads] + + # Expand batch dimension and permute for proper shape + fire_bias = fire_bias.expand(batch_size, -1, -1, -1).permute(0, 3, 1, 2) + + return fire_bias + + class RotaryEmbedding(nn.Module): def __init__(self, config): super().__init__() self.dim = config.n_embd # Register frequencies directly as buffers - self.register_buffer('freq_left', (10000 ** (torch.arange(0, self.dim//2).float() / self.dim//2))) - self.register_buffer('freq_right',(10000 ** (torch.arange(0, self.dim//2).float() / self.dim//2))) + self.register_buffer( + "freq_left", + (10000 ** (torch.arange(0, self.dim // 2).float() / self.dim // 2)), + ) + self.register_buffer( + "freq_right", + (10000 ** (torch.arange(0, self.dim // 2).float() / self.dim // 2)), + ) def forward(self, x): seq_len = x.shape[-2] @@ -17,11 +257,11 @@ def forward(self, x): t = torch.arange(seq_len, device=device) # Get separate frequencies for left and right - freqs_left = torch.einsum('i,j->ij', t, self.freq_left) - freqs_right = torch.einsum('i,j->ij', t, self.freq_right) + freqs_left = torch.einsum("i,j->ij", t, self.freq_left) + freqs_right = torch.einsum("i,j->ij", t, self.freq_right) # Apply frequencies - x_left, x_right = x[..., :self.dim//2], x[..., self.dim//2:] + x_left, x_right = x[..., : self.dim // 2], x[..., self.dim // 2 :] x_left = x_left * freqs_left.cos() - x_right * freqs_left.sin() x_right = x_left * freqs_right.sin() + x_right * freqs_right.cos() @@ -30,6 +270,7 @@ def forward(self, x): return x + class ShortRope(nn.Module): def __init__(self, config): @@ -38,21 +279,26 @@ def __init__(self, config): self.dim = config.n_embd # Generate freqs of size n rather than full dim - self.register_buffer('freq_left', (10000 ** (torch.arange(0, self.n//2).float() / self.n//2))) - self.register_buffer('freq_right', (10000 ** (torch.arange(0, self.n//2).float() / self.n//2))) + self.register_buffer( + "freq_left", (10000 ** (torch.arange(0, self.n // 2).float() / self.n // 2)) + ) + self.register_buffer( + "freq_right", + (10000 ** (torch.arange(0, self.n // 2).float() / self.n // 2)), + ) def forward(self, x): # Step 1: Get the input tensor shape batch_size, seq_len, _ = x.shape # Step 2: Split the input tensor into unrotated and rotated sections - x_unrotated = x[..., :-self.n] # All but the last n dimensions - x_rotated = x[..., -self.n:] # Only the last n dimensions + x_unrotated = x[..., : -self.n] # All but the last n dimensions + x_rotated = x[..., -self.n :] # Only the last n dimensions # Step 3: Generate rotation frequencies t = torch.arange(self.n, device=x.device) - freqs_left = torch.einsum('i,j->ij', t, self.freq_left) - freqs_right = torch.einsum('i,j->ij', t, self.freq_right) + freqs_left = torch.einsum("i,j->ij", t, self.freq_left) + freqs_right = torch.einsum("i,j->ij", t, self.freq_right) # Calculate how many times to repeat freqs along the sequence length num_repeats = seq_len // self.n + int(seq_len % self.n != 0) @@ -66,8 +312,8 @@ def forward(self, x): freqs_right = freqs_right[:, :seq_len, :] # Step 4: Process the x_rotated section - x_left = x_rotated[..., :self.n//2] - x_right = x_rotated[..., self.n//2:] + x_left = x_rotated[..., : self.n // 2] + x_right = x_rotated[..., self.n // 2 :] # Apply the cosine and sine rotations x_left = x_left * freqs_left.cos() - x_right * freqs_left.sin() @@ -83,4 +329,3 @@ def forward(self, x): x = torch.cat([x_unrotated, x_rotated], dim=-1) return x - diff --git a/variations/softmax_variations.py b/variations/softmax_variations.py index 2367faecaf..ee1dd885e0 100644 --- a/variations/softmax_variations.py +++ b/variations/softmax_variations.py @@ -37,6 +37,8 @@ def __init__(self, config, dim=-1): def forward(self, x): x = x - self.beta + k = 10.0 + x = torch.clamp(x, max=k) e_x = torch.pow(self.constantmax_base, x) return e_x / self.gamma @@ -90,12 +92,16 @@ def forward(self, x): #print('fake_gamma:', self.fake_gamma) self.fake_beta, self.fake_gamma=_const_quan(self.beta, self.gamma) x = x - self.fake_beta + k = 10.0 + x = torch.clamp(x, max=k) e_x = torch.exp(x) return e_x / self.fake_gamma else: scale_beta=100 #scaling factor for quantization, should make it as parameter scale_gamma=10 x = x - dequantize(quantize(self.beta,scale_beta), scale_beta) + k = 10.0 + x = torch.clamp(x, max=k) e_x = torch.exp(x) return e_x/dequantize(quantize(self.gamma,scale_gamma), scale_gamma) @@ -144,6 +150,10 @@ def forward(self, x): # Polynomial section poly_piece = torch.where(x > 0, x**self.power + self.y_intercept, torch.tensor(0.0, device=x.device)) + # eps = 1e-6 # Small constant to ensure x is never zero when raised to a power + # safe_x = torch.where(x > 0, x, torch.full_like(x, eps)) + # poly_piece = torch.where(x > 0, safe_x ** self.power + self.y_intercept, torch.tensor(0.0, device=x.device)) + # Combine sections return (poly_piece + linear_piece + flat_piece)/self.divisor