Skip to content

Commit

Permalink
fix half2 decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
ardfork authored and SlyEcho committed Jul 31, 2023
1 parent c1cb70d commit d91456a
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,8 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q4_1 * x = (const block_q4_1 *) vx;

const dfloat d = x[ib].dm.x;
const dfloat m = x[ib].dm.y;
const dfloat d = __low2half(x[ib].dm);
const dfloat m = __high2half(x[ib].dm);

const int vui = x[ib].qs[iqs];

Expand Down Expand Up @@ -515,8 +515,8 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q5_1 * x = (const block_q5_1 *) vx;

const dfloat d = x[ib].dm.x;
const dfloat m = x[ib].dm.y;
const dfloat d = __low2half(x[ib].dm);
const dfloat m = __high2half(x[ib].dm);

uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));
Expand Down Expand Up @@ -568,8 +568,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
const uint8_t q = x[i].qs[32*n + l];
float * y = yy + i*QK_K + 128*n;

float dall = x[i].dm.x;
float dmin = x[i].dm.y;
float dall = __low2half(x[i].dm);
float dmin = __high2half(x[i].dm);
y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
Expand All @@ -579,8 +579,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
const int il = tid%16; // 0...15
const uint8_t q = x[i].qs[il] >> (2*is);
float * y = yy + i*QK_K + 16*is + il;
float dall = x[i].dm.x;
float dmin = x[i].dm.y;
float dall = __low2half(x[i].dm);
float dmin = __high2half(x[i].dm);
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
#endif
Expand Down Expand Up @@ -666,8 +666,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float

float * y = yy + i*QK_K + 64*il + n*ir;

const float dall = x[i].dm.x;
const float dmin = x[i].dm.y;
const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm);

const uint8_t * q = x[i].qs + 32*il + n*ir;

Expand Down Expand Up @@ -705,8 +705,8 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float

float * y = yy + i*QK_K + 64*il + 2*ir;

const float dall = x[i].dm.x;
const float dmin = x[i].dm.y;
const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm);

const uint8_t * ql = x[i].qs + 32*il + 2*ir;
const uint8_t * qh = x[i].qh + 2*ir;
Expand Down Expand Up @@ -818,8 +818,8 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset;

const float dall = x[i].dm.x;
const float dmin = x[i].dm.y;
const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm);

const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
aux[0] = a[0] & 0x0f0f0f0f;
Expand Down Expand Up @@ -1039,8 +1039,8 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;

const float dall = x[i].dm.x;
const float dmin = x[i].dm.y;
const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm);

const uint16_t * a = (const uint16_t *)x[i].scales;
aux[0] = a[im+0] & kmask1;
Expand Down Expand Up @@ -1172,8 +1172,8 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;

const float dall = x[i].dm.x;
const float dmin = x[i].dm.y;
const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm);

const uint16_t * a = (const uint16_t *)x[i].scales;
aux[0] = a[im+0] & kmask1;
Expand Down

0 comments on commit d91456a

Please sign in to comment.