Skip to content

Commit

Permalink
add selective activation checkpointing
Browse files Browse the repository at this point in the history
ghstack-source-id: 2fbae95768f06b1b35af3f69eb0d39777b214089
Pull Request resolved: #97
  • Loading branch information
tianyu-l committed Feb 28, 2024
1 parent 96d1cb1 commit 180b213
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 13 deletions.
15 changes: 10 additions & 5 deletions torchtrain/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,17 @@ def init_args_from_command_line(
help="collect profiler traces every x iterations",
)
# metrics configs
parser.add_argument(
"--metrics.enable_tensorboard",
action="store_true",
help="whether to log metrics to TensorBoard",
)
parser.add_argument(
"--metrics.log_freq",
type=int,
default=10,
help="how often to log metrics to TensorBoard",
)
parser.add_argument(
"--metrics.enable_tensorboard",
action="store_true",
help="how often to log metrics to TensorBoard",
)
parser.add_argument(
"--metrics.save_tb_folder",
type=str,
Expand Down Expand Up @@ -215,4 +215,9 @@ def init_args_from_command_line(
"is an empty string, checkpointing is disabled."
),
)
parser.add_argument(
"--metrics.enable_selective_ac",
action="store_false",
help="whether to enable selective activation checkpointing",
)
return parser.parse_args(args_list)
60 changes: 52 additions & 8 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
RowwiseParallel,
)
from torchtrain.config_manager import JobConfig

from torchtrain.logging_utils import rank0_log

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,12 +66,54 @@ def partition_fn(name, module, device_mesh):
)


# AC/selective AC
no_recompute_list = {
torch.ops.aten.mm.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.c10d_functional.reduce_scatter_tensor.default,
}

# Uses PTD FSDP AC wrapper
# TODO: why is config needed here?
def checkpoint_wrapper(module, job_config: JobConfig):
return ptd_checkpoint_wrapper(
module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False
)
def checkpoint_wrapper(module, enable_selective_ac=False):
if enable_selective_ac:
from torch.utils.checkpoint import (
_pt2_selective_checkpoint_context_fn_gen,
checkpoint,
)

def _get_custom_policy(meta):
def _custom_policy(mode, func, *args, **kwargs):
mm_count_key = f"{mode}_mm_count"
if mm_count_key not in meta:
meta[mm_count_key] = 0
if func == torch.ops.aten.mm.default:
meta[mm_count_key] += 1
# Saves output of all compute ops, except every second mm
return func in no_recompute_list and not (
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
)

return _custom_policy

def selective_checkpointing_context_fn():
meta = {}
return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta))

return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
checkpoint_fn=checkpoint,
context_fn=selective_checkpointing_context_fn,
use_reentrant=False,
preserve_rng_state=False,
)
else:
return ptd_checkpoint_wrapper(
module,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
preserve_rng_state=False,
)


def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
Expand Down Expand Up @@ -168,10 +209,13 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):

with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
for layer_id, transformer_block in enumerate(model.layers):
# apply AC to each layer
# before wrapping with FSDP, we need to make sure the layer is on GPU
transformer_block = transformer_block.cuda()
transformer_block = checkpoint_wrapper(transformer_block, job_config)

# apply selective AC
transformer_block = checkpoint_wrapper(
transformer_block, job_config.training.enable_selective_ac
)

# Wraps each layer with FSDP
model.layers[layer_id] = wrap(transformer_block)
Expand Down
1 change: 1 addition & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ checkpoint_interval = 3600
checkpoint_interval_type = "steps"
checkpoint_folder = ""
dataset = "alpaca"
enable_selective_ac = false

0 comments on commit 180b213

Please sign in to comment.