From e5bfdcc60ed72df4e994ab3cb0cba2558970f50d Mon Sep 17 00:00:00 2001 From: kuvaus <22169537+kuvaus@users.noreply.github.com> Date: Thu, 11 May 2023 20:14:32 +0300 Subject: [PATCH] Add temperature sampling with repetition penalty Compatible with new llama.cpp --- gpt4all-backend/llamamodel.cpp | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index 272633c7d404..ad7d8970e651 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -170,10 +170,28 @@ void LLamaModel::prompt(const std::string &prompt, int32_t totalPredictions = 0; for (int i = 0; i < promptCtx.n_predict; i++) { // sample next token - llama_token id = llama_sample_top_p_top_k(d_ptr->ctx, - promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.repeat_last_n, - promptCtx.repeat_last_n, promptCtx.top_k, promptCtx.top_p, promptCtx.temp, + float* logits = llama_get_logits(d_ptr->ctx); + std::vector candidates; + candidates.resize(llama_n_vocab(d_ptr->ctx)); + + for (llama_token i = 0; i < candidates.size(); i++) { + candidates[i] = llama_token_data{ + i, logits[i], 0.0f, + }; + } + llama_token_data_array candidates_data = { + candidates.data(), candidates.size(), false, + }; + + // Temperature sampling with repetition penalty + llama_sample_repetition_penalty( + d_ptr->ctx, &candidates_data, + promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.repeat_last_n, promptCtx.repeat_last_n, promptCtx.repeat_penalty); + llama_sample_top_k(d_ptr->ctx, &candidates_data, promptCtx.top_k, 1); + llama_sample_top_p(d_ptr->ctx, &candidates_data, promptCtx.top_p, 1); + llama_sample_temperature(d_ptr->ctx, &candidates_data, promptCtx.temp); + llama_token id = llama_sample_token(d_ptr->ctx, &candidates_data); // Check if the context has run out... if (promptCtx.n_past + 1 > promptCtx.n_ctx) {