Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3.1 and kv_cache quantization #738

Merged
merged 4 commits into from
Aug 27, 2024
Merged

Llama3.1 and kv_cache quantization #738

merged 4 commits into from
Aug 27, 2024

Conversation

HDCharles
Copy link
Contributor

@HDCharles HDCharles commented Aug 23, 2024

this PR has support for llama 3.1 and some improvements to kv_cache quantization and general peak memory performance for llama

high level, we can now do inference with 130k context length in 18.9 GB peak memory if we apply kv cache quantization, linear causal mask and int4 weight-only quantization

summary of changes

  1. add 3.1 support for llama
  2. change quantized_kv_cache init so it doesn't create a full precision peak: see below
  3. reorder causal mask init: see below
  4. add option for linear causal mask: see below
  5. add option for cache_size: the default generate.py behavior requires you do generate 32k tokens if you want to haev a size 32k kv_cache/causal_mask, the cache_size option lets you simply set the cache size but generate a smaller number of tokens to make it easier to benchmark
  6. add option to generate memory profile: used to generate the images below

image

context length (tokens) normal peak (GB) kv_quant peak (GB) kv quant+causal fix peak (GB)
8192 17.86 17.52 17.47
16384 19.81 18.75 18.48
32768 23.83 21.72 20.64
65536 33.5 29.54 25.24
131072 59.27 52.62 34.18

Change to quantized kv_cache init

The first change is avoiding creating of the full precision kv_cache, previously we would initialize the kv_cache and then convert it to the quantized form as seen in this memory profile:

image

those horizontal lines from ~16.1 GB to 16.6GB are the normal kv_cache and you can see them being deallocated on the right side of the image as the quantized kv_cache's are instantiated. This created an unnecessary increase in peak memory any time the initialization is the peak (which was the case for very long context lengths).

Change to causal mask

Screenshot 2024-08-26 at 7 24 40 PM

This is a memory profile for 32k context length without kv_cache quantization or any other changes, compare to one with kv_cache quantization

Screenshot 2024-08-26 at 7 24 11 PM

those horizontal bands that run from 16GB to 20.5 GB on the top image and 18.5 on the bottom, are the kv_cache. With quantization its 2 GB smaller which shows the technique is performing as expected, however there is a large blue (top) or (green) blob (with a spike on the left side) that appears in the memory profile, this is the causal mask.

Normally the causal mask is handled by creating a (token length x token length) tensor of ones, then creating a copy that is lower triangular and taking slices from it throughout the model runs. Notice the sharp peak right at the start, this occurs because in order to copy a tensor of ones into a lower triangular matrix requires you to hold 2 instances of this in memory for a moment, thereby doubling its impact in addition to taking up O(context_length^2) memory. The doubling issue was solved by creating the causal mask before the kv_cache, if done like that, the momentary doubling spike doesn't affect the peak memory since the kv_cache will be higher than the spike.

image

Although the earlier instantiation of the causal mask helps (red blob now), it is still taking up a ton of space, especially at even higher context lengths, which is eating into the gains we expect from kv_cache quantization. Why do we need to actually store the causal mask though? A slice of the causal mask is essentually just a sequence of n ones in a row and then
context_length-n zeros in a row where n is the current token being generated. Each slice differs from the next only by a single value. We can just store the slice and update it each iteration instead. Result:

image

tests:

see benchmarks.sh

the 18.9 GB number came from

python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --cache_size 131072 --kv_cache_quantization --linear_causal_mask --quantization int4wo-64

Copy link

pytorch-bot bot commented Aug 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/738

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c5e4dcb with merge base 37276d6 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 23, 2024
@HDCharles HDCharles requested a review from msaroufim August 23, 2024 06:19
@msaroufim
Copy link
Member

Mostly looking good!

  1. There's a merge conflict, @gau-nernst recently added training support for gpt-fast
  2. The memory traces you shared seemed compelling, let's have the baseline be gpt-fast as is and the intervention kv-cache + reshuffled mask init + vector mask
  3. cc @Jack-Khuu and @kartikayk since this is landing soon

Summary:

TODO: finish kv_cache testing
generate memory_trace

Added the 3.1 frequency rescaling and model definitions

testing is ongoing

Test Plan: python eval.py --checkpoint_path
$../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --compile

wikitext: {'word_perplexity,none': 7.441690325135099, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.4554823564993407, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.541497351075118, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}

python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 16384
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 16384
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 32768
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 32768

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@HDCharles
Copy link
Contributor Author

Mostly looking good!

  1. There's a merge conflict, @gau-nernst recently added training support for gpt-fast
  2. The memory traces you shared seemed compelling, let's have the baseline be gpt-fast as is and the intervention kv-cache + reshuffled mask init + vector mask
  3. cc @Jack-Khuu and @kartikayk since this is landing soon
  1. fixed
  2. gpt-fast errors even at 32k context length, it requires a bunch of fixes to even get it working, I don't have a good way to compare apples to apples. At the moment i'm comparing normal performance (with reordering of causal mask init) v kv_cache quantization v kv_cache quantization + linear causal mask

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@msaroufim
Copy link
Member

msaroufim commented Aug 27, 2024

This feels good to merge to me, fwiw @iseeyuan and @felipemello1 have also noticed a large difference by reducing the memory requirements of logits pytorch/executorch#4688 from O(context length) to O(1)

Also remind me if you also quantized the model? (seems like no?) I'm trying to see if we can hit a 24GB VRAM budget or whether we need to explore int4 kv quantization. It'd be pretty sick do to a full llama 8b inference on a 128K context length, that should for example be enough to fit the entire AO code repo

Also mind adding the top line VRAM requirements at the top, the line chart doesnt have even ranges on the y-axis (log scale?) so a bit hard to eyeball

@HDCharles HDCharles merged commit 86c7b0d into main Aug 27, 2024
16 checks passed
@vadimkantorov
Copy link

vadimkantorov commented Aug 28, 2024

Could the mask be generated via a broadcasting trick (arange broadcasted and compared to another arange broadcasted differently) to alleviate the need for ones and then tril? Or not in this context? Otherwise, does FlexAttention allow to avoid materialization of such masks and compute the masking directly during the attention? (I thought that flash attention supported such materialization-free causal masks too...)

@kir152
Copy link

kir152 commented Aug 29, 2024

Great work on the memory optimizations! Have you measured any impact on model accuracy or perplexity, with this method?

yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* MPS on macos-14

* typo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants