This repo contains the official PyTorch code and pre-trained models for Injective Linear Attention (InLine).
- November 12 2024: Initialize repo.
Widely adopted in modern Vision Transformer designs, Softmax attention can effectively capture long-range visual information; however, it incurs excessive computational cost when dealing with high-resolution inputs. In contrast, linear attention naturally enjoys linear complexity and has great potential to scale up to higher-resolution images. Nonetheless, the unsatisfactory performance of linear attention greatly limits its practical application in various scenarios. In this paper, we take a step forward to close the gap between the linear and Softmax attention with novel theoretical analyses, which demystify the core factors behind the per formance deviations. Specifically, we present two key perspectives to understand and alleviate the limitations of linear attention: the injective property and the local modeling ability. Firstly, we prove that linear attention is not injective, which is prone to assign identical attention weights to different query vectors, thus adding to severe semantic confusion since different queries correspond to the same outputs. Secondly, we confirm that effective local modeling is essential for the success of Softmax attention, in which linear attention falls short. The aforementioned two fundamental differences significantly contribute to the disparities between these two attention paradigms, which is demonstrated by our substantial empirical validation in the paper. In addition, more experiment results indicate that linear attention, as long as endowed with these two properties, can outperform Softmax attention across various tasks while maintaining lower computation complexity.
We find that the injectivity of attention function greatly affects the performance of the model. Specifically, if the attention function is not injective, different queries will induce identical attention distributions, leading to severe semantic confusion within the feature space. Our prove that the Softmax attention function is an injective function, whereas the linear attention function is not. Therefore, linear attention is vulnerable to the semantic confusion problem, which largely leads to its insufficient expressiveness.
Our method, Injective Linear Attention (InLine):
- ImageNet-1K results.
- Real speed measurements. Benefited from linear complexity and simple design, our InLine attention delivers much higher inference speed than Softmax attention, especially in high-resolution scenarios.
- Python 3.9
- PyTorch == 1.11.0
- torchvision == 0.12.0
- numpy
- timm == 0.4.12
- yacs
The ImageNet dataset should be prepared as follows:
imagenet
├── train
│ ├── class1
│ │ ├── img1.jpeg
│ │ └── ...
│ ├── class2
│ │ ├── img2.jpeg
│ │ └── ...
│ └── ...
└── val
├── class1
│ ├── img3.jpeg
│ └── ...
├── class2
│ ├── img4.jpeg
│ └── ...
└── ...
model | Resolution | #Params | FLOPs | acc@1 | config | pretrained weights |
---|---|---|---|---|---|---|
InLine-DeiT-T | 224 | 6.5M | 1.1G | 74.5 | config | TsinghuaCloud |
InLine-DeiT-S | 288 | 16.7M | 5.0G | 80.2 | config | TsinghuaCloud |
InLine-DeiT-B | 448 | 23.8M | 17.2G | 82.3 | config | TsinghuaCloud |
InLine-PVT-T | 224 | 12.0M | 2.0G | 78.2 | config | TsinghuaCloud |
InLine-PVT-S | 224 | 21.6M | 3.9G | 82.0 | config | TsinghuaCloud |
InLine-PVT-M | 224 | 37.6M | 6.9G | 83.2 | config | TsinghuaCloud |
InLine-PVT-L | 224 | 50.2M | 10.2G | 83.6 | config | TsinghuaCloud |
InLine-Swin-T | 224 | 30M | 4.5G | 82.4 | config | TsinghuaCloud |
InLine-Swin-S | 224 | 50M | 8.7G | 83.6 | config | TsinghuaCloud |
InLine-Swin-B | 224 | 88M | 15.4G | 84.1 | config | TsinghuaCloud |
InLine-CSwin-T | 224 | 25M | 4.3G | 83.2 | config | TsinghuaCloud |
InLine-CSwin-S | 224 | 43M | 6.8G | 83.8 | config | TsinghuaCloud |
InLine-CSwin-B | 224 | 96M | 14.9G | 84.5 | config | TsinghuaCloud |
- Evaluate InLine-DeiT/PVT/Swin on ImageNet:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path> --eval --resume <path-to-pretrained-weights>
- To train InLine-DeiT/PVT/Swin on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path> --amp
- Evaluate InLine-CSwin on ImageNet:
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path> --eval --resume <path-to-pretrained-weights>
- To train InLine-CSwin on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path> --amp
This code is developed on the top of Swin Transformer.
If you find this repo helpful, please consider citing us.
@inproceedings{han2024inline,
title={Bridging the Divide: Reconsidering Softmax and Linear Attention
},
author={Han, Dongchen and Pu, Yifan and Xia, Zhuofan and Han, Yizeng and Pan, Xuran and Li, Xiu and Lu, Jiwen and Song, Shiji and Huang, Gao},
booktitle={NeurIPS},
year={2024},
}
If you have any questions, please feel free to contact the authors.
Dongchen Han: hdc23@mails.tsinghua.edu.cn