-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft layer skip addition #34240
base: main
Are you sure you want to change the base?
Draft layer skip addition #34240
Conversation
status: code runs, output is gibberish. Numerical debugging after lunch to figure out what's wrong |
EDIT: Please ignore my comment above. I thought the output of early exit was gibberish but I think it was the output of self-speculative decoding was gibberish. Yes, self-speculative decoding should have same quality as last layer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi. My name is Mostafa and I am one of the main authors of the LayerSkip paper!
Thanks for working on this PR so quickly! I have provided some comments.
Also, for the future, I have some suggestions to consider:
early_exit
arg in generation could be come a callable function for researchers to experiment with dynamic early exit, i.e., a different condition or heurestic to exit for each token (e.g., cosing similarity between a layers input and output above a certain threshold). This is done in papers like CALM.- adapter modules for early exit. Rather than just exiting by jumping to the model's LM head, users may opt to add their own separate LM head or even add their own adapter layers when exiting. This is done in a paper like Kangaroo.
- Different types of self-speculative decoding, e.g.,
- Draft stage uses a subset of KV cache. This is done in MagicDec.
I am happy to discuss online or offline how we can add more features along this direction to enable researchers to unlock a lot of early exit ideas.
Alternatively, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed | ||
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259). | ||
|
||
#### Early Exit (Self-Speculative Decoding) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to mention that "Early Exit" is orthogonal to "Self-Speculative Decoding". I mean that there are other self-speculative decoding approaches that use methods other than early-exit, e.g.,
- Draft & Verify: Draft stage is skipping intermediate FFNs and attention layers of the model
- MagicDec: Draft stage is attending to a window of the KV cache rather than the ful KV-cache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! I'm going to rename it to self-speculative decoding, and mention that we support self-speculative decoding by specifying an exit layer
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) | ||
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here torch.cat
will only be correct if min(new_positions) == previous_length + 1
? If that's correct, should we also add an assert
statement for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, that is correct!
I'm not going to add any check for now, though, and rely on internal tests to detect issues: adding a check here would hurt throughput in the forward pass, and a test can immediately detect issues :)
generation_config: "GenerationConfig", | ||
model_kwargs: Dict, | ||
inputs_tensor: Optional[torch.Tensor] = None, | ||
logits_processor: "LogitsProcessorList" = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, I recently also added stopping_criteria
as well to support integration with Eleuther LM Eval Harness:
facebookresearch/LayerSkip@e38784d
@@ -887,7 +887,7 @@ def forward( | |||
all_self_attns = () if output_attentions else None | |||
next_decoder_cache = None | |||
|
|||
for decoder_layer in self.layers: | |||
for decoder_layer in self.layers[: self.num_hidden_layers]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Smart! I like that simple change that enables flexibility.
early_exit_outputs = model.generate(**inputs, early_exit=4, do_sample=False, max_new_tokens=20) | ||
early_exit_decoded = tokenizer.batch_decode(early_exit_outputs, skip_special_tokens=True) | ||
self.assertEqual(early_exit_decoded, [expected_output]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest adding an assertion check to ensure the output of a model prunted to early_exit
layers has the identical output as the same model with early_exit
arg in generation
# Remove layers manually | |
model = model.model.layers[:4] | |
del model.model.layers[4:] | |
model.num_hidden_layers = 4 | |
manual_early_exit_outputs = model.generate(**inputs, do_sample=False, max_new_tokens=20) | |
manual_early_exit_decoded = tokenizer.batch_decode(manual_early_exit_outputs, skip_special_tokens=True) | |
self.assertEqual(early_exit_decoded, manual_early_exit_decoded) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might have misunderstood the code, does model.generate(**inputs, early_exit=4, do_sample=False, max_new_tokens=20)
perform static early exit, or does it perform self-speculative early-exit decoding?
Personally, I would suggest to separate them some how:
- Static early exit:
model.generate(**inputs, early_exit=4)
- Self-speculative decoding, early exit:
model.generate(**inputs, assisstant_model={"early_exit": 4})
or something like that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The interface is indeed confusing -- the demo above was meant to run self-speculative early-exit decoding.
I see two options:
model.generate(**inputs, assistant_early_exit=4)
-- make the name of the argument more precisemodel.generate(**inputs, assistant_model=model, early_exit=4)
-- withassistant_model
set we know we are doing speculative decoding, so the use ofearly_exit
becomes more self-evident.
I was thinking of going with option 2, since we could then do model.generate(**inputs, early_exit=4)
to run static early exit. WDYT?
(btw, in the long run, we will mode ALL assisted generation/speculative decoding args into a assistant_kwags
dictionary, otherwise things will get messy soon)
Hi Mostafa (@mostafaelhoushi) 👋 Glad to see you here! My utmost goal for this PR is to get Layer Skip to the hands of our users with a) good throughput numbers b) a simple interface. Self-speculative decoding is indeed the best of both worlds for low batch sizes 💪 I appreciate the extra suggestions, but they add significant complexity -- e.g. if we accept callable for self-speculative decoding, we might want to apply the callable in different positions. Keeping things somewhat simple means others can just fork what we have and implement their idea quickly on top! It also makes our maintenance job doable 🤗 Naturally, if a given technique shows clear advantages and can be applied on pre-trained weights without too much complexity, like layer skip, we'll jump straight to implementation. (For instance, a few years ago we implemented a complex constrained decoding method, before json generation became popular. However, because the implementation was complex and it was somewhat niche, it quickly became unmaintained -- we got the additional code bloat with no relevant benefits) Sorry to be a turn off -- I really appreciate the ideas coming in! |
What does this PR do?
Test script: