Skip to content

Commit

Permalink
metal: fix bugs for GQA and perplexity test.
Browse files Browse the repository at this point in the history
I mixed up ne02 and nb02 in previous commit.
  • Loading branch information
lshzh-ww committed Aug 15, 2023
1 parent bfa455d commit a527ecc
Showing 1 changed file with 39 additions and 30 deletions.
69 changes: 39 additions & 30 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -343,16 +343,16 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
// N_DST, so this is another explicit assumption of the implementation.
template<typename block_q_type, int nr, int nsg, int nw>
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, uint gqa,
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
uint3 tgpig, uint tiisg, uint sgitg) {
const int nb = ne00/QK4_0;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr;
const uint offset0 = first_row * nb + im/gqa*(ne02/QK4_0);
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne12;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
float yl[16]; // src1 vector cache
float sumf[nr]={0.f};

Expand Down Expand Up @@ -383,7 +383,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
for (int row = 0; row < nr; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0 && first_row + row < ne01) {
dst[r1*ne0 + im*ne12 + first_row + row] = tot;
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
}
}
}
Expand All @@ -398,11 +398,12 @@ kernel void kernel_mul_mat_q4_0_f32(
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg);
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
}

kernel void kernel_mul_mat_q4_1_f32(
Expand All @@ -415,11 +416,12 @@ kernel void kernel_mul_mat_q4_1_f32(
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg);
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
}

kernel void kernel_mul_mat_f16_f32(
Expand Down Expand Up @@ -800,6 +802,7 @@ kernel void kernel_mul_mat_q2_K_f32(
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
Expand All @@ -812,9 +815,9 @@ kernel void kernel_mul_mat_q2_K_f32(

const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
const uint offset0 = r2/gqa*(ne02/QK_K);
const uint offset0 = r2/gqa*(nb*ne0);
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[32];
float sumf[N_DST]={0.f}, all_sum;

Expand Down Expand Up @@ -927,7 +930,7 @@ kernel void kernel_mul_mat_q2_K_f32(
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*ne0 + r2*ne12 + first_row + row] = all_sum;
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
}
}
}
Expand All @@ -943,6 +946,7 @@ kernel void kernel_mul_mat_q3_K_f32(
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
Expand All @@ -955,9 +959,9 @@ kernel void kernel_mul_mat_q3_K_f32(
const int64_t r2 = tgpig.z;

const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
const uint offset0 = r2/gqa*(ne02/QK_K);
const uint offset0 = r2/gqa*(nb*ne0);
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;

float yl[16];

Expand Down Expand Up @@ -1045,7 +1049,7 @@ kernel void kernel_mul_mat_q3_K_f32(
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
const float tot = simd_sum(sumf);
if (tiisg == 0) {
dst[r1*ne0 + r2*ne12 + first_row + row] = tot;
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
}
}
}
Expand All @@ -1060,6 +1064,7 @@ kernel void kernel_mul_mat_q3_K_f32(
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
Expand All @@ -1072,9 +1077,9 @@ kernel void kernel_mul_mat_q3_K_f32(
const int64_t r2 = tgpig.z;

const int row = 2 * r0 + sgitg;
const uint offset0 = r2/gqa*(ne02/QK_K);
const uint offset0 = r2/gqa*(nb*ne0);
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
const int ix = tiisg/4;
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
const int im = il/8; // 0, 0, 1, 1
Expand Down Expand Up @@ -1113,7 +1118,7 @@ kernel void kernel_mul_mat_q3_K_f32(

const float tot = simd_sum(sumf);
if (tiisg == 0) {
dst[r1*ne0 + r2*ne12 + row] = tot;
dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
}

}
Expand All @@ -1130,6 +1135,7 @@ kernel void kernel_mul_mat_q4_K_f32(
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
Expand All @@ -1150,9 +1156,9 @@ kernel void kernel_mul_mat_q4_K_f32(
const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
const uint offset0 = r2/gqa*(ne02/QK_K);
const uint offset0 = r2/gqa*(nb*ne0);
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[16];
float yh[16];
float sumf[N_DST]={0.f}, all_sum;
Expand Down Expand Up @@ -1219,7 +1225,7 @@ kernel void kernel_mul_mat_q4_K_f32(
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*ne0 + r2*ne12 + first_row + row] = all_sum;
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
}
}
}
Expand All @@ -1234,6 +1240,7 @@ kernel void kernel_mul_mat_q4_K_f32(
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
Expand All @@ -1248,9 +1255,9 @@ kernel void kernel_mul_mat_q4_K_f32(
const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
const uint offset0 = r2/gqa*(ne02/QK_K);
const uint offset0 = r2/gqa*(nb*ne0);
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[8];
float yh[8];
float sumf[N_DST]={0.f}, all_sum;
Expand Down Expand Up @@ -1306,7 +1313,7 @@ kernel void kernel_mul_mat_q4_K_f32(
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*ne0+ r2*ne12 + first_row + row] = all_sum;
dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
}
}
}
Expand All @@ -1322,6 +1329,7 @@ kernel void kernel_mul_mat_q5_K_f32(
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
Expand All @@ -1334,9 +1342,9 @@ kernel void kernel_mul_mat_q5_K_f32(
const int r2 = tgpig.z;

const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
const uint offset0 = r2/gqa*(ne02/QK_K);
const uint offset0 = r2/gqa*(nb*ne0);
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;

float sumf[2]={0.f};

Expand Down Expand Up @@ -1470,7 +1478,7 @@ kernel void kernel_mul_mat_q5_K_f32(
for (int row = 0; row < 2; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*ne0 + r2*ne12 + first_row + row] = tot;
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
}
}

Expand All @@ -1486,6 +1494,7 @@ kernel void kernel_mul_mat_q6_K_f32(
constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]],
constant int64_t & ne1[[buffer(16)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
Expand All @@ -1503,9 +1512,9 @@ kernel void kernel_mul_mat_q6_K_f32(
const int r2 = tgpig.z;

const int row = 2 * r0 + sgitg;
const uint offset0 = r2/gqa*(ne02/QK_K);
const uint offset0 = r2/gqa*(nb*ne0);
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;

float sumf = 0;

Expand Down Expand Up @@ -1571,7 +1580,7 @@ kernel void kernel_mul_mat_q6_K_f32(

const float tot = simd_sum(sumf);
if (tiisg == 0) {
dst[r1*ne0 + r2*ne12 + row] = tot;
dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
}
}

Expand Down Expand Up @@ -1835,7 +1844,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne12;
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;

for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
//load data and store to threadgroup memory
Expand Down Expand Up @@ -1880,7 +1889,7 @@ kernel void kernel_mul_mm(device const uchar * src0,

if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne12;
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
for (int i = 0; i < 8; i++) {
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
}
Expand All @@ -1893,7 +1902,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
}

threadgroup_barrier(mem_flags::mem_threadgroup);
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne12;
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
if (sgitg==0) {
for (int i = 0; i < n_rows; i++) {
for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
Expand Down

0 comments on commit a527ecc

Please sign in to comment.