Skip to content

Commit

Permalink
ggml_add: Add more checks
Browse files Browse the repository at this point in the history
  • Loading branch information
slaren committed Apr 16, 2023
1 parent 0a6d5ad commit 8d37db3
Showing 1 changed file with 34 additions and 16 deletions.
50 changes: 34 additions & 16 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -5893,27 +5893,36 @@ static void ggml_compute_forward_add_f16_f32(
const int n = ggml_nrows(src0);
const int nc = src0->ne[0];

//const size_t nb00 = src0->nb[0];
const size_t nb00 = src0->nb[0];
const size_t nb01 = src0->nb[1];

const size_t nb10 = src1->nb[0];
const size_t nb11 = src1->nb[1];

//const size_t nb0 = dst->nb[0];
const size_t nb0 = dst->nb[0];
const size_t nb1 = dst->nb[1];

GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F16);

for (int j = ith; j < n; j += nth) {
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
for (int i = 0; i < nc; i++) {
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));

if (nb10 == sizeof(float)) {
for (int j = ith; j < n; j += nth) {
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
for (int i = 0; i < nc; i++) {
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
}
}
}
else {
// src1 is not contiguous
GGML_ASSERT(false);
}
}

static void ggml_compute_forward_add_f16_f16(
Expand All @@ -5933,27 +5942,36 @@ static void ggml_compute_forward_add_f16_f16(
const int n = ggml_nrows(src0);
const int nc = src0->ne[0];

//const size_t nb00 = src0->nb[0];
const size_t nb00 = src0->nb[0];
const size_t nb01 = src0->nb[1];

const size_t nb10 = src1->nb[0];
const size_t nb11 = src1->nb[1];

//const size_t nb0 = dst->nb[0];
const size_t nb0 = dst->nb[0];
const size_t nb1 = dst->nb[1];

GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F16);

for (int j = ith; j < n; j += nth) {
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
for (int i = 0; i < nc; i++) {
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));

if (nb10 == sizeof(ggml_fp16_t)) {
for (int j = ith; j < n; j += nth) {
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
for (int i = 0; i < nc; i++) {
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
}
}
}
else {
// src1 is not contiguous
GGML_ASSERT(false);
}
}

static void ggml_compute_forward_add_q_f32(
Expand Down

0 comments on commit 8d37db3

Please sign in to comment.