-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix checkpointable_layers
Logic
#6881
Merged
loadams
merged 8 commits into
deepspeedai:master
from
Quentin-Anthony:qanthony/fix-act-recomp
Jan 4, 2025
Merged
Fix checkpointable_layers
Logic
#6881
loadams
merged 8 commits into
deepspeedai:master
from
Quentin-Anthony:qanthony/fix-act-recomp
Jan 4, 2025
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tjruwase
approved these changes
Dec 17, 2024
loadams
reviewed
Dec 17, 2024
tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py
Show resolved
Hide resolved
siqi654321
pushed a commit
to siqi654321/DeepSpeed
that referenced
this pull request
Feb 7, 2025
**Problem** There's an edge-case in DeepSpeed, where if all three of the following are true: 1. Deepspeed activation checkpointing is applied 2. The user passes `checkpointable_layers` (e.g. https://github.com/EleutherAI/gpt-neox/blob/f5325805678c2b9e35aae4528283e0132c5f5bbc/megatron/model/gpt2_model.py#L175) 3. The user's model class contains `GPT2ModelPipe` or GPTModelPipe` Then the `checkpointable_layers` will not be activation checkpointed. **Reason** This is because in the current logic, `_is_checkpointable` will short-circuit to just return layers matching `ParallelTransformerLayerPipe` in the case of `self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe')`. See https://github.com/microsoft/DeepSpeed/blob/da771ed42e41a44d5047813ca4672f1cfe9d1731/deepspeed/runtime/pipe/module.py#L653 **Proposed Fixes** I think that `checkpointable_layers` should always be checked for, and added logic to this effect. I also found the documentation for `checkpointable_layers` confusing and contradictory, so I updated the docstring. Lastly, I added a unit test for `checkpointable_layers`. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Signed-off-by: siqi <siqi@tecorigin.com>
traincheck-team
pushed a commit
to traincheck-team/DeepSpeed
that referenced
this pull request
Feb 9, 2025
**Problem** There's an edge-case in DeepSpeed, where if all three of the following are true: 1. Deepspeed activation checkpointing is applied 2. The user passes `checkpointable_layers` (e.g. https://github.com/EleutherAI/gpt-neox/blob/f5325805678c2b9e35aae4528283e0132c5f5bbc/megatron/model/gpt2_model.py#L175) 3. The user's model class contains `GPT2ModelPipe` or GPTModelPipe` Then the `checkpointable_layers` will not be activation checkpointed. **Reason** This is because in the current logic, `_is_checkpointable` will short-circuit to just return layers matching `ParallelTransformerLayerPipe` in the case of `self.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe')`. See https://github.com/microsoft/DeepSpeed/blob/da771ed42e41a44d5047813ca4672f1cfe9d1731/deepspeed/runtime/pipe/module.py#L653 **Proposed Fixes** I think that `checkpointable_layers` should always be checked for, and added logic to this effect. I also found the documentation for `checkpointable_layers` confusing and contradictory, so I updated the docstring. Lastly, I added a unit test for `checkpointable_layers`. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Problem
There's an edge-case in DeepSpeed, where if all three of the following are true:
checkpointable_layers
(e.g. https://github.com/EleutherAI/gpt-neox/blob/f5325805678c2b9e35aae4528283e0132c5f5bbc/megatron/model/gpt2_model.py#L175)GPT2ModelPipe
or GPTModelPipe`Then the
checkpointable_layers
will not be activation checkpointed.Reason
This is because in the current logic,
_is_checkpointable
will short-circuit to just return layers matchingParallelTransformerLayerPipe
in the case ofself.__class__.__name__ in ('GPTModelPipe', 'GPT2ModelPipe')
. See https://github.com/microsoft/DeepSpeed/blob/da771ed42e41a44d5047813ca4672f1cfe9d1731/deepspeed/runtime/pipe/module.py#L653Proposed Fixes
I think that
checkpointable_layers
should always be checked for, and added logic to this effect. I also found the documentation forcheckpointable_layers
confusing and contradictory, so I updated the docstring. Lastly, I added a unit test forcheckpointable_layers
.