From e8da10b77631c3e72242995fc857e87b7ed87e01 Mon Sep 17 00:00:00 2001 From: icpp Date: Wed, 22 Jan 2025 07:30:14 -0500 Subject: [PATCH] implement remove_prompt_cache --- native/test_tiny_stories.cpp | 12 +++++ src/llama_cpp.did | 65 +++++++++++++------------- src/run.cpp | 90 ++++++++++++++++++++++++++++++++++++ src/run.h | 1 + 4 files changed, 136 insertions(+), 32 deletions(-) diff --git a/native/test_tiny_stories.cpp b/native/test_tiny_stories.cpp index 2ce28ed..ae5742e 100644 --- a/native/test_tiny_stories.cpp +++ b/native/test_tiny_stories.cpp @@ -127,6 +127,18 @@ void test_tiny_stories(MockIC &mockIC) { run_update, "4449444c026c01dd9ad28304016d7101000d0e2d2d70726f6d70742d6361636865156d795f63616368652f70726f6d70742e6361636865122d2d70726f6d70742d63616368652d616c6c0a2d2d73616d706c65727305746f705f70062d2d74656d7003302e31072d2d746f702d7003302e39022d6e023230022d7000", "", silent_on_trap, my_principal); + + // ----------------------------------------------------------------------------- + // Remove the prompt-cache file if it exists + // '(record { args = vec {"--prompt-cache"; "my_cache/prompt.cache"} })' -> + // '(variant { Ok = record { status_code = 200 : nat16; output = "Ready to start a new chat for cache file .canister_cache/expmt-gtxsw-inftj-ttabj-qhp5s-nozup-n3bbo-k7zvn-dg4he-knac3-lae/sessions/my_cache/prompt.cache"; input = ""; error=""; prompt_remaining=""; generated_eog=false : bool } })' + mockIC.run_test( + std::string(__func__) + ": " + "remove_prompt_cache " + std::to_string(i) + + " - " + model, + remove_prompt_cache, + "4449444c026c01dd9ad28304016d710100020e2d2d70726f6d70742d6361636865156d795f63616368652f70726f6d70742e6361636865", + "4449444c026c06819e846471838fe5800671c897a79907719aa1b2f90c7adb92a2c90d71cdd9e6b30e7e6b01bc8a01000101009701526561647920746f2073746172742061206e6577206368617420666f722063616368652066696c65202e63616e69737465725f63616368652f6578706d742d67747873772d696e66746a2d747461626a2d71687035732d6e6f7a75702d6e3362626f2d6b377a766e2d64673468652d6b6e6163332d6c61652f73657373696f6e732f6d795f63616368652f70726f6d70742e63616368650000c8000000", + silent_on_trap, my_principal); } } } \ No newline at end of file diff --git a/src/llama_cpp.did b/src/llama_cpp.did index d6a7a49..6c14fef 100644 --- a/src/llama_cpp.did +++ b/src/llama_cpp.did @@ -1,24 +1,24 @@ // Always return a single record wrapped in a variant type StatusCode = nat16; -type InputRecord = record { +type InputRecord = record { args : vec text; // the CLI args of llama.cpp/examples/main, as a list of strings }; // to avoid hitting the IC's instructions limit // 0 = no limit (default) -type MaxTokensRecord = record { +type MaxTokensRecord = record { max_tokens_update : nat64; max_tokens_query : nat64; }; -type RunOutputRecord = record { - status_code: StatusCode; - output: text; - conversation: text; - error: text; - prompt_remaining: text; - generated_eog: bool; +type RunOutputRecord = record { + status_code : StatusCode; + output : text; + conversation : text; + error : text; + prompt_remaining : text; + generated_eog : bool; }; type OutputRecordResult = variant { Ok : RunOutputRecord; @@ -39,36 +39,36 @@ type StatusCodeRecordResult = variant { type StatusCodeRecord = record { status_code : nat16 }; // ----------------------------------------------------- -type FileDownloadInputRecord = record { - filename: text; - chunksize: nat64; - offset: nat64; +type FileDownloadInputRecord = record { + filename : text; + chunksize : nat64; + offset : nat64; }; type FileDownloadRecordResult = variant { Err : ApiError; Ok : FileDownloadRecord; }; type FileDownloadRecord = record { - chunk : vec nat8; // the chunk read from the file, as a vec of bytes - chunksize : nat64; // the chunksize in bytes - filesize : nat64; // the total filesize in bytes - offset: nat64; // the chunk starts here (bytes from beginning) - done : bool; // true if there are no more bytes to read + chunk : vec nat8; // the chunk read from the file, as a vec of bytes + chunksize : nat64; // the chunksize in bytes + filesize : nat64; // the total filesize in bytes + offset : nat64; // the chunk starts here (bytes from beginning) + done : bool; // true if there are no more bytes to read }; // ----------------------------------------------------- -type FileUploadInputRecord = record { - filename: text; - chunk : vec nat8; // the chunk being uploaded, as a vec of bytes - chunksize: nat64; // the chunksize (allowing sanity check) - offset: nat64; // the offset where to write the chunk +type FileUploadInputRecord = record { + filename : text; + chunk : vec nat8; // the chunk being uploaded, as a vec of bytes + chunksize : nat64; // the chunksize (allowing sanity check) + offset : nat64; // the offset where to write the chunk }; type FileUploadRecordResult = variant { Err : ApiError; Ok : FileUploadRecord; }; type FileUploadRecord = record { - filesize : nat64; // the total filesize in bytes after writing chunk at offset + filesize : nat64; // the total filesize in bytes after writing chunk at offset }; // ----------------------------------------------------- @@ -78,8 +78,8 @@ type GetChatsRecordResult = variant { }; type GetChatsRecord = record { chats : vec record { - timestamp: text; - chat: text; + timestamp : text; + chat : text; }; }; @@ -87,8 +87,8 @@ type GetChatsRecord = record { // Access level // 0 = only controllers // 1 = all except anonymous -type AccessInputRecord = record { - level: nat16; +type AccessInputRecord = record { + level : nat16; }; type AccessRecordResult = variant { Err : ApiError; @@ -110,7 +110,7 @@ service : { load_model : (InputRecord) -> (OutputRecordResult); set_max_tokens : (MaxTokensRecord) -> (StatusCodeRecordResult); get_max_tokens : () -> (MaxTokensRecord) query; - + // up & down load of files file_download_chunk : (FileDownloadInputRecord) -> (FileDownloadRecordResult) query; file_upload_chunk : (FileUploadInputRecord) -> (FileUploadRecordResult); @@ -118,7 +118,8 @@ service : { // Inference endpoints new_chat : (InputRecord) -> (OutputRecordResult); run_query : (InputRecord) -> (OutputRecordResult) query; - run_update: (InputRecord) -> (OutputRecordResult); + run_update : (InputRecord) -> (OutputRecordResult); + remove_prompt_cache : (InputRecord) -> (OutputRecordResult); // Chats retrieval get_chats : () -> (GetChatsRecordResult) query; @@ -130,5 +131,5 @@ service : { // Other admin endpoints whoami : () -> (text) query; - -} \ No newline at end of file + +}; diff --git a/src/run.cpp b/src/run.cpp index 7b0f216..f0ff599 100644 --- a/src/run.cpp +++ b/src/run.cpp @@ -93,6 +93,11 @@ void new_chat() { } else { msg = "Cache file " + path_session + " not found. Nothing to delete."; } + } else { + error_msg = "ERROR: path_session is empty "; + send_output_record_result_error_to_wire( + ic_api, Http::StatusCode::InternalServerError, error_msg); + return; } std::cout << msg << std::endl; @@ -110,6 +115,91 @@ void new_chat() { ic_api.to_wire(CandidTypeVariant{"Ok", r_out}); } +void remove_prompt_cache() { + IC_API ic_api(CanisterUpdate{std::string(__func__)}, false); + std::string error_msg; + if (!is_caller_whitelisted(ic_api, false)) { + error_msg = "Access Denied."; + send_output_record_result_error_to_wire( + ic_api, Http::StatusCode::Unauthorized, error_msg); + return; + } + + CandidTypePrincipal caller = ic_api.get_caller(); + std::string principal_id = caller.get_text(); + + auto [argc, argv, args] = get_args_for_main(ic_api); + + // Get the cache filename from --prompt-cache in args + gpt_params params; + if (!gpt_params_parse(argc, argv.data(), params)) { + error_msg = "Cannot parse args."; + send_output_record_result_error_to_wire( + ic_api, Http::StatusCode::InternalServerError, error_msg); + return; + } + + // // Create a new file to save this chat for this prinicipal + // if (!db_chats_new(principal_id, error_msg)) { + // send_output_record_result_error_to_wire( + // ic_api, Http::StatusCode::InternalServerError, error_msg); + // return; + // } + + // // Each principal can only save N chats + // if (!db_chats_clean(principal_id, error_msg)) { + // send_output_record_result_error_to_wire( + // ic_api, Http::StatusCode::InternalServerError, error_msg); + // return; + // } + + // Each principal has their own cache folder + std::string path_session = params.path_prompt_cache; + std::string canister_path_session; + if (!get_canister_path_session(path_session, principal_id, + canister_path_session, error_msg)) { + send_output_record_result_error_to_wire( + ic_api, Http::StatusCode::InternalServerError, error_msg); + return; + } + path_session = canister_path_session; + + std::string msg; + if (!path_session.empty()) { + // Remove the file if it exists + if (std::filesystem::exists(path_session)) { + bool success = std::filesystem::remove(path_session); + if (success) { + msg = "Cache file " + path_session + " deleted successfully"; + } else { + error_msg = "Error deleting cache file " + path_session; + send_output_record_result_error_to_wire( + ic_api, Http::StatusCode::InternalServerError, error_msg); + return; + } + } else { + msg = "Cache file " + path_session + " not found. Nothing to delete."; + } + } else { + error_msg = "ERROR: path_session is empty "; + send_output_record_result_error_to_wire( + ic_api, Http::StatusCode::InternalServerError, error_msg); + return; + } + std::cout << msg << std::endl; + + // Return output over the wire + CandidTypeRecord r_out; + r_out.append("status_code", CandidTypeNat16{Http::StatusCode::OK}); // 200 + r_out.append("conversation", CandidTypeText{""}); + r_out.append("output", CandidTypeText{msg}); + r_out.append("error", CandidTypeText{""}); + r_out.append("prompt_remaining", CandidTypeText{""}); + r_out.append("generated_eog", CandidTypeBool{false}); + ic_api.to_wire(CandidTypeVariant{"Ok", r_out}); +} + + void run(IC_API &ic_api, const uint64_t &max_tokens) { std::string error_msg; if (!is_caller_whitelisted(ic_api, false)) { diff --git a/src/run.h b/src/run.h index 7519bd1..e1c25fc 100644 --- a/src/run.h +++ b/src/run.h @@ -6,6 +6,7 @@ void new_chat() WASM_SYMBOL_EXPORTED("canister_update new_chat"); void run_query() WASM_SYMBOL_EXPORTED("canister_query run_query"); void run_update() WASM_SYMBOL_EXPORTED("canister_update run_update"); +void remove_prompt_cache() WASM_SYMBOL_EXPORTED("canister_update remove_prompt_cache"); bool get_canister_path_session(const std::string &path_session, const std::string &principal_id,