Skip to content

Codebase for "CtrLoRA: An Extensible and Efficient Framework for Controllable Image Generation"

License

Notifications You must be signed in to change notification settings

xyfJASON/ctrlora

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

75 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

To create your customized ControlNet in an easy and low-cost manner πŸŽ‰

banner

style-transfer

The images are compressed for loading speed.

CtrLoRA: An Extensible and Efficient Framework for Controllable Image Generation
Yifeng Xu1,2, Zhenliang He1, Shiguang Shan1,2, Xilin Chen1,2
1Key Lab of AI Safety, Institute of Computing Technology, CAS, China
2University of Chinese Academy of Sciences, China

base-conditions

We first train a Base ControlNet along with condition-specific LoRAs on base conditions with a large-scale dataset. Then, our Base ControlNet can be efficiently adapted to novel conditions by new LoRAs with as few as 1,000 images and less than 1 hour on a single GPU.

πŸ“œ Content

🎨 Visual Results

🎨 Controllable generation on "base conditions"

base-conditions

🎨 Controllable generation on "novel conditions"

novel-conditions

🎨 Integration into community models & Multi-conditional generation

integration

🎨 Application to style transfer

style-transfer

πŸ› οΈ Installation

Clone this repo:

git clone --depth 1 https://github.com/xyfJASON/ctrlora.git
cd ctrlora

Create and activate a new conda environment:

conda create -n ctrlora python=3.10
conda activate ctrlora

Install pytorch and other dependencies:

pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
pip install -r requirements.txt

πŸ€–οΈ Download Pretrained Models

We provide our pretrained models here. Please put the Base ControlNet (ctrlora_sd15_basecn700k.ckpt) into ./ckpts/ctrlora-basecn and the LoRAs into ./ckpts/ctrlora-loras. The naming convention of the LoRAs is ctrlora_sd15_<basecn>_<condition>.ckpt for base conditions and ctrlora_sd15_<basecn>_<condition>_<images>_<steps>.ckpt for novel conditions.

You also need to download the SD1.5-based Models and put them into ./ckpts/sd15. Models used in our work:

πŸš€ Gradio Demo

python app/gradio_ctrlora.py

Requires at least 9GB/21GB GPU RAM to generate a batch of one/four 512x512 images.

πŸš€ Single-conditional generation

  1. select the Stable Diffusion checkpoint, Base Controlnet checkpoint and LoRA checkpoint.
  2. write prompts and negative prompts. We provide several commonly used prompts.
  3. prepare a condition image
    • upload an image to the left of the "Condition" panel, select the preprocessor corresponding to the LoRA, and click "Detect".
    • or upload the condition image directly, select the "none" preprocessor, and click "Detect".
  4. click "Run" to generate images.
  5. if you upload any new checkpoints, restart gradio or click "Refresh".

gradio

πŸš€ Multi-conditional generation

gradio2

πŸš€ Application to style transfer

  1. select a stylized Stable Diffusion checkpoint to specify the target style, e.g., Pixel.
  2. select the Base ControlNet checkpoint.
  3. select palette for the LoRA1 checkpoint and lineart for the LoRA2 checkpoint.
    • palette + canny or palette + hed also work, maybe there are more interesting combinations to be discovered
  4. write prompts and negative prompts.
  5. upload the source image to the "Condition 1" panel, select the "none" preprocessor, and click "Detect".
  6. upload the source image to the "Condition 2" panel, select the "lineart" preprocessor, and click "Detect".
  7. adjust the weights for the two conditions in the "Basic options" panel.
  8. click "Run" to generate images.

gradio3

πŸš— Python API

Besides the Gradio demo, you can also sample images with the following Python code.

πŸš— Single-conditional generation

from api import CtrLoRA

ctrlora = CtrLoRA(num_loras=1)
ctrlora.create_model(
    sd_file='ckpts/sd15/v1-5-pruned.ckpt',
    basecn_file='ckpts/ctrlora-basecn/ctrlora_sd15_basecn700k.ckpt',
    lora_files='ckpts/ctrlora-loras/novel-conditions/ctrlora_sd15_basecn700k_inpainting_brush_rank128_1kimgs_1ksteps.ckpt',
)
samples = ctrlora.sample(
    cond_image_paths='assets/test_images/inpaint_cat.png',
    prompt='A cat wearing a brown cowboy hat, best quality',
    n_prompt='worst quality',
    num_samples=1,
)
samples[0].show()

πŸš— Multi-conditional generation

from api import CtrLoRA

ctrlora = CtrLoRA(num_loras=2)
ctrlora.create_model(
    sd_file='ckpts/sd15/v1-5-pruned.ckpt',
    basecn_file='ckpts/ctrlora-basecn/ctrlora_sd15_basecn700k.ckpt',
    lora_files=('ckpts/ctrlora-loras/novel-conditions/ctrlora_sd15_basecn700k_lineart_rank128_1kimgs_1ksteps.ckpt',
                'ckpts/ctrlora-loras/novel-conditions/ctrlora_sd15_basecn700k_palette_rank128_100kimgs_100ksteps.ckpt'),
)
samples = ctrlora.sample(
    cond_image_paths=('assets/test_images/lineart_bird.png',
                      'assets/test_images/palette_bird.png'),
    prompt='Photo of a parrot, best quality',
    n_prompt='worst quality',
    num_samples=1,
    lora_weights=(1.0, 1.0),
)
samples[0].show()

πŸ”₯ Train a LoRA for Your Custom Condition

Based on our Base ControlNet, you can train a LoRA for your custom condition with as few as 1,000 images and less than 1 hour on a single GPU (20GB).

First, download the Stable Diffusion v1.5 (v1-5-pruned.ckpt) into ./ckpts/sd15 and the Base ControlNet (ctrlora_sd15_basecn700k.ckpt) into ./ckpts/ctrlora-basecn as described above.

Second, put your custom data into ./data/<custom_data_name> with the following structure:

data
└── custom_data_name
    β”œβ”€β”€ prompt.json
    β”œβ”€β”€ source
    β”‚   β”œβ”€β”€ 0000.jpg
    β”‚   β”œβ”€β”€ 0001.jpg
    β”‚   └── ...
    └── target
        β”œβ”€β”€ 0000.jpg
        β”œβ”€β”€ 0001.jpg
        └── ...
  • source contains condition images, such as canny edges, segmentation maps, depth images, etc.
  • target contains ground-truth images corresponding to the condition images.
  • each line of prompt.json should follow the format like {"source": "source/0000.jpg", "target": "target/0000.jpg", "prompt": "The quick brown fox jumps over the lazy dog."}.

Third, run the following command to train the LoRA for your custom condition:

python scripts/train_ctrlora_finetune.py \
    --dataroot ./data/<custom_data_name> \
    --config ./configs/ctrlora_finetune_sd15_rank128.yaml \
    --sd_ckpt ./ckpts/sd15/v1-5-pruned.ckpt \
    --cn_ckpt ./ckpts/ctrlora-basecn/ctrlora_sd15_basecn700k.ckpt \
    [--name NAME] \
    [--max_steps MAX_STEPS]
  • --dataroot: path to the custom data.
  • --name: name of the experiment. The logging directory will be ./runs/name. Default: current time.
  • --max_steps: maximum number of training steps. Default: 100000.

After training, extract the LoRA weights with the following command:

python scripts/tool_extract_weights.py -t lora --ckpt CHECKPOINT --save_path SAVE_PATH
  • --ckpt: path to the checkpoint produced by the above training.
  • --save_path: path to save the extracted LoRA weights.

Finally, put the extracted LoRA into ./ckpts/ctrlora-loras and use it in the Gradio demo.

πŸ“š Detailed Instructions

Please refer to the instructions here for more details of training, fine-tuning, and evaluation.

πŸͺ§ Acknowledgement

This project is built upon Stable Diffusion, ControlNet, and UniControl. Thanks for their great work!

πŸ–‹οΈ Citation

If you find this project helpful, please consider citing:

@article{xu2024ctrlora,
  title={CtrLoRA: An Extensible and Efficient Framework for Controllable Image Generation},
  author={Xu, Yifeng and He, Zhenliang and Shan, Shiguang and Chen, Xilin},
  journal={arXiv preprint arXiv:2410.09400},
  year={2024}
}

About

Codebase for "CtrLoRA: An Extensible and Efficient Framework for Controllable Image Generation"

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published