Skip to content

Commit

Permalink
Merge pull request karpathy#223 from gkielian/add_model_param_section
Browse files Browse the repository at this point in the history
Add model param section
  • Loading branch information
klei22 authored Aug 9, 2024
2 parents 9e1206a + 2ef1c57 commit c12522e
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ If you are compatible with cu11.8, then use the following:
```bash
python3 -m pip install --upgrade pip
python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
python3 -m pip install numpy transformers datasets tiktoken wandb tqdm tensorboard rich
python3 -m pip install numpy transformers datasets tiktoken wandb tqdm tensorboard rich torchinfo
```

If unsure, visit the pytorch page and subtitute the appropriate line for the `torch` installation line above: https://pytorch.org/get-started/locally/
Expand Down
1 change: 1 addition & 0 deletions requirements_cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,4 @@ rich==13.5.3
plotly==5.22.0
seaborn==0.13.2
kaleido==0.2.1
torchinfo==1.8.0
17 changes: 15 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import sys
import time

from torchinfo import summary

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
Expand Down Expand Up @@ -373,12 +375,15 @@ def parse_args():
logging_group.add_argument('--box_plot_statistic', choices=['input', 'output', 'all'],
default='', help='Select input or output statistic to display in boxplot')

# Model Parameter Distribution
logging_group.add_argument('--print_block_summary', default=False, action=argparse.BooleanOptionalAction)

args = parser.parse_args()

if args.load_config_json is not None:
with open(args.load_config_json, 'r') as config_file:
config = json.load(config_file)

# Update the args namespace with values from the JSON file
for key, value in config.items():
setattr(args, key, value)
Expand Down Expand Up @@ -424,7 +429,7 @@ class Trainer:
def __init__(self, args, model_group, training_group, logging_group):
self.args = args
self.model_group = model_group
self.training_group = training_group
self.training_group = training_group
self.logging_group = logging_group

# typically make the decay iters equal to max_iters
Expand Down Expand Up @@ -538,6 +543,14 @@ def setup(self):

self.model.to(self.device)

# Print the model summary
summary(self.model)

if self.args.print_block_summary:
for idx, block in enumerate(self.model.transformer.h):
print(f"Summary for Block {idx + 1}:")
summary(block)

# Optimizer
self.scaler = torch.cuda.amp.GradScaler(enabled=(self.args.dtype == 'float16'))
self.optimizer = self.model.configure_optimizers(self.args.weight_decay, self.args.learning_rate,
Expand Down

0 comments on commit c12522e

Please sign in to comment.