Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

[LLM Runtime] enable MHA fusion for gptneox&dolly&starcoder&llama2-70b #567

Merged
merged 17 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,28 @@ static bool gptneox_model_eval_internal(model_context& lctx, const model_token*
ne_cgraph gf = {};
gf.n_threads = N >= 32 && ne_cpu_has_blas() ? 1 : n_threads;

const bool run_mha_reordered = kv_self.k->type == NE_TYPE_JBLAS;
kv_cache_info_t kv_cache_info = {};
if (run_mha_reordered) {
NE_ASSERT(("kv cache should be the same dtype", kv_self.v->type == NE_TYPE_JBLAS));
attn_shape_t attn_shape = {
/* .batch_size = */ 1,
/* .head_num = */ n_head,
/* .heads_kv = */ n_head,
/* .head_size = */ head_dim,
/* .sl_q = */ N, // Note: make sure that jblas reordered attn supports next token inference
/* .sl_kv = */ n_past + N,
};

NE_ASSERT(("jblas managed kv-cache not supported; use `--memory-f16 / --memory-f32` instead",
jblas_reordered_attn_fp32_support(&attn_shape)));
kv_shape_t kv_shape{
/* .heads_kv = */ static_cast<uint32_t>(n_head),
/* .head_size = */ static_cast<uint32_t>(head_dim),
/* .sl_kv_max = */ static_cast<uint32_t>(n_ctx),
};
jblas_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info);
}
struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N * batch_size);
ne_set_name(embd, "embd");
for (int i = 0; i < batch_size; ++i) {
Expand Down Expand Up @@ -151,78 +173,120 @@ static bool gptneox_model_eval_internal(model_context& lctx, const model_token*
// using mode = 2 for GPT-NeoX mode
Qcur = ne_rope_inplace(ctx0, ne_reshape_4d(ctx0, Qcur, head_dim, n_head, N, batch_size), n_past, n_rot, 2, 0);
Kcur = ne_rope_inplace(ctx0, ne_reshape_4d(ctx0, Kcur, head_dim, n_head, N, batch_size), n_past, n_rot, 2, 0);

const float attn_scale = 1.0f / sqrtf(static_cast<float>(head_dim));
// store key and value to memory
{
std::vector<ne_tensor*> Kcur_bs(batch_size);
std::vector<ne_tensor*> Vcur_bs(batch_size);
std::vector<ne_tensor*> k_bs(batch_size);
std::vector<ne_tensor*> v_bs(batch_size);
for (int i = 0; i < batch_size; ++i) {
// batch K
Kcur_bs[i] = ne_permute(ctx0,
ne_view_4d(ctx0, Kcur, head_dim, n_head, N, 1, ne_element_size(Kcur) * head_dim,
ne_element_size(Kcur) * n_embd, ne_element_size(Kcur) * n_embd * N,
i * ne_element_size(Kcur) * n_embd * N),
0, 2, 1, 3);
k_bs[i] = ne_view_4d(
ctx0, kv_self.k, head_dim, N, n_head, 1, ne_element_size(kv_self.k) * head_dim,
ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx,
((il * n_ctx) * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block +
i * n_ctx * n_embd * ne_element_size(kv_self.k) + head_dim * n_past * ne_element_size(kv_self.k)));

// batch V
Vcur_bs[i] = ne_permute(ctx0,
ne_reshape_4d(ctx0,
ne_view_2d(ctx0, Vcur, n_embd, N, ne_element_size(Vcur) * n_embd,
i * ne_element_size(Vcur) * n_embd * N),
head_dim, n_head, N, 1),
1, 2, 0, 3);
v_bs[i] =
ne_view_4d(ctx0, kv_self.v, N, head_dim, n_head, 1, n_ctx * ne_element_size(kv_self.v),
n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd,
((il * n_ctx) * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block +
i * n_ctx * n_embd * ne_element_size(kv_self.v) + n_past * ne_element_size(kv_self.v)));
ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs[i], k_bs[i]));
ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs[i], v_bs[i]));
if (!run_mha_reordered) {
{
std::vector<ne_tensor*> Kcur_bs(batch_size);
std::vector<ne_tensor*> Vcur_bs(batch_size);
std::vector<ne_tensor*> k_bs(batch_size);
std::vector<ne_tensor*> v_bs(batch_size);
for (int i = 0; i < batch_size; ++i) {
// batch K
Kcur_bs[i] = ne_permute(ctx0,
ne_view_4d(ctx0, Kcur, head_dim, n_head, N, 1, ne_element_size(Kcur) * head_dim,
ne_element_size(Kcur) * n_embd, ne_element_size(Kcur) * n_embd * N,
i * ne_element_size(Kcur) * n_embd * N),
0, 2, 1, 3);
k_bs[i] = ne_view_4d(
ctx0, kv_self.k, head_dim, N, n_head, 1, ne_element_size(kv_self.k) * head_dim,
ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx,
((il * n_ctx) * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block +
i * n_ctx * n_embd * ne_element_size(kv_self.k) + head_dim * n_past * ne_element_size(kv_self.k)));

// batch V
Vcur_bs[i] = ne_permute(ctx0,
ne_reshape_4d(ctx0,
ne_view_2d(ctx0, Vcur, n_embd, N, ne_element_size(Vcur) * n_embd,
i * ne_element_size(Vcur) * n_embd * N),
head_dim, n_head, N, 1),
1, 2, 0, 3);
v_bs[i] =
ne_view_4d(ctx0, kv_self.v, N, head_dim, n_head, 1, n_ctx * ne_element_size(kv_self.v),
n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd,
((il * n_ctx) * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block +
i * n_ctx * n_embd * ne_element_size(kv_self.v) + n_past * ne_element_size(kv_self.v)));
ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs[i], k_bs[i]));
ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs[i], v_bs[i]));
}
}
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
struct ne_tensor* Q = ne_permute(ctx0, ne_reshape_4d(ctx0, Qcur, head_dim, n_head, N, batch_size), 0, 2, 1, 3);

// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
struct ne_tensor* K =
ne_view_4d(ctx0, kv_self.k, head_dim, n_past + N, n_head, batch_size, ne_element_size(kv_self.k) * head_dim,
ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx,
il * n_ctx * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block);

// K * Q
struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q);

// KQ_scaled = KQ / sqrt(n_embd/n_head)
struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, ne_new_f32(ctx0, 1.0f / sqrt(float(n_embd) / n_head)));

// KQ_masked = mask_past(KQ_scaled)
struct ne_tensor* KQ_masked = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);

// KQ = soft_max(KQ_masked)
struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_masked);

// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
struct ne_tensor* V =
ne_view_4d(ctx0, kv_self.v, n_past + N, head_dim, n_head, batch_size, n_ctx * ne_element_size(kv_self.v),
n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd,
il * n_ctx * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block);

// KQV = transpose(V) * KQ_soft_max
struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max);

// KQV_merged = KQV.permute(0, 2, 1, 3)
struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3);

// cur = KQV_merged.contiguous().view(n_embd, N)
cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size, NE_SIZE_CALC));
} else {
const auto seq_kv = n_past + N;
const auto k_size = kv_cache_info.k_bytes;
const auto v_size = kv_cache_info.v_bytes;

// store key and value to memory
{
const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor
head_dim, n_ctx, n_head, // ne
0, 0, // nb (jblas managed)
il * k_size); // offset
ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, Kcur, n_past));
const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor
head_dim, n_ctx, n_head, // ne
0, 0, // nb (jblas managed)
il * v_size); // offset
ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur, n_past));
}
}
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
struct ne_tensor* Q = ne_permute(ctx0, ne_reshape_4d(ctx0, Qcur, head_dim, n_head, N, batch_size), 0, 2, 1, 3);

// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
struct ne_tensor* K =
ne_view_4d(ctx0, kv_self.k, head_dim, n_past + N, n_head, batch_size, ne_element_size(kv_self.k) * head_dim,
ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx,
il * n_ctx * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block);

// K * Q
struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q);

// KQ_scaled = KQ / sqrt(n_embd/n_head)
struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, ne_new_f32(ctx0, 1.0f / sqrt(float(n_embd) / n_head)));

// KQ_masked = mask_past(KQ_scaled)
struct ne_tensor* KQ_masked = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);

// KQ = soft_max(KQ_masked)
struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_masked);

// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
struct ne_tensor* V =
ne_view_4d(ctx0, kv_self.v, n_past + N, head_dim, n_head, batch_size, n_ctx * ne_element_size(kv_self.v),
n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_embd,
il * n_ctx * ne_element_size(kv_self.v) * n_embd * kv_n_ctx_block);

// KQV = transpose(V) * KQ_soft_max
struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max);

// KQV_merged = KQV.permute(0, 2, 1, 3)
struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3);

// cur = KQV_merged.contiguous().view(n_embd, N)
cur = ne_cpy(ctx0, KQV_merged, ne_new_tensor_2d(ctx0, NE_TYPE_F32, n_embd, N * batch_size, NE_SIZE_CALC));

struct ne_tensor* Q = ne_permute(ctx0, Qcur, 0, 2, 1, 3);
ne_set_name(Q, "Q");

struct ne_tensor* K =
ne_view_3d(ctx0, kv_self.k, // tensor
head_dim, seq_kv, n_head, // ne
kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (jblas managed)
il * k_size); // offset
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&K->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout
ne_set_name(K, "K");
struct ne_tensor* V =
ne_view_3d(ctx0, kv_self.v, // tensor
seq_kv, head_dim, n_head, // ne
kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (jblas managed)
il * v_size); // offset
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout
ne_set_name(V, "V");

ne_attn_flags_t attn_flags = 0;
if (n_past == 0) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases
struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags);
cur = ne_view_2d(ctx0, KQV_Out, n_embd, N, n_embd * ne_element_size(KQV_Out), 0);
}
// projection
{
cur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ void model_load_internal(const std::string& fname, model_archs arch, model_conte
std::unique_ptr<GPTNEOX> ms(new GPTNEOX());
ms->init(fname.c_str(), lctx, n_gpu_layers, use_mmap, use_mlock, vocab_only);
ms->load(lctx, progress_callback, progress_callback_user_data);
lctx.support_jblas_kv = true;
if (lctx.beam_search) {
lctx.bs_kv_reorder = std::make_shared<gptneox_beam_search_kv_cache_reorder>(&lctx);
#ifdef NE_BEAM_SEARCH_VERBOSE_ON
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,17 +264,18 @@ static bool llama_model_eval_internal(model_context& lctx, const model_token* to
const auto v_size = kv_cache_info.v_bytes;
// store key and value to memory
{
const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor
head_size, n_ctx, n_head, // ne
0, 0, // nb (jblas managed)
il * k_size); // offset
const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor
head_size, n_ctx, n_head_kv, // ne
0, 0, // nb (jblas managed)
il * k_size); // offset
ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, Kcur, n_past));
const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor
head_size, n_ctx, n_head, // ne
0, 0, // nb (jblas managed)
il * v_size); // offset
const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor
head_size, n_ctx, n_head_kv, // ne
0, 0, // nb (jblas managed)
il * v_size); // offset
// jblas alway view V as (D, n_head, seq)
const auto Vcur_plain = ne_reshape_3d(ctx0, ne_view_1d(ctx0, Vcur, n_embd * N, 0), n_embd / n_head, n_head, N);
const auto Vcur_plain =
ne_reshape_3d(ctx0, ne_view_1d(ctx0, Vcur, n_embd_gqa * N, 0), n_embd_gqa / n_head_kv, n_head_kv, N);
intellinjun marked this conversation as resolved.
Show resolved Hide resolved
ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur_plain, n_past));
}

Expand All @@ -283,14 +284,14 @@ static bool llama_model_eval_internal(model_context& lctx, const model_token* to

struct ne_tensor* K =
ne_view_3d(ctx0, kv_self.k, // tensor
head_size, n_cached, n_head, // ne
head_size, n_cached, n_head_kv, // ne
kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (jblas managed)
il * k_size); // offset
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&K->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout
ne_set_name(K, "K");
struct ne_tensor* V =
ne_view_3d(ctx0, kv_self.v, // tensor
n_cached, head_size, n_head, // ne
n_cached, head_size, n_head_kv, // ne
kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (jblas managed)
il * v_size); // offset
*reinterpret_cast<ATTN_FWD_LAYOUT*>(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout
Expand Down
Loading