Skip to content

Commit

Permalink
q8_1 half2 ds
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Jul 14, 2023
1 parent 84c38ea commit e395994
Showing 1 changed file with 23 additions and 24 deletions.
47 changes: 23 additions & 24 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
#define QR8_1 1
#define QI8_1 (QK8_1 / (4 * QR8_1))
typedef struct {
half d; // delta
half s; // unquantized sum
half2 ds; // ds.x = delta, ds.y = sum
int8_t qs[QK8_0]; // quants
} block_q8_1;
static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding");
Expand Down Expand Up @@ -1258,8 +1257,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
return;
}

y[ib].d = d;
y[ib].s = sum;
y[ib].ds.x = d;
y[ib].ds.y = sum;
}

template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
Expand All @@ -1284,18 +1283,18 @@ static __global__ void dequantize_block(const void * __restrict__ vx, float * __
}

static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl(
const int & vi, const int & ui0, const int & ui1, const float & d4, const float & d8) {
const int & vi, const int & ui0, const int & ui1, const half & d4, const half2 & ds8) {

#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
// subtract 8 from each quantized value
const int vi0 = __vsub4((vi >> 0) & 0x0F0F0F0F, 0x08080808);
const int vi1 = __vsub4((vi >> 4) & 0x0F0F0F0F, 0x08080808);
const int vi0 = (vi >> 0) & 0x0F0F0F0F;
const int vi1 = (vi >> 4) & 0x0F0F0F0F;

// SIMD dot product of quantized values
int sumi = __dp4a(vi0, ui0, 0);
sumi = __dp4a(vi1, ui1, sumi);

return sumi*d4*d8;
return __half2float(d4) * (sumi * __half2float(ds8.x) - (8/QI4_0) * __half2float(ds8.y));
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
Expand All @@ -1311,7 +1310,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_0)]);

return vec_dot_q4_0_q8_1_impl(vi, ui0, ui1, __half2float(bq4_0->d), __half2float(bq8_1->d));
return vec_dot_q4_0_q8_1_impl(vi, ui0, ui1, bq4_0->d, bq8_1->ds);
}

static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
Expand All @@ -1324,9 +1323,9 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]);

const float d = __half2float(bq4_1->d) * __half2float(bq8_1->d);
const float d = __half2float(bq4_1->d) * __half2float(bq8_1->ds.x);
const float m = bq4_1->m;
const float s = bq8_1->s;
const float s = bq8_1->ds.y;

const int vi0 = (vi >> 0) & 0x0F0F0F0F;
const int vi1 = (vi >> 4) & 0x0F0F0F0F;
Expand Down Expand Up @@ -1354,7 +1353,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_0)]);

const float d = __half2float(bq5_0->d) * __half2float(bq8_1->d);
const float d = __half2float(bq5_0->d) * __half2float(bq8_1->ds.x);

int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5
Expand Down Expand Up @@ -1390,9 +1389,9 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_1)]);

const float d = __half2float(bq5_1->d) * __half2float(bq8_1->d);
const float d = __half2float(bq5_1->d) * __half2float(bq8_1->ds.x);
const float m = bq5_1->m;
const float s = bq8_1->s;
const float s = bq8_1->ds.y;

int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5
Expand Down Expand Up @@ -1424,7 +1423,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
memcpy(&vi, &bq8_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
const int ui = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);

const float d = __half2float(bq8_0->d) * __half2float(bq8_1->d);
const float d = __half2float(bq8_0->d) * __half2float(bq8_1->ds.x);

// SIMD dot product of quantized values
int sumi = __dp4a(vi, ui, 0);
Expand Down Expand Up @@ -1456,7 +1455,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
const int sc = bq2_K->scales[scale_offset + 2*i];

const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
const float d8i = bq8i->d;
const float d8i = bq8i->ds.x;

const int vi = (v >> (2*i)) & 0x03030303;
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
Expand Down Expand Up @@ -1507,7 +1506,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(

const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
const float d8i = bq8i->d;
const float d8i = bq8i->ds.x;

const int vil = (vl >> (2*i)) & 0x03030303;

Expand Down Expand Up @@ -1548,7 +1547,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(

const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
const float d8i = bq8i->d;
const float d8i = bq8i->ds.x;

const int vi = (v >> (4*i)) & 0x0F0F0F0F;

Expand Down Expand Up @@ -1588,7 +1587,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(

const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
const float d8i = bq8i->d;
const float d8i = bq8i->ds.x;

const int vil = (vl >> (4*i)) & 0x0F0F0F0F;

Expand Down Expand Up @@ -1631,7 +1630,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(

const block_q8_1 * bq8i = bq8_1 + bq8_offset + 2*i;
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % (QI8_1))]);
const float d8i = bq8i->d;
const float d8i = bq8i->ds.x;

const int vil = (vl >> (4*i)) & 0x0F0F0F0F;

Expand Down Expand Up @@ -1673,7 +1672,7 @@ static __global__ void mul_mat_q(
__shared__ int tile_x_qs[WARP_SIZE][WARP_SIZE + 1];
__shared__ half tile_x_d[WARP_SIZE][WARP_SIZE/QI4_0];
__shared__ int tile_y_qs[WARP_SIZE][2*WARP_SIZE];
__shared__ half tile_y_d[WARP_SIZE][2*WARP_SIZE/QI8_1];
__shared__ half2 tile_y_ds[WARP_SIZE][2*WARP_SIZE/QI8_1];
float sum[4] = {0.0f};

for (int ib0 = 0; ib0 < blocks_per_row; ib0 += blocks_per_warp) {
Expand All @@ -1694,12 +1693,12 @@ static __global__ void mul_mat_q(
const block_q8_1 * __restrict__ by0 = &y[(col_y_0 + tid_y + i)*blocks_per_row + ib0 + iby0];

tile_y_qs[tid_y + i][tid_x] = *((int *) &by0->qs[iqsy]);
tile_y_d[tid_y + i][iby0] = by0->d;
tile_y_ds[tid_y + i][iby0] = by0->ds;

const block_q8_1 * __restrict__ by1 = &y[(col_y_0 + tid_y + i)*blocks_per_row + ib0 + iby1];

tile_y_qs[tid_y + i][tid_x + WARP_SIZE] = *((int *) &by1->qs[iqsy]);
tile_y_d[tid_y + i][iby1] = by1->d;
tile_y_ds[tid_y + i][iby1] = by1->ds;
}

__syncthreads();
Expand All @@ -1709,7 +1708,7 @@ static __global__ void mul_mat_q(
for (int j = 0; j < WARP_SIZE; j += 8) {
sum[j/8] += vec_dot_q4_0_q8_1_impl(
tile_x_qs[tid_x][k], tile_y_qs[tid_y + j][iqsy + 0], tile_y_qs[tid_y + j][iqsy + (QI8_1/2)],
tile_x_d[tid_x][k / QI4_0], tile_y_d[tid_y + j][2 * k / QI8_1]);
tile_x_d[tid_x][k / QI4_0], tile_y_ds[tid_y + j][2 * k / QI8_1]);
}
}

Expand Down

0 comments on commit e395994

Please sign in to comment.