Skip to content

Commit

Permalink
add selective activation checkpointing
Browse files Browse the repository at this point in the history
ghstack-source-id: f7ee3c867bfcdcae5dbb490982920606191e6f40
Pull Request resolved: pytorch#97
  • Loading branch information
tianyu-l committed Feb 29, 2024
1 parent d5c27a9 commit 2c8cec2
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 13 deletions.
15 changes: 10 additions & 5 deletions torchtrain/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,17 @@ def init_args_from_command_line(
help="collect profiler traces every x iterations",
)
# metrics configs
parser.add_argument(
"--metrics.enable_tensorboard",
action="store_true",
help="whether to log metrics to TensorBoard",
)
parser.add_argument(
"--metrics.log_freq",
type=int,
default=10,
help="how often to log metrics to TensorBoard",
)
parser.add_argument(
"--metrics.enable_tensorboard",
action="store_true",
help="how often to log metrics to TensorBoard",
)
parser.add_argument(
"--metrics.save_tb_folder",
type=str,
Expand Down Expand Up @@ -218,4 +218,9 @@ def init_args_from_command_line(
"is an empty string, checkpointing is disabled."
),
)
parser.add_argument(
"--training.enable_selective_ac",
action="store_false",
help="whether to enable selective activation checkpointing",
)
return parser.parse_args(args_list)
59 changes: 51 additions & 8 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# llama model, i.e. activation checkpoint, etc.

import logging
from collections import defaultdict

import torch
from torch.distributed._tensor import (
Expand Down Expand Up @@ -33,7 +34,6 @@
RowwiseParallel,
)
from torchtrain.config_manager import JobConfig

from torchtrain.logging_utils import rank0_log

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,12 +67,52 @@ def partition_fn(name, module, device_mesh):
)


# for selective AC
no_recompute_list = {
torch.ops.aten.mm.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.c10d_functional.reduce_scatter_tensor.default,
}

# Uses PTD FSDP AC wrapper
# TODO: why is config needed here?
def checkpoint_wrapper(module, job_config: JobConfig):
return ptd_checkpoint_wrapper(
module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False
)
def checkpoint_wrapper(module, enable_selective_ac):
if enable_selective_ac:
from torch.utils.checkpoint import (
_pt2_selective_checkpoint_context_fn_gen,
checkpoint,
)

def _get_custom_policy(meta):
def _custom_policy(mode, func, *args, **kwargs):
mm_count_key = f"{mode}_mm_count"
if func == torch.ops.aten.mm.default:
meta[mm_count_key] += 1
# Saves output of all compute ops, except every second mm
return func in no_recompute_list and not (
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
)

return _custom_policy

def selective_checkpointing_context_fn():
meta = defaultdict(int)
return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta))

return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
checkpoint_fn=checkpoint,
context_fn=selective_checkpointing_context_fn,
use_reentrant=False,
preserve_rng_state=False,
)
else:
return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
preserve_rng_state=False,
)


def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
Expand Down Expand Up @@ -168,10 +208,13 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):

with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
for layer_id, transformer_block in enumerate(model.layers):
# apply AC to each layer
# before wrapping with FSDP, we need to make sure the layer is on GPU
transformer_block = transformer_block.cuda()
transformer_block = checkpoint_wrapper(transformer_block, job_config)

# apply selective AC
transformer_block = checkpoint_wrapper(
transformer_block, job_config.training.enable_selective_ac
)

# Wraps each layer with FSDP
model.layers[layer_id] = wrap(transformer_block)
Expand Down
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def main(job_config: JobConfig):

# torch.compile model for improved performance
if job_config.training.compile:
if job_config.training.enable_selective_ac:
torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = (
True
)
rank0_log(f"Compiling model {model_name} with torch.compile...")
model = torch.compile(
model,
Expand Down
1 change: 1 addition & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ checkpoint_interval = 3600
checkpoint_interval_type = "steps"
checkpoint_folder = ""
dataset = "alpaca"
enable_selective_ac = false

0 comments on commit 2c8cec2

Please sign in to comment.