Skip to content

Commit

Permalink
Add function to load Lora adapters at initialisation and runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
CTOJunior authored and yacineMTB committed Jun 20, 2023
1 parent 08c34bd commit 33df016
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions examples/addon.node/addon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,17 @@ Napi::Number init(const Napi::CallbackInfo &info)
g_params = new gpt_params;
g_params->model = model;
// load the model and apply lora adapter, if any
// TODO: Create a function that "swaps" the adapter
// TODO: Create a function that holds more than one adapter in memory
if (obj.Has("lora")) {
Napi::String loraNapi = obj.Get("lora").As<Napi::String>();
std::string lora = loraNapi.Utf8Value();
if (!lora.empty()){
fprintf(stderr, "Loading lora from Path: %s\n", lora.c_str());
g_params->lora_adapter = lora;
g_params->use_mmap = false; // with mmap, ggml lora will throw segfault
}
}

g_ctx = llama_init_from_gpt_params(*g_params);
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", g_params->n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
if (g_ctx == NULL)
Expand All @@ -40,9 +49,25 @@ Napi::Number init(const Napi::CallbackInfo &info)
return Napi::Number::New(env, 0);
}


std::mutex worker_mutex;

// Function to load adapter at runtime
Napi::Number swapLora(const Napi::CallbackInfo &info)
{
Napi::Object obj = info[0].As<Napi::Object>();
Napi::String loraNapi = obj.Get("lora").As<Napi::String>();
std::string lora = loraNapi.Utf8Value();

fprintf(stderr, "Acquiring lock\n");
worker_mutex.lock();

fprintf(stderr, "Swapping lora from Path: %s\n", lora.c_str());
llama_apply_lora_from_file(g_ctx, lora.c_str(), NULL, get_num_physical_cores());

worker_mutex.unlock();
return Napi::Number::New(info.Env(), 0);
}

class InferenceWorker {
public:
InferenceWorker(Napi::Env env, Napi::Function listener, std::string prompt)
Expand Down Expand Up @@ -260,6 +285,9 @@ Napi::Object Init(Napi::Env env, Napi::Object exports)
exports.Set(
Napi::String::New(env, "startAsync"),
Napi::Function::New(env, StartAsync));
exports.Set(
Napi::String::New(env, "swapLora"),
Napi::Function::New(env, swapLora));
return exports;
}

Expand Down

0 comments on commit 33df016

Please sign in to comment.