Skip to content

Latest commit

 

History

History
47 lines (32 loc) · 2.06 KB

README.md

File metadata and controls

47 lines (32 loc) · 2.06 KB

mamba-minimal-jax

Simple, minimal implementation of the Mamba SSM in one file of JAX.

Plan:

  1. First finish the model.py, done.
  2. Convert the pytorch weights into the JAX weights, done.
  3. Check the results of greedy generation is the same as pytorch, done.
  4. Implement the associative scan so that the state update is faster, done in the speedup branch. See discussion in srush/annotated-mamba#1.
  5. Pay attention to the weights initialization so that we can train the model from scratch.
  6. Implement the step function for mamba inference.

From mamba-minimal

Featuring:

  • Equivalent numerical output as official implementation for both forward and backward pass
  • Simplified, readable, annotated code

Does NOT include:

  • Speed. The official implementation is heavily optimized, and these optimizations are core contributions of the Mamba paper. I kept most implementations simple for readability.
  • Proper parameter initialization (though this could be added without sacrificing readability)

Demo

See demo.ipynb for examples of prompt completions.

from model import Mamba
from transformers import AutoTokenizer

model = Mamba.from_pretrained('state-spaces/mamba-370m')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

generate(model, tokenizer, 'Mamba is the')

Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)

150 meters... 🫢 scary!

References

The Mamba architecture was introduced in Mamba: Linear-Time Sequence Modeling with Selective State Spaces by Albert Gu and Tri Dao.

The official implementation is here: https://github.com/state-spaces/mamba

The minimal implementation in torch is here: https://github.com/johnma2006/mamba-minimal