install ohara
pip install ohara
To train MLA:
python train_mla.py --attn_type=mla
For baseline, use MHA:
python train_mla.py --attn_type=mha
If you cant to calculate the number of parameters, and check what % kv cache you'll save visite this link: https://joey00072.github.io/Multi-Head-Latent-Attention-MLA-/
- write blog post
- add jax version
- Add GQA and MOQ in calculation (index.html)
- Distill llama to MLA version Maybe