diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b8bff24433ad6..5fdcde6da9f1d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -757,8 +757,8 @@ struct llama_server_context result.text_to_send = slot.generated_text.substr(pos, std::string::npos); slot.sent_count += result.text_to_send.size(); // add the token to slot queue and cache - slot.addTokenString(result); } + slot.addTokenString(result); if (slot.multibyte_pending > 0) { slot.multibyte_pending -= token_str.size(); @@ -925,8 +925,8 @@ struct llama_server_context } // context shift takes effect only when there is a single slot - if(slots.size() == 1) { - llama_client_slot slot = slots[0]; + if(params.n_parallel == 1) { + llama_client_slot &slot = slots[0]; if (slot.isProcessing() && slot.cache_tokens.size() >= (size_t)n_ctx) { // Shift context @@ -1028,22 +1028,16 @@ struct llama_server_context slot.num_prompt_tokens = prompt_tokens.size(); - slot.n_past = slot.params.cache_prompt ? common_part(slot.cache_tokens, prompt_tokens) : 0; - - slot.cache_tokens = prompt_tokens; - - if (slot.n_past == slot.num_prompt_tokens) { - // we have to evaluate at least 1 token to generate logits. - printf("we have to evaluate at least 1 token to generate logits\n"); - slot.n_past--; - } - - slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past; - - if(!slot.params.cache_prompt) { + if(!slot.params.cache_prompt) { std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0); + slot.n_past = 0; + slot.num_prompt_tokens_processed = slot.num_prompt_tokens; } else { - LOG_TEE("slot %i - in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); + if (params.n_keep < 0 && params.n_parallel == 1) + { + params.n_keep = (int)slot.num_prompt_tokens; + } + params.n_keep = std::min(params.n_ctx - 4, params.n_keep); //if input prompt is too big, truncate like normal if (slot.num_prompt_tokens >= (size_t)n_ctx) { @@ -1059,14 +1053,26 @@ struct llama_server_context }); slot.truncated = true; prompt_tokens = new_tokens; + slot.num_prompt_tokens = prompt_tokens.size(); } const size_t ps = slot.num_prompt_tokens; std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end() - ps, 0); std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.last_n_tokens.end() - ps); + slot.n_past = common_part(slot.cache_tokens, prompt_tokens); + slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past; + LOG_TEE("slot %i - in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); } llama_kv_cache_seq_rm(ctx, slot.id, num_tokens_system + slot.n_past, -1); + slot.cache_tokens = prompt_tokens; + + if (slot.n_past == slot.num_prompt_tokens) { + // we have to evaluate at least 1 token to generate logits. + printf("we have to evaluate at least 1 token to generate logits\n"); + slot.n_past--; + } + LOG_VERBOSE("prompt ingested", { {"n_past", slot.n_past}, {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, @@ -1185,7 +1191,7 @@ struct llama_server_context } } - if(kv_cache_free < 0) { + if(kv_cache_free < 0 && params.n_parallel > 1) { LOG_TEE("\nError: kv cache is full, increase context size."); return false; } @@ -1581,6 +1587,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } } +static void slot_print_timings(struct llama_client_slot * slot) { + LOG_TEE("\n"); + LOG_TEE("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, slot->t_prompt_processing, slot->num_prompt_tokens_processed, slot->t_prompt_processing / slot->num_prompt_tokens_processed, 1e3 / slot->t_prompt_processing * slot->num_prompt_tokens_processed); + LOG_TEE("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, slot->t_token_generation, slot->n_decoded, slot->t_token_generation / slot->n_decoded, 1e3 / slot->t_token_generation * slot->n_decoded); + LOG_TEE("%s: total time = %10.2f ms\n", __func__, slot->t_prompt_processing + slot->t_token_generation); +} + static json format_generation_settings(llama_server_context &llama, llama_client_slot* slot) { const auto eos_bias = slot->sparams.logit_bias.find(llama_token_eos(llama.ctx)); @@ -1606,7 +1621,7 @@ static json format_generation_settings(llama_server_context &llama, llama_client {"penalize_nl", slot->sparams.penalize_nl}, {"stop", slot->params.antiprompt}, {"n_predict", slot->params.n_predict}, - // {"n_keep", slot.params.n_keep}, + {"n_keep", llama.params.n_keep}, {"ignore_eos", ignore_eos}, {"stream", slot->params.stream}, {"logit_bias", slot->sparams.logit_bias}, @@ -1730,7 +1745,7 @@ static void parse_options_completion(const json &body, llama_client_slot* slot, slot->sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); slot->sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); slot->sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl); - llama.params.n_keep = json_value(body, "n_keep", -1); + llama.params.n_keep = json_value(body, "n_keep", 0); slot->params.seed = json_value(body, "seed", default_params.seed); slot->params.grammar = json_value(body, "grammar", default_params.grammar); slot->sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs); @@ -2089,6 +2104,7 @@ int main(int argc, char **argv) } const json data = format_final_response(llama, slot, completion_text, probs); + slot_print_timings(slot); slot->release(); res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json"); @@ -2131,6 +2147,7 @@ int main(int argc, char **argv) slot->generated_token_probs.begin(), slot->generated_token_probs.begin() + sent_token_probs_index) ); + slot_print_timings(slot); const std::string str = "data: " + data.dump(-1, ' ', false, json::error_handler_t::replace) + @@ -2197,6 +2214,7 @@ int main(int argc, char **argv) } const json data = format_final_response(llama, slot, completion_text, probs); + slot_print_timings(slot); res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json"); } else { @@ -2238,6 +2256,7 @@ int main(int argc, char **argv) slot->generated_token_probs.begin(), slot->generated_token_probs.begin() + sent_token_probs_index) ); + slot_print_timings(slot); const std::string str = "data: " + data.dump(-1, ' ', false, json::error_handler_t::replace) +