This repository provides a clean, minimal PyTorch training pipeline for an unconditional diffusion model using a custom UNet2DModel
from the HuggingFace diffusers
library. The script is designed for fast prototyping and single-class image generation, supporting mixed precision, safe image loading, basic augmentations, and periodic sample export.
- SafeImageFolder: Handles corrupted images robustly.
- Flexible Augmentation: Includes random flip and color jitter.
- Mixed Precision Training: Uses
torch.amp
for faster, memory-efficient training. - Cosine Annealing Scheduler: Smoothly anneals the learning rate.
- Progress Image Saving: Periodically generates and saves sample images during training.
- Model Checkpointing: Saves final model and scheduler states for easy resumption.
- Python 3.8+
- PyTorch
- torchvision
- diffusers
- tqdm
- Pillow
You can install the dependencies using:
pip install torch torchvision diffusers tqdm pillow
-
Prepare your dataset:
- Place all training images in a folder named
images
in the project root. - Only images (e.g.,
.jpg
,.png
) should be in this folder.
- Place all training images in a folder named
-
Run the training script:
python train.py
- Training progress and sample images will be saved in the
progress/
directory. - The final model and scheduler weights are saved as
unet_final/
andscheduler_final/
.
- Training progress and sample images will be saved in the
SafeImageFolder
: Custom Dataset for robust image loading.train_diffusion()
: Handles all training logic, including loading data, training loop, and checkpointing.save_sample()
: Denoises and saves generated samples at regular intervals.main()
: Entry point that sets up the environment and triggers training.
- Hyperparameters: Adjust
image_size
,batch_size
,num_epochs
,learning_rate
directly in the script. - Model Architecture: Modify
UNet2DModel
parameters for deeper, wider, or more complex networks. - Data Augmentation: Tune or expand the
transform
pipeline for your dataset.
This repository uses HuggingFace Diffusers and PyTorch. If you use this codebase, please consider citing the respective libraries.
This project is licensed under the MIT License.