diff --git a/common/arg.cpp b/common/arg.cpp index ecc296485cb47..98c452e78e905 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1935,6 +1935,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.n_ctx_checkpoints = value; } ).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--cache-ram", "-cram"}, "N", + string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)\n" + "[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)", params.cache_ram_mib), + [](common_params & params, int value) { + params.cache_ram_mib = value; + } + ).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--kv-unified", "-kvu"}, string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n" diff --git a/common/chat.h b/common/chat.h index a1afe574bd0ca..f7b36ec711df4 100644 --- a/common/chat.h +++ b/common/chat.h @@ -33,8 +33,8 @@ struct common_chat_msg_content_part { struct common_chat_msg { std::string role; std::string content; - std::vector content_parts = {}; - std::vector tool_calls = {}; + std::vector content_parts; + std::vector tool_calls; std::string reasoning_content; std::string tool_name; std::string tool_call_id; @@ -44,7 +44,7 @@ struct common_chat_msg { bool empty() const { return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); } - void ensure_tool_call_ids_set(std::vector & ids_cache, const std::function & gen_tool_call_id) { + void set_tool_call_ids(std::vector & ids_cache, const std::function & gen_tool_call_id) { for (auto i = 0u; i < tool_calls.size(); i++) { if (ids_cache.size() <= i) { auto id = tool_calls[i].id; diff --git a/common/common.h b/common/common.h index 8a8ecd667f2cc..03d095198b516 100644 --- a/common/common.h +++ b/common/common.h @@ -378,7 +378,7 @@ struct common_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool no_perf = false; // disable performance metrics - bool ctx_shift = false; // context shift on infinite text generation + bool ctx_shift = false; // context shift on infinite text generation bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) bool kv_unified = false; // enable unified KV cache @@ -425,7 +425,8 @@ struct common_params { int32_t timeout_write = timeout_read; // http write timeout in seconds int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting - int32_t n_ctx_checkpoints = 3; // max number of context checkpoints per slot + int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot + int32_t cache_ram_mib = 8192; // 0 = no limit, 1 = 1 MiB, etc. std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 816f2d5de592b..736693e174527 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -123,11 +123,8 @@ llama_kv_cache::llama_kv_cache( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k; - ggml_tensor * v; - - k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); - v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); + ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); + ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(v, "cache_v_l%d", il); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index de6e1a322b2c2..41ecb279feb89 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -9,7 +9,6 @@ #include "sampling.h" #include "speculative.h" #include "mtmd.h" -#include "mtmd-helper.h" // mime type for sending response #define MIMETYPE_JSON "application/json; charset=utf-8" @@ -158,7 +157,6 @@ struct slot_params { if (only_metrics) { return json { - {"n_predict", n_predict}, // Server configured n_predict {"seed", sampling.seed}, {"temperature", sampling.temp}, {"dynatemp_range", sampling.dynatemp_range}, @@ -181,7 +179,8 @@ struct slot_params { {"mirostat", sampling.mirostat}, {"mirostat_tau", sampling.mirostat_tau}, {"mirostat_eta", sampling.mirostat_eta}, - {"max_tokens", n_predict}, // User configured n_predict + {"max_tokens", n_predict}, + {"n_predict", n_predict}, // TODO: deduplicate? {"n_keep", n_keep}, {"n_discard", n_discard}, {"ignore_eos", sampling.ignore_eos}, @@ -209,7 +208,6 @@ struct slot_params { } return json { - {"n_predict", n_predict}, // Server configured n_predict {"seed", sampling.seed}, {"temperature", sampling.temp}, {"dynatemp_range", sampling.dynatemp_range}, @@ -234,7 +232,8 @@ struct slot_params { {"mirostat_tau", sampling.mirostat_tau}, {"mirostat_eta", sampling.mirostat_eta}, {"stop", antiprompt}, - {"max_tokens", n_predict}, // User configured n_predict + {"max_tokens", n_predict}, + {"n_predict", n_predict}, // TODO: deduplicate? {"n_keep", n_keep}, {"n_discard", n_discard}, {"ignore_eos", sampling.ignore_eos}, @@ -265,15 +264,15 @@ struct server_task { int id = -1; // to be filled by server_queue int index = -1; // used when there are multiple prompts (batch request) - server_task_type type; - // used by SERVER_TASK_TYPE_CANCEL int id_target = -1; + int id_slot = -1; // used by SERVER_TASK_TYPE_INFERENCE slot_params params; - server_tokens prompt_tokens; - int id_selected_slot = -1; + server_tokens tokens; + + server_task_type type; // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE struct slot_action { @@ -289,6 +288,8 @@ struct server_task { // used by SERVER_TASK_TYPE_SET_LORA std::vector set_lora; + server_task() = default; + server_task(server_task_type type) : type(type) {} static slot_params params_from_json_cmpl( @@ -305,6 +306,7 @@ struct server_task { defaults.sampling = params_base.sampling; defaults.speculative = params_base.speculative; defaults.n_keep = params_base.n_keep; + defaults.n_predict = params_base.n_predict; defaults.antiprompt = params_base.antiprompt; // enabling this will output extra debug information in the HTTP responses from the server @@ -323,32 +325,32 @@ struct server_task { params.n_discard = json_value(data, "n_discard", defaults.n_discard); //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); - params.response_fields = json_value(data, "response_fields", std::vector()); - - params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); - params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); - params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); - params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); - params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); - params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); - params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); - params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); - params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); - params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); - params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); - params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); - params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); - params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); - params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); - params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); - params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); - params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); - params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); - params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); - params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); - params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); - params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); - params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.response_fields = json_value(data, "response_fields", std::vector()); + + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); + params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); + params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); + params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); + params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); + params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); + params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); + params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); + params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); + params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); + params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); + params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); + params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); + params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); + params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); + params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); + params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); + params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); + params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); + params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); + params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); + params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); + params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); + params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); @@ -690,7 +692,7 @@ struct server_task_result { // using shared_ptr for polymorphism of server_task_result using server_task_result_ptr = std::unique_ptr; -inline std::string stop_type_to_str(stop_type type) { +static inline std::string stop_type_to_str(stop_type type) { switch (type) { case STOP_TYPE_EOS: return "eos"; case STOP_TYPE_WORD: return "word"; @@ -764,13 +766,6 @@ struct completion_token_output { } }; -struct ctx_checkpoint { - llama_pos pos_min; - llama_pos pos_max; - - std::vector data; -}; - struct server_task_result_cmpl_final : server_task_result { int index = 0; @@ -797,11 +792,12 @@ struct server_task_result_cmpl_final : server_task_result { slot_params generation_params; // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_msg oaicompat_msg; + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_msg oaicompat_msg; + std::vector oaicompat_msg_diffs; virtual int get_index() override { @@ -1373,17 +1369,17 @@ struct server_task_result_slot_save_load : server_task_result { { "save_ms", t_ms } }}, }; - } else { - return json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_restored", n_tokens }, - { "n_read", n_bytes }, - { "timings", { - { "restore_ms", t_ms } - }}, - }; } + + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", n_tokens }, + { "n_read", n_bytes }, + { "timings", { + { "restore_ms", t_ms } + }}, + }; } }; @@ -1404,15 +1400,218 @@ struct server_task_result_apply_lora : server_task_result { } }; +struct server_prompt_checkpoint { + llama_pos pos_min; + llama_pos pos_max; + + std::vector data; + + size_t size() const { + return data.size(); + } +}; + +struct server_prompt { + server_tokens tokens; + + std::vector data; + + std::list checkpoints; + + size_t size() const { + size_t res = data.size(); + + for (const auto & checkpoint : checkpoints) { + res += checkpoint.size(); + } + + return res; + } + + int n_tokens() const { + return tokens.size(); + } +}; + +struct server_prompt_cache { + server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) { + this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib); + this->limit_tokens = limit_tokens; + } + + std::list states; + + // in bytes, 0 = no limit + size_t limit_size = 0; + + // in tokens, 0 = no limit + size_t limit_tokens = 0; + + size_t size() const { + size_t res = 0; + + for (const auto & state : states) { + res += state.size(); + } + + return res; + } + + size_t n_tokens() const { + size_t res = 0; + + for (const auto & state : states) { + res += state.n_tokens(); + } + + return res; + } + + server_prompt * alloc(const server_prompt & prompt, size_t state_size) { + // first check if the current state is contained fully in the cache + for (auto it = states.begin(); it != states.end(); ++it) { + const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens); + + if (cur_lcp_len == (int) prompt.tokens.size()) { + SRV_WRN("%s", " - prompt is already in the cache, skipping\n"); + return nullptr; + } + } + + // next, remove any cached prompts that are fully contained in the current prompt + for (auto it = states.begin(); it != states.end();) { + const int len = it->tokens.get_common_prefix(prompt.tokens); + + if (len == (int) it->tokens.size()) { + SRV_WRN(" - removing obsolete cached prompt with length %d\n", len); + + it = states.erase(it); + } else { + ++it; + } + } + + std::vector state_data; + + // check if we can allocate enough memory for the new state + try { + state_data.resize(state_size); + } catch (const std::bad_alloc & e) { + SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what()); + + limit_size = std::max(1, 0.4*size()); + + SRV_WRN(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0)); + + update(); + + return nullptr; + } + + // TODO: for some reason we can't copy server_tokens, so we have to do this workaround + auto & cur = states.emplace_back(); + cur = { + /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), + /*.data =*/ std::move(state_data), + /*.checkpoints =*/ prompt.checkpoints, + }; + + return &cur; + } + + bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { + const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); + + float f_keep_best = float(lcp_best) / prompt.tokens.size(); + float sim_best = float(lcp_best) / tokens_new.size(); + + SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); + + auto it_best = states.end(); + + // find the most similar cached prompt, that would also preserve the most context + for (auto it = states.begin(); it != states.end(); ++it) { + const int lcp_cur = it->tokens.get_common_prefix(tokens_new); + + const float f_keep_cur = float(lcp_cur) / it->tokens.size(); + const float sim_cur = float(lcp_cur) / tokens_new.size(); + + // don't trash large prompts + if (f_keep_cur < 0.25f) { + continue; + } + + if (f_keep_best < f_keep_cur && sim_best < sim_cur) { + f_keep_best = f_keep_cur; + sim_best = sim_cur; + + it_best = it; + } + } + + if (it_best != states.end()) { + SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); + + const size_t size = it_best->data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0); + if (n != size) { + SRV_WRN("failed to restore state with size %zu\n", size); + + return false; + } + + it_best->data.clear(); + it_best->data.shrink_to_fit(); + + prompt = std::move(*it_best); + + states.erase(it_best); + } + + return true; + } + + void update() { + if (limit_size > 0) { + // always keep at least one state, regardless of the limits + while (states.size() > 1 && size() > limit_size) { + if (states.empty()) { + break; + } + + SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + + states.pop_front(); + } + } + + if (limit_tokens > 0) { + while (states.size() > 1 && n_tokens() > limit_tokens) { + if (states.empty()) { + break; + } + + SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + + states.pop_front(); + } + } + + SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens)\n", + states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens); + + for (const auto & state : states) { + SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); + } + } +}; + struct server_slot { int id; - int id_task = -1; - - // only used for completion/embedding/infill/rerank - server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; llama_batch batch_spec = {}; + // TODO: change to unique_ptrs for consistency: llama_context * ctx = nullptr; llama_context * ctx_dft = nullptr; @@ -1421,15 +1620,8 @@ struct server_slot { common_speculative * spec = nullptr; - std::vector lora; - int32_t alora_invocation_start = -1; - - // the index relative to completion multi-task request - size_t index = 0; - - struct slot_params params; - - slot_state state = SLOT_STATE_IDLE; + std::unique_ptr task; + std::unique_ptr task_prev; // used for debugging // used to determine the slot that has been used the longest int64_t t_last_used = -1; @@ -1437,38 +1629,66 @@ struct server_slot { // generation props int32_t n_ctx = 0; // context size per slot int32_t n_past = 0; + int32_t n_keep = 0; int32_t n_decoded = 0; int32_t n_remaining = -1; int32_t i_batch = -1; - int32_t n_predict = -1; // TODO: disambiguate from params.n_predict - // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated - int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens_cache = 0; int32_t n_prompt_tokens_processed = 0; - // input prompt tokens - server_tokens prompt_tokens; + int32_t n_prompt_tokens() const { + return task->tokens.size(); + } size_t last_nl_pos = 0; std::string generated_text; llama_tokens generated_tokens; - common_chat_msg chat_msg; - server_tokens cache_tokens; + common_chat_msg chat_msg; std::vector generated_token_probs; - std::vector ctx_checkpoints; - bool has_next_token = true; bool has_new_line = false; bool truncated = false; + stop_type stop; std::string stopping_word; + // state + slot_state state = SLOT_STATE_IDLE; + + server_prompt prompt; + + void prompt_save(server_prompt_cache & prompt_cache) const { + assert(prompt.data.size() == 0); + + const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); + + SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", + (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); + + auto * cur = prompt_cache.alloc(prompt, cur_size); + if (cur == nullptr) { + return; + } + + llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0); + } + + void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { + bool res = prompt_cache.load(prompt, tokens, ctx, id); + if (!res) { + SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); + } + } + + std::vector lora; + int32_t alora_invocation_start = -1; + // sampling json json_schema; @@ -1480,7 +1700,7 @@ struct server_slot { std::vector generated_tool_call_ids; // stats - size_t n_sent_text = 0; // number of sent text character + size_t n_sent_text = 0; // number of sent text character int64_t t_start_process_prompt; int64_t t_start_generation; @@ -1497,19 +1717,17 @@ struct server_slot { void reset() { SLT_DBG(*this, "%s", "\n"); - n_prompt_tokens = 0; n_prompt_tokens_cache = 0; - last_nl_pos = 0; - generated_text = ""; - has_new_line = false; - truncated = false; - stop = STOP_TYPE_NONE; - stopping_word = ""; - n_past = 0; - n_sent_text = 0; - task_type = SERVER_TASK_TYPE_COMPLETION; - chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + last_nl_pos = 0; + generated_text = ""; + has_new_line = false; + truncated = false; + stop = STOP_TYPE_NONE; + stopping_word = ""; + n_past = 0; + n_sent_text = 0; + chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; generated_tokens.clear(); generated_token_probs.clear(); @@ -1521,16 +1739,23 @@ struct server_slot { n_draft_total = 0; n_draft_accepted = 0; + task.reset(); + task_prev.reset(); + // clear alora start alora_invocation_start = -1; } bool need_embd() const { - return server_task_type_need_embd(task_type); + GGML_ASSERT(task); + + return server_task_type_need_embd(task->type); } bool need_logits() const { - return server_task_type_need_logits(task_type); + GGML_ASSERT(task); + + return server_task_type_need_logits(task->type); } // if the context does not have a memory module then all embeddings have to be computed within a single ubatch @@ -1542,18 +1767,22 @@ struct server_slot { } bool can_batch_with(server_slot & other_slot) const { - return task_type == other_slot.task_type && are_lora_equal(lora, other_slot.lora); + GGML_ASSERT(task); + + return task->type == other_slot.task->type && are_lora_equal(lora, other_slot.lora); } bool has_budget(const common_params & global_params) { - if (params.n_predict == -1 && global_params.n_predict == -1) { + GGML_ASSERT(task); + + if (task->params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless } n_remaining = -1; - if (params.n_predict != -1) { - n_remaining = params.n_predict - n_decoded; + if (task->params.n_predict != -1) { + n_remaining = task->params.n_predict - n_decoded; } else if (global_params.n_predict != -1) { n_remaining = global_params.n_predict - n_decoded; } @@ -1566,7 +1795,7 @@ struct server_slot { } bool can_speculate() const { - return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; + return ctx_dft; } void add_token(const completion_token_output & token) { @@ -1579,11 +1808,17 @@ struct server_slot { void release() { if (is_processing()) { + GGML_ASSERT(task); + SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated); t_last_used = ggml_time_us(); t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; state = SLOT_STATE_IDLE; + + task_prev = std::move(task); + task.reset(); + callback_on_release(id); } } @@ -1592,19 +1827,19 @@ struct server_slot { result_timings timings; timings.cache_n = n_prompt_tokens_cache; - timings.prompt_n = n_prompt_tokens_processed; - timings.prompt_ms = t_prompt_processing; + timings.prompt_n = n_prompt_tokens_processed; + timings.prompt_ms = t_prompt_processing; timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; - timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - timings.predicted_n = n_decoded; - timings.predicted_ms = t_token_generation; + timings.predicted_n = n_decoded; + timings.predicted_ms = t_token_generation; timings.predicted_per_token_ms = t_token_generation / n_decoded; - timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; // Add speculative metrics if (n_draft_total > 0) { - timings.draft_n = n_draft_total; + timings.draft_n = n_draft_total; timings.draft_n_accepted = n_draft_accepted; } @@ -1612,14 +1847,16 @@ struct server_slot { } const common_chat_msg & update_chat_msg(std::vector & diffs) { + GGML_ASSERT(task); + auto previous_msg = chat_msg; SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); auto new_msg = common_chat_parse( generated_text, /* is_partial= */ stop != STOP_TYPE_EOS, - params.oaicompat_chat_syntax); + task->params.oaicompat_chat_syntax); if (!new_msg.empty()) { - new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id); + new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); chat_msg = new_msg; diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg); } @@ -1627,9 +1864,11 @@ struct server_slot { } size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { + GGML_ASSERT(task); + size_t stop_pos = std::string::npos; - for (const std::string & word : params.antiprompt) { + for (const std::string & word : task->params.antiprompt) { size_t pos; if (is_full_stop) { @@ -1682,43 +1921,36 @@ struct server_slot { } json to_json(bool only_metrics = false) const { - if (only_metrics) { - return json { - {"id", id}, - {"id_task", id_task}, - {"n_ctx", n_ctx}, - {"speculative", can_speculate()}, - {"is_processing", is_processing()}, - {"params", params.to_json(true)}, - {"next_token", - { - {"has_next_token", has_next_token}, - {"has_new_line", has_new_line}, - {"n_remain", n_remaining}, - {"n_decoded", n_decoded}, - } - }, - }; - } + json res; - return json { + res = { {"id", id}, - {"id_task", id_task}, {"n_ctx", n_ctx}, {"speculative", can_speculate()}, {"is_processing", is_processing()}, - {"params", params.to_json()}, - {"prompt", prompt_tokens.detokenize(ctx, true)}, - {"next_token", + }; + + const auto & ptask = task ? task : task_prev; + + if (ptask) { + res["id_task"] = ptask->id; + res["params"] = ptask->params.to_json(only_metrics); + res["next_token"] = { { {"has_next_token", has_next_token}, {"has_new_line", has_new_line}, {"n_remain", n_remaining}, {"n_decoded", n_decoded}, - {"stopping_word", stopping_word}, } - }, - }; + }; + + if (!only_metrics) { + res["prompt"] = ptask->tokens.detokenize(ctx, true); + res["generated"] = generated_text; + } + } + + return res; } }; @@ -2109,11 +2341,14 @@ struct server_context { // slots / clients std::vector slots; - json default_generation_settings_for_props; + + int slots_debug = 0; server_queue queue_tasks; server_response queue_results; + std::unique_ptr prompt_cache; + server_metrics metrics; // Necessary similarity of prompt for slot selection @@ -2268,9 +2503,8 @@ struct server_context { slot.id = i; slot.ctx = ctx; slot.n_ctx = n_ctx_slot; - slot.n_predict = params_base.n_predict; slot.mctx = mctx; - slot.cache_tokens.has_mtmd = mctx != nullptr; + slot.prompt.tokens.has_mtmd = mctx != nullptr; if (model_dft) { slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); @@ -2286,16 +2520,13 @@ struct server_context { SRV_ERR("%s", "failed to create speculator\n"); return; } - for (auto &pair : params_base.speculative.replacements) { + for (auto & pair : params_base.speculative.replacements) { common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); } } SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); - slot.params.sampling = params_base.sampling; - slot.params.n_keep = params_base.n_keep; - slot.callback_on_release = [this](int) { queue_tasks.pop_deferred_task(); }; @@ -2305,7 +2536,14 @@ struct server_context { slots.push_back(std::move(slot)); } - default_generation_settings_for_props = slots[0].to_json(); + { + const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG"); + slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0; + + if (slots_debug) { + SRV_WRN("slots debug = %d\n", slots_debug); + } + } // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) @@ -2316,11 +2554,25 @@ struct server_context { metrics.init(); + if (params_base.cache_ram_mib != 0) { + if (params_base.cache_ram_mib < 0) { + SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit"); + } else { + SRV_WRN("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib); + } + SRV_WRN("%s", "use `--cache-ram 0` to disable the prompt cache\n"); + + prompt_cache = std::make_unique(params_base.cache_ram_mib, n_ctx); + } else { + SRV_WRN("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n"); + } + SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n"); + // thinking is enabled if: // 1. It's not explicitly disabled (reasoning_budget == 0) // 2. The chat template supports it const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); - SRV_INF("Enable thinking? %d\n", enable_thinking); + SRV_INF("thinking = %d\n", enable_thinking); oai_parser_opt = { /* use_jinja */ params_base.use_jinja, @@ -2347,10 +2599,11 @@ struct server_context { server_slot * get_available_slot(const server_task & task) { server_slot * ret = nullptr; + bool update_cache = false; + // find the slot that has at least n% prompt similarity if (ret == nullptr && slot_prompt_similarity != 0.0f) { - int lcs_len = 0; - float similarity = 0; + float sim_best = 0; for (server_slot & slot : slots) { // skip the slot if it is not available @@ -2358,27 +2611,34 @@ struct server_context { continue; } + const auto & tokens = slot.prompt.tokens; + // skip the slot if it does not contains cached tokens - if (slot.cache_tokens.empty()) { + if (tokens.empty()) { continue; } - // length of the Longest Common Subsequence between the current slot's prompt and the input prompt - int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); - - // fraction of the common subsequence length compared to the current slot's prompt length - float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); + // fraction of the Longest Common Prefix length with respect to the input prompt length + const float sim_cur = float(tokens.get_common_prefix(task.tokens)) / task.tokens.size(); // select the current slot if the criteria match - if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { - lcs_len = cur_lcs_len; - similarity = cur_similarity; + if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) { + sim_best = sim_cur; + ret = &slot; } } if (ret != nullptr) { - SLT_INF(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %.3f (> %.3f thold)\n", lcs_len, similarity, slot_prompt_similarity); + const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size(); + + SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n", + sim_best, slot_prompt_similarity, f_keep); + + // if we are about to lose a large portion of the existing context - save it in the prompt cache + if (f_keep < 0.5f) { + update_cache = true; + } } } @@ -2401,6 +2661,36 @@ struct server_context { if (ret != nullptr) { SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last); + + update_cache = true; + } + } + + if (ret) { + const auto & tokens = ret->prompt.tokens; + + update_cache = update_cache && prompt_cache; + + // cache prompts only for completion tasks + update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; + + // don't update the cache if the slot's context is empty + update_cache = update_cache && tokens.size() > 0; + + // TODO: mtmd does not support prompt cache + update_cache = update_cache && (ret->mctx == nullptr); + + if (update_cache) { + SRV_WRN("%s", "updating prompt cache\n"); + + const int64_t t_start = ggml_time_us(); + + ret->prompt_save(*prompt_cache); + ret->prompt_load(*prompt_cache, task.tokens); + + prompt_cache->update(); + + SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); } } @@ -2409,27 +2699,21 @@ struct server_context { bool launch_slot_with_task(server_slot & slot, server_task && task) { slot.reset(); - slot.id_task = task.id; - slot.index = task.index; - slot.task_type = task.type; - slot.params = std::move(task.params); - slot.prompt_tokens = std::move(task.prompt_tokens); - if (!are_lora_equal(slot.params.lora, slot.lora)) { + if (!are_lora_equal(task.params.lora, slot.lora)) { // if lora has changed, check to see if the cache should be cleared - if (lora_should_clear_cache(slot.lora, slot.params.lora)) { - SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), slot.params.lora.size()); - slot.cache_tokens.clear(); + if (lora_should_clear_cache(slot.lora, task.params.lora)) { + SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size()); + slot.prompt.tokens.clear(); } else { - SLT_INF(slot, "keeping cache for alora. %zu target loras\n", slot.params.lora.size()); + SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task.params.lora.size()); } - slot.lora = slot.params.lora; + slot.lora = task.params.lora; } // if using alora, make sure it's only a single one requested and active - size_t alora_invocation_start = slot.prompt_tokens.size(); + size_t alora_invocation_start = task.tokens.size(); if (lora_all_alora(slot.lora)) { - const auto & enabled_ids = lora_get_enabled_ids(slot.lora); // TODO: This will error out if a user requests two aloras, but only // provides the activation string for one. We could, instead search @@ -2448,10 +2732,10 @@ struct server_context { // scan backwards through the prompt tokens to find the last // occurrence of the invocation sequence int match_idx = static_cast(n_invocation_tokens) - 1; - for (int i = slot.prompt_tokens.size() - 1; i >= 0; --i) { + for (int i = task.tokens.size() - 1; i >= 0; --i) { // the token in this position matches the next token to find in // the invocation sequence - if (slot.prompt_tokens[i] == invocation_tokens[match_idx]) { + if (task.tokens[i] == invocation_tokens[match_idx]) { // if it's a full match, we've found the start if (match_idx == 0) { alora_invocation_start = i; @@ -2466,7 +2750,7 @@ struct server_context { } // if the activation string is not found, disable the alora - if (alora_invocation_start == slot.prompt_tokens.size()) { + if (alora_invocation_start == task.tokens.size()) { SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]); slot.lora[enabled_ids[0]].scale = 0.0f; } else { @@ -2475,24 +2759,20 @@ struct server_context { } } - if (!slot.prompt_tokens.validate(ctx)) { + if (!task.tokens.validate(ctx)) { send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); return false; } - SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); - if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { - // Might be better to reject the request with a 400 ? - SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict); - slot.params.n_predict = slot.n_predict; - } + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); + // initialize samplers { if (slot.smpl != nullptr) { common_sampler_free(slot.smpl); } - slot.smpl = common_sampler_init(model, slot.params.sampling); + slot.smpl = common_sampler_init(model, task.params.sampling); if (slot.smpl == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); @@ -2500,12 +2780,15 @@ struct server_context { } } + // initialize draft batch if (slot.ctx_dft) { llama_batch_free(slot.batch_spec); - slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); + slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1); } + slot.task = std::make_unique(std::move(task)); + slot.state = SLOT_STATE_STARTED; SLT_INF(slot, "%s", "processing task\n"); @@ -2527,7 +2810,7 @@ struct server_context { slot.sampled = result.tok; slot.generated_text += token_str; - if (slot.params.return_tokens) { + if (slot.task->params.return_tokens) { slot.generated_tokens.push_back(result.tok); } slot.has_next_token = true; @@ -2564,7 +2847,7 @@ struct server_context { } slot.add_token(result); - if (slot.params.stream) { + if (slot.task->params.stream) { send_partial_response(slot, result, false); } } @@ -2586,12 +2869,12 @@ struct server_context { slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); + SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict); } if (slot.has_new_line) { // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent - if (slot.params.n_indent > 0) { + if (slot.task->params.n_indent > 0) { // check the current indentation // TODO: improve by not doing it more than once for each new line if (slot.last_nl_pos > 0) { @@ -2603,7 +2886,7 @@ struct server_context { pos++; } - if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) { + if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) { slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; @@ -2630,11 +2913,11 @@ struct server_context { slot.has_new_line = true; // if we have seen a new line, we stop after a certain time limit, but only upon another new line - if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) { + if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) { slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms); + SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task->params.t_max_predict_ms); } } @@ -2645,7 +2928,7 @@ struct server_context { slot.has_next_token = false; SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", - slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx); + slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx); } if (llama_vocab_is_eog(vocab, result.tok)) { @@ -2657,7 +2940,7 @@ struct server_context { const auto n_ctx_train = llama_model_n_ctx_train(model); - if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { + if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) { slot.truncated = true; slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; // stop prediction @@ -2665,7 +2948,7 @@ struct server_context { SLT_WRN(slot, "n_predict (%d) is set for infinite generation. " "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n", - slot.params.n_predict, n_ctx_train); + slot.task->params.n_predict, n_ctx_train); } SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); @@ -2674,7 +2957,7 @@ struct server_context { } void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { - size_t n_probs = slot.params.sampling.n_probs; + size_t n_probs = slot.task->params.sampling.n_probs; size_t n_vocab = llama_vocab_n_tokens(vocab); if (post_sampling) { @@ -2728,7 +3011,7 @@ struct server_context { } void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(slot.id_task, error, type, slot.n_prompt_tokens, slot.n_ctx); + send_error(slot.task->id, error, type, slot.n_prompt_tokens(), slot.n_ctx); } void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { @@ -2749,7 +3032,7 @@ struct server_context { } // if multimodal is enabled, send an error and return false - bool ensure_no_mtmd(const int id_task) { + bool check_no_mtmd(const int id_task) { if (mctx) { send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); return false; @@ -2760,14 +3043,14 @@ struct server_context { void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { auto res = std::make_unique(); - res->id = slot.id_task; - res->index = slot.index; + res->id = slot.task->id; + res->index = slot.task->index; if (is_progress) { res->is_progress = true; - res->progress.total = slot.n_prompt_tokens; + res->progress.total = slot.n_prompt_tokens(); res->progress.cache = slot.n_prompt_tokens_cache; - res->progress.processed = slot.cache_tokens.size(); + res->progress.processed = slot.prompt.tokens.size(); res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt / 1000); } else { res->content = tkn.text_to_send; @@ -2777,21 +3060,21 @@ struct server_context { } res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; - res->post_sampling_probs = slot.params.post_sampling_probs; + res->n_prompt_tokens = slot.n_prompt_tokens(); + res->post_sampling_probs = slot.task->params.post_sampling_probs; - res->verbose = slot.params.verbose; - res->oaicompat = slot.params.oaicompat; - res->oaicompat_model = slot.params.oaicompat_model; - res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + res->verbose = slot.task->params.verbose; + res->oaicompat = slot.task->params.oaicompat; + res->oaicompat_model = slot.task->params.oaicompat_model; + res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; // populate res.probs_output - if (slot.params.sampling.n_probs > 0) { + if (slot.task->params.sampling.n_probs > 0) { res->prob_output = tkn; // copy the token probs } // populate timings if this is final response or timings_per_token is enabled - if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) { + if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) { res->timings = slot.get_timings(); } @@ -2800,36 +3083,37 @@ struct server_context { void send_final_response(server_slot & slot) { auto res = std::make_unique(); - res->id = slot.id_task; - res->id_slot = slot.id; - res->index = slot.index; + res->id = slot.task->id; + res->id_slot = slot.id; + + res->index = slot.task->index; res->content = slot.generated_text; res->tokens = std::move(slot.generated_tokens); res->timings = slot.get_timings(); - res->prompt = slot.prompt_tokens.detokenize(ctx, true); - res->response_fields = std::move(slot.params.response_fields); + res->prompt = slot.task->tokens.detokenize(ctx, true); + res->response_fields = std::move(slot.task->params.response_fields); res->truncated = slot.truncated; res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_prompt_tokens = slot.n_prompt_tokens(); res->n_tokens_cached = slot.n_past; res->has_new_line = slot.has_new_line; res->stopping_word = slot.stopping_word; res->stop = slot.stop; - res->post_sampling_probs = slot.params.post_sampling_probs; + res->post_sampling_probs = slot.task->params.post_sampling_probs; - res->verbose = slot.params.verbose; - res->stream = slot.params.stream; - res->include_usage = slot.params.include_usage; - res->oaicompat = slot.params.oaicompat; - res->oaicompat_model = slot.params.oaicompat_model; - res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; - res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); + res->verbose = slot.task->params.verbose; + res->stream = slot.task->params.stream; + res->include_usage = slot.task->params.include_usage; + res->oaicompat = slot.task->params.oaicompat; + res->oaicompat_model = slot.task->params.oaicompat_model; + res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; + res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); // populate res.probs_output - if (slot.params.sampling.n_probs > 0) { - if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) { + if (slot.task->params.sampling.n_probs > 0) { + if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) { const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); @@ -2843,17 +3127,17 @@ struct server_context { } } - res->generation_params = slot.params; // copy the parameters + res->generation_params = slot.task->params; // copy the parameters queue_results.send(std::move(res)); } void send_embedding(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); - res->id = slot.id_task; - res->index = slot.index; - res->n_tokens = slot.n_prompt_tokens; - res->oaicompat = slot.params.oaicompat; + res->id = slot.task->id; + res->index = slot.task->index; + res->n_tokens = slot.n_prompt_tokens(); + res->oaicompat = slot.task->params.oaicompat; const int n_embd = llama_model_n_embd(model); @@ -2880,12 +3164,12 @@ struct server_context { // normalize only when there is pooling if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { - common_embd_normalize(embd, embd_res.data(), n_embd, slot.params.embd_normalize); + common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize); res->embedding.push_back(embd_res); break; - } else { - res->embedding.emplace_back(embd, embd + n_embd); } + + res->embedding.emplace_back(embd, embd + n_embd); } SLT_DBG(slot, "%s", "sending embeddings\n"); @@ -2895,9 +3179,9 @@ struct server_context { void send_rerank(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); - res->id = slot.id_task; - res->index = slot.index; - res->n_tokens = slot.n_prompt_tokens; + res->id = slot.task->id; + res->index = slot.task->index; + res->n_tokens = slot.n_prompt_tokens(); for (int i = 0; i < batch.n_tokens; ++i) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { @@ -3034,7 +3318,7 @@ struct server_context { case SERVER_TASK_TYPE_EMBEDDING: case SERVER_TASK_TYPE_RERANK: { - const int id_slot = task.id_selected_slot; + const int id_slot = task.id_slot; server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); @@ -3061,7 +3345,7 @@ struct server_context { { // release slot linked with the task id for (auto & slot : slots) { - if (slot.id_task == task.id_target) { + if (slot.task && slot.task->id == task.id_target) { slot.release(); break; } @@ -3079,7 +3363,7 @@ struct server_context { int n_processing_slots = 0; for (server_slot & slot : slots) { - json slot_data = slot.to_json(true); + json slot_data = slot.to_json(slots_debug == 0); if (slot.is_processing()) { n_processing_slots++; @@ -3121,7 +3405,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_SAVE: { - if (!ensure_no_mtmd(task.id)) { + if (!check_no_mtmd(task.id)) { break; } @@ -3138,13 +3422,13 @@ struct server_context { break; } - const size_t token_count = slot->cache_tokens.size(); + const size_t token_count = slot->prompt.tokens.size(); const int64_t t_start = ggml_time_us(); std::string filename = task.slot_action.filename; std::string filepath = task.slot_action.filepath; - const llama_tokens & tokens = slot->cache_tokens.get_text_tokens(); + const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens(); const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); const int64_t t_end = ggml_time_us(); @@ -3162,7 +3446,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_RESTORE: { - if (!ensure_no_mtmd(task.id)) break; + if (!check_no_mtmd(task.id)) break; int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -3186,13 +3470,13 @@ struct server_context { size_t token_count = 0; size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); if (nread == 0) { - slot->cache_tokens.clear(); // KV may already been invalidated? + slot->prompt.tokens.clear(); // KV may already been invalidated? send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); break; } tokens.resize(token_count); - slot->cache_tokens.clear(); - slot->cache_tokens.insert(tokens); + slot->prompt.tokens.clear(); + slot->prompt.tokens.insert(tokens); const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; @@ -3209,7 +3493,9 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_ERASE: { - if (!ensure_no_mtmd(task.id)) break; + if (!check_no_mtmd(task.id)) { + break; + } int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -3224,9 +3510,9 @@ struct server_context { } // Erase token cache - const size_t n_erased = slot->cache_tokens.size(); + const size_t n_erased = slot->prompt.tokens.size(); llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1); - slot->cache_tokens.clear(); + slot->prompt.tokens.clear(); auto res = std::make_unique(); res->id = task.id; @@ -3282,8 +3568,8 @@ struct server_context { if (!params_base.ctx_shift) { // this check is redundant (for good) // we should never get here, because generation should already stopped in process_token() - slot.release(); send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); + slot.release(); continue; } @@ -3294,9 +3580,16 @@ struct server_context { } // Shift context - const int n_keep = slot.params.n_keep + add_bos_token; + int n_keep = slot.task->params.n_keep < 0 ? slot.n_prompt_tokens() : slot.task->params.n_keep; + + if (add_bos_token) { + n_keep += 1; + } + + n_keep = std::min(slot.n_ctx - 4, n_keep); + const int n_left = slot.n_past - n_keep; - const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); + const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2); SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); @@ -3305,14 +3598,14 @@ struct server_context { // add generated tokens to cache { - llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy + llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { new_tokens[i - n_discard] = new_tokens[i]; } - new_tokens.resize(slot.cache_tokens.size() - n_discard); - slot.cache_tokens.clear(); - slot.cache_tokens.insert(new_tokens); + new_tokens.resize(slot.prompt.tokens.size() - n_discard); + slot.prompt.tokens.clear(); + slot.prompt.tokens.insert(new_tokens); } slot.n_past -= n_discard; @@ -3328,7 +3621,8 @@ struct server_context { server_slot * slot_batched = nullptr; auto accept_special_token = [&](server_slot & slot, llama_token token) { - return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end(); + return params_base.special || + slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); }; // frist, add sampled tokens from any ongoing sequences @@ -3349,10 +3643,10 @@ struct server_context { common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); slot.n_past += 1; - slot.cache_tokens.push_back(slot.sampled); + slot.prompt.tokens.push_back(slot.sampled); SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); + slot.n_ctx, slot.n_past, (int) slot.prompt.tokens.size(), slot.truncated); } // process in chunks of params.n_batch @@ -3375,7 +3669,7 @@ struct server_context { // this slot still has a prompt to be processed if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { - auto & prompt_tokens = slot.prompt_tokens; + const auto & input_tokens = slot.task->tokens; // TODO: maybe move branch to outside of this loop in the future if (slot.state == SLOT_STATE_STARTED) { @@ -3383,104 +3677,64 @@ struct server_context { slot.t_start_generation = 0; slot.n_past = 0; - slot.n_prompt_tokens = prompt_tokens.size(); slot.state = SLOT_STATE_PROCESSING_PROMPT; - SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", + slot.n_ctx, slot.task->params.n_keep, slot.n_prompt_tokens()); // print prompt tokens (for debugging) /*if (1) { // first 16 tokens (avoid flooding logs) - for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + for (int i = 0; i < std::min(16, input_tokens.size()); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); } } else { // all - for (int i = 0; i < (int) prompt_tokens.size(); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + for (int i = 0; i < (int) input_tokens.size(); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); } }*/ // empty prompt passed -> release the slot and send empty response - if (prompt_tokens.empty()) { + if (input_tokens.empty()) { SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); - slot.release(); slot.print_timings(); send_final_response(slot); + slot.release(); + continue; } // TODO: support memory-less logits computation if (slot.need_logits() && !llama_get_memory(ctx)) { - slot.release(); send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); + slot.release(); continue; } if (!slot.can_split()) { - if (slot.n_prompt_tokens > n_ubatch) { - slot.release(); + if (slot.n_prompt_tokens() > n_ubatch) { send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); + slot.release(); continue; } - if (slot.n_prompt_tokens > slot.n_ctx) { - slot.release(); + if (slot.n_prompt_tokens() > slot.n_ctx) { send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE); + slot.release(); continue; } } else { - if (!params_base.ctx_shift) { - // if context shift is disabled, we make sure prompt size is smaller than KV size - // TODO: there should be a separate parameter that control prompt truncation - // context shift should be applied only during the generation phase - if (slot.n_prompt_tokens >= slot.n_ctx) { - slot.release(); - send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE); - continue; - } - } - if (slot.params.n_keep < 0) { - slot.params.n_keep = slot.n_prompt_tokens; - } - slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - - // if input prompt is too big, truncate it - if (slot.n_prompt_tokens >= slot.n_ctx) { - if (mctx) { - // we should never reach this - GGML_ABORT("not supported by multimodal"); - } - const int n_left = slot.n_ctx - slot.params.n_keep; - - const int n_block_size = n_left / 2; - const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - - const llama_tokens & curr_tokens = slot.prompt_tokens.get_text_tokens(); - llama_tokens new_tokens( - curr_tokens.begin(), - curr_tokens.begin() + slot.params.n_keep); - - new_tokens.insert( - new_tokens.end(), - curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, - curr_tokens.end()); - - prompt_tokens.clear(); - prompt_tokens.insert(new_tokens); - - slot.truncated = true; - slot.n_prompt_tokens = prompt_tokens.size(); - - SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); - - GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); + if (slot.n_prompt_tokens() >= slot.n_ctx) { + send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE); + slot.release(); + continue; } - if (slot.params.cache_prompt) { + if (slot.task->params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt - slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); + slot.n_past = slot.prompt.tokens.get_common_prefix(input_tokens); // if there is an alora invoked, don't cache after the invocation start if (slot.alora_invocation_start >= 0) { @@ -3500,13 +3754,13 @@ struct server_context { SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); - while (head_c < slot.cache_tokens.size() && - head_p < prompt_tokens.size()) { + while (head_c < slot.prompt.tokens.size() && + head_p < input_tokens.size()) { size_t n_match = 0; - while (head_c + n_match < slot.cache_tokens.size() && - head_p + n_match < prompt_tokens.size() && - slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { + while (head_c + n_match < slot.prompt.tokens.size() && + head_p + n_match < input_tokens.size() && + slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) { n_match++; } @@ -3523,7 +3777,7 @@ struct server_context { llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); for (size_t i = 0; i < n_match; i++) { - slot.cache_tokens.set_token(head_p + i, slot.cache_tokens[head_c + i]); + slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]); slot.n_past++; } @@ -3547,41 +3801,83 @@ struct server_context { // the largest pos_min required for a checkpoint to be useful const auto pos_min_thold = std::max(0, slot.n_past - n_swa); - if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { + if (slot.n_past > 0 && slot.n_past < (int) slot.prompt.tokens.size()) { const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); if (pos_min == -1) { - SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min); + SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); } + // when the prompt prefix does not match, print the tokens around the mismatch + // this is useful for debugging prompt caching + { + const int np0 = std::max(slot.n_past - 4, 0); + const int np1 = std::min(slot.n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size())); + + std::stringstream ss0; + std::stringstream ss1; + + std::stringstream st0; + std::stringstream st1; + + ss0 << "old: ... "; + ss1 << "new: ... "; + + for (int i = np0; i < np1; i++) { + if (i == slot.n_past) { + ss0 << " | "; + ss1 << " | "; + } + + { + const auto token = slot.prompt.tokens[i]; + const auto piece = common_token_to_piece(ctx, token); + ss0 << piece; + st0 << std::setw(8) << token; + } + + { + const auto token = slot.task->tokens[i]; + const auto piece = common_token_to_piece(ctx, token); + ss1 << piece; + st1 << std::setw(8) << token; + } + } + + SLT_WRN(slot, "%s\n", ss0.str().c_str()); + SLT_WRN(slot, "%s\n", ss1.str().c_str()); + + SLT_WRN(slot, "%s\n", st0.str().c_str()); + SLT_WRN(slot, "%s\n", st1.str().c_str()); + } + if (pos_min > pos_min_thold) { - SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); + SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); // search for a context checkpoint const auto it = std::find_if( - slot.ctx_checkpoints.rbegin(), - slot.ctx_checkpoints.rend(), + slot.prompt.checkpoints.rbegin(), + slot.prompt.checkpoints.rend(), [&](const auto & cur) { // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] return cur.pos_min < pos_min_thold; } ); - bool do_reset = it == slot.ctx_checkpoints.rend(); - //printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false"); + bool do_reset = it == slot.prompt.checkpoints.rend(); if (!do_reset) { // restore the context checkpoint - const size_t ctx_checkpoint_size = it->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + const size_t checkpoint_size = it->data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - if (n != ctx_checkpoint_size) { - SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024); + if (n != checkpoint_size) { + SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); do_reset = true; //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); } else { slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max)); - SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); } } @@ -3595,19 +3891,21 @@ struct server_context { { // erase any checkpoints with pos_min > pos_min_thold - for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) { - const auto & cur = slot.ctx_checkpoints[i]; + for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { + const auto & cur = *it; if (cur.pos_min > pos_min_thold) { SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); - slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i); + it = slot.prompt.checkpoints.erase(it); + } else { + ++it; } } } } // [TAG_PROMPT_LOGITS] - if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { - SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens); + if (slot.n_past == slot.n_prompt_tokens() && slot.n_past > 0) { + SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens()); slot.n_past--; SLT_WRN(slot, "n_past was set to %d\n", slot.n_past); } @@ -3618,7 +3916,7 @@ struct server_context { if (!slot.can_split()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { + if (batch.n_tokens + slot.n_prompt_tokens() > n_batch) { continue; } } @@ -3636,28 +3934,28 @@ struct server_context { SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past); // remove the non-common part from the cache - slot.cache_tokens.keep_first(slot.n_past); + slot.prompt.tokens.keep_first(slot.n_past); // check if we should process the image - if (slot.n_past < slot.n_prompt_tokens && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { + if (slot.n_past < slot.n_prompt_tokens() && input_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { // process the image int32_t new_n_past; - int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); - int32_t n_pos = new_n_past - slot.n_past; - + int32_t res = input_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); if (res != 0) { SLT_ERR(slot, "failed to process image, res = %d\n", res); - slot.release(); send_error(slot, "failed to process image", ERROR_TYPE_SERVER); + slot.release(); continue; } // add the image chunk to cache { - const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past); - slot.cache_tokens.push_back(chunk.get()); // copy + const auto & chunk = input_tokens.find_chunk(slot.n_past); + slot.prompt.tokens.push_back(chunk.get()); // copy } + const int32_t n_pos = new_n_past - slot.n_past; + slot.n_past += n_pos; slot.n_prompt_tokens_processed += n_pos; } @@ -3678,6 +3976,9 @@ struct server_context { bool do_checkpoint = params_base.n_ctx_checkpoints > 0; + // make checkpoints only for completion tasks + do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; + // make a checkpoint of the parts of the memory that cannot be rolled back. // checkpoints are created only if: // - the model uses SWA and we are not using `swa_full` @@ -3691,9 +3992,9 @@ struct server_context { ); // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + while (slot.n_past < slot.n_prompt_tokens() && batch.n_tokens < n_batch) { // get next token to process - llama_token cur_tok = slot.prompt_tokens[slot.n_past]; + llama_token cur_tok = input_tokens[slot.n_past]; if (cur_tok == LLAMA_TOKEN_NULL) { break; // end of text chunk } @@ -3707,36 +4008,33 @@ struct server_context { } // embedding requires all tokens in the batch to be output - const bool need_embd = server_task_type_need_embd(slot.task_type); - - common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); - slot.cache_tokens.push_back(cur_tok); + common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, slot.need_embd()); + slot.prompt.tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; slot.n_past++; // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. - if (do_checkpoint && slot.n_prompt_tokens - slot.n_past == 64) { + if (do_checkpoint && slot.n_prompt_tokens() - slot.n_past == 64) { break; } } // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / slot.n_prompt_tokens()); // entire prompt has been processed - if (slot.n_past == slot.n_prompt_tokens) { + if (slot.n_past == slot.n_prompt_tokens()) { slot.state = SLOT_STATE_DONE_PROMPT; GGML_ASSERT(batch.n_tokens > 0); - GGML_ASSERT((size_t) slot.n_prompt_tokens == slot.prompt_tokens.size()); common_sampler_reset(slot.smpl); // Process all prompt tokens through sampler system - for (int i = 0; i < slot.n_prompt_tokens; ++i) { - llama_token id = slot.prompt_tokens[i]; + for (int i = 0; i < slot.n_prompt_tokens(); ++i) { + llama_token id = input_tokens[i]; if (id != LLAMA_TOKEN_NULL) { common_sampler_accept(slot.smpl, id, false); } @@ -3757,21 +4055,22 @@ struct server_context { do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.ctx_checkpoints.empty() || pos_max > slot.ctx_checkpoints.back().pos_max + 64); + do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); if (do_checkpoint) { - while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { // make room for the new checkpoint, if needed - const auto & cur = slot.ctx_checkpoints.front(); + const auto & cur = slot.prompt.checkpoints.front(); + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin()); + slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); } const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{ + auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ /*.pos_min = */ pos_min, /*.pos_max = */ pos_max, /*.data = */ std::vector(checkpoint_size), @@ -3779,8 +4078,8 @@ struct server_context { llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); } } } @@ -3854,8 +4153,8 @@ struct server_context { if (!err.empty()) { SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); for (auto & slot : slots) { - slot.release(); send_error(slot, err); + slot.release(); } break; } @@ -3878,7 +4177,7 @@ struct server_context { for (auto & slot : slots) { // optionally send prompt processing progress if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.params.stream && slot.params.return_progress) { + if (slot.task->params.stream && slot.task->params.return_progress) { send_partial_response(slot, {}, true); } } @@ -3888,7 +4187,7 @@ struct server_context { } if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { + if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { // prompt evaluated for embedding send_embedding(slot, batch_view); slot.release(); @@ -3896,7 +4195,7 @@ struct server_context { continue; // continue loop of slots } - if (slot.task_type == SERVER_TASK_TYPE_RERANK) { + if (slot.task->type == SERVER_TASK_TYPE_RERANK) { send_rerank(slot, batch_view); slot.release(); slot.i_batch = -1; @@ -3934,16 +4233,17 @@ struct server_context { result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs - if (slot.params.sampling.n_probs > 0) { - populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx); + if (slot.task->params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); } if (!process_token(result, slot)) { // release slot because of stop condition - slot.release(); slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); + slot.release(); + continue; } } @@ -3964,7 +4264,7 @@ struct server_context { } // determine the max draft that fits the current slot state - int n_draft_max = slot.params.speculative.n_max; + int n_draft_max = slot.task->params.speculative.n_max; // note: n_past is not yet increased for the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts @@ -3976,8 +4276,8 @@ struct server_context { SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); - if (n_draft_max < slot.params.speculative.n_min) { - SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min); + if (n_draft_max < slot.task->params.speculative.n_min) { + SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min); continue; } @@ -3985,16 +4285,16 @@ struct server_context { llama_token id = slot.sampled; struct common_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; - params_spec.p_min = slot.params.speculative.p_min; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; + params_spec.p_min = slot.task->params.speculative.p_min; - const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens(); + const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); // ignore small drafts - if (slot.params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); + if (slot.task->params.speculative.n_min > (int) draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); continue; } @@ -4023,8 +4323,8 @@ struct server_context { // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; - slot.cache_tokens.push_back(id); - slot.cache_tokens.insert({ids.begin(), ids.end() - 1}); + slot.prompt.tokens.push_back(id); + slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1); @@ -4038,11 +4338,11 @@ struct server_context { // TODO: set result.probs if (!process_token(result, slot)) { - // release slot because of stop condition - slot.release(); slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); + slot.release(); + break; } } @@ -4310,18 +4610,18 @@ int main(int argc, char ** argv) { } // TODO: get rid of this dynamic_cast - auto res_metrics = dynamic_cast(result.get()); - GGML_ASSERT(res_metrics != nullptr); + auto res_task = dynamic_cast(result.get()); + GGML_ASSERT(res_task != nullptr); // optionally return "fail_on_no_slot" error if (req.has_param("fail_on_no_slot")) { - if (res_metrics->n_idle_slots == 0) { + if (res_task->n_idle_slots == 0) { res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); return; } } - res_ok(res, res_metrics->slots_data); + res_ok(res, res_task->slots_data); }; const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { @@ -4349,56 +4649,56 @@ int main(int argc, char ** argv) { } // TODO: get rid of this dynamic_cast - auto res_metrics = dynamic_cast(result.get()); - GGML_ASSERT(res_metrics != nullptr); + auto res_task = dynamic_cast(result.get()); + GGML_ASSERT(res_task != nullptr); // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names json all_metrics_def = json { {"counter", {{ {"name", "prompt_tokens_total"}, {"help", "Number of prompt tokens processed."}, - {"value", (uint64_t) res_metrics->n_prompt_tokens_processed_total} + {"value", (uint64_t) res_task->n_prompt_tokens_processed_total} }, { {"name", "prompt_seconds_total"}, {"help", "Prompt process time"}, - {"value", (uint64_t) res_metrics->t_prompt_processing_total / 1.e3} + {"value", (uint64_t) res_task->t_prompt_processing_total / 1.e3} }, { {"name", "tokens_predicted_total"}, {"help", "Number of generation tokens processed."}, - {"value", (uint64_t) res_metrics->n_tokens_predicted_total} + {"value", (uint64_t) res_task->n_tokens_predicted_total} }, { {"name", "tokens_predicted_seconds_total"}, {"help", "Predict process time"}, - {"value", (uint64_t) res_metrics->t_tokens_generation_total / 1.e3} + {"value", (uint64_t) res_task->t_tokens_generation_total / 1.e3} }, { {"name", "n_decode_total"}, {"help", "Total number of llama_decode() calls"}, - {"value", res_metrics->n_decode_total} + {"value", res_task->n_decode_total} }, { {"name", "n_past_max"}, {"help", "Largest observed n_past."}, - {"value", res_metrics->n_past_max} + {"value", res_task->n_past_max} }, { {"name", "n_busy_slots_per_decode"}, {"help", "Average number of busy slots per llama_decode() call"}, - {"value", (float) res_metrics->n_busy_slots_total / std::max((float) res_metrics->n_decode_total, 1.f)} + {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)} }}}, {"gauge", {{ {"name", "prompt_tokens_seconds"}, {"help", "Average prompt throughput in tokens/s."}, - {"value", res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.} + {"value", res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.} },{ {"name", "predicted_tokens_seconds"}, {"help", "Average generation throughput in tokens/s."}, - {"value", res_metrics->n_tokens_predicted ? 1.e3 / res_metrics->t_tokens_generation * res_metrics->n_tokens_predicted : 0.} + {"value", res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.} },{ {"name", "requests_processing"}, {"help", "Number of requests processing."}, - {"value", (uint64_t) res_metrics->n_processing_slots} + {"value", (uint64_t) res_task->n_processing_slots} },{ {"name", "requests_deferred"}, {"help", "Number of requests deferred."}, - {"value", (uint64_t) res_metrics->n_tasks_deferred} + {"value", (uint64_t) res_task->n_tasks_deferred} }}} }; @@ -4419,7 +4719,7 @@ int main(int argc, char ** argv) { } } - res.set_header("Process-Start-Time-Unix", std::to_string(res_metrics->t_start)); + res.set_header("Process-Start-Time-Unix", std::to_string(res_task->t_start)); res.set_content(prometheus.str(), "text/plain; version=0.0.4"); res.status = 200; // HTTP OK @@ -4543,9 +4843,22 @@ int main(int argc, char ** argv) { }; const auto handle_props = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + json default_generation_settings_for_props; + + { + slot_params params; + + params.sampling = ctx_server.params_base.sampling; + + default_generation_settings_for_props = json { + {"params", params.to_json(true)}, + {"n_ctx", ctx_server.slots[0].n_ctx}, + }; + } + // this endpoint is publicly available, please only return what is safe to be exposed json data = { - { "default_generation_settings", ctx_server.default_generation_settings_for_props }, + { "default_generation_settings", default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model.path }, { "modalities", json { @@ -4650,12 +4963,12 @@ int main(int argc, char ** argv) { task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = std::move(inputs[i]); - task.params = server_task::params_from_json_cmpl( + task.tokens = std::move(inputs[i]); + task.params = server_task::params_from_json_cmpl( ctx_server.ctx, ctx_server.params_base, data); - task.id_selected_slot = json_value(data, "id_slot", -1); + task.id_slot = json_value(data, "id_slot", -1); // OAI-compat task.params.oaicompat = oaicompat; @@ -5024,9 +5337,9 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < tokenized_prompts.size(); i++) { server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.prompt_tokens = std::move(tokenized_prompts[i]); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tokenized_prompts[i]); // OAI-compat task.params.oaicompat = oaicompat; @@ -5122,10 +5435,10 @@ int main(int argc, char ** argv) { tasks.reserve(documents.size()); for (size_t i = 0; i < documents.size(); i++) { auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); - server_task task = server_task(SERVER_TASK_TYPE_RERANK); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.prompt_tokens = std::move(tmp); + server_task task = server_task(SERVER_TASK_TYPE_RERANK); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tmp); tasks.push_back(std::move(task)); } @@ -5383,7 +5696,7 @@ int main(int argc, char ** argv) { #endif LOG_INF("%s: server is listening on %s - starting the main loop\n", __func__, - is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() : + is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() : string_format("http://%s:%d", params.hostname.c_str(), params.port).c_str()); // this call blocks the main thread until queue_tasks.terminate() is called diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py index 829af2ebe7bfb..720b136b05175 100644 --- a/tools/server/tests/unit/test_basic.py +++ b/tools/server/tests/unit/test_basic.py @@ -66,8 +66,7 @@ def test_server_slots(): assert len(res.body) == server.n_slots assert server.n_ctx is not None and server.n_slots is not None assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots - assert "params" in res.body[0] - assert res.body[0]["params"]["seed"] == server.seed + assert "params" not in res.body[0] def test_load_split_model(): diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 2979ed4bb7b12..6e5a3488e789b 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -19,8 +19,8 @@ def create_server(): (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None), (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'), (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None), (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None), (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None), ] @@ -54,7 +54,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", [ ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), - ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"), + ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"), ] ) def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index 11483e679a505..00ba78cf67c09 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -16,7 +16,7 @@ def create_server(): @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [ ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False), - ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True), + ("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True), ]) def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool): global server @@ -41,7 +41,7 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), - ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), + ("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), ]) def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): global server diff --git a/tools/server/tests/unit/test_ctx_shift.py b/tools/server/tests/unit/test_ctx_shift.py index 92e49f2bb05a4..4adbbde64f594 100644 --- a/tools/server/tests/unit/test_ctx_shift.py +++ b/tools/server/tests/unit/test_ctx_shift.py @@ -4,6 +4,12 @@ server = ServerPreset.tinyllama2() +SHORT_TEXT = """ +Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. +Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. +Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. +""".strip() + LONG_TEXT = """ Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. @@ -21,19 +27,18 @@ def create_server(): def test_ctx_shift_enabled(): - # the prompt is 301 tokens + # the prompt is 226 tokens # the slot context is 512/2 = 256 tokens - # the prompt is truncated to keep the last (301 - 256/2) = 173 tokens # 96 tokens are generated thanks to shifting the context when it gets full global server server.enable_ctx_shift = True server.start() res = server.make_request("POST", "/completion", data={ "n_predict": 96, - "prompt": LONG_TEXT, + "prompt": SHORT_TEXT, }) assert res.status_code == 200 - assert res.body["timings"]["prompt_n"] == 173 + assert res.body["timings"]["prompt_n"] == 226 assert res.body["timings"]["predicted_n"] == 96 assert res.body["truncated"] is True diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 4ca1423aaf2d4..f175115f4fd6a 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -31,10 +31,10 @@ using json = nlohmann::ordered_json; -#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) #define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) @@ -1102,6 +1102,7 @@ struct server_tokens { ~server_tokens() = default; // Prevent copying + // TODO: server_tokens should be copyable - remove this: server_tokens(const server_tokens&) = delete; server_tokens& operator=(const server_tokens&) = delete; @@ -1119,7 +1120,7 @@ struct server_tokens { } } - server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} + server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} // for debugging std::string str() const { @@ -1144,9 +1145,8 @@ struct server_tokens { auto it = map_pos_to_media.find(pos); if (it != map_pos_to_media.end()) { return it->second; - } else { - throw std::runtime_error("Chunk not found"); } + throw std::runtime_error("Chunk not found"); } void push_back(llama_token tok) { @@ -1170,7 +1170,7 @@ struct server_tokens { map_pos_to_media[start_pos] = std::move(new_chunk); } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { size_t n_tokens; - auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); for (size_t i = 0; i < n_tokens; ++i) { push_back(text_tokens[i]); } @@ -1190,7 +1190,7 @@ struct server_tokens { // We could also just check, but this will prevent silently dropping MTMD data. GGML_ASSERT(has_mtmd); for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) { - auto chunk = tokens.map_pos_to_media[it->first].get(); + auto * chunk = tokens.map_pos_to_media[it->first].get(); mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); map_pos_to_media[start_pos+it->first] = std::move(new_chunk); } @@ -1271,33 +1271,52 @@ struct server_tokens { } size_t get_common_prefix(const server_tokens & b) const { - size_t max_idx = std::min(tokens.size(), b.tokens.size()); + const size_t max_idx = std::min(tokens.size(), b.tokens.size()); + + if (!has_mtmd) { + for (size_t i = 0; i < max_idx; ++i) { + if (tokens[i] == b.tokens[i]) { + continue; + } + + return i; + } + + return max_idx; + } + for (size_t i = 0; i < max_idx; ++i) { - auto & ai = tokens[i]; - auto & bi = b.tokens[i]; + const llama_token ai = tokens[i]; + const llama_token bi = b.tokens[i]; if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { - GGML_ASSERT(has_mtmd); const auto & a_chunk = find_chunk(i); const auto & b_chunk = b.find_chunk(i); + GGML_ASSERT(a_chunk && b_chunk); - std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get()); - std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get()); - size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get()); - size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get()); - if (ai_id == bi_id && a_pos == b_pos) { - GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen - i += a_pos - 1; // will be +1 by the for loop + + const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get()); + const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get()); + + const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get()); + const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get()); + + if (id_ai == id_bi && pos_a == pos_b) { + GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen + i += pos_a - 1; // will be +1 by the for loop continue; - } else { - return i; } - } else if (ai == bi) { - continue; - } else { + return i; } + + if (ai == bi) { + continue; + } + + return i; } + return max_idx; // all tokens are equal } @@ -1308,7 +1327,7 @@ struct server_tokens { const int32_t n_vocab = llama_vocab_n_tokens(vocab); for (size_t i = 0; i < tokens.size(); ++i) { - auto & t = tokens[i]; + const auto & t = tokens[i]; if (t == LLAMA_TOKEN_NULL) { try { const auto & chunk = find_chunk(i); @@ -1330,8 +1349,8 @@ struct server_tokens { mtmd_context * mctx, llama_pos n_past, int32_t seq_id, - llama_pos & n_pos_out) { - auto & chunk = find_chunk(n_past); + llama_pos & n_pos_out) const { + const auto & chunk = find_chunk(n_past); const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio"; SRV_INF("processing %s...\n", name);