-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Llama Flax Implementation #24587
Conversation
Very cool @vvvm23! Scanned through the PR and it looks very nice already - happy to do a full review when it's close to completion. Just drop me a line and I'll have a look! 🚀 Likewise if you have any questions or queries, I'm on hand to help :) |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
Hi @vvvm23 and @sanchit-gandhi, do you guys have a timeline for this effort? Asking because I would love to import FlaxLlama from Hugging Face, but if it is going to take a while, I will probably build my own pipeline to import the model. Not sure if this helps at all, but here you find an implementation of Llama in Flax (plus some other library-specific methods that you probably won't need). |
Hi @gianlucadetommaso, I haven't had the time to work on this since this draft PR went live, but I am blocking time out this weekend to continue. |
Cool to see community interest around running Flax Llama! Feel free to ping me here when you need a review @vvvm23! |
Thanks @sanchit-gandhi I found a little time to continue today. One issue I am noticing is that the tolerance when comparing the ground truth PyTorch implementation (in I think some numerical differences are expected, but not sure to what degree. I am also testing with |
Update: I now have a full model working. I haven't checked if the pretrained weight loading wrappers (provided by the Flax GPTNeo implementation) work yet, but once they are it will be ready for review. I'll simultaneously clean it up and add some missing features whilst it is being reviewed. |
Hey! Thanks for the progress update here @vvvm23 and great questions regarding numerical equivalence between models. Generally, for any model less than 1B params we should be able to get equivalence to within 1e-5 between Flax and PyTorch. It's quite likely that you won't get this equivalence running the matmuls in bfloat16 on TPU. But you should be able to running the matmuls in float32, see #15754 and jax-ml/jax#10413 (comment) for details Here's a script that I used previously for checking PT / Flax equivalence for BLOOM: https://github.com/sanchit-gandhi/codesnippets/blob/main/check_flax_bloom_jit_small_testing.ipynb You can ignore the bits about JIT'ing the forward pass for the time being. You can also uncomment the check to run it on CPU to force the highest precision, or use the decorator as provided If we don't get 1e-5 precision, it's usually an indicator that we have a divergence in our model. Here, going through layer-by-layer and checking the hidden-states might be required to pinpoint it |
Okay, thanks for the guidance and helper scripts 🔥 I expected that this lack of precision was not normal 😅 I'll get the pretrained wrappers working first and then focus on debugging the numerical divergence. I'm aiming for end of this week to fix those numerical issues, but my responsibilities elsewhere are pulling me a lot, so fingers crossed 🤞 |
I've begun my hunt for numerical bugs 🐛 The first I squashed was rather strange. It seems
So the fix there was just to replace the Models still mismatches so I'll keep digging. |
@sanchit-gandhi The model now numerically matches in fp32 on CPU. The issue was my backend has changed from CPU to GPU since fixing the What are the next steps to take? I am guessing some model tests, as well as trying it out on a real model checkpoint rather than random weights. However, my dev machine goes OOM when attempting to load the checkpoint on CPU. |
Hey @vvvm23! Excellent work on pinpointing the difference between torch and jax.lax Adding some tests and updating the docs would be the most sensible next steps! Again, you can refer to the Flax GPT Neo model to see the relevant tests to add: https://github.com/huggingface/transformers/blob/main/tests/models/gpt_neo/test_modeling_flax_gpt_neo.py
That's interesting - are we loading the weights on GPU by accident? There shouldn't be any GPU OOM if running on CPU. We might see our RAM get full if loading extremely large weights, but the GPU memory shouldn't be affected. What model size are you loading? We can try the smallest 7b checkpoint: https://huggingface.co/meta-llama/Llama-2-7b |
Awesome thanks, tests and docs it is! I am currently on leave so won't be progressing on this until the 31st.
Actually, in the end no. By OOM on my dev machine, I meant out of CPU memory. Switching to a GPU backend meant I could load the model without running out of memory. So, nothing to worry about 😅 |
Awesome - thanks for the update @vvvm23. Looking forward to doing a full review of the PR on your return! |
How's it looking @vvvm23? Let me know if I can help in anyway! Otherwise feel free to ping me here as soon as we're ready for a review, very excited to add this Flax model for the community! |
Hi, currently been pretty split responsibility wise (moving house and job !!) so have only made a small bit of progress. Most of the tests pass, however, there seems to be some matmul shape mismatch in the I'll have some time to work on this Thursday, Friday (10th and 11th) but then probably nothing for another week 🤯 If you are in a rush and fancy trying to get the remaining tests to pass, please try! Sorry for the slowness on my part also! |
The final tests ended up being easy to fix: I had simply forgotten to swap the attention mask and position ids in the pretrained model wrapper. @sanchit-gandhi I haven't retested the slow tests locally (as my laptop is slow) but later today I can run them, then tidy the code a bit. If all goes well, should be good for review later today or early tomorrow 👍 |
@sanchit-gandhi all tests pass locally 🎉 And I've also ran the model using the
Seems good to me! I think this is ready for review. I would like to draw your attention to a few points I was unsure about: Firstly, the model currently throws a warning when loading pretrained checkpoints:
This has no effect on the outputs, just simply the Flax version of the model does not store the Secondly, please double check the licensing. I just copied this from the PyTorch version of Llama and updated the year. Third, I use the checkpoint Fourth, as we discussed a lot of the code is copied from the Flax implementation of GPT-Neo. There may be some leftover parts from there that we don't need in Llama, and I may have missed some best practices for Llama as GPT-Neo is (relatively) old now. In particular, see the following code block in
Finally, the tests pass but please check if they have sufficient coverage 🤗 Generally speaking, I have a lot of experience writing model code but no experience making large contributions to the Huggingface ecosystem, so there is almost certainly a lot wrong! Apologies in advance and I will do my best to help you bring this model to the finish line 💪 Thanks for your work so far! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking very nice already @vvvm23 - there's a lot that's right here! Great to hear that the tests are passing and that the generations look accurate. It should be quite fast to see this one to the finish line: just a small refactoring suggestion for the RMS norm layer, and two slow integration tests to check we have numerical correctness! Let's go!
Thanks for your additional comments, I have some time to work on the more involved points today 🤗 |
@sanchit-gandhi I think everything except the missing weight issue is resolved now (see my comment). Trying to resolve some remaining CI issues, I noticed that the line |
That's correct behaviour @vvvm23! What we need to do is create the variable |
Yeah, it is correct behaviour - what I meant though is that I did have a I get that we need the docstring itself, just currently the CI won't pass with both that docstring and the |
@sanchit-gandhi fixed the CI issue (I think) by just adding more |
@sanchit-gandhi the CI still fails, this is for two reasons. Could you assist me with resolving this?
All the tests pass and the model is pretty much ready. Just not sure how to get through these last two blockers. Help would be much appreciated! |
PRs on the hub would be better, however I cannot merge these PRs 😅 We could always create our own internal copy and point the test at that? |
I'll open a PR to support revisions anyway, another contributor is also stuck because of this |
Sorry for the delay, #27645 will help 😉 |
nice! left some comments on that PR |
Merged 😉 |
now diverges from llama tests following FlaxLlama changes
The PR documentation test seems to time out. Is this a intermittent or known issue? Or something wrong with the code 🤔 |
@ArthurZucker @sanchit-gandhi could you advise on the above? Not intimately familiar with how your test workers operate. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay, either use a smaller checkpoint, or just use a dummy checkpoint using something like this:
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
See here :
_CHECKPOINT_FOR_DOC = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM" |
Hi @vvvm23 I believe that in order to pass the tests you should update the https://huggingface.co/afmck/testing-llama-tiny/tree/main to have the flax ckpts and also the tokenizer, I created this https://huggingface.co/ksmcg/Mistral-tiny/tree/main for Mistral and it has passed the tests. |
Ah good catch, I will fix that tomorrow |
Finally, green light 😁 |
Let's just resolve the confilcts and good to go IMO! |
done! |
Thanks for the PR 🤗 |
* Copies `modeling_flax_gpt_neo.py` to start * MLP Block. WIP Attention and Block * Adds Flax implementation of `LlamaMLP` Validated with in-file test. Some slight numeric differences, but assuming it isn't an issue * Adds `FlaxLlamaRMSNorm` layer `flax.linen` includes `RMSNorm` layer but not necessarily in all versions. Hence, we add in-file. * Adds FlaxLlamaAttention Copied from GPT-J as it has efficient caching implementation as well as rotary embeddings. Notice numerically different, but not by a huge amount. Needs investigating * Adds `FlaxLlamaDecoderLayer` numerically inaccurate, debugging.. * debugging rotary mismatch gptj uses interleaved whilst llama uses contiguous i think they match now but still final result is wrong. maybe drop back to just debugging attention layer? * fixes bug with decoder layer still somewhat numerically inaccurate, but close enough for now * adds markers for what to implement next the structure here diverges a lot from the PT version. not a big fan of it, but just get something working for now * implements `FlaxLlamaBlockCollection`] tolerance must be higher than expected, kinda disconcerting * Adds `FlaxLlamaModule` equivalent PyTorch model is `LlamaModel` yay! a language model🤗 * adds `FlaxLlamaForCausalLMModule` equivalent to `LlamaForCausalLM` still missing returning dict or tuple, will add later * start porting pretrained wrappers realised it probably needs return dict as a prereq * cleanup, quality, style * readds `return_dict` and model output named tuples * (tentatively) pretrained wrappers work 🔥 * fixes numerical mismatch in `FlaxLlamaRMSNorm` seems `jax.lax.rsqrt` does not match `torch.sqrt`. manually computing `1 / jax.numpy.sqrt` results in matching values. * [WIP] debugging numerics * numerical match I think issue was accidental change of backend. forcing CPU fixes test. We expect some mismatch on GPU. * adds in model and integration tests for Flax Llama summary of failing: - mul invalid combination of dimensions - one numerical mismatch - bf16 conversion (maybe my local backend issue) - params are not FrozenDict * adds missing TYPE_CHECKING import and `make fixup` * adds back missing docstrings needs review on quality of docstrings, not sure what is required. Furthermore, need to check if `CHECKPOINT_FOR_DOC` is valid. See TODO * commenting out equivalence test as can just use common * debugging * Fixes bug where mask and pos_ids were swapped in pretrained models This results in all tests passing now 🔥 * cleanup of modeling file * cleanup of test file * Resolving simpler review comments * addresses more minor review comments * fixing introduced pytest errors from review * wip additional slow tests * wip tests need to grab a GPU machine to get real logits for comparison otherwise, slow tests should be okay * `make quality`, `make style` * adds slow integration tests - checking logits - checking hidden states - checking generation outputs * `make fix-copies` * fix mangled function following `make fix-copies` * adds missing type checking imports * fixes missing parameter checkpoint warning * more finegrained 'Copied from' tags avoids issue of overwriting `LLAMA_INPUTS_DOCSTRING` * swaps import guards ??? how did these get swapped initially? * removing `inv_freq` again as pytorch version has now removed * attempting to get CI to pass * adds doc entries for llama flax models * fixes typo in __init__.py imports * adds back special equivalence tests these come from the gpt neo flax tests. there is special behaviour for these models that needs to override the common version * overrides tests with dummy to see if CI passes need to fill in these tests later * adds my contribution to docs * `make style; make quality` * replaces random masking with fixed to work with flax version * `make quality; make style` * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * updates `x`->`tensor` in `rotate_half` * addresses smaller review comments * Update docs/source/en/model_doc/llama.md Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * adds integration test class * adds `dtype` to rotary embedding to cast outputs * adds type to flax llama rotary layer * `make style` * `make fix-copies` * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * applies suggestions from review * Update modeling_flax_llama.py * `make fix-copies` * Update tests/models/llama/test_modeling_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * fixes shape mismatch in FlaxLlamaMLP * applies some suggestions from reviews * casts attn output logits to f32 regardless of dtype * adds attn bias using `LlamaConfig.attention_bias` * adds Copied From comments to Flax Llama test * mistral and persimmon test change -copy from llama * updates docs index * removes Copied from in tests it was preventing `make fix-copies` from succeeding * quality and style * ignores FlaxLlama input docstring * adds revision to `_CHECKPOINT_FOR_DOC` * repo consistency and quality * removes unused import * removes copied from from Phi test now diverges from llama tests following FlaxLlama changes * adds `_REAL_CHECKPOINT_FOR_DOC` * removes refs from pr tests * reformat to make ruff happy --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* Copies `modeling_flax_gpt_neo.py` to start * MLP Block. WIP Attention and Block * Adds Flax implementation of `LlamaMLP` Validated with in-file test. Some slight numeric differences, but assuming it isn't an issue * Adds `FlaxLlamaRMSNorm` layer `flax.linen` includes `RMSNorm` layer but not necessarily in all versions. Hence, we add in-file. * Adds FlaxLlamaAttention Copied from GPT-J as it has efficient caching implementation as well as rotary embeddings. Notice numerically different, but not by a huge amount. Needs investigating * Adds `FlaxLlamaDecoderLayer` numerically inaccurate, debugging.. * debugging rotary mismatch gptj uses interleaved whilst llama uses contiguous i think they match now but still final result is wrong. maybe drop back to just debugging attention layer? * fixes bug with decoder layer still somewhat numerically inaccurate, but close enough for now * adds markers for what to implement next the structure here diverges a lot from the PT version. not a big fan of it, but just get something working for now * implements `FlaxLlamaBlockCollection`] tolerance must be higher than expected, kinda disconcerting * Adds `FlaxLlamaModule` equivalent PyTorch model is `LlamaModel` yay! a language model🤗 * adds `FlaxLlamaForCausalLMModule` equivalent to `LlamaForCausalLM` still missing returning dict or tuple, will add later * start porting pretrained wrappers realised it probably needs return dict as a prereq * cleanup, quality, style * readds `return_dict` and model output named tuples * (tentatively) pretrained wrappers work 🔥 * fixes numerical mismatch in `FlaxLlamaRMSNorm` seems `jax.lax.rsqrt` does not match `torch.sqrt`. manually computing `1 / jax.numpy.sqrt` results in matching values. * [WIP] debugging numerics * numerical match I think issue was accidental change of backend. forcing CPU fixes test. We expect some mismatch on GPU. * adds in model and integration tests for Flax Llama summary of failing: - mul invalid combination of dimensions - one numerical mismatch - bf16 conversion (maybe my local backend issue) - params are not FrozenDict * adds missing TYPE_CHECKING import and `make fixup` * adds back missing docstrings needs review on quality of docstrings, not sure what is required. Furthermore, need to check if `CHECKPOINT_FOR_DOC` is valid. See TODO * commenting out equivalence test as can just use common * debugging * Fixes bug where mask and pos_ids were swapped in pretrained models This results in all tests passing now 🔥 * cleanup of modeling file * cleanup of test file * Resolving simpler review comments * addresses more minor review comments * fixing introduced pytest errors from review * wip additional slow tests * wip tests need to grab a GPU machine to get real logits for comparison otherwise, slow tests should be okay * `make quality`, `make style` * adds slow integration tests - checking logits - checking hidden states - checking generation outputs * `make fix-copies` * fix mangled function following `make fix-copies` * adds missing type checking imports * fixes missing parameter checkpoint warning * more finegrained 'Copied from' tags avoids issue of overwriting `LLAMA_INPUTS_DOCSTRING` * swaps import guards ??? how did these get swapped initially? * removing `inv_freq` again as pytorch version has now removed * attempting to get CI to pass * adds doc entries for llama flax models * fixes typo in __init__.py imports * adds back special equivalence tests these come from the gpt neo flax tests. there is special behaviour for these models that needs to override the common version * overrides tests with dummy to see if CI passes need to fill in these tests later * adds my contribution to docs * `make style; make quality` * replaces random masking with fixed to work with flax version * `make quality; make style` * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * updates `x`->`tensor` in `rotate_half` * addresses smaller review comments * Update docs/source/en/model_doc/llama.md Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * adds integration test class * adds `dtype` to rotary embedding to cast outputs * adds type to flax llama rotary layer * `make style` * `make fix-copies` * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * applies suggestions from review * Update modeling_flax_llama.py * `make fix-copies` * Update tests/models/llama/test_modeling_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/llama/modeling_flax_llama.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * fixes shape mismatch in FlaxLlamaMLP * applies some suggestions from reviews * casts attn output logits to f32 regardless of dtype * adds attn bias using `LlamaConfig.attention_bias` * adds Copied From comments to Flax Llama test * mistral and persimmon test change -copy from llama * updates docs index * removes Copied from in tests it was preventing `make fix-copies` from succeeding * quality and style * ignores FlaxLlama input docstring * adds revision to `_CHECKPOINT_FOR_DOC` * repo consistency and quality * removes unused import * removes copied from from Phi test now diverges from llama tests following FlaxLlama changes * adds `_REAL_CHECKPOINT_FOR_DOC` * removes refs from pr tests * reformat to make ruff happy --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
What does this PR do?
Fixes #26809. This is a work-in-progress port of Llama to Flax, leaving it as a draft PR for now.
The implementation is based heavily off the GPT-Neo and GPT-J Flax implementations.
Currently, the submodules are ready, I just need to assemble into a full model, check weight loading, add tests, and update the documentation.
Before submitting
Pull Request section?
to it if that's the case. mentioned in this issue comment
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sanchit-gandhi