Gongfan Fang, Hongxu Yin, Saurav Muralidharan, Greg Heinrich
Jeff Pool, Jan Kautz, Pavlo Molchanov, Xinchao Wang
NVIDIA Research, National University of Singapore
π [ArXiv] | π― [Project Page] | π [License]
This work introduces MaskLLM, a learnable pruning method that establishes Semi-structured (or ``N:M'') Sparsity in LLMs, aimed at reducing computational overhead during inference.
The following section provides an example for MaskLLM-LLaMA-2/3 on a single node with 8 GPUs. The LLaMA model will be shard across 8 GPUs with tensor parallelism, taking ~40GB per GPU for end-to-end training.
Docker is required for Megatron-LM. We use the official PyTorch docker image pytorch:24.01-py3
from NVIDIA NGC as the base image. If you can not use docker, please refer to the official setup instructions in Megatron-LM. Run the following command to download & start the docker container and mount your home directory.
docker run --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -v $HOME:$HOME -it --rm nvcr.io/nvidia/pytorch:24.01-py3
In the container, we need to download the LLaMA checkpoints and convert them to Megatron format.
Install basic dependencies.
pip install transformers accelerate datasets SentencePiece wandb tqdm ninja tensorboardx==2.6 pulp timm einops
The following scripts download and save all HF checkpoints at ./assets/checkpoints
.
python scripts/tools/download_llama2_7b_hf.py
python scripts/tools/download_llama2_13b_hf.py
python scripts/tools/download_llama3_8b_hf.py
assets
βββ checkpoints
β βββ llama2_13b_hf
β βββ llama2_7b_hf
β βββ llama3_8b_hf
Convert the downloaded HF checkpoint to Megatron format, with tp=8
for tensor parallelism.
bash scripts/tools/convert_llama2_7b_hf_to_megatron.sh
bash scripts/tools/convert_llama2_13b_hf_to_megatron.sh
bash scripts/tools/convert_llama3_8b_hf_to_megatron.sh
assets/
βββ checkpoints
β βββ llama2_13b_hf
β βββ llama2_13b_megatron_tp8 # <= Megatron format
β βββ llama2_7b_hf
β βββ llama2_7b_megatron_tp8
β βββ llama3_8b_hf
β βββ llama3_8b_megatron_tp8
Evaluate the dense model with the arguments size (7b/8b/13b)
, tensor parallelism (8)
, and sparsity (dense or sparse)
.
bash scripts/ppl/evaluate_llama2_wikitext2.sh assets/checkpoints/llama2_7b_megatron_tp8 7b 8 dense
bash scripts/ppl/evaluate_llama2_wikitext2.sh assets/checkpoints/llama2_13b_megatron_tp8 13b 8 dense
bash scripts/ppl/evaluate_llama3_wikitext2.sh assets/checkpoints/llama3_8b_megatron_tp8 8b 8 dense
# Outputs for LLaMA-2 7B:
validation results on WIKITEXT2 | avg loss: 1.6323E+00 | ppl: 5.1155E+00 | adjusted ppl: 5.1155E+00 | token ratio: 1.0 |
# Outputs for LLaMA-2 13B:
validation results on WIKITEXT2 | avg loss: 1.5202E+00 | ppl: 4.5730E+00 | adjusted ppl: 4.5730E+00 | token ratio: 1.0 |
# Outputs for LLaMA-3 8B:
validation results on WIKITEXT2 | avg loss: 1.7512E+00 | ppl: 5.7615E+00 | adjusted ppl: 5.7615E+00 | token ratio: 1.0 |
Our paper uses a blended internal data for training. For reproducibility, we provide an example of learning masks on a subset of the public allenai/c4 dataset. Corresponding results can be found in Appendix D of our paper. Please see docs/preprocess_c4.md for the instructions.
It is encouraged to start training with a prior mask, either generated by SparseGPT, Wanda or Magnitude Pruning. The following scripts prune an LLaMA-2 7B model with 2:4 patterns. For SparseGPT, weight update is disabled. Add an argument --update-weight
if necessary. More scripts for LLaMA-2 13B and LLaMA-3 8B are available at scripts/oneshot.
# <= SparseGPT mask
bash scripts/oneshot/run_llama2_7b_prune_tp8.sh hessian # --update-weight
# <= Magnitude mask
bash scripts/oneshot/run_llama2_7b_prune_tp8.sh magnitude # --update-weight
# <= Wanda mask
bash scripts/oneshot/run_llama2_7b_prune_tp8.sh wanda # --update-weight
The pruned Llama model will contain additional .mask
parameters in sparse linears, such as module.language_model.encoder.layers.31.mlp.dense_h_to_4h.mask
.
output/
βββ oneshot_pruning
β βββ checkpoint
β β βββ llama2-7b-tp8.sparse.nmprune.sp0.5hessian.ex0
β β βββ llama2-7b-tp8.sparse.nmprune.sp0.5magnitude.ex0
β β βββ llama2-7b-tp8.sparse.nmprune.sp0.5wanda.ex0
β βββ llama2-7b-tp8.sparse.nmprune.sp0.5hessian.ex0.log
β βββ llama2-7b-tp8.sparse.nmprune.sp0.5magnitude.ex0.log
β βββ llama2-7b-tp8.sparse.nmprune.sp0.5wanda.ex0.log
To evaluate the pruned model:
bash scripts/ppl/evaluate_llama2_wikitext2.sh output/oneshot_pruning/checkpoint/llama2-7b-tp8.sparse.nmprune.sp0.5hessian.ex0 7b 8 sparse
Mask Sampling | Visualization |
---|---|
By default, the script will load SparseGPT prior. Please modify the path in the script to load other masks. Here 0 means the initial training, and 1 means continue training from the latest checkpoint.
# Initial training with a prior mask.
# By default, the script will load output/oneshot_pruning/checkpoint/llama2-7b-tp8.sparse.nmprune.sp0.5hessian.ex0 as the mask prior
bash scripts/learnable_sparsity/llama2_7b_mask_only_tp8_c4.sh 0
# Pass the argument 1 to continue the training from the latest checkpoint
bash scripts/learnable_sparsity/llama2_7b_mask_only_tp8_c4.sh 1
For inference, we only need those winner masks with the highest probability. The following command will trim the checkpoint and remove unnecessary components.
python tool_trim_learnable_sparsity.py --ckpt_dir output/checkpoints/llama2-7b-tp8-mask-only-c4-singlenode/train_iters_2000/ckpt/iter_0002000
Please modify the content in latest_checkpointed_iteration.txt
as release
for loading. This will set up a clean checkpoint with additional .mask
parameters for each sparse layer.
# For llama2 7b & 13b
bash scripts/ppl/evaluate_llama2_wikitext2.sh output/checkpoints/llama2-7b-tp8-mask-only-c4-singlenode/train_iters_2000/ckpt/ 7b 8 sparse
bash scripts/ppl/evaluate_llama2_wikitext2.sh output/checkpoints/llama2-13b-tp8-mask-only-c4-singlenode/train_iters_2000/ckpt/ 13b 8 sparse
# For llama3 8b
bash scripts/ppl/evaluate_llama3_wikitext2.sh output/checkpoints/llama3-8b-tp8-mask-only-c4-singlenode/train_iters_2000/ckpt/ 8b 8 sparse
Please see docs/export_hf.md for instructions on exporting sparse models to Huggingface.
@article{
fang2024maskllm,
title={MaskLLM: Learnable Semi-structured Sparsity for Large Language Models},
author={Fang, Gongfan and Yin, Hongxu and Muralidharan, Saurav and Heinrich, Greg and Pool, Jeff and Kautz, Jan and Molchanov, Pavlo and Wang, Xinchao },
journal={Advances in Neural Information Processing Systems},
year={2024}
}