From 5f2dd55b1a8c936428f297469381b332012ebe5a Mon Sep 17 00:00:00 2001 From: SenmiaoORZ Date: Wed, 22 May 2024 01:46:58 -0400 Subject: [PATCH 1/2] replace linear with KAN --- KALnet.py | 72 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ model.py | 14 +++++++---- 2 files changed, 81 insertions(+), 5 deletions(-) create mode 100644 KALnet.py diff --git a/KALnet.py b/KALnet.py new file mode 100644 index 0000000000..aa6f4dcef3 --- /dev/null +++ b/KALnet.py @@ -0,0 +1,72 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import lru_cache + +class KAL_Net(nn.Module): # Kolmogorov Arnold Legendre Network (KAL-Net) + def __init__(self, layers_hidden, polynomial_order=3, base_activation=nn.SiLU): + super(KAL_Net, self).__init__() # Initialize the parent nn.Module class + + # layers_hidden: A list of integers specifying the number of neurons in each layer + self.layers_hidden = layers_hidden + # polynomial_order: Order up to which Legendre polynomials are calculated + self.polynomial_order = polynomial_order + # base_activation: Activation function used after each layer's computation + self.base_activation = base_activation() + + # ParameterList for the base weights of each layer + self.base_weights = nn.ParameterList() + # ParameterList for the polynomial weights for Legendre expansion + self.poly_weights = nn.ParameterList() + # ModuleList for layer normalization for each layer's output + self.layer_norms = nn.ModuleList() + + # Initialize network parameters + for i, (in_features, out_features) in enumerate(zip(layers_hidden, layers_hidden[1:])): + # Base weight for linear transformation in each layer + self.base_weights.append(nn.Parameter(torch.randn(out_features, in_features))) + # Polynomial weight for handling Legendre polynomial expansions + self.poly_weights.append(nn.Parameter(torch.randn(out_features, in_features * (polynomial_order + 1)))) + # Layer normalization to stabilize learning and outputs + self.layer_norms.append(nn.LayerNorm(out_features)) + + # Initialize weights using Kaiming uniform distribution for better training start + for weight in self.base_weights: + nn.init.kaiming_uniform_(weight, nonlinearity='linear') + for weight in self.poly_weights: + nn.init.kaiming_uniform_(weight, nonlinearity='linear') + + @lru_cache(maxsize=128) # Cache to avoid recomputation of Legendre polynomials + def compute_legendre_polynomials(self, x, order): + # Base case polynomials P0 and P1 + P0 = x.new_ones(x.shape) # P0 = 1 for all x + if order == 0: + return P0.unsqueeze(-1) + P1 = x # P1 = x + legendre_polys = [P0, P1] + + # Compute higher order polynomials using recurrence + for n in range(1, order): + Pn = ((2.0 * n + 1.0) * x * legendre_polys[-1] - n * legendre_polys[-2]) / (n + 1.0) + legendre_polys.append(Pn) + + return torch.stack(legendre_polys, dim=-1) + + def forward(self, x): + x = x.to(self.base_weights[0].device) + batch_size, seq_len, feature_dim = x.size() + + for i, (base_weight, poly_weight, layer_norm) in enumerate(zip(self.base_weights, self.poly_weights, self.layer_norms)): + base_output = F.linear(self.base_activation(x), base_weight) + + # Normalize x to range [-1, 1] for Legendre polynomial computation + x_normalized = 2 * (x - x.min(dim=1, keepdim=True)[0]) / (x.max(dim=1, keepdim=True)[0] - x.min(dim=1, keepdim=True)[0]) - 1 + legendre_basis = self.compute_legendre_polynomials(x_normalized, self.polynomial_order) + legendre_basis = legendre_basis.view(batch_size * seq_len, -1) # Flatten for linear layer + + poly_output = F.linear(legendre_basis, poly_weight) + poly_output = poly_output.view(batch_size, seq_len, -1) # Reshape back to match base_output + + x = self.base_activation(layer_norm(base_output + poly_output)) + + return x \ No newline at end of file diff --git a/model.py b/model.py index f4e07d3988..f0c90a0166 100644 --- a/model.py +++ b/model.py @@ -24,6 +24,7 @@ from variations.position_encoding_variations import RotaryEmbedding, ShortRope, SymmetricalOverlapAngularPositions, FIRE from variations.activation_variations import SquaredReLU, activation_dictionary from variations.linear_variations import BitLinear1p58, BitLinear, BitLinearOptimized, linear_dictionary +from KALnet import KAL_Net as KAN def create_shared_param_group(layer_type, config): shared_size = None @@ -83,7 +84,8 @@ def __init__(self, config, fire_pos_enc=None): super().__init__() assert config.n_embd % config.n_head == 0 # key, query, value projections for all heads, but in a batch - self.c_attn_q = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # self.c_attn_q = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + self.c_attn_q = KAN([config.n_embd, config.n_embd]) self.n_head = config.n_head if config.n_kv_group == None: @@ -93,10 +95,13 @@ def __init__(self, config, fire_pos_enc=None): self.n_kv_group = config.n_kv_group self.kv_dim = (config.n_embd // config.n_head) * self.n_kv_group - self.c_attn_k = nn.Linear(config.n_embd, self.kv_dim, bias=config.bias) - self.c_attn_v = nn.Linear(config.n_embd, self.kv_dim, bias=config.bias) + self.c_attn_k = KAN([config.n_embd, self.kv_dim]) + self.c_attn_v = KAN([config.n_embd, self.kv_dim]) + # self.c_attn_k = nn.Linear(config.n_embd, self.kv_dim, bias=config.bias) + # self.c_attn_v = nn.Linear(config.n_embd, self.kv_dim, bias=config.bias) # output projection - self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + self.c_proj = KAN([config.n_embd, config.n_embd]) + # self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) # regularization self.attn_dropout = nn.Dropout(config.dropout) self.resid_dropout = nn.Dropout(config.dropout) @@ -161,7 +166,6 @@ def __init__(self, config, fire_pos_enc=None): def forward(self, x): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) - q = self.c_attn_q(x) k = self.c_attn_k(x) v = self.c_attn_v(x) From 0b719707ce62ef5cb9a2c1aff5966dffc330b775 Mon Sep 17 00:00:00 2001 From: SenmiaoORZ Date: Mon, 27 May 2024 23:49:53 -0400 Subject: [PATCH 2/2] 1,2,3 finished --- KALnet.py | 72 --------------------------------- model.py | 35 +++++++++------- train.py | 3 +- variations/linear_variations.py | 72 +++++++++++++++++++++++++++++++++ 4 files changed, 95 insertions(+), 87 deletions(-) delete mode 100644 KALnet.py diff --git a/KALnet.py b/KALnet.py deleted file mode 100644 index aa6f4dcef3..0000000000 --- a/KALnet.py +++ /dev/null @@ -1,72 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from functools import lru_cache - -class KAL_Net(nn.Module): # Kolmogorov Arnold Legendre Network (KAL-Net) - def __init__(self, layers_hidden, polynomial_order=3, base_activation=nn.SiLU): - super(KAL_Net, self).__init__() # Initialize the parent nn.Module class - - # layers_hidden: A list of integers specifying the number of neurons in each layer - self.layers_hidden = layers_hidden - # polynomial_order: Order up to which Legendre polynomials are calculated - self.polynomial_order = polynomial_order - # base_activation: Activation function used after each layer's computation - self.base_activation = base_activation() - - # ParameterList for the base weights of each layer - self.base_weights = nn.ParameterList() - # ParameterList for the polynomial weights for Legendre expansion - self.poly_weights = nn.ParameterList() - # ModuleList for layer normalization for each layer's output - self.layer_norms = nn.ModuleList() - - # Initialize network parameters - for i, (in_features, out_features) in enumerate(zip(layers_hidden, layers_hidden[1:])): - # Base weight for linear transformation in each layer - self.base_weights.append(nn.Parameter(torch.randn(out_features, in_features))) - # Polynomial weight for handling Legendre polynomial expansions - self.poly_weights.append(nn.Parameter(torch.randn(out_features, in_features * (polynomial_order + 1)))) - # Layer normalization to stabilize learning and outputs - self.layer_norms.append(nn.LayerNorm(out_features)) - - # Initialize weights using Kaiming uniform distribution for better training start - for weight in self.base_weights: - nn.init.kaiming_uniform_(weight, nonlinearity='linear') - for weight in self.poly_weights: - nn.init.kaiming_uniform_(weight, nonlinearity='linear') - - @lru_cache(maxsize=128) # Cache to avoid recomputation of Legendre polynomials - def compute_legendre_polynomials(self, x, order): - # Base case polynomials P0 and P1 - P0 = x.new_ones(x.shape) # P0 = 1 for all x - if order == 0: - return P0.unsqueeze(-1) - P1 = x # P1 = x - legendre_polys = [P0, P1] - - # Compute higher order polynomials using recurrence - for n in range(1, order): - Pn = ((2.0 * n + 1.0) * x * legendre_polys[-1] - n * legendre_polys[-2]) / (n + 1.0) - legendre_polys.append(Pn) - - return torch.stack(legendre_polys, dim=-1) - - def forward(self, x): - x = x.to(self.base_weights[0].device) - batch_size, seq_len, feature_dim = x.size() - - for i, (base_weight, poly_weight, layer_norm) in enumerate(zip(self.base_weights, self.poly_weights, self.layer_norms)): - base_output = F.linear(self.base_activation(x), base_weight) - - # Normalize x to range [-1, 1] for Legendre polynomial computation - x_normalized = 2 * (x - x.min(dim=1, keepdim=True)[0]) / (x.max(dim=1, keepdim=True)[0] - x.min(dim=1, keepdim=True)[0]) - 1 - legendre_basis = self.compute_legendre_polynomials(x_normalized, self.polynomial_order) - legendre_basis = legendre_basis.view(batch_size * seq_len, -1) # Flatten for linear layer - - poly_output = F.linear(legendre_basis, poly_weight) - poly_output = poly_output.view(batch_size, seq_len, -1) # Reshape back to match base_output - - x = self.base_activation(layer_norm(base_output + poly_output)) - - return x \ No newline at end of file diff --git a/model.py b/model.py index f0c90a0166..b98de22066 100644 --- a/model.py +++ b/model.py @@ -23,8 +23,7 @@ from variations.norm_variations import norm_dictionary, LayerNorm, RMSNorm, pRMSNorm, kRMSNorm from variations.position_encoding_variations import RotaryEmbedding, ShortRope, SymmetricalOverlapAngularPositions, FIRE from variations.activation_variations import SquaredReLU, activation_dictionary -from variations.linear_variations import BitLinear1p58, BitLinear, BitLinearOptimized, linear_dictionary -from KALnet import KAL_Net as KAN +from variations.linear_variations import BitLinear1p58, BitLinear, BitLinearOptimized,KAL_Net as KAN, linear_dictionary def create_shared_param_group(layer_type, config): shared_size = None @@ -84,8 +83,10 @@ def __init__(self, config, fire_pos_enc=None): super().__init__() assert config.n_embd % config.n_head == 0 # key, query, value projections for all heads, but in a batch - # self.c_attn_q = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) - self.c_attn_q = KAN([config.n_embd, config.n_embd]) + if config.linear_variant == "KAN": + self.c_attn_q = KAN([config.n_embd, config.n_embd]) + else: + self.c_attn_q = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) self.n_head = config.n_head if config.n_kv_group == None: @@ -95,13 +96,14 @@ def __init__(self, config, fire_pos_enc=None): self.n_kv_group = config.n_kv_group self.kv_dim = (config.n_embd // config.n_head) * self.n_kv_group - self.c_attn_k = KAN([config.n_embd, self.kv_dim]) - self.c_attn_v = KAN([config.n_embd, self.kv_dim]) - # self.c_attn_k = nn.Linear(config.n_embd, self.kv_dim, bias=config.bias) - # self.c_attn_v = nn.Linear(config.n_embd, self.kv_dim, bias=config.bias) - # output projection - self.c_proj = KAN([config.n_embd, config.n_embd]) - # self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + if config.linear_variant == "KAN": + self.c_attn_k = KAN([config.n_embd, self.kv_dim]) + self.c_attn_v = KAN([config.n_embd, self.kv_dim]) + self.c_proj = KAN([config.n_embd, config.n_embd]) + else: + self.c_attn_k = nn.Linear(config.n_embd, self.kv_dim, bias=config.bias) + self.c_attn_v = nn.Linear(config.n_embd, self.kv_dim, bias=config.bias) + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) # regularization self.attn_dropout = nn.Dropout(config.dropout) self.resid_dropout = nn.Dropout(config.dropout) @@ -256,12 +258,17 @@ def __init__(self, config): # Select linear variant self.linear_variant = linear_dictionary[config.linear_variant] - self.c_fc = self.linear_variant(config.n_embd, 4 * config.n_embd, bias=config.bias) + if config.linear_variant == "KAN": + self.c_fc = self.linear_variant([config.n_embd, 4 * config.n_embd]) + else: + self.c_fc = self.linear_variant(config.n_embd, 4 * config.n_embd, bias=config.bias) # Select activation variant self.activation_variant = activation_dictionary[config.activation_variant] - - self.c_proj = self.linear_variant(4 * config.n_embd, config.n_embd, bias=config.bias) + if config.linear_variant == "KAN": + self.c_proj = self.linear_variant([4 * config.n_embd, config.n_embd]) + else: + self.c_proj = self.linear_variant(4 * config.n_embd, config.n_embd, bias=config.bias) self.dropout = nn.Dropout(config.dropout) def forward(self, x): diff --git a/train.py b/train.py index 365d1b8932..50c060ab50 100644 --- a/train.py +++ b/train.py @@ -104,12 +104,13 @@ def parse_args(): model_group.add_argument( "--linear_variant", type=str, - default="linear", + default="KAN", choices=[ "linear", "bitlinear", "bitlinear_1p58", "bitlinear_optimized", + "KAN", ], ) diff --git a/variations/linear_variations.py b/variations/linear_variations.py index 141ccfec3e..22bcaafbbc 100644 --- a/variations/linear_variations.py +++ b/variations/linear_variations.py @@ -1,6 +1,8 @@ import torch import torch.nn as nn import math +import torch.nn.functional as F +from functools import lru_cache class BitLinear1p58(nn.Linear): """ BitLinear from Era of 1.58 LLMs Paper @@ -213,6 +215,75 @@ def forward(self, input): output = self.quantize_activations_groupwise(output) return output + + +class KAL_Net(nn.Module): # Kolmogorov Arnold Legendre Network (KAL-Net) + def __init__(self, layers_hidden, polynomial_order=3, base_activation=nn.SiLU): + super(KAL_Net, self).__init__() # Initialize the parent nn.Module class + + # layers_hidden: A list of integers specifying the number of neurons in each layer + self.layers_hidden = layers_hidden + # polynomial_order: Order up to which Legendre polynomials are calculated + self.polynomial_order = polynomial_order + # base_activation: Activation function used after each layer's computation + self.base_activation = base_activation() + + # ParameterList for the base weights of each layer + self.base_weights = nn.ParameterList() + # ParameterList for the polynomial weights for Legendre expansion + self.poly_weights = nn.ParameterList() + # ModuleList for layer normalization for each layer's output + self.layer_norms = nn.ModuleList() + + # Initialize network parameters + for i, (in_features, out_features) in enumerate(zip(layers_hidden, layers_hidden[1:])): + # Base weight for linear transformation in each layer + self.base_weights.append(nn.Parameter(torch.randn(out_features, in_features))) + # Polynomial weight for handling Legendre polynomial expansions + self.poly_weights.append(nn.Parameter(torch.randn(out_features, in_features * (polynomial_order + 1)))) + # Layer normalization to stabilize learning and outputs + self.layer_norms.append(nn.LayerNorm(out_features)) + + # Initialize weights using Kaiming uniform distribution for better training start + for weight in self.base_weights: + nn.init.kaiming_uniform_(weight, nonlinearity='linear') + for weight in self.poly_weights: + nn.init.kaiming_uniform_(weight, nonlinearity='linear') + + @lru_cache(maxsize=128) # Cache to avoid recomputation of Legendre polynomials + def compute_legendre_polynomials(self, x, order): + # Base case polynomials P0 and P1 + P0 = x.new_ones(x.shape) # P0 = 1 for all x + if order == 0: + return P0.unsqueeze(-1) + P1 = x # P1 = x + legendre_polys = [P0, P1] + + # Compute higher order polynomials using recurrence + for n in range(1, order): + Pn = ((2.0 * n + 1.0) * x * legendre_polys[-1] - n * legendre_polys[-2]) / (n + 1.0) + legendre_polys.append(Pn) + + return torch.stack(legendre_polys, dim=-1) + + def forward(self, x): + x = x.to(self.base_weights[0].device) + batch_size, seq_len, feature_dim = x.size() + + for i, (base_weight, poly_weight, layer_norm) in enumerate(zip(self.base_weights, self.poly_weights, self.layer_norms)): + base_output = F.linear(self.base_activation(x), base_weight) + + # Normalize x to range [-1, 1] for Legendre polynomial computation + x_normalized = 2 * (x - x.min(dim=1, keepdim=True)[0]) / (x.max(dim=1, keepdim=True)[0] - x.min(dim=1, keepdim=True)[0]) - 1 + legendre_basis = self.compute_legendre_polynomials(x_normalized, self.polynomial_order) + legendre_basis = legendre_basis.view(batch_size * seq_len, -1) # Flatten for linear layer + + poly_output = F.linear(legendre_basis, poly_weight) + poly_output = poly_output.view(batch_size, seq_len, -1) # Reshape back to match base_output + + x = self.base_activation(layer_norm(base_output + poly_output)) + + return x linear_dictionary = { @@ -220,4 +291,5 @@ def forward(self, input): "bitlinear": BitLinear, "bitlinear_optimized": BitLinearOptimized, "bitlinear_1p58": BitLinear1p58, + "KAN": KAL_Net, }