Pytorch implementation for Easy Consistency Tuning (ECT).
ECT unlocks state-of-the-art (SoTA) few-step generative abilities through a simple yet principled approach. With minimal tuning costs, ECT demonstrates promising early results and scales with training FLOPs and model sizes.
Try your own Consistency Models! You only need to fine-tune a bit. :D
This repository is organized in a multi-branch structure, with each branch offering a minimal implementation for a specific purpose. The current branches support the following training protocols:
main
: ECT on CIFAR-10. Best for understanding CMs and fast prototyping.amp
: Mixed Precision training with GradScaler on CIFAR-10.imgnet
: ECT ImageNet 64x64.
Baking more in the oven. 🙃
- 2024.10.12 - Add ECT code for ImgNet 64x64. Switch to the
imgnet
branch:git checkout imgnet
. - 2024.09.23 - Add Gradscaler for Mixed Precision Training. To use mixed precision with GradScaler, switch to the
amp
branch:git checkout amp
. - 2024.04.27 - Upgrade environment to Pytorch 2.3.0.
- 2024.04.12 - ECMs can now surpass SoTA GANs using 1 model step and SoTA Diffusion Models using 2 model steps on CIFAR10. Checkpoints available.
You can run the following command to set up the Python environment through conda
.
Pytorch 2.3.0 and Python 3.9.18 will be installed.
conda env create -f env.yml
Prepare the dataset in the EDM's format. See a reference here.
Run the following command to tune your SoTA 2-step ECM and match Consistency Distillation (CD) within 1 A100 GPU hour.
bash run_ecm_1hour.sh 1 <PORT> --desc bs128.1hour
Run the following command to run ECT at batch size 128 and 200k iterations. NGPUs=2/4 is recommended.
bash run_ecm.sh <NGPUs> <PORT> --desc bs128.200k
Replace NGPUs and PORT with the number of GPUs used for training and the port number for DDP sync.
In this branch, we enable fp16 training with AMP GradScaler for more stable training dynamics. To enable fp16 and GradScaler, add the following arguments to your script:
bash run_ecm_1hour.sh 1 <PORT> --desc bs128.1hour --fp16=True --enable_amp=True
For more information, please refer to this PR. Full support for Automatic Mixed Precision (AMP) will be added later.
Run the following command to calculate FID of a pretrained checkpoint.
bash eval_ecm.sh <NGPUs> <PORT> --resume <CKPT_PATH>
Taking the models trained by ECT as ECM, we compare ECMs' unconditional image generation capabilities with SoTA generative models on the CIFAR10 dataset, including popular diffusion models w/ advanced samplers, diffusion distillations, and consistency models on the CIFAR10 dataset.
Method | FID | NFE | Model | Params | Batch Size | Schedule |
---|---|---|---|---|---|---|
Score SDE | 2.38 | 2000 | NCSN++ | 56.4M | 128 | ~1600k |
Score SDE-deep | 2.20 | 2000 | NCSN++ (2 |
> 100M | 128 | ~1600k |
EDM | 2.01 | 35 | DDPM++ | 56.4M | 512 | 400k |
PD | 8.34 | 1 | DDPM++ | 56.4M | 512 | 800k |
Diff-Instruct | 4.53 | 1 | DDPM++ | 56.4M | 512 | 800k |
CD (LPIPS) | 3.55 | 1 | NCSN++ | 56.4M | 512 | 800k |
CD (LPIPS) | 2.93 | 2 | NCSN++ | 56.4M | 512 | 800k |
iCT-deep | 2.51 | 1 | NCSN++ (2 |
> 100M | 1024 | 400k |
iCT-deep | 2.24 | 2 | NCSN++ (2 |
> 100M | 1024 | 400k |
ECM (100k) | 4.54 | 1 | DDPM++ | 55.7M | 128 | 100k |
ECM (200k) | 3.86 | 1 | DDPM++ | 55.7M | 128 | 200k |
ECM (400k) | 3.60 | 1 | DDPM++ | 55.7M | 128 | 400k |
ECM (100k) | 2.20 | 2 | DDPM++ | 55.7M | 128 | 100k |
ECM (200k) | 2.15 | 2 | DDPM++ | 55.7M | 128 | 200k |
ECM (400k) | 2.11 | 2 | DDPM++ | 55.7M | 128 | 400k |
Since DINOv2 could produce evaluation better aligned with human vision, we evaluate the image fidelity using Fréchet Distance in the latent space of SoTA open-source representation model DINOv2, denoted as
Using dgm-eval, we have
Method | NFE | |
---|---|---|
EDM | 145.20 | 35 |
StyleGAN-XL | 204.60 | 1 |
ECM | 198.51 | 1 |
ECM | 128.63 | 2 |
Without combining with other generative mechanisms like GANs or diffusion distillation like Score Distillation, ECT is capable of generating high-quality samples much faster than SoTA diffusion models and much better than SoTA GANs SoTA Diffusion Models and GANs.
- CIFAR10
$\mathrm{FD}_{\text{DINOv2}}$ checkpoint.
Feel free to drop me an email at zhengyanggeng@gmail.com if you have additional questions or are interested in collaboration. You can find me on Twitter or WeChat.
@article{ect,
title={Consistency Models Made Easy},
author={Geng, Zhengyang and Pokle, Ashwini and Luo, William and Lin, Justin and Kolter, J Zico},
journal={arXiv preprint arXiv:2406.14548},
year={2024}
}