Skip to content

Commit

Permalink
Fixed k-quant kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Jul 14, 2023
1 parent 7fa2d80 commit 84c38ea
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1303,6 +1303,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl(

static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {

const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;

int vi;
Expand All @@ -1313,7 +1314,9 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
return vec_dot_q4_0_q8_1_impl(vi, ui0, ui1, __half2float(bq4_0->d), __half2float(bq8_1->d));
}

static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {

#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;

Expand All @@ -1340,6 +1343,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restric

static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {

#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;

Expand Down Expand Up @@ -1376,6 +1380,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(

static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {

#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;

Expand Down Expand Up @@ -1411,6 +1416,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(

static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {

#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;

Expand All @@ -1430,7 +1436,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
}

static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {

#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q2_K * bq2_K = (const block_q2_K *) vbq;
Expand Down Expand Up @@ -1466,7 +1472,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
}

static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {

#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q3_K * bq3_K = (const block_q3_K *) vbq;
Expand Down Expand Up @@ -1519,7 +1525,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
}

static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {

#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q4_K * bq4_K = (const block_q4_K *) vbq;
Expand Down Expand Up @@ -1557,7 +1563,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
}

static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {

#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q5_K * bq5_K = (const block_q5_K *) vbq;
Expand Down Expand Up @@ -1601,7 +1607,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
}

static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {

#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q6_K * bq6_K = (const block_q6_K *) vbq;
Expand Down

0 comments on commit 84c38ea

Please sign in to comment.