Skip to content

Commit

Permalink
ggml : fix multi-threaded ggml_compute_forward_diag_mask_f32()
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed May 14, 2023
1 parent 788381e commit a483bb2
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -10372,22 +10372,34 @@ static void ggml_compute_forward_diag_mask_f32(
assert(src1->type == GGML_TYPE_I32);
assert(ggml_nelements(src1) == 2);

if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
const int n_past = ((int32_t *) src1->data)[0];
const bool inplace = (bool)((int32_t *) src1->data)[1];

if (params->type == GGML_TASK_INIT) {
// TODO: this hack is not good, need a better way to handle this
if (!inplace) {
// use the init task to copy src -> dst
struct ggml_compute_params params_cpy = *params;

params_cpy.ith = 0;
params_cpy.nth = 1;
params_cpy.type = GGML_TASK_COMPUTE;

ggml_compute_forward_dup_same_cont(&params_cpy, src0, dst);
}

return;
}

if (params->type == GGML_TASK_FINALIZE) {
return;
}

const int ith = params->ith;
const int nth = params->nth;

const int n_past = ((int32_t *) src1->data)[0];
const bool inplace = (bool)((int32_t *) src1->data)[1];

assert(n_past >= 0);

if (!inplace) {
ggml_compute_forward_dup_same_cont(params, src0, dst);
}

// TODO: handle transposed/permuted matrices

const int n = ggml_nrows(src0);
Expand Down Expand Up @@ -10474,7 +10486,7 @@ static void ggml_compute_forward_soft_max_f32(

for (int i1 = ir0; i1 < ir1; i1++) {
float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
float *dp = (float *)((char *) dst->data + i1*dst->nb[1]);
float *dp = (float *)((char *) dst->data + i1*dst->nb[1]);

#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {
Expand Down

0 comments on commit a483bb2

Please sign in to comment.