Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft layer skip addition #34240

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft

Draft layer skip addition #34240

wants to merge 3 commits into from

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Oct 18, 2024

What does this PR do?


Test script:

from transformers import AutoTokenizer, AutoModelForCausalLM
import time

expected_output = [""]

prompt = "Alice and Bob"
checkpoint = "facebook/layerskip-llama3.2-1B"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

model = AutoModelForCausalLM.from_pretrained(checkpoint).to("cuda")
original_outputs = model.generate(**inputs, do_sample=False, max_new_tokens=20)  #warmup

start = time.time()
original_outputs = model.generate(**inputs, do_sample=False, max_new_tokens=20)
end = time.time()
print(f"Original: {end-start}")
print(f"Output text", tokenizer.batch_decode(original_outputs, skip_special_tokens=True))

start = time.time()
early_exit_outputs = model.generate(**inputs, early_exit=4, do_sample=False, max_new_tokens=20)
end = time.time()
print(f"Early Exit: {end-start}")
print(f"Early Exit text", tokenizer.batch_decode(early_exit_outputs, skip_special_tokens=True))

ArthurZucker and others added 2 commits October 18, 2024 11:31
* mvp

* docs and tests
@gante
Copy link
Member

gante commented Oct 21, 2024

status: code runs, output is gibberish. Numerical debugging after lunch to figure out what's wrong

@ArthurZucker ArthurZucker changed the title Draft Draft layer skip addition Oct 21, 2024
@mostafaelhoushi
Copy link
Contributor

mostafaelhoushi commented Oct 22, 2024

status: code runs, output is gibberish. Numerical debugging after lunch to figure out what's wrong

Can you show a sample of how the output looks like? It is expected to be of lower quality but interested to see how gibberish it would be.

Also, maybe try a later layer like layer 14 or 13 to see if it's still gibberish?

EDIT: Please ignore my comment above. I thought the output of early exit was gibberish but I think it was the output of self-speculative decoding was gibberish. Yes, self-speculative decoding should have same quality as last layer.

Copy link
Contributor

@mostafaelhoushi mostafaelhoushi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi. My name is Mostafa and I am one of the main authors of the LayerSkip paper!
Thanks for working on this PR so quickly! I have provided some comments.

Also, for the future, I have some suggestions to consider:

  • early_exit arg in generation could be come a callable function for researchers to experiment with dynamic early exit, i.e., a different condition or heurestic to exit for each token (e.g., cosing similarity between a layers input and output above a certain threshold). This is done in papers like CALM.
  • adapter modules for early exit. Rather than just exiting by jumping to the model's LM head, users may opt to add their own separate LM head or even add their own adapter layers when exiting. This is done in a paper like Kangaroo.
  • Different types of self-speculative decoding, e.g.,
    • Draft stage uses a subset of KV cache. This is done in MagicDec.

I am happy to discuss online or offline how we can add more features along this direction to enable researchers to unlock a lot of early exit ideas.

Alternatively, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).

#### Early Exit (Self-Speculative Decoding)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to mention that "Early Exit" is orthogonal to "Self-Speculative Decoding". I mean that there are other self-speculative decoding approaches that use methods other than early-exit, e.g.,

  • Draft & Verify: Draft stage is skipping intermediate FFNs and attention layers of the model
  • MagicDec: Draft stage is attending to a window of the KV cache rather than the ful KV-cache

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I'm going to rename it to self-speculative decoding, and mention that we support self-speculative decoding by specifying an exit layer

Comment on lines +453 to +454
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here torch.cat will only be correct if min(new_positions) == previous_length + 1? If that's correct, should we also add an assert statement for that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that is correct!

I'm not going to add any check for now, though, and rely on internal tests to detect issues: adding a check here would hurt throughput in the forward pass, and a test can immediately detect issues :)

generation_config: "GenerationConfig",
model_kwargs: Dict,
inputs_tensor: Optional[torch.Tensor] = None,
logits_processor: "LogitsProcessorList" = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, I recently also added stopping_criteria as well to support integration with Eleuther LM Eval Harness:
facebookresearch/LayerSkip@e38784d

@@ -887,7 +887,7 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
for decoder_layer in self.layers[: self.num_hidden_layers]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smart! I like that simple change that enables flexibility.

early_exit_outputs = model.generate(**inputs, early_exit=4, do_sample=False, max_new_tokens=20)
early_exit_decoded = tokenizer.batch_decode(early_exit_outputs, skip_special_tokens=True)
self.assertEqual(early_exit_decoded, [expected_output])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest adding an assertion check to ensure the output of a model prunted to early_exit layers has the identical output as the same model with early_exit arg in generation

Suggested change
# Remove layers manually
model = model.model.layers[:4]
del model.model.layers[4:]
model.num_hidden_layers = 4
manual_early_exit_outputs = model.generate(**inputs, do_sample=False, max_new_tokens=20)
manual_early_exit_decoded = tokenizer.batch_decode(manual_early_exit_outputs, skip_special_tokens=True)
self.assertEqual(early_exit_decoded, manual_early_exit_decoded)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might have misunderstood the code, does model.generate(**inputs, early_exit=4, do_sample=False, max_new_tokens=20) perform static early exit, or does it perform self-speculative early-exit decoding?

Personally, I would suggest to separate them some how:

  • Static early exit: model.generate(**inputs, early_exit=4)
  • Self-speculative decoding, early exit: model.generate(**inputs, assisstant_model={"early_exit": 4}) or something like that

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interface is indeed confusing -- the demo above was meant to run self-speculative early-exit decoding.

I see two options:

  1. model.generate(**inputs, assistant_early_exit=4) -- make the name of the argument more precise
  2. model.generate(**inputs, assistant_model=model, early_exit=4) -- with assistant_model set we know we are doing speculative decoding, so the use of early_exit becomes more self-evident.

I was thinking of going with option 2, since we could then do model.generate(**inputs, early_exit=4) to run static early exit. WDYT?

(btw, in the long run, we will mode ALL assisted generation/speculative decoding args into a assistant_kwags dictionary, otherwise things will get messy soon)

@gante
Copy link
Member

gante commented Oct 22, 2024

Hi Mostafa (@mostafaelhoushi) 👋 Glad to see you here!

My utmost goal for this PR is to get Layer Skip to the hands of our users with a) good throughput numbers b) a simple interface. Self-speculative decoding is indeed the best of both worlds for low batch sizes 💪

I appreciate the extra suggestions, but they add significant complexity -- e.g. if we accept callable for self-speculative decoding, we might want to apply the callable in different positions. Keeping things somewhat simple means others can just fork what we have and implement their idea quickly on top! It also makes our maintenance job doable 🤗 Naturally, if a given technique shows clear advantages and can be applied on pre-trained weights without too much complexity, like layer skip, we'll jump straight to implementation.

(For instance, a few years ago we implemented a complex constrained decoding method, before json generation became popular. However, because the implementation was complex and it was somewhat niche, it quickly became unmaintained -- we got the additional code bloat with no relevant benefits)

Sorry to be a turn off -- I really appreciate the ideas coming in!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants