Skip to content

Commit

Permalink
cuda : wip
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Dec 5, 2023
1 parent 6b58ae9 commit e8457c9
Showing 1 changed file with 62 additions and 5 deletions.
67 changes: 62 additions & 5 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4582,12 +4582,43 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
}
}

// TODO: generalize for all quants
template <cpy_kernel_t cpy_blck>
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q4_0 * dsti = (block_q4_0 *) cdsti;

float amax = 0.0f;
float max = 0.0f;

for (int j = 0; j < QK4_0; ++j) {
const float v = xi[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
max = v;
}
}

const float d = max / -8;
const float id = d ? 1.0f/d : 0.0f;

y[i].d = d;

for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK4_0/2 + j]*id;

const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));

dsti->qs[j] = xi0;
dsti->qs[j] |= xi1 << 4;
}
}

template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*QK8_0;
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;

if (i >= ne) {
return;
Expand All @@ -4600,7 +4631,7 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,

const int i12 = i / (ne10*ne11);
const int i11 = (i - i12*ne10*ne11) / ne10;
const int i10 = (i - i12*ne10*ne11 - i11*ne10)/QK8_0;
const int i10 = (i - i12*ne10*ne11 - i11*ne10)/qk;
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;

cpy_blck(cx + x_offset, cdst + dst_offset);
Expand Down Expand Up @@ -5791,7 +5822,29 @@ static void ggml_cpy_f32_q8_0_cuda(

GGML_ASSERT(ne % QK8_0 == 0);
const int num_blocks = ne / QK8_0;
cpy_f32_q<cpy_blck_f32_q8_0><<<num_blocks, 1, 0, stream>>>
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
}

static void ggml_cpy_f32_q4_0_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {

GGML_ASSERT(ne % QK4_0 == 0);
const int num_blocks = ne / QK4_0;
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
}

static void ggml_cpy_f32_q4_1_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {

GGML_ASSERT(ne % QK4_1 == 0);
const int num_blocks = ne / QK4_1;
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
}

Expand Down Expand Up @@ -7836,6 +7889,10 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
} else {
Expand Down

0 comments on commit e8457c9

Please sign in to comment.