Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Hot swap for LoRA #8056

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.model = argv[i];
return true;
}
if (arg == "-hl" || arg == "--hot-lora") {
if (++i >= argc) {
invalid_param = true;
return true;
}
params.hot_lora = argv[i];
return true;
}
if (arg == "-md" || arg == "--model-draft") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -2435,6 +2443,10 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.n_ubatch = params.n_ubatch;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
const char* c_string = params.hot_lora.c_str();
strncpy(cparams.hot_lora, c_string, sizeof(cparams.hot_lora) - 1);
cparams.hot_lora[sizeof(cparams.hot_lora) - 1] = '\0'; // Ensure null-termination

cparams.seed = params.seed;
cparams.logits_all = params.logits_all;
cparams.embeddings = params.embedding;
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ struct gpt_params {

std::string model = ""; // model path
std::string model_draft = ""; // draft model for speculative decoding
std::string hot_lora = ""; // lora model path for hot swapping
std::string model_alias = "unknown"; // model alias
std::string model_url = ""; // model url to download
std::string hf_repo = ""; // HF repo
Expand Down
2 changes: 2 additions & 0 deletions data/hot-lora.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

test data to train adapter
47 changes: 47 additions & 0 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -4313,6 +4313,52 @@ struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * nam
return NULL;
}

//////// LORA

struct lora_tensor_pair* build_lora_weights_map(struct ggml_context* ctx) {
struct lora_tensor_pair* pair = malloc(sizeof(struct lora_tensor_pair));
if (!pair) return NULL;
pair->pairs = NULL;
pair->count = 0;
pair->capacity = 0;

struct ggml_object * obj = ctx->objects_begin;
char * const mem_buffer = ctx->mem_buffer;

while (obj != NULL) {
if (obj->type == GGML_OBJECT_TYPE_TENSOR) {
struct ggml_tensor * tensor = (struct ggml_tensor *)(mem_buffer + obj->offs);
char * tensor_name = tensor->name;

if (strlen(tensor_name) > 6 && (strcmp(tensor_name + strlen(tensor_name) - 6, ".loraA") == 0 ||
strcmp(tensor_name + strlen(tensor_name) - 6, ".loraB") == 0)) {
if (pair->count == pair->capacity) {
pair->capacity = pair->capacity > 0 ? pair->capacity * 2 : 4;
pair->pairs = realloc(pair->pairs, pair->capacity * sizeof(struct lora_tensor_info));
}

pair->pairs[pair->count].name = strdup(tensor_name);
pair->pairs[pair->count].tensor = tensor;
pair->count++;
}
}
obj = obj->next;
}

return pair;
}

void free_lora_tensor_pair(struct lora_tensor_pair* pair) {
if (!pair) return;
for (int i = 0; i < pair->count; i++) {
free(pair->pairs[i].name);
}
free(pair->pairs);
free(pair);
}

//////// LORA

////////////////////////////////////////////////////////////////////////////////

// ggml_dup
Expand Down Expand Up @@ -5285,6 +5331,7 @@ struct ggml_tensor * ggml_group_norm_inplace(
return ggml_group_norm_impl(ctx, a, n_groups, true);
}


// ggml_mul_mat

struct ggml_tensor * ggml_mul_mat(
Expand Down
19 changes: 19 additions & 0 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,25 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_get_next_tensor (const struct ggml_context * ctx, struct ggml_tensor * tensor);
GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);

struct lora_tensor_info {
char* name;
struct ggml_tensor* tensor;
};

struct lora_tensor_pair {
struct lora_tensor_info* pairs; // Dynamic array of tensor pairs
int count;
int capacity;
};

// Function to build tensor pairs
struct lora_tensor_pair* build_lora_weights_map(struct ggml_context* ctx);

// Cleanup function for lora_tensor_pair
void free_lora_tensor_pair(struct lora_tensor_pair* pair);



GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
Expand Down
Loading