From d2b96b5593caef3cece5acce4b08c631aa05e285 Mon Sep 17 00:00:00 2001 From: beiller Date: Sat, 11 Mar 2023 14:23:33 -0500 Subject: [PATCH 1/6] Adding repeat penalization --- main.cpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/main.cpp b/main.cpp index 2f47480698f1e..f02b5ddbde94d 100644 --- a/main.cpp +++ b/main.cpp @@ -792,7 +792,7 @@ int main(int argc, char ** argv) { printf("%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str()); } printf("\n"); - printf("sampling parameters: temp = %f, top_k = %d, top_p = %f\n", params.temp, params.top_k, params.top_p); + printf("sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); printf("\n\n"); std::vector embd; @@ -801,6 +801,10 @@ int main(int argc, char ** argv) { size_t mem_per_token = 0; llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); + int last_n_size = params.repeat_last_n; + std::vector last_n_tokens(last_n_size); + std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); + for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { // predict if (embd.size() > 0) { @@ -821,6 +825,7 @@ int main(int argc, char ** argv) { // sample next token const float top_p = params.top_p; const float temp = params.temp; + const float repeat_penalty = params.repeat_penalty; const int n_vocab = model.hparams.n_vocab; @@ -829,7 +834,10 @@ int main(int argc, char ** argv) { { const int64_t t_start_sample_us = ggml_time_us(); - id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_p, temp, rng); + id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_p, temp, rng); + + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(id); t_sample_us += ggml_time_us() - t_start_sample_us; } @@ -840,6 +848,8 @@ int main(int argc, char ** argv) { // if here, it means we are still processing the input prompt for (int k = i; k < embd_inp.size(); k++) { embd.push_back(embd_inp[k]); + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(embd_inp[k]); if (embd.size() > params.n_batch) { break; } From 3f6a118d6adb583f350ac9c3ee670259bdda3e00 Mon Sep 17 00:00:00 2001 From: beiller Date: Sat, 11 Mar 2023 14:24:12 -0500 Subject: [PATCH 2/6] Update utils.h --- utils.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/utils.h b/utils.h index bbe8fe823d01e..e331904baa33b 100644 --- a/utils.h +++ b/utils.h @@ -16,11 +16,13 @@ struct gpt_params { int32_t seed = -1; // RNG seed int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_predict = 128; // new tokens to predict + int32_t repeat_last_n = 64; // last n tokens to penalize // sampling parameters int32_t top_k = 40; // unused float top_p = 0.95f; float temp = 0.80f; + float repeat_penalty = 1.30f; int32_t n_batch = 8; // batch size for prompt processing @@ -89,6 +91,8 @@ gpt_vocab::id gpt_sample_top_k_top_p( gpt_vocab::id llama_sample_top_p( const gpt_vocab & vocab, const float * logits, + std::vector & last_n_tokens, + double repeat_penalty, double top_p, double temp, std::mt19937 & rng); From 78651d5792616e7df5a68e0397d131da5b138ef6 Mon Sep 17 00:00:00 2001 From: beiller Date: Sat, 11 Mar 2023 14:24:32 -0500 Subject: [PATCH 3/6] Update utils.cpp --- utils.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/utils.cpp b/utils.cpp index abb34756ac026..59fd05a342dbb 100644 --- a/utils.cpp +++ b/utils.cpp @@ -23,6 +23,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.top_p = std::stof(argv[++i]); } else if (arg == "--temp") { params.temp = std::stof(argv[++i]); + } else if (arg == "--repeat_last_n") { + params.repeat_last_n = std::stoi(argv[++i]); + } else if (arg == "--repeat_penalty") { + params.repeat_penalty = std::stof(argv[++i]); } else if (arg == "-b" || arg == "--batch_size") { params.n_batch = std::stoi(argv[++i]); } else if (arg == "-m" || arg == "--model") { @@ -52,6 +56,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) { fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict); fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k); fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p); + fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); + fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty); fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp); fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " -m FNAME, --model FNAME\n"); @@ -372,6 +378,8 @@ gpt_vocab::id gpt_sample_top_k_top_p( gpt_vocab::id llama_sample_top_p( const gpt_vocab & vocab, const float * logits, + std::vector & last_n_tokens, + double repeat_penalty, double top_p, double temp, std::mt19937 & rng) { @@ -383,7 +391,11 @@ gpt_vocab::id llama_sample_top_p( { const double scale = 1.0/temp; for (int i = 0; i < n_logits; ++i) { - logits_id.push_back(std::make_pair(logits[i]*scale, i)); + if ( std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end() ) { + logits_id.push_back(std::make_pair(logits[i]*(1/repeat_penalty), i)); + } else { + logits_id.push_back(std::make_pair(logits[i]*scale, i)); + } } } From c90e78edc3b87c9d3d91f3f1ef817561bac2a02f Mon Sep 17 00:00:00 2001 From: beiller Date: Sat, 11 Mar 2023 14:55:57 -0500 Subject: [PATCH 4/6] Numeric fix Should probably still scale by temp even if penalized --- utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils.cpp b/utils.cpp index 59fd05a342dbb..eceab4591dfc0 100644 --- a/utils.cpp +++ b/utils.cpp @@ -392,7 +392,7 @@ gpt_vocab::id llama_sample_top_p( const double scale = 1.0/temp; for (int i = 0; i < n_logits; ++i) { if ( std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end() ) { - logits_id.push_back(std::make_pair(logits[i]*(1/repeat_penalty), i)); + logits_id.push_back(std::make_pair(logits[i]*scale*(1/repeat_penalty), i)); } else { logits_id.push_back(std::make_pair(logits[i]*scale, i)); } From 340bff0f0e8d6140a587949bac86ee7b0026507c Mon Sep 17 00:00:00 2001 From: beiller Date: Sat, 11 Mar 2023 21:51:03 -0500 Subject: [PATCH 5/6] Update comments, more proper application I see that numbers can go negative so a fix from a referenced commit --- utils.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/utils.cpp b/utils.cpp index eceab4591dfc0..b6c9493d1c484 100644 --- a/utils.cpp +++ b/utils.cpp @@ -391,8 +391,16 @@ gpt_vocab::id llama_sample_top_p( { const double scale = 1.0/temp; for (int i = 0; i < n_logits; ++i) { + // repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) + // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main if ( std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end() ) { - logits_id.push_back(std::make_pair(logits[i]*scale*(1/repeat_penalty), i)); + // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if(logits[i] < 0.0) { + logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i)); + } else { + logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i)); + } + } else { logits_id.push_back(std::make_pair(logits[i]*scale, i)); } From ebb357f7118f3e1f044976d866d94f9fe891aabe Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 12 Mar 2023 11:26:48 +0200 Subject: [PATCH 6/6] Minor formatting --- utils.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/utils.cpp b/utils.cpp index b6c9493d1c484..49023bd7b8626 100644 --- a/utils.cpp +++ b/utils.cpp @@ -393,14 +393,13 @@ gpt_vocab::id llama_sample_top_p( for (int i = 0; i < n_logits; ++i) { // repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main - if ( std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end() ) { + if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) { // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability - if(logits[i] < 0.0) { + if (logits[i] < 0.0) { logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i)); } else { logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i)); } - } else { logits_id.push_back(std::make_pair(logits[i]*scale, i)); }