From ad19812cda4062c9f154ef16315df41fbe6a770a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Jan 2024 15:33:01 +0200 Subject: [PATCH] perplexity : faster HellaSwag via batching (#5017) * perplexity : faster HellaSwag ggml-ci * perplexity : clean-up ggml-ci * perplexity : no need for decode_helper ggml-ci * perplexity : add comments * perplexity : option to specify max batched tasks via `n_parallel` * perplexity : remove HellaSwag restruction for n_batch --- examples/perplexity/perplexity.cpp | 259 ++++++++++++++++------------- 1 file changed, 148 insertions(+), 111 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 57eaa713ed616..ea2c8026cfcec 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -470,7 +470,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { prompt_lines.push_back(line); } - if( prompt_lines.size() % 6 != 0) { + if (prompt_lines.size() % 6 != 0) { fprintf(stderr, "%s : number of lines in prompt not a multiple of 6.\n", __func__); return; } @@ -485,7 +485,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); // Number of tasks to use when computing the score - if ( params.hellaswag_tasks < hs_task_count ) { + if (params.hellaswag_tasks < hs_task_count) { hs_task_count = params.hellaswag_tasks; } @@ -502,27 +502,54 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { std::string ending[4]; size_t ending_logprob_count[4]; double ending_logprob[4]; + + size_t i_batch; // starting index in the llama_batch + size_t common_prefix; // max number of initial tokens that are the same in all sentences + size_t required_tokens; // needed number of tokens to evaluate all 4 endings + std::vector seq_tokens[4]; }; fprintf(stderr, "%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first") ); // Select and read data from prompt lines - hs_data_t *hs_data = new hs_data_t[hs_task_count]; - for (size_t i=0; i < hs_task_count; i++) { + std::vector hs_data(hs_task_count); + for (size_t i = 0; i < hs_task_count; i++) { size_t idx = i; + auto & hs_cur = hs_data[i]; + // Select a random example of those left in the prompt if (randomize_tasks) { std::uniform_int_distribution dist(0, prompt_lines.size()/6-1 ) ; idx = dist(rng); } - hs_data[i].context = prompt_lines[idx*6]; - hs_data[i].gold_ending_idx = std::stoi( prompt_lines[idx*6+1] ); - for (size_t j=0; j < 4; j++) { - hs_data[i].ending[j] = prompt_lines[idx*6+2+j]; + hs_cur.context = prompt_lines[idx*6]; + hs_cur.gold_ending_idx = std::stoi( prompt_lines[idx*6+1] ); + for (size_t j = 0; j < 4; j++) { + hs_cur.ending[j] = prompt_lines[idx*6+2+j]; + hs_cur.seq_tokens[j] = ::llama_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], add_bos); } + // determine the common prefix of the endings + hs_cur.common_prefix = 0; + hs_cur.required_tokens = 0; + for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) { + if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] || + hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] || + hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[3][k]) { + break; + } + hs_cur.common_prefix++; + } + hs_cur.required_tokens = hs_cur.common_prefix + + hs_cur.seq_tokens[0].size() - hs_cur.common_prefix + + hs_cur.seq_tokens[1].size() - hs_cur.common_prefix + + hs_cur.seq_tokens[2].size() - hs_cur.common_prefix + + hs_cur.seq_tokens[3].size() - hs_cur.common_prefix; + + //GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, add_bos).size()); + // Delete the selected random example from the prompt if (randomize_tasks) { prompt_lines.erase( std::next(prompt_lines.begin(),idx*6) , std::next(prompt_lines.begin(),idx*6+6) ); @@ -530,150 +557,160 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { } fprintf(stderr, "%s : calculating hellaswag score over selected tasks.\n", __func__); + printf("\ntask\tacc_norm\n"); double acc = 0.0f; + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); - const int n_ctx = llama_n_ctx(ctx); + const int n_ctx = llama_n_ctx(ctx); + const int n_batch = params.n_batch; - std::vector> ending_tokens(4); + const int max_tasks_per_batch = params.n_parallel; + const int max_seq = 4*max_tasks_per_batch; - std::vector tok_logits(n_vocab); + llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); - for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) { - // Tokenize the context to count tokens - std::vector context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos); - size_t context_size = context_embd.size(); - - for (int i = 0; i < 4; ++i) { - ending_tokens[i] = ::llama_tokenize(ctx, hs_data[task_idx].context + " " + hs_data[task_idx].ending[i], add_bos); - for (int k = 0; k < int(context_size); ++k) { - if (ending_tokens[i][k] != context_embd[k]) { - fprintf(stderr, "Oops: ending %d of task %d differs from context at position %d\n",i,int(task_idx),k); - break; - } + std::vector tok_logits(n_vocab); + std::vector batch_logits(n_ctx*n_vocab); + + auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) { + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + 0, 0, 0, // unused + }; + + const int ret = llama_decode(ctx, batch_view); + if (ret != 0) { + LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); + return false; } - } - // Do the 1st ending - // In this case we include the context when evaluating - //auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos); - auto query_embd = ending_tokens[0]; - auto query_size = query_embd.size(); - - // Stop if query wont fit the ctx window - if (query_size > (size_t)n_ctx) { - fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size); - return; + memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float)); } - // Speedup small evaluations by evaluating atleast 32 tokens - if (query_size < 32) { - query_embd.resize(32); - } + return true; + }; - // clear the KV cache - llama_kv_cache_clear(ctx); + for (size_t i0 = 0; i0 < hs_task_count; i0++) { + int n_cur = 0; - auto logits = evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab); - if (logits.empty()) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return; - } + size_t i1 = i0; + size_t i_batch = 0; // this tells us where in `llama_batch` we are currently + + llama_batch_clear(batch); - std::memcpy(tok_logits.data(), logits.data() + (context_size-1)*n_vocab, n_vocab*sizeof(float)); - const auto first_probs = softmax(tok_logits); + // batch as much tasks as possible into the available context + // each task has 4 unique seuqnce ids - one for each ending + // the common prefix is shared among the 4 sequences to save tokens + // we extract logits only from the last common token and from all ending tokens of each sequence + while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) { + auto & hs_cur = hs_data[i1]; - hs_data[task_idx].ending_logprob_count[0] = 1; - hs_data[task_idx].ending_logprob[0] = std::log(first_probs[query_embd[context_size]]); + const int s0 = 4*(i1 - i0); + if (s0 + 4 > max_seq) { + break; + } - // Calculate the logprobs over the ending - for (size_t j = context_size; j < query_size - 1; j++) { + for (size_t i = 0; i < hs_cur.common_prefix; ++i) { + llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); + } + batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix - std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float)); + for (int s = 0; s < 4; ++s) { + for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) { + llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, true); + } + } - const float prob = softmax(tok_logits)[query_embd[j + 1]]; + hs_cur.i_batch = i_batch; + i_batch += hs_cur.required_tokens; - hs_data[task_idx].ending_logprob[0] += std::log(prob); - hs_data[task_idx].ending_logprob_count[0]++; + n_cur += hs_data[i1].required_tokens; + if (++i1 == hs_task_count) { + break; + } } - // Calculate the mean token logprob for acc_norm - hs_data[task_idx].ending_logprob[0] /= hs_data[task_idx].ending_logprob_count[0]; + if (i0 == i1) { + fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0); + return; + } - // Do the remaining endings - // For these, we use the bare ending with n_past = context_size - // - for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) { + llama_kv_cache_clear(ctx); - // Tokenize the query - query_embd.resize(ending_tokens[ending_idx].size() - context_size); - std::memcpy(query_embd.data(), ending_tokens[ending_idx].data() + context_size, query_embd.size()*sizeof(int)); - query_size = query_embd.size(); + // decode all tasks [i0, i1) + if (!decode_helper(ctx, batch, n_batch)) { + fprintf(stderr, "%s: llama_decode() failed\n", __func__); + return; + } - // Stop if query wont fit the ctx window - if (context_size + query_size > (size_t)n_ctx) { - fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size); - return; - } + // compute the logprobs for each ending of the decoded tasks + for (size_t i = i0; i < i1; ++i) { + auto & hs_cur = hs_data[i]; - // Speedup small evaluations by evaluating atleast 32 tokens - // No, resizing to 32 is actually slightly slower (at least on CUDA) - //if (query_size < 32) { - // query_embd.resize(32); - //} + std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + hs_cur.common_prefix - 1), n_vocab*sizeof(float)); - // Evaluate the query - logits = evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab); - if (logits.empty()) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return; - } + const auto first_probs = softmax(tok_logits); - hs_data[task_idx].ending_logprob_count[ending_idx] = 1; - hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]); + size_t li = hs_cur.common_prefix; // logits index in the batch - // Calculate the logprobs over the ending - for (size_t j = 0; j < query_size - 1; j++) { - std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float)); + for (int s = 0; s < 4; ++s) { + hs_cur.ending_logprob_count[s] = 1; + hs_cur.ending_logprob[s] = std::log(first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]); - const float prob = softmax(tok_logits)[query_embd[j + 1]]; + // Calculate the logprobs over the ending + for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) { + std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + li++), n_vocab*sizeof(float)); - hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob); - hs_data[task_idx].ending_logprob_count[ending_idx]++; - } + const float prob = softmax(tok_logits)[hs_cur.seq_tokens[s][j + 1]]; - // Calculate the mean token logprob for acc_norm - hs_data[task_idx].ending_logprob[ending_idx] /= hs_data[task_idx].ending_logprob_count[ending_idx]; + hs_cur.ending_logprob[s] += std::log(prob); + hs_cur.ending_logprob_count[s]++; + } + // account that we skip the last token in the ending + ++li; -// printf("task %lu, ending %lu, whole_len %lu, context_len %lu, ending_logprob_count %lu, ending_logprob %.4f\n", -// task_idx,ending_idx,whole_size,context_size, hs_data[task_idx].ending_logprob_count[ending_idx], hs_data[task_idx].ending_logprob[ending_idx] ); - } + // Calculate the mean token logprob for acc_norm + hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s]; + } - // Find the ending with maximum logprob - size_t ending_logprob_max_idx = 0; - double ending_logprob_max_val = hs_data[task_idx].ending_logprob[0]; - for (size_t j = 1; j < 4; j++) { - if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) { - ending_logprob_max_idx = j; - ending_logprob_max_val = hs_data[task_idx].ending_logprob[j]; + // Find the ending with maximum logprob + size_t ending_logprob_max_idx = 0; + double ending_logprob_max_val = hs_cur.ending_logprob[0]; + for (size_t s = 1; s < 4; s++) { + if (hs_cur.ending_logprob[s] > ending_logprob_max_val) { + ending_logprob_max_idx = s; + ending_logprob_max_val = hs_cur.ending_logprob[s]; + } } - } -// printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_data[task_idx].gold_ending_idx); + //printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_cur.gold_ending_idx); + + // If the gold ending got the maximum logprobe add one accuracy point + if (ending_logprob_max_idx == hs_cur.gold_ending_idx) { + acc += 1.0; + } - // If the gold ending got the maximum logprobe add one accuracy point - if (ending_logprob_max_idx == hs_data[task_idx].gold_ending_idx) { - acc += 1.0; + // Print the accumulated accuracy mean x 100 + printf("%zu\t%.8lf\n", i + 1, acc/double(i + 1)*100.0); + fflush(stdout); } - // Print the accumulated accuracy mean x 100 - printf("%zu\t%.8lf\n",task_idx+1, acc/double(task_idx+1)*100.0); - fflush(stdout); + i0 = i1 - 1; } - delete [] hs_data; + llama_batch_free(batch); printf("\n"); }