Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attention and final logit soft-capping, update scaling factor to Gemma2 #8197

Merged
merged 10 commits into from
Jun 30, 2024
6 changes: 6 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2363,6 +2363,12 @@ def set_gguf_parameters(self):
self.gguf_writer.add_key_length(hparams["head_dim"])
self.gguf_writer.add_value_length(hparams["head_dim"])
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_attn_logit_softcapping(
self.hparams["attn_logit_softcapping"]
)
self.gguf_writer.add_final_logit_softcapping(
self.hparams["final_logit_softcapping"]
)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unusem
Expand Down
2 changes: 2 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class LLM:
POOLING_TYPE = "{arch}.pooling_type"
LOGIT_SCALE = "{arch}.logit_scale"
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"

class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,12 @@ def add_clamp_kqv(self, value: float) -> None:
def add_logit_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)

def add_attn_logit_softcapping(self, value: float) -> None:
self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)

def add_final_logit_softcapping(self, value: float) -> None:
self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)

def add_expert_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)

Expand Down
33 changes: 31 additions & 2 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ enum llm_kv {
LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID,
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
LLM_KV_FINAL_LOGIT_SOFTCAPPING,

LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
Expand Down Expand Up @@ -392,6 +394,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },

{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
Expand Down Expand Up @@ -2099,6 +2103,9 @@ struct llama_hparams {
float f_norm_eps;
float f_norm_rms_eps;

float f_attn_logit_softcapping;
float f_final_logit_softcapping;
abetlen marked this conversation as resolved.
Show resolved Hide resolved

float rope_attn_factor = 1.0f;
float rope_freq_base_train;
float rope_freq_scale_train;
Expand All @@ -2115,8 +2122,9 @@ struct llama_hparams {
float f_max_alibi_bias = 0.0f;
float f_logit_scale = 0.0f;

bool causal_attn = true;
bool use_alibi = false;
bool causal_attn = true;
bool use_alibi = false;
bool attn_soft_cap = false;

enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
Expand Down Expand Up @@ -4702,6 +4710,9 @@ static void llm_load_hparams(
case LLM_ARCH_GEMMA2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping);
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping);
hparams.attn_soft_cap = true;

switch (hparams.n_layer) {
case 42: model.type = e_model::MODEL_9B; break;
Expand Down Expand Up @@ -7579,6 +7590,12 @@ static struct ggml_tensor * llm_build_kqv(
kq = ggml_scale(ctx, kq, 30);
}

if (hparams.attn_soft_cap) {
kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
kq = ggml_tanh(ctx, kq);
kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
}
abetlen marked this conversation as resolved.
Show resolved Hide resolved

kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);

Expand Down Expand Up @@ -11106,6 +11123,12 @@ struct llm_build_context {

// lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);

// final logit soft-capping
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
Copy link

@eran-medan eran-medan Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Total nitpick that probably should be ignored. I came here from curiosity, and I know this is merged by now and I have absolutely no place to comment. But
Isn’t this similar logic to lines 7594 - 7596?

While I’m a proponent of the “rule of 3” l think there’s merit in extracting it to something like a separate apply_softcap method. For educational purposes at least (gives the opportunity to add docs explaining what it does, single responsibility principle and all that, also I know for sure if I had to fix a bug in it, I’d fix it in one place and forget to update the other)

cur = ggml_tanh(ctx0, cur);
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);

cb(cur, "result_output", -1);

ggml_build_forward_expand(gf, cur);
Expand Down Expand Up @@ -17379,6 +17402,12 @@ struct llama_context * llama_new_context_with_model(
params.flash_attn = false;
}

if (params.flash_attn && model->hparams.attn_soft_cap) {
LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
params.flash_attn = false;
}


if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
params.flash_attn = false;
Expand Down
Loading