From 6b58ae98921884fc9a85efde0cdefc6a2f4c73b0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 Dec 2023 16:09:16 +0200 Subject: [PATCH] metal : add F32 -> Q4_1 copy kernel --- ggml-metal.m | 8 +++--- ggml-metal.metal | 66 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 418d3003e0e4a..3023aa6db5bf1 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -120,7 +120,7 @@ GGML_METAL_DECL_KERNEL(cpy_f32_f32); GGML_METAL_DECL_KERNEL(cpy_f32_q8_0); GGML_METAL_DECL_KERNEL(cpy_f32_q4_0); - //GGML_METAL_DECL_KERNEL(cpy_f32_q4_1); + GGML_METAL_DECL_KERNEL(cpy_f32_q4_1); //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0); //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1); GGML_METAL_DECL_KERNEL(cpy_f16_f16); @@ -331,7 +331,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(cpy_f32_f32); GGML_METAL_ADD_KERNEL(cpy_f32_q8_0); GGML_METAL_ADD_KERNEL(cpy_f32_q4_0); - //GGML_METAL_ADD_KERNEL(cpy_f32_q4_1); + GGML_METAL_ADD_KERNEL(cpy_f32_q4_1); //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0); //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1); GGML_METAL_ADD_KERNEL(cpy_f16_f16); @@ -437,7 +437,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(cpy_f32_f32); GGML_METAL_DEL_KERNEL(cpy_f32_q8_0); GGML_METAL_DEL_KERNEL(cpy_f32_q4_0); - //GGML_METAL_DEL_KERNEL(cpy_f32_q4_1); + GGML_METAL_DEL_KERNEL(cpy_f32_q4_1); //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0); //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1); GGML_METAL_DEL_KERNEL(cpy_f16_f16); @@ -1578,7 +1578,7 @@ void ggml_metal_graph_compute( case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break; case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break; - //case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break; + case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break; //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break; //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break; default: GGML_ASSERT(false && "not implemented"); diff --git a/ggml-metal.metal b/ggml-metal.metal index 3ca21b8d5652c..9f5ffcbafe8fc 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1586,6 +1586,72 @@ kernel void kernel_cpy_f32_q4_0( } } +kernel void kernel_cpy_f32_q4_1( + device const float * src0, + device void * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1; + + device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < QK4_1; j++) { + const float v = src[j]; + if (min > v) min = v; + if (max < v) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK4_1].d = d; + dst_data[i00/QK4_1].m = min; + + for (int j = 0; j < QK4_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK4_1/2 + j] - min)*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); + + dst_data[i00/QK4_1].qs[j] = xi0; + dst_data[i00/QK4_1].qs[j] |= xi1 << 4; + } + } +} + kernel void kernel_concat( device const char * src0, device const char * src1,