-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
74 lines (61 loc) · 2.52 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.nn.functional as F
import numpy as np
def weighted_mse_loss(inputs, targets, weights=None):
loss = (inputs - targets) ** 2
if weights is not None:
loss *= weights.expand_as(loss)
loss = torch.mean(loss)
return loss
def weighted_focal_mse_loss(inputs, targets, weights=None, activate='sigmoid', beta=.2, gamma=1):
loss = (inputs - targets) ** 2
loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \
(2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma
if weights is not None:
loss *= weights.expand_as(loss)
loss = torch.mean(loss)
return loss
def weighted_focal_l1_loss(inputs, targets, weights=None, activate='sigmoid', beta=.2, gamma=1):
loss = F.l1_loss(inputs, targets, reduction='none')
loss *= (torch.tanh(beta * torch.abs(inputs - targets))) ** gamma if activate == 'tanh' else \
(2 * torch.sigmoid(beta * torch.abs(inputs - targets)) - 1) ** gamma
if weights is not None:
loss *= weights.expand_as(loss)
loss = torch.mean(loss)
return loss
def balanced_l1_loss(inputs, targets, beta=1.0, alpha=0.5, gamma=1.5, loss_weight=1.0):
"""Balanced L1 Loss
reference: https://paperswithcode.com/method/balanced-l1-loss
arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019)
"""
assert beta > 0
assert inputs.size() == targets.size() and targets.numel() > 0
diff = torch.abs(inputs - targets)
b = np.e**(gamma / alpha) - 1
loss = torch.where(
diff < beta,
alpha / b * (b * diff + 1) * torch.log(b * diff / beta + 1) - alpha * diff,
gamma * diff + gamma / b - alpha * beta)
return torch.mean(loss * loss_weight)
def weighted_huber_loss(inputs, targets, weights=None, beta=1.):
l1_loss = torch.abs(inputs - targets)
cond = l1_loss < beta
loss = torch.where(cond, 0.5 * l1_loss ** 2 / beta, l1_loss - 0.5 * beta)
if weights is not None:
loss *= weights.expand_as(loss)
loss = torch.mean(loss)
return loss
def choose_criterion_type(loss_type):
if loss_type == 'mse':
criterion = weighted_mse_loss
elif loss_type == 'weighted_huber':
criterion = weighted_huber_loss
elif loss_type == 'weighted_focal_l1':
criterion = weighted_focal_l1_loss
elif loss_type == 'balanced_l1':
criterion = balanced_l1_loss
elif loss_type == 'weighted_focal_mse':
criterion = weighted_focal_mse_loss
else:
criterion = None
return criterion