Skip to content

Commit

Permalink
implement remove_prompt_cache
Browse files Browse the repository at this point in the history
  • Loading branch information
icppWorld committed Jan 22, 2025
1 parent 17356ea commit e8da10b
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 32 deletions.
12 changes: 12 additions & 0 deletions native/test_tiny_stories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ void test_tiny_stories(MockIC &mockIC) {
run_update,
"4449444c026c01dd9ad28304016d7101000d0e2d2d70726f6d70742d6361636865156d795f63616368652f70726f6d70742e6361636865122d2d70726f6d70742d63616368652d616c6c0a2d2d73616d706c65727305746f705f70062d2d74656d7003302e31072d2d746f702d7003302e39022d6e023230022d7000",
"", silent_on_trap, my_principal);

// -----------------------------------------------------------------------------
// Remove the prompt-cache file if it exists
// '(record { args = vec {"--prompt-cache"; "my_cache/prompt.cache"} })' ->
// '(variant { Ok = record { status_code = 200 : nat16; output = "Ready to start a new chat for cache file .canister_cache/expmt-gtxsw-inftj-ttabj-qhp5s-nozup-n3bbo-k7zvn-dg4he-knac3-lae/sessions/my_cache/prompt.cache"; input = ""; error=""; prompt_remaining=""; generated_eog=false : bool } })'
mockIC.run_test(
std::string(__func__) + ": " + "remove_prompt_cache " + std::to_string(i) +
" - " + model,
remove_prompt_cache,
"4449444c026c01dd9ad28304016d710100020e2d2d70726f6d70742d6361636865156d795f63616368652f70726f6d70742e6361636865",
"4449444c026c06819e846471838fe5800671c897a79907719aa1b2f90c7adb92a2c90d71cdd9e6b30e7e6b01bc8a01000101009701526561647920746f2073746172742061206e6577206368617420666f722063616368652066696c65202e63616e69737465725f63616368652f6578706d742d67747873772d696e66746a2d747461626a2d71687035732d6e6f7a75702d6e3362626f2d6b377a766e2d64673468652d6b6e6163332d6c61652f73657373696f6e732f6d795f63616368652f70726f6d70742e63616368650000c8000000",
silent_on_trap, my_principal);
}
}
}
65 changes: 33 additions & 32 deletions src/llama_cpp.did
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
// Always return a single record wrapped in a variant
type StatusCode = nat16;

type InputRecord = record {
type InputRecord = record {
args : vec text; // the CLI args of llama.cpp/examples/main, as a list of strings
};

// to avoid hitting the IC's instructions limit
// 0 = no limit (default)
type MaxTokensRecord = record {
type MaxTokensRecord = record {
max_tokens_update : nat64;
max_tokens_query : nat64;
};

type RunOutputRecord = record {
status_code: StatusCode;
output: text;
conversation: text;
error: text;
prompt_remaining: text;
generated_eog: bool;
type RunOutputRecord = record {
status_code : StatusCode;
output : text;
conversation : text;
error : text;
prompt_remaining : text;
generated_eog : bool;
};
type OutputRecordResult = variant {
Ok : RunOutputRecord;
Expand All @@ -39,36 +39,36 @@ type StatusCodeRecordResult = variant {
type StatusCodeRecord = record { status_code : nat16 };

// -----------------------------------------------------
type FileDownloadInputRecord = record {
filename: text;
chunksize: nat64;
offset: nat64;
type FileDownloadInputRecord = record {
filename : text;
chunksize : nat64;
offset : nat64;
};
type FileDownloadRecordResult = variant {
Err : ApiError;
Ok : FileDownloadRecord;
};
type FileDownloadRecord = record {
chunk : vec nat8; // the chunk read from the file, as a vec of bytes
chunksize : nat64; // the chunksize in bytes
filesize : nat64; // the total filesize in bytes
offset: nat64; // the chunk starts here (bytes from beginning)
done : bool; // true if there are no more bytes to read
chunk : vec nat8; // the chunk read from the file, as a vec of bytes
chunksize : nat64; // the chunksize in bytes
filesize : nat64; // the total filesize in bytes
offset : nat64; // the chunk starts here (bytes from beginning)
done : bool; // true if there are no more bytes to read
};

// -----------------------------------------------------
type FileUploadInputRecord = record {
filename: text;
chunk : vec nat8; // the chunk being uploaded, as a vec of bytes
chunksize: nat64; // the chunksize (allowing sanity check)
offset: nat64; // the offset where to write the chunk
type FileUploadInputRecord = record {
filename : text;
chunk : vec nat8; // the chunk being uploaded, as a vec of bytes
chunksize : nat64; // the chunksize (allowing sanity check)
offset : nat64; // the offset where to write the chunk
};
type FileUploadRecordResult = variant {
Err : ApiError;
Ok : FileUploadRecord;
};
type FileUploadRecord = record {
filesize : nat64; // the total filesize in bytes after writing chunk at offset
filesize : nat64; // the total filesize in bytes after writing chunk at offset
};

// -----------------------------------------------------
Expand All @@ -78,17 +78,17 @@ type GetChatsRecordResult = variant {
};
type GetChatsRecord = record {
chats : vec record {
timestamp: text;
chat: text;
timestamp : text;
chat : text;
};
};

// -----------------------------------------------------
// Access level
// 0 = only controllers
// 1 = all except anonymous
type AccessInputRecord = record {
level: nat16;
type AccessInputRecord = record {
level : nat16;
};
type AccessRecordResult = variant {
Err : ApiError;
Expand All @@ -110,15 +110,16 @@ service : {
load_model : (InputRecord) -> (OutputRecordResult);
set_max_tokens : (MaxTokensRecord) -> (StatusCodeRecordResult);
get_max_tokens : () -> (MaxTokensRecord) query;

// up & down load of files
file_download_chunk : (FileDownloadInputRecord) -> (FileDownloadRecordResult) query;
file_upload_chunk : (FileUploadInputRecord) -> (FileUploadRecordResult);

// Inference endpoints
new_chat : (InputRecord) -> (OutputRecordResult);
run_query : (InputRecord) -> (OutputRecordResult) query;
run_update: (InputRecord) -> (OutputRecordResult);
run_update : (InputRecord) -> (OutputRecordResult);
remove_prompt_cache : (InputRecord) -> (OutputRecordResult);

// Chats retrieval
get_chats : () -> (GetChatsRecordResult) query;
Expand All @@ -130,5 +131,5 @@ service : {

// Other admin endpoints
whoami : () -> (text) query;
}

};
90 changes: 90 additions & 0 deletions src/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ void new_chat() {
} else {
msg = "Cache file " + path_session + " not found. Nothing to delete.";
}
} else {
error_msg = "ERROR: path_session is empty ";
send_output_record_result_error_to_wire(
ic_api, Http::StatusCode::InternalServerError, error_msg);
return;
}
std::cout << msg << std::endl;

Expand All @@ -110,6 +115,91 @@ void new_chat() {
ic_api.to_wire(CandidTypeVariant{"Ok", r_out});
}

void remove_prompt_cache() {
IC_API ic_api(CanisterUpdate{std::string(__func__)}, false);
std::string error_msg;
if (!is_caller_whitelisted(ic_api, false)) {
error_msg = "Access Denied.";
send_output_record_result_error_to_wire(
ic_api, Http::StatusCode::Unauthorized, error_msg);
return;
}

CandidTypePrincipal caller = ic_api.get_caller();
std::string principal_id = caller.get_text();

auto [argc, argv, args] = get_args_for_main(ic_api);

// Get the cache filename from --prompt-cache in args
gpt_params params;
if (!gpt_params_parse(argc, argv.data(), params)) {
error_msg = "Cannot parse args.";
send_output_record_result_error_to_wire(
ic_api, Http::StatusCode::InternalServerError, error_msg);
return;
}

// // Create a new file to save this chat for this prinicipal
// if (!db_chats_new(principal_id, error_msg)) {
// send_output_record_result_error_to_wire(
// ic_api, Http::StatusCode::InternalServerError, error_msg);
// return;
// }

// // Each principal can only save N chats
// if (!db_chats_clean(principal_id, error_msg)) {
// send_output_record_result_error_to_wire(
// ic_api, Http::StatusCode::InternalServerError, error_msg);
// return;
// }

// Each principal has their own cache folder
std::string path_session = params.path_prompt_cache;
std::string canister_path_session;
if (!get_canister_path_session(path_session, principal_id,
canister_path_session, error_msg)) {
send_output_record_result_error_to_wire(
ic_api, Http::StatusCode::InternalServerError, error_msg);
return;
}
path_session = canister_path_session;

std::string msg;
if (!path_session.empty()) {
// Remove the file if it exists
if (std::filesystem::exists(path_session)) {
bool success = std::filesystem::remove(path_session);
if (success) {
msg = "Cache file " + path_session + " deleted successfully";
} else {
error_msg = "Error deleting cache file " + path_session;
send_output_record_result_error_to_wire(
ic_api, Http::StatusCode::InternalServerError, error_msg);
return;
}
} else {
msg = "Cache file " + path_session + " not found. Nothing to delete.";
}
} else {
error_msg = "ERROR: path_session is empty ";
send_output_record_result_error_to_wire(
ic_api, Http::StatusCode::InternalServerError, error_msg);
return;
}
std::cout << msg << std::endl;

// Return output over the wire
CandidTypeRecord r_out;
r_out.append("status_code", CandidTypeNat16{Http::StatusCode::OK}); // 200
r_out.append("conversation", CandidTypeText{""});
r_out.append("output", CandidTypeText{msg});
r_out.append("error", CandidTypeText{""});
r_out.append("prompt_remaining", CandidTypeText{""});
r_out.append("generated_eog", CandidTypeBool{false});
ic_api.to_wire(CandidTypeVariant{"Ok", r_out});
}


void run(IC_API &ic_api, const uint64_t &max_tokens) {
std::string error_msg;
if (!is_caller_whitelisted(ic_api, false)) {
Expand Down
1 change: 1 addition & 0 deletions src/run.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
void new_chat() WASM_SYMBOL_EXPORTED("canister_update new_chat");
void run_query() WASM_SYMBOL_EXPORTED("canister_query run_query");
void run_update() WASM_SYMBOL_EXPORTED("canister_update run_update");
void remove_prompt_cache() WASM_SYMBOL_EXPORTED("canister_update remove_prompt_cache");

bool get_canister_path_session(const std::string &path_session,
const std::string &principal_id,
Expand Down

0 comments on commit e8da10b

Please sign in to comment.