-
Notifications
You must be signed in to change notification settings - Fork 252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add selective activation checkpointing #97
Changes from 2 commits
c7c6b4e
55ec314
89a31e5
7ad1e43
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
# llama model, i.e. activation checkpoint, etc. | ||
|
||
import logging | ||
from collections import defaultdict | ||
|
||
import torch | ||
from torch.distributed._tensor import ( | ||
|
@@ -33,7 +34,6 @@ | |
RowwiseParallel, | ||
) | ||
from torchtrain.config_manager import JobConfig | ||
|
||
from torchtrain.logging_utils import rank0_log | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -67,12 +67,52 @@ def partition_fn(name, module, device_mesh): | |
) | ||
|
||
|
||
# for selective AC | ||
no_recompute_list = { | ||
torch.ops.aten.mm.default, | ||
torch.ops.aten._scaled_dot_product_efficient_attention.default, | ||
torch.ops.aten._scaled_dot_product_flash_attention.default, | ||
torch.ops.c10d_functional.reduce_scatter_tensor.default, | ||
} | ||
|
||
# Uses PTD FSDP AC wrapper | ||
# TODO: why is config needed here? | ||
def checkpoint_wrapper(module, job_config: JobConfig): | ||
return ptd_checkpoint_wrapper( | ||
module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False | ||
) | ||
def checkpoint_wrapper(module, enable_selective_ac=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. super nit: Prefer to not make |
||
if enable_selective_ac: | ||
from torch.utils.checkpoint import ( | ||
_pt2_selective_checkpoint_context_fn_gen, | ||
checkpoint, | ||
) | ||
|
||
def _get_custom_policy(meta): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For my understanding, does this policy also run in eager mode? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually it only works in eager mode; with compiler we got:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NVM. Wanchao told me we need to set |
||
def _custom_policy(mode, func, *args, **kwargs): | ||
mm_count_key = f"{mode}_mm_count" | ||
if func == torch.ops.aten.mm.default: | ||
meta[mm_count_key] += 1 | ||
# Saves output of all compute ops, except every second mm | ||
return func in no_recompute_list and not ( | ||
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 | ||
) | ||
|
||
return _custom_policy | ||
|
||
def selective_checkpointing_context_fn(): | ||
meta = defaultdict(int) | ||
return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta)) | ||
|
||
return ptd_checkpoint_wrapper( | ||
module, | ||
checkpoint_impl=CheckpointImpl.NO_REENTRANT, | ||
checkpoint_fn=checkpoint, | ||
context_fn=selective_checkpointing_context_fn, | ||
use_reentrant=False, | ||
preserve_rng_state=False, | ||
) | ||
else: | ||
return ptd_checkpoint_wrapper( | ||
module, | ||
checkpoint_impl=CheckpointImpl.NO_REENTRANT, | ||
preserve_rng_state=False, | ||
) | ||
|
||
|
||
def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | ||
|
@@ -168,10 +208,13 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): | |
|
||
with enable_wrap(wrapper_cls=FSDP, **fsdp_config): | ||
for layer_id, transformer_block in enumerate(model.layers): | ||
# apply AC to each layer | ||
# before wrapping with FSDP, we need to make sure the layer is on GPU | ||
transformer_block = transformer_block.cuda() | ||
transformer_block = checkpoint_wrapper(transformer_block, job_config) | ||
|
||
# apply selective AC | ||
transformer_block = checkpoint_wrapper( | ||
transformer_block, job_config.training.enable_selective_ac | ||
) | ||
|
||
# Wraps each layer with FSDP | ||
model.layers[layer_id] = wrap(transformer_block) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,3 +37,4 @@ checkpoint_interval = 3600 | |
checkpoint_interval_type = "steps" | ||
checkpoint_folder = "" | ||
dataset = "alpaca" | ||
enable_selective_ac = false | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmmm you are putting it in the training section but the cmd arg parser is not on training, I think this should not work as expected.. (if it is, then we need to figuring out why toml parsing works wrongly) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In our setting, cmd arg parser is a backup way of providing args. The toml file has it in the training section, in the code |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not vital for now, but longer term I think this policy should be exposed at a higher level if we expect to have other models being added and/or expect people to customize this policy list . i.e. if have parallelize_gpt or similar, then it's awkward to pull the recompute list from parallelize_llama.