English | 中文
[TOC]
该模型是一个盲人脸修复模型。作者将前人提出的 StyleGAN V2 的解码器嵌入模型,作为GPEN的解码器;用DNN重新构建了一种简单的编码器,为解码器提供输入。这样模型在保留了 StyleGAN V2 解码器优秀的性能的基础上,将模型的功能由图像风格转换变为了盲人脸修复。模型的总体结构如下图所示:
该模型的总体结构如上图所示。左边的区域为生成器(Generator)的结构,绿色部分为编码器,中间的Mapping Network用于特征映射,这两部分为作者添加的结构。蓝色的部分为解码器,const为作者添加的噪声输入。最右侧的为对抗网络的鉴别器。其中解码器、鉴别器的结构与StyleGAN V2一致。
本项目也集成到了百度飞浆AI Studio中,可更快进行复现。
地址:https://aistudio.baidu.com/aistudio/projectdetail/3936241?contributionType=1
GPEN模型训练集是经典的FFHQ人脸数据集,共70000张1024 x 1024高分辨率的清晰人脸图片,测试集是CELEBA-HQ数据集,共2000张高分辨率人脸图片。详细信息可以参考数据集网址: FFHQ ,CELEBA-HQ 。以下给出了具体的下载链接:
数据集下载地址:
FFHQ : https://drive.google.com/drive/folders/1tZUcXDBeOibC6jcMCtgRRz67pzrAHeHL?usp=drive_open
下载后,文件组织形式如下
|-- data/GPEN
|-- train
|-- 00000
|-- 00000.png
|-- 00001.png
|-- ......
|-- 00999.png
|-- 01000
|-- ......
|-- ......
|-- 69000
|-- ......
|-- 69999.png
|-- test
|-- 2000张png图片
模型参数文件及训练日志下载地址:
链接: 提取码:
从链接中下载模型参数,并放到项目根目录下的data文件夹下,这样data下有个GPEN文件夹,FFA文件夹下包含四个模型参数文件以及一个参考训练日志文件,具体文件结构如下所示:
文件结构
data/GPEN
|-- model_ir_se50_2.pdparams #复现的模型经过400000step训练后得到的室内去雾模型的参数文件
|-- G_256_repo.pdparams #复现的256分辨率生成器的的模型参数文件
|-- D_256_repo.pdparams #复现的256分辨率鉴别器的的模型参数文件
|-- G_256_article_pretrained.pdparams #作者提供的256分辨率的生成器的模型参数文件
|-- train.log #完整的训练日志文件
注:由于paddle.grad存在问题,还未进行完整的训练过程,以下提供一个临时链接,
链接中的模型参数文件与G_256_article_pretrained.pdparams相同,可更名后进行测试。
经过测试,训练512 x 512分辨率的模型所需的时间是256 x 256分辨率模型的三倍,因此在这里我只复现了256 x 256 分辨率的GPEN模型,来验证复现的效果。当然该项目也支持复现512分辨率的模型,只需修改配置文件中模型的size参数为512即可。256分辨率模型的复现精度如下表所示:
CELE数据集测试精度(256 x 256 分辨率):
Backbone | Train dataset | Test dataset | FID | PSNR | checkpoints |
---|---|---|---|---|---|
GPEN | FFHQ | CELEBA-HQ | 123.48 | 21.85 | G_256_article_pretrained.pdparams |
在控制台输入以下代码,开始训练:
python train.py --size=256 --mul=1 --narrow=0.5 --start_iter=0 --max_iter=15000 --batch_size=2 --pretrain=None --train_path='data/GPEN/train/' --test_path='data/GPEN/test/'
训练过程中会在ckpts/文件夹下生成train.log文件夹,用于保存训练日志。
如果你想要在其他数据集上训练网络,在 configs/GPEN.yaml 中修改数据集的路径。
如果要修改模型的参数,可修改的参数主要是size、mul和narrow,决定了模型针对的图片的分辨率,推荐选用256或512。如果要将size改为512,则要同时将mul改为2、narrow改为1,下面的操作步骤也是这样。
如果要改变训练的step数,需要修改max_iter 参数。
如果要从训练断点继续训练,则修改pretrain参数为模型参数文件位置,并根据需要修改start_iter来保证训练日志的延续性。
模型只支持单卡训练。
模型训练需使用paddle2.3及以上版本,且需等paddle实现elementwise_pow 的二阶算子相关功能,使用paddle2.2.2版本能正常运行,但因部分损失函数会求出错误梯度,导致模型无法训练成功。如训练时报错则暂不支持进行训练,可跳过训练部分,直接使用提供的模型参数进行测试。模型评估和测试使用paddle2.2.2及以上版本即可。
对模型进行评估时,在控制台输入以下代码,下面代码中使用上面提到的下载的模型参数:
###对作者提供的模型进行评估###
python test.py -w data/GPEN/G_256_article_pretrained.pdparams --test_path='data/GPEN/test/' --size=256 --mul=1 --narrow=0.5
###对我复现的模型进行评估###
python test.py -w data/GPEN/G_256_repo.pdparams --test_path='data/GPEN/test/' --size=256 --mul=1 --narrow=0.5
如果要测试你自己准备的图像,请更改test_path参数。
如果要在自己提供的模型上进行测试,请将模型的路径放在 -w 后面。
对模型进行单图像的简单测试时,在控制台输入以下代码,下面代码中使用上面提到的下载的模型参数:
###对作者提供的模型进行测试###
python predict.py --size 256 --mul=1 --narrow=0.5 --w data/GPEN/G_256_article_pretrained.pdparams --img data/GPEN/predict/test_img.jpg
###对我复现的模型进行测试###
python predict.py --size 256 --mul=1 --narrow=0.5 --w data/GPEN/G_256_repo.pdparams --img data/GPEN/predict/test_img.jpg
如果要在自己提供的模型上进行测试,请将模型的路径放在 -w 后面。如要修改测试的图片,请修改--img后的模型参数。
python export_model.py --model_path data/GPEN/G_256_repo.pdparams --save_dir inference/GPEN
上述命令将生成预测所需的模型结构文件model.pdmodel
和模型权重文件model.pdiparams
以及model.pdiparams.info
文件,均存放在inference/GPEN/
目录下。
python infer.py --model_file inference/GPEN/model.pdmodel --params_file inference/GPEN/model.pdiparams --img data/GPEN/predict/test_img.jpg
推理结束会默认保存下模型生成的修复图像,并输出测试得到的FID和psnr值。
以下是样例图片和对应的修复图像:
输出示例如下:
result saved in : data/GPEN/predict/test_img_predict.png
FID: 151.78178552134233
PSNR:21.65281356833421
注:由于对高清图片进行退化的操作具有一定的随机性,所以每次测试的结果都会有所不同。为了保证测试结果一致,在这里我固定了随机种子,使每次测试时对图片都进行相同的退化操作。
测试基本训练预测功能的lite_train_lite_infer
模式,运行:
# 准备数据
bash test_tipc/prepare.sh ./test_tipc/configs/GPEN/train_infer_python.txt 'lite_train_lite_infer'
# 运行测试
bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/GPEN/train_infer_python.txt 'lite_train_lite_infer'
GPEN/
|-- data/ #存放数据集及下载的模型参数文件的文件夹
|-- data_loader/ #存放数据预处理相关的代码
|-- model/ #存放GPEN模型结构相关的代码
|-- loss/ #存放损失函数计算相关的代码
|-- metric/ #存放计算评估指标(FID,PSNR)相关的代码
|-- test_tipc/ #存放tipc相关文件
|-- figs/ #存放说明文档用到的图片
|-- ckpts/ #训练时生成的文件夹,用于存放训练过程中保存的模型参数
|-- samples/ #训练时生成的文件夹,用于存放训练过程中保存的测试图片,来直观地展示训练过程中模型生成图片的变化
|-- train.py #模型训练时调用
|-- test.py #模型评估时调用
|-- predict.py #用模型测试单张图片时调用
|-- export_model.py #tipc生成推理模型时调用
|-- infer.py #tipc进行推理时调用
|-- readme.md #项目说明文档
本项目的发布受Apache 2.0 license许可认证。
论文地址:https://paperswithcode.com/paper/gan-prior-embedded-network-for-blind-face
参考repo Github:https://github.com/yangxy/GPEN
readme文档模板:https://github.com/PaddlePaddle/models/blob/release/2.2/community/repo_template/README.md