Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add selective activation checkpointing #97

Merged
merged 4 commits into from
Feb 29, 2024

Conversation

tianyu-l
Copy link
Contributor

@tianyu-l tianyu-l commented Feb 28, 2024

Stack from ghstack (oldest at bottom):

Selective activation checkpointing (SAC), compared with full AC which always does activation recomputation, selectively stores some intermediate activations to save training time, at the cost of more memory usage.

Here are some test results on llama 7B.

with full activation checkpointing:

  • [rank0]: Average iter time: 4.9126 seconds
  • [rank0]: Peak Memory: Reserved 40.61%, Alloc 28.12%, Active: 29.61%

with selective activation checkpointing:

  • [rank0]: Average iter time: 4.5459 seconds
  • [rank0]: Peak Memory: Reserved 80.45%, Alloc 62.0%, Active: 63.43%

tianyu-l added a commit that referenced this pull request Feb 28, 2024
ghstack-source-id: 2fbae95768f06b1b35af3f69eb0d39777b214089
Pull Request resolved: #97
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 28, 2024
@tianyu-l tianyu-l linked an issue Feb 28, 2024 that may be closed by this pull request
Copy link
Contributor

@lessw2020 lessw2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great addition to the codebase!
lgtm!
two minor proposals:

  • propose a more pythonic/cleaner use of defaultdict for meta dict
  • longer term the ac policy probably wants to live at a more global level for future re-use and/or customization.

@@ -67,12 +66,54 @@ def partition_fn(name, module, device_mesh):
)


# AC/selective AC
no_recompute_list = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not vital for now, but longer term I think this policy should be exposed at a higher level if we expect to have other models being added and/or expect people to customize this policy list . i.e. if have parallelize_gpt or similar, then it's awkward to pull the recompute list from parallelize_llama.

def _get_custom_policy(meta):
def _custom_policy(mode, func, *args, **kwargs):
mm_count_key = f"{mode}_mm_count"
if mm_count_key not in meta:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is cleaner to use a default dict here:
meta = defaultdict(int)
and then no need for this this check whether key is present and init.
if mm_count_key not in meta: meta[mm_count_key] = 0

return _custom_policy

def selective_checkpointing_context_fn():
meta = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

propose
meta = defaultdict(int)

Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see inline comments

@@ -215,4 +215,9 @@ def init_args_from_command_line(
"is an empty string, checkpointing is disabled."
),
)
parser.add_argument(
"--metrics.enable_selective_ac",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not a metrics flag, but rather a training flag

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a mistake. I'll correct this.

@@ -37,3 +37,4 @@ checkpoint_interval = 3600
checkpoint_interval_type = "steps"
checkpoint_folder = ""
dataset = "alpaca"
enable_selective_ac = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm you are putting it in the training section but the cmd arg parser is not on training, I think this should not work as expected.. (if it is, then we need to figuring out why toml parsing works wrongly)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our setting, cmd arg parser is a backup way of providing args. The toml file has it in the training section, in the code traning.enable_selective_ac is used, so no problem there -- the cmd arg metrics.enable_selective_ac is parsed but not used.

Selective activation checkpointing (SAC), compared with full AC, selectively stores some intermediate activations to save training time, at the cost of more memory usage.

Here are some test results on llama 7B.

with full activation checkpointing:
- [rank0]: Average iter time: 4.9126 seconds
- [rank0]: Peak Memory: Reserved 40.61%, Alloc 28.12%, Active: 29.61%

with selective activation checkpointing:
- [rank0]: Average iter time: 4.5459 seconds
- [rank0]: Peak Memory: Reserved 80.45%, Alloc 62.0%, Active: 63.43%

[ghstack-poisoned]
tianyu-l added a commit that referenced this pull request Feb 28, 2024
ghstack-source-id: 1ab8e19df172354cb3cd7f10a7c45e7c7c1ceb51
Pull Request resolved: #97
return ptd_checkpoint_wrapper(
module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False
)
def checkpoint_wrapper(module, enable_selective_ac=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: Prefer to not make enable_selective_ac a default arg if we always pass it explicitly.

checkpoint,
)

def _get_custom_policy(meta):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, does this policy also run in eager mode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it only works in eager mode; with compiler we got:

[rank0]:[rank0]:     assert torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint, \
[rank0]:[rank0]: torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
[rank0]:[rank0]: AssertionError: Passing context_fn to torch.utils.checkpoint is currently not supported under torch.compile

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVM. Wanchao told me we need to set torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = True for it to work. Now SAC works with compile.

Selective activation checkpointing (SAC), compared with full AC, selectively stores some intermediate activations to save training time, at the cost of more memory usage.

Here are some test results on llama 7B.

with full activation checkpointing:
- [rank0]: Average iter time: 4.9126 seconds
- [rank0]: Peak Memory: Reserved 40.61%, Alloc 28.12%, Active: 29.61%

with selective activation checkpointing:
- [rank0]: Average iter time: 4.5459 seconds
- [rank0]: Peak Memory: Reserved 80.45%, Alloc 62.0%, Active: 63.43%

[ghstack-poisoned]
tianyu-l added a commit that referenced this pull request Feb 29, 2024
ghstack-source-id: 1d082ae8c9639ac0dffbbfac1d325fef86a2a880
Pull Request resolved: #97
Selective activation checkpointing (SAC), compared with full AC which always does activation recomputation, selectively stores some intermediate activations to save training time, at the cost of more memory usage.

Here are some test results on llama 7B.

with full activation checkpointing:
- [rank0]: Average iter time: 4.9126 seconds
- [rank0]: Peak Memory: Reserved 40.61%, Alloc 28.12%, Active: 29.61%

with selective activation checkpointing:
- [rank0]: Average iter time: 4.5459 seconds
- [rank0]: Peak Memory: Reserved 80.45%, Alloc 62.0%, Active: 63.43%

[ghstack-poisoned]
tianyu-l added a commit that referenced this pull request Feb 29, 2024
ghstack-source-id: f7ee3c867bfcdcae5dbb490982920606191e6f40
Pull Request resolved: #97
@tianyu-l tianyu-l merged commit 7ad1e43 into gh/tianyu-l/2/base Feb 29, 2024
4 checks passed
tianyu-l added a commit that referenced this pull request Feb 29, 2024
ghstack-source-id: f7ee3c867bfcdcae5dbb490982920606191e6f40
Pull Request resolved: #97
@tianyu-l tianyu-l deleted the gh/tianyu-l/2/head branch February 29, 2024 22:18
@tianyu-l tianyu-l mentioned this pull request Mar 20, 2024
lessw2020 pushed a commit that referenced this pull request Apr 18, 2024
ghstack-source-id: f7ee3c867bfcdcae5dbb490982920606191e6f40
Pull Request resolved: #97
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
ghstack-source-id: f7ee3c867bfcdcae5dbb490982920606191e6f40
Pull Request resolved: pytorch#97
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

add AC/selective AC to the model
5 participants