Skip to content

Commit

Permalink
1,2,3 finished
Browse files Browse the repository at this point in the history
  • Loading branch information
SenmiaoORZ committed May 28, 2024
1 parent 5f2dd55 commit 0b71970
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 87 deletions.
72 changes: 0 additions & 72 deletions KALnet.py

This file was deleted.

35 changes: 21 additions & 14 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from variations.norm_variations import norm_dictionary, LayerNorm, RMSNorm, pRMSNorm, kRMSNorm
from variations.position_encoding_variations import RotaryEmbedding, ShortRope, SymmetricalOverlapAngularPositions, FIRE
from variations.activation_variations import SquaredReLU, activation_dictionary
from variations.linear_variations import BitLinear1p58, BitLinear, BitLinearOptimized, linear_dictionary
from KALnet import KAL_Net as KAN
from variations.linear_variations import BitLinear1p58, BitLinear, BitLinearOptimized,KAL_Net as KAN, linear_dictionary

def create_shared_param_group(layer_type, config):
shared_size = None
Expand Down Expand Up @@ -84,8 +83,10 @@ def __init__(self, config, fire_pos_enc=None):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
# self.c_attn_q = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.c_attn_q = KAN([config.n_embd, config.n_embd])
if config.linear_variant == "KAN":
self.c_attn_q = KAN([config.n_embd, config.n_embd])
else:
self.c_attn_q = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

self.n_head = config.n_head
if config.n_kv_group == None:
Expand All @@ -95,13 +96,14 @@ def __init__(self, config, fire_pos_enc=None):
self.n_kv_group = config.n_kv_group

self.kv_dim = (config.n_embd // config.n_head) * self.n_kv_group
self.c_attn_k = KAN([config.n_embd, self.kv_dim])
self.c_attn_v = KAN([config.n_embd, self.kv_dim])
# self.c_attn_k = nn.Linear(config.n_embd, self.kv_dim, bias=config.bias)
# self.c_attn_v = nn.Linear(config.n_embd, self.kv_dim, bias=config.bias)
# output projection
self.c_proj = KAN([config.n_embd, config.n_embd])
# self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
if config.linear_variant == "KAN":
self.c_attn_k = KAN([config.n_embd, self.kv_dim])
self.c_attn_v = KAN([config.n_embd, self.kv_dim])
self.c_proj = KAN([config.n_embd, config.n_embd])
else:
self.c_attn_k = nn.Linear(config.n_embd, self.kv_dim, bias=config.bias)
self.c_attn_v = nn.Linear(config.n_embd, self.kv_dim, bias=config.bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# regularization
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
Expand Down Expand Up @@ -256,12 +258,17 @@ def __init__(self, config):

# Select linear variant
self.linear_variant = linear_dictionary[config.linear_variant]
self.c_fc = self.linear_variant(config.n_embd, 4 * config.n_embd, bias=config.bias)
if config.linear_variant == "KAN":
self.c_fc = self.linear_variant([config.n_embd, 4 * config.n_embd])
else:
self.c_fc = self.linear_variant(config.n_embd, 4 * config.n_embd, bias=config.bias)

# Select activation variant
self.activation_variant = activation_dictionary[config.activation_variant]

self.c_proj = self.linear_variant(4 * config.n_embd, config.n_embd, bias=config.bias)
if config.linear_variant == "KAN":
self.c_proj = self.linear_variant([4 * config.n_embd, config.n_embd])
else:
self.c_proj = self.linear_variant(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)

def forward(self, x):
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,13 @@ def parse_args():
model_group.add_argument(
"--linear_variant",
type=str,
default="linear",
default="KAN",
choices=[
"linear",
"bitlinear",
"bitlinear_1p58",
"bitlinear_optimized",
"KAN",
],
)

Expand Down
72 changes: 72 additions & 0 deletions variations/linear_variations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from functools import lru_cache

class BitLinear1p58(nn.Linear):
""" BitLinear from Era of 1.58 LLMs Paper
Expand Down Expand Up @@ -213,11 +215,81 @@ def forward(self, input):
output = self.quantize_activations_groupwise(output)

return output


class KAL_Net(nn.Module): # Kolmogorov Arnold Legendre Network (KAL-Net)
def __init__(self, layers_hidden, polynomial_order=3, base_activation=nn.SiLU):
super(KAL_Net, self).__init__() # Initialize the parent nn.Module class

# layers_hidden: A list of integers specifying the number of neurons in each layer
self.layers_hidden = layers_hidden
# polynomial_order: Order up to which Legendre polynomials are calculated
self.polynomial_order = polynomial_order
# base_activation: Activation function used after each layer's computation
self.base_activation = base_activation()

# ParameterList for the base weights of each layer
self.base_weights = nn.ParameterList()
# ParameterList for the polynomial weights for Legendre expansion
self.poly_weights = nn.ParameterList()
# ModuleList for layer normalization for each layer's output
self.layer_norms = nn.ModuleList()

# Initialize network parameters
for i, (in_features, out_features) in enumerate(zip(layers_hidden, layers_hidden[1:])):
# Base weight for linear transformation in each layer
self.base_weights.append(nn.Parameter(torch.randn(out_features, in_features)))
# Polynomial weight for handling Legendre polynomial expansions
self.poly_weights.append(nn.Parameter(torch.randn(out_features, in_features * (polynomial_order + 1))))
# Layer normalization to stabilize learning and outputs
self.layer_norms.append(nn.LayerNorm(out_features))

# Initialize weights using Kaiming uniform distribution for better training start
for weight in self.base_weights:
nn.init.kaiming_uniform_(weight, nonlinearity='linear')
for weight in self.poly_weights:
nn.init.kaiming_uniform_(weight, nonlinearity='linear')

@lru_cache(maxsize=128) # Cache to avoid recomputation of Legendre polynomials
def compute_legendre_polynomials(self, x, order):
# Base case polynomials P0 and P1
P0 = x.new_ones(x.shape) # P0 = 1 for all x
if order == 0:
return P0.unsqueeze(-1)
P1 = x # P1 = x
legendre_polys = [P0, P1]

# Compute higher order polynomials using recurrence
for n in range(1, order):
Pn = ((2.0 * n + 1.0) * x * legendre_polys[-1] - n * legendre_polys[-2]) / (n + 1.0)
legendre_polys.append(Pn)

return torch.stack(legendre_polys, dim=-1)

def forward(self, x):
x = x.to(self.base_weights[0].device)
batch_size, seq_len, feature_dim = x.size()

for i, (base_weight, poly_weight, layer_norm) in enumerate(zip(self.base_weights, self.poly_weights, self.layer_norms)):
base_output = F.linear(self.base_activation(x), base_weight)

# Normalize x to range [-1, 1] for Legendre polynomial computation
x_normalized = 2 * (x - x.min(dim=1, keepdim=True)[0]) / (x.max(dim=1, keepdim=True)[0] - x.min(dim=1, keepdim=True)[0]) - 1
legendre_basis = self.compute_legendre_polynomials(x_normalized, self.polynomial_order)
legendre_basis = legendre_basis.view(batch_size * seq_len, -1) # Flatten for linear layer

poly_output = F.linear(legendre_basis, poly_weight)
poly_output = poly_output.view(batch_size, seq_len, -1) # Reshape back to match base_output

x = self.base_activation(layer_norm(base_output + poly_output))

return x


linear_dictionary = {
"linear": nn.Linear,
"bitlinear": BitLinear,
"bitlinear_optimized": BitLinearOptimized,
"bitlinear_1p58": BitLinear1p58,
"KAN": KAL_Net,
}

0 comments on commit 0b71970

Please sign in to comment.