Explorations into the Taylor Series Linear Attention proposed in the paper Zoology: Measuring and Improving Recall in Efficient Language Models
This repository will offer full self attention, cross attention, and autoregressive via CUDA kernel from pytorch-fast-transformers
.
Be aware that in linear attention, the quadratic is pushed to the attention head dimension. With the second taylor expansion, this becomes O(D^3), so more research needed.
Update: It works! Strongest formulation of linear attention I've come across in the literature
- A16Z Open Source AI Grant Program and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
$ pip install taylor-series-linear-attention
import torch
from taylor_series_linear_attention import TaylorSeriesLinearAttn
attn = TaylorSeriesLinearAttn(
dim = 512,
dim_head = 16,
heads = 16
)
x = torch.randn(1, 4096, 512)
mask = torch.ones((1, 4096)).bool()
out = attn(x, mask = mask)
assert x.shape == out.shape
Cross attention
import torch
from taylor_series_linear_attention import TaylorSeriesLinearAttn
attn = TaylorSeriesLinearAttn(
dim = 512,
dim_head = 16,
heads = 16
)
x = torch.randn(1, 1024, 512)
context = torch.randn(1, 65536, 512)
context_mask = torch.ones((1, 65536)).bool()
out = attn(x, context = context, mask = context_mask)
assert x.shape == out.shape
For autoregressive, first pip install pytorch-fast-transformers
. Then set causal = True
import torch
from taylor_series_linear_attention import TaylorSeriesLinearAttn
attn = TaylorSeriesLinearAttn(
dim = 512,
dim_head = 16,
heads = 16,
causal = True, # set this to True
rotary_emb = True # rotary embeddings
)
x = torch.randn(1, 8192, 512)
out = attn(x)
assert x.shape == out.shape
- take care of caching for causal variant
@inproceedings{Arora2023ZoologyMA,
title = {Zoology: Measuring and Improving Recall in Efficient Language Models},
author = {Simran Arora and Sabri Eyuboglu and Aman Timalsina and Isys Johnson and Michael Poli and James Zou and Atri Rudra and Christopher R'e},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:266149332}
}
@inproceedings{Keles2022OnTC,
title = {On The Computational Complexity of Self-Attention},
author = {Feyza Duman Keles and Pruthuvi Maheshakya Wijewardena and Chinmay Hegde},
booktitle = {International Conference on Algorithmic Learning Theory},
year = {2022},
url = {https://api.semanticscholar.org/CorpusID:252198880}
}
@article{Shazeer2019FastTD,
title = {Fast Transformer Decoding: One Write-Head is All You Need},
author = {Noam M. Shazeer},
journal = {ArXiv},
year = {2019},
volume = {abs/1911.02150}
}
@inproceedings{Peng2023RWKVRR,
title = {RWKV: Reinventing RNNs for the Transformer Era},
author = {Bo Peng and Eric Alcaide and Quentin G. Anthony and Alon Albalak and Samuel Arcadinho and Stella Biderman and Huanqi Cao and Xin Cheng and Michael Chung and Matteo Grella and G Kranthikiran and Xuming He and Haowen Hou and Przemyslaw Kazienko and Jan Kocoń and Jiaming Kong and Bartlomiej Koptyra and Hayden Lau and Krishna Sri Ipsit Mantri and Ferdinand Mom and Atsushi Saito and Xiangru Tang and Bolun Wang and Johan Sokrates Wind and Stansilaw Wozniak and Ruichong Zhang and Zhenyuan Zhang and Qihang Zhao and Peng Zhou and Jian Zhu and Rui Zhu},
booktitle = {Conference on Empirical Methods in Natural Language Processing},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:258832459}
}
@inproceedings{Katharopoulos2020TransformersAR,
title = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
author = {Angelos Katharopoulos and Apoorv Vyas and Nikolaos Pappas and Franccois Fleuret},
booktitle = {International Conference on Machine Learning},
year = {2020},
url = {https://api.semanticscholar.org/CorpusID:220250819}
}
@misc{buckman2024,
author = {Buckman, Jacob and Gelada, Carles and Zhang, Sean},
publisher = {Manifest AI},
title = {Symmetric {Power} {Transformers}},
date = {2024-08-15},
langid = {en}
}
The greatest shortcoming of the human race is man’s inability to understand the exponential function. - Albert A. Bartlett