Skip to content

Commit

Permalink
Add backprop intermediate value recompute option
Browse files Browse the repository at this point in the history
  • Loading branch information
gkielian committed Aug 28, 2024
1 parent f4c0781 commit f64a039
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 1 deletion.
1 change: 1 addition & 0 deletions gpt_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class GPTConfig:
# Training options
## Gradient Checkpointing - More memory efficient (can do long contexts), but is slower
use_gradient_checkpointing: bool = False
recompute_backward_pass: bool = False

# MLP Options
use_parallel_mlp: bool = False
Expand Down
2 changes: 1 addition & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def forward(self, idx, targets=None):

for block in self.transformer.h:
if self.config.use_gradient_checkpointing:
x = checkpoint.checkpoint(block, x, use_reentrant=False)
x = checkpoint.checkpoint(block, x, use_reentrant=self.recompute_backward_pass)
else:
x = block(x)

Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def parse_args():

# Gradient Checkpointing
training_group.add_argument('--use_gradient_checkpointing', default=False, action=argparse.BooleanOptionalAction, help="Memory efficient training, but takes longer time to train due to trading compute time for memory efficiency. For best memory tradeoff omit the --compile flag. For medium memory tradeoff add --compile.")
training_group.add_argument('--recompute_backward_pass', default=False, action=argparse.BooleanOptionalAction, help="Recomputes for the backward pass, must use with --use_gradient_checkpointing")

# Optimizer args
training_group.add_argument('--max_iters', default=3500, type=int)
Expand Down

0 comments on commit f64a039

Please sign in to comment.