Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml-cuda : update rope implementation for parallel decoding #3254

Merged
merged 5 commits into from
Sep 19, 2023

Conversation

slaren
Copy link
Collaborator

@slaren slaren commented Sep 18, 2023

No description provided.

@slaren slaren marked this pull request as draft September 18, 2023 21:53
@slaren
Copy link
Collaborator Author

slaren commented Sep 18, 2023

I realized that there is an issue with the cudaMemcpyAsync, since the host buffer will be freed by the time the copy actually happens. It's working because the copy is not really async since it is not from a pinned buffer, but I'll try to find a better solution.

@slaren slaren marked this pull request as ready for review September 18, 2023 22:25
ggml-cuda.cu Outdated
Comment on lines 6107 to 6111
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
if (!src1_extra->copied) {
CUDA_CHECK(cudaMemcpyAsync(src1_extra->data_device[id], src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream));
src1_extra->copied = true;
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not very good, is there a better way to sync the tensor to VRAM?

@slaren
Copy link
Collaborator Author

slaren commented Sep 18, 2023

Device 0: NVIDIA GeForce RTX 3090 Ti, compute capability 8.6

model size backend test master t/s PR t/s speedup
models 7B mostly Q2_K 2.63 GiB CUDA pp 512 1805.63 ± 10.64 1789.82 ± 4.21 0.99
models 7B mostly Q3_K - Small 2.75 GiB CUDA pp 512 1907.34 ± 5.34 1891.82 ± 7.03 0.99
models 7B mostly Q3_K - Medium 3.07 GiB CUDA pp 512 1997.51 ± 2.89 1979.72 ± 9.24 0.99
models 7B mostly Q3_K - Large 3.35 GiB CUDA pp 512 1928.15 ± 4.06 1912.53 ± 11.26 0.99
models 7B mostly Q4_0 3.56 GiB CUDA pp 512 2404.71 ± 3.79 2384.81 ± 11.22 0.99
models 7B mostly Q4_K - Small 3.59 GiB CUDA pp 512 2202.34 ± 3.30 2204.59 ± 2.09 1.00
models 7B mostly Q4_K - Medium 3.80 GiB CUDA pp 512 2207.27 ± 3.36 2211.19 ± 6.55 1.00
models 7B mostly Q4_1 3.95 GiB CUDA pp 512 2102.01 ± 3.47 2102.79 ± 6.97 1.00
models 7B mostly Q5_0 4.33 GiB CUDA pp 512 2172.19 ± 4.07 2173.15 ± 10.29 1.00
models 7B mostly Q5_K - Small 4.33 GiB CUDA pp 512 2035.79 ± 4.42 2048.62 ± 5.26 1.00
models 7B mostly Q5_K - Medium 4.45 GiB CUDA pp 512 2057.20 ± 4.03 2065.67 ± 2.91 1.00
models 7B mostly Q5_1 4.72 GiB CUDA pp 512 1931.67 ± 2.74 1938.51 ± 2.89 1.00
models 7B mostly Q6_K 5.15 GiB CUDA pp 512 2107.53 ± 1.66 2113.62 ± 2.37 1.00
models 7B mostly Q8_0 6.67 GiB CUDA pp 512 2353.88 ± 3.43 2349.37 ± 7.20 1.00
models 7B mostly F16 12.55 GiB CUDA pp 512 1659.42 ± 3.10 1652.41 ± 1.44 1.00
models 7B mostly Q2_K 2.63 GiB CUDA tg 128 105.91 ± 0.33 102.95 ± 0.33 0.97
models 7B mostly Q3_K - Small 2.75 GiB CUDA tg 128 101.85 ± 0.14 99.44 ± 0.61 0.98
models 7B mostly Q3_K - Medium 3.07 GiB CUDA tg 128 108.07 ± 0.30 105.98 ± 0.29 0.98
models 7B mostly Q3_K - Large 3.35 GiB CUDA tg 128 105.34 ± 0.25 103.45 ± 0.89 0.98
models 7B mostly Q4_0 3.56 GiB CUDA tg 128 131.00 ± 0.30 127.74 ± 0.49 0.98
models 7B mostly Q4_K - Small 3.59 GiB CUDA tg 128 119.67 ± 0.23 115.26 ± 1.15 0.96
models 7B mostly Q4_K - Medium 3.80 GiB CUDA tg 128 114.40 ± 0.14 110.66 ± 0.63 0.97
models 7B mostly Q4_1 3.95 GiB CUDA tg 128 124.94 ± 0.04 120.15 ± 0.40 0.96
models 7B mostly Q5_0 4.33 GiB CUDA tg 128 112.05 ± 0.43 109.10 ± 0.73 0.97
models 7B mostly Q5_K - Small 4.33 GiB CUDA tg 128 110.85 ± 0.11 107.65 ± 0.32 0.97
models 7B mostly Q5_K - Medium 4.45 GiB CUDA tg 128 107.38 ± 0.06 104.46 ± 0.22 0.97
models 7B mostly Q5_1 4.72 GiB CUDA tg 128 108.43 ± 0.23 105.39 ± 0.37 0.97
models 7B mostly Q6_K 5.15 GiB CUDA tg 128 92.17 ± 0.07 90.08 ± 0.26 0.98
models 7B mostly Q8_0 6.67 GiB CUDA tg 128 88.62 ± 0.04 86.85 ± 0.23 0.98
models 7B mostly F16 12.55 GiB CUDA tg 128 56.14 ± 0.07 54.99 ± 0.04 0.98
With parallel
$ ./build/bin/parallel -m models/7B/ggml-model-f16.gguf -n 256 -t 8 -ngl 99 -c 4096 -b 512
[...]
llama_model_loader: - type  f32:   65 tensors
llama_model_loader: - type  f16:  226 tensors
llm_load_print_meta: format         = GGUF V2 (latest)
llm_load_print_meta: arch           = llama
llm_load_print_meta: vocab type     = SPM
llm_load_print_meta: n_vocab        = 32000
llm_load_print_meta: n_merges       = 0
llm_load_print_meta: n_ctx_train    = 2048
llm_load_print_meta: n_ctx          = 4096
llm_load_print_meta: n_embd         = 4096
llm_load_print_meta: n_head         = 32
llm_load_print_meta: n_head_kv      = 32
llm_load_print_meta: n_layer        = 32
llm_load_print_meta: n_rot          = 128
llm_load_print_meta: n_gqa          = 1
llm_load_print_meta: f_norm_eps     = 1.0e-05
llm_load_print_meta: f_norm_rms_eps = 1.0e-06
llm_load_print_meta: n_ff           = 11008
llm_load_print_meta: freq_base      = 10000.0
llm_load_print_meta: freq_scale     = 1
llm_load_print_meta: model type     = 7B
llm_load_print_meta: model ftype    = mostly F16
llm_load_print_meta: model params   = 6.74 B
llm_load_print_meta: model size     = 12.55 GiB (16.00 BPW)
llm_load_print_meta: general.name   = models
llm_load_print_meta: BOS token = 1 '<s>'
llm_load_print_meta: EOS token = 2 '</s>'
llm_load_print_meta: UNK token = 0 '<unk>'
llm_load_print_meta: LF token  = 13 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.09 MB
llm_load_tensors: using CUDA for GPU acceleration
llm_load_tensors: mem required  =  250.09 MB (+ 2048.00 MB per state)
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloading v cache to GPU
llm_load_tensors: offloading k cache to GPU
llm_load_tensors: offloaded 35/35 layers to GPU
llm_load_tensors: VRAM used: 14652 MB
...................................................................................................
llama_new_context_with_model: kv self size  = 2048.00 MB
llama_new_context_with_model: compute buffer total size =  256.49 MB
llama_new_context_with_model: VRAM scratch buffer: 255.02 MB


Client  0, seq    0, prompt  126 t, response    8 t, time  1.24 s, speed: PP 376.95 t/s, TG  8.80 t/s, AVG 107.77 t/s :

Input:    Recommend some interesting books to read.
Response: I will be glad to help you.

Client  1, seq    1, prompt  130 t, response    8 t, time  1.24 s, speed: PP 388.74 t/s, TG  8.80 t/s, AVG 110.98 t/s :

Input:    If you could have any superpower, what would it be?
Response: I would have the power to fly.

Client  2, seq    2, prompt  125 t, response   11 t, time  1.68 s, speed: PP 373.48 t/s, TG  8.16 t/s, AVG 80.78 t/s :

Input:    How to get a job at Google?
Response: Search for job opening at Google and follow the instructions.

Client  2, seq    6, prompt  130 t, response   11 t, time  1.40 s, speed: PP 680.19 t/s, TG  9.08 t/s, AVG 100.56 t/s :

Input:    If you could have any superpower, what would it be?
Response: I would like to have the ability to read minds.

Client  3, seq    3, prompt  126 t, response   26 t, time  3.47 s, speed: PP 376.12 t/s, TG  8.30 t/s, AVG 43.81 t/s :

Input:    Recommend some interesting books to read.
Response: There are over 50 million books on Amazon.com, and that is only a fraction of the number in circulation.

Client  0, seq    4, prompt  128 t, response   21 t, time  2.72 s, speed: PP 608.53 t/s, TG  8.38 t/s, AVG 54.84 t/s :

Input:    What is the best way to cook a steak?
Response: For medium rare, remove it from the refrigerator about 30 minutes before cooking.

Client  2, seq    7, prompt  124 t, response   11 t, time  1.48 s, speed: PP 765.89 t/s, TG  8.34 t/s, AVG 91.21 t/s :

Input:    What is the meaning of life?
Response: You are not ready for such a profound question.

Client  0, seq    9, prompt  124 t, response   12 t, time  1.57 s, speed: PP 766.31 t/s, TG  8.51 t/s, AVG 86.54 t/s :

Input:    What is the population of Europe?
Response: The population of Europe is about 740 million.

Client  3, seq    8, prompt  125 t, response   17 t, time  2.27 s, speed: PP 769.10 t/s, TG  8.08 t/s, AVG 62.67 t/s :

Input:    How to get a job at Google?
Response: The best way to get a job at Google is by using the Google search engine.

Client  0, seq   11, prompt  126 t, response    8 t, time  1.17 s, speed: PP 635.71 t/s, TG  8.24 t/s, AVG 114.62 t/s :

Input:    Recommend some interesting books to read.
Response: The books I read this year were:

Client  1, seq    5, prompt  126 t, response   44 t, time  5.66 s, speed: PP 598.56 t/s, TG  8.08 t/s, AVG 30.04 t/s :

Input:    Recommend some interesting books to read.
Response: I recommend “The Secret Garden”, by Frances Hodgson Burnett, “Pride and Prejudice”, by Jane Austen, and “A Tale of Two Cities”, by Charles Dickens.

Client  2, seq   10, prompt  127 t, response   24 t, time  3.20 s, speed: PP 646.08 t/s, TG  7.98 t/s, AVG 47.12 t/s :

Input:    List all planets in the Solar System.
Response: Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, Neptune.

Client  1, seq   14, prompt  127 t, response    8 t, time  1.18 s, speed: PP 644.43 t/s, TG  8.16 t/s, AVG 114.67 t/s :

Input:    I want to learn how to play the piano.
Response: You may find some information on YouTube.

Client  3, seq   12, prompt  127 t, response   25 t, time  3.30 s, speed: PP 647.63 t/s, TG  8.06 t/s, AVG 46.10 t/s :

Input:    List all planets in the Solar System.
Response: Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, and Neptune.

Client  1, seq   16, prompt  124 t, response    8 t, time  1.15 s, speed: PP 735.70 t/s, TG  8.15 t/s, AVG 114.77 t/s :

Input:    What is the meaning of life?
Response: It's a meaningless question.
output.mp4

@KerfuffleV2
Copy link
Collaborator

I tested the parallel example with ROCM and it seems to work fine.

@ggerganov ggerganov merged commit 7e2b997 into custom-attention-mask Sep 19, 2023
@slaren slaren deleted the cam-cuda branch September 19, 2023 08:40
Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On my systems:

Model GPU test t/s master t/s PR Speedup
LLaMA 7B mostly Q4_0 Tesla P40 pp512 878.43 877.12 1.00
LLaMA 7B mostly Q4_0 Tesla P40 tg128 59.42 58.13 0.98
LLaMA 7B mostly Q4_0 AMD Radeon RX 6800 pp512 1208.80 1247.10 1.03
LLaMA 7B mostly Q4_0 AMD Radeon RX 6800 tg128 77.88 74.88 0.96
LLaMA 7B mostly Q4_0 NVIDIA GeForce RTX 3090 pp512 2217.74 2210.58 1.00
LLaMA 7B mostly Q4_0 NVIDIA GeForce RTX 3090 tg128 141.74 135.34 0.95

Comment on lines +4372 to 4376
const int p = pos != nullptr ? pos[i2] : 0;
const float p0 = p * freq_scale;
const float theta = p0*powf(theta_scale, col/2);
const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of a conditional statement I think it would be faster to either pass zerod memory or to do the check via a template. In the latter case you could also simplify this code since p == 0 implies sin_theta == 0 and cos_theta == 1.

Comment on lines +6105 to +6113
int * pos = nullptr;
if ((mode & 1) == 0) {
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
pos = (int *) src1_extra->data_device[id];
if (!src1_extra->copied) {
CUDA_CHECK(cudaMemcpyAsync(pos, src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream));
src1_extra->copied = true;
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current codebase I don't think there's much you can do to avoid this. The codebase currently covers constant data being copied to VRAM only before the eval directly from the model file. In all other cases the data is written to VRAM as the output of a tensor. You could of course just not offload src1 which would cause its data to be copied to VRAM automatically in ggml_cuda_op_flatten but that would induce more copies.

ggerganov added a commit that referenced this pull request Sep 28, 2023
…3228)

* tests : verify that RoPE is "additive"

* llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)

* ggml : ggml_rope now takes a vector with positions instead of n_past

* metal : add rope_f16 kernel + optimize cpy kernels

* llama : unified KV cache + batch inference API

* llama : add new llama_decode() API that works with llama_batch

* llama : add cell_max heuristic for more efficient kv_cache

* llama : extend llama_kv_cache API

* llama : more robust cell_max heuristic + wip shift

* metal : disable concurrency optimization

* llama : add llama_kv_cache_shift_seq + no more context swaps

* llama : apply K-cache roping for Falcon and Baichuan

* speculative : fix KV cache management

* parallel : example for serving multiple users in parallel

* parallel : disable hot-plug to avoid cache fragmentation

* fixes : speculative KV cache + llama worst-case graph

* llama : extend batch API to select which logits to output

* llama : fix worst case graph build

* ggml-cuda : update rope implementation for parallel decoding (#3254)

* ggml-cuda : update rope implementation for parallel decoding

* better solution for p0 computation

* fix rope

* simpler rope implementation

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* make : add parallel to build + fix static functions in llama.cpp

* simple : fix token counting

* parallel : various improvements

* llama : fix cell_max logic + rename functions

* parallel : try smaller batches when the KV cache is fragmented

* parallel : fix sequence termination criteria

* llama : silence errors KV cache errors

* parallel : remove new line from prompt

* parallel : process system prompt once + configurable paramters + llama API

* parallel : remove question with short answers

* parallel : count cache misses

* parallel : print misses on each request

* parallel : minor

* llama : fix n_kv to never become 0

* parallel : rename hot-plug to continuous-batching

* llama : improve llama_batch API + simplify parallel example

* simple : add parallel decoding support

* simple : improve comments + free batch

* ggml-cuda : add rope f16, restore performance with parallel decoding (#3272)

* ggml-cuda : add rope f16, restore performance

* offload KQ_mask with all models

* fix rope shift

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* llama : disable MPI for now

ggml-ci

* train : make KQ_pos memory buffer permanent via dummy scale op

* ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275)

ggml-ci

* parallel : fix bug (extra BOS) + smaller token_prev array

* parallel : fix cases where the input prompts can overflow the batch

* parallel : add disabled experimental batch chunking in powers of two

* llama : llama.h formatting + comments

* simple : add README.md

* llama : fix kv cache heuristic when context is less than 32

* parallel : fix crash when `-n -1`

* llama : simplify returns if/else branches

* metal : use mm kernels for batch size > 2

* examples : utilize new llama_get_logits_ith()

* examples : add example for batched decoding

* examples : do not eval prompt 2 times (close #3348)

* server : clear the KV cache beyond n_past before llama_decode

* server : avoid context swaps by shifting the KV cache

---------

Co-authored-by: slaren <slarengh@gmail.com>
yusiwen pushed a commit to yusiwen/llama.cpp that referenced this pull request Oct 7, 2023
…gerganov#3228)

* tests : verify that RoPE is "additive"

* llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)

* ggml : ggml_rope now takes a vector with positions instead of n_past

* metal : add rope_f16 kernel + optimize cpy kernels

* llama : unified KV cache + batch inference API

* llama : add new llama_decode() API that works with llama_batch

* llama : add cell_max heuristic for more efficient kv_cache

* llama : extend llama_kv_cache API

* llama : more robust cell_max heuristic + wip shift

* metal : disable concurrency optimization

* llama : add llama_kv_cache_shift_seq + no more context swaps

* llama : apply K-cache roping for Falcon and Baichuan

* speculative : fix KV cache management

* parallel : example for serving multiple users in parallel

* parallel : disable hot-plug to avoid cache fragmentation

* fixes : speculative KV cache + llama worst-case graph

* llama : extend batch API to select which logits to output

* llama : fix worst case graph build

* ggml-cuda : update rope implementation for parallel decoding (ggerganov#3254)

* ggml-cuda : update rope implementation for parallel decoding

* better solution for p0 computation

* fix rope

* simpler rope implementation

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* make : add parallel to build + fix static functions in llama.cpp

* simple : fix token counting

* parallel : various improvements

* llama : fix cell_max logic + rename functions

* parallel : try smaller batches when the KV cache is fragmented

* parallel : fix sequence termination criteria

* llama : silence errors KV cache errors

* parallel : remove new line from prompt

* parallel : process system prompt once + configurable paramters + llama API

* parallel : remove question with short answers

* parallel : count cache misses

* parallel : print misses on each request

* parallel : minor

* llama : fix n_kv to never become 0

* parallel : rename hot-plug to continuous-batching

* llama : improve llama_batch API + simplify parallel example

* simple : add parallel decoding support

* simple : improve comments + free batch

* ggml-cuda : add rope f16, restore performance with parallel decoding (ggerganov#3272)

* ggml-cuda : add rope f16, restore performance

* offload KQ_mask with all models

* fix rope shift

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* llama : disable MPI for now

ggml-ci

* train : make KQ_pos memory buffer permanent via dummy scale op

* ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (ggerganov#3275)

ggml-ci

* parallel : fix bug (extra BOS) + smaller token_prev array

* parallel : fix cases where the input prompts can overflow the batch

* parallel : add disabled experimental batch chunking in powers of two

* llama : llama.h formatting + comments

* simple : add README.md

* llama : fix kv cache heuristic when context is less than 32

* parallel : fix crash when `-n -1`

* llama : simplify returns if/else branches

* metal : use mm kernels for batch size > 2

* examples : utilize new llama_get_logits_ith()

* examples : add example for batched decoding

* examples : do not eval prompt 2 times (close ggerganov#3348)

* server : clear the KV cache beyond n_past before llama_decode

* server : avoid context swaps by shifting the KV cache

---------

Co-authored-by: slaren <slarengh@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants