From 2a82ee0fc85aed3b06a5f79a21bbf7d844f16891 Mon Sep 17 00:00:00 2001 From: zhentaoyu Date: Tue, 1 Aug 2023 16:15:30 +0800 Subject: [PATCH] [Graph] Falcon-7B optimization (#1199) --- .../application/ChatFALCON/main_falcon.cpp | 50 ++-- .../application/ChatFALCON/quant_falcon.cpp | 2 +- .../neural_engine/graph/core/ne_layers.c | 217 ++++++++---------- 3 files changed, 122 insertions(+), 147 deletions(-) diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatFALCON/main_falcon.cpp b/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatFALCON/main_falcon.cpp index 7dcf8d9f752..1d858105de3 100644 --- a/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatFALCON/main_falcon.cpp +++ b/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatFALCON/main_falcon.cpp @@ -165,7 +165,7 @@ bool falcon_model_load(const std::string& fname, falcon_model& model, gpt_vocab& const int n_vocab = hparams.n_vocab; const int head_dim = hparams.n_embd / hparams.n_head; - ctx_size += n_embd * n_vocab * ne_type_sizef(NE_TYPE_F32); // tok_embeddings + ctx_size += n_embd * n_vocab * ne_type_sizef(wtype); // tok_embeddings ctx_size += n_embd * ne_type_sizef(NE_TYPE_F32); // output_norm ctx_size += n_embd * ne_type_sizef(NE_TYPE_F32); // output_norm_b @@ -218,7 +218,7 @@ bool falcon_model_load(const std::string& fname, falcon_model& model, gpt_vocab& model.layers.resize(n_layer); - model.tok_embeddings = ne_new_tensor_2d(ctx, NE_TYPE_F32, n_embd, n_vocab, NE_SIZE_CALC); + model.tok_embeddings = ne_new_tensor_2d(ctx, wtype, n_embd, n_vocab, NE_SIZE_CALC); model.output_norm = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_embd, NE_SIZE_CALC); model.output_norm_b = ne_new_tensor_1d(ctx, NE_TYPE_F32, n_embd, NE_SIZE_CALC); model.lm_head = ne_new_tensor_2d(ctx, NE_TYPE_F32, n_embd, n_vocab, NE_SIZE_CALC); @@ -411,7 +411,6 @@ bool falcon_eval(const falcon_model& model, const int n_threads, const int n_pas // wte struct ne_tensor* inpL = ne_get_rows(ctx0, model.tok_embeddings, embd); - struct ne_tensor* repeat_dummy = ne_new_tensor_3d(ctx0, inpL->type, head_dim, N + n_past, n_head, NE_SIZE_CALC); for (int il = 0; il < n_layer; ++il) { struct ne_tensor* cur; @@ -453,28 +452,34 @@ bool falcon_eval(const falcon_model& model, const int n_threads, const int n_pas // store key and value to memory { - struct ne_tensor* k = ne_view_1d(ctx0, model.memory_k, N * head_dim, - (ne_element_size(model.memory_k) * head_dim) * (il * n_ctx + n_past)); - struct ne_tensor* v = ne_view_1d(ctx0, model.memory_v, N * head_dim, - (ne_element_size(model.memory_v) * head_dim) * (il * n_ctx + n_past)); - - ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur, k)); - ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur, v)); + // head_dim, 1 (head_num), N --> head_dim, N, 1 (head_num) + struct ne_tensor* Kcur_permuted = ne_permute(ctx0, Kcur, 0, 2, 1, 3); + // head_dim, 1 (head_num), N --> N, head_dim, 1 (head_num) + struct ne_tensor* Vcur_permuted = ne_permute(ctx0, Vcur, 1, 2, 0, 3); + + struct ne_tensor* k = + ne_view_3d(ctx0, model.memory_k, head_dim, N, 1, ne_element_size(model.memory_k) * head_dim, + ne_element_size(model.memory_k) * head_dim * n_ctx, + il * n_ctx * ne_element_size(model.memory_k) * head_dim + + n_past * ne_element_size(model.memory_k) * head_dim); + struct ne_tensor* v = ne_view_3d( + ctx0, model.memory_v, N, head_dim, 1, n_ctx * ne_element_size(model.memory_v), + n_ctx * ne_element_size(model.memory_v) * head_dim, + il * n_ctx * ne_element_size(model.memory_v) * head_dim + n_past * ne_element_size(model.memory_v)); + + ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_permuted, k)); + ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_permuted, v)); } // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) struct ne_tensor* Q = ne_permute(ctx0, Qcur, 0, 2, 1, 3); struct ne_tensor* K = - ne_permute(ctx0, - ne_reshape_3d(ctx0, - ne_view_1d(ctx0, model.memory_k, (n_past + N) * head_dim, - il * n_ctx * ne_element_size(model.memory_k) * head_dim), - head_dim, 1, n_past + N), - 0, 2, 1, 3); + ne_view_3d(ctx0, model.memory_k, head_dim, N + n_past, 1, ne_element_size(model.memory_k) * head_dim, + ne_element_size(model.memory_k) * head_dim * n_ctx, + il * n_ctx * ne_element_size(model.memory_k) * head_dim * 1); // K * Q - K = ne_cont(ctx0, ne_repeat(ctx0, K, repeat_dummy)); struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q); // KQ_scaled = KQ / sqrt(n_embd/n_head) @@ -488,14 +493,9 @@ bool falcon_eval(const falcon_model& model, const int n_threads, const int n_pas // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() struct ne_tensor* V = - ne_permute(ctx0, - ne_reshape_3d(ctx0, - ne_view_1d(ctx0, model.memory_v, (n_past + N) * head_dim, - il * n_ctx * ne_element_size(model.memory_v) * head_dim), - head_dim, 1, n_past + N), - 0, 2, 1, 3); - - V = ne_cont(ctx0, ne_transpose(ctx0, ne_repeat(ctx0, V, repeat_dummy))); + ne_view_3d(ctx0, model.memory_v, N + n_past, head_dim, 1, ne_element_size(model.memory_v) * n_ctx, + ne_element_size(model.memory_v) * n_ctx * head_dim, + il * n_ctx * ne_element_size(model.memory_v) * head_dim * 1); // KQV = transpose(V) * KQ_soft_max struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max); diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatFALCON/quant_falcon.cpp b/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatFALCON/quant_falcon.cpp index a75f5d89019..9b2c51dccc4 100644 --- a/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatFALCON/quant_falcon.cpp +++ b/intel_extension_for_transformers/backends/neural_engine/graph/application/ChatFALCON/quant_falcon.cpp @@ -118,7 +118,7 @@ bool falcon_model_quantize(const std::string& fname_inp, const std::string& fnam ".*weight", }; - if (!ne_common_quantize_0(finp, fout, params, to_quant, {"transformer.word_embeddings.weight", "lm_head.weight"})) { + if (!ne_common_quantize_0(finp, fout, params, to_quant, {"lm_head.weight"})) { fprintf(stderr, "%s: failed to quantize model '%s'\n", __func__, fname_inp.c_str()); return false; } diff --git a/intel_extension_for_transformers/backends/neural_engine/graph/core/ne_layers.c b/intel_extension_for_transformers/backends/neural_engine/graph/core/ne_layers.c index a1d5559633b..72de9d32b1f 100644 --- a/intel_extension_for_transformers/backends/neural_engine/graph/core/ne_layers.c +++ b/intel_extension_for_transformers/backends/neural_engine/graph/core/ne_layers.c @@ -624,7 +624,8 @@ static inline bool ne_is_matrix(const struct ne_tensor* tensor) { static inline bool ne_can_mul_mat(const struct ne_tensor* t0, const struct ne_tensor* t1) { static_assert(NE_MAX_DIMS == 4, "NE_MAX_DIMS is not 4 - update this function"); - return (t0->ne[0] == t1->ne[0]) && (t0->ne[2] == t1->ne[2]) && (t0->ne[3] == t1->ne[3]); + // verify t0 is broadcastable + return (t0->ne[0] == t1->ne[0]) && (t1->ne[2] % t0->ne[2] == 0) && (t1->ne[3] % t0->ne[3] == 0); } bool ne_is_quantized(enum ne_type type) { return NE_IS_QUANTIZED[type]; } @@ -1939,8 +1940,8 @@ struct ne_tensor* ne_mul_mat(struct ne_context* ctx, struct ne_tensor* a, struct is_node = true; } - const int64_t ne[4] = {a->ne[1], b->ne[1], a->ne[2], b->ne[3]}; - struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, MIN(a->n_dims, b->n_dims), ne, NE_SIZE_CALC); + const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] }; + struct ne_tensor * result = ne_new_tensor(ctx, NE_TYPE_F32, MAX(a->n_dims, b->n_dims), ne, NE_SIZE_CALC); result->op = NE_OP_MUL_MAT; result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; @@ -5528,7 +5529,7 @@ static void ne_compute_forward_mul_mat_f32(const struct ne_compute_params* param const int64_t ne03 = src0->ne[3]; const int64_t ne11 = src1->ne[1]; -#ifndef NDEBUG + const int64_t ne12 = src1->ne[2]; const int64_t ne13 = src1->ne[3]; @@ -5538,17 +5539,17 @@ static void ne_compute_forward_mul_mat_f32(const struct ne_compute_params* param const int64_t ne3 = dst->ne[3]; const int nb00 = src0->nb[0]; -#endif + const int nb01 = src0->nb[1]; const int nb02 = src0->nb[2]; const int nb03 = src0->nb[3]; -#ifndef NDEBUG + const int nb10 = src1->nb[0]; -#endif - const int nb11 = src1->nb[1]; - const int nb12 = src1->nb[2]; - const int nb13 = src1->nb[3]; + + const int nb11 = src1->nb[1]; UNUSED(nb11); + const int nb12 = src1->nb[2]; UNUSED(nb12); + const int nb13 = src1->nb[3]; UNUSED(nb13); const int nb0 = dst->nb[0]; const int nb1 = dst->nb[1]; @@ -5558,25 +5559,20 @@ static void ne_compute_forward_mul_mat_f32(const struct ne_compute_params* param const int ith = params->ith; const int nth = params->nth; - assert(ne02 == ne12); - assert(ne03 == ne13); - assert(ne2 == ne12); - assert(ne3 == ne13); + NE_ASSERT(ne0 == ne01); + NE_ASSERT(ne1 == ne11); + NE_ASSERT(ne2 == ne12); + NE_ASSERT(ne3 == ne13); // we don't support permuted src0 or src1 - assert(nb00 == sizeof(float)); - assert(nb10 == sizeof(float)); + NE_ASSERT(nb00 == sizeof(float)); + NE_ASSERT(nb10 == sizeof(float)); // dst cannot be transposed or permuted - assert(nb0 == sizeof(float)); - assert(nb0 <= nb1); - assert(nb1 <= nb2); - assert(nb2 <= nb3); - - assert(ne0 == ne01); - assert(ne1 == ne11); - assert(ne2 == ne02); - assert(ne3 == ne03); + NE_ASSERT(nb0 == sizeof(float)); + NE_ASSERT(nb0 <= nb1); + NE_ASSERT(nb1 <= nb2); + NE_ASSERT(nb2 <= nb3); // nb01 >= nb00 - src0 is not transposed // compute by src0 rows @@ -5589,39 +5585,35 @@ static void ne_compute_forward_mul_mat_f32(const struct ne_compute_params* param return; } - // parallelize by src0 rows using ne_vec_dot_f32 + // parallelize by src0 rows + const int64_t dr = (ne01 + nth - 1) / nth; - // total rows in src0 - const int nr = ne01 * ne02 * ne03; + const int64_t ir10 = dr * ith; + const int64_t ir11 = MIN(ir10 + dr, ne01); - // rows per thread - const int dr = (nr + nth - 1) / nth; + // src1 rows + const int64_t nr1 = ne11 * ne12 * ne13; - // row range for this thread - const int ir0 = dr * ith; - const int ir1 = MIN(ir0 + dr, nr); + for (int64_t ir1 = 0; ir1 < nr1; ++ir1) { + const int64_t i13 = (ir1 / (ne12 * ne11)); + const int64_t i12 = (ir1 - i13 * ne12 * ne11) / ne11; + const int64_t i11 = (ir1 - i13 * ne12 * ne11 - i12 * ne11); - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir / (ne02 * ne01); - const int i02 = (ir - i03 * ne02 * ne01) / ne01; - const int i01 = (ir - i03 * ne02 * ne01 - i02 * ne01); + const int64_t ir0 = (ir1 / ne11) % (ne02 * ne03); + const int64_t i03 = (ir0 / (ne02)); + const int64_t i02 = (ir0 - i03 * ne02); - for (int64_t ic = 0; ic < ne11; ++ic) { - // src1 indices - const int i13 = i03; - const int i12 = i02; - const int i11 = ic; + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; - // dst indices - const int i0 = i01; - const int i1 = i11; - const int i2 = i02; - const int i3 = i03; + char* src0_row = (char*) src0->data + (0 + i02 * nb02 + i03 * nb03); + char* src1_col = (char*) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13); - ne_vec_dot_f32(ne00, (float*)((char*)dst->data + (i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3)), - (float*)((char*)src0->data + (i01 * nb01 + i02 * nb02 + i03 * nb03)), - (float*)((char*)src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13))); + float* dst_col = (float*) ((char*) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); + + for (int64_t ir = ir10; ir < ir11; ++ir) { + ne_vec_dot_f32(ne00, &dst_col[ir], (float*) (src0_row + ir * nb01), (float*) src1_col); } } @@ -5679,8 +5671,8 @@ static void ne_compute_forward_mul_mat_f16_f32(const struct ne_compute_params* p const int ith = params->ith; const int nth = params->nth; - NE_ASSERT(ne02 == ne12); - NE_ASSERT(ne03 == ne13); + NE_ASSERT(ne0 == ne01); + NE_ASSERT(ne1 == ne11); NE_ASSERT(ne2 == ne12); NE_ASSERT(ne3 == ne13); @@ -5693,11 +5685,6 @@ static void ne_compute_forward_mul_mat_f16_f32(const struct ne_compute_params* p NE_ASSERT(nb1 <= nb2); NE_ASSERT(nb2 <= nb3); - NE_ASSERT(ne0 == ne01); - NE_ASSERT(ne1 == ne11); - NE_ASSERT(ne2 == ne02); - NE_ASSERT(ne3 == ne03); - // nb01 >= nb00 - src0 is not transposed // compute by src0 rows @@ -5729,40 +5716,38 @@ static void ne_compute_forward_mul_mat_f16_f32(const struct ne_compute_params* p // TODO: do not support transposed src1 assert(nb10 / 2 == sizeof(ne_fp16_t)); - // parallelize by src0 rows using ne_vec_dot_f16 - - // total rows in src0 - const int nr = ne01 * ne02 * ne03; + // parallelize by src0 rows + const int64_t dr = (ne01 + nth - 1) / nth; - // rows per thread - const int dr = (nr + nth - 1) / nth; + const int64_t ir10 = dr * ith; + const int64_t ir11 = MIN(ir10 + dr, ne01); - // row range for this thread - const int ir0 = dr * ith; - const int ir1 = MIN(ir0 + dr, nr); + // src1 rows + const int64_t nr1 = ne11 * ne12 * ne13; - ne_fp16_t* wdata = params->wdata; + void* wdata = params->wdata; + const size_t row_size = ne10 * NE_TYPE_SIZE[NE_TYPE_F16]; - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir / (ne02 * ne01); - const int i02 = (ir - i03 * ne02 * ne01) / ne01; - const int i01 = (ir - i03 * ne02 * ne01 - i02 * ne01); + for (int64_t ir1 = 0; ir1 < nr1; ++ir1) { + const int64_t i13 = (ir1 / (ne12 * ne11)); + const int64_t i12 = (ir1 - i13 * ne12 * ne11) / ne11; + const int64_t i11 = (ir1 - i13 * ne12 * ne11 - i12 * ne11); - const int i13 = i03; - const int i12 = i02; + const int64_t ir0 = (ir1 / ne11) % (ne02 * ne03); + const int64_t i03 = (ir0 / (ne02)); + const int64_t i02 = (ir0 - i03 * ne02); - const int i0 = i01; - const int i2 = i02; - const int i3 = i03; + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; - ne_fp16_t* src0_row = (ne_fp16_t*)((char*)src0->data + (i01 * nb01 + i02 * nb02 + i03 * nb03)); - ne_fp16_t* src1_col = wdata + (0 + i12 * ne11 + i13 * ne12 * ne11) * ne00; + char* src0_row = (char*) src0->data + (0 + i02 * nb02 + i03 * nb03); + char* src1_col = (char*) wdata + (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size; - float* dst_col = (float*)((char*)dst->data + (i0 * nb0 + 0 * nb1 + i2 * nb2 + i3 * nb3)); + float* dst_col = (float*) ((char*) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); - for (int64_t ic = 0; ic < ne11; ++ic) { - ne_vec_dot_f16(ne00, &dst_col[ic * ne0], src0_row, src1_col + ic * ne00); + for (int64_t ir = ir10; ir < ir11; ++ir) { + ne_vec_dot_f16(ne00, &dst_col[ir], (ne_fp16_t*) (src0_row + ir * nb01), (ne_fp16_t*) src1_col); } } @@ -5818,16 +5803,16 @@ static void ne_compute_forward_mul_mat_q_f32(const struct ne_compute_params* par const int ith = params->ith; const int nth = params->nth; - NE_ASSERT(ne02 == ne12); - NE_ASSERT(ne03 == ne13); - NE_ASSERT(ne2 == ne12); - NE_ASSERT(ne3 == ne13); - const enum ne_type type = src0->type; quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot; vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q; enum ne_type const vec_dot_type = quantize_fns[type].vec_dot_type; + NE_ASSERT(ne0 == ne01); + NE_ASSERT(ne1 == ne11); + NE_ASSERT(ne2 == ne12); + NE_ASSERT(ne3 == ne13); + // we don't support permuted src0 or src1 NE_ASSERT(nb00 == (int)NE_TYPE_SIZE[type]); NE_ASSERT(nb10 == sizeof(float)); @@ -5838,11 +5823,6 @@ static void ne_compute_forward_mul_mat_q_f32(const struct ne_compute_params* par NE_ASSERT(nb1 <= nb2); NE_ASSERT(nb2 <= nb3); - NE_ASSERT(ne0 == ne01); - NE_ASSERT(ne1 == ne11); - NE_ASSERT(ne2 == ne02); - NE_ASSERT(ne3 == ne03); - // nb01 >= nb00 - src0 is not transposed // compute by src0 rows @@ -5866,43 +5846,38 @@ static void ne_compute_forward_mul_mat_q_f32(const struct ne_compute_params* par return; } - // parallelize by src0 rows using ne_vec_dot_q + // parallelize by src0 rows + const int64_t dr = (ne01 + nth - 1) / nth; - // total rows in src0 - const int nr = ne01 * ne02 * ne03; + const int64_t ir10 = dr * ith; + const int64_t ir11 = MIN(ir10 + dr, ne01); - // rows per thread - const int dr = (nr + nth - 1) / nth; + // src1 rows + const int64_t nr1 = ne11 * ne12 * ne13; - // row range for this thread - const int ir0 = dr * ith; - const int ir1 = MIN(ir0 + dr, nr); + const void * wdata = params->wdata; + const size_t row_size = ne10 * NE_TYPE_SIZE[vec_dot_type] / NE_BLCK_SIZE[vec_dot_type]; - void* wdata = params->wdata; - const size_t row_size = ne00 * NE_TYPE_SIZE[vec_dot_type] / NE_BLCK_SIZE[vec_dot_type]; + for (int64_t ir1 = 0; ir1 < nr1; ++ir1) { + const int64_t i13 = (ir1 / (ne12 * ne11)); + const int64_t i12 = (ir1 - i13 * ne12 * ne11) / ne11; + const int64_t i11 = (ir1 - i13 * ne12 * ne11 - i12 * ne11); - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir / (ne02 * ne01); - const int i02 = (ir - i03 * ne02 * ne01) / ne01; - const int i01 = (ir - i03 * ne02 * ne01 - i02 * ne01); + const int64_t ir0 = (ir1 / ne11) % (ne02 * ne03); + const int64_t i03 = (ir0 / (ne02)); + const int64_t i02 = (ir0 - i03 * ne02); - const int i13 = i03; - const int i12 = i02; + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; - const int i0 = i01; - const int i2 = i02; - const int i3 = i03; - - void* src0_row = (void*)((char*)src0->data + (i01 * nb01 + i02 * nb02 + i03 * nb03)); - char* src1_col = ((char*)wdata + ((0 + i12 * ne11 + i13 * ne12 * ne11) * row_size)); + const char* src0_row = (const char*) src0->data + (0 + i02 * nb02 + i03 * nb03); + const char* src1_col = (const char*) wdata + (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size; - float* dst_col = (float*)((char*)dst->data + (i0 * nb0 + 0 * nb1 + i2 * nb2 + i3 * nb3)); - - assert(ne00 % 32 == 0); + float* dst_col = (float*) ((char*) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); - for (int64_t ic = 0; ic < ne11; ++ic) { - vec_dot_q(ne00, &dst_col[ic * ne0], src0_row, (void*)(src1_col + ic * row_size)); + for (int64_t ir = ir10; ir < ir11; ++ir) { + vec_dot_q(ne00, &dst_col[ir], src0_row + ir * nb01, src1_col); } }