This repository contains the official implementation of "LeaPformer: Enabling Linear Transformers for Autoregressive and Simultaneous Tasks via Learned Proportions," the preprint for which can be found here. LeaPformers are, fundamentally, a novel modification of specific re-weighting functions for linear attention mechanisms that can enable them for a wider range of tasks. Due to improved flexibility, oftentimes LeaPformers are also more accurate than alternatives with only a small amount of added latency.
Set-up for various parts of this repo are somewhat separated, as they were occasionally validated in different environments (i.e. the environment for LRA tests was not necessarily identical to the environment for LM or SimulST due to some compatibility issues). Instructions for set-up are provided in pytorch-lra
and fairseq-leapformer
.
Our slightly modified version of the Skyformer PyTorch LRA benchmark can be found in pytorch-lra,
containing several additional linear attention mechanisms compared to the original implementation. Details for running the LRA benchmark are also provided there, including some example scripts.
As a note, this particular set-up focuses on extremely small models, allowing for tests with quadratic, softmax attention on long-sequence tasks for medium-to-low quality hardware.
We validated LeaPformers on small-scale autoregressive language modeling (i.e. around 140M parameters) via an older, private fork of Fairseq, provided in fairseq-leapformer
. Scripts are available in fairseq-leapformer/leapformer-scripts/lm
and, should one want to use a more updated version of Fairseq, it can be found here.
Similarly, we validated LeaPformers on SimulST on that same Fairseq fork. Unlike the autoregressive language modeling example, changes for SimulST are also placed in the eval agent here and in custom attention formulations for SimulST here, where some custom encoder-decoder masking occurs and the SimulEval agent is modified. Scripts are available in fairseq-leapformer/leapformer-scripts/simulst
.
Cleaning up. Will be finished soon.
As mentioned in this work, our implementations (especially causal ones) are not optimized. A number of works have demonstrated the importance of constructing hardware-aware implementations to maximize performance. Obvious next steps here would be constructing a Triton-based LeaPformer implementation (à la Flash Linear Attention or FLA). In fact, integration with FLA is likely simple, especially for applications that are just decoder-based (e.g. autoregressive language modeling), requiring transforms being applied to the query and key before calling FLA specialized kernels.
LeaPformers were originally conceived back in mid-2023, and a number of interesting works have been published since then containing elements which can be applied towards LeaPformers. For example:
- There are no RNN-like gating mechanisms in this work, despite concurrent work like Gated Linear Attention (GLA) using it to great effect.
- Moreover, several works have skipped the time-dependent normalization term in linear attention, favoring normalization blocks (e.g. LayerNorm or GroupNorm, seen in papers here and here), similarly seen in GLA. In our experiments, this made no real difference but might at scale.
- Finally, the scale of the experiments in this work are ultimately small for modern applications, where it's very attractive to attempt to experiment at scale (i.e. around 300M+ minimum to several billion parameters).
If you found our work insightful or useful, please consider citing us as:
@inproceedings{
agostinelli2024leapformer,
title={LeaPformer: Enabling Linear Transformers for Autoregressive and Simultaneous Tasks via Learned Proportions},
author={Victor Agostinelli and Sanghyun Hong and Lizhong Chen},
booktitle={Forty-first International Conference on Machine Learning},
year={2024},
url={https://openreview.net/forum?id=XhH1OKLANY}
}