Skip to content

Commit

Permalink
Add fire positional encoding
Browse files Browse the repository at this point in the history
This is just a draft commit

Because the repo is moving so quickly would be easier to apply the
changes on top of the latest.

Also including stability improvements for working with fire for
constantmax and polymax. Was thinking we might also want to try adding
normalization between the fire and softmax stages to test if this fixes
the compatibility.
  • Loading branch information
gkielian committed Apr 7, 2024
1 parent c3259c2 commit d8d6502
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 22 deletions.
34 changes: 28 additions & 6 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# Variations
from variations.softmax_variations import Softermax, Constantmax, Constantmax_quan, Strongermax, Polymax, SigSoftmax
from variations.normalization_variations import LayerNorm, RMSNorm
from variations.position_encoding_variations import RotaryEmbedding, ShortRope
from variations.position_encoding_variations import RotaryEmbedding, ShortRope, FIRE
from variations.activation_variations import SquaredReLU, activation_dictionary


Expand All @@ -39,6 +39,7 @@ def __init__(self, config):
self.dropout = config.dropout
self.window_size = config.window_size
self.gate = config.gate
self.fire_pos_enc = FIRE(num_heads=config.n_head)

# Rotary Positional Embeddings
self.rotary_emb = None
Expand Down Expand Up @@ -86,8 +87,7 @@ def __init__(self, config):
.view(1, 1, config.block_size, config.block_size))

def forward(self, x):
if self.rotary_emb is not None:
x = self.rotary_emb(x)

B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
if self.window_size is not None:
window_mask = torch.ones((1, 1, T, T), device=x.device)
Expand Down Expand Up @@ -124,9 +124,21 @@ def forward(self, x):
# regular lower triangle attention
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))

if self.use_fire_emb:
fire_bias = self.fire_pos_enc(x) # Add this line
att = att + fire_bias # Add this line

# if torch.sum(att.isnan()==True):
# print("fire bias + att nan")
# input()


# softmax variation
if self.softmax_variant_attn != 'softmax':
att = self.softmax_layer(att)
# if torch.sum(att.isnan()==True):
# print("softmax nan")
# input()
else:
att = F.softmax(att, dim=-1)

Expand Down Expand Up @@ -345,7 +357,8 @@ def _init_weights(self, module):
def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
# assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
print(t)
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

# forward the GPT model itself
Expand Down Expand Up @@ -483,15 +496,19 @@ def estimate_mfu(self, fwdbwd_per_iter, dt):
return mfu

@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None,
block_size=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
# idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
block_size = 800
idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]

# forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
# pluck the logits at the final step and scale by desired temperature
Expand All @@ -512,3 +529,8 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
idx = torch.cat((idx, idx_next), dim=1)

return idx


# Enable anomaly detection
torch.autograd.set_detect_anomaly(True)

5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ def parse_args():
default="softmax", choices=["constantmax_quan", "constantmax", "polymax", "strongermax", "softermax", "sigsoftmax", "softmax"])

## Custom Softmax Variation Options
model_group.add_argument("--constantmax_initial_beta", type=float, default=0.0)
model_group.add_argument("--constantmax_initial_gamma", type=float, default=100.0)
model_group.add_argument("--constantmax_initial_beta", type=float, default=6.5)
model_group.add_argument("--constantmax_initial_gamma", type=float,
default=100.0)
model_group.add_argument('--constantmax_use_euler_base', default=True, action=argparse.BooleanOptionalAction)
model_group.add_argument("--constantmax_base", type=float, default=2.0)

Expand Down
Loading

0 comments on commit d8d6502

Please sign in to comment.