Skip to content

Twilight92z/DDPM_pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DDPM in Pytorch

Denoising Diffusion Probablistic Model的Pytorch实现。这是一种新的生成模型方法,过程如图所示:

img

介绍

原理参考苏剑林的博客:

代码实现参考:

说明

使用CIFAR10训练数据集训练80个epoch,使用Adamw优化器,学习率5e-5,结果如图 img

1. 训练

config.yml文件中修改训练所需要的参数

# model_path is the path to save checkpoint
python main.py --mode train --model_path model.pth

2. 推理

config.yml文件中修改推理所需要的参数,模型配置参数要与训练时所用参数一致

# output_path is the path to save output picture, and the nums is the batch_size of generation
python main.py --mode predict --nums 10 --output_path result.png

环境

Package Version
opencv-python 4.9.0.80
torch 1.13.1

About

DDPM in Pytorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages