diff --git a/examples/server/README.md b/examples/server/README.md index e17595fe87f25..930ae15f64d8b 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -368,15 +368,16 @@ node index.js ## API Endpoints -### GET `/health`: Returns the current state of the server +### GET `/health`: Returns heath check result - - 503 -> `{"status": "loading model"}` if the model is still being loaded. - - 500 -> `{"status": "error"}` if the model failed to load. - - 200 -> `{"status": "ok", "slots_idle": 1, "slots_processing": 2 }` if the model is successfully loaded and the server is ready for further requests mentioned below. - - 200 -> `{"status": "no slot available", "slots_idle": 0, "slots_processing": 32}` if no slots are currently available. - - 503 -> `{"status": "no slot available", "slots_idle": 0, "slots_processing": 32}` if the query parameter `fail_on_no_slot` is provided and no slots are currently available. +**Response format** - If the query parameter `include_slots` is passed, `slots` field will contain internal slots data except if `--slots-endpoint-disable` is set. +- HTTP status code 503 + - Body: `{"error": {"code": 503, "message": "Loading model", "type": "unavailable_error"}}` + - Explanation: the model is still being loaded. +- HTTP status code 200 + - Body: `{"status": "ok" }` + - Explanation: the model is successfully loaded and the server is ready. ### POST `/completion`: Given a `prompt`, it returns the predicted completion. @@ -639,10 +640,16 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte }' ``` -### GET `/slots`: Returns the current slots processing state. Can be disabled with `--slots-endpoint-disable`. +### GET `/slots`: Returns the current slots processing state + +This endpoint can be disabled with `--no-slots` + +If query param `?fail_on_no_slot=1` is set, this endpoint will respond with status code 503 if there is no available slots. **Response format** +Example: + ```json [ { @@ -702,7 +709,13 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte ] ``` -### GET `/metrics`: Prometheus compatible metrics exporter endpoint if `--metrics` is enabled: +Possible values for `slot[i].state` are: +- `0`: SLOT_STATE_IDLE +- `1`: SLOT_STATE_PROCESSING + +### GET `/metrics`: Prometheus compatible metrics exporter + +This endpoint is only accessible if `--metrics` is set. Available metrics: - `llamacpp:prompt_tokens_total`: Number of prompt tokens processed. @@ -767,6 +780,10 @@ Available metrics: ### GET `/lora-adapters`: Get list of all LoRA adapters +This endpoint returns the loaded LoRA adapters. You can add adapters using `--lora` when starting the server, for example: `--lora my_adapter_1.gguf --lora my_adapter_2.gguf ...` + +By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply` + If an adapter is disabled, the scale will be set to 0. **Response format** diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e073f5813d459..ce711eadd29ac 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -15,6 +15,8 @@ // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" +// mime type for sending response +#define MIMETYPE_JSON "application/json; charset=utf-8" // auto generated files (update with ./deps.sh) #include "colorthemes.css.hpp" @@ -67,7 +69,6 @@ enum slot_command { enum server_state { SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet SERVER_STATE_READY, // Server is ready and model is loaded - SERVER_STATE_ERROR // An error occurred, load_model failed }; enum server_task_type { @@ -695,6 +696,7 @@ struct server_context { add_bos_token = llama_add_bos_token(model); has_eos_token = !llama_add_eos_token(model); + return true; } @@ -2555,19 +2557,19 @@ int main(int argc, char ** argv) { svr->set_default_headers({{"Server", "llama.cpp"}}); // CORS preflight - svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) { + // Access-Control-Allow-Origin is already set by middleware res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Methods", "POST"); res.set_header("Access-Control-Allow-Headers", "*"); - return res.set_content("", "application/json; charset=utf-8"); + return res.set_content("", "text/html"); // blank response, no data }); svr->set_logger(log_server_request); auto res_error = [](httplib::Response & res, json error_data) { json final_response {{"error", error_data}}; - res.set_content(final_response.dump(), "application/json; charset=utf-8"); + res.set_content(final_response.dump(), MIMETYPE_JSON); res.status = json_value(error_data, "code", 500); }; @@ -2597,11 +2599,6 @@ int main(int argc, char ** argv) { svr->set_read_timeout (params.timeout_read); svr->set_write_timeout(params.timeout_write); - if (!svr->bind_to_port(params.hostname, params.port)) { - fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", params.hostname.c_str(), params.port); - return 1; - } - std::unordered_map log_data; log_data["hostname"] = params.hostname; @@ -2617,35 +2614,6 @@ int main(int argc, char ** argv) { // Necessary similarity of prompt for slot selection ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; - // load the model - if (!ctx_server.load_model(params)) { - state.store(SERVER_STATE_ERROR); - return 1; - } else { - ctx_server.init(); - state.store(SERVER_STATE_READY); - } - - LOG_INFO("model loaded", {}); - - const auto model_meta = ctx_server.model_meta(); - - // if a custom chat template is not supplied, we will use the one that comes with the model (if any) - if (params.chat_template.empty()) { - if (!ctx_server.validate_model_chat_template()) { - LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {}); - params.chat_template = "chatml"; - } - } - - // print sample chat example to make it clear which template is used - { - LOG_INFO("chat template", { - {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)}, - {"built_in", params.chat_template.empty()}, - }); - } - // // Middlewares // @@ -2689,8 +2657,6 @@ int main(int argc, char ** argv) { } // API key is invalid or not provided - // TODO: make another middleware for CORS related logic - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); LOG_WARNING("Unauthorized: Invalid API Key", {}); @@ -2698,8 +2664,21 @@ int main(int argc, char ** argv) { return false; }; + auto middleware_server_state = [&res_error, &state](const httplib::Request &, httplib::Response & res) { + server_state current_state = state.load(); + if (current_state == SERVER_STATE_LOADING_MODEL) { + res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); + return false; + } + return true; + }; + // register server middlewares - svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) { + svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + if (!middleware_server_state(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } if (!middleware_validate_api_key(req, res)) { return httplib::Server::HandlerResponse::Handled; } @@ -2710,62 +2689,15 @@ int main(int argc, char ** argv) { // Route handlers (or controllers) // - const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) { - server_state current_state = state.load(); - switch (current_state) { - case SERVER_STATE_READY: - { - // request slots data using task queue - server_task task; - task.id = ctx_server.queue_tasks.get_new_id(); - task.type = SERVER_TASK_TYPE_METRICS; - task.id_target = -1; - - ctx_server.queue_results.add_waiting_task_id(task.id); - ctx_server.queue_tasks.post(task); - - // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); - ctx_server.queue_results.remove_waiting_task_id(task.id); - - const int n_idle_slots = result.data.at("idle"); - const int n_processing_slots = result.data.at("processing"); - - json health = { - {"status", "ok"}, - {"slots_idle", n_idle_slots}, - {"slots_processing", n_processing_slots} - }; - - res.status = 200; // HTTP OK - if (params.endpoint_slots && req.has_param("include_slots")) { - health["slots"] = result.data.at("slots"); - } - - if (n_idle_slots == 0) { - health["status"] = "no slot available"; - if (req.has_param("fail_on_no_slot")) { - res.status = 503; // HTTP Service Unavailable - } - } - - res.set_content(health.dump(), "application/json"); - break; - } - case SERVER_STATE_LOADING_MODEL: - { - res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); - } break; - case SERVER_STATE_ERROR: - { - res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER)); - } break; - } + const auto handle_health = [&](const httplib::Request &, httplib::Response & res) { + // error and loading states are handled by middleware + json health = {{"status", "ok"}}; + res.set_content(health.dump(), "application/json"); }; - const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) { + const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) { if (!params.endpoint_slots) { - res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2783,13 +2715,22 @@ int main(int argc, char ** argv) { server_task_result result = ctx_server.queue_results.recv(task.id); ctx_server.queue_results.remove_waiting_task_id(task.id); - res.set_content(result.data.at("slots").dump(), "application/json"); + // optionally return "fail_on_no_slot" error + const int n_idle_slots = result.data.at("idle"); + if (req.has_param("fail_on_no_slot")) { + if (n_idle_slots == 0) { + res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); + return; + } + } + + res.set_content(result.data.at("slots").dump(), MIMETYPE_JSON); res.status = 200; // HTTP OK }; const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { if (!params.endpoint_metrics) { - res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2914,7 +2855,7 @@ int main(int argc, char ** argv) { if (result.error) { res_error(res, result.data); } else { - res.set_content(result.data.dump(), "application/json"); + res.set_content(result.data.dump(), MIMETYPE_JSON); } }; @@ -2944,7 +2885,7 @@ int main(int argc, char ** argv) { if (result.error) { res_error(res, result.data); } else { - res.set_content(result.data.dump(), "application/json"); + res.set_content(result.data.dump(), MIMETYPE_JSON); } }; @@ -2964,13 +2905,11 @@ int main(int argc, char ** argv) { if (result.error) { res_error(res, result.data); } else { - res.set_content(result.data.dump(), "application/json"); + res.set_content(result.data.dump(), MIMETYPE_JSON); } }; const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - std::string id_slot_str = req.path_params.at("id_slot"); int id_slot; @@ -2994,7 +2933,7 @@ int main(int argc, char ** argv) { } }; - const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) { std::string template_key = "tokenizer.chat_template", curr_tmpl; int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); if (tlen > 0) { @@ -3003,7 +2942,6 @@ int main(int argc, char ** argv) { curr_tmpl = std::string(curr_tmpl_buf.data(), tlen); } } - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { { "system_prompt", ctx_server.system_prompt.c_str() }, { "default_generation_settings", ctx_server.default_generation_settings_for_props }, @@ -3011,7 +2949,7 @@ int main(int argc, char ** argv) { { "chat_template", curr_tmpl.c_str() } }; - res.set_content(data.dump(), "application/json; charset=utf-8"); + res.set_content(data.dump(), MIMETYPE_JSON); }; const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { @@ -3020,8 +2958,6 @@ int main(int argc, char ** argv) { return; } - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - json data = json::parse(req.body); const int id_task = ctx_server.queue_tasks.get_new_id(); @@ -3032,7 +2968,7 @@ int main(int argc, char ** argv) { if (!json_value(data, "stream", false)) { server_task_result result = ctx_server.queue_results.recv(id_task); if (!result.error && result.stop) { - res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); + res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); } else { res_error(res, result.data); } @@ -3095,9 +3031,7 @@ int main(int argc, char ** argv) { } }; - const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - + const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { json models = { {"object", "list"}, {"data", { @@ -3106,12 +3040,12 @@ int main(int argc, char ** argv) { {"object", "model"}, {"created", std::time(0)}, {"owned_by", "llamacpp"}, - {"meta", model_meta} + {"meta", ctx_server.model_meta()} }, }} }; - res.set_content(models.dump(), "application/json; charset=utf-8"); + res.set_content(models.dump(), MIMETYPE_JSON); }; const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { @@ -3119,8 +3053,6 @@ int main(int argc, char ** argv) { res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } - - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); const int id_task = ctx_server.queue_tasks.get_new_id(); @@ -3135,7 +3067,7 @@ int main(int argc, char ** argv) { if (!result.error && result.stop) { json result_oai = format_final_response_oaicompat(data, result.data, completion_id); - res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); + res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); } else { res_error(res, result.data); } @@ -3197,8 +3129,6 @@ int main(int argc, char ** argv) { return; } - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - json data = json::parse(req.body); const int id_task = ctx_server.queue_tasks.get_new_id(); @@ -3209,7 +3139,7 @@ int main(int argc, char ** argv) { if (!json_value(data, "stream", false)) { server_task_result result = ctx_server.queue_results.recv(id_task); if (!result.error && result.stop) { - res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); + res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); } else { res_error(res, result.data); } @@ -3257,7 +3187,6 @@ int main(int argc, char ** argv) { }; const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::vector tokens; @@ -3266,11 +3195,10 @@ int main(int argc, char ** argv) { tokens = ctx_server.tokenize(body.at("content"), add_special); } const json data = format_tokenizer_response(tokens); - return res.set_content(data.dump(), "application/json; charset=utf-8"); + return res.set_content(data.dump(), MIMETYPE_JSON); }; const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::string content; @@ -3280,12 +3208,10 @@ int main(int argc, char ** argv) { } const json data = format_detokenized_response(content); - return res.set_content(data.dump(), "application/json; charset=utf-8"); + return res.set_content(data.dump(), MIMETYPE_JSON); }; const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - const json body = json::parse(req.body); bool is_openai = false; @@ -3331,11 +3257,10 @@ int main(int argc, char ** argv) { json root = is_openai ? format_embeddings_response_oaicompat(body, responses) : responses[0]; - return res.set_content(root.dump(), "application/json; charset=utf-8"); + return res.set_content(root.dump(), MIMETYPE_JSON); }; - const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { json result = json::array(); for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) { auto & la = ctx_server.lora_adapters[i]; @@ -3345,13 +3270,11 @@ int main(int argc, char ** argv) { {"scale", la.scale}, }); } - res.set_content(result.dump(), "application/json"); + res.set_content(result.dump(), MIMETYPE_JSON); res.status = 200; // HTTP OK }; const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - const std::vector body = json::parse(req.body); int max_idx = ctx_server.lora_adapters.size(); @@ -3379,7 +3302,7 @@ int main(int argc, char ** argv) { server_task_result result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - res.set_content(result.data.dump(), "application/json"); + res.set_content(result.data.dump(), MIMETYPE_JSON); res.status = 200; // HTTP OK }; @@ -3455,35 +3378,75 @@ int main(int argc, char ** argv) { log_data["n_threads_http"] = std::to_string(params.n_threads_http); svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); }; - LOG_INFO("HTTP server listening", log_data); + // clean up function, to be called before exit + auto clean_up = [&svr]() { + svr->stop(); + llama_backend_free(); + }; - // run the HTTP server in a thread - see comment below - std::thread t([&]() { - if (!svr->listen_after_bind()) { - state.store(SERVER_STATE_ERROR); - return 1; + // bind HTTP listen port, run the HTTP server in a thread + if (!svr->bind_to_port(params.hostname, params.port)) { + LOG_ERROR("couldn't bind HTTP server socket", { + {"hostname", params.hostname}, + {"port", params.port}, + }); + clean_up(); + LOG_ERROR("exiting due to HTTP server error", {}); + return 1; + } + std::thread t([&]() { svr->listen_after_bind(); }); + svr->wait_until_ready(); + + LOG_INFO("HTTP server is listening", log_data); + + // load the model + LOG_INFO("loading model", log_data); + if (!ctx_server.load_model(params)) { + clean_up(); + t.join(); + LOG_ERROR("exiting due to model loading error", {}); + return 1; + } else { + ctx_server.init(); + state.store(SERVER_STATE_READY); + + LOG_INFO("model loaded", {}); + + // if a custom chat template is not supplied, we will use the one that comes with the model (if any) + if (params.chat_template.empty()) { + if (!ctx_server.validate_model_chat_template()) { + LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {}); + params.chat_template = "chatml"; + } } - return 0; - }); + // print sample chat example to make it clear which template is used + { + LOG_INFO("chat template", { + {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)}, + {"built_in", params.chat_template.empty()}, + }); + } - ctx_server.queue_tasks.on_new_task(std::bind( - &server_context::process_single_task, &ctx_server, std::placeholders::_1)); - ctx_server.queue_tasks.on_finish_multitask(std::bind( - &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1)); - ctx_server.queue_tasks.on_update_slots(std::bind( - &server_context::update_slots, &ctx_server)); - ctx_server.queue_results.on_multitask_update(std::bind( - &server_queue::update_multitask, - &ctx_server.queue_tasks, - std::placeholders::_1, - std::placeholders::_2, - std::placeholders::_3 - )); - - shutdown_handler = [&](int) { - ctx_server.queue_tasks.terminate(); - }; + ctx_server.queue_tasks.on_new_task(std::bind( + &server_context::process_single_task, &ctx_server, std::placeholders::_1)); + ctx_server.queue_tasks.on_finish_multitask(std::bind( + &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1)); + ctx_server.queue_tasks.on_update_slots(std::bind( + &server_context::update_slots, &ctx_server)); + ctx_server.queue_results.on_multitask_update(std::bind( + &server_queue::update_multitask, + &ctx_server.queue_tasks, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3 + )); + + shutdown_handler = [&](int) { + ctx_server.queue_tasks.terminate(); + }; + ctx_server.queue_tasks.start_loop(); + } #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; @@ -3499,12 +3462,8 @@ int main(int argc, char ** argv) { SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif - ctx_server.queue_tasks.start_loop(); - - svr->stop(); + clean_up(); t.join(); - llama_backend_free(); - return 0; } diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 6705a34fc4696..1ba7b60b69c46 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -205,27 +205,20 @@ def step_start_server(context): async def step_wait_for_the_server_to_be_started(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str): match expecting_status: case 'healthy': - await wait_for_health_status(context, context.base_url, 200, 'ok', - timeout=30) + await wait_for_slots_status(context, context.base_url, 200, + timeout=30) case 'ready' | 'idle': - await wait_for_health_status(context, context.base_url, 200, 'ok', - timeout=30, - params={'fail_on_no_slot': 0, 'include_slots': 0}, - slots_idle=context.n_slots, - slots_processing=0, - expected_slots=[{'id': slot_id, 'state': 0} - for slot_id in - range(context.n_slots if context.n_slots else 1)]) + await wait_for_slots_status(context, context.base_url, 200, + timeout=30, + params={'fail_on_no_slot': 1}, + slots_idle=context.n_slots, + slots_processing=0) case 'busy': - await wait_for_health_status(context, context.base_url, 503, - 'no slot available', - params={'fail_on_no_slot': 0, 'include_slots': 0}, - slots_idle=0, - slots_processing=context.n_slots, - expected_slots=[{'id': slot_id, 'state': 1} - for slot_id in - range(context.n_slots if context.n_slots else 1)]) + await wait_for_slots_status(context, context.base_url, 503, + params={'fail_on_no_slot': 1}, + slots_idle=0, + slots_processing=context.n_slots) case _: assert False, "unknown status" @@ -1187,17 +1180,15 @@ async def gather_tasks_results(context): return n_completions -async def wait_for_health_status(context, - base_url, - expected_http_status_code, - expected_health_status, - timeout=3, - params=None, - slots_idle=None, - slots_processing=None, - expected_slots=None): +async def wait_for_slots_status(context, + base_url, + expected_http_status_code, + timeout=3, + params=None, + slots_idle=None, + slots_processing=None): if context.debug: - print(f"Starting checking for health for expected_health_status={expected_health_status}") + print(f"Starting checking for health for expected_http_status_code={expected_http_status_code}") interval = 0.5 counter = 0 if 'GITHUB_ACTIONS' in os.environ: @@ -1205,26 +1196,19 @@ async def wait_for_health_status(context, async with aiohttp.ClientSession() as session: while True: - async with await session.get(f'{base_url}/health', params=params) as health_response: - status_code = health_response.status - health = await health_response.json() + async with await session.get(f'{base_url}/slots', params=params) as slots_response: + status_code = slots_response.status + slots = await slots_response.json() if context.debug: - print(f"HEALTH - response for expected health status='{expected_health_status}' on " - f"'{base_url}/health'?{params} is {health}\n") - if (status_code == expected_http_status_code - and health['status'] == expected_health_status - and (slots_idle is None or health['slots_idle'] == slots_idle) - and (slots_processing is None or health['slots_processing'] == slots_processing)): - if expected_slots is not None: - assert_slots_status(health['slots'], expected_slots) - return - if (status_code == expected_http_status_code - and health['status'] == expected_health_status - and (slots_idle is None or health['slots_idle'] == slots_idle) - and (slots_processing is None or health['slots_processing'] == slots_processing)): - if expected_slots is not None: - assert_slots_status(health['slots'], expected_slots) + print(f"slots responses {slots}\n") + if status_code == 503 and status_code == expected_http_status_code: return + if status_code == 200 and status_code == expected_http_status_code: + n_slots_idle = sum(1 if slot["state"] == 0 else 0 for slot in slots) + n_slots_processing = sum(1 if slot["state"] != 0 else 0 for slot in slots) + if ((slots_idle is None or slots_idle == n_slots_idle) + and (slots_processing is None or slots_processing == n_slots_processing)): + return await asyncio.sleep(interval) counter += interval @@ -1238,7 +1222,7 @@ async def wait_for_health_status(context, if n_completions > 0: return - assert False, f'{expected_health_status} timeout exceeded {counter}s>={timeout}' + assert False, f'slots check timeout exceeded {counter}s>={timeout}' def assert_embeddings(embeddings):