forked from karpathy/nanoGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into add_training_estimates
- Loading branch information
Showing
19 changed files
with
670 additions
and
176 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
[ | ||
{ | ||
"max_iters": ["3500"], | ||
"n_layer": ["12"], | ||
"n_head": ["6"], | ||
"n_embd": ["384"], | ||
"block_size":["256"], | ||
"device": ["cuda"], | ||
"dtype": ["bfloat16"], | ||
"dataset": ["shakespeare_char"], | ||
"use_gradient_checkpointing": [true], | ||
"use_rotary_embeddings": [false], | ||
"use_abs_pos_embeddings": [true], | ||
"compile": [false], | ||
"softmax_variant_attn": ["consmax_v2"], | ||
"consmax_initial_beta": ["2.5"], | ||
"consmax_initial_gamma": ["100.0"], | ||
"softmax_io_logging": [true], | ||
"create_statistics": [true], | ||
"plot_statistics": [true], | ||
"consmax_per_head": { | ||
"conditions": [["softmax_variant_attn", "consmax_v2"]], | ||
"options": [true, false] | ||
}, | ||
"use_post_ln": [true, false] | ||
} | ||
] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,27 @@ | ||
[ | ||
{ | ||
"max_iters": ["6000"], | ||
"max_iters": ["3500"], | ||
"n_layer": ["6"], | ||
"n_kv_group": ["6"], | ||
"n_head": ["6"], | ||
"n_embd": ["384"], | ||
"block_size":["256"], | ||
"device": ["cuda"], | ||
"dtype": ["float16"], | ||
"dtype": ["bfloat16"], | ||
"dataset": ["shakespeare_char"], | ||
"use_rotary_embeddings": [false], | ||
"use_abs_pos_embeddings": [true], | ||
"compile": [true], | ||
"compile": [false], | ||
"softmax_variant_attn": ["strongermax"], | ||
"strongermax_strength": ["1.5", "2", "2.719", "3", "4", "5"], | ||
"strongermax_divisor": ["1.0", "10.0", "100.0", "1000.0"], | ||
"use_post_ln": [true, false], | ||
"strongermax_use_xmax" : [true, false], | ||
"strongermax_sum_to_1": [true, false], | ||
"statistic": ["all_stats"], | ||
"graph_type":["all"], | ||
"box_plot_interval": ["1000"], | ||
"box_plot_statistic": ["all"], | ||
"patience": ["1000"] | ||
"strongermax_strength": ["2.719"], | ||
"strongermax_divisor": ["256"], | ||
"strongermax_use_xmax" : [true], | ||
"strongermax_xmax_guess" : ["-50"], | ||
"strongermax_overflow_recompute" : [true], | ||
"softmax_io_logging" : [false], | ||
"create_statistics" : [false], | ||
"plot_statistics" : [false], | ||
"statistic" : ["input_max"], | ||
"strongermax_sum_to_1": [false] | ||
} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from transformers import pipeline, GPT2LMHeadModel, GPT2Tokenizer | ||
|
||
model = GPT2LMHeadModel.from_pretrained("gpt2-custom") | ||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-custom") | ||
|
||
generator = pipeline("text-generation", model=model, tokenizer=tokenizer) | ||
output = generator("Once upon a time", max_length=50, num_return_sequences=1) | ||
print(output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from transformers import AutoTokenizer, AutoConfig, AutoModel | ||
|
||
config = AutoConfig.from_pretrained("gpt2-custom") | ||
config.push_to_hub("custom_gpt2") | ||
model = AutoModel.from_pretrained("gpt2-custom") | ||
model.push_to_hub("custom_gpt2") | ||
tokenizer = AutoTokenizer.from_pretrained("gpt2-custom") | ||
tokenizer.push_to_hub("custom_gpt2") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import torch | ||
from torchinfo import summary | ||
from rich import print | ||
from rich.console import Console | ||
from rich.text import Text | ||
import io | ||
|
||
console = Console() | ||
|
||
def print_summary(model): | ||
block_header = Text(f"High Level Parameters:", style="bold underline purple") | ||
console.print(block_header) | ||
summary(model) | ||
|
||
def print_model_blocks(model, block_range=1): | ||
for idx, block in enumerate(model.transformer.h): | ||
block_header = Text(f"Summary for Block {idx + 1}:", style="bold underline green") | ||
console.print(block_header) | ||
summary(block) | ||
if (idx + 1) == block_range: | ||
break | ||
|
||
def print_module_structure(module): | ||
console.print("-" * 50, style="dim") | ||
for name, submodule in module.named_children(): | ||
console.print(f'{name}: {submodule}', style="yellow") | ||
console.print("-" * 50, style="dim") | ||
|
||
def print_model_tree(model, indent="", print_params=False): | ||
for name, module in model.named_children(): | ||
print(indent + name + ": " + str(module.__class__.__name__)) | ||
if isinstance(module, torch.nn.Module): | ||
# Print parameters for the next level only | ||
if print_params: | ||
for param_name, _ in module.named_parameters(): | ||
print(indent + " " + param_name) | ||
else: # Recursively print submodules without parameters | ||
print_model_tree(module, indent + " ") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.