Skip to content

Commit

Permalink
metal: use template to reduce size
Browse files Browse the repository at this point in the history
The template mul_vec_q_n_f32 has codes that aim to maxmize the q4_0 and
q4_1 throughput, but it shouldn't affect future q5_0 and q5_1
implementations.
  • Loading branch information
lshzh-ww committed Jul 17, 2023
1 parent 4088df1 commit 8ba11ac
Showing 1 changed file with 50 additions and 123 deletions.
173 changes: 50 additions & 123 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -384,27 +384,44 @@ kernel void kernel_rms_norm(
}
}

// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
float block_q_n_dot_y(block_q4_0 qb_curr, float sumy, thread float * yl) {
float d = qb_curr.d;
float acc = sumy * -8.f;
for (int i = 0; i < 16; i+=2) {
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
}
return d * acc;
}

// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
float block_q_n_dot_y(block_q4_1 qb_curr, float sumy, thread float * yl) {
float d = qb_curr.d;
float m = qb_curr.m;
float acc = 0.f;
for (int i = 0; i < 16; i+=2) {
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
}
return d * acc + m * sumy;
}

// putting them in the kernel cause a significant performance penalty
#define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
kernel void kernel_mul_mat_q4_0_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne10,
constant int64_t & ne0,
constant int64_t & ne01[[buffer(4)]],
uint2 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {

template<typename block_q_type>
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
uint2 tgpig, uint tiisg, uint sgitg) {
const int nb = ne00/QK4_0;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
device const float * y = (device const float *) src1 + r1*ne10;
block_q4_0 qb_curr, qb_next;
block_q_type qb_curr, qb_next;
float4 y_curr[8]; // src1 vector cache
float sumf[N_DST]={0.f}, all_sum;
thread float * yl=(thread float *)y_curr;
Expand All @@ -419,25 +436,15 @@ kernel void kernel_mul_mat_q4_0_f32(
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
sumy *= (-8.f);
// we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
// this design is q4_0 and q4_1 centered, but I think most of the people use these two quantizations.
for (int i = 0; i < 32; i++) {
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
}

for (int row = 0; row < N_DST; row++) {
// prefetch next x block
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];

// calculate
float d = qb_curr.d;
float acc = sumy;
for (int i = 0; i < 16; i+=2) {
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
}
sumf[row] += d * acc;
qb_curr = qb_next;
sumf[row] += block_q_n_dot_y(qb_curr, sumy, yl);
qb_curr = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
}
}

Expand All @@ -449,32 +456,20 @@ kernel void kernel_mul_mat_q4_0_f32(
}
}
} else {

float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) {
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
sumy *= (-8.f);
for (int i = 0; i < 32; i++) {
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
}

for (int row = 0; row < N_DST; row++) {
// prefetch next x block
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];

// calculate
float d = qb_curr.d;
float acc = sumy;
for (int i = 0; i < 16; i+=2) {
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
}
if (tiisg < nb % N_SIMDWIDTH) {
sumf[row] += d * acc;
sumf[row] += block_q_n_dot_y(qb_curr, sumy, yl);
}
qb_curr = qb_next;
qb_curr = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];

all_sum = simd_sum(sumf[row]);
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
Expand All @@ -484,7 +479,7 @@ kernel void kernel_mul_mat_q4_0_f32(
}
}

kernel void kernel_mul_mat_q4_1_f32(
kernel void kernel_mul_mat_q4_0_f32(
device const void * src0,
device const float * src1,
device float * dst,
Expand All @@ -495,89 +490,21 @@ kernel void kernel_mul_mat_q4_1_f32(
uint2 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK4_0;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
device const float * y = (device const float *) src1 + r1*ne10;
block_q4_1 qb_curr, qb_next;
float4 y_curr[8]; // src1 vector cache
float sumf[N_DST]={0.f}, all_sum;
thread float * yl=(thread float *)y_curr;

// bootstrap
qb_curr = x[tiisg];
// each thread in a SIMD group deals with 1 block.
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {

float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) {
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
// we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
for (int i = 0; i < 32; i++) {
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
}

for (int row = 0; row < N_DST; row++) {
// prefetch next x block
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];

// calculate
const float d = qb_curr.d;
const float m = qb_curr.m;
float acc = 0.f;
for (int i = 0; i < 16; i+=2) {
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
}
sumf[row] += d * acc + m * sumy;
qb_curr = qb_next;
}
}

if (nb % N_SIMDWIDTH == 0) {
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
}
}
} else {

float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) {
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
for (int i = 0; i < 32; i++) {
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
}

for (int row = 0; row < N_DST; row++) {
// prefetch next x block
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];

// calculate
const float d = qb_curr.d;
const float m = qb_curr.m;
float acc = 0.f;
for (int i = 0; i < 16; i+=2) {
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
}
if (tiisg < nb % N_SIMDWIDTH) {
sumf[row] += d * acc + m * sumy;
}
qb_curr = qb_next;
mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
}

all_sum = simd_sum(sumf[row]);
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
}
}
}
kernel void kernel_mul_mat_q4_1_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne10,
constant int64_t & ne0,
constant int64_t & ne01[[buffer(4)]],
uint2 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
}

kernel void kernel_mul_mat_f16_f32(
Expand Down

0 comments on commit 8ba11ac

Please sign in to comment.