From 8f36df8fc9c82eb2b56257217957a09509170e87 Mon Sep 17 00:00:00 2001 From: ngxson Date: Tue, 23 Jan 2024 18:13:38 +0100 Subject: [PATCH] server: fix a race condition cause by "request_completion" --- examples/server/server.cpp | 60 +++++++++++++++++++++++++------------- examples/server/utils.hpp | 8 +++-- 2 files changed, 44 insertions(+), 24 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 54dac86910b11..39283613256ed 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1122,9 +1122,10 @@ struct llama_server_context queue_results.send(res); } - int request_completion(json data, bool infill, bool embedding, int multitask_id) + void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id) { task_server task; + task.id = task_id; task.target_id = 0; task.data = std::move(data); task.infill_mode = infill; @@ -1135,11 +1136,11 @@ struct llama_server_context // when a completion task's prompt array is not a singleton, we split it into multiple requests if (task.data.count("prompt") && task.data.at("prompt").size() > 1) { - return split_multiprompt_task(task); + split_multiprompt_task(task_id, task); } // otherwise, it's a single-prompt task, we actually queue it - return queue_tasks.post(task); + queue_tasks.post(task); } // for multiple images processing @@ -1218,25 +1219,30 @@ struct llama_server_context queue_tasks.post(task); } - int split_multiprompt_task(task_server& multiprompt_task) + void split_multiprompt_task(int multitask_id, task_server& multiprompt_task) { int prompt_count = multiprompt_task.data.at("prompt").size(); assert(prompt_count > 1); - int multitask_id = queue_tasks.get_next_id(); + // generate all the ID for subtask std::vector subtask_ids(prompt_count); for (int i = 0; i < prompt_count; i++) + { + subtask_ids[i] = queue_tasks.get_new_id(); + } + + // queue up the multitask so we can track its subtask progression + queue_tasks.add_multitask(multitask_id, subtask_ids); + + // add subtasks + for (int i = 0; i < prompt_count; i++) { json subtask_data = multiprompt_task.data; subtask_data["prompt"] = subtask_data["prompt"][i]; // subtasks inherit everything else (infill mode, embedding mode, etc.) - subtask_ids[i] = request_completion(subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id); + request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id); } - - // queue up the multitask so we can track its subtask progression - queue_tasks.add_multitask(multitask_id, subtask_ids); - return multitask_id; } void process_single_task(task_server& task) @@ -2493,8 +2499,9 @@ int main(int argc, char **argv) return; } json data = json::parse(req.body); - const int task_id = llama.request_completion(data, false, false, -1); + const int task_id = llama.queue_tasks.get_new_id(); llama.queue_results.add_waiting_task_id(task_id); + llama.request_completion(task_id, data, false, false, -1); if (!json_value(data, "stream", false)) { std::string completion_text; task_result result = llama.queue_results.recv(task_id); @@ -2505,9 +2512,8 @@ int main(int argc, char **argv) { res.status = 404; res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); - llama.queue_results.remove_waiting_task_id(task_id); - return; } + llama.queue_results.remove_waiting_task_id(task_id); } else { const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) { @@ -2546,8 +2552,9 @@ int main(int argc, char **argv) break; } } - sink.done(); + llama.queue_results.remove_waiting_task_id(task_id); + sink.done(); return true; }; @@ -2592,8 +2599,9 @@ int main(int argc, char **argv) } json data = oaicompat_completion_params_parse(json::parse(req.body)); - const int task_id = llama.request_completion(data, false, false, -1); + const int task_id = llama.queue_tasks.get_new_id(); llama.queue_results.add_waiting_task_id(task_id); + llama.request_completion(task_id, data, false, false, -1); if (!json_value(data, "stream", false)) { std::string completion_text; @@ -2608,9 +2616,8 @@ int main(int argc, char **argv) } else { res.status = 500; res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); - llama.queue_results.remove_waiting_task_id(task_id); - return; } + llama.queue_results.remove_waiting_task_id(task_id); } else { const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) { while (true) { @@ -2671,7 +2678,9 @@ int main(int argc, char **argv) return; } json data = json::parse(req.body); - const int task_id = llama.request_completion(data, true, false, -1); + const int task_id = llama.queue_tasks.get_new_id(); + llama.queue_results.add_waiting_task_id(task_id); + llama.request_completion(task_id, data, true, false, -1); if (!json_value(data, "stream", false)) { std::string completion_text; task_result result = llama.queue_results.recv(task_id); @@ -2683,8 +2692,8 @@ int main(int argc, char **argv) { res.status = 404; res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); - return; } + llama.queue_results.remove_waiting_task_id(task_id); } else { const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink & sink) { while (true) @@ -2700,6 +2709,7 @@ int main(int argc, char **argv) }); if (!sink.write(str.c_str(), str.size())) { + llama.queue_results.remove_waiting_task_id(task_id); return false; } if (result.stop) @@ -2713,8 +2723,8 @@ int main(int argc, char **argv) } } + llama.queue_results.remove_waiting_task_id(task_id); sink.done(); - return true; }; @@ -2788,8 +2798,16 @@ int main(int argc, char **argv) image_data = ""; } - const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1); + // create and queue the task + const int task_id = llama.queue_tasks.get_new_id(); + llama.queue_results.add_waiting_task_id(task_id); + llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1); + + // get the result task_result result = llama.queue_results.recv(task_id); + llama.queue_results.remove_waiting_task_id(task_id); + + // send the result return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); }); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 3ea697f59e5b5..e2b6065f734a2 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -203,7 +203,9 @@ struct llama_server_queue { // Add a new task to the end of the queue int post(task_server task) { std::unique_lock lock(mutex_tasks); - task.id = id++; + if (task.id == -1) { + task.id = id++; + } queue_tasks.push_back(std::move(task)); condition_tasks.notify_one(); return task.id; @@ -215,8 +217,8 @@ struct llama_server_queue { queue_tasks_deferred.push_back(std::move(task)); } - // Get the next task id - int get_next_id() { + // Get the next id for creating anew task + int get_new_id() { std::unique_lock lock(mutex_tasks); return id++; }