-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: speculative decoding #27979
Conversation
18a4eda
to
993c9ee
Compare
@patrickvonplaten tagging you here for a 2nd set of eyes on the speculative decoding method (changes in |
src/transformers/generation/utils.py
Outdated
if do_sample and candidate_logits is not None: | ||
# Gets the probabilities from the logits. q_i and p_i denote the model and assistant (respectively) | ||
# probabilities of the tokens selected by the assistant. | ||
q = candidate_logits.softmax(dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are not the best variable names, but it's hard to compare against the original algorithm if they don't match 🤔 As such, I've decided to keep the original names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with it as there's good comments and other variables are well names e.g. is_rejected
:)
Thanks for adding this! Can we split this up into two separate PRs: one changing the assisted generation and the other adding speculative decoding? |
@amyeroberts pulled the assisted generation changes into this PR: #28030 After it is merged, I will rebase this one and ping you again -- this one will become exclusively about speculative decoding 🤗 |
7bf05a9
to
e234e1e
Compare
@amyeroberts I've rerun the slow tests, and I can confirm they are passing. Ready for a review :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this!
Can we add some tests, in particular one which checks case 1. and one which makes sure the correct logic branch is being selected e.g. checking candidate_logits is None when expected (might be a test on the candidate generator instead)?
src/transformers/generation/utils.py
Outdated
if do_sample and candidate_logits is not None: | ||
# Gets the probabilities from the logits. q_i and p_i denote the model and assistant (respectively) | ||
# probabilities of the tokens selected by the assistant. | ||
q = candidate_logits.softmax(dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with it as there's good comments and other variables are well names e.g. is_rejected
:)
if do_sample: | ||
probs = new_logits.softmax(dim=-1) | ||
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] | ||
else: | ||
selected_tokens = new_logits.argmax(dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if do_sample: | |
probs = new_logits.softmax(dim=-1) | |
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] | |
else: | |
selected_tokens = new_logits.argmax(dim=-1) | |
if do_sample: | |
probs = new_logits.softmax(dim=-1) | |
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :] | |
else: | |
selected_tokens = new_logits.argmax(dim=-1) |
It's probably time to soon factor this out into something like:
selected_tokens = Categorical(new_logits / temperature).sample()
everywhere in generate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! Then equivalent sampling/non-sampling methods (e.g. greedy decoding/samplinh) could be merged into a single function, facilitating maintenance. I'm going to leave it to a follow-up PR, though, to keep this PR exclusively about speculative decoding.
else: | ||
selected_tokens = new_logits.argmax(dim=-1) | ||
if do_sample: | ||
probs = new_logits.softmax(dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this case still relevant? Not sure it's a good idea to have two "assisted decoding" do_sample=True cases in our generate. Should we maybe just deprecate this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super cool addition!
Not really related to this PR, but I feel like we should start putting all the generation submethods (assisted decoding, greedy & sample (guess we can merge these two), beam search, ...) into their own files by now
My only important comment here is that I don't think it's great that we have 2 assisted generation cases now where do_sample=True
. Can we deprecate the "non-official" one?
@patrickvonplaten the two types of sampling are needed :D New candidate-based methods are popping up (e.g. #27775), and they don't necessarily have logits. As such, speculative decoding, which needs the candidates' logits, can't be applied to those methods. |
But shouldn't they just be the "own" method now? I.e. I don't think we should put #27775 into the speculative decoding method no? |
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@patrickvonplaten #27775 does not introduce changes to assisted generation 🤗 In #28030 I've abstracted the candidate generation part of assisted generation. We now load candidate generators the same way as we load the logits processors: transformers/src/transformers/generation/utils.py Lines 899 to 919 in e6dcf8a
In assisted generation, we call the candidate generator to get candidate sequences (which may or may not contain associated logits, depending on the method) transformers/src/transformers/generation/utils.py Line 4588 in e6dcf8a
The technique in #27775 can thus be added by adding a new candidate generator in Because needing the logits (for speculative decoding) is a very limiting constraint, I'd rather keep the two sampling paths. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@amyeroberts PR comments addressed 🤗 @patrickvonplaten Unless you don't strongly oppose, I'd like to keep the two sampling paths, for the reasons I've written here -- I think it will be beneficial in the long run! :) (otherwise, a whole new generation method has to be written for #27775) |
@amyeroberts -- @patrickvonplaten and I had a chat about whether to keep the two sampling paths or not. For context, here's what we agreed on:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating!
@gante |
@jmamou speculative decoding with |
@gante |
@gante |
@gante In current implementation (4.38), Is it intentional? If that's a bug, I can open a PR to fix it. |
Not sure if this is a good idea
This is a good point! A PR to revert to the previous behaviour (with a test) would be appreciated 🙏 |
What does this PR do?
Useful context:
In a recent PR (#27750), the candidate generation in assisted generation got abstracted, so we can host new candidate generation techniques (such as #27722).
This PR:
Reworks assisted candidate generation to callEdit: moved to Generate: assisted decoding now uses.generate()
, instead of having its own custom generation loop. For most models this is nothing more than a nice abstraction. However, for models with a customgenerate()
function, this means the assistant model will now make use of it! (🤔 does this mean that DistilWhisper gets better numbers with this refactor?)generate
for the assistant #28030The following tests were run locally and are passing:
RUN_SLOW=1 py.test tests/models/whisper/ -k speculative
py.test tests/ -k test_assisted
(which now triggers speculative decoding)TODO: