perf : study batched decoding bottleneck #3726
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
ref #3479
Description
I wanted to find out what is currently
llama.cpp
's main limitation when running batched decoding, so I ran a few tests on different hardware to profile mainly the self attention overhead when using the existing unified KV cache implementation (#3228).Below are the results on 3 different hardware:
I'm using the batched-bench tool to run PP + TG for different number of batches:
This PR adds a hack to allow for conveniently turning on and off some of the attention ops via environment variables:
SKIP_KQ_KQV=1
to skip the 2 matrix multiplicationsKQ
andKQV
SKIP_KQ_ALL=1
to skip all attention ops (KQ
,KQ_scaled
,KQ_masked
,KQ_soft_max
,KQV
)master
I've also performed 2 custom diffs for Metal and CUDA to run the full computation but force only 1 KV head to be computed during matrix-multiplication:
CUDA diff to force 1 KV head
Metal diff to force 1 KV head
All these options allow us to measure the overhead from the following computations individually:
KQ
andKQV
matrix multiplications for all headsKQ
andKQV
matrix multiplications per attention headKQ_scale
+KQ_masked
+KQ_soft_max
Results
These are the raw numbers that I measured. In each file, first are the 7B runs, followed by the 1B runs:
Here I'll inline part of the A100 results for convenience. For the rest of the results, checkout the text files above:
normal
LLAMA_CUBLAS=1 make -j && ./batched-bench /workspace/openllama-7b/ggml-model-f16.gguf 4608 1 99 0 512 128 1,2,3,4,5,6,7,8,16,32
SKIP_KQ_ALL=1
LLAMA_CUBLAS=1 make -j && SKIP_KQ_ALL=1 ./batched-bench /workspace/openllama-7b/ggml-model-f16.gguf 4608 1 99 0 512 128 1,2,3,4,5,6,7,8,16,32
normal + force 1 KV head
LLAMA_CUBLAS=1 make -j && ./batched-bench /workspace/openllama-7b/ggml-model-f16.gguf 4608 1 99 0 512 128 1,2,3,4,5,6,7,8,16,32
Observations
1.13x
for 1B and1.09x
for 7B2.6x
1B,1.8x
7B) for RTX 4080 and (3.7x
1B,4.3x
7B)KQ
andKQV
matrix multiplications on CUDA take a much larger toll compared to Metal, both for prompt processing and for more than 1 batchesKQ
andKQV
processing time scales linearly with the number of KV heads for more than 1 batch, while on Metal where we have a custom matrix-matrix multiplication kernel, the computation scales much betterIf my analysis above are correct, there is a significant speedup to be gained for CUDA - both for batched decoding and for prompt processing. I'm not familiar with the best practices for CUDA, but I think we should either:
n_head
CUBLAS GEMMs in a single CUDA stream). If I remember correctly, we have tried utilizing CUDA streams, but only for single-batch decoding. Probably we have to revisit?KQ
andKQV
ops wherene02 > 1
andne12 > 1
These observations could also explain the poor performance observed for speculative decoding on A100 reported here: #3649
Reproducing
If anyone is interested in re-running the CUDA tests above, I used the following script:
Bash script for getting data and running llama.cpp
On runpod, the above RTX 4080 and A100 tests cost me a total of ~$1.11 to perform. You would need ~40 GB storage.
Alternatively, you can run them locally - the tests require 16GB VRAM
cc @slaren @JohannesGaessler for any comments and insights