Skip to content

Commit

Permalink
metal : add indirect mat-vec kernels for all quantization types
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Dec 10, 2023
1 parent 016f9bb commit 6cfb31f
Show file tree
Hide file tree
Showing 2 changed files with 1,255 additions and 82 deletions.
210 changes: 199 additions & 11 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,21 @@
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
Expand Down Expand Up @@ -354,6 +369,21 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
Expand Down Expand Up @@ -454,6 +484,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
Expand Down Expand Up @@ -1491,17 +1536,22 @@ void ggml_metal_graph_compute(

// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
int ne11_mm_min = 0;
int ne11_mm_min = 1;

const int idx = ((int32_t *) dst->op_params)[0];

// batch size
GGML_ASSERT(ne01 == ne11);

const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory

// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne11 > ne11_mm_min) {
// !!!
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
// indirect matrix multiplication
// !!!
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
switch (src2->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
Expand All @@ -1517,7 +1567,6 @@ void ggml_metal_graph_compute(
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
}
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
Expand Down Expand Up @@ -1549,14 +1598,153 @@ void ggml_metal_graph_compute(

[encoder setThreadgroupMemoryLength:8192 atIndex:0];

[encoder dispatchThreadgroups:MTLSizeMake( (1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
//[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
//for (int64_t i01 = 0; i01 < src0->ne[1]; i01++) {
// [encoder setBuffer:id_src0 offset:offs_src0 + i01*nb01 atIndex:0];
// [encoder setBuffer:id_src1 offset:offs_src1 + i01*nb11 atIndex:1];
// [encoder setBuffer:id_dst offset:offs_dst + i01*nb1 atIndex:2];
// TODO: processing one row at a time (ne11 -> 1) is not efficient
[encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else {
int nth0 = 32;
int nth1 = 1;
int nrows = 1;
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);

// use custom matrix x vector kernel
switch (src2t) {
case GGML_TYPE_F32:
{
GGML_ASSERT(src1t == GGML_TYPE_F32);
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
nrows = 4;
} break;
case GGML_TYPE_F16:
{
GGML_ASSERT(src1t == GGML_TYPE_F32);
nth0 = 32;
nth1 = 1;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
} break;
case GGML_TYPE_Q4_0:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
} break;
case GGML_TYPE_Q4_1:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
} break;
case GGML_TYPE_Q5_0:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
} break;
case GGML_TYPE_Q5_1:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
} break;
case GGML_TYPE_Q8_0:
{
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
} break;
case GGML_TYPE_Q2_K:
{
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
} break;
case GGML_TYPE_Q3_K:
{
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
} break;
case GGML_TYPE_Q4_K:
{
nth0 = 4; //1;
nth1 = 8; //32;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
} break;
case GGML_TYPE_Q5_K:
{
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
} break;
case GGML_TYPE_Q6_K:
{
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
} break;
default:
{
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
GGML_ASSERT(false && "not implemented");
}
};

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
// TODO: how to make this an array? read Metal docs
for (int j = 0; j < n_as; ++j) {
struct ggml_tensor * src_cur = dst->src[2 + j];

size_t offs_src_cur = 0;
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);

[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
}

//}
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src2t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src2t == GGML_TYPE_Q3_K) {
#ifdef GGML_QKK_64
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#else
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#endif
}
else if (src2t == GGML_TYPE_Q5_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src2t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
const int64_t ny = (_ne1 + nrows - 1)/nrows;
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
} break;
case GGML_OP_GET_ROWS:
Expand Down
Loading

0 comments on commit 6cfb31f

Please sign in to comment.