Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Llama Flax Implementation #24587

Merged
merged 91 commits into from
Dec 7, 2023
Merged

Conversation

vvvm23
Copy link
Contributor

@vvvm23 vvvm23 commented Jun 30, 2023

What does this PR do?

Fixes #26809. This is a work-in-progress port of Llama to Flax, leaving it as a draft PR for now.

The implementation is based heavily off the GPT-Neo and GPT-J Flax implementations.

Currently, the submodules are ready, I just need to assemble into a full model, check weight loading, add tests, and update the documentation.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sanchit-gandhi

@sanchit-gandhi
Copy link
Contributor

Very cool @vvvm23! Scanned through the PR and it looks very nice already - happy to do a full review when it's close to completion. Just drop me a line and I'll have a look! 🚀 Likewise if you have any questions or queries, I'm on hand to help :)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@gianlucadetommaso
Copy link

Hi @vvvm23 and @sanchit-gandhi, do you guys have a timeline for this effort? Asking because I would love to import FlaxLlama from Hugging Face, but if it is going to take a while, I will probably build my own pipeline to import the model.

Not sure if this helps at all, but here you find an implementation of Llama in Flax (plus some other library-specific methods that you probably won't need).

@vvvm23
Copy link
Contributor Author

vvvm23 commented Jul 6, 2023

Hi @gianlucadetommaso, I haven't had the time to work on this since this draft PR went live, but I am blocking time out this weekend to continue.

@sanchit-gandhi
Copy link
Contributor

Cool to see community interest around running Flax Llama! Feel free to ping me here when you need a review @vvvm23!

@vvvm23
Copy link
Contributor Author

vvvm23 commented Jul 10, 2023

Thanks @sanchit-gandhi I found a little time to continue today.

One issue I am noticing is that the tolerance when comparing the ground truth PyTorch implementation (in modeling_llama.py) and my own implementation, is a lot higher than I'd like. For three hidden layers in the decoder stack, I have to raise it to atol=1e-2, rtol=1e-2, with one hidden layer being at atol=1e-3, rtol=1e-3 in order to pass. You can see the scratch test I am using at the bottom of modeling_flax_llama.py

I think some numerical differences are expected, but not sure to what degree. I am also testing with float32 so that made me even more suspicious. Would you expected the results to be identical? This is my first time porting a PyTorch model to Flax. Thanks~

@vvvm23
Copy link
Contributor Author

vvvm23 commented Jul 12, 2023

Update: I now have a full model working. I haven't checked if the pretrained weight loading wrappers (provided by the Flax GPTNeo implementation) work yet, but once they are it will be ready for review. I'll simultaneously clean it up and add some missing features whilst it is being reviewed.

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Jul 13, 2023

Hey! Thanks for the progress update here @vvvm23 and great questions regarding numerical equivalence between models.

Generally, for any model less than 1B params we should be able to get equivalence to within 1e-5 between Flax and PyTorch. It's quite likely that you won't get this equivalence running the matmuls in bfloat16 on TPU. But you should be able to running the matmuls in float32, see #15754 and jax-ml/jax#10413 (comment) for details

Here's a script that I used previously for checking PT / Flax equivalence for BLOOM: https://github.com/sanchit-gandhi/codesnippets/blob/main/check_flax_bloom_jit_small_testing.ipynb You can ignore the bits about JIT'ing the forward pass for the time being. You can also uncomment the check to run it on CPU to force the highest precision, or use the decorator as provided

If we don't get 1e-5 precision, it's usually an indicator that we have a divergence in our model. Here, going through layer-by-layer and checking the hidden-states might be required to pinpoint it

@vvvm23
Copy link
Contributor Author

vvvm23 commented Jul 13, 2023

Okay, thanks for the guidance and helper scripts 🔥 I expected that this lack of precision was not normal 😅

I'll get the pretrained wrappers working first and then focus on debugging the numerical divergence.

I'm aiming for end of this week to fix those numerical issues, but my responsibilities elsewhere are pulling me a lot, so fingers crossed 🤞

@vvvm23
Copy link
Contributor Author

vvvm23 commented Jul 15, 2023

I've begun my hunt for numerical bugs 🐛

The first I squashed was rather strange. It seems torch.rsqrt and jax.lax.rsqrt do not match. This is used in the RMSNorm layers. Simple test to reproduce:

In [19]: a = np.asarray(a, dtype=np.float32)

In [20]: a
Out[20]:
array([1.16661310, 1.46686172, 0.13794081, 1.22346771, 1.17509305],
      dtype=float32)
In [21]: torch.rsqrt(torch.from_numpy(a))
Out[21]: tensor([0.92584139, 0.82566792, 2.69248700, 0.90407354, 0.92249471])

In [22]: jax.lax.rsqrt(a)
Out[22]: Array([0.92584133, 0.82566792, 2.69248700, 0.90407354, 0.92249471],      dtype=float32)

In [23]: 1 / torch.sqrt(torch.from_numpy(a))
Out[23]: tensor([0.92584139, 0.82566792, 2.69248700, 0.90407354, 0.92249471])

In [24]: 1 / jax.numpy.sqrt(a)
Out[24]: Array([0.92584139, 0.82566792, 2.69248700, 0.90407354, 0.92249471],      dtype=float32)

So the fix there was just to replace the jax.lax.rsqrt calls with 1 / jax.numpy.sqrt(...)

Models still mismatches so I'll keep digging.

@vvvm23
Copy link
Contributor Author

vvvm23 commented Jul 16, 2023

@sanchit-gandhi The model now numerically matches in fp32 on CPU. The issue was my backend has changed from CPU to GPU since fixing the rsqrt issue. I don't think we can expect a perfect match on GPU as the two models use fundamentally different backends. If there is anything you know of that could help remedy this, let me know.

What are the next steps to take? I am guessing some model tests, as well as trying it out on a real model checkpoint rather than random weights. However, my dev machine goes OOM when attempting to load the checkpoint on CPU.

@sanchit-gandhi
Copy link
Contributor

Hey @vvvm23! Excellent work on pinpointing the difference between torch and jax.lax rsqrt and glad to hear we're within numerical precision using fp32 on CPU - we can be pretty confident we have an accurate Flax implantation based on these results. For GPU, there will be differences between PyTorch and JAX. This is expected since JAX fundamentally works differently to PyTorch with how it computes the matmuls, and is OK since the JAX model will typically generate predictions that are 'as good' as the PyTorch one.

Adding some tests and updating the docs would be the most sensible next steps! Again, you can refer to the Flax GPT Neo model to see the relevant tests to add: https://github.com/huggingface/transformers/blob/main/tests/models/gpt_neo/test_modeling_flax_gpt_neo.py

However, my dev machine goes OOM when attempting to load the checkpoint on CPU.

That's interesting - are we loading the weights on GPU by accident? There shouldn't be any GPU OOM if running on CPU. We might see our RAM get full if loading extremely large weights, but the GPU memory shouldn't be affected. What model size are you loading? We can try the smallest 7b checkpoint: https://huggingface.co/meta-llama/Llama-2-7b

@vvvm23
Copy link
Contributor Author

vvvm23 commented Jul 26, 2023

Awesome thanks, tests and docs it is! I am currently on leave so won't be progressing on this until the 31st.

That's interesting - are we loading the weights on GPU by accident?

Actually, in the end no. By OOM on my dev machine, I meant out of CPU memory. Switching to a GPU backend meant I could load the model without running out of memory. So, nothing to worry about 😅

@sanchit-gandhi
Copy link
Contributor

Awesome - thanks for the update @vvvm23. Looking forward to doing a full review of the PR on your return!

@sanchit-gandhi
Copy link
Contributor

How's it looking @vvvm23? Let me know if I can help in anyway! Otherwise feel free to ping me here as soon as we're ready for a review, very excited to add this Flax model for the community!

@vvvm23
Copy link
Contributor Author

vvvm23 commented Aug 9, 2023

Hi, currently been pretty split responsibility wise (moving house and job !!) so have only made a small bit of progress.

Most of the tests pass, however, there seems to be some matmul shape mismatch in the generate_* tests. Guessing I didn't implement the KV cache correctly, so I'll need to look at that. I also added back some missing docstrings.

I'll have some time to work on this Thursday, Friday (10th and 11th) but then probably nothing for another week 🤯 If you are in a rush and fancy trying to get the remaining tests to pass, please try! Sorry for the slowness on my part also!

@vvvm23
Copy link
Contributor Author

vvvm23 commented Aug 10, 2023

The final tests ended up being easy to fix: I had simply forgotten to swap the attention mask and position ids in the pretrained model wrapper.

@sanchit-gandhi I haven't retested the slow tests locally (as my laptop is slow) but later today I can run them, then tidy the code a bit. If all goes well, should be good for review later today or early tomorrow 👍

@vvvm23
Copy link
Contributor Author

vvvm23 commented Aug 11, 2023

@sanchit-gandhi all tests pass locally 🎉 And I've also ran the model using the generate API to see if the outputs make sense:

In [23]: inputs = tokenizer('Aloha, World!', return_tensors='np')

In [24]: tokenizer.decode(model.generate(**inputs, generation_config=model.generation_config, max_length=100).sequences[0])
Out[24]: '<s> Aloha, World!\nI’m back from my trip to Hawaii and I’m feeling great! I’m still trying to get back into the swing of things, but I’m getting there. I’m going to be posting a lot of pictures from my trip, so stay tuned!\nI’m also going to be posting a lot of pictures from my trip to Hawaii, so stay tuned!\nI’m also going to be posting a lot of pictures'

Seems good to me!


I think this is ready for review. I would like to draw your attention to a few points I was unsure about:

Firstly, the model currently throws a warning when loading pretrained checkpoints:

Some weights of the model checkpoint at openlm-research/open_llama_3b_v2 were not used when initializing FlaxLlamaForCausalLM: 
{('model', 'layers', '9', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '1', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '24', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '11', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '7', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '23', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '13', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '5', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '6', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '20', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '21', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '16', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '10', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '4', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '0', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '25', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '12', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '3', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '19', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '14', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '18', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '22', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '8', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '15', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '17', 'self_attn', 'rotary_emb', 'inv_freq'), ('model', 'layers', '2', 'self_attn', 'rotary_emb', 'inv_freq')}

This has no effect on the outputs, just simply the Flax version of the model does not store the inv_freq tensor for rotary embeddings within the state dictionary, so these just get discarded. Is there a way to suppress this warning so not to scare any users?

Secondly, please double check the licensing. I just copied this from the PyTorch version of Llama and updated the year.

Third, I use the checkpoint openlm-research/open_llama_3b_v2 as it was the smallest, fully open Llama checkpoint I could find. The 'official' Llama checkpoints have gated access, so I am unsure if they are appropriate for testing / documentation purposes. This also means I haven't been able to test the model with the official Llama checkpoint as I still haven't managed to get permission from Meta 😢

Fourth, as we discussed a lot of the code is copied from the Flax implementation of GPT-Neo. There may be some leftover parts from there that we don't need in Llama, and I may have missed some best practices for Llama as GPT-Neo is (relatively) old now. In particular, see the following code block in FlaxLlamaPreTrainedModel.__call__:

        # TODO: can this handle input tensors being passed as kwargs? I copied GPT-Neo directly here
        outputs = self.module.apply(
            inputs,
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            not train,
            False,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
            mutable=mutable,
        )

Finally, the tests pass but please check if they have sufficient coverage 🤗


Generally speaking, I have a lot of experience writing model code but no experience making large contributions to the Huggingface ecosystem, so there is almost certainly a lot wrong! Apologies in advance and I will do my best to help you bring this model to the finish line 💪 Thanks for your work so far!

@vvvm23 vvvm23 marked this pull request as ready for review August 11, 2023 06:27
@vvvm23 vvvm23 changed the title [WIP] Add Llama Flax Implementation Add Llama Flax Implementation Aug 11, 2023
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking very nice already @vvvm23 - there's a lot that's right here! Great to hear that the tests are passing and that the generations look accurate. It should be quite fast to see this one to the finish line: just a small refactoring suggestion for the RMS norm layer, and two slow integration tests to check we have numerical correctness! Let's go!

src/transformers/models/llama/modeling_flax_llama.py Outdated Show resolved Hide resolved
src/transformers/models/llama/modeling_flax_llama.py Outdated Show resolved Hide resolved
src/transformers/models/llama/modeling_flax_llama.py Outdated Show resolved Hide resolved
tests/models/llama/test_modeling_flax_llama.py Outdated Show resolved Hide resolved
tests/models/llama/test_modeling_flax_llama.py Outdated Show resolved Hide resolved
tests/models/llama/test_modeling_flax_llama.py Outdated Show resolved Hide resolved
@vvvm23
Copy link
Contributor Author

vvvm23 commented Aug 23, 2023

Thanks for your additional comments, I have some time to work on the more involved points today 🤗

@vvvm23
Copy link
Contributor Author

vvvm23 commented Aug 24, 2023

@sanchit-gandhi I think everything except the missing weight issue is resolved now (see my comment).

Trying to resolve some remaining CI issues, I noticed that the line # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->Llama will change the line @add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING) , and overwrite LLAMA_INPUTS_DOCSTRNG. Any idea how to stop this happening? Otherwise the CI won't pass 🤔

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Aug 29, 2023

That's correct behaviour @vvvm23! What we need to do is create the variable LLAMA_INPUTS_DOCSTRNG in the modelling file that contains the necessary docstring info for LLAMA (more or less copied one-for-one from Flax GPT Neo, but adapted for any different inputs)

@vvvm23
Copy link
Contributor Author

vvvm23 commented Aug 29, 2023

Yeah, it is correct behaviour - what I meant though is that I did have a LLAMA_INPUTS_DOCSTRING in a previous commit, but running make fix-copies overwrote this docstring with the GPT-Neo version (as you suggested we add that at a class level). I guess my question was, how can we copy everything else in the class but somehow exclude the docstring line?

I get that we need the docstring itself, just currently the CI won't pass with both that docstring and the # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->Llama line. Does the issue make sense?

@vvvm23
Copy link
Contributor Author

vvvm23 commented Aug 31, 2023

@sanchit-gandhi fixed the CI issue (I think) by just adding more Copied from ... comments and deleting the class level comment. I also fixed the merge conflict. We should be good to go once CI passes I think 🙂

@vvvm23
Copy link
Contributor Author

vvvm23 commented Aug 31, 2023

@sanchit-gandhi the CI still fails, this is for two reasons. Could you assist me with resolving this?

  1. The documentation test fails as it tries to load the checkpoint _CHECKPOINT_FOR_DOC however it needs from_pt=True to be set.
  2. Flax Llama is based off the GPT Neo implementation. GPT Neo uses special tests to test equivalence between the flax and pytorch implementations. This overrides the common test_equivalence_pt_to_flax test. I copy these special tests (to make my flax tests pass). However, changing the tests for the flax version will cause the pytorch version to fail as it is using the flax version incorrectly.

edit: for the second, please see test_modeling_llama.py:308. These tests need to be overriden somehow, for now I just return directly to get the CI to pass.

All the tests pass and the model is pretty much ready. Just not sure how to get through these last two blockers. Help would be much appreciated!

@vvvm23
Copy link
Contributor Author

vvvm23 commented Nov 21, 2023

PRs on the hub would be better, however I cannot merge these PRs 😅

We could always create our own internal copy and point the test at that?

@ArthurZucker
Copy link
Collaborator

I'll open a PR to support revisions anyway, another contributor is also stuck because of this

@ArthurZucker
Copy link
Collaborator

Sorry for the delay, #27645 will help 😉

@vvvm23
Copy link
Contributor Author

vvvm23 commented Nov 23, 2023

nice! left some comments on that PR

@ArthurZucker
Copy link
Collaborator

Merged 😉

@vvvm23
Copy link
Contributor Author

vvvm23 commented Nov 24, 2023

The PR documentation test seems to time out. Is this a intermittent or known issue? Or something wrong with the code 🤔

@vvvm23
Copy link
Contributor Author

vvvm23 commented Dec 3, 2023

@ArthurZucker @sanchit-gandhi could you advise on the above? Not intimately familiar with how your test workers operate.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay, either use a smaller checkpoint, or just use a dummy checkpoint using something like this:

    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutputWithPast,
        config_class=_CONFIG_FOR_DOC,
    )

See here :

_CHECKPOINT_FOR_DOC = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM"

src/transformers/models/llama/modeling_flax_llama.py Outdated Show resolved Hide resolved
@kiansierra
Copy link
Contributor

Hi @vvvm23 I believe that in order to pass the tests you should update the https://huggingface.co/afmck/testing-llama-tiny/tree/main to have the flax ckpts and also the tokenizer, I created this https://huggingface.co/ksmcg/Mistral-tiny/tree/main for Mistral and it has passed the tests.
Also I belive the revision should be removed to pass

@vvvm23
Copy link
Contributor Author

vvvm23 commented Dec 4, 2023

Ah good catch, I will fix that tomorrow

@vvvm23
Copy link
Contributor Author

vvvm23 commented Dec 5, 2023

Finally, green light 😁

@ArthurZucker
Copy link
Collaborator

Let's just resolve the confilcts and good to go IMO!

@vvvm23
Copy link
Contributor Author

vvvm23 commented Dec 5, 2023

done!

@ArthurZucker ArthurZucker merged commit 75336c1 into huggingface:main Dec 7, 2023
3 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks for the PR 🤗

sbucaille pushed a commit to sbucaille/transformers that referenced this pull request Mar 19, 2024
* Copies `modeling_flax_gpt_neo.py` to start

* MLP Block. WIP Attention and Block

* Adds Flax implementation of `LlamaMLP`
Validated with in-file test.
Some slight numeric differences, but assuming it isn't an issue

* Adds `FlaxLlamaRMSNorm` layer
`flax.linen` includes `RMSNorm` layer but not necessarily in all
versions. Hence, we add in-file.

* Adds FlaxLlamaAttention
Copied from GPT-J as it has efficient caching implementation as well as
rotary embeddings.
Notice numerically different, but not by a huge amount. Needs
investigating

* Adds `FlaxLlamaDecoderLayer`
numerically inaccurate, debugging..

* debugging rotary mismatch
gptj uses interleaved whilst llama uses contiguous
i think they match now but still final result is wrong.
maybe drop back to just debugging attention layer?

* fixes bug with decoder layer
still somewhat numerically inaccurate, but close enough for now

* adds markers for what to implement next
the structure here diverges a lot from the PT version.
not a big fan of it, but just get something working for now

* implements `FlaxLlamaBlockCollection`]
tolerance must be higher than expected, kinda disconcerting

* Adds `FlaxLlamaModule`
equivalent PyTorch model is `LlamaModel`
yay! a language model🤗

* adds `FlaxLlamaForCausalLMModule`
equivalent to `LlamaForCausalLM`
still missing returning dict or tuple, will add later

* start porting pretrained wrappers
realised it probably needs return dict as a prereq

* cleanup, quality, style

* readds `return_dict` and model output named tuples

* (tentatively) pretrained wrappers work 🔥

* fixes numerical mismatch in `FlaxLlamaRMSNorm`
seems `jax.lax.rsqrt` does not match `torch.sqrt`.
manually computing `1 / jax.numpy.sqrt` results in matching values.

* [WIP] debugging numerics

* numerical match
I think issue was accidental change of backend. forcing CPU fixes test.
We expect some mismatch on GPU.

* adds in model and integration tests for Flax Llama
summary of failing:
- mul invalid combination of dimensions
- one numerical mismatch
- bf16 conversion (maybe my local backend issue)
- params are not FrozenDict

* adds missing TYPE_CHECKING import and `make fixup`

* adds back missing docstrings
needs review on quality of docstrings, not sure what is required.
Furthermore, need to check if `CHECKPOINT_FOR_DOC` is valid. See TODO

* commenting out equivalence test as can just use common

* debugging

* Fixes bug where mask and pos_ids were swapped in pretrained models
This results in all tests passing now 🔥

* cleanup of modeling file

* cleanup of test file

* Resolving simpler review comments

* addresses more minor review comments

* fixing introduced pytest errors from review

* wip additional slow tests

* wip tests
need to grab a GPU machine to get real logits for comparison
otherwise, slow tests should be okay

* `make quality`, `make style`

* adds slow integration tests
- checking logits
- checking hidden states
- checking generation outputs

* `make fix-copies`

* fix mangled function following `make fix-copies`

* adds missing type checking imports

* fixes missing parameter checkpoint warning

* more finegrained 'Copied from' tags
avoids issue of overwriting `LLAMA_INPUTS_DOCSTRING`

* swaps import guards
??? how did these get swapped initially?

* removing `inv_freq` again as pytorch version has now removed

* attempting to get CI to pass

* adds doc entries for llama flax models

* fixes typo in __init__.py imports

* adds back special equivalence tests
these come from the gpt neo flax tests. there is special behaviour for these models that needs to override the common version

* overrides tests with dummy to see if CI passes
need to fill in these tests later

* adds my contribution to docs

* `make style; make quality`

* replaces random masking with fixed to work with flax version

* `make quality; make style`

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* updates `x`->`tensor` in `rotate_half`

* addresses smaller review comments

* Update docs/source/en/model_doc/llama.md

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* adds integration test class

* adds `dtype` to rotary embedding to cast outputs

* adds type to flax llama rotary layer

* `make style`

* `make fix-copies`

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* applies suggestions from review

* Update modeling_flax_llama.py

* `make fix-copies`

* Update tests/models/llama/test_modeling_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* fixes shape mismatch in FlaxLlamaMLP

* applies some suggestions from reviews

* casts attn output logits to f32 regardless of dtype

* adds attn bias using `LlamaConfig.attention_bias`

* adds Copied From comments to Flax Llama test

* mistral and persimmon test change -copy from llama

* updates docs index

* removes Copied from in tests

it was preventing `make fix-copies` from succeeding

* quality and style

* ignores FlaxLlama input docstring

* adds revision to `_CHECKPOINT_FOR_DOC`

* repo consistency and quality

* removes unused import

* removes copied from from Phi test

now diverges from llama tests following FlaxLlama changes

* adds `_REAL_CHECKPOINT_FOR_DOC`

* removes refs from pr tests

* reformat to make ruff happy

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
sbucaille pushed a commit to sbucaille/transformers that referenced this pull request Mar 26, 2024
* Copies `modeling_flax_gpt_neo.py` to start

* MLP Block. WIP Attention and Block

* Adds Flax implementation of `LlamaMLP`
Validated with in-file test.
Some slight numeric differences, but assuming it isn't an issue

* Adds `FlaxLlamaRMSNorm` layer
`flax.linen` includes `RMSNorm` layer but not necessarily in all
versions. Hence, we add in-file.

* Adds FlaxLlamaAttention
Copied from GPT-J as it has efficient caching implementation as well as
rotary embeddings.
Notice numerically different, but not by a huge amount. Needs
investigating

* Adds `FlaxLlamaDecoderLayer`
numerically inaccurate, debugging..

* debugging rotary mismatch
gptj uses interleaved whilst llama uses contiguous
i think they match now but still final result is wrong.
maybe drop back to just debugging attention layer?

* fixes bug with decoder layer
still somewhat numerically inaccurate, but close enough for now

* adds markers for what to implement next
the structure here diverges a lot from the PT version.
not a big fan of it, but just get something working for now

* implements `FlaxLlamaBlockCollection`]
tolerance must be higher than expected, kinda disconcerting

* Adds `FlaxLlamaModule`
equivalent PyTorch model is `LlamaModel`
yay! a language model🤗

* adds `FlaxLlamaForCausalLMModule`
equivalent to `LlamaForCausalLM`
still missing returning dict or tuple, will add later

* start porting pretrained wrappers
realised it probably needs return dict as a prereq

* cleanup, quality, style

* readds `return_dict` and model output named tuples

* (tentatively) pretrained wrappers work 🔥

* fixes numerical mismatch in `FlaxLlamaRMSNorm`
seems `jax.lax.rsqrt` does not match `torch.sqrt`.
manually computing `1 / jax.numpy.sqrt` results in matching values.

* [WIP] debugging numerics

* numerical match
I think issue was accidental change of backend. forcing CPU fixes test.
We expect some mismatch on GPU.

* adds in model and integration tests for Flax Llama
summary of failing:
- mul invalid combination of dimensions
- one numerical mismatch
- bf16 conversion (maybe my local backend issue)
- params are not FrozenDict

* adds missing TYPE_CHECKING import and `make fixup`

* adds back missing docstrings
needs review on quality of docstrings, not sure what is required.
Furthermore, need to check if `CHECKPOINT_FOR_DOC` is valid. See TODO

* commenting out equivalence test as can just use common

* debugging

* Fixes bug where mask and pos_ids were swapped in pretrained models
This results in all tests passing now 🔥

* cleanup of modeling file

* cleanup of test file

* Resolving simpler review comments

* addresses more minor review comments

* fixing introduced pytest errors from review

* wip additional slow tests

* wip tests
need to grab a GPU machine to get real logits for comparison
otherwise, slow tests should be okay

* `make quality`, `make style`

* adds slow integration tests
- checking logits
- checking hidden states
- checking generation outputs

* `make fix-copies`

* fix mangled function following `make fix-copies`

* adds missing type checking imports

* fixes missing parameter checkpoint warning

* more finegrained 'Copied from' tags
avoids issue of overwriting `LLAMA_INPUTS_DOCSTRING`

* swaps import guards
??? how did these get swapped initially?

* removing `inv_freq` again as pytorch version has now removed

* attempting to get CI to pass

* adds doc entries for llama flax models

* fixes typo in __init__.py imports

* adds back special equivalence tests
these come from the gpt neo flax tests. there is special behaviour for these models that needs to override the common version

* overrides tests with dummy to see if CI passes
need to fill in these tests later

* adds my contribution to docs

* `make style; make quality`

* replaces random masking with fixed to work with flax version

* `make quality; make style`

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* updates `x`->`tensor` in `rotate_half`

* addresses smaller review comments

* Update docs/source/en/model_doc/llama.md

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* adds integration test class

* adds `dtype` to rotary embedding to cast outputs

* adds type to flax llama rotary layer

* `make style`

* `make fix-copies`

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* applies suggestions from review

* Update modeling_flax_llama.py

* `make fix-copies`

* Update tests/models/llama/test_modeling_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_flax_llama.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* fixes shape mismatch in FlaxLlamaMLP

* applies some suggestions from reviews

* casts attn output logits to f32 regardless of dtype

* adds attn bias using `LlamaConfig.attention_bias`

* adds Copied From comments to Flax Llama test

* mistral and persimmon test change -copy from llama

* updates docs index

* removes Copied from in tests

it was preventing `make fix-copies` from succeeding

* quality and style

* ignores FlaxLlama input docstring

* adds revision to `_CHECKPOINT_FOR_DOC`

* repo consistency and quality

* removes unused import

* removes copied from from Phi test

now diverges from llama tests following FlaxLlama changes

* adds `_REAL_CHECKPOINT_FOR_DOC`

* removes refs from pr tests

* reformat to make ruff happy

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Mistral Models to Flax
8 participants