From 33df016489cfd51ab635226bd33ecb6927ea7421 Mon Sep 17 00:00:00 2001 From: TDM Date: Tue, 13 Jun 2023 17:28:29 +0530 Subject: [PATCH] Add function to load Lora adapters at initialisation and runtime --- examples/addon.node/addon.cpp | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/examples/addon.node/addon.cpp b/examples/addon.node/addon.cpp index 3e2d80914e034..cae06cc9e29e0 100644 --- a/examples/addon.node/addon.cpp +++ b/examples/addon.node/addon.cpp @@ -27,8 +27,17 @@ Napi::Number init(const Napi::CallbackInfo &info) g_params = new gpt_params; g_params->model = model; // load the model and apply lora adapter, if any - // TODO: Create a function that "swaps" the adapter // TODO: Create a function that holds more than one adapter in memory + if (obj.Has("lora")) { + Napi::String loraNapi = obj.Get("lora").As(); + std::string lora = loraNapi.Utf8Value(); + if (!lora.empty()){ + fprintf(stderr, "Loading lora from Path: %s\n", lora.c_str()); + g_params->lora_adapter = lora; + g_params->use_mmap = false; // with mmap, ggml lora will throw segfault + } + } + g_ctx = llama_init_from_gpt_params(*g_params); fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", g_params->n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); if (g_ctx == NULL) @@ -40,9 +49,25 @@ Napi::Number init(const Napi::CallbackInfo &info) return Napi::Number::New(env, 0); } - std::mutex worker_mutex; +// Function to load adapter at runtime +Napi::Number swapLora(const Napi::CallbackInfo &info) +{ + Napi::Object obj = info[0].As(); + Napi::String loraNapi = obj.Get("lora").As(); + std::string lora = loraNapi.Utf8Value(); + + fprintf(stderr, "Acquiring lock\n"); + worker_mutex.lock(); + + fprintf(stderr, "Swapping lora from Path: %s\n", lora.c_str()); + llama_apply_lora_from_file(g_ctx, lora.c_str(), NULL, get_num_physical_cores()); + + worker_mutex.unlock(); + return Napi::Number::New(info.Env(), 0); +} + class InferenceWorker { public: InferenceWorker(Napi::Env env, Napi::Function listener, std::string prompt) @@ -260,6 +285,9 @@ Napi::Object Init(Napi::Env env, Napi::Object exports) exports.Set( Napi::String::New(env, "startAsync"), Napi::Function::New(env, StartAsync)); + exports.Set( + Napi::String::New(env, "swapLora"), + Napi::Function::New(env, swapLora)); return exports; }