Skip to content

Commit

Permalink
Update on "add selective activation checkpointing"
Browse files Browse the repository at this point in the history
Selective activation checkpointing (SAC), compared with full AC which always does activation recomputation, selectively stores some intermediate activations to save training time, at the cost of more memory usage.

Here are some test results on llama 7B.

with full activation checkpointing:
- [rank0]: Average iter time: 4.9126 seconds
- [rank0]: Peak Memory: Reserved 40.61%, Alloc 28.12%, Active: 29.61%

with selective activation checkpointing:
- [rank0]: Average iter time: 4.5459 seconds
- [rank0]: Peak Memory: Reserved 80.45%, Alloc 62.0%, Active: 63.43%

[ghstack-poisoned]
  • Loading branch information
tianyu-l committed Feb 29, 2024
1 parent 89a31e5 commit 7ad1e43
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def main(job_config: JobConfig):

# torch.compile model for improved performance
if job_config.training.compile:
if job_config.training.enable_selective_ac:
torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = (
True
)
rank0_log(f"Compiling model {model_name} with torch.compile...")
model = torch.compile(
model,
Expand Down

0 comments on commit 7ad1e43

Please sign in to comment.