From 57acd74f9a8762cb3f250c7dc6b5020ea8960814 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 15 Aug 2023 11:22:37 +0300 Subject: [PATCH] llama : no need to pass full file loader to the file saver just gguf_ctx --- gguf-llama.cpp | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/gguf-llama.cpp b/gguf-llama.cpp index e0f35e00daa1c2..e7e9de8516d460 100644 --- a/gguf-llama.cpp +++ b/gguf-llama.cpp @@ -696,12 +696,12 @@ struct gguf_file_saver { // we need to calculate the delta in number of bytes written with a counter as a struct member. gguf_file file; - gguf_file_loader * fl; + gguf_context * ctx; // loaded gguf context (used to re-write the KV section (good enough for now)) size_t info_offset; size_t tensor_offset = 0; - gguf_file_saver(const char * fname, gguf_file_loader * fl) - : file(fname, "wb"), fl(fl) { + gguf_file_saver(const char * fname, gguf_context * ctx) + : file(fname, "wb"), ctx(ctx) { fprintf(stderr, "llama.cpp: saving model to %s\n", fname); write_header(); write_kv(); @@ -710,15 +710,15 @@ struct gguf_file_saver { void write_header() { file.write_i32(GGUF_MAGIC); file.write_i32(GGUF_VERSION); - file.write_i32(gguf_get_n_tensors(fl->gguf_ctx)); - file.write_i32(gguf_get_n_kv (fl->gguf_ctx)); + file.write_i32(gguf_get_n_tensors(ctx)); + file.write_i32(gguf_get_n_kv (ctx)); } void write_kv_arr_str(const std::string & key, enum gguf_type type, int i, int n_arr) { std::vector data(n_arr); for (int j = 0; j < n_arr; ++j) { - std::string val = gguf_get_arr_str(fl->gguf_ctx, i, j); + std::string val = gguf_get_arr_str(ctx, i, j); data[j] = val; } @@ -729,7 +729,7 @@ struct gguf_file_saver { std::vector data(n_arr); for (int j = 0; j < n_arr; ++j) { - float val = gguf_get_arr_f32(fl->gguf_ctx, i, j); + float val = gguf_get_arr_f32(ctx, i, j); data[j] = val; } @@ -738,28 +738,28 @@ struct gguf_file_saver { // re-write the key-value section from the loaded file void write_kv() { - const int32_t n_kv = gguf_get_n_kv(fl->gguf_ctx); + const int32_t n_kv = gguf_get_n_kv(ctx); for (int i = 0; i < n_kv; ++i) { - const char * key = gguf_get_key(fl->gguf_ctx, i); + const char * key = gguf_get_key(ctx, i); if (strcmp(key, "general.quantization_version") == 0) { file.write_val("general.quantization_version", GGUF_TYPE_UINT32, GGML_QNT_VERSION); } else { - const gguf_type vtype = gguf_get_kv_type(fl->gguf_ctx, i); + const gguf_type vtype = gguf_get_kv_type(ctx, i); switch (vtype) { - case GGUF_TYPE_BOOL: file.write_val (key, GGUF_TYPE_BOOL, gguf_get_val_bool(fl->gguf_ctx, i)); break; - case GGUF_TYPE_FLOAT32: file.write_val (key, GGUF_TYPE_FLOAT32, gguf_get_val_f32 (fl->gguf_ctx, i)); break; - case GGUF_TYPE_INT16: file.write_val (key, GGUF_TYPE_INT16, gguf_get_val_i16 (fl->gguf_ctx, i)); break; - case GGUF_TYPE_INT32: file.write_val (key, GGUF_TYPE_INT32, gguf_get_val_i32 (fl->gguf_ctx, i)); break; - case GGUF_TYPE_INT8: file.write_val (key, GGUF_TYPE_INT8, gguf_get_val_i8 (fl->gguf_ctx, i)); break; - case GGUF_TYPE_STRING: file.write_str (key, GGUF_TYPE_STRING, gguf_get_val_str (fl->gguf_ctx, i)); break; - case GGUF_TYPE_UINT16: file.write_val(key, GGUF_TYPE_UINT16, gguf_get_val_u16 (fl->gguf_ctx, i)); break; - case GGUF_TYPE_UINT32: file.write_val(key, GGUF_TYPE_UINT32, gguf_get_val_u32 (fl->gguf_ctx, i)); break; - case GGUF_TYPE_UINT8: file.write_val (key, GGUF_TYPE_UINT8, gguf_get_val_u8 (fl->gguf_ctx, i)); break; + case GGUF_TYPE_BOOL: file.write_val (key, GGUF_TYPE_BOOL, gguf_get_val_bool(ctx, i)); break; + case GGUF_TYPE_FLOAT32: file.write_val (key, GGUF_TYPE_FLOAT32, gguf_get_val_f32 (ctx, i)); break; + case GGUF_TYPE_INT16: file.write_val (key, GGUF_TYPE_INT16, gguf_get_val_i16 (ctx, i)); break; + case GGUF_TYPE_INT32: file.write_val (key, GGUF_TYPE_INT32, gguf_get_val_i32 (ctx, i)); break; + case GGUF_TYPE_INT8: file.write_val (key, GGUF_TYPE_INT8, gguf_get_val_i8 (ctx, i)); break; + case GGUF_TYPE_STRING: file.write_str (key, GGUF_TYPE_STRING, gguf_get_val_str (ctx, i)); break; + case GGUF_TYPE_UINT16: file.write_val(key, GGUF_TYPE_UINT16, gguf_get_val_u16 (ctx, i)); break; + case GGUF_TYPE_UINT32: file.write_val(key, GGUF_TYPE_UINT32, gguf_get_val_u32 (ctx, i)); break; + case GGUF_TYPE_UINT8: file.write_val (key, GGUF_TYPE_UINT8, gguf_get_val_u8 (ctx, i)); break; case GGUF_TYPE_ARRAY: { - const gguf_type arr_type = gguf_get_arr_type(fl->gguf_ctx, i); - const int n_arr = gguf_get_arr_n (fl->gguf_ctx, i); + const gguf_type arr_type = gguf_get_arr_type(ctx, i); + const int n_arr = gguf_get_arr_n (ctx, i); if (arr_type == GGUF_TYPE_FLOAT32) { write_kv_arr_f32(key, arr_type, i, n_arr); } else if (arr_type == GGUF_TYPE_STRING) { @@ -776,9 +776,9 @@ struct gguf_file_saver { info_offset = file.tell(); - GGML_ASSERT(gguf_get_data_offset(fl->gguf_ctx) >= info_offset); + GGML_ASSERT(gguf_get_data_offset(ctx) >= info_offset); - size_t count = gguf_get_data_offset(fl->gguf_ctx) - info_offset; + size_t count = gguf_get_data_offset(ctx) - info_offset; file.write_zeros(count); file.seek(info_offset, SEEK_SET); GGML_ASSERT(info_offset == file.tell()); @@ -3219,7 +3219,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } std::unique_ptr model_loader(new llama_model_loader(fname_inp, /*use_mmap*/ false)); - gguf_file_saver file_saver(fname_out.c_str(), model_loader->file_loader.get()); + gguf_file_saver file_saver(fname_out.c_str(), model_loader->file_loader->gguf_ctx); #ifdef GGML_USE_K_QUANTS int n_attention_wv = 0;