Skip to content

[ECCV 2024] AdaNAT: Exploring Adaptive Policy for Token-Based Image Generation

Notifications You must be signed in to change notification settings

LeapLabTHU/AdaNAT

Repository files navigation

AdaNAT (ECCV2024)

This repo contains the official PyTorch implementation of AdaNAT: Exploring Adaptive Policy for Token-Based Image Generation.

illustrate.png

Installation

We support PyTorch>=2.0.0 and torchvision>=0.15.1. Please install them following the official instructions.

Clone this repo and install the required packages:

git clone https://github.com/LeapLabTHU/AdaNAT
pip install tqdm loguru numpy pandas pyyaml einops omegaconf Pillow accelerate xformers transformers ninja
  • Prepare FID-stats: Download the FID-stats from this link and put it in assets/fid_stats directory.

  • Prepare pre-trained inception model for FID calculation: Download the pre-trained inception model from this link and put it in assets/pt_inception-2015-12-05-6726825d.pth.

  • Prepare pre-trained VQ-tokenizer: Use this link (from MAGE, thanks!) to download the pre-trained VQGAN tokenizer and put it in assets/vqgan_jax_strongaug.ckpt.

Class-conditional Generation on ImageNet-256

Data preparation

  • The ImageNet dataset should be prepared as follows:
data
├── train
│   ├── folder 1 (class 1)
│   ├── folder 2 (class 1)
│   ├── ...

Pre-trained Model & Evaluation

Download the pre-trained NAT model from this link. Download the pre-trained policy network from this link. Then run the following command for evaluation:

torchrun --nproc_per_node=8 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 train.py \
--data_root /path/of/ImageNet/dataset \
--config configs/imagenet_256_AdaNAT_L.yaml \
--state_dict_path /path/of/pretrained/imagenet/NAT/model \
--eval_paths /path/of/pretrained/policy/network \
--output_dir ./output/imagenet_256

Training the Policy Network

To train the policy network, run the following command:

torchrun --nproc_per_node=8 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 train.py \
--data_root /path/of/ImageNet/dataset \
--config configs/imagenet_256_AdaNAT_L.yaml \
--state_dict_path /path/of/pretrained/imagenet/NAT/model \
--max_training_timesteps 1000 \
--output_dir ./output/imagenet_256 

Text-to-image Generation on CC3M

Data preparation

Please refer CC3M_data_preparation.md for data preparation. After preparing the dataset, the directory structure should be as follows (cc3m_val.tsv is under assets directory):

data
├── cc3m_train.tsv
├── cc3m_val.tsv

Pre-trained Model & Evaluation

Download the pre-trained NAT model from this link. Download the pre-trained policy network from this link. Then run the following command for evaluation:

torchrun --nproc_per_node=8 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 train.py \
--dset cc3m \
--data_root /path/of/CC3M/dataset \
--config configs/cc3m_AdaNAT_muse.yaml \
--reference_image_path assets/fid_stats/fid_stats_cc3m_val.npz \
--c_dim 1280 \
--state_dict_path /path/of/pretrained/cc3m/NAT/model \
--eval_paths /path/of/pretrained/policy/network \
--n_samples 30000 \
--heu 1 \
--batch_size 64 \
--output_dir ./output/cc3m

Training the Policy Network

To train the policy network, run the following command:

torchrun --nproc_per_node=8 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 train.py \
--dset cc3m \
--data_root /path/of/CC3M/dataset \
--config configs/cc3m_AdaNAT_muse.yaml \
--reference_image_path assets/fid_stats/fid_stats_cc3m_val.npz \
--c_dim 1280 \
--state_dict_path /path/of/pretrained/cc3m/NAT/model \
--n_samples 30000 \
--heu 1 \
--batch_size 64 \
--output_dir ./output/cc3m \
--max_training_timesteps 150  # early stopping

Citation

If you find our work useful for your research, please consider citing

@inproceedings{Ni2024AdaNAT,
  title={AdaNAT: Exploring Adaptive Policy for Token-Based Image Generation.},
  author={Ni, Zanlin and Wang, Yulin and Zhou, Renping and Lu, Rui and Guo, Jiayi and Hu, Jinyi and Liu, Zhiyuan and Yao, Yuan and Huang, Gao},
  booktitle={ECCV},
  year={2024},
}

Acknowledgements

Our implementation is based on

We thank the authors for their excellent work.

Contact

If you have any questions, feel free to send mail to nzl22@mails.tsinghua.edu.cn.

About

[ECCV 2024] AdaNAT: Exploring Adaptive Policy for Token-Based Image Generation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published