Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf : study batched decoding bottleneck #3726

Closed
wants to merge 3 commits into from
Closed

Conversation

ggerganov
Copy link
Owner

@ggerganov ggerganov commented Oct 22, 2023

ref #3479

Description

I wanted to find out what is currently llama.cpp's main limitation when running batched decoding, so I ran a few tests on different hardware to profile mainly the self attention overhead when using the existing unified KV cache implementation (#3228).

Below are the results on 3 different hardware:

  • M2 Ultra
  • RTX 4080
  • A100

I'm using the batched-bench tool to run PP + TG for different number of batches:

  • Shared prompt of 512 tokens
  • Each batch generates 128 tokens
  • Batch nums: 1,2,3,4,5,6,7,8,16,32
  • Models: LLaMA 1B and 7B, F16 precision

This PR adds a hack to allow for conveniently turning on and off some of the attention ops via environment variables:

  • Set env SKIP_KQ_KQV=1 to skip the 2 matrix multiplications KQ and KQV
  • Set env SKIP_KQ_ALL=1 to skip all attention ops (KQ, KQ_scaled, KQ_masked, KQ_soft_max, KQV)
  • Without setting the env this branch runs normally as master

I've also performed 2 custom diffs for Metal and CUDA to run the full computation but force only 1 KV head to be computed during matrix-multiplication:

CUDA diff to force 1 KV head
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 654d3632..28ed98de 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -6653,13 +6653,13 @@ static void ggml_cuda_op_mul_mat(
 
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
+    const int64_t ne02 = 1;
     const int64_t ne03 = src0->ne[3];
     const int64_t nrows0 = ggml_nrows(src0);
 
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
-    const int64_t ne12 = src1->ne[2];
+    const int64_t ne12 = 1;
     const int64_t ne13 = src1->ne[3];
     const int64_t nrows1 = ggml_nrows(src1);
Metal diff to force 1 KV head
diff --git a/ggml-metal.m b/ggml-metal.m
index c908106b..4b7a5226 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -736,7 +736,7 @@ void ggml_metal_graph_compute(

                 const int64_t  ne00 = src0 ? src0->ne[0] : 0;
                 const int64_t  ne01 = src0 ? src0->ne[1] : 0;
-                const int64_t  ne02 = src0 ? src0->ne[2] : 0;
+                      int64_t  ne02 = src0 ? src0->ne[2] : 0;
                 const int64_t  ne03 = src0 ? src0->ne[3] : 0;

                 const uint64_t nb00 = src0 ? src0->nb[0] : 0;
@@ -746,7 +746,7 @@ void ggml_metal_graph_compute(

                 const int64_t  ne10 = src1 ? src1->ne[0] : 0;
                 const int64_t  ne11 = src1 ? src1->ne[1] : 0;
-                const int64_t  ne12 = src1 ? src1->ne[2] : 0;
+                      int64_t  ne12 = src1 ? src1->ne[2] : 0;
                 const int64_t  ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);

                 const uint64_t nb10 = src1 ? src1->nb[0] : 0;
@@ -786,6 +786,11 @@ void ggml_metal_graph_compute(
                 //            dst->name);
                 //}

+                if (dst->op == GGML_OP_MUL_MAT) {
+                    ne02 = 1;
+                    ne12 = 1;
+                }
+
                 switch (dst->op) {
                     case GGML_OP_NONE:
                     case GGML_OP_RESHAPE:

All these options allow us to measure the overhead from the following computations individually:

  • KQ and KQV matrix multiplications for all heads
  • KQ and KQV matrix multiplications per attention head
  • KQ_scale + KQ_masked + KQ_soft_max

Results

These are the raw numbers that I measured. In each file, first are the 7B runs, followed by the 1B runs:

Here I'll inline part of the A100 results for convenience. For the rest of the results, checkout the text files above:

normal

LLAMA_CUBLAS=1 make -j && ./batched-bench /workspace/openllama-7b/ggml-model-f16.gguf 4608 1 99 0 512 128 1,2,3,4,5,6,7,8,16,32
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.126 4063.46 1.973 64.88 2.099 304.91
512 128 2 768 0.130 3941.89 6.521 39.26 6.651 115.48
512 128 3 896 0.118 4327.61 6.988 54.95 7.106 126.09
512 128 4 1024 0.111 4597.99 6.209 82.47 6.320 162.03
512 128 5 1152 0.110 4664.21 7.266 88.09 7.375 156.20
512 128 6 1280 0.108 4745.09 7.300 105.20 7.408 172.78
512 128 7 1408 0.111 4632.48 7.369 121.60 7.479 188.26
512 128 8 1536 0.111 4603.20 7.200 142.22 7.312 210.08
512 128 16 2560 0.111 4602.00 7.168 285.71 7.279 351.68
512 128 32 4608 0.113 4521.81 7.693 532.46 7.806 590.32

SKIP_KQ_ALL=1

LLAMA_CUBLAS=1 make -j && SKIP_KQ_ALL=1 ./batched-bench /workspace/openllama-7b/ggml-model-f16.gguf 4608 1 99 0 512 128 1,2,3,4,5,6,7,8,16,32
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.078 6574.39 1.622 78.92 1.700 376.54
512 128 2 768 0.069 7423.63 1.649 155.25 1.718 447.05
512 128 3 896 0.059 8675.03 1.645 233.38 1.704 525.70
512 128 4 1024 0.069 7436.13 1.723 297.17 1.792 571.50
512 128 5 1152 0.059 8691.08 1.683 380.30 1.742 661.38
512 128 6 1280 0.061 8415.79 1.666 460.93 1.727 741.15
512 128 7 1408 0.069 7384.22 1.674 535.26 1.743 807.66
512 128 8 1536 0.059 8710.30 1.689 606.37 1.748 878.96
512 128 16 2560 0.070 7330.83 1.747 1172.30 1.817 1409.04
512 128 32 4608 0.061 8420.91 1.920 2133.05 1.981 2326.03

normal + force 1 KV head

LLAMA_CUBLAS=1 make -j && ./batched-bench /workspace/openllama-7b/ggml-model-f16.gguf 4608 1 99 0 512 128 1,2,3,4,5,6,7,8,16,32
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.092 5589.03 1.878 68.15 1.970 324.90
512 128 2 768 0.087 5869.74 1.843 138.87 1.931 397.79
512 128 3 896 0.077 6656.61 1.858 206.70 1.935 463.13
512 128 4 1024 0.084 6120.96 1.845 277.43 1.929 530.81
512 128 5 1152 0.071 7204.98 1.899 337.08 1.970 584.86
512 128 6 1280 0.086 5943.47 1.908 402.46 1.994 641.79
512 128 7 1408 0.073 7014.95 1.936 462.79 2.009 700.82
512 128 8 1536 0.084 6120.23 1.938 528.31 2.022 759.67
512 128 16 2560 0.072 7110.02 2.109 971.01 2.181 1173.69
512 128 32 4608 0.090 5705.43 2.528 1620.15 2.618 1760.19

Observations

  • Using the text generation times (T_TG) for 8 batches, with Metal the entire self attention calculation amounts to ~12% of the text generation time for 1B and ~8% for 7B. This means that even with infinitely fast KV cache and self attention, the Metal performance cannot improve by more than 1.13x for 1B and 1.09x for 7B
  • The same calculations for CUDA show that the self attention amounts to (64% 1B, 44% 7B) for RTX 4080 and (76% 1B, 77% 7B) for A100 of the text generation time. So the maximum theoretical improvement in this case is (2.6x 1B, 1.8x 7B) for RTX 4080 and (3.7x 1B, 4.3x 7B)
  • Of course, we cannot ever have an infinitely fast self attention that takes 0 time, but from the above observations we can see that the KQ and KQV matrix multiplications on CUDA take a much larger toll compared to Metal, both for prompt processing and for more than 1 batches
  • Another observation is that on CUDA, the KQ and KQV processing time scales linearly with the number of KV heads for more than 1 batch, while on Metal where we have a custom matrix-matrix multiplication kernel, the computation scales much better

If my analysis above are correct, there is a significant speedup to be gained for CUDA - both for batched decoding and for prompt processing. I'm not familiar with the best practices for CUDA, but I think we should either:

  • Utilize CUDA streams for each separate KV head (currently we run n_head CUBLAS GEMMs in a single CUDA stream). If I remember correctly, we have tried utilizing CUDA streams, but only for single-batch decoding. Probably we have to revisit?
  • Implement a custom 3D matrix-matrix multiplication kernel similar to the one that we have in Metal. Use it for the KQ and KQV ops where ne02 > 1 and ne12 > 1

These observations could also explain the poor performance observed for speculative decoding on A100 reported here: #3649

Reproducing

If anyone is interested in re-running the CUDA tests above, I used the following script:

Bash script for getting data and running llama.cpp
#!/bin/bash

# setup deps
apt-get update
apt-get install git-lfs cmake cmake-curses-gui vim ruby
git-lfs install

# this is useful to git clone repos without doubling the disk size due to .git
git clone https://github.com/iboB/git-lfs-download
ln -sfn /git-lfs-download/git-lfs-download /usr/local/bin/git-lfs-download

# download data
cd workspace

git-lfs-download https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.3
git-lfs-download https://huggingface.co/openlm-research/open_llama_7b

# llama.cpp
cd /
git clone https://github.com/ggerganov/llama.cpp

cd llama.cpp

ln -sfn /workspace/open_llama_7b            /workspace/openllama-7b
ln -sfn /workspace/TinyLlama-1.1B-Chat-v0.3 /workspace/tinyllama-1b

pip install -r requirements.txt
python3 convert.py /workspace/tinyllama-1b --outfile /workspace/tinyllama-1b/ggml-model-f16.gguf --outtype f16
python3 convert.py /workspace/openllama-7b --outfile /workspace/openllama-7b/ggml-model-f16.gguf --outtype f16

LLAMA_CUBLAS=1 make -j && ./batched-bench /workspace/tinyllama-1b/ggml-model-f16.gguf 4608 1 99 0 512 128 1,2,3,4,5,6,7,8,16,32
LLAMA_CUBLAS=1 make -j && ./batched-bench /workspace/openllama-7b/ggml-model-f16.gguf 4608 1 99 0 512 128 1,2,3,4,5,6,7,8,16,32

On runpod, the above RTX 4080 and A100 tests cost me a total of ~$1.11 to perform. You would need ~40 GB storage.
Alternatively, you can run them locally - the tests require 16GB VRAM


cc @slaren @JohannesGaessler for any comments and insights

@ggerganov ggerganov added performance Speed related topics demo Demonstrate some concept or idea, not intended to be merged labels Oct 22, 2023
@JohannesGaessler
Copy link
Collaborator

Utilize CUDA streams for each separate KV head (currently we run n_head CUBLAS GEMMs in a single CUDA stream). If I remember correctly, we have tried utilizing CUDA streams, but only for single-batch decoding. Probably we have to revisit?

Check the GPU utilization, especially over time with NSight Systems. Multiple CUDA streams can help with better utilization but if the utilization is not the problem then it won't do much.

Implement a custom 3D matrix-matrix multiplication kernel similar to the one that we have in Metal. Use it for the KQ and KQV ops where ne02 > 1 and ne12 > 1

For FP16 KV cache I don't think this will help. Writing efficient GEMM kernels is very hard and I very much do not expect that custom FP16 GEMM kernels will be able to outperform cuBLAS. I think cuBLAS had functionality for batched GEMM, so maybe using that would make more sense?

@ggerganov
Copy link
Owner Author

ggerganov commented Oct 23, 2023

I've implemented a batched CUBLAS GEMM version and I observe significant improvements across the board.
Looking for additional feedback using the scratch branch and verification that the results are correct.

Will try to post updated numbers for the tests above a bit later.

Edit: see #3749 for A100 numbers and proposed PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
demo Demonstrate some concept or idea, not intended to be merged performance Speed related topics
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants