Skip to content

Latest commit

 

History

History
31 lines (21 loc) · 1.33 KB

README.md

File metadata and controls

31 lines (21 loc) · 1.33 KB

Basenji2 in PyTorch

This repo provides a PyTorch re-implementation of the Basenji2 model published in "Cross-species regulatory sequence activity prediction" by David Kelley. This implementation was checked by verifying that the Tensorflow and PyTorch version yielded the same output on random data. Small deviations were found, likely due to differences in the underlying algorithms used by Tensorflow and PyTorch (e.g. different matrix multiplication algorithms).

Installation

On Linux with conda/mamba:

  1. Clone the repository.
  2. Add it to your PYTHONPATH environment variable (i.e. in your .bashrc file).
  3. Use conda/mamba to install dependencies from the environment.yml found in the repo.
  4. Download the PyTorch weights.

Usage

import json
import torch
from basenji2_pytorch import Basenji2, params # or PLBasenji2 to use training parameters from Kelley et al. 2020

model_weights = 'path/to/basenji2.pth'

with open(params) as params_open:
    model_params = json.load(params_open)['model']

# to use a headless model e.g. for transfer learning
# model_params.pop("head_human", None)

basenji2 = Basenji2(model_params)
basenji2.load_state_dict(torch.load(model_weights), strict=False)