Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mistral Models to Flax #26809

Closed
kiansierra opened this issue Oct 14, 2023 · 10 comments · Fixed by #24587 or #26943
Closed

Add Mistral Models to Flax #26809

kiansierra opened this issue Oct 14, 2023 · 10 comments · Fixed by #24587 or #26943
Labels
Feature request Request for a new feature

Comments

@kiansierra
Copy link
Contributor

kiansierra commented Oct 14, 2023

Feature request

I would like to implement the Llama Mistral model in flax

Motivation

I've been trying to get familiar with jax and as such I started migrating the llama model, and I think I am at a point where both models are quite comparable in outcome

Your contribution

Yes I could submit a PR with the model implementation

@ArthurZucker ArthurZucker added the Feature request Request for a new feature label Oct 16, 2023
@ArthurZucker
Copy link
Collaborator

I think this could be interesting! Feel free to open a PR and ping @sanchit-gandhi 😉

@sanchit-gandhi
Copy link
Contributor

Hey @kiansierra - there's already a PR for Flax LLaMA that is pretty much ready to be merged: #24587 Feel free to check it out!

But we'd love contributions for other LLM's in the library where there's only PyTorch support and not Flax 🤗 If there are particular checkpoints on the HF Hub that you see getting a lot of usage (downloads) where there's only PyTorch support but not Flax, definitely let us know here and we can get going with a PR! 🚀

@kiansierra
Copy link
Contributor Author

Thansk for the Heads up @sanchit-gandhi, I'll see if there is any other model I think I can add to Flax and tag you on the next issue

@ArthurZucker
Copy link
Collaborator

Oups I even reviewed the PR 😅 sorry @kiansierra 🤗

@vvvm23
Copy link
Contributor

vvvm23 commented Oct 17, 2023

@kiansierra sorry to scoop Flax Llama from you! If you want any suggestions, I think Mistral is a pretty popular model right now without a Flax port.

@kiansierra
Copy link
Contributor Author

Hey no worries, I think I will give Mistral a go, it seems some of the work can be ported

@kiansierra kiansierra changed the title Add Llama Models to Flax Add Mistral Models to Flax Oct 17, 2023
@kiansierra kiansierra reopened this Oct 17, 2023
@konstantinos-p
Copy link

Happy to see that a couple of people are interested in porting these models to flax! I was also interested in contributing! Is there any other model that would be interesting? On a side note: I guess flash-attention only works for the pytorch models atm (?) Is there any fundamental reason why porting the flash-attention implementation to jax would be difficult?

@erfanzar
Copy link

hello, guys I have created both llama and mistral models in flax if you want you can use them modelling_mistral_flax.py

@kiansierra kiansierra mentioned this issue Oct 19, 2023
5 tasks
@sanchit-gandhi
Copy link
Contributor

Yes Flash Attention relies on dispatching optimised CUDA kernels, which as far as I'm aware haven't been implemented in JAX. You could look into Pallas and see if someone's written Flash Attention kernels for JAX using this library? https://jax.readthedocs.io/en/latest/pallas/design.html

@konstantinos-p
Copy link

konstantinos-p commented Nov 3, 2023

Indeed there's an effort to write FlashAttention in Pallas, https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/attention.py although it's still a work in progress jax-ml/jax#17328 . @sanchit-gandhi I'd be happy to try to port another model. For example, Yarn-Mistral seems to have some traction, though it's not part of the transformers library atm. Any other suggestions are welcome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants