From c7c6b4e5631649512769d19631842805b3d09e84 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 27 Feb 2024 18:14:45 -0800 Subject: [PATCH 1/4] add selective activation checkpointing [ghstack-poisoned] --- torchtrain/config_manager.py | 15 +++-- torchtrain/parallelisms/parallelize_llama.py | 60 +++++++++++++++++--- train_configs/debug_model.toml | 1 + 3 files changed, 63 insertions(+), 13 deletions(-) diff --git a/torchtrain/config_manager.py b/torchtrain/config_manager.py index 613f9411..a75c77d6 100644 --- a/torchtrain/config_manager.py +++ b/torchtrain/config_manager.py @@ -94,17 +94,17 @@ def init_args_from_command_line( help="collect profiler traces every x iterations", ) # metrics configs + parser.add_argument( + "--metrics.enable_tensorboard", + action="store_true", + help="whether to log metrics to TensorBoard", + ) parser.add_argument( "--metrics.log_freq", type=int, default=10, help="how often to log metrics to TensorBoard", ) - parser.add_argument( - "--metrics.enable_tensorboard", - action="store_true", - help="how often to log metrics to TensorBoard", - ) parser.add_argument( "--metrics.save_tb_folder", type=str, @@ -215,4 +215,9 @@ def init_args_from_command_line( "is an empty string, checkpointing is disabled." ), ) + parser.add_argument( + "--metrics.enable_selective_ac", + action="store_false", + help="whether to enable selective activation checkpointing", + ) return parser.parse_args(args_list) diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 805bfa87..7f04faa9 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -33,7 +33,6 @@ RowwiseParallel, ) from torchtrain.config_manager import JobConfig - from torchtrain.logging_utils import rank0_log logger = logging.getLogger(__name__) @@ -67,12 +66,54 @@ def partition_fn(name, module, device_mesh): ) +# AC/selective AC +no_recompute_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.c10d_functional.reduce_scatter_tensor.default, +} + # Uses PTD FSDP AC wrapper -# TODO: why is config needed here? -def checkpoint_wrapper(module, job_config: JobConfig): - return ptd_checkpoint_wrapper( - module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False - ) +def checkpoint_wrapper(module, enable_selective_ac=False): + if enable_selective_ac: + from torch.utils.checkpoint import ( + _pt2_selective_checkpoint_context_fn_gen, + checkpoint, + ) + + def _get_custom_policy(meta): + def _custom_policy(mode, func, *args, **kwargs): + mm_count_key = f"{mode}_mm_count" + if mm_count_key not in meta: + meta[mm_count_key] = 0 + if func == torch.ops.aten.mm.default: + meta[mm_count_key] += 1 + # Saves output of all compute ops, except every second mm + return func in no_recompute_list and not ( + func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 + ) + + return _custom_policy + + def selective_checkpointing_context_fn(): + meta = {} + return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta)) + + return ptd_checkpoint_wrapper( + module, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + checkpoint_fn=checkpoint, + context_fn=selective_checkpointing_context_fn, + use_reentrant=False, + preserve_rng_state=False, + ) + else: + return ptd_checkpoint_wrapper( + module, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + preserve_rng_state=False, + ) def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): @@ -168,10 +209,13 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): with enable_wrap(wrapper_cls=FSDP, **fsdp_config): for layer_id, transformer_block in enumerate(model.layers): - # apply AC to each layer # before wrapping with FSDP, we need to make sure the layer is on GPU transformer_block = transformer_block.cuda() - transformer_block = checkpoint_wrapper(transformer_block, job_config) + + # apply selective AC + transformer_block = checkpoint_wrapper( + transformer_block, job_config.training.enable_selective_ac + ) # Wraps each layer with FSDP model.layers[layer_id] = wrap(transformer_block) diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 1cca38b0..a7a28e22 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -37,3 +37,4 @@ checkpoint_interval = 3600 checkpoint_interval_type = "steps" checkpoint_folder = "" dataset = "alpaca" +enable_selective_ac = false From 55ec314c71fc572207be103deab21301ecdd674b Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 28 Feb 2024 12:05:24 -0800 Subject: [PATCH 2/4] Update on "add selective activation checkpointing" Selective activation checkpointing (SAC), compared with full AC, selectively stores some intermediate activations to save training time, at the cost of more memory usage. Here are some test results on llama 7B. with full activation checkpointing: - [rank0]: Average iter time: 4.9126 seconds - [rank0]: Peak Memory: Reserved 40.61%, Alloc 28.12%, Active: 29.61% with selective activation checkpointing: - [rank0]: Average iter time: 4.5459 seconds - [rank0]: Peak Memory: Reserved 80.45%, Alloc 62.0%, Active: 63.43% [ghstack-poisoned] --- torchtrain/config_manager.py | 2 +- torchtrain/parallelisms/parallelize_llama.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/torchtrain/config_manager.py b/torchtrain/config_manager.py index a75c77d6..48ce33ee 100644 --- a/torchtrain/config_manager.py +++ b/torchtrain/config_manager.py @@ -216,7 +216,7 @@ def init_args_from_command_line( ), ) parser.add_argument( - "--metrics.enable_selective_ac", + "--training.enable_selective_ac", action="store_false", help="whether to enable selective activation checkpointing", ) diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 7f04faa9..13ae0b1e 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -5,6 +5,7 @@ # llama model, i.e. activation checkpoint, etc. import logging +from collections import defaultdict import torch from torch.distributed._tensor import ( @@ -66,7 +67,7 @@ def partition_fn(name, module, device_mesh): ) -# AC/selective AC +# for selective AC no_recompute_list = { torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, @@ -85,8 +86,6 @@ def checkpoint_wrapper(module, enable_selective_ac=False): def _get_custom_policy(meta): def _custom_policy(mode, func, *args, **kwargs): mm_count_key = f"{mode}_mm_count" - if mm_count_key not in meta: - meta[mm_count_key] = 0 if func == torch.ops.aten.mm.default: meta[mm_count_key] += 1 # Saves output of all compute ops, except every second mm @@ -97,7 +96,7 @@ def _custom_policy(mode, func, *args, **kwargs): return _custom_policy def selective_checkpointing_context_fn(): - meta = {} + meta = defaultdict(int) return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta)) return ptd_checkpoint_wrapper( From 89a31e5463787af052490f2e88680ff304d542ce Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Thu, 29 Feb 2024 10:54:40 -0800 Subject: [PATCH 3/4] Update on "add selective activation checkpointing" Selective activation checkpointing (SAC), compared with full AC, selectively stores some intermediate activations to save training time, at the cost of more memory usage. Here are some test results on llama 7B. with full activation checkpointing: - [rank0]: Average iter time: 4.9126 seconds - [rank0]: Peak Memory: Reserved 40.61%, Alloc 28.12%, Active: 29.61% with selective activation checkpointing: - [rank0]: Average iter time: 4.5459 seconds - [rank0]: Peak Memory: Reserved 80.45%, Alloc 62.0%, Active: 63.43% [ghstack-poisoned] --- torchtrain/parallelisms/parallelize_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 13ae0b1e..ae02d15b 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -76,7 +76,7 @@ def partition_fn(name, module, device_mesh): } # Uses PTD FSDP AC wrapper -def checkpoint_wrapper(module, enable_selective_ac=False): +def checkpoint_wrapper(module, enable_selective_ac): if enable_selective_ac: from torch.utils.checkpoint import ( _pt2_selective_checkpoint_context_fn_gen, From 7ad1e43034c7cb07713da03118a8267c289dd7a4 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Thu, 29 Feb 2024 14:11:23 -0800 Subject: [PATCH 4/4] Update on "add selective activation checkpointing" Selective activation checkpointing (SAC), compared with full AC which always does activation recomputation, selectively stores some intermediate activations to save training time, at the cost of more memory usage. Here are some test results on llama 7B. with full activation checkpointing: - [rank0]: Average iter time: 4.9126 seconds - [rank0]: Peak Memory: Reserved 40.61%, Alloc 28.12%, Active: 29.61% with selective activation checkpointing: - [rank0]: Average iter time: 4.5459 seconds - [rank0]: Peak Memory: Reserved 80.45%, Alloc 62.0%, Active: 63.43% [ghstack-poisoned] --- train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/train.py b/train.py index 95d42226..8e4e44d3 100644 --- a/train.py +++ b/train.py @@ -155,6 +155,10 @@ def main(job_config: JobConfig): # torch.compile model for improved performance if job_config.training.compile: + if job_config.training.enable_selective_ac: + torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( + True + ) rank0_log(f"Compiling model {model_name} with torch.compile...") model = torch.compile( model,