Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attention and final logit soft-capping, update scaling factor to Gemma2 #8197

Merged
merged 10 commits into from
Jun 30, 2024

Conversation

abetlen
Copy link
Collaborator

@abetlen abetlen commented Jun 28, 2024

This PR adds the missing attention layer and final logit soft-capping. Implementation referenced from huggingface transformers. Additionally Gemma2 applies a pre-attention scaling of hidden_size / num_attention_heads.

NOTE: attention soft-capping is not compatible with flash attention so flash attention is disabled when loading the model.

Once this PR is finalised / merged the gguf will need to be generated again to include the soft-capping scales.

@slaren let me kv names / hparams should be changed or if anything stands out to you.

  • Self-reported review complexity:
    • Low
    • Medium
    • High

@github-actions github-actions bot added the python python script changes label Jun 28, 2024
@N8python
Copy link

This is absolutely awesome!

convert-hf-to-gguf.py Outdated Show resolved Hide resolved
src/llama.cpp Show resolved Hide resolved
src/llama.cpp Outdated Show resolved Hide resolved
@N8python
Copy link

Will this require updating the commit of llama.cpp, or just the gguf?

Co-authored-by: slaren <slarengh@gmail.com>
@slaren
Copy link
Collaborator

slaren commented Jun 28, 2024

Perplexity is looking much better now with the 9b base model.

[1]9.2071,[2]9.7469,[3]8.5964,[4]8.9120,[5]9.0974,[6]9.5314,[7]9.5795,[8]9.9935,[9]10.8124,[10]11.5707,[11]11.2326,[12]11.6694,[13]12.1546,[14]11.2626,[15]10.7205,[16]10.8900,[17]10.2102,[18]10.4680,[19]10.3165,[20]10.1336,[21]9.9171,[22]9.8545,[23]9.3609,[24]8.9963,[25]8.7959,[26]8.4272,[27]8.2837,[28]8.2287,[29]8.1524,[30]8.1909,[31]8.1299,[32]8.1604,[33]8.2355,[34]8.3548,[35]8.3592,[36]8.4679,[37]8.5151,[38]8.4315,[39]8.4928,[40]8.5288,[41]8.4178,[42]8.4029,[43]8.4344,[44]8.3349,[45]8.2776,[46]8.2448,[47]8.3508,[48]8.3703,[49]8.3698,[50]8.4112,[51]8.3984,[52]8.4121,[53]8.4798,[54]8.4386,[55]8.5013,[56]8.4724,[57]8.4617,[58]8.5479,[59]8.5988,[60]8.6415,[61]8.6199,[62]8.6748,[63]8.7169,[64]8.7935,[65]8.8943,[66]8.9945,[67]8.9497,[68]8.9609,[69]8.9441,[70]8.9776,[71]9.0380,[72]9.0679,[73]9.0876,[74]9.0596,[75]9.0648,[76]9.0820,[77]9.1173,[78]9.0350,[79]9.0445,[80]8.9748,[81]9.0271,[82]9.0024,[83]9.0165,[84]9.0856,[85]9.1867,[86]9.2471,[87]9.2703,[88]9.2412,[89]9.1967,[90]9.2079,[91]9.1894,[92]9.2655,[93]9.2878,[94]9.3000,[95]9.3255,[96]9.3562,[97]9.3499,[98]9.3618,[99]9.4381,[100]9.4716,
Final estimate: PPL = 9.4716 +/- 0.18491

9b-it also improved significantly:

[1]12.1578,[2]12.9578,[3]11.0636,[4]11.1662,[5]11.5421,[6]11.9280,[7]12.0649,[8]12.6567,[9]13.6820,[10]14.8274,[11]14.4769,[12]15.0507,[13]15.7453,[14]14.5753,[15]13.8343,[16]13.9240,[17]12.9838,[18]13.3013,[19]13.0462,[20]12.9001,[21]12.6406,[22]12.5324,[23]11.8712,[24]11.3990,[25]11.1628,[26]10.6578,[27]10.4992,[28]10.4211,[29]10.2483,[30]10.3309,[31]10.2608,[32]10.2979,[33]10.3751,[34]10.5412,[35]10.5460,[36]10.7030,[37]10.7956,[38]10.6544,[39]10.6946,[40]10.7043,[41]10.5450,[42]10.5358,[43]10.5697,[44]10.4340,[45]10.3420,[46]10.2950,[47]10.4161,[48]10.4460,[49]10.4774,[50]10.5432,[51]10.5373,[52]10.5660,[53]10.6603,[54]10.6175,[55]10.6904,[56]10.6554,[57]10.6355,[58]10.7337,[59]10.7862,[60]10.8272,[61]10.8127,[62]10.8861,[63]10.9615,[64]11.0753,[65]11.2012,[66]11.3123,[67]11.2318,[68]11.2459,[69]11.2298,[70]11.2590,[71]11.3402,[72]11.3831,[73]11.4191,[74]11.3693,[75]11.3663,[76]11.3823,[77]11.4085,[78]11.3111,[79]11.3382,[80]11.2589,[81]11.3194,[82]11.2924,[83]11.2912,[84]11.3787,[85]11.5069,[86]11.5906,[87]11.6226,[88]11.5911,[89]11.5328,[90]11.5461,[91]11.5310,[92]11.6363,[93]11.6774,[94]11.7063,[95]11.7455,[96]11.7893,[97]11.7733,[98]11.7826,[99]11.8833,[100]11.9212,
Final estimate: PPL = 11.9212 +/- 0.24186

@slaren
Copy link
Collaborator

slaren commented Jun 28, 2024

The chat template continues adding one line on each reply, there is probably something wrong there.

@theo77186
Copy link

Looks really nice. This would require regenerating gguf files as using existing gguf files don't have the required keys. Attempting to load an old gguf file with this PR results in this error:

llama_model_load: error loading model: error loading model hyperparameters: key not found in model: gemma2.attn_logit_softcapping

@abetlen
Copy link
Collaborator Author

abetlen commented Jun 28, 2024

@slaren thank you for checking!

The chat template continues adding one line on each reply, there is probably something wrong there.

Do you mean like this?

image

I do see that most of the time but occasionally it does generated the eot without adding an extra newline at the end.

@slaren
Copy link
Collaborator

slaren commented Jun 28, 2024

I am using the built-in chat template support, with -cnv. The result is this:

$ ./llama-cli -m models/gemma-2-9b-it/ggml-model-f16.gguf -ngl 99 -cnv 
[...]
== Running in interactive mode. ==
 - Press Ctrl+C to interject at any time.
 - Press Return to return control to the AI.
 - To return control without starting a new line, end your input with '/'.
 - If you want to submit another line, end your input with '\'.


> say 1
1

> say 2
2

> say 3
3


> say 4
4

> say 5
5


> say 6
6

>

Sometimes there is between one or two extra lines, which seems odd.

@abetlen
Copy link
Collaborator Author

abetlen commented Jun 28, 2024

Looks really nice. This would require regenerating gguf files as using existing gguf files don't have the required keys. Attempting to load an old gguf file with this PR results in this error:

llama_model_load: error loading model: error loading model hyperparameters: key not found in model: gemma2.attn_logit_softcapping

That's when you're loading an older gemma2 gguf correct? Yes, those will need to be regenerated.

src/llama.cpp Outdated Show resolved Hide resolved
@ngxson
Copy link
Collaborator

ngxson commented Jun 28, 2024

@slaren Thanks for spotting that error. I made a typo in the code, will be fixed with #8198

@ngxson
Copy link
Collaborator

ngxson commented Jun 28, 2024

Re comment: #8198 (comment)

I made a dirty patch to see it that's the root cause or not, but seems like the model still want to output multiple new lines:

> say hi
Hi there! 👋

> say 123
123 😊  

> say sunflower in emoji
🌻  





> 
'':108, '<start_of_turn>':106, 'user':1645, '':108, 'say':23523, ' sunflower':74207, ' in':575, ' emoji':52810, '<end_of_turn>':107, '':108, '<start_of_turn>':106, 'model':2516, '':108, '':241549, '  ':139, '':111, '<end_of_turn>':107

Explanation: token 108 \n is added before <start_of_turn> for user (in other words, after <end_of_turn> of the last turn)

But for some reasons, it still output token 111 \n\n\n\n, so I assume there is something to do with the model itself. Please note that I'm using -t 0 for all of my tests.


Edit: this is the 3rd message in the conversation. I did not add \n for the first message which should be correct.

@mofosyne mofosyne added the Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level label Jun 29, 2024
Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the new lines - maybe it's best to compare with the reference inference results and see if these are expected

@qnixsynapse
Copy link
Contributor

qnixsynapse commented Jun 29, 2024

BTW, Now am thinking that this is also probably needed.

From the technical paper of Gemma2 :

We alternate between a local sliding window attention (Beltagy et al., 2020a,b) and global attention (Luong et al., 2015) in every other layer. The sliding window size of local attention layers is set to 4096 tokens, while the span of the global attention layers is set to 8192 tokens.

I am not sure if this is implemented or not( I am still weak in C++, but still I am trying to find out in the code).

Edit Looks like HF Transformers has it.

Edit2: We don't even have a sliding window attention logic in either llm_build_kv or llm_build_kv_store functions, nor in the attention implementation for gemma2... Am I missing something?

@abetlen
Copy link
Collaborator Author

abetlen commented Jun 29, 2024

It seems that there's another difference as well, the attention scaling factor is a custom value of 1 / sqrt(144).

@abetlen abetlen changed the title Add attention and final logit soft-capping to Gemma2 Add attention and final logit soft-capping, custom scaling factor to Gemma2 Jun 29, 2024
@sinand99
Copy link

I hope this fixes Gemma and Phi models which are both broken now. Getting forever-repeating bug after they exceed their context length. Changing n_predict does not work.

@arlo-phoenix
Copy link
Contributor

arlo-phoenix commented Jun 29, 2024

Edit2: We don't even have a sliding window attention logic in either llm_build_kv or llm_build_kv_store functions, nor in the attention implementation for gemma2... Am I missing something?

@qnixsynapse No, it's not implemented, but since ggerganov refactored the KQ masking a while ago the change should be simple, the issue for SWA was also just reopened. I already tried implementing it for gemma2, see my comment here

EDIT:

BTW, Now am thinking that this is also probably needed.

From the technical paper of Gemma2 :

We alternate between a local sliding window attention (Beltagy et al., 2020a,b) and global attention (Luong et al., 2015) in every other layer. The sliding window size of local attention layers is set to 4096 tokens, while the span of the global attention layers is set to 8192 tokens.

I am not sure if this is implemented or not( I am still weak in C++, but still I am trying to find out in the code).

Edit Looks like HF Transformers has it.

good catch. Well that's gonna be harder. llama.cpp currently shares the mask between layers. Not like it's impossible to just add a second one, actually really simple, but I can't think of a clean way to implement this. I'll hack sth together and see if it works better than just always doing sliding window.

@arlo-phoenix
Copy link
Contributor

@qnixsynapse hacked it in here: https://github.com/arlo-phoenix/llama.cpp/tree/gemma2 Doesn't need new ggufs to work. Output is better at the start, but then runs into a repeating issue, so something is probably still missing or I messed something up. I'll stop for today though.

@abetlen abetlen changed the title Add attention and final logit soft-capping, custom scaling factor to Gemma2 Add attention and final logit soft-capping, update scaling factor to Gemma2 Jun 30, 2024
@abetlen
Copy link
Collaborator Author

abetlen commented Jun 30, 2024

I've updated the scaling factor approach to compute the value instead of adding yet another key to the gguf. This value is computed using hidden_size / num_attention_heads as in the original pytorch implementation.

@qnixsynapse
Copy link
Contributor

hacked it in here: https://github.com/arlo-phoenix/llama.cpp/tree/gemma2 Doesn't need new ggufs to work. Output is better at the start, but then runs into a repeating issue, so something is probably still missing or I messed something up. I'll stop for today though.

@arlo-phoenix Good work!

I will take a look at it today(Sunday).

Also, just to mention, I am still finding the HF's Transformer's Gemma 2 implementation to be somewhat subpar since the model hosted at HF chat still gives much low quality responses than the model hosted at Google AI studio. Not sure if the transformers library being used there is updated with the latest fixes or not.
So we cannot blindly depend on HF transformers library's implementation for now imo.

@abetlen abetlen merged commit 1c5eba6 into master Jun 30, 2024
56 checks passed
@ddh0
Copy link
Contributor

ddh0 commented Jun 30, 2024

Now that this is merged, if I make a new GGUF, am I good to go? Or are there other fixes I should wait for?

Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Jun 30, 2024
…tor to Gemma2 (ggerganov#8197)

* Add attention and final logit softcapping.

* fix

* Add custom add_ functions

* Disable flash attention for Gemma2

* Update src/llama.cpp

Co-authored-by: slaren <slarengh@gmail.com>

* Add default value for attention and final logit softcap value

* Add custom kq scaling from Gemma2Attention

* Remove custom pre attention scaling and use computed value instead.

---------

Co-authored-by: slaren <slarengh@gmail.com>
@qnixsynapse
Copy link
Contributor

qnixsynapse commented Jun 30, 2024

@ddh0 Currently, @arlo-phoenix is using hardcoded window size for testing. I think best is to regenerate ggufs once this(alternate SWA) is fixed.

Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Jun 30, 2024
…tor to Gemma2 (ggerganov#8197)

* Add attention and final logit softcapping.

* fix

* Add custom add_ functions

* Disable flash attention for Gemma2

* Update src/llama.cpp

Co-authored-by: slaren <slarengh@gmail.com>

* Add default value for attention and final logit softcap value

* Add custom kq scaling from Gemma2Attention

* Remove custom pre attention scaling and use computed value instead.

---------

Co-authored-by: slaren <slarengh@gmail.com>
Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Jun 30, 2024
…tor to Gemma2 (ggerganov#8197)

* Add attention and final logit softcapping.

* fix

* Add custom add_ functions

* Disable flash attention for Gemma2

* Update src/llama.cpp

Co-authored-by: slaren <slarengh@gmail.com>

* Add default value for attention and final logit softcap value

* Add custom kq scaling from Gemma2Attention

* Remove custom pre attention scaling and use computed value instead.

---------

Co-authored-by: slaren <slarengh@gmail.com>
src/llama.cpp Show resolved Hide resolved
@ngxson ngxson mentioned this pull request Jun 30, 2024
2 tasks
Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Jun 30, 2024
…tor to Gemma2 (ggerganov#8197)

* Add attention and final logit softcapping.

* fix

* Add custom add_ functions

* Disable flash attention for Gemma2

* Update src/llama.cpp

Co-authored-by: slaren <slarengh@gmail.com>

* Add default value for attention and final logit softcap value

* Add custom kq scaling from Gemma2Attention

* Remove custom pre attention scaling and use computed value instead.

---------

Co-authored-by: slaren <slarengh@gmail.com>
MagnusS0 pushed a commit to MagnusS0/llama.cpp-normistral-tokenizer that referenced this pull request Jul 1, 2024
…tor to Gemma2 (ggerganov#8197)

* Add attention and final logit softcapping.

* fix

* Add custom add_ functions

* Disable flash attention for Gemma2

* Update src/llama.cpp

Co-authored-by: slaren <slarengh@gmail.com>

* Add default value for attention and final logit softcap value

* Add custom kq scaling from Gemma2Attention

* Remove custom pre attention scaling and use computed value instead.

---------

Co-authored-by: slaren <slarengh@gmail.com>
Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Jul 1, 2024
…tor to Gemma2 (ggerganov#8197)

* Add attention and final logit softcapping.

* fix

* Add custom add_ functions

* Disable flash attention for Gemma2

* Update src/llama.cpp

Co-authored-by: slaren <slarengh@gmail.com>

* Add default value for attention and final logit softcap value

* Add custom kq scaling from Gemma2Attention

* Remove custom pre attention scaling and use computed value instead.

---------

Co-authored-by: slaren <slarengh@gmail.com>
jart pushed a commit to Mozilla-Ocho/llamafile that referenced this pull request Jul 1, 2024
@@ -11106,6 +11123,12 @@ struct llm_build_context {

// lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);

// final logit soft-capping
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
Copy link

@eran-medan eran-medan Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Total nitpick that probably should be ignored. I came here from curiosity, and I know this is merged by now and I have absolutely no place to comment. But
Isn’t this similar logic to lines 7594 - 7596?

While I’m a proponent of the “rule of 3” l think there’s merit in extracting it to something like a separate apply_softcap method. For educational purposes at least (gives the opportunity to add docs explaining what it does, single responsibility principle and all that, also I know for sure if I had to fix a bug in it, I’d fix it in one place and forget to update the other)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level
Projects
None yet
Development

Successfully merging this pull request may close these issues.