Skip to content

Commit

Permalink
metal : fix ggml_get_rows to work with non-cont src1
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Dec 10, 2023
1 parent 0710b0f commit 016f9bb
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 30 deletions.
9 changes: 5 additions & 4 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1584,11 +1584,12 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];

const int64_t n = ggml_nelements(src1);

[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break;
case GGML_OP_RMS_NORM:
{
Expand Down
75 changes: 49 additions & 26 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -3219,69 +3219,89 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows(
device const void * src0,
device const int * src1,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb1,
uint tgpig[[threadgroup_position_in_grid]],
constant uint64_t & nb2,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tptg [[threads_per_threadgroup]]) {
const int64_t i = tgpig;
const int64_t r = ((device int32_t *) src1)[i];
uint3 tptg [[threads_per_threadgroup]]) {
//const int64_t i = tgpig;
//const int64_t r = ((device int32_t *) src1)[i];

const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;

for (int64_t ind = tiitg; ind < ne00/16; ind += tptg) {
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];

const int64_t i02 = i11;

for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
float4x4 temp;
dequantize_func(
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
*(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
}
}

kernel void kernel_get_rows_f32(
device const void * src0,
device const int * src1,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb1,
uint tgpig[[threadgroup_position_in_grid]],
constant uint64_t & nb2,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tptg [[threads_per_threadgroup]]) {
const int64_t i = tgpig;
const int64_t r = ((device int32_t *) src1)[i];
uint3 tptg [[threads_per_threadgroup]]) {
const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;

const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];

const int64_t i02 = i/ne10;
const int64_t i02 = i11;

for (int ind = tiitg; ind < ne00; ind += tptg) {
((device float *) ((device char *) dst + i*nb1))[ind] =
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
}
}

kernel void kernel_get_rows_f16(
device const void * src0,
device const int * src1,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb1,
uint tgpig[[threadgroup_position_in_grid]],
constant uint64_t & nb2,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tptg [[threads_per_threadgroup]]) {
const int64_t i = tgpig;
const int64_t r = ((device int32_t *) src1)[i];
uint3 tptg [[threads_per_threadgroup]]) {
const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;

const int64_t i02 = i/ne10;
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];

for (int ind = tiitg; ind < ne00; ind += tptg) {
((device float *) ((device char *) dst + i*nb1))[ind] =
const int64_t i02 = i11;

for (int ind = tiitg; ind < ne00; ind += tptg.x) {
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
}
}
Expand Down Expand Up @@ -3543,14 +3563,17 @@ kernel void kernel_mul_mm_id(

typedef void (get_rows_t)(
device const void * src0,
device const int * src1,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb1,
uint, uint, uint);
constant uint64_t & nb2,
uint3, uint, uint3);

//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
Expand Down

0 comments on commit 016f9bb

Please sign in to comment.