This is the the official repository for:
Also contains a baseline implementation of TNN.
Clone using
git clone --recursive https://github.com/jonathanmei/ski-tnn.git
If you miss the recursive flag, you can update after the fact
cd lra-tnn
git submodule update --init
On a fresh linux install:
bash setup.sh
OR if conda is already installed, set up environments:
conda env create --file tnn.yaml
conda env create -f lra.yaml && conda env update -f lra2.yaml
bash setup_wikitext_data.sh
bash setup_lra_data.sh
In environment tnn
:
SKI-TNN and FD-TNN are be found in laxtnn/
. Run via
bash laxtnn/scripts/train_alm.sh
bash laxtnn/scripts/train_blm.sh
In environment lra
:
Running setup_lra_data.sh
should have created lra_release/
in same dir as ski-tnn/
.
Run via
python script_lra.py
The architectures and tasks are matched according to the data type. Batch size and number of GPU's can be modified to fit hardware configuration.
This is performed by setup_wikitext_data.sh
and comes from fairseq.
Use the following command to train the autoregressive language model:
bash script_alm.sh
You should change data_dir to preprocessed data.
After training, you can do a length extrapolation test by the following command, replacing model architecture and sequence length as desired:
bash laxtnn/scripts/length_extrapolation.sh laxtnn_alm_tno_fd_3lyrs 4096
The same as the autoregressive language model part.
Use the following command to train the bidirectional language model:
bash script_blm.sh
You should change data_dir to preprocessed data.
We provide the setup_lra_data.sh
script.
The main
branch of this repository points to the main
branch of the lra-tnn
, which is a minified version of the code that reproduces the paper. The dev
branch of the lra-tnn
repository contains implementations of other architectures that allows running more experiments.
@inproceedings{
moreno2023ski,
title={{SKI to go Faster: Accelerating Toeplitz Neural Networks via Asymmetric Kernels}},
author={Alexander Moreno and Jonathan Mei and Luke Walters},
booktitle={arXiv:2305.09028},
year={2023},
url={https://arxiv.org/abs/2305.09028}
}