Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize ForCausalLMLoss by removing unnecessary contiguous() call to reduce memory overhead #35646

Merged
merged 1 commit into from
Jan 16, 2025

Conversation

efsotr
Copy link
Contributor

@efsotr efsotr commented Jan 13, 2025

What does this PR do?

Removing unnecessary contiguous() call

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@efsotr efsotr requested a review from ArthurZucker as a code owner January 13, 2025 03:24
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good but before merging could you add some performance data? 🤗

@efsotr
Copy link
Contributor Author

efsotr commented Jan 13, 2025

Condition: Without clearing the cache, the dtype is bfloat16, and the size of logits is 0.5 GB.
The old ForCausalLMLoss requires 3 GB for the forward pass and an additional 0.5 GB to hold the original logits.
The new ForCausalLMLoss reduces this to 2 GB for the forward pass while still needing 0.5 GB for holding the logits.
For the backward pass, both versions require 3 GB, plus the 0.5 GB needed to hold the original logits.

@v-lmn
Copy link

v-lmn commented Jan 16, 2025

Thank you very much, I want to understand some questions.Can you explain some of your test case,I don't understand what's the meaning of "~create_mask(max_seqlen, half_prompts_lens)" and "mask" and "start_pos_id" and "end_pos_id". Can you give me your personal email? @efsotr

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this PR!

Explaining it (for myself and reviewers): logits is (batch, seq_len, vocab_size) and labels is just (batch, seq_len). Also, logits is a float tensor that carries gradient, and labels is not. Therefore, you want to avoid small manipulations of logits as much as possible, because they are slow and because they pollute the graph with extra temporary tensors that have to be saved for backprop.

Before this PR, we shifted logits and labels one position each, in opposite directions. After this PR, we keep logits static and pad+shift labels. This is not exactly equivalent, so the code compensates by using the ignore_index of -100 as the pad value, which restores equivalence.

This is a great change. Some models compute loss internally and don't use this function, so if you want to do a follow-up PR to search the codebase for shift_logits and update the logic there too, that would be a nice speed boost for them as well. Thank you!

@Rocketknight1
Copy link
Member

(Merging without @ArthurZucker approval because the code has the same output, it's just faster + less memory)

@Rocketknight1 Rocketknight1 merged commit 8ebe9d7 into huggingface:main Jan 16, 2025
24 of 25 checks passed
@ArthurZucker
Copy link
Collaborator

LGTM anyways, thanks all 🤗

bursteratom pushed a commit to bursteratom/transformers that referenced this pull request Jan 31, 2025
… reduce memory overhead (huggingface#35646)

Optimize ForCausalLMLoss by removing unnecessary contiguous() calls to reduce memory overhead
elvircrn pushed a commit to elvircrn/transformers that referenced this pull request Feb 13, 2025
… reduce memory overhead (huggingface#35646)

Optimize ForCausalLMLoss by removing unnecessary contiguous() calls to reduce memory overhead
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants