Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding nnx Gemma2-2b (including overall fixes) to examples/gemma #4587

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

mdda
Copy link

@mdda mdda commented Feb 28, 2025

Adds Gemma2-2b (including GQA to Attention and fixes to Block)

Fixes #4567

It includes:

  • New TransformerConfig for gemma2-2b model
  • Renames existing configs to make them uniform
    • NB: This cannot affect existing code, since this module was previously unusable
  • Added additional params reading key adjustment (key in gemma2-2b needs remapping from kaggle download)
  • Adds GQA to Attention module
  • Reorders the operations in Block module so that logits output from overall Transformer are not gibberish
    • logits confirmed to (approximately) match those from GDE gemma (flax linen) model
  • No new documentation provided
    • This change would make the example in the nnx documentation work (did not work before)
  • No additional tests provided

@cgarciae
Copy link
Collaborator

cgarciae commented Mar 4, 2025

Hey Martin! Thanks for doing this.
Some folks are internally also improving the model. Let me merge their changes first and make sure there are no conflicts with those changes.

@mdda
Copy link
Author

mdda commented Mar 5, 2025

Ahhh - Brings back memories of PRs for TensorFlow. Good times! /s

I'll attempt fix the white-space check failures, when you let me know that it isn't a waste of my time.

@@ -128,30 +128,31 @@ def gemma_7b(cls):
)

@classmethod
def gemma_27b(cls):
num_layers = 46
def gemma2_2b(cls):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we add a new method for gemma_27b configuration instead of removing the gemma2_2b one?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are both in the file - it's just the git that didn't pick up the diff properly.

But also, it makes sense to rename the different 'generations' of gemma, gemma2, gemma3, etc as separate classes, rather than relying on implicit knowledge about which size came from where.

Moreover, the gemma2_27 didn't have the right normalisation in the attention - will need a separate fix.

I also see that the Google-generated PR adopted the same fix as mine to the Block module. Good for you!

@cgarciae
Copy link
Collaborator

Thanks @mdda for doing this, it took a while because there where other changes to the model in the background that were pending. I think this is great. Can you please add a test to check that the GQA configuration works and matches the base version?

Also now need to solve the conflicts, sorry about this, this code is being used by a couple of users internally.

@mdda
Copy link
Author

mdda commented Mar 15, 2025

Surprised that the code was being used internally prior to my PR, since the Block module was entirely borked.

@cgarciae
Copy link
Collaborator

@mdda please take a look at CI, you probably need to run:

pip install pre-commit
pre-commit run --all-files

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Some surprises when adding Gemma2-2b to flax/examples/gemma
2 participants