Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for lookahead decoding for autoregressive (decoder + encoder-decoder) models #27649

Open
shermansiu opened this issue Nov 22, 2023 · 9 comments
Labels
Feature request Request for a new feature

Comments

@shermansiu
Copy link
Contributor

Feature request

Fu et al. propose a novel decoding technique that accelerates greedy decoding on Llama 2 and Code-Llama by 1.5-2x across various parameters sizes, without a draft model. This method can be extended to work on beam search decoding.

Blog post: https://lmsys.org/blog/2023-11-21-lookahead-decoding/
Code: https://github.com/hao-ai-lab/LookaheadDecoding

Motivation

Lookahead decoding provides a massive speedup at a worthwhile tradeoff (namely, a windowed n-gram cache and a custom attention mask). There have been other proposals to integrate lookahead decoding in other libraries like TGI or vLLM, but it seems that for this specific feature, it would be best integrated into the core transformers library the same way that Flash Attention has.

Your contribution

I'm busy with thesis work, but I can submit a PR based on the original implementation here if I have time.

@ArthurZucker ArthurZucker added the Feature request Request for a new feature label Nov 22, 2023
@ArthurZucker
Copy link
Collaborator

FYI @gante so we keep track of this. Shared offline but might be good after the cache refactoring.

@shermansiu
Copy link
Contributor Author

shermansiu commented Nov 22, 2023

The current reference implementation builds directly on top of Huggingface transformers, but the authors have mentioned that they plan to release a custom CUDA kernel to speed up the method.

Should we wait for this kernel? (My opinion: No, we shouldn't wait. Plus, I'm skeptical about whether such a kernel would be compatible with Flash Attention's own CUDA kernel, but we'll see.)

@shermansiu
Copy link
Contributor Author

Cache refactoring PR: #26681

@shermansiu
Copy link
Contributor Author

While we're waiting for the KV cache refactor to be completed, I think it might be worth considering how exactly to manage the Lookahead Decoding configuration, especially since there are a few associated parameters with it (e.g. the lookahead window size, the N-gram size).

I suppose it would be better to introduce a LookaheadDecoderConfig dataclass for this?

@ArthurZucker
Copy link
Collaborator

No I think these can just be passed in the generation config.

@gante
Copy link
Member

gante commented Nov 23, 2023

Hi @shermansiu 👋

Before commenting here, I've spent some time playing with lookahead decoding. In particular, using a modified version of their minimal.py, so I could benchmark against datasets. I'm pasting an example in the collapsible below:

LADE test script
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import time
import torch
import os
if int(os.environ.get("LOAD_LADE", 0)):
  import lade
  lade.augment_all()
  # lade.config_lade(LEVEL=7, WINDOW_SIZE=20, GUESS_SET_SIZE=20, DEBUG=1)
  lade.config_lade(LEVEL=4, WINDOW_SIZE=8, GUESS_SET_SIZE=8, DEBUG=1)

assert torch.cuda.is_available()

num_samples = 20
device = "cuda:0"
model_name = "meta-llama/Llama-2-7b-chat-hf"
# model_name = "TheBloke/Llama-2-7B-Chat-AWQ"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device)
# model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, use_flash_attention_2=True)
model.tokenizer = tokenizer

ds = load_dataset("cnn_dailymail", "3.0.0", split="validation", streaming=True)
ds_iterator = iter(ds.take(num_samples))

torch.cuda.reset_peak_memory_stats("cuda")
torch.cuda.empty_cache()
torch.cuda.synchronize()

#warm up
greedy_output = model.generate(torch.ones((1, 10), dtype=torch.long, device=device), max_new_tokens=1)
#end warm up

ellapsed_time = 0
generated_tokens = 0
for _ in range(num_samples):
  chat = [
      {"role": "system", "content": "You are a helpful model that summarizes a given article."},
      {"role": "user", "content": next(ds_iterator)["article"]}
  ]

  input_ids = tokenizer.apply_chat_template(chat, return_tensors='pt').to(device)
  start = time.time()
  greedy_output = model.generate(input_ids, max_new_tokens=2048, do_sample=False)
  end = time.time()

  generated_tokens += greedy_output.numel() - input_ids.numel()
  ellapsed_time += end - start

max_memory = torch.cuda.max_memory_allocated("cuda")
print("\nMax memory (MB): ", max_memory * 1e-6)
print("AVG Generated Tokens: ", (generated_tokens / num_samples))
print("AVG Generation Speed: ", (generated_tokens / ellapsed_time), " tokens/s")

Here are some findings:
👉 As mentioned in the blog post, you are increasing FLOPS to get additional LLM throughput. All is good if the model is small for your device, but it's hard to achieve speedups using modest models on consumer GPUs (e.g. 7B models in a 3090)
👉 After some fiddling with the LADE parameters, I was able to get a 25% speedup on a 7B model in a 3090, compared to the model without FA2. Running with their default parameterization actually slows the model down by 33%, despite achieving a high compression ratio (= FLOPS is the bottleneck)
👉 Doesn't work correctly with FA2: the output is significantly different
👉 Works with BNB, but I didn't manage to get a speedup on my setup, only slowdowns
👉 Works with AWQ, same findings as in the case without quantization

On top of that, from the blog post we know that:
👉 It requires changes in the modeling code of each model, so it will require a lot of work to add and to maintain
👉 It is limited to greedy decoding, meaning that it doesn't support the most common use case (do_sample=True)
👉 Batching with this technique is much trickier -- just like in speculative decoding/assisted generation, we may have more than one accepted token per forward pass


The idea does look very promising -- it would be amazing to be able to speed up a model without relying on external models. However, the current benefits are limited to GPU-rich users using a GPU oversized for the task at hand, and the addition costs are heavy, especially with model-level changes. The original code is also open-source and transformers-compatible, despite being limited to llama.

If a model-independent solution can be achieved, more positive benchmarks are found, or upgrades to the technique are released, I'd be happy to reconsider this decision!

Let's keep this issue open for discussion 🤗

@shermansiu
Copy link
Contributor Author

^ Some of the acronyms in the above response:
LADE = Lookahead decoding
FA2 = Flash Attention 2
BNB: Bitsandbytes
AWQ: Activation-aware Weight Quantization.

@shermansiu
Copy link
Contributor Author

The authors mentioned that they are working on an FA2-compatible CUDA kernel, so hopefully we'll see better results soon!

@knagrecha
Copy link

BTW, here's a PR where we are looking at adding sampling support.

hao-ai-lab/LookaheadDecoding#6

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature
Projects
None yet
Development

No branches or pull requests

4 participants