Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: speculative decoding #27979

Merged
merged 8 commits into from
Dec 19, 2023

Conversation

gante
Copy link
Member

@gante gante commented Dec 12, 2023

What does this PR do?

Useful context:
In a recent PR (#27750), the candidate generation in assisted generation got abstracted, so we can host new candidate generation techniques (such as #27722).


This PR:

  1. Reworks assisted candidate generation to call .generate(), instead of having its own custom generation loop. For most models this is nothing more than a nice abstraction. However, for models with a custom generate() function, this means the assistant model will now make use of it! (🤔 does this mean that DistilWhisper gets better numbers with this refactor?) Edit: moved to Generate: assisted decoding now uses generate for the assistant #28030
  2. Adds speculative decoding (paper, see Algorithm 1). This implied a minor interface change in the candidate generation class, which should be okay since it hasn't been released :)

The following tests were run locally and are passing:

  1. RUN_SLOW=1 py.test tests/models/whisper/ -k speculative
  2. py.test tests/ -k test_assisted (which now triggers speculative decoding)

TODO:

  • Benchmark speculative decoding

@gante gante force-pushed the candidate_generate_refactor branch from 18a4eda to 993c9ee Compare December 12, 2023 17:08
@gante gante marked this pull request as ready for review December 12, 2023 17:11
@gante
Copy link
Member Author

gante commented Dec 12, 2023

@patrickvonplaten tagging you here for a 2nd set of eyes on the speculative decoding method (changes in utils.py), which I'm assuming you're familiar with. Feel free to delegate to someone else who is familiar with the method! 🤗

if do_sample and candidate_logits is not None:
# Gets the probabilities from the logits. q_i and p_i denote the model and assistant (respectively)
# probabilities of the tokens selected by the assistant.
q = candidate_logits.softmax(dim=-1)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are not the best variable names, but it's hard to compare against the original algorithm if they don't match 🤔 As such, I've decided to keep the original names

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with it as there's good comments and other variables are well names e.g. is_rejected :)

@amyeroberts
Copy link
Collaborator

Thanks for adding this! Can we split this up into two separate PRs: one changing the assisted generation and the other adding speculative decoding?

@gante
Copy link
Member Author

gante commented Dec 14, 2023

@amyeroberts pulled the assisted generation changes into this PR: #28030

After it is merged, I will rebase this one and ping you again -- this one will become exclusively about speculative decoding 🤗

@gante gante force-pushed the candidate_generate_refactor branch from 7bf05a9 to e234e1e Compare December 14, 2023 14:03
@gante
Copy link
Member Author

gante commented Dec 14, 2023

@amyeroberts I've rerun the slow tests, and I can confirm they are passing. Ready for a review :)

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

Can we add some tests, in particular one which checks case 1. and one which makes sure the correct logic branch is being selected e.g. checking candidate_logits is None when expected (might be a test on the candidate generator instead)?

if do_sample and candidate_logits is not None:
# Gets the probabilities from the logits. q_i and p_i denote the model and assistant (respectively)
# probabilities of the tokens selected by the assistant.
q = candidate_logits.softmax(dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with it as there's good comments and other variables are well names e.g. is_rejected :)

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
Comment on lines +4679 to +4683
if do_sample:
probs = new_logits.softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits.argmax(dim=-1)
Copy link
Contributor

@patrickvonplaten patrickvonplaten Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if do_sample:
probs = new_logits.softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits.argmax(dim=-1)
if do_sample:
probs = new_logits.softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits.argmax(dim=-1)

It's probably time to soon factor this out into something like:

selected_tokens = Categorical(new_logits / temperature).sample()

everywhere in generate

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Then equivalent sampling/non-sampling methods (e.g. greedy decoding/samplinh) could be merged into a single function, facilitating maintenance. I'm going to leave it to a follow-up PR, though, to keep this PR exclusively about speculative decoding.

else:
selected_tokens = new_logits.argmax(dim=-1)
if do_sample:
probs = new_logits.softmax(dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this case still relevant? Not sure it's a good idea to have two "assisted decoding" do_sample=True cases in our generate. Should we maybe just deprecate this case?

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super cool addition!

Not really related to this PR, but I feel like we should start putting all the generation submethods (assisted decoding, greedy & sample (guess we can merge these two), beam search, ...) into their own files by now

My only important comment here is that I don't think it's great that we have 2 assisted generation cases now where do_sample=True. Can we deprecate the "non-official" one?

@gante
Copy link
Member Author

gante commented Dec 17, 2023

@patrickvonplaten the two types of sampling are needed :D

New candidate-based methods are popping up (e.g. #27775), and they don't necessarily have logits. As such, speculative decoding, which needs the candidates' logits, can't be applied to those methods.

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten the two types of sampling are needed :D

New candidate-based methods are popping up (e.g. #27775), and they don't necessarily have logits. As such, speculative decoding, which needs the candidates' logits, can't be applied to those methods.

But shouldn't they just be the "own" method now? I.e. I don't think we should put #27775 into the speculative decoding method no?

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@gante
Copy link
Member Author

gante commented Dec 18, 2023

@patrickvonplaten #27775 does not introduce changes to assisted generation 🤗 In #28030 I've abstracted the candidate generation part of assisted generation. We now load candidate generators the same way as we load the logits processors:

def _get_candidate_generator(
self,
generation_config: GenerationConfig,
input_ids: torch.LongTensor,
inputs_tensor: torch.Tensor,
assistant_model: "PreTrainedModel",
logits_processor: LogitsProcessorList,
model_kwargs: Dict,
) -> CandidateGenerator:
"""
Returns the candidate generator to be used in `assisted_generation`
"""
candidate_generator = AssistedCandidateGenerator(
input_ids=input_ids,
assistant_model=assistant_model,
logits_processor=logits_processor,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
eos_token_id=generation_config.eos_token_id,
)
return candidate_generator

In assisted generation, we call the candidate generator to get candidate sequences (which may or may not contain associated logits, depending on the method)

candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)

The technique in #27775 can thus be added by adding a new candidate generator in _get_candidate_generator. Other candidate generators may be added the same way, enabling users to experiment with the concept of candidates!

Because needing the logits (for speculative decoding) is a very limiting constraint, I'd rather keep the two sampling paths.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante
Copy link
Member Author

gante commented Dec 18, 2023

@amyeroberts PR comments addressed 🤗

@patrickvonplaten Unless you don't strongly oppose, I'd like to keep the two sampling paths, for the reasons I've written here -- I think it will be beneficial in the long run! :) (otherwise, a whole new generation method has to be written for #27775)

@gante
Copy link
Member Author

gante commented Dec 18, 2023

@amyeroberts -- @patrickvonplaten and I had a chat about whether to keep the two sampling paths or not. For context, here's what we agreed on:

  • It's okay to leave it as is, and perhaps abstract the different ways we accept candidates into a candidate_checker block.
  • Be conservative on adding new candidate generators, so we don't end up with unused methods
  • [in a follow-up PR] squash other cases where the decoding method is the same except for the token selection, like greedy_decoding + sample
  • [in a follow-up PR] mode each decoding method into its own file. There are several private functions in generation/utils.py that are exclusively used with one generation method.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!

@gante gante merged commit ac97419 into huggingface:main Dec 19, 2023
21 checks passed
@gante gante deleted the candidate_generate_refactor branch January 9, 2024 16:06
@jmamou
Copy link
Contributor

jmamou commented Jan 18, 2024

@gante
According to experiments reported in Leviathan's paper, speculative decoding (SD) has higher speedup with greedy decoding (temp=0). However, in the current implementation, SD works only with do_sample=True.

@gante
Copy link
Member Author

gante commented Jan 19, 2024

@jmamou speculative decoding with do_sample=False (or temp=0) was already encoded in assisted_generation, before this PR -- try calling model.generate(input_ids, do_sample=False, assistant_model=assistant_model) :)

@jmamou
Copy link
Contributor

jmamou commented Jan 21, 2024

@gante
Since acceptance criteria are different between speculative decoding and assisted generation, I think that it would be great to be able to run both speculative decoding and assisted generation with no sampling.

@jmamou
Copy link
Contributor

jmamou commented Jan 21, 2024

@gante
I implemented it. I can submit a PR.

@jmamou
Copy link
Contributor

jmamou commented Jan 24, 2024

@gante
In previous implementation of assisted generation (4.33) with heuristical update of num_assistant_tokens (or max_assistant_tokens), the value of num_assistant_tokens was preserved between 2 consecutive generate() calls.

In current implementation (4.38), num_assistant_tokens is updated by the candidate_generator during the generation but assistant_model.generation_config.num_assistant_tokens is not updated at the end of the generation. Therefore, next call to generate will start with the initial value of assistant_model.generation_config.num_assistant_tokens (5).

Is it intentional? If that's a bug, I can open a PR to fix it.

@gante
Copy link
Member Author

gante commented Jan 27, 2024

@jmamou

Since acceptance criteria are different between speculative decoding and assisted generation, I think that it would be great to be able to run both speculative decoding and assisted generation with no sampling.

Not sure if this is a good idea

  1. if we see greedy decoding as applying temperature=0, the model probability will be 1 at the most likely token and 0 everywhere else. In turn, this implies that p_i/q_i is >=1 at all positions, and thus all candidate tokens would be accepted 👉 speculative decoding would be the same as simply using the assistant model
  2. If we don't apply temperature=0, then it would be sampling -- in other words, it wouldn't be greedy decoding

In previous implementation of assisted generation (4.33) with heuristical update of num_assistant_tokens (or max_assistant_tokens), the value of num_assistant_tokens was preserved between 2 consecutive generate() calls.
In current implementation (4.38), num_assistant_tokens is updated by the candidate_generator during the generation but assistant_model.generation_config.num_assistant_tokens is not updated at the end of the generation. Therefore, next call to generate will start with the initial value of assistant_model.generation_config.num_assistant_tokens (5).
Is it intentional? If that's a bug, I can open a PR to fix it.

This is a good point! A PR to revert to the previous behaviour (with a test) would be appreciated 🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants