Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
[LLM Runtime] make rms_norm_eps and freq_base as parameter (#903)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenwei-intel authored Dec 11, 2023
1 parent 13c0c7f commit 021d66a
Show file tree
Hide file tree
Showing 27 changed files with 136 additions and 80 deletions.
57 changes: 31 additions & 26 deletions intel_extension_for_transformers/llm/runtime/graph/core/ne_layers.c
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,12 @@ void ne_scratch_load(struct ne_context* ctx) { ctx->scratch = ctx->scratch_save;

////////////////////////////////////////////////////////////////////////////////

static void ne_set_op_params(struct ne_tensor* tensor, const void* params, size_t params_size) {
NE_ASSERT(tensor != NULL); // silence -Warray-bounds warnings
// assert(params_size <= NE_MAX_OP_PARAMS);
memcpy(tensor->op_params, params, params_size);
}

struct ne_tensor* ne_new_tensor_impl(struct ne_context* ctx, enum ne_type type, int n_dims, const int64_t* ne,
void* data, size_t size) {
// always insert objects at the end of the context's memory pool
Expand Down Expand Up @@ -2065,7 +2071,7 @@ struct ne_tensor* ne_norm_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_norm_impl(ctx, a, true);
}

struct ne_tensor* ne_rms_norm_impl(struct ne_context* ctx, struct ne_tensor* a, bool inplace) {
struct ne_tensor* ne_rms_norm_impl(struct ne_context* ctx, struct ne_tensor* a, bool inplace, float eps) {
bool is_node = false;

if (!inplace && (a->grad)) {
Expand All @@ -2074,20 +2080,21 @@ struct ne_tensor* ne_rms_norm_impl(struct ne_context* ctx, struct ne_tensor* a,

struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a);

ne_set_op_params(result, &eps, sizeof(eps));

result->op = NE_OP_RMS_NORM;
result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = NULL; // TODO: maybe store epsilon here?

return result;
}

struct ne_tensor* ne_rms_norm(struct ne_context* ctx, struct ne_tensor* a) {
return ne_rms_norm_impl(ctx, a, false);
struct ne_tensor* ne_rms_norm(struct ne_context* ctx, struct ne_tensor* a, float eps) {
return ne_rms_norm_impl(ctx, a, false, eps);
}

struct ne_tensor* ne_rms_norm_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_rms_norm_impl(ctx, a, true);
struct ne_tensor* ne_rms_norm_inplace(struct ne_context* ctx, struct ne_tensor* a, float eps) {
return ne_rms_norm_impl(ctx, a, true, eps);
}

struct ne_tensor* ne_rms_norm_back(struct ne_context* ctx, struct ne_tensor* a, struct ne_tensor* b) {
Expand Down Expand Up @@ -2973,7 +2980,7 @@ struct ne_tensor* ne_soft_max_inplace(struct ne_context* ctx, struct ne_tensor*

struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
int prompt_size, bool inplace, int n_keep, struct ne_tensor* cossin, int* n_padding,
bool padding_left) {
bool padding_left, float freq_base) {
NE_ASSERT(n_past >= 0 || n_keep >= 0);
NE_ASSERT(padding_left);
bool is_node = false;
Expand Down Expand Up @@ -3013,6 +3020,7 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int

ne_scratch_load(ctx);

ne_set_op_params(result, &freq_base, sizeof(freq_base));
result->op = NE_OP_ROPE;
result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL;
result->src0 = a;
Expand All @@ -3023,18 +3031,18 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int
}

struct ne_tensor* ne_rope(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
int prompt_size) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, NULL, true);
int prompt_size, float freq_base) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, NULL, true, freq_base);
}

struct ne_tensor* ne_rope_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
int prompt_size) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true);
int prompt_size, float freq_base) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base);
}

struct ne_tensor* ne_rope_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims, int mode,
int prompt_size, int n_keep, struct ne_tensor* cossin) {
return ne_rope_impl(ctx, a, n_shift, n_dims, mode, prompt_size, true, n_keep, cossin, NULL, true);
int prompt_size, int n_keep, struct ne_tensor* cossin, float freq_base) {
return ne_rope_impl(ctx, a, n_shift, n_dims, mode, prompt_size, true, n_keep, cossin, NULL, true, freq_base);
}

// ne_rope_back
Expand Down Expand Up @@ -3070,13 +3078,13 @@ struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int
}

struct ne_tensor* ne_rope_with_padding(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
int prompt_size, int* n_padding) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, n_padding, true);
int prompt_size, int* n_padding, float freq_base) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, n_padding, true, freq_base);
}

struct ne_tensor* ne_rope_with_padding_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims,
int mode, int prompt_size, int* n_padding) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, n_padding, true);
int mode, int prompt_size, int* n_padding, float freq_base) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, n_padding, true, freq_base);
}

// ne_alibi
Expand Down Expand Up @@ -3201,12 +3209,6 @@ struct ne_tensor* ne_conv_1d_2s(struct ne_context* ctx, struct ne_tensor* a, str
return result;
}

static void ne_set_op_params(struct ne_tensor* tensor, const void* params, size_t params_size) {
NE_ASSERT(tensor != NULL); // silence -Warray-bounds warnings
// assert(params_size <= NE_MAX_OP_PARAMS);
memcpy(tensor->op_params, params, params_size);
}

// for ne_conv_1d
static int64_t ne_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
Expand Down Expand Up @@ -6096,7 +6098,8 @@ static void ne_compute_forward_rms_norm_f32(const struct ne_compute_params* para
const size_t nb2 = dst->nb[2];
const size_t nb3 = dst->nb[3];

const float eps = 1e-6f; // TODO: make this a parameter
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
Expand Down Expand Up @@ -7864,7 +7867,8 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
NE_ASSERT(src1->type == NE_TYPE_I32);
NE_ASSERT(ne_nelements(src1) == 5 + bs); // 5 + bs params

static const float freq_base = 10000.0f;
float freq_base = 10000.0f;
memcpy(&freq_base, dst->op_params, sizeof(float));
static const float freq_scale = 1.0f;

const int64_t n_past = ((int32_t*)src1->data)[ROPE_NPAST_IDX];
Expand Down Expand Up @@ -10011,7 +10015,8 @@ static void ne_compute_backward(struct ne_context* ctx, struct ne_tensor* tensor
const int n_past = ((int32_t*)src1->data)[0];
const int n_dims = ((int32_t*)src1->data)[1];
const int mode = ((int32_t*)src1->data)[2];
src0->grad = ne_add_impl(ctx, src0->grad, ne_rope(ctx, tensor->grad, n_past, n_dims, mode, 0), inplace);
src0->grad =
ne_add_impl(ctx, src0->grad, ne_rope(ctx, tensor->grad, n_past, n_dims, mode, 0, 10000.0), inplace);
}
if (src1->grad) {
// noop
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,9 @@ NE_API struct ne_tensor* ne_silu_back(struct ne_context* ctx, struct ne_tensor*
// TODO: eps is hardcoded to 1e-5 for now
NE_API struct ne_tensor* ne_norm(struct ne_context* ctx, struct ne_tensor* a);

NE_API struct ne_tensor* ne_rms_norm(struct ne_context* ctx, struct ne_tensor* a);
NE_API struct ne_tensor* ne_rms_norm(struct ne_context* ctx, struct ne_tensor* a, float eps);

NE_API struct ne_tensor* ne_rms_norm_inplace(struct ne_context* ctx, struct ne_tensor* a);
NE_API struct ne_tensor* ne_rms_norm_inplace(struct ne_context* ctx, struct ne_tensor* a, float eps);
// a - x
// b - dy
NE_API struct ne_tensor* ne_rms_norm_back(struct ne_context* ctx, struct ne_tensor* a, struct ne_tensor* b);
Expand Down Expand Up @@ -403,27 +403,29 @@ NE_API struct ne_tensor* ne_soft_max_inplace(struct ne_context* ctx, struct ne_t
// if mode & 4 == 1, especially for glm
// TODO: avoid creating a new tensor every time
NE_API struct ne_tensor* ne_rope(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
int prompt_size);
int prompt_size, float freq_base);

// in-place, returns view(a)
NE_API struct ne_tensor* ne_rope_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
int prompt_size);
int prompt_size, float freq_base);

// shift all tokens by a give p (n_shift)
// Optionally give a 1d tensor of precomputed interleaved cos/sin value of n_shift*scale^k for k \in [0, n_dims)
NE_API struct ne_tensor* ne_rope_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims,
int mode, int prompt_size, int n_keep, struct ne_tensor* cossin);
int mode, int prompt_size, int n_keep, struct ne_tensor* cossin,
float freq_base);

// rotary position embedding backward, i.e compute dx from dy
// a - dy
NE_API struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode);

NE_API struct ne_tensor* ne_rope_with_padding(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims,
int mode, int prompt_size, int* n_padding);
int mode, int prompt_size, int* n_padding, float freq_base);

// in-place, returns view(a)
NE_API struct ne_tensor* ne_rope_with_padding_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past,
int n_dims, int mode, int prompt_size, int* n_padding);
int n_dims, int mode, int prompt_size, int* n_padding,
float freq_base);

// alibi position embedding
// in-place, returns view(a)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input*
struct ne_tensor* residual = inpL;

// LayerNorm
cur = ne_rms_norm(ctx0, inpL);
cur = ne_rms_norm(ctx0, inpL, hparams.rms_norm_eps);
cur = ne_mul(ctx0, cur, model.layers[il].norm[0]);
// SelfAttention
{
Expand Down Expand Up @@ -241,7 +241,7 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input*
residual = cur;

// post_attention_layernorm
struct ne_tensor* hidden_states = ne_rms_norm(ctx0, cur);
struct ne_tensor* hidden_states = ne_rms_norm(ctx0, cur, hparams.rms_norm_eps);
hidden_states = ne_mul(ctx0, hidden_states, model.layers[il].norm[1]);

// mlp.forward
Expand All @@ -267,7 +267,7 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input*
struct ne_tensor* embeddings = NULL;
// norm
{
inpL = ne_rms_norm(ctx0, inpL);
inpL = ne_rms_norm(ctx0, inpL, hparams.rms_norm_eps);
inpL = ne_mul(ctx0, inpL, model.others[1]);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,15 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
cur->nb[1] * N, 0); // [qlen * bs, 3 * hidden]

ne_set_name(query_layer, "query_layer");
query_layer =
ne_rope_with_padding_inplace(ctx0, query_layer, n_past, rope_dim, 4, first_tokens_size, n_padding.data());
query_layer = ne_rope_with_padding_inplace(ctx0, query_layer, n_past, rope_dim, 4, first_tokens_size,
n_padding.data(), hparams.freq_base);
query_layer = ne_permute(ctx0, query_layer, 0, 2, 1, 3); // [bs, heads, qlen, head_size]

ne_tensor* key_layer =
ne_view_4d(ctx0, cur, head_size, num_attention_heads, qlen, batch_size, 3 * head_size * ne_element_size(cur),
cur->nb[1], cur->nb[1] * qlen, head_size * ne_element_size(cur)); // [bs, qlen, heads, head_size]
key_layer =
ne_rope_with_padding_inplace(ctx0, key_layer, n_past, rope_dim, 4, first_tokens_size, n_padding.data());
key_layer = ne_rope_with_padding_inplace(ctx0, key_layer, n_past, rope_dim, 4, first_tokens_size,
n_padding.data(), hparams.freq_base);

ne_tensor* value_layer = ne_view_4d(ctx0, cur, head_size, num_attention_heads, qlen, batch_size,
3 * head_size * ne_element_size(cur), cur->nb[1], cur->nb[1] * qlen,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
lctx.use_buf(ctx0, 0);

// self-attention
cur = ne_rms_norm(ctx0, inpL);
cur = ne_rms_norm(ctx0, inpL, hparams.rms_norm_eps);
cur = ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].norm[0], cur), cur);
{
// compute QKV
Expand All @@ -146,14 +146,14 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
ne_view_3d(ctx0, cur, head_size, n_head, N, head_size * ne_element_size(cur), cur->nb[1],
0); // [N, heads, head_size]
ne_set_name(query_layer, "query_layer");
query_layer = ne_rope_inplace(ctx0, query_layer, std::max(n_cached - N, n_past), n_rot, 0, 0);
query_layer = ne_rope_inplace(ctx0, query_layer, std::max(n_cached - N, n_past), n_rot, 0, 0, hparams.freq_base);

struct ne_tensor* key_layer =
ne_view_3d(ctx0, cur, head_size, num_kv_heads, N, head_size * ne_element_size(cur), cur->nb[1],
hidden_size * ne_element_size(cur)); // [N, kv_heads, head_size]
ne_set_name(key_layer, "key_layer");
key_layer = ne_rope_inplace( // n_ctx exceeds but it will be shift-roped back with cached K
ctx0, key_layer, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0);
ctx0, key_layer, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0, hparams.freq_base);

struct ne_tensor* value_layer =
ne_view_3d(ctx0, cur, head_size, num_kv_heads, N, head_size * ne_element_size(cur), cur->nb[1],
Expand Down Expand Up @@ -198,7 +198,7 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
// Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N
// in a single eval execution
if (N == 1) cossin_cache = kv_self.cossin;
key_layer = ne_rope_shift_inplace(ctx0, key_layer, -N, n_rot, 0, 0, n_keep, cossin_cache);
key_layer = ne_rope_shift_inplace(ctx0, key_layer, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base);
key_layer = ne_permute(ctx0, key_layer, 0, 2, 1, 3); // perm back
}

Expand Down Expand Up @@ -253,7 +253,7 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
// Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N
// in a single eval execution
if (N == 1) cossin_cache = kv_self.cossin;
key_layer = ne_rope_shift_inplace(ctx0, key_layer, -N, n_rot, 0, 0, n_keep, cossin_cache);
key_layer = ne_rope_shift_inplace(ctx0, key_layer, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base);
}
value_layer =
ne_view_3d(ctx0, model.layers[il].v_cache, // tensor
Expand All @@ -275,7 +275,7 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
struct ne_tensor* hidden_states = ne_add(ctx0, inpL, cur);

// mlp.forward
struct ne_tensor* mlp_output = ne_rms_norm(ctx0, hidden_states);
struct ne_tensor* mlp_output = ne_rms_norm(ctx0, hidden_states, hparams.rms_norm_eps);
ne_set_name(mlp_output, "mlp_output");
// mlp_output = ne_mul(ctx0, mlp_output, model.layers[il].norm[1]);
mlp_output = ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].norm[1], mlp_output), mlp_output);
Expand All @@ -298,7 +298,7 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
struct ne_tensor* embeddings = NULL;
// norm
{
inpL = ne_rms_norm(ctx0, inpL);
inpL = ne_rms_norm(ctx0, inpL, hparams.rms_norm_eps);
ne_set_name(inpL, "inpL");
// inpL = ne_mul(ctx0, inpL, model.others[1]);
inpL = ne_mul(ctx0, ne_repeat(ctx0, model.others[1], inpL), inpL);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ static bool falcon_model_eval_internal(model_context* ctx, const model_input* in
fused_qkv_row_nb, (n_embd + n_head_kv * head_dim) * ne_element_size(cur));

// using mode = 2 for neox mode
Qcur = ne_rope_inplace(ctx0, Qcur, n_past, head_dim, 2, 0);
Kcur = ne_rope_inplace(ctx0, Kcur, n_past, head_dim, 2, 0);
Qcur = ne_rope_inplace(ctx0, Qcur, n_past, head_dim, 2, 0, hparams.freq_base);
Kcur = ne_rope_inplace(ctx0, Kcur, n_past, head_dim, 2, 0, hparams.freq_base);

// self-attention
const float attn_scale = 1.0f / sqrtf(static_cast<float>(head_dim));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,9 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
Kcur = ne_reshape_4d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), head_size, n_head, N, batch_size);
Vcur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur);
}
Qcur = ne_rope_inplace(ctx0, Qcur, std::max(n_cached - N, n_past), n_rot, 0, 0);
Qcur = ne_rope_inplace(ctx0, Qcur, std::max(n_cached - N, n_past), n_rot, 0, 0, hparams.freq_base);
Kcur = ne_rope_inplace( // n_ctx exceeds but it will be shift-roped back with cached K
ctx0, Kcur, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0);
ctx0, Kcur, (is_ring_full ? n_ctx : n_past), n_rot, 0, 0, hparams.freq_base);
ne_set_name(Qcur, "Qcur");
ne_set_name(Kcur, "Kcur");
ne_set_name(Vcur, "Vcur");
Expand Down Expand Up @@ -293,7 +293,7 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
// Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N
// in a single eval execution
if (N == 1) cossin_cache = kv_self.cossin;
K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache);
K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base);
}
const auto v_size = kv_cache_info.v_bytes;
V = ne_view_4d(ctx0, kv_self.v, // tensor
Expand Down Expand Up @@ -321,7 +321,7 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
// Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N in
// a single eval execution
if (N == 1) cossin_cache = kv_self.cossin;
K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache);
K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base);
K = ne_permute(ctx0, K, 0, 2, 1, 3);
}
} else {
Expand All @@ -332,7 +332,7 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
// Currently we only cache cossin for N == 1 in model-wide; It may be worthwhile to cache cossin for other N in
// a single eval execution
if (N == 1) cossin_cache = kv_self.cossin;
K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache);
K = ne_rope_shift_inplace(ctx0, K, -N, n_rot, 0, 0, n_keep, cossin_cache, hparams.freq_base);
K = ne_permute(ctx0, K, 0, 2, 1, 3);
}

Expand Down
Loading

0 comments on commit 021d66a

Please sign in to comment.