Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KV cache quantized to q8_0 #2969

Closed
wants to merge 1 commit into from

Conversation

JohannesGaessler
Copy link
Collaborator

This PR aims to implement quantizing the KV cache as q8_0. As of right now it is a proof of concept and I would like to discuss the best way of doing a proper implementation ahead of time.

Potential benefits:

  • Roughly 50% less RAM/VRAM usage for the KV cache, particularly relevant for long contexts and batched evaluation, as may become more important due to speculative : PoC for speeding-up inference via speculative sampling #2926 .
  • Faster evaluation due to integer SIMD instructions and less I/O.
  • Less reliance on external BLAS libraries that are at this point mostly needed for floating point GEMM; as of right now the KV cache is the only remaining use case for quantized models.

The current state of this PR is that only the CPU implementation works (with some hacks). These are the performance results on my system:

Model CUDA? ngl test t/s master t/s PR Speedup
7B Q4_0 No 0 pp 512 25.05 ± 0.04 25.12 ± 0.20 1.00
7B Q4_0 Yes 0 pp 512 309.63 ± 1.71 289.90 ± 0.60 0.94
7B Q4_0 Yes 33 pp 512 585.59 ± 2.64 519.90 ± 1.97 0.89
7B Q4_0 No 0 tg 128 9.99 ± 0.03 10.03 ± 0.03 1.00
7B Q4_0 Yes 0 tg 128 9.98 ± 0.04 9.97 ± 0.05 1.00
7B Q4_0 Yes 33 tg 128 89.74 ± 0.13 84.01 ± 0.02 0.94

The GPU is an RTX 3090. The use of CUDA up to 33 layers is possible because only at 34 layers the KV cache starts being offloaded. In its current form there seems to be a small performance regression. The Perplexity on wikitext-2 changed from 5.9679 +/- 0.03352 to 5.9668 +/- 0.03349 which is actually a very small improvement, likely due to randomness (calculated with 33 layers offloaded).

The implementation for the K component of the KV cache was fairly straightforward: the values copied to the K component always have a row size of n_embd_head which is a multiple of 32 (q8_0 block size). So you can essentially just change the datatype of the K component to q8_0 and the preexisting implementations for copying and matrix multiplication will work out-of-the-box.

The implementation for the V component was more tricky because the row size is equal to the number of processed tokens which then need to be copied to an arbitrary location in a q8_0 block. I did the following things to make it work:

  • Extend ggml_compute_forward_dup_f32 by an extra case that so far was not implemented: copying non-contiguous f32 values to non-contiguous q8_0 values. My solution was to fill a zero-initialized temporary buffer behind the q8_0 block with the unquantized values. Then re-create the q8_0 block from those when either the buffer contains enough values for an entire q8_0 block or when there are no more new unquantized values to copy for a row. This requires an additional 128 bytes per V component row (hacked by just increasing the size by 128 elements).
  • The aforementioned extra case needs to know the starting position inside the q8_0 block in order to determine the location of the temporary buffer and when it should be converted to q8_0. I hacked this by writing this information to the op_params of the corresponding tensor in llm_build_llama.
  • The matrix multiplication code for q8_0 expects the row size to be a multiple of 32 (q8_0 block size). For the V component of the KV cache this can be done by just taking a larger slice from it. The hidden state needs to be padded with extra values. I hacked this by creating a new tensor that has the padded size, then creating a view of that tensor, and then copying the hidden state to said view.

I plan to do a proper implementation of this PR in the following way:

  • Modify ggml_view so that the position inside a block is written to op_params if the underlying tensor is quantized.
  • Add a GGML_OP like GGML_OP_PAD that pads a tensor with a constant value. Alternatively the code for quantizing the hidden state could be modified which would probably have better performance. But this would also be more complicated and require changes across a much wider range of hardware (to some of which I don't have access for testing). For the CUDA implementation only very minimal changes would be necessary to avoid padding the hidden state in an extra GGML_OP.
  • I will implement q8_0 KV cache quantization for CPU and CUDA.
  • I will replace the --memory-f32 CLI argument with a new one that lets you choose between q8_0, f16, and f32 for the KV cache type.

If there are issues with the way I plan to do the implementation, please let me know ahead of time.

@JohannesGaessler
Copy link
Collaborator Author

Add a GGML_OP like GGML_OP_PAD that pads a tensor with a constant value. Alternatively the code for quantizing the hidden state could be modified which would probably have better performance.

On second thought, maybe not. The order of tensors would be matrix multiplication -> scale -> mask -> softmax -> pad -> matrix multiplication. For optimal performance you should be fusing those tensors anyways and adding one more unary tensor to the pipeline would not make a difference.

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see anything immediately wrong with this approach, so interesting if we can make it work all the way.

Not sure if we would win much from a dedicated PAD op. We can have it as a combination of view + cpy as you have proposed. Either way would be fine I guess

Why do you add 128 again?
Also, probably better to utilize the wdata buffer that each op uses for writing work data

@JohannesGaessler
Copy link
Collaborator Author

Why do you add 128 again?

I'm adding 128 so each row has some extra space to temporarily store the unquantized values as the context fills up.
It's to avoid precision loss when only part of a q8_0 block has been filled.

@KerfuffleV2
Copy link
Collaborator

KerfuffleV2 commented Sep 2, 2023

This seems to work on ROCM (RX 6600) but weirdly uses considerably more memory than master for the same context size. Generation speed was about the same (perplexity not tested). This was a 7B LLaMA1 Q4_K model with -ngl 33 -c 2048. For master, it used ~4,500 MiB VRAM. This pull ~6,500 MiB.

@JohannesGaessler
Copy link
Collaborator Author

Also, probably better to utilize the wdata buffer that each op uses for writing work data

From what I can tell wdata would not be suitable because I need persistence across multiple eval calls.

This seems to work on ROCM (RX 6600) but weirdly uses considerably more memory than master for the same context size.

I think this is due to the hack that I'm currently using to pad one of the tensors. It should not be an issue in the final version.

@kiratp
Copy link

kiratp commented Sep 2, 2023

Less reliance on external BLAS libraries that are at this point mostly needed for floating point GEMM; as of right now the KV cache is the only remaining use case for quantized models.

Unless the plan is to call the new AMX (and VNNI) instructions manually, not using MKL would be slower on any newer (post Sapphire Rapids) Intel CPUs. These new instructions perform tiled Int8 multiplies in HW.

@JohannesGaessler
Copy link
Collaborator Author

I added new functions like ggml_view_blck_1d that allow you to set the starting position in a block (I think this is preferable over changing the interface of e.g. ggml_view_1d and adding an extra parameter).

I replaced the padding hack with an extension to the matrix multiplication ggml code. The solution that I came up with is to quantize the last block of the hidden state without any SIMD instruction if it is incomplete. The rest of the tensor is still quantiized using SIMD instructions. This should keep implementation/maintenance overhead low with negligible impact on performance.

I pushed CUDA implementations that produce correct results but the performance is bad.

@cebtenzzre
Copy link
Collaborator

cebtenzzre commented Sep 5, 2023

test-opt and test-grad0 are failing an assertion:

GGML_ASSERT: ggml.c:11303: from_float_to_vec_dot != NULL

edit: type_traits[GGML_TYPE_F32].from_float is NULL in ggml_compute_forward_mul_mat.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Sep 5, 2023

Alright, this PR should now be feature complete. I added a new CLI argument --kv-type that lets users set the type as which the KV cache is stored. By default q8_0 is used since the memory savings are quite substantial: for 4096 context with 7b q4_0 the VRAM usage goes down from 6478 MiB to 5536 MiB (942 MiB reduction). The token generation performance seems to be slightly slower, prompt processing on average a little faster:

model GPU test t/s master t/s PR Speedup
7b q4_0 RTX 3090 pp 512 2187.37 ± 34.18 2303.08 ± 24.19 1.05
7b q4_0 RTX 3090 tg 128 141.11 ± 0.30 140.19 ± 0.31 0.99
7b q4_0 P40 pp 512 875.56 ± 4.72 870.33 ± 2.09 0.99
7b q4_0 P40 tg 128 59.45 ± 0.07 57.23 ± 0.01 0.96
7b q4_0 RX 6800 pp 512 1061.05 ± 6.51 1159.28 ± 7.16 1.09
7b q4_0 RX 6800 tg 128 77.89 ± 0.02 72.95 ± 0.01 0.94

The perplexity changes as follows:

model PPL master PPL PR PPL diff
7b f16 5.7962 5.7989 2.7e-3
7b q4_0 5.9670 5.9658 -1.2e-3
7b q4_1 6.0035 6.0038 3e-4
7b q5_0 5.8292 5.8327 3.5e-3
7b q5_1 5.8538 5.8560 2.2e-3
7b q8_0 5.8014 5.8018 4e-4
7b q2_k 6.4466 6.4514 4.8e-3
7b q3_k_s 6.2950 6.2955 5e-4
7b q3_k_m 6.0264 6.0297 3.3e-3
7b q3_k_l 5.9854 5.9911 5.7e-3
7b q4_k_s 5.8891 5.8934 4.3e-3
7b q4_k_m 5.8798 5.8840 4.2e-3
7b q5_k_s 5.8215 5.8260 4.5e-3
7b q5_k_m 5.8292 5.8332 4e-3
7b q6_k 5.8098 5.8141 4.3e-3

The increase in perplexity is on the same order as the perplexity differences between f16 and q8_0 or q8_0 and q6_K.

@cebtenzzre
Copy link
Collaborator

cebtenzzre commented Sep 5, 2023

make LLAMA_CUBLAS=1 is failing on my computer:

ggml-cuda.cu(2107): error: the size of an array must be greater than zero
      __attribute__((shared)) int tile_x_qs[mmq_y * (32) + mmq_y];
                                            ^
          detected during:
            instantiation of "void allocate_tiles_q4_0<mmq_y>(int **, half2 **, int **, int **) [with mmq_y=-1]" at line 3483
            instantiation of "void mul_mat_q4_0<need_check>(const void *, const void *, float *, int, int, int, int, int, int, int, int) [with need_check=false]" at line 4662

I have a GTX 970 installed (compute 5.2), maybe that's the problem? You can test with make LLAMA_CUBLAS=1 CUDA_DOCKER_ARCH=sm_52.

@JohannesGaessler
Copy link
Collaborator Author

The failing asserts were added by me and I removed them again since the assumptions that went into them seem to have been incorrect. The issue with CUDA compilation was caused by bad fallback values (which get used by cmake for the CC 5.2 PTX code) and should be fixed now.

@KerfuffleV2
Copy link
Collaborator

Running perplexity with ROCM, I just get NANs when -kvt q8_0. It doesn't seem to make a difference whether -ngl 0 or offloading the whole model.

ggml_init_cublas: found 1 ROCm devices:
  Device 0: AMD Radeon RX 6600, compute capability 10.3
[...]
llm_load_tensors: ggml ctx size =    0.08 MB
llm_load_tensors: using ROCm for GPU acceleration
llm_load_tensors: mem required  =  103.84 MB (+   97.12 MB per state)
llm_load_tensors: offloading 26 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloading v cache to GPU
llm_load_tensors: offloading k cache to GPU
llm_load_tensors: offloaded 29/29 layers to GPU
llm_load_tensors: VRAM used: 3466 MB
[...]
llama_new_context_with_model: kv self size  =   97.12 MB
llama_new_context_with_model: compute buffer total size =   70.22 MB
llama_new_context_with_model: VRAM scratch buffer: 68.75 MB
system_info: n_threads = 5 / 16 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | 
perplexity: tokenizing the input ..
perplexity: tokenization took 759.39 ms
perplexity: calculating perplexity over 616 chunks, batch_size=512
perplexity: 0.59 seconds per pass - ETA 6.10 minutes
[1]nan,[2]nan,[3]nan,[4]nan,[5]nan,[6]nan,[7]nan,[8]nan,[9]nan,[10]nan,[11]nan,[12]nan,[13]nan

Token generation seemed to work okay.

@JohannesGaessler
Copy link
Collaborator Author

Do you still get NaN with -nommq?

@KerfuffleV2
Copy link
Collaborator

Do you still get NaN with -nommq?

Yes, I do.

perplexity: 1.27 seconds per pass - ETA 13.03 minutes
[1]nan,[2]nan,[3]nan

I get them slower! (Not that this information is likely to be too helpful.)

@cebtenzzre
Copy link
Collaborator

cebtenzzre commented Sep 5, 2023

I get an assertion failure if I set my GTX 970 as the main GPU with -mg 1.
Full command: ./main -ngl 100 -t 1 -m llama-7b.q4_0.gguf -n 128 --ignore-eos -mg 1

GGML_ASSERT: ggml-cuda.cu:5782: buffers_contiguous || ne02 == 1

Backtrace:

#3  0x00007fffcea264b8 in __GI_abort () at abort.c:79
#4  0x000055555568bb21 in ggml_cuda_op_mul_mat_vec<false> (src0=0x20ea29be0, src1=0x20ea29a90, dst=0x20ea29d30, 
    src0_ddq_i=0x93a000000 "", src0_ddf_i=0x0, src1_ddf_i=0x93e408040, dst_ddf_i=0x93e404040, i02=0, i01_low=0, i01_high=1,
    i1=0, cudaStream_main=@0x7fffffff77e0: 0x555555b9d430) at /home/cebtenzzre/src/forks/llama.cpp/ggml-cuda.cu:5782
#5  0x000055555567918b in ggml_cuda_mul_mat_vec_p021 (src0=0x20ea29be0, src1=0x20ea29a90, dst=0x20ea29d30)
    at /home/cebtenzzre/src/forks/llama.cpp/ggml-cuda.cu:6492
#6  0x0000555555679668 in ggml_cuda_mul_mat (src0=0x20ea29be0, src1=0x20ea29a90, dst=0x20ea29d30)
    at /home/cebtenzzre/src/forks/llama.cpp/ggml-cuda.cu:6543
#7  0x000055555567b89e in ggml_cuda_compute_forward (params=0x7fffffff7980, tensor=0x20ea29d30)
    at /home/cebtenzzre/src/forks/llama.cpp/ggml-cuda.cu:7062
#8  0x00005555555ad84f in ggml_compute_forward (params=0x7fffffff7980, tensor=0x20ea29d30) at ggml.c:15760
#9  0x00005555555b1e01 in ggml_graph_compute_thread (data=0x7fffffff79c0) at ggml.c:17335
#10 0x00005555555b3432 in ggml_graph_compute (cgraph=0x20ea00020, cplan=0x7fffffff7ad0) at ggml.c:17845
#11 0x00005555555be540 in ggml_graph_compute_helper (buf=std::vector of length 16384, capacity 32768 = {...},
    graph=0x20ea00020, n_threads=1) at llama.cpp:440
#12 0x00005555555ca37a in llama_eval_internal (lctx=..., tokens=0x555558214510, embd=0x0, n_tokens=1, n_past=0, n_threads=1,
    cgraph_fname=0x0) at llama.cpp:3014
#13 0x00005555555d6109 in llama_eval (ctx=0x555557305580, tokens=0x555558214510, n_tokens=1, n_past=0, n_threads=1)
    at llama.cpp:6164
#14 0x00005555555690aa in main (argc=12, argv=0x7fffffffd998) at examples/main/main.cpp:617

@JohannesGaessler
Copy link
Collaborator Author

I've pushed a fix for incorrect results with compute capabilities < 6.1. The results should now be correct but the performance will be bad.

@KerfuffleV2
Copy link
Collaborator

Tested at e0d5c0b - no difference as far as I can see. With -ngl 0, with -nommq, both. All NaNs. Perhaps your compute level checking isn't working correctly for ROCM? The speed seems the same as before.

@JohannesGaessler
Copy link
Collaborator Author

It's not just a check for compute capability, it's primarily that the old dequantize_mul_mat_vec implementation that was used as a fallback could not correctly handle non-contiguous tensors. I implemented support for non-contiguous tensors but it seems that this was a separate issue from yours.

@KerfuffleV2
Copy link
Collaborator

Alrighty. Just let me know if/when you want me to try a new version, different settings, etc. I can test out whatever.

@JohannesGaessler
Copy link
Collaborator Author

@KerfuffleV2 which GPU do you have again?

@KerfuffleV2
Copy link
Collaborator

ggml_init_cublas: found 1 ROCm devices:
  Device 0: AMD Radeon RX 6600, compute capability 10.3

Arch Linux, the ROCM/ROCBLAS stuff seems to be at 5.6.0 if that makes a difference.

@JohannesGaessler
Copy link
Collaborator Author

You could maybe try editing line 6593 of ggml_cuda.cu. There should be a boolean nc_okay that indicates whether directly handling non-contiguous data is going to give you correct results. I would be interested in whether you still get incorrect results if you manually set nc_okay = false which should then trigger the fallback implementation where the data is first made contiguous.

@KerfuffleV2
Copy link
Collaborator

KerfuffleV2 commented Sep 5, 2023

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index ebe267f..98847cf 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -6590,7 +6590,7 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
     }
 
     // no quantized non-contiguous support for lower CC kernels implemented
-    const bool nc_okay = src0->type == GGML_TYPE_F16 || g_compute_capabilities[g_main_device] >= MIN_CC_DP4A;
+    const bool nc_okay = false;
 
     if (all_on_device && nc_okay && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
         ggml_cuda_mul_mat_vec_p021(src0, src1, dst);

No effect. Also tested with -nommq.

edit: Kind of just randomly changing stuff but I also tried:

-        if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
+        if (false && src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {

Also no difference.

edit 2: Continuing that approach, I forced it to end up at line 6612:

ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);

Still no difference. (And I verified it was getting there with with a GGML_ASSERT(false), though I don't know if it was for the KV stuff specifically.)

@JohannesGaessler
Copy link
Collaborator Author

This model: https://huggingface.co/pankajmathur/orca_mini_3b - converted/quantized to q8_0

Okay, that's the piece that I was missing. Using the same model I can now reproduce the issue. The problem seems to be that 3b has n_embd_head == 100 which is not a multiple of 32 so the K component of the KV cache does not work properly. The fix is going to be annoying though; padding K to 128 is probably the least bad option.

@Dampfinchen
Copy link

Dampfinchen commented Sep 14, 2023

I pushed a fix for partial offloading when the input or batch size is not a multiple of 32.

Can confirm this fixes it. Awesome work!

@Green-Sky
Copy link
Collaborator

Green-Sky commented Sep 16, 2023

EDIT3: seems to be a bug in the setup/build process, tracked by: #3202
Got another problematic model for you: tinyllama (1.1B)
(commit 91523fb)

ppl for f16:

$ perplexity -m models/TinyLlama-1.1B-step-50K-105b/ggml-model-f32.gguf -f wikitext-2-raw/wiki.test.raw -ngl 0 --chunks 300 -kvt f16
[1]212.4852,[2]299.6576,[3]302.9353,[4]341.5807,^C

ppl for q8_0:

$ perplexity -m models/TinyLlama-1.1B-step-50K-105b/ggml-model-f32.gguf -f wikitext-2-raw/wiki.test.raw -ngl 0 --chunks 300 -kvt q8_0
[1]12633.6645,[2]16262.5879,[3]17373.0927,[4]16899.1684,^C

ppl for q8_0 nommq:

$ perplexity -m models/TinyLlama-1.1B-step-50K-105b/ggml-model-f32.gguf -f wikitext-2-raw/wiki.test.raw -ngl 0 --chunks 300 -kvt q8_0 -nommq
[1]212.1701,[2]299.5440,[3]302.8249,[4]341.6177,[5]333.2243,[6]330.2414,^C

edit: also with e0d0a0f

edit2: i realize that this is more of a generic mmq issue

@JohannesGaessler
Copy link
Collaborator Author

I pushed a fix for small models with n_embd_head % 32 != 0. I fixed it by adding op_params to ggml_cpy. If defined the rows of the source tensor are padded to whole blocks of the type of the destination tensor (with zeros).

@cebtenzzre
Copy link
Collaborator

Could you resolve the merge conflicts? I wanted to try this out on my fork, but there have been a lot of changes to ggml-cuda.cu on master.

@JohannesGaessler
Copy link
Collaborator Author

Already on it, I'll push it later today.

@JohannesGaessler
Copy link
Collaborator Author

I pushed the rebased version. From what I can tell it's working correctly. If you do encounter issues, please compare against both master and the branch kv-q8_0-6 which is the version immediately prior to rebasing so I will know whether the issue was caused by the rebase.

char * dst_ptr = (char *) dst->data;
float src0_padded[ne00_padded];
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why the CI runs are failing. From what I can tell the issue is that for some reason it tries to allocate an array of size 0 here. This would imply ne00_padded == 0 which in turn would imply ne00 == 0. Is there a situation in which calling this function with that configuration actually makes sense?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MSVC doesn't support VLA, the size of the array must be a compile-time constant.

@KerfuffleV2
Copy link
Collaborator

Tested with my weird 3B Orca model and a LLaMA2-70B for token generation and perplexity. Both seem to work fine now on ROCM.

@Nexesenex
Copy link
Contributor

Nexesenex commented Sep 19, 2023

CodeLlama 34b (Airoboros, Spicyboros, WizardLM) outputs rubbish, tested both in KoboldCPP chat and with LLama perplexity test (KV_Q_8_0 branch, around 50,000 perplexity at every context length).

Example : Output: _\\\\:\ \Item ( =,,
,
,,,
,,, =OfN of of <\xl(alias T,
,
,
,, = \AST of the Års
, = \point \bib no, = "\prog
,

No problem with CodeLlama 7b & 13b (although I didn't test all quants).

Also, the Q_8_0 branch shows again the VRAM "leak", while the Q_8_0_6 branches doesn't suffer from the problem, and shows a flat VRAM occupation as the context grows (as does show a flat VRAM occupation the current KoboldCPP Experimental build with the usual "KV16" implementation since my previous occurrence of the "VRAM leak" problem was resolved a week ago).

Sorry for my lack of precise technical terms, I'm just a user testing above my pay-grade !

@BarfingLemurs
Copy link
Contributor

BarfingLemurs commented Sep 20, 2023

On the latest build, got the same errors as before on android in termux:
I don't know if the behavior is reproducible across Mac, or an SBC, but I tested with a pixel and another phone with snapdragon SoC.

llm_load_print_meta: LF token  = 13 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.08 MB
llm_load_tensors: mem required  = 1887.56 MB (+  109.21 MB per state)
..............................................................................................
llama_new_context_with_model: kv self size  =  109.21 MB
llama_new_context_with_model: compute buffer total size =   70.22 MB
GGML_ASSERT: /data/data/com.termux/files/home/yomama/2/llama.cpp/ggml.c:3523: nb % 2 == 0
GGML_ASSERT: /data/data/com.termux/files/home/yomama/2/llama.cpp/ggml.c:3523: nb % 2 == 0
GGML_ASSERT: /data/data/com.termux/files/home/yomama/2/llama.cpp/ggml.c:3523: nb % 2 == 0
GGML_ASSERT: /data/data/com.termux/files/home/yomama/2/llama.cpp/ggml.c:3523: nb % 2 == 0
Aborted
~/.../2/llama.cpp $

tested various models: 3b q4_1, q4_0 7b, q2k 7b, q4_k_m 13b

@cebtenzzre
Copy link
Collaborator

I have been using q8_0 kv cache on my local fork without issues. But I have to disable it for now because there are merge conflicts after #3301.

@Dampfinchen
Copy link

Dampfinchen commented Oct 1, 2023

Personally, I would definately like to have this merged as a seperate option. I didn't encounter any issues on my end and it does help a lot.

@ann-wzhao
Copy link

ann-wzhao commented Oct 3, 2023

Alright, this PR should now be feature complete. I added a new CLI argument --kv-type that lets users set the type as which the KV cache is stored. By default q8_0 is used since the memory savings are quite substantial: for 4096 context with 7b q4_0 the VRAM usage goes down from 6478 MiB to 5536 MiB (942 MiB reduction). The token generation performance seems to be slightly slower, prompt processing on average a little faster:

model GPU test t/s master t/s PR Speedup
7b q4_0 RTX 3090 pp 512 2187.37 ± 34.18 2303.08 ± 24.19 1.05
7b q4_0 RTX 3090 tg 128 141.11 ± 0.30 140.19 ± 0.31 0.99
7b q4_0 P40 pp 512 875.56 ± 4.72 870.33 ± 2.09 0.99
7b q4_0 P40 tg 128 59.45 ± 0.07 57.23 ± 0.01 0.96
7b q4_0 RX 6800 pp 512 1061.05 ± 6.51 1159.28 ± 7.16 1.09
7b q4_0 RX 6800 tg 128 77.89 ± 0.02 72.95 ± 0.01 0.94
The perplexity changes as follows:

model PPL master PPL PR PPL diff
7b f16 5.7962 5.7989 2.7e-3
7b q4_0 5.9670 5.9658 -1.2e-3
7b q4_1 6.0035 6.0038 3e-4
7b q5_0 5.8292 5.8327 3.5e-3
7b q5_1 5.8538 5.8560 2.2e-3
7b q8_0 5.8014 5.8018 4e-4
7b q2_k 6.4466 6.4514 4.8e-3
7b q3_k_s 6.2950 6.2955 5e-4
7b q3_k_m 6.0264 6.0297 3.3e-3
7b q3_k_l 5.9854 5.9911 5.7e-3
7b q4_k_s 5.8891 5.8934 4.3e-3
7b q4_k_m 5.8798 5.8840 4.2e-3
7b q5_k_s 5.8215 5.8260 4.5e-3
7b q5_k_m 5.8292 5.8332 4e-3
7b q6_k 5.8098 5.8141 4.3e-3
The increase in perplexity is on the same order as the perplexity differences between f16 and q8_0 or q8_0 and q6_K.

Hi, I use this pr to quantize KV cache and I found that when I offload 32 repeated layers to GPU, the answer of the model is almost the same as the fp16 kv cache, but when I load one or some repeated layers to CPU, the answer of this kv int8 model for long prompt is total wrong(for example: always output "!!!!!!!!!!!!!!!!!!!!!!!!!"), so I wonder how to fix this bug? Thanks!

I debug further and found some intermediate tensor's data is nan value if use kv cache q8_0.

 Assertion failed: !isnan(sp[i]), file D:\llm_inference\gguf\llama.cpp-kv-q8_0-8\ggml.c, line 12412

So how should I debug next?

@cebtenzzre
Copy link
Collaborator

I merged this onto 45855b3, among some other local changes, and hit this assertion while using it via oobabooga's UI:

GGML_ASSERT: /home/cebtenzzre/src/forks/llama-cpp-python/vendor/llama.cpp/ggml-cuda.cu:5550: i_blck_0 + ne00 < QK8_0

This must be related to this PR, because nothing else in the model I'm testing with should be quantized to 8-bit. I had just edited the prompt, so it rewound the KV cache (eval was called with a smaller n_past than before).

@JohannesGaessler
Copy link
Collaborator Author

I'm assuming the issue is caused by the assert in ggml_cpy_f32_q8_0_cuda. I think the following patch for cpy_f32_q8_0 would work:

    float val;
    if (first_incomplete) {
        memcpy(&val, src, sizeof(float));
    } else {
        val = *((float *) src);
    }

    if (first_incomplete && last_incomplete) {
        __syncthreads();
    }

    if (save_unquantized && last_incomplete && i0 / QK8_0 == (i_blck_0 + ne00) / QK8_0) {
        memcpy(&dst[1 + iqs/8].qs[sizeof(float) * (iqs % 8)], src, sizeof(float));
    }

Note that I did not test this code and that I currently do not have the time to work on llama.cpp, especially since there is no indication that this PR would be merged.

@Nexesenex
Copy link
Contributor

Nexesenex commented Oct 9, 2023

@ggerganov This PR really deserves to be promised a merge, and his author to be encouraged. KV-Q8_0 is really a boner for LlamaCPP and its users compared to the competition.

@FNsi
Copy link
Contributor

FNsi commented Oct 9, 2023

If anyone wants to open a long enough but easy to achieve context window, let's say the long llamaLora merge, 16k q8_0 kv could save 10g+ memory usage.

@KerfuffleV2
Copy link
Collaborator

This PR really deserves to be promised a merge, and his author to be encouraged.

Note that this is just one random person's opinion: I'd say it's not quite that simple. How using less memory is useful for the KV cache is very easy to see but it's not the only consideration.

One thing to note is the changes are pretty complicated and there have been a lot of issues that have cropped up in testing. It seems like we haven't even reached the end of finding new ones. This isn't an easy thing to get right, even for someone as skilled as JG.

It also isn't just a self-contained set of changes that only affects people that decide to opt into using the Q8_0 quantized KV: the changes touch a lot of stuff that can have an effect whether or not quantized KV is actually being used.

Also another consideration is that supporting quantized KV can limit the other design choices that are possible. Just an example, KV cache fragmentation is currently an active issue which can come up when doing parallel generation. If the KV cache is quantized, it becomes a lot harder (maybe even not possible) to just move the entries around.

Of course, right now it has conflicts and can't be merged. (Also, resolving those conflicts and incorporating other changes can cause new bugs to crop up.)

Anyway, someone maintaining a project can't really think like "I have to merge this pull or the author will be unhappy/discouraged". They have to look at the broader picture, or they'll kill their own project by making bad choices. Personally, I know I couldn't handle it. Of course, it also feels bad to put a lot of work into a pull and have it rejected or not know if will ever be accepted.

@cebtenzzre
Copy link
Collaborator

I'm assuming the issue is caused by the assert in ggml_cpy_f32_q8_0_cuda. I think the following patch for cpy_f32_q8_0 would work:

Unfortunately, I still hit the same assertion with the patch applied.

@JohannesGaessler
Copy link
Collaborator Author

What I meant is that with the patch it should be fine to remove the assert.

@FNsi
Copy link
Contributor

FNsi commented Oct 10, 2023

decide to opt into using the Q8_0 quantized KV: the changes touch a lot of stuff that can have an effect whether or not quantized KV is actually being used.

Forgive me again but i asked myself why llama.cpp? Why a lot of people use it in the very beginning?

I think it's because,

Because we all are GPU poor people.

Because we all want something freaking great.

😂

@cebtenzzre
Copy link
Collaborator

cebtenzzre commented Oct 12, 2023

What I meant is that with the patch it should be fine to remove the assert.

I have the patch applied and the assertion removed, but now in similar situations the KV cache gets corrupted and the model just starts generating hash marks no matter how many times I regenerate the response or even remove tokens from the prompt. Reloading the model fixes it.

@cebtenzzre
Copy link
Collaborator

cebtenzzre commented Oct 12, 2023

And with or without replacing the assertion with the syncthreads call, using the 'main' example with spicyboros-c34b-2.2.Q5_0.gguf produces garbage output for me on this PR.

@Green-Sky Green-Sky mentioned this pull request Dec 3, 2023
12 tasks
@cebtenzzre
Copy link
Collaborator

obsoleted by #4312

@cebtenzzre cebtenzzre closed this Dec 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.