diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/gptneox/gptneox.cpp b/intel_extension_for_transformers/llm/runtime/graph/models/gptneox/gptneox.cpp index 23aff936e94..698a7e7976f 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/gptneox/gptneox.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/models/gptneox/gptneox.cpp @@ -113,6 +113,28 @@ static bool gptneox_model_eval_internal(model_context& lctx, const model_token* ne_cgraph gf = {}; gf.n_threads = N >= 32 && ne_cpu_has_blas() ? 1 : n_threads; + const bool run_mha_reordered = kv_self.k->type == NE_TYPE_JBLAS; + kv_cache_info_t kv_cache_info = {}; + if (run_mha_reordered) { + NE_ASSERT(("kv cache should be the same dtype", kv_self.v->type == NE_TYPE_JBLAS)); + attn_shape_t attn_shape = { + /* .batch_size = */ 1, + /* .head_num = */ n_head, + /* .heads_kv = */ n_head, + /* .head_size = */ head_dim, + /* .sl_q = */ N, // Note: make sure that jblas reordered attn supports next token inference + /* .sl_kv = */ n_past + N, + }; + + NE_ASSERT(("jblas managed kv-cache not supported; use `--memory-f16 / --memory-f32` instead", + jblas_reordered_attn_fp32_support(&attn_shape))); + kv_shape_t kv_shape{ + /* .heads_kv = */ static_cast(n_head), + /* .head_size = */ static_cast(head_dim), + /* .sl_kv_max = */ static_cast(n_ctx), + }; + jblas_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info); + } struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N * batch_size); ne_set_name(embd, "embd"); for (int i = 0; i < batch_size; ++i) { @@ -151,78 +173,120 @@ static bool gptneox_model_eval_internal(model_context& lctx, const model_token* // using mode = 2 for GPT-NeoX mode Qcur = ne_rope_inplace(ctx0, ne_reshape_4d(ctx0, Qcur, head_dim, n_head, N, batch_size), n_past, n_rot, 2, 0); Kcur = ne_rope_inplace(ctx0, ne_reshape_4d(ctx0, Kcur, head_dim, n_head, N, batch_size), n_past, n_rot, 2, 0); - + const float attn_scale = 1.0f / sqrtf(static_cast(head_dim)); // store key and value to memory - { - std::vector Kcur_bs(batch_size); - std::vector Vcur_bs(batch_size); - std::vector k_bs(batch_size); - std::vector v_bs(batch_size); - for (int i = 0; i < batch_size; ++i) { - // batch K - Kcur_bs[i] = ne_permute(ctx0, - ne_view_4d(ctx0, Kcur, head_dim, n_head, N, 1, ne_element_size(Kcur) * head_dim, - ne_element_size(Kcur) * n_embd, ne_element_size(Kcur) * n_embd * N, - i * ne_element_size(Kcur) * n_embd * N), - 0, 2, 1, 3); - k_bs[i] = ne_view_4d( - ctx0, kv_self.k, head_dim, N, n_head, 1, ne_element_size(kv_self.k) * head_dim, - ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx, - ((il * n_ctx) * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block + - i * n_ctx * n_embd * ne_element_size(kv_self.k) + head_dim * n_past * ne_element_size(kv_self.k))); - - // batch V - Vcur_bs[i] = ne_permute(ctx0, - ne_reshape_4d(ctx0, - ne_view_2d(ctx0, Vcur, n_embd, N, ne_element_size(Vcur) * n_embd, - i * ne_element_size(Vcur) * n_embd * N), - head_dim, n_head, N, 1), - 1, 2, 0, 3); - v_bs[i] = - ne_view_4d(ctx0, kv_self.v, N, head_dim, n_head, 1, n_ctx * ne_element_size(kv_self.v), - n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd, - ((il * n_ctx) * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block + - i * n_ctx * n_embd * ne_element_size(kv_self.v) + n_past * ne_element_size(kv_self.v))); - ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs[i], k_bs[i])); - ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs[i], v_bs[i])); + if (!run_mha_reordered) { + { + std::vector Kcur_bs(batch_size); + std::vector Vcur_bs(batch_size); + std::vector k_bs(batch_size); + std::vector v_bs(batch_size); + for (int i = 0; i < batch_size; ++i) { + // batch K + Kcur_bs[i] = ne_permute(ctx0, + ne_view_4d(ctx0, Kcur, head_dim, n_head, N, 1, ne_element_size(Kcur) * head_dim, + ne_element_size(Kcur) * n_embd, ne_element_size(Kcur) * n_embd * N, + i * ne_element_size(Kcur) * n_embd * N), + 0, 2, 1, 3); + k_bs[i] = ne_view_4d( + ctx0, kv_self.k, head_dim, N, n_head, 1, ne_element_size(kv_self.k) * head_dim, + ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx, + ((il * n_ctx) * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block + + i * n_ctx * n_embd * ne_element_size(kv_self.k) + head_dim * n_past * ne_element_size(kv_self.k))); + + // batch V + Vcur_bs[i] = ne_permute(ctx0, + ne_reshape_4d(ctx0, + ne_view_2d(ctx0, Vcur, n_embd, N, ne_element_size(Vcur) * n_embd, + i * ne_element_size(Vcur) * n_embd * N), + head_dim, n_head, N, 1), + 1, 2, 0, 3); + v_bs[i] = + ne_view_4d(ctx0, kv_self.v, N, head_dim, n_head, 1, n_ctx * ne_element_size(kv_self.v), + n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd, + ((il * n_ctx) * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block + + i * n_ctx * n_embd * ne_element_size(kv_self.v) + n_past * ne_element_size(kv_self.v))); + ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs[i], k_bs[i])); + ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs[i], v_bs[i])); + } + } + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) + struct ne_tensor* Q = ne_permute(ctx0, ne_reshape_4d(ctx0, Qcur, head_dim, n_head, N, batch_size), 0, 2, 1, 3); + + // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) + struct ne_tensor* K = + ne_view_4d(ctx0, kv_self.k, head_dim, n_past + N, n_head, batch_size, ne_element_size(kv_self.k) * head_dim, + ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx, + il * n_ctx * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block); + + // K * Q + struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, ne_new_f32(ctx0, 1.0f / sqrt(float(n_embd) / n_head))); + + // KQ_masked = mask_past(KQ_scaled) + struct ne_tensor* KQ_masked = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + struct ne_tensor* V = + ne_view_4d(ctx0, kv_self.v, n_past + N, head_dim, n_head, batch_size, n_ctx * ne_element_size(kv_self.v), + n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd, + il * n_ctx * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block); + + // KQV = transpose(V) * KQ_soft_max + struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size, NE_SIZE_CALC)); + } else { + const auto seq_kv = n_past + N; + const auto k_size = kv_cache_info.k_bytes; + const auto v_size = kv_cache_info.v_bytes; + + // store key and value to memory + { + const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor + head_dim, n_ctx, n_head, // ne + 0, 0, // nb (jblas managed) + il * k_size); // offset + ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, Kcur, n_past)); + const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor + head_dim, n_ctx, n_head, // ne + 0, 0, // nb (jblas managed) + il * v_size); // offset + ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur, n_past)); } - } - // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) - struct ne_tensor* Q = ne_permute(ctx0, ne_reshape_4d(ctx0, Qcur, head_dim, n_head, N, batch_size), 0, 2, 1, 3); - - // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) - struct ne_tensor* K = - ne_view_4d(ctx0, kv_self.k, head_dim, n_past + N, n_head, batch_size, ne_element_size(kv_self.k) * head_dim, - ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx, - il * n_ctx * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block); - - // K * Q - struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q); - - // KQ_scaled = KQ / sqrt(n_embd/n_head) - struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, ne_new_f32(ctx0, 1.0f / sqrt(float(n_embd) / n_head))); - - // KQ_masked = mask_past(KQ_scaled) - struct ne_tensor* KQ_masked = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); - - // KQ = soft_max(KQ_masked) - struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_masked); - - // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() - struct ne_tensor* V = - ne_view_4d(ctx0, kv_self.v, n_past + N, head_dim, n_head, batch_size, n_ctx * ne_element_size(kv_self.v), - n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd, - il * n_ctx * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block); - - // KQV = transpose(V) * KQ_soft_max - struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max); - - // KQV_merged = KQV.permute(0, 2, 1, 3) - struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); - - // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size, NE_SIZE_CALC)); + struct ne_tensor* Q = ne_permute(ctx0, Qcur, 0, 2, 1, 3); + ne_set_name(Q, "Q"); + + struct ne_tensor* K = + ne_view_3d(ctx0, kv_self.k, // tensor + head_dim, seq_kv, n_head, // ne + kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (jblas managed) + il * k_size); // offset + *reinterpret_cast(&K->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout + ne_set_name(K, "K"); + struct ne_tensor* V = + ne_view_3d(ctx0, kv_self.v, // tensor + seq_kv, head_dim, n_head, // ne + kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (jblas managed) + il * v_size); // offset + *reinterpret_cast(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout + ne_set_name(V, "V"); + + ne_attn_flags_t attn_flags = 0; + if (n_past == 0) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases + struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags); + cur = ne_view_2d(ctx0, KQV_Out, n_embd, N, n_embd * ne_element_size(KQV_Out), 0); + } // projection { cur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur); diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/gptneox/gptneox_utils.cpp b/intel_extension_for_transformers/llm/runtime/graph/models/gptneox/gptneox_utils.cpp index 7a518f618ca..af78fd26f2c 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/gptneox/gptneox_utils.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/models/gptneox/gptneox_utils.cpp @@ -285,6 +285,7 @@ void model_load_internal(const std::string& fname, model_archs arch, model_conte std::unique_ptr ms(new GPTNEOX()); ms->init(fname.c_str(), lctx, n_gpu_layers, use_mmap, use_mlock, vocab_only); ms->load(lctx, progress_callback, progress_callback_user_data); + lctx.support_jblas_kv = true; if (lctx.beam_search) { lctx.bs_kv_reorder = std::make_shared(&lctx); #ifdef NE_BEAM_SEARCH_VERBOSE_ON diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/llama/llama.cpp b/intel_extension_for_transformers/llm/runtime/graph/models/llama/llama.cpp index db3ee5682f9..1fedc13d174 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/llama/llama.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/models/llama/llama.cpp @@ -264,17 +264,18 @@ static bool llama_model_eval_internal(model_context& lctx, const model_token* to const auto v_size = kv_cache_info.v_bytes; // store key and value to memory { - const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor - head_size, n_ctx, n_head, // ne - 0, 0, // nb (jblas managed) - il * k_size); // offset + const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor + head_size, n_ctx, n_head_kv, // ne + 0, 0, // nb (jblas managed) + il * k_size); // offset ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, Kcur, n_past)); - const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor - head_size, n_ctx, n_head, // ne - 0, 0, // nb (jblas managed) - il * v_size); // offset + const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor + head_size, n_ctx, n_head_kv, // ne + 0, 0, // nb (jblas managed) + il * v_size); // offset // jblas alway view V as (D, n_head, seq) - const auto Vcur_plain = ne_reshape_3d(ctx0, ne_view_1d(ctx0, Vcur, n_embd * N, 0), n_embd / n_head, n_head, N); + const auto Vcur_plain = + ne_reshape_3d(ctx0, ne_view_1d(ctx0, Vcur, n_embd_gqa * N, 0), n_embd_gqa / n_head_kv, n_head_kv, N); ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur_plain, n_past)); } @@ -283,14 +284,14 @@ static bool llama_model_eval_internal(model_context& lctx, const model_token* to struct ne_tensor* K = ne_view_3d(ctx0, kv_self.k, // tensor - head_size, n_cached, n_head, // ne + head_size, n_cached, n_head_kv, // ne kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (jblas managed) il * k_size); // offset *reinterpret_cast(&K->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout ne_set_name(K, "K"); struct ne_tensor* V = ne_view_3d(ctx0, kv_self.v, // tensor - n_cached, head_size, n_head, // ne + n_cached, head_size, n_head_kv, // ne kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (jblas managed) il * v_size); // offset *reinterpret_cast(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/starcoder/starcoder.cpp b/intel_extension_for_transformers/llm/runtime/graph/models/starcoder/starcoder.cpp index 80424c3a616..c8796d3c7a6 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/starcoder/starcoder.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/models/starcoder/starcoder.cpp @@ -67,6 +67,7 @@ static bool starcoder_model_eval_internal(model_context& lctx, const model_token const int n_head = hparams.n_head; const int n_vocab = hparams.n_vocab; const int n_rot = hparams.n_rot; + const int head_dim = n_embd / n_head; auto& mem_per_token = lctx.mem_per_token; auto& buf_compute = lctx.buf_compute; @@ -84,8 +85,28 @@ static bool starcoder_model_eval_internal(model_context& lctx, const model_token ne_cgraph gf = {}; gf.n_threads = N >= 32 && ne_cpu_has_blas() ? 1 : n_threads; - const bool kv_mem_jblas = kv_self.k->type == NE_TYPE_JBLAS; - NE_ASSERT(("jblas managed kv-cache is not yet supported; use `--memory-f16 / --memory-f32` instead", !kv_mem_jblas)); + const bool run_mha_reordered = kv_self.k->type == NE_TYPE_JBLAS; + kv_cache_info_t kv_cache_info = {}; + if (run_mha_reordered) { + NE_ASSERT(("kv cache should be the same dtype", kv_self.v->type == NE_TYPE_JBLAS)); + attn_shape_t attn_shape = { + /* .batch_size = */ 1, + /* .head_num = */ n_head, + /* .heads_kv = */ n_head, + /* .head_size = */ head_dim, + /* .sl_q = */ N, // Note: make sure that jblas reordered attn supports next token inference + /* .sl_kv = */ n_past + N, + }; + + NE_ASSERT(("jblas managed kv-cache not supported; use `--memory-f16 / --memory-f32` instead", + jblas_reordered_attn_fp32_support(&attn_shape))); + kv_shape_t kv_shape{ + /* .heads_kv = */ static_cast(n_head), + /* .head_size = */ static_cast(head_dim), + /* .sl_kv_max = */ static_cast(n_ctx), + }; + jblas_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info); + } struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N); ne_set_name(embd, "embd"); @@ -133,92 +154,132 @@ static bool starcoder_model_eval_internal(model_context& lctx, const model_token // self-attention { size_t fused_qkv_row_nb = (3 * n_embd) * sizeof(float); - size_t head_dim = n_embd / n_head; struct ne_tensor* Qcur = ne_view_3d(ctx0, cur, head_dim, n_head, N, head_dim * sizeof(float), fused_qkv_row_nb, 0 * sizeof(float) * n_embd); // head_dim, n_head, N --> head_dim, N, n_head - struct ne_tensor* Kcur = ne_permute(ctx0, - ne_view_3d(ctx0, cur, head_dim, n_head, N, head_dim * sizeof(float), - fused_qkv_row_nb, 1 * sizeof(float) * n_embd), - 0, 2, 1, 3); + struct ne_tensor* Kcur = ne_view_3d(ctx0, cur, head_dim, n_head, N, head_dim * sizeof(float), fused_qkv_row_nb, + 1 * sizeof(float) * n_embd); // head_dim, n_head, N --> N, head_dim, n_head - struct ne_tensor* Vcur = ne_permute(ctx0, - ne_view_3d(ctx0, cur, head_dim, n_head, N, head_dim * sizeof(float), - fused_qkv_row_nb, 2 * sizeof(float) * n_embd), - 1, 2, 0, 3); - + struct ne_tensor* Vcur = ne_view_3d(ctx0, cur, head_dim, n_head, N, head_dim * sizeof(float), fused_qkv_row_nb, + 2 * sizeof(float) * n_embd); + const float attn_scale = 1.0f / sqrtf(static_cast(head_dim)); // store transposed key and value to memory (k_v cache) - if (N >= 1) { - // n_embd / n_head as col - struct ne_tensor* k = ne_view_3d( - ctx0, kv_self.k, n_embd / n_head, N, n_head, ne_element_size(kv_self.k) * n_embd / n_head, - ne_element_size(kv_self.k) * n_embd / n_head * n_ctx, - il * n_ctx * ne_element_size(kv_self.k) * n_embd + n_past * ne_element_size(kv_self.k) * n_embd / n_head); - // N as col, n_embd as row - struct ne_tensor* v = - ne_view_3d(ctx0, kv_self.v, N, n_embd / n_head, n_head, n_ctx * ne_element_size(kv_self.v), - n_ctx * ne_element_size(kv_self.v) * head_dim, - il * n_ctx * ne_element_size(kv_self.v) * n_embd + n_past * ne_element_size(kv_self.v)); - // concat - ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur, k)); - ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur, v)); + if (!run_mha_reordered) { + struct ne_tensor* Kcur_permuted = ne_permute(ctx0, Kcur, 0, 2, 1, 3); + // head_dim, n_head_kv, N --> N, head_dim, n_head_kv + struct ne_tensor* Vcur_permuted = ne_permute(ctx0, Vcur, 1, 2, 0, 3); + if (N >= 1) { + // n_embd / n_head as col + struct ne_tensor* k = ne_view_3d( + ctx0, kv_self.k, head_dim, N, n_head, ne_element_size(kv_self.k) * head_dim, + ne_element_size(kv_self.k) * head_dim * n_ctx, + il * n_ctx * ne_element_size(kv_self.k) * n_embd + n_past * ne_element_size(kv_self.k) * head_dim); + // N as col, n_embd as row + struct ne_tensor* v = + ne_view_3d(ctx0, kv_self.v, N, head_dim, n_head, n_ctx * ne_element_size(kv_self.v), + n_ctx * ne_element_size(kv_self.v) * head_dim, + il * n_ctx * ne_element_size(kv_self.v) * n_embd + n_past * ne_element_size(kv_self.v)); + // concat + ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_permuted, k)); + ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_permuted, v)); + } + + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) + // [64, N, 12] + struct ne_tensor* Q = ne_permute(ctx0, Qcur, 0, 2, 1, 3); + + // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) + // [64, n_past + N, 12] + struct ne_tensor* K = + ne_view_3d(ctx0, kv_self.k, head_dim, N + n_past, n_head, ne_element_size(kv_self.k) * head_dim, + ne_element_size(kv_self.k) * head_dim * n_ctx, il * n_ctx * ne_element_size(kv_self.k) * n_embd); + + // GG: flash attention + // struct ne_tensor * V = + // ne_cpy(ctx0, + // ne_permute(ctx0, + // ne_reshape_3d(ctx0, + // ne_view_1d(ctx0, kv_self.v, (n_past + N)*n_embd, + // il*n_ctx*ne_element_size(kv_self.v)*n_embd), n_embd/n_head, n_head, n_past + N), + // 1, 2, 0, 3), + // ne_new_tensor_3d(ctx0, NE_TYPE_F32, n_past + N, n_embd/n_head, n_head, NE_SIZE_CALC)); + + // struct ne_tensor * KQV = ne_flash_attn(ctx0, Q, K, V, NE_ATTN_FLAG_IS_CAUSAL); + + // K * Q + // [n_past + N, N, 12] + struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q); // TODO: check if it broadcasts + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + // [n_past + N, N, 12] + struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, ne_new_f32(ctx0, 1.0f / sqrt(float(n_embd) / n_head))); + + // KQ_masked = mask_past(KQ_scaled) + // [n_past + N, N, 12] + struct ne_tensor* KQ_masked = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + // [n_past + N, N, 12] + struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + // [n_past + N, 64, 12] + struct ne_tensor* V_trans = + ne_view_3d(ctx0, kv_self.v, N + n_past, head_dim, n_head, n_ctx * ne_element_size(kv_self.v), + n_ctx * ne_element_size(kv_self.v) * head_dim, il * n_ctx * ne_element_size(kv_self.v) * n_embd); + + // KQV = transpose(V) * KQ_soft_max + // [64, N, 12] + struct ne_tensor* KQV = ne_mul_mat(ctx0, V_trans, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + // [64, 12, N] + struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + // [768, N] + cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N, NE_SIZE_CALC)); + } else { + const auto seq_kv = n_past + N; + const auto k_size = kv_cache_info.k_bytes; + const auto v_size = kv_cache_info.v_bytes; + // store key and value to memory + { + const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor + head_dim, n_ctx, n_head, // ne + 0, 0, // nb (jblas managed) + il * k_size); // offset + ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, Kcur, n_past)); + const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor + head_dim, n_ctx, n_head, // ne + 0, 0, // nb (jblas managed) + il * v_size); // offset + ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur, n_past)); + } + + struct ne_tensor* Q = ne_permute(ctx0, Qcur, 0, 2, 1, 3); + ne_set_name(Q, "Q"); + + struct ne_tensor* K = + ne_view_3d(ctx0, kv_self.k, // tensor + head_dim, seq_kv, n_head, // ne + kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (jblas managed) + il * k_size); // offset + *reinterpret_cast(&K->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout + ne_set_name(K, "K"); + struct ne_tensor* V = + ne_view_3d(ctx0, kv_self.v, // tensor + seq_kv, head_dim, n_head, // ne + kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (jblas managed) + il * v_size); // offset + *reinterpret_cast(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout + ne_set_name(V, "V"); + + ne_attn_flags_t attn_flags = 0; + if (n_past == 0) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases + struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags); + cur = ne_view_2d(ctx0, KQV_Out, n_embd, N, n_embd * ne_element_size(KQV_Out), 0); } - - // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) - // [64, N, 12] - struct ne_tensor* Q = ne_permute(ctx0, Qcur, 0, 2, 1, 3); - - // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) - // [64, n_past + N, 12] - struct ne_tensor* K = ne_view_3d( - ctx0, kv_self.k, n_embd / n_head, N + n_past, n_head, ne_element_size(kv_self.k) * n_embd / n_head, - ne_element_size(kv_self.k) * n_embd / n_head * n_ctx, il * n_ctx * ne_element_size(kv_self.k) * n_embd); - - // GG: flash attention - // struct ne_tensor * V = - // ne_cpy(ctx0, - // ne_permute(ctx0, - // ne_reshape_3d(ctx0, - // ne_view_1d(ctx0, kv_self.v, (n_past + N)*n_embd, - // il*n_ctx*ne_element_size(kv_self.v)*n_embd), n_embd/n_head, n_head, n_past + N), - // 1, 2, 0, 3), - // ne_new_tensor_3d(ctx0, NE_TYPE_F32, n_past + N, n_embd/n_head, n_head, NE_SIZE_CALC)); - - // struct ne_tensor * KQV = ne_flash_attn(ctx0, Q, K, V, NE_ATTN_FLAG_IS_CAUSAL); - - // K * Q - // [n_past + N, N, 12] - struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q); // TODO: check if it broadcasts - - // KQ_scaled = KQ / sqrt(n_embd/n_head) - // [n_past + N, N, 12] - struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, ne_new_f32(ctx0, 1.0f / sqrt(float(n_embd) / n_head))); - - // KQ_masked = mask_past(KQ_scaled) - // [n_past + N, N, 12] - struct ne_tensor* KQ_masked = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); - - // KQ = soft_max(KQ_masked) - // [n_past + N, N, 12] - struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_masked); - - // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() - // [n_past + N, 64, 12] - struct ne_tensor* V_trans = ne_view_3d( - ctx0, kv_self.v, N + n_past, n_embd / n_head, n_head, n_ctx * ne_element_size(kv_self.v), - n_ctx * ne_element_size(kv_self.v) * n_embd / n_head, il * n_ctx * ne_element_size(kv_self.v) * n_embd); - - // KQV = transpose(V) * KQ_soft_max - // [64, N, 12] - struct ne_tensor* KQV = ne_mul_mat(ctx0, V_trans, KQ_soft_max); - - // KQV_merged = KQV.permute(0, 2, 1, 3) - // [64, 12, N] - struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); - - // cur = KQV_merged.contiguous().view(n_embd, N) - // [768, N] - cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N, NE_SIZE_CALC)); } // projection diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/starcoder/starcoder_utils.cpp b/intel_extension_for_transformers/llm/runtime/graph/models/starcoder/starcoder_utils.cpp index 2e21b75cc1a..de50de99e98 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/starcoder/starcoder_utils.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/models/starcoder/starcoder_utils.cpp @@ -45,6 +45,7 @@ void model_load_internal(const std::string& fname, model_archs arch, model_conte std::unique_ptr ms(new STARCODER()); ms->init(fname.c_str(), lctx, n_gpu_layers, use_mmap, use_mlock, vocab_only); ms->load(lctx, progress_callback, progress_callback_user_data); + lctx.support_jblas_kv = true; } void STARCODER::init(const char* path_model, model_context& lctx, int n_gpu_layer_, bool use_mmap_, bool use_mlock_,