Skip to content

Official PyTorch implementation of "On Inference Stability for Diffusion Models" (AAAI'24)

License

Notifications You must be signed in to change notification settings

VinAIResearch/SA-DPM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Table of contents
  1. Installation
  2. Dataset preparation
  3. How to run
  4. Models and Hyperparameters
  5. Evaluation
  6. Acknowledgments
  7. Contacts

Official PyTorch implementation of "On Inference Stability for Diffusion Models" (AAAI'24)

Abstract: Denoising Probabilistic Models (DPMs) represent an emerging domain of generative models that excel in generating diverse and high-quality images. However, most current training methods for DPMs often neglect the correlation between timesteps, limiting the model's performance in generating images effectively. Notably, we theoretically point out that this issue can be caused by the cumulative estimation gap between the predicted and the actual trajectory. To minimize that gap, we propose a novel sequence-aware loss that aims to reduce the estimation gap to enhance the sampling quality. Furthermore, we theoretically show that our proposed loss function is a tighter upper bound of the estimation loss in comparison with the conventional loss in DPMs. Experimental results on several benchmark datasets including CIFAR10, CelebA, and CelebA-HQ consistently show a remarkable improvement of our proposed method regarding the image generalization quality measured by FID and Inception Score compared to several DPM baselines.

Details of algorithms and experimental results can be found in our following paper:

@inproceedings{nguyen2024inference,
  title={On Inference Stability for Diffusion Models},
  author={Viet Nguyen and Giang Vu and Tung Nguyen Thanh and Khoat Than and Toan Tran},
  booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence}
  year={2024}
}

Please CITE our paper whenever this repository is used to help produce published results or incorporated into other software.

News

Installation

Python 3.8.0 and Pytorch 2.0.0 are used in this implementation.

Dataset preparation

We trained on five datasets, including CIFAR10, CelebA 64, FFHQ 64, AFHQ 64 and CelebA-HQ 256.

For CIFAR10, they will be automatically downloaded in the first time execution.

For FFHQ 64 and AFHQ 64, please download files here

For CelebA 64 and CelebA HQ 256, please check out here and here for dataset preparation.

Once a dataset is downloaded, please put it in exp/datasets/ directory as follows:

exp/datasets/
├── cifar10
├── celeba
├── celebahq
├── ffhq64
├── afhq

For FID score statistics of the datasets, please download files here and put it in pytorch_fid/ directory

How to run

We provide a bash script for our experiments on different datasets. The syntax is following:

python train_ddp.py --config <DATASET>.yml \
    --doc <MODEL_NAME> --ni \
    --num_consecutive_steps <K> --lamda <lamda> \
    --num_process_per_node <#GPUS>

where:

  • <DATASET>: cifar10, celeba, celebahq, afhq, and ffhq64.
  • <K>: the number of consecutive steps (e.g. 2, 3, 4).
  • <lamda>: the coefficient of SA loss (e.g. 0.2, 0.5).
  • <#GPUS>: the number of used GPUs (e.g. 1, 2, 4, 8).

Models and Hyperparameters

CelebA 64x64 with the SA_DPM loss [checkpoint]

python train_ddp.py --config celeba.yml --doc celeba_sa --ni --num_consecutive_steps 2 --lamda 1 --num_process_per_node 4

AFHQv2 64x64 with the base loss [checkpoint]

python train_ddp.py --config afhq.yml --doc afhq_simple --ni --num_consecutive_steps 0 --lamda 0 --num_process_per_node 4

AFHQv2 64x64 with the SA_DPM loss [checkpoint]

python train_ddp.py --config afhq.yml --doc afhq_sa --ni --num_consecutive_steps 2 --lamda 0.2 --num_process_per_node 4

FFHQ 64x64 with the base loss [checkpoint]

python train_ddp.py --config ffhq64.yml --doc ffhq_simple --ni --num_consecutive_steps 0 --lamda 0 --num_process_per_node 4

FFHQ 64x64 with the SA_DPM loss [checkpoint]

python train_ddp.py --config ffhq64.yml --doc ffhq_sa --ni --num_consecutive_steps 2 --lamda 0.5 --num_process_per_node 4

CelebA-HQ 256x256 with the base loss [checkpoint]

python train_ddp.py --config celebahq.yml --doc celebahq_simple --ni --num_consecutive_steps 0 --lamda 0 --num_process_per_node 4

CelebA-HQ 256x256 with the SA_DPM loss [checkpoint]

python train_ddp.py --config celebahq.yml --doc celebahq_sa --ni --num_consecutive_steps 2 --lamda 0.1 --num_process_per_node 4

Downloaded pre-trained models should be put in exp/logs/<MODEL_NAME> (e.g. <MODEL_NAME>: celeba_sa, afhq_simple, ...)

Evaluation

FID

python generate.py --config <DATASET>.yml --doc <MODEL_NAME> \
    --image_folder </path/to/save/images> --ckpt_id <ckpt_id> --num_samples 50000 \
    --timesteps <#STEPS> --eta <ETA> --ni --seeds=0-49999 \
    --model_ema --num_process_per_node <#GPUS> \
    --fid_log </path/to/file/saving/log>

where:

  • <DATASET>: cifar10, celeba, celebahq, afhq, and ffhq64.
  • <#STEPS>: the number of sampling steps (e.g. 10, 50, 100, 200, 1000).
  • <ckpt_id>: the id of the best checkpoint (e.g. 800, 850, 900, ...).
  • <ETA>: controls the scale of the variance (0 is DDIM, and 1 is one type of DDPM).
  • <#GPUS>: the number of used GPUs (e.g. 1, 2, 4, 8).

Example:

# Generate 50000 images of AFHQv2 dataset using DDIM sampling with 200 steps (4 GPUs)
python generate.py --config afhq.yml --doc afhq_sa --image_folder ./afhq_sa/1500_DDIM_T200 --ckpt_id 1500 --num_samples 50000 \
    --timesteps 200 --eta 0 --ni --seeds=0-49999 --model_ema --num_process_per_node 4 --fid_log fid_afhq_sa.txt

# Generate 10000 images of CelebA-HQ dataset using DDPM sampling with 100 steps (4 GPUs)
python generate.py --config celebahq.yml --doc celebahq_simple --image_folder ./celebahq_simple/700_DDPM_T100 --ckpt_id 700 --num_samples 10000 \
    --timesteps 100 --eta 1 --ni --seeds=0-9999 --model_ema --num_process_per_node 4 --fid_log fid_celebahq_simple.txt

Acknowledgments

This implementation is based on:

Contacts

If you have any problems, please open an issue in this repository or ping an email to viettmab123@gmail.com.

About

Official PyTorch implementation of "On Inference Stability for Diffusion Models" (AAAI'24)

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages