-
Notifications
You must be signed in to change notification settings - Fork 252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add selective activation checkpointing #97
Conversation
[ghstack-poisoned]
ghstack-source-id: 2fbae95768f06b1b35af3f69eb0d39777b214089 Pull Request resolved: #97
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great addition to the codebase!
lgtm!
two minor proposals:
- propose a more pythonic/cleaner use of defaultdict for meta dict
- longer term the ac policy probably wants to live at a more global level for future re-use and/or customization.
@@ -67,12 +66,54 @@ def partition_fn(name, module, device_mesh): | |||
) | |||
|
|||
|
|||
# AC/selective AC | |||
no_recompute_list = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not vital for now, but longer term I think this policy should be exposed at a higher level if we expect to have other models being added and/or expect people to customize this policy list . i.e. if have parallelize_gpt or similar, then it's awkward to pull the recompute list from parallelize_llama.
def _get_custom_policy(meta): | ||
def _custom_policy(mode, func, *args, **kwargs): | ||
mm_count_key = f"{mode}_mm_count" | ||
if mm_count_key not in meta: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is cleaner to use a default dict here:
meta = defaultdict(int)
and then no need for this this check whether key is present and init.
if mm_count_key not in meta: meta[mm_count_key] = 0
return _custom_policy | ||
|
||
def selective_checkpointing_context_fn(): | ||
meta = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
propose
meta = defaultdict(int)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see inline comments
torchtrain/config_manager.py
Outdated
@@ -215,4 +215,9 @@ def init_args_from_command_line( | |||
"is an empty string, checkpointing is disabled." | |||
), | |||
) | |||
parser.add_argument( | |||
"--metrics.enable_selective_ac", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not a metrics flag, but rather a training flag
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a mistake. I'll correct this.
@@ -37,3 +37,4 @@ checkpoint_interval = 3600 | |||
checkpoint_interval_type = "steps" | |||
checkpoint_folder = "" | |||
dataset = "alpaca" | |||
enable_selective_ac = false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm you are putting it in the training section but the cmd arg parser is not on training, I think this should not work as expected.. (if it is, then we need to figuring out why toml parsing works wrongly)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In our setting, cmd arg parser is a backup way of providing args. The toml file has it in the training section, in the code traning.enable_selective_ac
is used, so no problem there -- the cmd arg metrics.enable_selective_ac
is parsed but not used.
Selective activation checkpointing (SAC), compared with full AC, selectively stores some intermediate activations to save training time, at the cost of more memory usage. Here are some test results on llama 7B. with full activation checkpointing: - [rank0]: Average iter time: 4.9126 seconds - [rank0]: Peak Memory: Reserved 40.61%, Alloc 28.12%, Active: 29.61% with selective activation checkpointing: - [rank0]: Average iter time: 4.5459 seconds - [rank0]: Peak Memory: Reserved 80.45%, Alloc 62.0%, Active: 63.43% [ghstack-poisoned]
ghstack-source-id: 1ab8e19df172354cb3cd7f10a7c45e7c7c1ceb51 Pull Request resolved: #97
return ptd_checkpoint_wrapper( | ||
module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False | ||
) | ||
def checkpoint_wrapper(module, enable_selective_ac=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super nit: Prefer to not make enable_selective_ac
a default arg if we always pass it explicitly.
checkpoint, | ||
) | ||
|
||
def _get_custom_policy(meta): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my understanding, does this policy also run in eager mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually it only works in eager mode; with compiler we got:
[rank0]:[rank0]: assert torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint, \
[rank0]:[rank0]: torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
[rank0]:[rank0]: AssertionError: Passing context_fn to torch.utils.checkpoint is currently not supported under torch.compile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NVM. Wanchao told me we need to set torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = True
for it to work. Now SAC works with compile.
Selective activation checkpointing (SAC), compared with full AC, selectively stores some intermediate activations to save training time, at the cost of more memory usage. Here are some test results on llama 7B. with full activation checkpointing: - [rank0]: Average iter time: 4.9126 seconds - [rank0]: Peak Memory: Reserved 40.61%, Alloc 28.12%, Active: 29.61% with selective activation checkpointing: - [rank0]: Average iter time: 4.5459 seconds - [rank0]: Peak Memory: Reserved 80.45%, Alloc 62.0%, Active: 63.43% [ghstack-poisoned]
ghstack-source-id: 1d082ae8c9639ac0dffbbfac1d325fef86a2a880 Pull Request resolved: #97
Selective activation checkpointing (SAC), compared with full AC which always does activation recomputation, selectively stores some intermediate activations to save training time, at the cost of more memory usage. Here are some test results on llama 7B. with full activation checkpointing: - [rank0]: Average iter time: 4.9126 seconds - [rank0]: Peak Memory: Reserved 40.61%, Alloc 28.12%, Active: 29.61% with selective activation checkpointing: - [rank0]: Average iter time: 4.5459 seconds - [rank0]: Peak Memory: Reserved 80.45%, Alloc 62.0%, Active: 63.43% [ghstack-poisoned]
ghstack-source-id: f7ee3c867bfcdcae5dbb490982920606191e6f40 Pull Request resolved: #97
ghstack-source-id: f7ee3c867bfcdcae5dbb490982920606191e6f40 Pull Request resolved: #97
ghstack-source-id: f7ee3c867bfcdcae5dbb490982920606191e6f40 Pull Request resolved: #97
ghstack-source-id: f7ee3c867bfcdcae5dbb490982920606191e6f40 Pull Request resolved: pytorch#97
Stack from ghstack (oldest at bottom):
Selective activation checkpointing (SAC), compared with full AC which always does activation recomputation, selectively stores some intermediate activations to save training time, at the cost of more memory usage.
Here are some test results on llama 7B.
with full activation checkpointing:
with selective activation checkpointing: