Skip to content

Commit

Permalink
fix q8_0 for model with n_embd_head % 32 != 0
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Sep 16, 2023
1 parent 91523fb commit e0d0a0f
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 66 deletions.
52 changes: 36 additions & 16 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4044,7 +4044,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
cpy_1(cx + x_offset, cdst + dst_offset);
}

template <bool first_incomplete, bool last_incomplete>
template <bool first_incomplete, bool last_incomplete, bool save_unquantized>
static __global__ void cpy_f32_q8_0(
const char * cx, char * cdst, const int i_blck_0, const int ne00, const int ne01, const int ne02,
const int nb00, const int nb01, const int nb02, const int nb11, const int nb12) {
Expand Down Expand Up @@ -4075,7 +4075,7 @@ static __global__ void cpy_f32_q8_0(
val = *((float *) src);
}

if (last_incomplete && i0 / QK8_0 == (i_blck_0 + ne00) / QK8_0) {
if (save_unquantized && last_incomplete && i0 / QK8_0 == (i_blck_0 + ne00) / QK8_0) {
memcpy(&dst[1 + iqs/8].qs[sizeof(float) * (iqs % 8)], src, sizeof(float));
}

Expand Down Expand Up @@ -5114,7 +5114,7 @@ static void ggml_cpy_f32_f16_cuda(

static void ggml_cpy_f32_q8_0_cuda(
const char * cx, char * cdst, const int i_blck_0, const int ne00, const int ne01, const int ne02,
const int nb00, const int nb01, const int nb02, const int nb11, const int nb12, cudaStream_t stream) {
const int nb00, const int nb01, const int nb02, const int nb11, const int nb12, const bool pad, cudaStream_t stream) {

const int num_blocks_x = (i_blck_0 + ne00 + WARP_SIZE - 1) / WARP_SIZE;
const dim3 block_nums(num_blocks_x, ne01, ne02);
Expand All @@ -5125,17 +5125,27 @@ static void ggml_cpy_f32_q8_0_cuda(

if (first_incomplete && last_incomplete) {
GGML_ASSERT(i_blck_0 + ne00 < QK8_0); // otherwise there would be a race condition
cpy_f32_q8_0<true, true><<<block_nums, block_dims, 0, stream>>>
GGML_ASSERT(pad == false);
cpy_f32_q8_0<true, true, false><<<block_nums, block_dims, 0, stream>>>
(cx, cdst, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02, nb11, nb12);
} else if (first_incomplete && !last_incomplete) {
cpy_f32_q8_0<true, false><<<block_nums, block_dims, 0, stream>>>
GGML_ASSERT(pad == false);
cpy_f32_q8_0<true, false, false><<<block_nums, block_dims, 0, stream>>>
(cx, cdst, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02, nb11, nb12);
} else if (!first_incomplete && last_incomplete) {
cpy_f32_q8_0<false, true><<<block_nums, block_dims, 0, stream>>>
} else if (!first_incomplete && last_incomplete && pad) {
cpy_f32_q8_0<false, true, false><<<block_nums, block_dims, 0, stream>>>
(cx, cdst, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02, nb11, nb12);
} else if (!first_incomplete && !last_incomplete) {
cpy_f32_q8_0<false, false><<<block_nums, block_dims, 0, stream>>>
} else if (!first_incomplete && last_incomplete && !pad) {
cpy_f32_q8_0<false, true, true><<<block_nums, block_dims, 0, stream>>>
(cx, cdst, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02, nb11, nb12);
} else if (!first_incomplete && !last_incomplete && pad) {
cpy_f32_q8_0<false, false, true><<<block_nums, block_dims, 0, stream>>>
(cx, cdst, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02, nb11, nb12);
} else if (!first_incomplete && !last_incomplete && !pad) {
cpy_f32_q8_0<false, false, true><<<block_nums, block_dims, 0, stream>>>
(cx, cdst, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02, nb11, nb12);
} else {
GGML_ASSERT(false);
}
}

Expand Down Expand Up @@ -6626,9 +6636,6 @@ void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_te
}

void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));

GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);

Expand All @@ -6652,6 +6659,16 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
const int64_t nb11 = src1->nb[1];
const int64_t nb12 = src1->nb[2];

const int64_t blck_size = ggml_blck_size(src1->type);
const int64_t ne00_padded = ((ne00 + blck_size - 1) / blck_size) * blck_size;
const int64_t ne = ggml_nelements(src0);
const bool pad = dst->op_params[0] & 1;
if (pad) {
GGML_ASSERT(ne00_padded * ggml_nrows(src0) == ggml_nelements(src1));
} else {
GGML_ASSERT(ne == ggml_nelements(src1));
}

CUDA_CHECK(cudaSetDevice(g_main_device));
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];

Expand All @@ -6670,16 +6687,19 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
GGML_ASSERT(nb10 == sizeof(block_q8_0));

const size_t * op_params = (const size_t *) src1->op_params;
const size_t i_blck_0 = op_params[1];
size_t i_blck_0 = 0;
if (src1->op == GGML_OP_VIEW) {
const size_t * op_params = (const size_t *) src1->op_params;
i_blck_0 = op_params[1];
}

if (ggml_is_contiguous(src1)) {
ggml_cpy_f32_q8_0_cuda(
src0_ddc, src1_ddc, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02,
ne00*sizeof(block_q8_0)/QK8_0, ne00*ne01*sizeof(block_q8_0)/QK8_0, cudaStream_main);
ne00_padded*sizeof(block_q8_0)/QK8_0, ne00_padded*ne01*sizeof(block_q8_0)/QK8_0, pad, cudaStream_main);
} else {
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, i_blck_0, ne00, ne01, ne02,
nb00, nb01, nb02, nb11, nb12, cudaStream_main);
nb00, nb01, nb02, nb11, nb12, pad, cudaStream_main);
}

} else {
Expand Down
47 changes: 39 additions & 8 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -6312,8 +6312,15 @@ static struct ggml_tensor * ggml_cpy_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
bool inplace) {
GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
const bool inplace,
const bool pad) {
if (pad) {
const int64_t blck_size = ggml_blck_size(b->type);
const int64_t ne00_padded = ((a->ne[0] + blck_size - 1) / blck_size) * blck_size;
GGML_ASSERT(ne00_padded*ggml_nrows(a) == ggml_nelements(b));
} else {
GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
}

bool is_node = false;

Expand All @@ -6329,6 +6336,8 @@ static struct ggml_tensor * ggml_cpy_impl(
ggml_format_name(result, "%s (copy)", a->name);
}

ggml_set_op_params_i32(result, 0, pad ? 1 : 0);

result->op = GGML_OP_CPY;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
Expand All @@ -6341,14 +6350,21 @@ struct ggml_tensor * ggml_cpy(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_cpy_impl(ctx, a, b, false);
return ggml_cpy_impl(ctx, a, b, false, false);
}

struct ggml_tensor * ggml_cpy_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_cpy_impl(ctx, a, b, true);
return ggml_cpy_impl(ctx, a, b, true, false);
}

struct ggml_tensor * ggml_cpy_pad(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_cpy_impl(ctx, a, b, false, true);
}

// ggml_cont
Expand Down Expand Up @@ -8233,6 +8249,8 @@ static void ggml_compute_forward_dup_f16(

GGML_TENSOR_UNARY_OP_LOCALS;

GGML_ASSERT(dst->op_params[0] == 0);

const int ith = params->ith; // thread index
const int nth = params->nth; // number of threads

Expand Down Expand Up @@ -8496,14 +8514,21 @@ static void ggml_compute_forward_dup_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}

GGML_TENSOR_UNARY_OP_LOCALS;

const bool pad = dst->op_params[0] & 1;
const int blck_size = ggml_blck_size(dst->type);
const int ne00_padded = ((ne00 + blck_size - 1) / blck_size) * blck_size;
if (pad) {
GGML_ASSERT(ggml_nelements(dst) == ne00_padded*ggml_nrows(src0));
} else {
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
}

const int ith = params->ith; // thread index
const int nth = params->nth; // number of threads

Expand Down Expand Up @@ -8561,15 +8586,20 @@ static void ggml_compute_forward_dup_f32(
ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;

size_t id = 0;
size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
const size_t rs = nb0 * ne00_padded / blck_size;
char * dst_ptr = (char *) dst->data;
float src0_padded[ne00_padded];

for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
id += rs * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
quantize_row_q(src0_ptr, dst_ptr + id, ne00);
if (ne00 != ne00_padded) {
memcpy(src0_padded, src0_ptr, ne00*sizeof(float));
memset(src0_padded + ne00, 0, (ne00_padded - ne00) * sizeof(float));
}
quantize_row_q(ne00 == ne00_padded ? src0_ptr : src0_padded, dst_ptr + id, ne00_padded);
id += rs;
}
id += rs * (ne01 - ir1);
Expand Down Expand Up @@ -8737,6 +8767,7 @@ static void ggml_compute_forward_dup_f32(
}
}
} else if (type_traits[dst->type].from_float) {
GGML_ASSERT(!pad);
GGML_ASSERT(ne00 == ne0);
GGML_ASSERT(ne01 == ne1);
GGML_ASSERT(ne02 == ne2);
Expand Down
6 changes: 6 additions & 0 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,12 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);

// a -> b, pad row size of a to a multiple of block size of b, return view(b)
GGML_API struct ggml_tensor * ggml_cpy_pad(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);

// make contiguous
GGML_API struct ggml_tensor * ggml_cont(
struct ggml_context * ctx,
Expand Down
Loading

0 comments on commit e0d0a0f

Please sign in to comment.