Skip to content

A minimal PyTorch script for training an unconditional diffusion model on your own images. Includes safe image loading, mixed precision, and periodic sample generation.

License

Notifications You must be signed in to change notification settings

anto18671/diffusion-model

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Minimal Unconditional Diffusion Model Trainer

This repository provides a clean, minimal PyTorch training pipeline for an unconditional diffusion model using a custom UNet2DModel from the HuggingFace diffusers library. The script is designed for fast prototyping and single-class image generation, supporting mixed precision, safe image loading, basic augmentations, and periodic sample export.


Features

  • SafeImageFolder: Handles corrupted images robustly.
  • Flexible Augmentation: Includes random flip and color jitter.
  • Mixed Precision Training: Uses torch.amp for faster, memory-efficient training.
  • Cosine Annealing Scheduler: Smoothly anneals the learning rate.
  • Progress Image Saving: Periodically generates and saves sample images during training.
  • Model Checkpointing: Saves final model and scheduler states for easy resumption.

Requirements

  • Python 3.8+
  • PyTorch
  • torchvision
  • diffusers
  • tqdm
  • Pillow

You can install the dependencies using:

pip install torch torchvision diffusers tqdm pillow

Usage

  1. Prepare your dataset:

    • Place all training images in a folder named images in the project root.
    • Only images (e.g., .jpg, .png) should be in this folder.
  2. Run the training script:

    python train.py
    • Training progress and sample images will be saved in the progress/ directory.
    • The final model and scheduler weights are saved as unet_final/ and scheduler_final/.

Training Script Structure

  • SafeImageFolder: Custom Dataset for robust image loading.
  • train_diffusion(): Handles all training logic, including loading data, training loop, and checkpointing.
  • save_sample(): Denoises and saves generated samples at regular intervals.
  • main(): Entry point that sets up the environment and triggers training.

Customization

  • Hyperparameters: Adjust image_size, batch_size, num_epochs, learning_rate directly in the script.
  • Model Architecture: Modify UNet2DModel parameters for deeper, wider, or more complex networks.
  • Data Augmentation: Tune or expand the transform pipeline for your dataset.

Citation

This repository uses HuggingFace Diffusers and PyTorch. If you use this codebase, please consider citing the respective libraries.


License

This project is licensed under the MIT License.

About

A minimal PyTorch script for training an unconditional diffusion model on your own images. Includes safe image loading, mixed precision, and periodic sample generation.

Topics

Resources

License

Stars

Watchers

Forks

Languages