-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
133 lines (111 loc) · 5.38 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import torch
import network
import dataset
import argparse
import numpy as np
from tqdm import tqdm
from torch import nn, optim
from torch.nn import functional as F
from torchvision.utils import save_image
import math
from loss_function import *
def psnr1(img1, img2):
mse = torch.mean((img1 / 1.0 - img2 / 1.0) ** 2)
if mse < 1.0e-10:
return 100
return 10 * math.log10(255.0 ** 2 / mse)
def train(args):
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
model = network.Zero_Net()
model = model.cuda()
best_psnr1 = 0.0
best_psnr2 = 0.0
optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
mse = nn.L1Loss().cuda()
content_folder1 = '/home/wenwen/Dehaze_project/wen/trains/'
# gts is generated by the image with the help of the DCP algorithm.
information_folder = '/home/wenwen/Dehaze_project/wen/gts/'
train_loader = dataset.style_loader(content_folder1, information_folder, args.size, args.batch_size)
num_batch = len(train_loader)
for epoch in range(0, args.epoch):
for phase in ['train']:
if phase == 'train':
model.train(True)
loop = tqdm(enumerate(train_loader), total=num_batch)
for idx, batch in loop:
content = batch[0].float().cuda()
information = batch[1].float().cuda()
optimizer.zero_grad()
output = model(content)
# infromation via DCP; output ** 2.2 is Gamma correct;
total_loss = 0.4*mse(output, information) + 0.2 *exposure_control_loss(output) + 0.2 *color_constency_loss(output) + 0.1 * output ** 2.2
total_loss.backward()
optimizer.step()
loop.set_description(f'Epoch [{epoch+1}/{args.epoch}]')
loop.set_postfix(loss = total_loss.item())
if (epoch + 1) % 5 == 0:
out_image = torch.cat([content[0:1], output[0:1], information[0:1]], dim=0)
save_image(out_image, args.save_dir + '/image/iter{}_h.jpg'.format(epoch + 1))
else:
PSNR = 0
error = 0
with torch.no_grad():
# 调用模型测试
model.eval()
# 依次获取所有图像,参与模型训练或测试
for idx, batch in tqdm(enumerate(test_loader1), total=t_num_batch1):
# 获取输入
content = batch[0].float().cuda()
information = batch[1].float().cuda()
output = model(content)
try:
PSNR += psnr1(information*255, output*255)
except Exception as e:
error+=1
print("IndexError Details : " + str(e))
continue
PSNR = PSNR/(len(test_loader1.dataset)-error)
if PSNR>best_psnr1:
best_psnr1 = PSNR
best_epoch1 = epoch
#torch.save(model.state_dict(), 'model' + '/our_deblur_UIEBD_{}.pth'.format("best"))
#print('Best UIEBD val psnr: {:4f},Best epoch:{}'.format(best_psnr1, best_epoch1))
#print('UIEBD val psnr: {:4f}'.format(PSNR))
print('UIEBD val psnr: {:4f}'.format(PSNR))
PSNR = 0
error = 0
with torch.no_grad():
# 调用模型测试
model.eval()
# 依次获取所有图像,参与模型训练或测试
for idx, batch in tqdm(enumerate(test_loader2), total=t_num_batch2):
# 获取输入
content = batch[0].float().cuda()
information = batch[1].float().cuda()
output = model(content)
try:
PSNR += psnr1(information*255, output*255)
except Exception as e:
error+=1
print("IndexError Details : " + str(e))
continue
PSNR = PSNR/(len(test_loader2.dataset)-error)
if PSNR>best_psnr2:
best_psnr2 = PSNR
best_epoch2 = epoch
#torch.save(model.state_dict(), 'model' + '/our_deblur_CYCLE_{}.pth'.format("best"))
#print('Best CYCLE val psnr: {:4f},Best epoch:{}'.format(best_psnr2, best_epoch2))
print('CYCLE val psnr: {:4f}'.format(PSNR))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--gpu_id', default=0, type=int)
parser.add_argument('--epoch', default=1000, type=int)
parser.add_argument('--size', default=512, type=int)
parser.add_argument('--batch_size', default=8, type=int)
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--save_dir', default='result', type=str)
args = parser.parse_args()
if not os.path.exists(args.save_dir):
os.mkdir(args.save_dir)
train(args)