Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix requires_grad of input must be true for activation checkpoint layer in pipeline train. #4128

Closed
wants to merge 11 commits into from

Conversation

inkcherry
Copy link
Contributor

@inkcherry inkcherry commented Aug 10, 2023

mechanism FYI https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb?short_path=e1e8ee7#L282C244-L282C257

For discrete input, could make activation for the embedding layer.
For image input, no need to calculate and save grads for input.

Have verified on megatron-deepspeed train/cifar10 train ut (test_pipe. py). Under the same seed, the data of parameter and gradient updates is exactly equal to before this addition.

It can reduce small memory on rank0 (stage0 contains embedding layer)

@tjruwase tjruwase requested a review from tohtana August 10, 2023 09:50
Copy link
Contributor

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for submitting this PR, @inkcherry!
Is the goal of this to stop creating gradients of inputs when it is unnecessary?

Can you also clarify intentions of some changes?

@@ -638,7 +638,6 @@ def _exec_forward_pass(self, buffer_id):

# Zero out the gradients each time we use the tensor because only the data in
# tensor changes across batches
self._zero_grads(inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we can delete this? We could also delete the comment before this line.

Copy link
Contributor Author

@inkcherry inkcherry Aug 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple):
inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id])
else:
inputs = self.pipe_buffers['inputs'][buffer_id].clone()
here inputs has become a non-leaf tensor by clone() op, and its gradient has not been saved. so no need to call _zero_grad again.

and if if self.is_pipe_partitioned and not self.is_first_stage() is False, which means not create a new leaf tensor named input, access current inputs'(non-leaf) gradient in _zero_grad will trigger a warning

class CkptLayer_Enum(Enum):
not_ckpt_layer = 0
normal_ckpt_layer = 1
warp_ckpt_layer = 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does warp mean wrap?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

@@ -320,7 +343,7 @@ def forward(self, forward_input):
# will see a different offset.
self.micro_offset += 1

def exec_range_func(start, end):
def exec_range_func(start, end, warp_layer=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean wrap, not warp?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

@@ -342,7 +365,29 @@ def exec_func(*inputs):
inputs = layer(inputs)
return inputs

return exec_func
def exec_func_warp(*inputs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we consolidate this with exec_func?
Most of the code are duplicated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merged.

@inkcherry
Copy link
Contributor Author

inkcherry commented Aug 17, 2023

Thank you for submitting this PR, @inkcherry! Is the goal of this to stop creating gradients of inputs when it is unnecessary?

Can you also clarify intentions of some changes?

@tohtana , thanks for your review and apologies for not explaining clearly.

Yes, currently, requires_grads of input of the first checkpoint layer must be true. I think this is a limitation, and the current code may have 2 following logic to work around this limitation in two scenarios. please correct me if I'm wrong.

  1. For image input, set requires_ grad=True for input from the dataset,

    loaded.requires_grad = loaded.is_floating_point()

    this will calculate and save grads for the input image, which I think may cost memory and computing time with large size input. related issue. The .grad attribute of a Tensor that is not a leaf Tensor is being accessed CERC-AAI/multimodal#16

  2. Secondly, for LLM, notice the comment

    # This is an unfortunate hack related to torch and deepspeed activation checkpoint implementations.
    # Some layers like torch.nn.Embedding will not receive grads if checkpointed, which breaks things.
    # I presume it's related to the discrete inputs that cannot require_grad? Need to revisit.
    , it hooks transformer layer(force make embedding layer as a non-checkpoint layer, which take input with requires_grad=False and ouput tensor with requires_grad=True).

For this limitation, we could pass a dummy input that requires grad but isn't necessarily used in computation FYI.
https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb?short_path=e1e8ee7#L282C244-L282C257

Here wrapping a WrapModule at the first layer to achieve this.

@inkcherry
Copy link
Contributor Author

@tohtana , I've made some changes to the code. Any suggestions? would appreciate your feedback.

@tohtana
Copy link
Contributor

tohtana commented Aug 23, 2023

Hi @inkcherry, thank you for the fixes. Overall, this looks okay to me.

On the other hand, I just reviewed #4118 and this might solve a part of the issue you addressed. I didn't notice this PR until very recently, and I'm sorry that I couldn't share it with you.

We want to merge the PR, and I am wondering if we can simplify the changes in this PR using the new features proposed in #4118. Can you share your thoughts on this?

@inkcherry
Copy link
Contributor Author

inkcherry commented Aug 23, 2023

@tohtana, yes I agree, #4118 is great!
I could try based on #4118 when it is merged.

@tjruwase
Copy link
Contributor

@inkcherry, #4118 is now merged. Also, could you please add a unit test when you retry? Thanks!

@inkcherry
Copy link
Contributor Author

moved to #4224
close this one

@inkcherry inkcherry closed this Aug 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants