Skip to content

Commit

Permalink
Merge pull request karpathy#252 from gkielian/add_mlp_expansion_factor
Browse files Browse the repository at this point in the history
Add MLP Expansion factor control and sweep
  • Loading branch information
klei22 authored Sep 3, 2024
2 parents 863c54d + 981c8dd commit 37ca368
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 18 deletions.
21 changes: 21 additions & 0 deletions explorations/mlp_expansion_factor_sweep.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[
{
"max_iters": ["3500"],
"n_layer": ["4", "6", "8"],
"n_head": ["6"],
"n_embd": ["384"],
"block_size":["256"],
"device": ["cuda"],
"dataset": ["shakespeare_char"],
"compile": [true],
"save_nan_checkpoint": [true],
"use_rotary_embeddings": [false],
"use_abs_pos_embeddings": [true],
"mlp_expansion_factor": ["2", "3", "4", "5"],
"softmax_variant_attn": ["softmax"],
"dtype": ["bfloat16"],
"use_parallel_mlp": [true, false],
"mlp_variant": ["mlp", "swiglu"]
}
]

1 change: 1 addition & 0 deletions gpt_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class GPTConfig:
# MLP Options
use_parallel_mlp: bool = False
mlp_variant: str = "mlp"
mlp_expansion_factor: int = 4

## KAN Option
kan_poly_order: int = 3
Expand Down
34 changes: 16 additions & 18 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(self, config, fire_pos_enc=None):
config.n_kv_group = config.n_head
else:
assert config.n_embd % config.n_kv_group == 0


self.quantization_attn_dict = {}
self.quantization_attn_dict["activations_quant_method"] = config.activations_quant_method
Expand All @@ -126,7 +126,7 @@ def __init__(self, config, fire_pos_enc=None):
self.quantization_attn_dict[arg] = set_variant(val, config.quantize_linear_bits)
elif arg.startswith("quantize_") and "linear_attn" in arg and arg.endswith("_method"):
self.quantization_attn_dict[arg] = set_variant(val, config.quantize_linear_method)

self.linear_variant_q = linear_dictionary[set_variant(config.linear_variant_q, config.linear_variant_attn)]
self.linear_variant_k = linear_dictionary[set_variant(config.linear_variant_k, config.linear_variant_attn)]
self.linear_variant_v = linear_dictionary[set_variant(config.linear_variant_v, config.linear_variant_attn)]
Expand Down Expand Up @@ -343,7 +343,7 @@ def forward(self, x):
class MLP(nn.Module):
def __init__(self, config):
super().__init__()

# Select "mlp variant"
self.mlp_variant = config.mlp_variant

Expand All @@ -357,10 +357,10 @@ def __init__(self, config):
# Sets the class of linear for MLP
self.linear_variant_mlp_up = linear_dictionary[set_variant(config.linear_variant_mlp_up, config.linear_variant_mlp)]
self.linear_variant_mlp_down = linear_dictionary[set_variant(config.linear_variant_mlp_down, config.linear_variant_mlp)]

self.quantization_mlp_dict = {}
self.quantization_mlp_dict["activations_quant_method"] = config.activations_quant_method

# Set quantization parameters for MLP
for arg, val in vars(config).items():
# Set MLP Activation precision and quantization method
Expand All @@ -375,15 +375,15 @@ def __init__(self, config):
self.quantization_mlp_dict[arg] = set_variant(val, config.quantize_linear_bits)
elif arg.startswith("quantize_") and "linear_mlp" in arg and arg.endswith("_method"):
self.quantization_mlp_dict[arg] = set_variant(val, config.quantize_linear_method)

# Instantiate Linear Layers
if self.mlp_variant == "mlp":
self.c_fc = self.linear_variant_mlp_up(config.n_embd, 4 * config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_up_method"], self.quantization_mlp_dict["quantize_linear_mlp_up_bits"], bias=config.bias)
self.c_proj = self.linear_variant_mlp_down(4 * config.n_embd, config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_down_method"], self.quantization_mlp_dict["quantize_linear_mlp_down_bits"], bias=config.bias)
self.c_fc = self.linear_variant_mlp_up(config.n_embd, config.mlp_expansion_factor * config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_up_method"], self.quantization_mlp_dict["quantize_linear_mlp_up_bits"], bias=config.bias)
self.c_proj = self.linear_variant_mlp_down(config.mlp_expansion_factor * config.n_embd, config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_down_method"], self.quantization_mlp_dict["quantize_linear_mlp_down_bits"], bias=config.bias)
elif self.mlp_variant == "swiglu":
self.c_fc_in1 = self.linear_variant_mlp_up(config.n_embd, 4 * config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_up_method"], self.quantization_mlp_dict["quantize_linear_mlp_up_bits"])
self.c_fc_in2 = self.linear_variant_mlp_up(config.n_embd, 4 * config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_up_method"], self.quantization_mlp_dict["quantize_linear_mlp_up_bits"])
self.c_fc_out = self.linear_variant_mlp_down(4 * config.n_embd, config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_down_method"], self.quantization_mlp_dict["quantize_linear_mlp_down_bits"])
self.c_fc_in1 = self.linear_variant_mlp_up(config.n_embd, config.mlp_expansion_factor * config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_up_method"], self.quantization_mlp_dict["quantize_linear_mlp_up_bits"])
self.c_fc_in2 = self.linear_variant_mlp_up(config.n_embd, config.mlp_expansion_factor * config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_up_method"], self.quantization_mlp_dict["quantize_linear_mlp_up_bits"])
self.c_fc_out = self.linear_variant_mlp_down(config.mlp_expansion_factor * config.n_embd, config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_down_method"], self.quantization_mlp_dict["quantize_linear_mlp_down_bits"])

self.dropout = nn.Dropout(config.dropout)

Expand All @@ -395,7 +395,7 @@ def forward(self, x):

if self.mlp_variant == "kan":
x = self.kan(x)

elif self.mlp_variant == "mlp":
x = self.c_fc(x)

Expand All @@ -412,7 +412,7 @@ def forward(self, x):
x = fake_quantize_act(self, "mlp_act_activation_output", x, num_bits, quant_method)

x = self.c_proj(x)

elif self.mlp_variant == "swiglu":
x_in1 = self.c_fc_in1(x)

Expand All @@ -433,7 +433,7 @@ def forward(self, x):
x = self.c_fc_out(x_out)

x = self.dropout(x)

if self.quantization_mlp_dict["quantize_mlp_act_output"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_output_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
Expand Down Expand Up @@ -646,9 +646,9 @@ def crop_block_size(self, block_size):
def from_pretrained(cls, config, model_type):
# assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
from transformers import GPT2LMHeadModel

print(f"loading weights from pretrained gpt: {model_type}")

# create a from-scratch initialized minGPT model
model = GPT(config)
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
Expand Down Expand Up @@ -843,5 +843,3 @@ def forward(self, x):
# print(f"final_output.shape = {final_output.shape}\n")
return final_output



1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def parse_args():
## MLP Options
model_group.add_argument('--use_parallel_mlp', default=False, action=argparse.BooleanOptionalAction)
model_group.add_argument("--mlp_variant", type=str, default="mlp", choices=["mlp", "kan", "swiglu"], help="MLP variation type")
model_group.add_argument("--mlp_expansion_factor", type=int, default=4, help="If MLP like variant is used, set the expansion factor for the linear transformations, default is 4.")

## KAN Options
model_group.add_argument("--kan_poly_order", type=int, default=3, help="Order of KAN non-linearity")
Expand Down

0 comments on commit 37ca368

Please sign in to comment.