-
Notifications
You must be signed in to change notification settings - Fork 1
/
losses.py
386 lines (328 loc) · 14.1 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import utils.util as util
from SSIM import SSIM
##########################################################################
#from IQA_pytorch import SSIM, MS_SSIM
class CharbonnierLoss1(nn.Module):
"""Charbonnier Loss (L1)"""
def __init__(self, eps=1e-6, reduction='mean'):
super(CharbonnierLoss1, self).__init__()
self.eps = eps
self.reduction = reduction
def forward(self, x, y):
diff = x - y
if self.reduction == 'mean':
loss = torch.mean(torch.sqrt(diff * diff + self.eps))
else:
loss = torch.sum(torch.sqrt(diff * diff + self.eps))
return loss
class HuberLoss(nn.Module):
"""Huber Loss (L1)"""
def __init__(self, delta=1e-2, reduction='mean'):
super(HuberLoss, self).__init__()
self.delta = delta
self.reduction = reduction
def forward(self, x, y):
abs_diff = torch.abs(x - y)
q_term = torch.min(abs_diff, torch.full_like(abs_diff, self.delta))
l_term = abs_diff - q_term
if self.reduction == 'mean':
loss = torch.mean(0.5 * q_term ** 2 + self.delta * l_term)
else:
loss = torch.sum(0.5 * q_term ** 2 + self.delta * l_term)
return loss
class TVLoss(nn.Module):
"""Total Variation Loss"""
def __init__(self):
super(TVLoss, self).__init__()
def forward(self, x):
return torch.sum(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) + \
torch.sum(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
class GWLoss(nn.Module):
"""Gradient Weighted Loss"""
def __init__(self, w=4, reduction='mean'):
super(GWLoss, self).__init__()
self.w = w
self.reduction = reduction
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float)
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float)
self.weight_x = nn.Parameter(data=sobel_x, requires_grad=False)
self.weight_y = nn.Parameter(data=sobel_y, requires_grad=False)
def forward(self, x1, x2):
b, c, w, h = x1.shape
weight_x = self.weight_x.expand(c, 1, 3, 3).type_as(x1)
weight_y = self.weight_y.expand(c, 1, 3, 3).type_as(x1)
Ix1 = F.conv2d(x1, weight_x, stride=1, padding=1, groups=c)
Ix2 = F.conv2d(x2, weight_x, stride=1, padding=1, groups=c)
Iy1 = F.conv2d(x1, weight_y, stride=1, padding=1, groups=c)
Iy2 = F.conv2d(x2, weight_y, stride=1, padding=1, groups=c)
dx = torch.abs(Ix1 - Ix2)
dy = torch.abs(Iy1 - Iy2)
# loss = torch.exp(2*(dx + dy)) * torch.abs(x1 - x2)
loss = (1 + self.w * dx) * (1 + self.w * dy) * torch.abs(x1 - x2)
if self.reduction == 'mean':
return torch.mean(loss)
else:
return torch.sum(loss)
class StyleLoss(nn.Module):
"""Style Loss"""
def __init__(self):
super(StyleLoss, self).__init__()
@staticmethod
def gram_matrix(self, x):
B, C, H, W = x.size()
features = x.view(B * C, H * W)
G = torch.mm(features, features.t())
return G.div(B * C * H * W)
def forward(self, input, target):
G_i = self.gram_matrix(input)
G_t = self.gram_matrix(target).detach()
loss = F.mse_loss(G_i, G_t)
return loss
class GANLoss(nn.Module):
"""GAN loss (vanilla | lsgan | wgan-gp)"""
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
super(GANLoss, self).__init__()
self.gan_type = gan_type.lower()
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
if self.gan_type == 'gan' or self.gan_type == 'ragan':
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'wgan-gp':
def wgan_loss(input, target):
# target is boolean
return -1 * input.mean() if target else input.mean()
self.loss = wgan_loss
else:
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
def get_target_label(self, input, target_is_real):
if self.gan_type == 'wgan-gp':
return target_is_real
if target_is_real:
return torch.empty_like(input).fill_(self.real_label_val)
else:
return torch.empty_like(input).fill_(self.fake_label_val)
def forward(self, input, target_is_real):
target_label = self.get_target_label(input, target_is_real)
loss = self.loss(input, target_label)
return loss
class GradientPenaltyLoss(nn.Module):
"""Gradient Penalty Loss"""
def __init__(self, device=torch.device('cpu')):
super(GradientPenaltyLoss, self).__init__()
self.register_buffer('grad_outputs', torch.Tensor())
self.grad_outputs = self.grad_outputs.to(device)
def get_grad_outputs(self, input):
if self.grad_outputs.size() != input.size():
self.grad_outputs.resize_(input.size()).fill_(1.0)
return self.grad_outputs
def forward(self, interp, interp_crit):
grad_outputs = self.get_grad_outputs(interp_crit)
grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
grad_outputs=grad_outputs, create_graph=True,
retain_graph=True, only_inputs=True)[0]
grad_interp = grad_interp.view(grad_interp.size(0), -1)
grad_interp_norm = grad_interp.norm(2, dim=1)
loss = ((grad_interp_norm - 1) ** 2).mean()
return loss
class PyramidLoss(nn.Module):
"""Pyramid Loss"""
def __init__(self, num_levels=3, pyr_mode='gau', loss_mode='l1', reduction='mean'):
super(PyramidLoss, self).__init__()
self.num_levels = num_levels
self.pyr_mode = pyr_mode
self.loss_mode = loss_mode
assert self.pyr_mode == 'gau' or self.pyr_mode == 'lap'
if self.loss_mode == 'l1':
self.loss = nn.L1Loss(reduction=reduction)
elif self.loss_mode == 'l2':
self.loss = nn.MSELoss(reduction=reduction)
elif self.loss_mode == 'hb':
self.loss = HuberLoss(reduction=reduction)
elif self.loss_mode == 'cb':
self.loss = CharbonnierLoss1(reduction=reduction)
else:
raise ValueError()
def forward(self, x, y):
B, C, H, W = x.shape
device = x.device
gauss_kernel = util.gauss_kernel(size=5, device=device, channels=C)
if self.pyr_mode == 'gau':
pyr_x = util.gau_pyramid(img=x, kernel=gauss_kernel, max_levels=self.num_levels)
pyr_y = util.gau_pyramid(img=y, kernel=gauss_kernel, max_levels=self.num_levels)
else:
pyr_x = util.lap_pyramid(img=x, kernel=gauss_kernel, max_levels=self.num_levels)
pyr_y = util.lap_pyramid(img=y, kernel=gauss_kernel, max_levels=self.num_levels)
loss = 0
for i in range(self.num_levels):
loss += self.loss(pyr_x[i], pyr_y[i])
return loss
class LapPyrLoss(nn.Module):
"""Pyramid Loss"""
def __init__(self, num_levels=3, lf_mode='ssim', hf_mode='cb', reduction='mean'):
super(LapPyrLoss, self).__init__()
self.num_levels = num_levels
self.lf_mode = lf_mode
self.hf_mode = hf_mode
# if lf_mode == 'ssim':
# #self.lf_loss = SSIM(channels=1)
# self.lf_loss = SSIM(channels=1)
# elif lf_mode == 'cb':
# self.lf_loss = CharbonnierLoss1(reduction=reduction)
# else:
# raise ValueError()
# if hf_mode == 'ssim':
# #self.hf_loss = SSIM(channels=1)
# self.hf_loss = SSIM(channels=1)
# elif hf_mode == 'cb':
# self.hf_loss = CharbonnierLoss1(reduction=reduction)
# else:
# raise ValueError()
self.CharLoss = CharbonnierLoss1()
def forward(self, x_img, y_img, x, y):
B, C, H, W = x.shape
device = x.device
gauss_kernel = util.gauss_kernel(size=5, device=device, channels=C)
pyr_x = util.laplacian_pyramid(img=x, kernel=gauss_kernel, max_levels=self.num_levels)
pyr_y = util.laplacian_pyramid(img=y, kernel=gauss_kernel, max_levels=self.num_levels)
#loss = self.lf_loss(pyr_x[-1], pyr_y[-1])
loss = self.CharLoss(x_img, y_img)
for i in range(self.num_levels - 1):
#loss += self.hf_loss(pyr_x[i], pyr_y[i])
loss += self.CharLoss(pyr_x[i], pyr_y[i])
return loss
#########################################################################################
def dwt_init(x):
x01 = x[:, :, 0::2, :] / 2
x02 = x[:, :, 1::2, :] / 2
x1 = x01[:, :, :, 0::2]
x2 = x02[:, :, :, 0::2]
x3 = x01[:, :, :, 1::2]
x4 = x02[:, :, :, 1::2]
x_LL = x1 + x2 + x3 + x4
x_HL = -x1 - x2 + x3 + x4
x_LH = -x1 + x2 - x3 + x4
x_HH = x1 - x2 - x3 + x4
return [x_LL, x_HL, x_LH, x_HH]#torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
# 使用哈尔 haar 小波变换来实现二维逆向离散小波
def iwt_init(x):
r = 2
in_batch, in_channel, in_height, in_width = x.size()
# print([in_batch, in_channel, in_height, in_width]) #[1, 12, 56, 56]
out_batch, out_channel, out_height, out_width = in_batch, int(in_channel / (r**2)), r * in_height, r * in_width
# print(out_batch, out_channel, out_height, out_width) #1 3 112 112
x1 = x[:, 0:out_channel, :, :] / 2
x2 = x[:, out_channel:out_channel * 2, :, :] / 2
x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2
# print(x1.shape) #torch.Size([1, 3, 56, 56])
# print(x2.shape) #torch.Size([1, 3, 56, 56])
# print(x3.shape) #torch.Size([1, 3, 56, 56])
# print(x4.shape) #torch.Size([1, 3, 56, 56])
# h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()
h = torch.zeros([out_batch, out_channel, out_height, out_width]).float()
h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
return h
# 二维离散小波
class DWT(nn.Module):
def __init__(self):
super(DWT, self).__init__()
self.requires_grad = False # 信号处理,非卷积运算,不需要进行梯度求导
def forward(self, x):
return dwt_init(x)
# 逆向二维离散小波
class IWT(nn.Module):
def __init__(self):
super(IWT, self).__init__()
self.requires_grad = False
def forward(self, x):
return iwt_init(x)
class CharbonnierLoss_dwt(nn.Module):
"""Charbonnier Loss (L1)"""
def __init__(self, eps=1e-3):
super(CharbonnierLoss_dwt, self).__init__()
self.eps = eps
self.target_down = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False)
self.x_dwt = DWT()
self.y_dwt = DWT()
def forward(self, x, y):
x_fea = self.x_dwt(x)
y_fea = self.y_dwt(y)
#_, _, x_kw, x_kh = x_fea[0].shape
#_, _, y_kw, y_kh = y_fea[0].shape
#if x_kw == y_kw:
#diff = x_fea - y_fea
#else:
#diff = x_fea - self.target_down(y_fea)
# loss = torch.sum(torch.sqrt(diff * diff + self.eps))
loss = np.sum([torch.mean(torch.sqrt(((x_fea[j]-y_fea[j]) * (x_fea[j]-y_fea[j])) + (self.eps*self.eps))) for j in range(len(x_fea))])
#loss = torch.mean(torch.sqrt(((x_fea[j]-y_fea[j]) * (x_fea[j]-y_fea[j])) + (self.eps*self.eps)))
return loss
class CharbonnierLoss(nn.Module):
"""Charbonnier Loss (L1)"""
def __init__(self, eps=1e-3):
super(CharbonnierLoss, self).__init__()
self.eps = eps
self.target_down = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False)
def forward(self, x, y):
_, _, x_kw, x_kh = x.shape
_, _, y_kw, y_kh = y.shape
if x_kw == y_kw:
diff = x - y
else:
diff = x - self.target_down(y)
# loss = torch.sum(torch.sqrt(diff * diff + self.eps))
loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
return loss
class L1smooth(nn.Module):
"""L1smooth (L1)"""
def __init__(self):
super(L1smooth, self).__init__()
self.target_down = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False)
self.L1_smooth = torch.nn.SmoothL1Loss()
def forward(self, x, y):
_, _, x_kw, x_kh = x.shape
_, _, y_kw, y_kh = y.shape
if x_kw == y_kw:
loss = self.L1_smooth(x,y)
else:
loss = self.L1_smooth(x,self.target_down(y))
return loss
class EdgeLoss(nn.Module):
def __init__(self):
super(EdgeLoss, self).__init__()
k = torch.Tensor([[.05, .25, .4, .25, .05]])
self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1)
if torch.cuda.is_available():
self.kernel = self.kernel.cuda()
self.loss = CharbonnierLoss()
self.target_down = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False)
def conv_gauss(self, img):
n_channels, _, kw, kh = self.kernel.shape
img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
return F.conv2d(img, self.kernel, groups=n_channels)
def laplacian_kernel(self, current):
filtered = self.conv_gauss(current) # filter
down = filtered[:,:,::2,::2] # downsample
new_filter = torch.zeros_like(filtered)
new_filter[:,:,::2,::2] = down*4 # upsample
filtered = self.conv_gauss(new_filter) # filter
diff = current - filtered
return diff
def forward(self, x, y):
_, _, x_kw, x_kh = x.shape
_, _, y_kw, y_kh = y.shape
if x_kw == y_kw:
loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
else:
loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(self.target_down(y)))
#loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
return loss