Skip to content

patrick-kidger/esm2quinox

Repository files navigation

ESM2quinox

An implementation of ESM2 in Equinox+JAX

Installation

pip install esm2quinox

Public API

See their docstrings for details:

esm2quinox
    .ESM2
        .__init__(self, num_layers: int, embed_size: int, num_heads: int, token_dropout: bool, key: PRNGKeyArray)
        .__call__(self, tokens: Int[np.ndarray | jax.Array, " length"]) -> esm2quinox.ESM2Result

    .ESM2Result
        .hidden: Float[Array, "length embed_size"]
        .logits: Float[Array, "length alphabet_size"]

    .tokenise(proteins: list[str], length: None | int = None, key: None | PRNGKeyArray = None)

    .from_torch(torch_esm2: esm.ESM2) -> esm2quinox.ESM2

Quick examples

Load an equivalent pretrained model from PyTorch:

import esm  # pip install fair-esm==2.0.0
import esm2quinox

torch_model, _ = esm.pretrained.esm2_t6_8M_UR50D()
model = esm2quinox.from_torch(torch_model)

Create a randomly-initialised model:

import esm2quinox
import jax.random as jr

key = jr.key(1337)
model = esm2quinox.ESM2(num_layers=3, embed_size=32, num_heads=2, token_dropout=False, key=key)

Forward pass (note the model operates on unbatched data):

proteins = esm2quinox.tokenise(["SPIDERMAN", "FOO"])
out = jax.vmap(model)(proteins)
out.hidden  # hidden representation from last layer
out.logits  # logits for masked positions

About

An implementation of ESM2 in Equinox+JAX

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Languages