Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KL-divergence #5076

Merged
merged 2 commits into from
Jan 22, 2024
Merged

KL-divergence #5076

merged 2 commits into from
Jan 22, 2024

Conversation

ikawrakow
Copy link
Contributor

@ikawrakow ikawrakow commented Jan 22, 2024

There have been several discussions about the potential value of being able to compute KL-divergence as another quantization accuracy test.

There is the Python script that @Ttl provided in PR #4739. But for those who prefer C/C++, this PR adds the ability to perform KL-divergence calculations natively in llama.cpp.

Usage

First get all logits of the fp16 model via

./perplexity -m <fp16_model> -f wiki.test.raw --kl-divergence-base <file_name> [other GPT parameters]

Be warned: the file can become quite large (about 10 GB for wiki.test.raw and context of 512) as all n_vocab logits are stored in the file for each evaluated token.

Then run a calculation using the base logits obtained in step 1 for a quantized model (or any model that has the same vocabulary):

./perplexity -m <quantized_model> --kl-divergence-base <file_name> --kl-divergence [other GPT parameters]

Note: you don't need to provide the test dataset via -f again as tokens are taken from the data stored in <file_name> (and if you do provide it, it will be simply ignored. In this way it is assured that the base model logits and the quantized model logits are based on the exact same set of tokens).

If everything goes well, you will see output such as

system_info: n_threads = 1 / 64 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | 
kl_divergence: 0.10 seconds per pass - ETA 1.05 minutes

chunk        PPL          ln(PPL(Q)/PPL(base))          KL-Divergence
   1        4.1138       0.02821 ±    0.01121       0.02080 ±    0.00302
   2        4.6038       0.02404 ±    0.00845       0.02413 ±    0.00268
   3        5.4438       0.02745 ±    0.00773       0.02251 ±    0.00191
   4        6.1767       0.02966 ±    0.00680       0.02109 ±    0.00147
   5        6.1787       0.02765 ±    0.00599       0.02052 ±    0.00125
   6        6.2272       0.02690 ±    0.00534       0.02007 ±    0.00106
   7        6.4234       0.02448 ±    0.00491       0.01964 ±    0.00094
   8        6.5653       0.02621 ±    0.00469       0.01984 ±    0.00085
   9        6.7273       0.02760 ±    0.00441       0.01954 ±    0.00077
  10        7.0274       0.02565 ±    0.00422       0.02056 ±    0.00088
  11        7.2842       0.02377 ±    0.00401       0.02034 ±    0.00080
  12        7.2599       0.02171 ±    0.00385       0.02042 ±    0.00078
  13        7.2981       0.02437 ±    0.00372       0.02160 ±    0.00093
  14        7.3247       0.02374 ±    0.00357       0.02124 ±    0.00087
  15        7.2000       0.02592 ±    0.00347       0.02172 ±    0.00084
  16        7.0290       0.02587 ±    0.00335       0.02185 ±    0.00080
...
 636        5.8173       0.02201 ±    0.00058       0.02315 ±    0.00020
 637        5.8190       0.02200 ±    0.00058       0.02316 ±    0.00020
 638        5.8234       0.02200 ±    0.00058       0.02317 ±    0.00020
 639        5.8272       0.02204 ±    0.00058       0.02317 ±    0.00020
 640        5.8335       0.02206 ±    0.00058       0.02318 ±    0.00020
 641        5.8241       0.02210 ±    0.00058       0.02319 ±    0.00020
 642        5.8189       0.02220 ±    0.00058       0.02320 ±    0.00020


llama_print_timings:        load time =     706.57 ms
llama_print_timings:      sample time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time =   62265.21 ms / 328704 tokens (    0.19 ms per token,  5279.10 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =   66075.96 ms / 328705 tokens

I.e., you get the PPL of the quantized model along with KL-divergence and the logarithm of the ratio of the quantized model PPL to the base model PPL. The statistical uncertainty on the KL-divergence and ln(PPL(Q)/PPL(Base)) is much lower compared to the uncertainty of PPL itself.

@Artefact2
Copy link
Collaborator

How are you calculating kl divergence? Average or mean? Would it also be possible to show the top1 match frequency and maybe the max divergence (or q99)?

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

@ikawrakow
Copy link
Contributor Author

How are you calculating kl divergence? Average or mean? Would it also be possible to show the top1 match frequency and maybe the max divergence (or q99)?

I don't think I understand the difference between average and mean.

Top-1 match can be easily added, along with other statistics such as max/q99, etc. But I personally don't like reviewing big PR's, and I assume neither do other people, so my preference would be to merge this, and then add additional functionality in subsequent PR(s).

@ikawrakow
Copy link
Contributor Author

To clarify what KL-divergence is being computed:

One can compute

  • The KL-divergence over the token probabilities for a given token, and then get the expectation value of these KL-divergence values over the set of evaluated tokens. This is what is being done here
  • The expectation values of the token probabilities over the set of evaluated contexts for all tokens in the vocabulary, and then compute the KL-divergence over these average token probabilities.

I.e., in the former case implemented by the PR, we compute KL-divergence for each token and get the average of that.

In the latter case, we first compute average probabilities for the tokens in the vocabulary over the evaluated tokens, and then we compute one KL-divergence based on that.

@ikawrakow ikawrakow merged commit 6f9939d into master Jan 22, 2024
47 checks passed
@ikawrakow ikawrakow deleted the ik/kl-divergence branch January 22, 2024 14:10
examples/perplexity/perplexity.cpp Show resolved Hide resolved
max_logit = std::max(max_logit, logits[i]);
min_logit = std::min(min_logit, logits[i]);
}
min_logit = std::max(min_logit, max_logit - 16);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this done? Because the value would be 0 anyways due to the scale? A comment would be helpful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, to reduce the size of the data being stored in the base run, I store them as uint16_t (the log-probabilities for wiki.test.run would be 20 GB, we have a size of 10 GB that way). The minimum logit can be very small, so I have decided to limit the probability range to e^(-16) ~ 1e-7. This slightly improves the precision of the 16-bit values being stored.

examples/perplexity/perplexity.cpp Show resolved Hide resolved
in.read((char *)&n_vocab, sizeof(n_vocab));
in.read((char *)&n_chunk, sizeof(n_chunk));
if (in.fail()) {
fprintf(stderr, "%s: failed rwading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rwading -> reading

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for noticing. Will fix in another PR.

@Artefact2
Copy link
Collaborator

A very quick benchmark/sanity check against the llama_kl.py script.

% time ./llama_kl.py -m ../models/psyonic-cetacean-20B-f16.gguf -w logits.gz -t <(head -n100 wiki.test.raw) -ngl 15
./llama_kl.py -m ../models/psyonic-cetacean-20B-f16.gguf -w logits.gz -t  -ng  739.12s user 7.58s system 240% cpu 5:10.20 total

% time ./llama_kl.py -m ../models/psyonic-cetacean-20B-Q2_K_S.gguf -r logits.gz -ngl 999
Computing KL-divergence against: logits.gz
[0] kl 0.4626, top1 0.6836
[1] kl 0.4344, top1 0.7168
[2] kl 0.4492, top1 0.7135
[3] kl 0.4359, top1 0.7319
[4] kl 0.4787, top1 0.7191
[5] kl 0.4785, top1 0.7195
[6] kl 0.4702, top1 0.7203
[7] kl 0.4656, top1 0.7177
[8] kl 0.4496, top1 0.7231
[9] kl 0.4535, top1 0.7192
[10] kl 0.4463, top1 0.7222
[11] kl 0.4427, top1 0.7254
[12] kl 0.4387, top1 0.73
Finished reading file: logits.gz

Model: psyonic-cetacean-20B-Q2_K_S.gguf
Size: 6.3 GiB, (BPW 2.70)
Tokens: 6208
KL-divergence:
mean: 0.438655, [0.415236 - 0.462073]
q90: 1.052, [0.9993 - 1.104]
q95: 1.654, [1.562 - 1.752]
q99: 3.744, [3.428 - 4.063]
max: 9.058
Reference top token in eval top-n probability:
ref_top1: 0.73 ± 0.01452
ref_top5: 0.9476 ± 0.007284
ref_top10: 0.974066 ± 0.005198
Eval top token in reference top-n probability:
eval_top5: 0.9514 ± 0.007035
eval_top10: 0.976965 ± 0.004906
errors: 0
./llama_kl.py -m ../models/psyonic-cetacean-20B-Q2_K_S.gguf -r logits.gz -ngl  73.56s user 5.13s system 102% cpu 1:17.12 total
% time ./perplexity -m ../models/psyonic-cetacean-20B-f16.gguf -f <(head -n100 wiki.test.raw) --kl-divergence-base logits.dat -ngl 15
./perplexity -m ../models/psyonic-cetacean-20B-f16.gguf -f   logits.dat -ngl   489.48s user 3.48s system 702% cpu 1:10.20 total

% time ./perplexity -m ../models/psyonic-cetacean-20B-Q2_K_S.gguf --kl-divergence-base logits.dat --kl-divergence -ngl 999
kl_divergence: 2.37 seconds per pass - ETA 0.47 minutes

chunk        PPL          ln(PPL(Q)/PPL(base))          KL-Divergence           Same top
   1        6.8970       0.23242 ±    0.07093       0.33087 ±    0.05203    0.82745 ± 0.02371
   2        8.1708       0.24716 ±    0.04899       0.35713 ±    0.03204    0.75686 ± 0.01901
   3        9.0824       0.19574 ±    0.03812       0.32259 ±    0.02283    0.76340 ± 0.01538
   4       10.8189       0.20556 ±    0.03348       0.34376 ±    0.01865    0.75882 ± 0.01340
   5       11.1750       0.18970 ±    0.02910       0.33630 ±    0.01574    0.75922 ± 0.01198
   6       11.1893       0.19189 ±    0.02669       0.33575 ±    0.01431    0.76209 ± 0.01089
   7       11.0660       0.15851 ±    0.02456       0.33891 ±    0.01290    0.75630 ± 0.01016
   8       11.1685       0.16114 ±    0.02287       0.34443 ±    0.01219    0.75588 ± 0.00951
   9       11.7421       0.15799 ±    0.02167       0.35707 ±    0.01162    0.74815 ± 0.00906
  10       12.4022       0.16578 ±    0.02100       0.36017 ±    0.01114    0.74275 ± 0.00866
  11       12.7933       0.15584 ±    0.02036       0.36672 ±    0.01066    0.73690 ± 0.00832
  12       12.9078       0.15550 ±    0.01949       0.36171 ±    0.01012    0.73922 ± 0.00794

===== KL-divergence statistics
Average:   0.361708 ±  0.010120
Median :   0.177382
Minimum:   0.000008
Maximum:  10.405725
KLD_01 :   0.000064
KLD_99 :   2.608581
KLD_05 :   0.000498
KLD_95 :   1.382345

./perplexity -m ../models/psyonic-cetacean-20B-Q2_K_S.gguf  logits.dat  -ngl   30.61s user 0.50s system 101% cpu 30.768 total

@ikawrakow
Copy link
Contributor Author

The sanity check tells us that either at least one of the two implementations is wrong, or that not the same thing is being computed. From a quick look at the Python script it seems all tokens in [1..n_ctx] are included in the evaluation there, while in llama.cpp only tokens in [n_ctx/2...n_ctx] are included (to be consistent with the way the perplexity is calculated).

@Artefact2
Copy link
Collaborator

Artefact2 commented Jan 23, 2024

That would make sense, because a full logits.gz file is over 20GB for the full wiki.test.raw. Is there any benefit to only evaluating half the tokens per chunk?

I guess the difference between the two would average out as the sample size increases, the test was only done on 6K tokens.

@ikawrakow
Copy link
Contributor Author

Is there any benefit to only evaluating half the tokens per chunk?

Are you interested in what token probabilities the model predicts without any context? Or with a context consisting of bos token, bos + 1 token, etc? If so, then you would include all tokens in the evaluation. But if you want to see performance with a context of n_ctx/2 or more, then you would use what is done in llama.cpp.

That would make sense, because a full logits.gz file is over 20GB for the full wiki.test.raw

The llama.cpp logits for a full wiki.test.raw run are 10 GB because I'm storing them with 16 bit precision. This is enough for all practical purposes (if I run the fp16 model against the stored 16-bit logits, I find a mean KL-divergence of 4e-6).

@ikawrakow
Copy link
Contributor Author

I guess the difference between the two would average out as the sample size increases, the test was only done on 6K tokens.

Not really. Based on the standard deviation of 0.01 that llama.cpp reports, the two KL-divergence estimates differ by 7 standard deviations. This means that the likelihood that the two estimates become the same if one ran them for more samples is zero. This is why I made the comment that they are different. The Python script includes the no-context situation in every estimate of the KL-divergence, which means that this has a disproportionately large impact on the computed KL-divergence mean.

@JohannesGaessler
Copy link
Collaborator

I forgot to say: thank you for this PR. It's already useful to me since it helps me to judge the precision loss in #4801 .

@ikawrakow
Copy link
Contributor Author

I forgot to say: thank you for this PR. It's already useful to me since it helps me to judge the precision loss in #4801 .

Glad to hear it is useful.

To me PPL and KL-divergence are basically the same thing. I wrote about this earlier (see #4739 (comment)). So, just for fun, here is a comparison between ln(PPL(Q)/PPL(fp16)) and KL-divergence for Mistral-7B. Graph shows both as a function of model size for all k-quants. There are some differences here and there (error bars are smaller than symbol size, so differences are real and not just noise), but overall, pretty much the same.

ppl_vs_size_mistral7B

@kalomaze
Copy link
Contributor

kalomaze commented Jan 23, 2024

I forgot to say: thank you for this PR. It's already useful to me since it helps me to judge the precision loss in #4801 .

Glad to hear it is useful.

To me PPL and KL-divergence are basically the same thing. I wrote about this earlier (see #4739 (comment)). So, just for fun, here is a comparison between ln(PPL(Q)/PPL(fp16)) and KL-divergence for Mistral-7B. Graph shows both as a function of model size for all k-quants. There are some differences here and there (error bars are smaller than symbol size, so differences are real and not just noise), but overall, pretty much the same.

ppl_vs_size_mistral7B

Maybe I'm beating a dead horse here, but to me, the usefulness of the KL divergence, is the fact that you get a much more fine grained measurement of how much change in probability mass is occurring. A model might be better at predicting some tokens compared to the other tokens due to quant loss or other changes caused by lower precision, but worse at others, and this would not necessarily be reflected in the average ppl (not without a much larger amount of perplexity calculations)

When you measure the whole distribution instead of one token probability, you don't have to make any of those assumptions, and you don't need as much compute to rule out the margin of error.

These are my old charts for Mistral 7b (before importance matrix and etc) where I did KL divergences with my own hacked together Python script, for example; it shows off the exponential degradation quite well.

image

image

@Artefact2
Copy link
Collaborator

Artefact2 commented Jan 26, 2024

To me PPL and KL-divergence are basically the same thing.

Here's my results on bagel 34b. No clear correlation, could be bad luck for this particular model on this particular dataset.

bar
foo

@ikawrakow
Copy link
Contributor Author

Here's my results on bagel 34b. No clear correlation, could be bad luck for this particular model on this particular dataset.

According to your graph, Q4_K_S has a ~7% lower perplexity than the fp16 model. Given this, my bet is that neither PPL, nor KL-divergence, nor top token probability measures quantization accuracy in any meaningful way for this LLM on this test dataset. Also, you are plotting KL-divergence median vs ln(PPL(Q)/PPL(fp16). It is the KL-divergence average, not median, that is very similar to ln(PPL(Q)/PPL(fp16).

@Artefact2
Copy link
Collaborator

Given this, my bet is that neither PPL, nor KL-divergence, nor top token probability measures quantization accuracy in any meaningful way for this LLM on this test dataset

Yes, that's probably correct. However at least kl centiles / top token data is correctly decreasing, so it seems the more useful indicator vs kl average/perplexity.

jordankanter pushed a commit to jordankanter/llama.cpp that referenced this pull request Feb 3, 2024
* kl-divergence: be able to save all logits to a file

* Add ability to compute KL-divergence

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
hodlen pushed a commit to hodlen/llama.cpp that referenced this pull request Apr 1, 2024
* kl-divergence: be able to save all logits to a file

* Add ability to compute KL-divergence

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants