-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize ForCausalLMLoss by removing unnecessary contiguous() call to reduce memory overhead #35646
Optimize ForCausalLMLoss by removing unnecessary contiguous() call to reduce memory overhead #35646
Conversation
…o reduce memory overhead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good but before merging could you add some performance data? 🤗
Condition: Without clearing the cache, the |
Thank you very much, I want to understand some questions.Can you explain some of your test case,I don't understand what's the meaning of "~create_mask(max_seqlen, half_prompts_lens)" and "mask" and "start_pos_id" and "end_pos_id". Can you give me your personal email? @efsotr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this PR!
Explaining it (for myself and reviewers): logits
is (batch, seq_len, vocab_size)
and labels is just (batch, seq_len)
. Also, logits
is a float tensor that carries gradient, and labels
is not. Therefore, you want to avoid small manipulations of logits
as much as possible, because they are slow and because they pollute the graph with extra temporary tensors that have to be saved for backprop.
Before this PR, we shifted logits
and labels
one position each, in opposite directions. After this PR, we keep logits
static and pad+shift labels
. This is not exactly equivalent, so the code compensates by using the ignore_index
of -100
as the pad value, which restores equivalence.
This is a great change. Some models compute loss internally and don't use this function, so if you want to do a follow-up PR to search the codebase for shift_logits
and update the logic there too, that would be a nice speed boost for them as well. Thank you!
(Merging without @ArthurZucker approval because the code has the same output, it's just faster + less memory) |
LGTM anyways, thanks all 🤗 |
… reduce memory overhead (huggingface#35646) Optimize ForCausalLMLoss by removing unnecessary contiguous() calls to reduce memory overhead
… reduce memory overhead (huggingface#35646) Optimize ForCausalLMLoss by removing unnecessary contiguous() calls to reduce memory overhead
What does this PR do?
Removing unnecessary contiguous() call
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.