From 8d8b76d469763fca498d55e04c8b10a18a545c3b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 26 Nov 2023 11:26:55 +0200 Subject: [PATCH] lookahead : add comments --- examples/lookahead/lookahead.cpp | 79 +++++++++++++++++++++++++------- 1 file changed, 63 insertions(+), 16 deletions(-) diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index ff17f06da146b..4c49a85ebcde7 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -6,7 +6,7 @@ #include #include -struct seq_ngram { +struct ngram_data { bool active = false; llama_seq_id seq_id = -1; @@ -16,11 +16,12 @@ struct seq_ngram { std::vector tokens; }; +// n-gram container struct ngram_container { ngram_container(int n_vocab, int N, int G) { cnt.resize(n_vocab); head.resize(n_vocab); - tokens.resize(n_vocab * (N - 1)*G); + tokens.resize(n_vocab * G * (N - 1)); } int n_total = 0; @@ -28,6 +29,8 @@ struct ngram_container { std::vector cnt; std::vector head; + // [n_vocab][G][N - 1] + // for each token of the vocab, keep a ring-buffer of capacity G of n-grams of size N - 1 std::vector tokens; }; @@ -109,6 +112,7 @@ int main(int argc, char ** argv) { // used to determine end of generation bool has_eos = false; + // for each decoded batch, we have at most W + G + 1 distinct sequences: // seq_id == 0 : the current input token // seq_id [1, W] : tokens from the past N - 1 Jacobi iterations // seq_id [W + 1, W + G] : verification n-grams @@ -118,7 +122,7 @@ int main(int argc, char ** argv) { struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); // verification n-grams - std::vector ngrams_cur(G); + std::vector ngrams_cur(G); // tokens for the past N - 1 Jacobi iterations std::vector tokens_j_prev(W); @@ -127,21 +131,26 @@ int main(int argc, char ** argv) { tokens_j[j].resize(W); for (int i = 0; i < W; i++) { - // initialize randomly from the prompt tokens - tokens_j[j][i] = all[1 + rand() % (all.size() - 1)]; - - // initialize with a sequence of increasing numbers - tokens_j[j][i] = 100 + i; + // there are different ways to init these tokens + if (0) { + // initialize randomly from the prompt tokens + tokens_j[j][i] = all[1 + rand() % (all.size() - 1)]; + } else { + // initialize with a sequence of increasing numbers + tokens_j[j][i] = 100 + i; + } } } std::vector seq_id_look; + // the input token belongs both to all sequences std::vector seq_id_all(W + G + 1); for (int i = 0; i < W + G + 1; i++) { seq_id_all[i] = i; } + // here we keep adding new n-grams as we go ngram_container ngrams_observed(llama_n_vocab(model), N, G); // debug @@ -171,13 +180,37 @@ int main(int argc, char ** argv) { } // build the mask from https://lmsys.org/blog/2023-11-21-lookahead-decoding/ + // + // Example for W = 5, N = 4, G = 2: + // (I = input, L = lookahead, V = verification) + // + // Batch: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 + // T: -2 -2 -2 -2 -1 -1 -1 -1 -1 0 0 0 0 0 0 + // Info: I L L L L L L L L L L L L L L V V V V V V + // Pos: 0 1 2 3 4 1 2 3 4 5 2 3 4 5 6 1 2 3 1 2 3 (+ n_past) + // Logits: 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 + // --------------------------------------------------------------------- + // Seq: 0 + // 1 1 1 + // 2 2 2 2 + // 3 3 3 3 3 + // 4 4 4 4 4 4 + // 5 5 5 5 5 5 5 + // 6 6 6 6 + // 7 7 7 7 + // --------------------------------------------------------------------- + // | | | | | | | | | | | + // V V V V V | | | | | | + // j_tokens | | | | | | + // V V V V V V + // id { llama_batch_clear(batch); // current token - first token of the first level llama_batch_add(batch, id, n_past, seq_id_all, true); - // verification n-grams - queue this here for less KV cache fragmentation + // verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation { const int g_cur = ngrams_observed.cnt[id]; @@ -233,6 +266,7 @@ int main(int argc, char ** argv) { for (int v = 0; v < N; ++v) { int i_batch = 0; + // if no active ngrams are left, it means the sampled token does not pass the verification if (v > 0) { for (int g = 0; g < (int) ngrams_cur.size(); g++) { if (ngrams_cur[g].active) { @@ -244,16 +278,18 @@ int main(int argc, char ** argv) { } } - // no more matches + // no more matches -> create a new batch if (i_batch == 0) { break; } } + // sample the next token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch); llama_sampling_accept(ctx_sampling, ctx, id, true); + // print { const std::string token_str = llama_token_to_piece(ctx, id); @@ -313,7 +349,7 @@ int main(int argc, char ** argv) { } } - // update Jacobi tokens (or whatever these are called) + // update lookahead tokens { for (int i = 0; i < W; i++) { tokens_j_prev[i] = tokens_j[0][i]; @@ -330,11 +366,14 @@ int main(int argc, char ** argv) { } } else { for (int i = 0; i < W; i++) { - // random init - //tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)]; - - // init from the previous level - tokens_j[N - 2][i] = tokens_j[0][i]; + // there are different ways to init these tokens + if (0) { + // random init + tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)]; + } else { + // init from the previous level + tokens_j[N - 2][i] = tokens_j[0][i]; + } } } } @@ -398,9 +437,13 @@ int main(int argc, char ** argv) { break; } + // KV cache management + // if no verification token matched, we simply remove all cells from this batch -> no fragmentation llama_kv_cache_seq_rm(ctx, -1, n_past, -1); if (seq_id_best != 0) { + // if a verification token matched, we keep the best sequence and remove the rest + // this leads to some KV cache fragmentation llama_kv_cache_seq_keep(ctx, seq_id_best); llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1); llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1); @@ -418,6 +461,10 @@ int main(int argc, char ** argv) { LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); + LOG_TEE("\n"); + LOG_TEE("W = %2d\n", W); + LOG_TEE("N = %2d\n", N); + LOG_TEE("G = %2d\n", G); LOG_TEE("\n"); LOG_TEE("n_predict = %d\n", n_predict); LOG_TEE("n_accept = %d\n", n_accept);