-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Mistral Models to Flax #26809
Comments
I think this could be interesting! Feel free to open a PR and ping @sanchit-gandhi 😉 |
Hey @kiansierra - there's already a PR for Flax LLaMA that is pretty much ready to be merged: #24587 Feel free to check it out! But we'd love contributions for other LLM's in the library where there's only PyTorch support and not Flax 🤗 If there are particular checkpoints on the HF Hub that you see getting a lot of usage (downloads) where there's only PyTorch support but not Flax, definitely let us know here and we can get going with a PR! 🚀 |
Thansk for the Heads up @sanchit-gandhi, I'll see if there is any other model I think I can add to Flax and tag you on the next issue |
Oups I even reviewed the PR 😅 sorry @kiansierra 🤗 |
@kiansierra sorry to scoop Flax Llama from you! If you want any suggestions, I think Mistral is a pretty popular model right now without a Flax port. |
Hey no worries, I think I will give Mistral a go, it seems some of the work can be ported |
Happy to see that a couple of people are interested in porting these models to flax! I was also interested in contributing! Is there any other model that would be interesting? On a side note: I guess flash-attention only works for the pytorch models atm (?) Is there any fundamental reason why porting the flash-attention implementation to jax would be difficult? |
hello, guys I have created both llama and mistral models in flax if you want you can use them modelling_mistral_flax.py |
Yes Flash Attention relies on dispatching optimised CUDA kernels, which as far as I'm aware haven't been implemented in JAX. You could look into Pallas and see if someone's written Flash Attention kernels for JAX using this library? https://jax.readthedocs.io/en/latest/pallas/design.html |
Indeed there's an effort to write FlashAttention in Pallas, https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/attention.py although it's still a work in progress jax-ml/jax#17328 . @sanchit-gandhi I'd be happy to try to port another model. For example, Yarn-Mistral seems to have some traction, though it's not part of the transformers library atm. Any other suggestions are welcome! |
Feature request
I would like to implement the
LlamaMistral model in flaxMotivation
I've been trying to get familiar with jax and as such I started migrating the llama model, and I think I am at a point where both models are quite comparable in outcome
Your contribution
Yes I could submit a PR with the model implementation
The text was updated successfully, but these errors were encountered: