Kaixuan Huang, Xudong Guo,
Mengdi Wang
Princeton University
ICML 2024 workshop on Efficient Systems for Foundation Models (ES-FoMo)
We propose SpecDec++, an enhanced version of speculative decoding that adaptively determines the candidate length with the help of a trained acceptance prediction head. Our method can boost the performance of speculative decoding and can be combined with other tricks like fused kernel, quantization, and advanced KV cache management.
*Tested with llama-2-chat 7B & 70B model pair (bfloat16) on 2 NVIDIA A100-80G GPUs.
- Quick Links
- Overview of Speculative Decoding
- Problem: Determination of the candidate length $K$.
- Using
SpecDec++
- Training and Evaluation
In speculative decoding, the draft model first generates
Following the first rejected token, the algorithm discards the remaining tokens and corrects the rejected token with a fresh sample from a modified distribution.
If all tokens are accepted, a new token is sampled from the next-token probability given by the target model and appended to the sequence of accepted tokens, and then the process moves forward.
SpecDec++
aims to find a theoretically justifiable approach towards the following problem: what is a proper candidate length that generates as many accepted tokens and wastes as few discarded tokens as possible?
We formalize the dynamic choice of candidate length in speculative decoding as a Markov Decision Process (MDP). We theoretically show that when the probability that at least one token gets rejected exceeds a threshold, the optimal action is to stop the speculation and submit it for verification:
We augment the draft model with a trained acceptance prediction head to predict the conditional acceptance probability of the candidate tokens. SpecDec++
will stop the current speculation round when the predicted probability that at least one token gets rejected exceeds a threshold.
SpecDec++
has better Pareto frontiers than SpecDec
on both the in-distribution dataset Alpaca and the two out-of-distribution datasets HumanEval and GSM8K. Please check our paper for more details.
Step 0 (Optional): To start with, prepare a conda environment with pytorch installed. If not, you can use the following command.
conda create -n specdecpp python=3.11
conda activate specdecpp
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
Step 1: Clone the repository and install the required packages.
git clone git@github.com:Kaffaljidhmah2/SpecDec_pp.git
cd SpecDec_pp
pip install -r requirements.txt
The checkpoint of our best acceptance prediction head for llama-2-chat 7B & 70B model pair is available at huggingface hub.
Please take a look at specdec_pp/sample.py for how to use SpecDec++.
Follow the instructions in data/readme.md for dataset preparation. After running the code, you should be able to get the Alpaca dataset (data/alpaca_data/train.json
, data/alpaca_data/dev.json
, data/alpaca_data/test.json
), HumanEval dataset (data/humaneval_data/test.json
), and GSM8K test dataset (data/gsm8k_test_data/test.json
) for llama-2-chat models.
Please modify the following code for training. Here layer
indicates the number of layers of the ResNet prediction head, weight
is the loss weight for the mismatched tokens for the BCE loss (the weight for the matched tokens is 1
). The mixing ratio can be set via --mixing_ratio
(default is 0.15).
layer=3
weight=6
draft_model=meta-llama/Llama-2-7b-chat-hf
WANDB_PROJECT=specdecpp python3 specdec_pp/train.py \
--data_path data/alpaca_data/train.json \
--eval_data_path data/alpaca_data/dev.json \
--output_dir exp-weight${weight}-layer${layer} \
--model_name_or_path ${draft_model} \
--bf16 True \
--per_device_train_batch_size 4 \
--num_train_epochs 3 \
--gradient_accumulation_steps 8 \
--logging_steps 5 \
--evaluation_strategy epoch \
--per_device_eval_batch_size 4 \
--weight_mismatch ${weight} \
--save_strategy no \
--warmup_ratio 0.03 \
--lr_scheduler_type cosine \
--resnet_num_layers ${layer} \
--mixing_ratio 0.15
Note: --num_assistant_tokens_schedule ada
indicates the proposed SpecDec++ method, where the ckeckpoint of the acceptance prediction head should be specified via --assist_acc_head_dir
. --stop_threshold
indicates the threshold value (between 0 and 1) used to stop the current speculation round. A larger stop_threshold
indicates longer speculation rounds. --bound MIN MAX
indicates the minimum number and the maximum number of candidate tokens for one speculation round.
layer=3
weight=6
thres=0.3
ckpt=exp-weight${weight}-layer${layer}
target_model=meta-llama/Llama-2-70b-chat-hf
draft_model=meta-llama/Llama-2-7b-chat-hf
data=data/alpaca_data/test.json
SAVEPATH=test-results-alpaca/weight${weight}-layer${layer}-thres${thres}-bound2_20/
python3 specdec_pp/evaluate.py \
--model_name ${target_model} \
--assistant_name ${draft_model} \
--num_assistant_tokens_schedule ada \
--data_path ${data} \
--assist_acc_head_dir $ckpt\
--do_sample \
--random_seed 42 \
--save_path ${SAVEPATH} \
--stop_threshold ${thres} \
--bound 2 20
The result will be stored under the folder ${SAVEPATH}
.
Note: --num_assistant_tokens_schedule constant
indicates the baseline SpecDec method. --num_assistant_tokens
means the constant number of candidate tokens generated per speculation round.
target_model=meta-llama/Llama-2-70b-chat-hf
draft_model=meta-llama/Llama-2-7b-chat-hf
K=4
data=data/alpaca_data/test.json
SAVEPATH=test-results-alpaca/baseline-${K}/
python3 specdec_pp/evaluate.py \
--model_name ${target_model} \
--assistant_name ${draft_model} \
--num_assistant_tokens_schedule constant \
--num_assistant_tokens ${K} \
--data_path ${data} \
--do_sample \
--random_seed 42 \
--save_path ${SAVEPATH} \
Note: --num_assistant_tokens_schedule none
indicates the baseline SpecDec method.
target_model=meta-llama/Llama-2-70b-chat-hf
draft_model=meta-llama/Llama-2-7b-chat-hf
data=data/alpaca_data/test.json
SAVEPATH=test-results-alpaca/standalone/
python3 specdec_pp/evaluate.py \
--model_name ${target_model} \
--assistant_name ${draft_model} \
--num_assistant_tokens_schedule none \
--data_path ${data} \
--do_sample \
--random_seed 42 \
--save_path ${SAVEPATH} \
[
{
## key-value pairs for prompt, continuation, prefix, tokens, draft, p_acc, and id
## for SpecDec & SpecDec++
"spec_time": 15.580421447753906,
"num_mismatched_tokens": 20,
"num_LM_call": 67,
"generated_length": 180,
## for standalone target model / draft model
"target_time": 25.6504251956939,
"draft_time": 2.795105218887329,
"generated_length_target": 203,
"generated_length_draft": 134
}
]
Feel free to send an email to kaixuanh@princeton.edu
or create a GitHub Issue/Pull request.
If you find this useful in your research, please consider citing our paper.
@article{huang2024specdec++,
title={SpecDec++: Boosting Speculative Decoding via Adaptive Candidate Lengths},
author={Huang, Kaixuan and Guo, Xudong and Wang, Mengdi},
journal={arXiv preprint arXiv:2405.19715},
year={2024}
}