Skip to content

This repository includes the official implementation our paper "Scaling White-Box Transformers for Vision"

Notifications You must be signed in to change notification settings

UCSC-VLAA/CRATE-alpha

Repository files navigation

Scaling White-Box Transformers for Vision

This repo contains official JAX implementation of CRATE-alpha in our paper: Scaling White-Box Transformers for Vision

We propose CRATE-α, featuring strategic yet minimal modifications to the sparse coding block in the CRATE architecture design, and a light training recipe designed to improve the scalability of CRATE.

One layer of the CRATE-α model architecture. MSSA (Multi-head Subspace Self-Attention) represents the compression block, and ODL (Overcomplete Dictionary Learning) represents the sparse coding block.

Comparison of CRATE, CRATE-α, and ViT

Left: We demonstrate how modifications to the components enhance the performance of the CRATE model on ImageNet-1K. Right: We compare the FLOPs and accuracy on ImageNet-1K of our methods with ViT Dosovitskiy et al., 2020 and CRATE Yu et al., 2023. CRATE is trained only on ImageNet-1K, while ours and ViT are pre-trained on ImageNet-21K.

Visualize the Improvement of Semantic Interpretability of CRATE-α.

Visualization of segmentation on COCO val2017 Lin et al., 2014 with MaskCut Wang et al., 2023. Top row: Supervised ours effectively identifies the main objects in the image. Compared with CRATE (Middle row), ours achieves better segmentation performance in terms of boundary. Bottom row: Supervised ViT fails to identify the main objects in most images. We warp the failed image in a red box.

Experimental Results

Models (Base) ImageNet-1K(%) Models (Large) ImageNet-1K(%)
CRATE-α-B/32 76.5 CRATE-α-L/32 80.2
CRATE-α-B/16 81.2 CRATE-α-L/14 83.9
CRATE-α-B/8 83.2 CRATE-α-L/8 85.1

Download Model Weights

You can download model weights from the following link: Model Weights

TPU Usage and Environment Installation

TPU Usage

Our experiments are conducted on TPUs. How can we gain access to and set up TPU machines? Check this brief doc in CLIPA.

Environment Installation

To set up the environment, run the following script:

bash scripts/env/setup_env.sh

Training

Classification Training

We provide scripts for pre-training on ImageNet-21K and fine-tuning on ImageNet-1K.

Pre-training on ImageNet-21K

To start pre-training on ImageNet-21K, run:

bash scripts/in1k/pre_training_in21k.sh

Fine-tuning on ImageNet-1K

To start fine-tuning on ImageNet-1K, run:

bash scripts/in1k/fine_tuning_in1k.sh

Vision-Language Contrastive Learning Training

We provide scripts for pre-training and fine-tuning on Datacomp1B.

Pre-training on Datacomp1B

To start pre-training on Datacomp1B, run:

bash scripts/clipa/pre_train.sh

Fine-tuning on Datacomp1B

To start fine-tuning on Datacomp1B, run:

bash scripts/clipa/fine_tune.sh

PyTorch Inference

To increase accessibility, we have converted the weights from JAX to PyTorch. We provide models in configurations B/16, L/14, CRATE-α-CLIPA-L/14, and CRATE-α-CLIPA-H/14. You can use the PyTorch code to reproduce the results from our paper.

Preparing ImageNet-1K Validation Set

You can download the ImageNet-1K validation set using the following commands:

wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar 
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz

Dependencies

For the PyTorch environment, the recommended dependencies are as follows:

pip install torch==2.0.0
pip install torchvision==0.15.0
pip install transformers==4.40.2
pip install open-clip-torch==2.24.0

Reproducing Results on ImageNet-1K with PyTorch

Model PyTorch Accuracy JAX (Paper) Accuracy PyTorch Weights
CRATE-α-B/16 81.2 81.2 Download
CRATE-α-L/14 83.9 83.9 Download
CRATE-α-CLIPA-L/14 69.8 69.8 Download
CRATE-α-CLIPA-H/14 72.3 72.3 Download

PyTorch Weights

Weights for the PyTorch models are available for download. Use the links provided in the table above.

Classification

To run the evaluation code, specify the path to the checkpoints and the ImageNet validation set in the eval_in1k_cls.py file.

python torch_inference/eval_in1k_cls.py

Zero-Shot on ImageNet-1K

For the CLIPA PyTorch version, we refer to CLIP.

To run the evaluation code, specify the path to the checkpoints and the ImageNet validation set in the eval_in1k.py and clipa_model.py files. The default model is CRATE-α-CLIPA-L/14.

python torch_inference/eval_in1k.py

Acknowledgement

The repo is built on big vision and CLIPA. Many thanks to the awesome works from the open-source community!

We are also very grateful that this work is supported by a gift from Open Philanthropy, TPU Research Cloud (TRC) program, and Google Cloud Research Credits program.

Citation

@article{yang2024cratealpha,
  title   = {Scaling White-Box Transformers for Vision},
  author  = {Yang, Jinrui and Li, Xianhang and Pai, Druv and Zhou, Yuyin and Ma, Yi and Yu, Yaodong and Xie, Cihang},
  journal = {arXiv preprint arXiv:2405.20299},
  year    = {2024}
}

Contact

If you have any questions, please feel free to raise an issue or contact us directly: jyang347@ucsc.edu.

About

This repository includes the official implementation our paper "Scaling White-Box Transformers for Vision"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published