diff --git a/torchtrain/config_manager.py b/torchtrain/config_manager.py index 613f94113..a75c77d62 100644 --- a/torchtrain/config_manager.py +++ b/torchtrain/config_manager.py @@ -94,17 +94,17 @@ def init_args_from_command_line( help="collect profiler traces every x iterations", ) # metrics configs + parser.add_argument( + "--metrics.enable_tensorboard", + action="store_true", + help="whether to log metrics to TensorBoard", + ) parser.add_argument( "--metrics.log_freq", type=int, default=10, help="how often to log metrics to TensorBoard", ) - parser.add_argument( - "--metrics.enable_tensorboard", - action="store_true", - help="how often to log metrics to TensorBoard", - ) parser.add_argument( "--metrics.save_tb_folder", type=str, @@ -215,4 +215,9 @@ def init_args_from_command_line( "is an empty string, checkpointing is disabled." ), ) + parser.add_argument( + "--metrics.enable_selective_ac", + action="store_false", + help="whether to enable selective activation checkpointing", + ) return parser.parse_args(args_list) diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 805bfa87c..7f04faa92 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -33,7 +33,6 @@ RowwiseParallel, ) from torchtrain.config_manager import JobConfig - from torchtrain.logging_utils import rank0_log logger = logging.getLogger(__name__) @@ -67,12 +66,54 @@ def partition_fn(name, module, device_mesh): ) +# AC/selective AC +no_recompute_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.c10d_functional.reduce_scatter_tensor.default, +} + # Uses PTD FSDP AC wrapper -# TODO: why is config needed here? -def checkpoint_wrapper(module, job_config: JobConfig): - return ptd_checkpoint_wrapper( - module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False - ) +def checkpoint_wrapper(module, enable_selective_ac=False): + if enable_selective_ac: + from torch.utils.checkpoint import ( + _pt2_selective_checkpoint_context_fn_gen, + checkpoint, + ) + + def _get_custom_policy(meta): + def _custom_policy(mode, func, *args, **kwargs): + mm_count_key = f"{mode}_mm_count" + if mm_count_key not in meta: + meta[mm_count_key] = 0 + if func == torch.ops.aten.mm.default: + meta[mm_count_key] += 1 + # Saves output of all compute ops, except every second mm + return func in no_recompute_list and not ( + func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 + ) + + return _custom_policy + + def selective_checkpointing_context_fn(): + meta = {} + return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta)) + + return ptd_checkpoint_wrapper( + module, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + checkpoint_fn=checkpoint, + context_fn=selective_checkpointing_context_fn, + use_reentrant=False, + preserve_rng_state=False, + ) + else: + return ptd_checkpoint_wrapper( + module, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + preserve_rng_state=False, + ) def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): @@ -168,10 +209,13 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): with enable_wrap(wrapper_cls=FSDP, **fsdp_config): for layer_id, transformer_block in enumerate(model.layers): - # apply AC to each layer # before wrapping with FSDP, we need to make sure the layer is on GPU transformer_block = transformer_block.cuda() - transformer_block = checkpoint_wrapper(transformer_block, job_config) + + # apply selective AC + transformer_block = checkpoint_wrapper( + transformer_block, job_config.training.enable_selective_ac + ) # Wraps each layer with FSDP model.layers[layer_id] = wrap(transformer_block) diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 1cca38b09..a7a28e222 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -37,3 +37,4 @@ checkpoint_interval = 3600 checkpoint_interval_type = "steps" checkpoint_folder = "" dataset = "alpaca" +enable_selective_ac = false