From e0d0a0f6af5abca012a126f84950849d09d4ccbb Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Fri, 15 Sep 2023 23:57:46 +0200 Subject: [PATCH] fix q8_0 for model with n_embd_head % 32 != 0 --- ggml-cuda.cu | 52 +++++++++++++++++++++---------- ggml.c | 47 +++++++++++++++++++++++----- ggml.h | 6 ++++ llama.cpp | 87 +++++++++++++++++++++++++++------------------------- 4 files changed, 126 insertions(+), 66 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 6f83785c01fb7..74c4e2c6863d2 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4044,7 +4044,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, cpy_1(cx + x_offset, cdst + dst_offset); } -template +template static __global__ void cpy_f32_q8_0( const char * cx, char * cdst, const int i_blck_0, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb11, const int nb12) { @@ -4075,7 +4075,7 @@ static __global__ void cpy_f32_q8_0( val = *((float *) src); } - if (last_incomplete && i0 / QK8_0 == (i_blck_0 + ne00) / QK8_0) { + if (save_unquantized && last_incomplete && i0 / QK8_0 == (i_blck_0 + ne00) / QK8_0) { memcpy(&dst[1 + iqs/8].qs[sizeof(float) * (iqs % 8)], src, sizeof(float)); } @@ -5114,7 +5114,7 @@ static void ggml_cpy_f32_f16_cuda( static void ggml_cpy_f32_q8_0_cuda( const char * cx, char * cdst, const int i_blck_0, const int ne00, const int ne01, const int ne02, - const int nb00, const int nb01, const int nb02, const int nb11, const int nb12, cudaStream_t stream) { + const int nb00, const int nb01, const int nb02, const int nb11, const int nb12, const bool pad, cudaStream_t stream) { const int num_blocks_x = (i_blck_0 + ne00 + WARP_SIZE - 1) / WARP_SIZE; const dim3 block_nums(num_blocks_x, ne01, ne02); @@ -5125,17 +5125,27 @@ static void ggml_cpy_f32_q8_0_cuda( if (first_incomplete && last_incomplete) { GGML_ASSERT(i_blck_0 + ne00 < QK8_0); // otherwise there would be a race condition - cpy_f32_q8_0<<>> + GGML_ASSERT(pad == false); + cpy_f32_q8_0<<>> (cx, cdst, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02, nb11, nb12); } else if (first_incomplete && !last_incomplete) { - cpy_f32_q8_0<<>> + GGML_ASSERT(pad == false); + cpy_f32_q8_0<<>> (cx, cdst, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02, nb11, nb12); - } else if (!first_incomplete && last_incomplete) { - cpy_f32_q8_0<<>> + } else if (!first_incomplete && last_incomplete && pad) { + cpy_f32_q8_0<<>> (cx, cdst, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02, nb11, nb12); - } else if (!first_incomplete && !last_incomplete) { - cpy_f32_q8_0<<>> + } else if (!first_incomplete && last_incomplete && !pad) { + cpy_f32_q8_0<<>> (cx, cdst, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02, nb11, nb12); + } else if (!first_incomplete && !last_incomplete && pad) { + cpy_f32_q8_0<<>> + (cx, cdst, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02, nb11, nb12); + } else if (!first_incomplete && !last_incomplete && !pad) { + cpy_f32_q8_0<<>> + (cx, cdst, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02, nb11, nb12); + } else { + GGML_ASSERT(false); } } @@ -6626,9 +6636,6 @@ void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_te } void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - const int64_t ne = ggml_nelements(src0); - GGML_ASSERT(ne == ggml_nelements(src1)); - GGML_ASSERT(src0->backend == GGML_BACKEND_GPU); GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); @@ -6652,6 +6659,16 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens const int64_t nb11 = src1->nb[1]; const int64_t nb12 = src1->nb[2]; + const int64_t blck_size = ggml_blck_size(src1->type); + const int64_t ne00_padded = ((ne00 + blck_size - 1) / blck_size) * blck_size; + const int64_t ne = ggml_nelements(src0); + const bool pad = dst->op_params[0] & 1; + if (pad) { + GGML_ASSERT(ne00_padded * ggml_nrows(src0) == ggml_nelements(src1)); + } else { + GGML_ASSERT(ne == ggml_nelements(src1)); + } + CUDA_CHECK(cudaSetDevice(g_main_device)); cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device]; @@ -6670,16 +6687,19 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { GGML_ASSERT(nb10 == sizeof(block_q8_0)); - const size_t * op_params = (const size_t *) src1->op_params; - const size_t i_blck_0 = op_params[1]; + size_t i_blck_0 = 0; + if (src1->op == GGML_OP_VIEW) { + const size_t * op_params = (const size_t *) src1->op_params; + i_blck_0 = op_params[1]; + } if (ggml_is_contiguous(src1)) { ggml_cpy_f32_q8_0_cuda( src0_ddc, src1_ddc, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02, - ne00*sizeof(block_q8_0)/QK8_0, ne00*ne01*sizeof(block_q8_0)/QK8_0, cudaStream_main); + ne00_padded*sizeof(block_q8_0)/QK8_0, ne00_padded*ne01*sizeof(block_q8_0)/QK8_0, pad, cudaStream_main); } else { ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, i_blck_0, ne00, ne01, ne02, - nb00, nb01, nb02, nb11, nb12, cudaStream_main); + nb00, nb01, nb02, nb11, nb12, pad, cudaStream_main); } } else { diff --git a/ggml.c b/ggml.c index 8a56a1bdb9c2d..6d39fb20b6ded 100644 --- a/ggml.c +++ b/ggml.c @@ -6312,8 +6312,15 @@ static struct ggml_tensor * ggml_cpy_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - bool inplace) { - GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); + const bool inplace, + const bool pad) { + if (pad) { + const int64_t blck_size = ggml_blck_size(b->type); + const int64_t ne00_padded = ((a->ne[0] + blck_size - 1) / blck_size) * blck_size; + GGML_ASSERT(ne00_padded*ggml_nrows(a) == ggml_nelements(b)); + } else { + GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); + } bool is_node = false; @@ -6329,6 +6336,8 @@ static struct ggml_tensor * ggml_cpy_impl( ggml_format_name(result, "%s (copy)", a->name); } + ggml_set_op_params_i32(result, 0, pad ? 1 : 0); + result->op = GGML_OP_CPY; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; @@ -6341,14 +6350,21 @@ struct ggml_tensor * ggml_cpy( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - return ggml_cpy_impl(ctx, a, b, false); + return ggml_cpy_impl(ctx, a, b, false, false); } struct ggml_tensor * ggml_cpy_inplace( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - return ggml_cpy_impl(ctx, a, b, true); + return ggml_cpy_impl(ctx, a, b, true, false); +} + +struct ggml_tensor * ggml_cpy_pad( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_cpy_impl(ctx, a, b, false, true); } // ggml_cont @@ -8233,6 +8249,8 @@ static void ggml_compute_forward_dup_f16( GGML_TENSOR_UNARY_OP_LOCALS; + GGML_ASSERT(dst->op_params[0] == 0); + const int ith = params->ith; // thread index const int nth = params->nth; // number of threads @@ -8496,14 +8514,21 @@ static void ggml_compute_forward_dup_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { - GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } GGML_TENSOR_UNARY_OP_LOCALS; + const bool pad = dst->op_params[0] & 1; + const int blck_size = ggml_blck_size(dst->type); + const int ne00_padded = ((ne00 + blck_size - 1) / blck_size) * blck_size; + if (pad) { + GGML_ASSERT(ggml_nelements(dst) == ne00_padded*ggml_nrows(src0)); + } else { + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + } + const int ith = params->ith; // thread index const int nth = params->nth; // number of threads @@ -8561,15 +8586,20 @@ static void ggml_compute_forward_dup_f32( ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + const size_t rs = nb0 * ne00_padded / blck_size; char * dst_ptr = (char *) dst->data; + float src0_padded[ne00_padded]; for (int i03 = 0; i03 < ne03; i03++) { for (int i02 = 0; i02 < ne02; i02++) { id += rs * ir0; for (int i01 = ir0; i01 < ir1; i01++) { const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - quantize_row_q(src0_ptr, dst_ptr + id, ne00); + if (ne00 != ne00_padded) { + memcpy(src0_padded, src0_ptr, ne00*sizeof(float)); + memset(src0_padded + ne00, 0, (ne00_padded - ne00) * sizeof(float)); + } + quantize_row_q(ne00 == ne00_padded ? src0_ptr : src0_padded, dst_ptr + id, ne00_padded); id += rs; } id += rs * (ne01 - ir1); @@ -8737,6 +8767,7 @@ static void ggml_compute_forward_dup_f32( } } } else if (type_traits[dst->type].from_float) { + GGML_ASSERT(!pad); GGML_ASSERT(ne00 == ne0); GGML_ASSERT(ne01 == ne1); GGML_ASSERT(ne02 == ne2); diff --git a/ggml.h b/ggml.h index f731915b0e3f3..6bbd73cac3f60 100644 --- a/ggml.h +++ b/ggml.h @@ -1053,6 +1053,12 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + // a -> b, pad row size of a to a multiple of block size of b, return view(b) + GGML_API struct ggml_tensor * ggml_cpy_pad( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + // make contiguous GGML_API struct ggml_tensor * ggml_cont( struct ggml_context * ctx, diff --git a/llama.cpp b/llama.cpp index e611ad0aca55e..fd6ec4e221690 100644 --- a/llama.cpp +++ b/llama.cpp @@ -925,12 +925,16 @@ struct llama_hparams { } size_t kv_size_k(ggml_type type) const { + const int64_t blck_size = ggml_blck_size(type); + const int64_t n_embd_head_padded = ((n_embd_head() + blck_size - 1) / blck_size) * blck_size; + size_t result = 1ull; - result *= (size_t) n_embd_gqa(); + result *= (size_t) n_embd_head_padded; + result *= (size_t) n_head_kv; result *= (size_t) n_ctx; result *= (size_t) n_layer; result *= ggml_type_size(type); - result /= ggml_blck_size(type); + result /= blck_size; return result; } @@ -1167,8 +1171,11 @@ static bool llama_kv_cache_init( ggml_type wtype, int n_ctx, int n_gpu_layers) { - const int n_embd = hparams.n_embd_gqa(); - const int n_layer = hparams.n_layer; + const int blck_size = ggml_blck_size(wtype); + const int n_embd_head = hparams.n_embd_head(); + const int n_embd_head_padded = ((n_embd_head + blck_size - 1) / blck_size) * blck_size; + const int n_head_kv = hparams.n_head_kv; + const int n_layer = hparams.n_layer; if (n_ctx % ggml_blck_size(wtype) != 0) { LLAMA_LOG_ERROR("error: for KV type %s n_ctx must be a multiple of %d but received n_ctx=%d\n", @@ -1176,20 +1183,13 @@ static bool llama_kv_cache_init( return false; } - if (n_embd % ggml_blck_size(wtype) != 0) { - LLAMA_LOG_ERROR("error: for KV type %s n_ctx must be a multiple of %d but received n_embd=%d\n", - ggml_type_name(wtype), ggml_blck_size(wtype), n_embd); - return false; - } - - const int64_t n_mem = n_layer*n_ctx; - const int64_t n_elements = n_embd*n_mem; - // if the KV cache is quantized we need a little extra space for each row to store the // unquantized values between evals (this avoids precision loss when rebuilding the block) - const int64_t v_quant_buffer = wtype == GGML_TYPE_Q8_0 ? 128*n_layer*n_embd : 0; + const int64_t n_mem = n_layer*n_ctx; + const int64_t n_elements_k = n_embd_head_padded * n_head_kv * n_mem; + const int64_t n_elements_v = n_embd_head * n_head_kv * (n_mem + (wtype == GGML_TYPE_Q8_0 ? 128*n_layer : 0)); - cache.buf.resize((2u*n_elements + v_quant_buffer)*ggml_type_size(wtype)/ggml_blck_size(wtype) + 2u*MB); + cache.buf.resize((n_elements_k + n_elements_v)*ggml_type_size(wtype)/ggml_blck_size(wtype) + 2u*MB); cache.n = 0; struct ggml_init_params params; @@ -1204,8 +1204,8 @@ static bool llama_kv_cache_init( return false; } - cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); - cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements + v_quant_buffer); + cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements_k); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements_v); ggml_set_name(cache.k, "cache_k"); ggml_set_name(cache.v, "cache_v"); @@ -2258,13 +2258,17 @@ static struct ggml_cgraph * llm_build_llama( GGML_ASSERT(!!kv_self.ctx); - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = hparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); + const int64_t blck_size_k = ggml_blck_size(kv_self.k->type); + const int64_t blck_size_v = ggml_blck_size(kv_self.v->type); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = hparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_embd_head_padded = ((n_embd_head + blck_size_k - 1) / blck_size_k) * blck_size_k; + const int64_t n_embd_gqa = hparams.n_embd_gqa(); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -2402,22 +2406,22 @@ static struct ggml_cgraph * llm_build_llama( ggml_set_name(Vcur, "Vcur"); struct ggml_tensor * k = ggml_view_1d( - ctx0, kv_self.k, N*n_embd_gqa, - (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)/ggml_blck_size(kv_self.k->type)); + ctx0, kv_self.k, N*n_embd_head_padded*n_head_kv, + (ggml_element_size(kv_self.k)*n_embd_head_padded*n_head_kv)*(il*n_ctx + n_past)/blck_size_k); offload_func_kq(k); ggml_set_name(k, "k"); const int64_t v_row_size = kv_self.v->type == GGML_TYPE_Q8_0 ? n_ctx + 128 : n_ctx; struct ggml_tensor * v = ggml_view_blck_2d(ctx0, kv_self.v, N, n_embd_gqa, - ( v_row_size)*ggml_element_size(kv_self.v)/ggml_blck_size(kv_self.v->type), - (il*v_row_size)*ggml_element_size(kv_self.v)*n_embd_gqa/ggml_blck_size(kv_self.v->type) + ggml_element_size(kv_self.v)*(n_past/ggml_blck_size(kv_self.v->type)), - n_past % ggml_blck_size(kv_self.v->type)); + ( v_row_size)*ggml_element_size(kv_self.v)/blck_size_v, + (il*v_row_size)*ggml_element_size(kv_self.v)*n_embd_gqa/blck_size_v + ggml_element_size(kv_self.v)*(n_past/blck_size_v), + n_past % blck_size_v); offload_func_v(v); ggml_set_name(v, "v"); // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + ggml_build_forward_expand(gf, ggml_cpy_pad(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); } struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); @@ -2426,10 +2430,10 @@ static struct ggml_cgraph * llm_build_llama( struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_past + N, n_head_kv, - ggml_element_size(kv_self.k)*n_embd_gqa/ggml_blck_size(kv_self.k->type), - ggml_element_size(kv_self.k)*n_embd_head/ggml_blck_size(kv_self.k->type), - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il/ggml_blck_size(kv_self.k->type)); + n_embd_head_padded, n_past + N, n_head_kv, + ggml_element_size(kv_self.k)*n_embd_head_padded*n_head_kv/blck_size_k, + ggml_element_size(kv_self.k)*n_embd_head_padded/blck_size_k, + ggml_element_size(kv_self.k)*n_embd_head_padded*n_head_kv*n_ctx*il/blck_size_k); offload_func_kq(K); ggml_set_name(K, "K"); @@ -2456,15 +2460,14 @@ static struct ggml_cgraph * llm_build_llama( // split cached V into n_head heads - int64_t v_nelements_padded = n_past + N + ggml_blck_size(kv_self.v->type) - 1; - v_nelements_padded -= v_nelements_padded % ggml_blck_size(kv_self.v->type); - const int64_t v_row_size = kv_self.v->type == GGML_TYPE_Q8_0 ? n_ctx + 128 : n_ctx; + const int64_t v_ne0_padded = ((n_past + N + blck_size_v - 1) / blck_size_v) * blck_size_v; // ne0 padded to multiple of blck_size_v + const int64_t v_row_size = kv_self.v->type == GGML_TYPE_Q8_0 ? n_ctx + 128 : n_ctx; // maximum ne0 + space for temporarily storing unquantized values struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - v_nelements_padded, n_embd_head, n_head_kv, - ggml_element_size(kv_self.v)*v_row_size/ggml_blck_size(kv_self.v->type), - ggml_element_size(kv_self.v)*v_row_size*n_embd_head/ggml_blck_size(kv_self.v->type), - ggml_element_size(kv_self.v)*v_row_size*n_embd_gqa*il/ggml_blck_size(kv_self.v->type)); + v_ne0_padded, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v)*v_row_size/blck_size_v, + ggml_element_size(kv_self.v)*v_row_size*n_embd_head/blck_size_v, + ggml_element_size(kv_self.v)*v_row_size*n_embd_gqa*il/blck_size_v); offload_func_v(V); ggml_set_name(V, "V");