-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlabel_smoothed_el_timm.py
36 lines (35 loc) · 1.53 KB
/
label_smoothed_el_timm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# parser.add_argument('--use_el', default=0, type=int,
# help='')
# parser.add_argument('--log_end', default=0.75, type=float,
# help='')
import torch
import torch.nn as nn
import torch.nn.functional as F
class LabelSmoothingEncouragingLoss(nn.Module):
"""
Encouraging Loss with label smoothing.
"""
def __init__(self, smoothing=0.1, log_end=0.75):
"""
Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
super(LabelSmoothingEncouragingLoss, self).__init__()
assert smoothing < 1.0
self.smoothing = smoothing
self.confidence = 1. - smoothing
self.log_end=log_end
def forward(self, x, target):
logprobs = F.log_softmax(x, dim=-1)
probs = torch.exp(logprobs)
bonus = torch.log(torch.clamp((torch.ones_like(probs) - probs), min=1e-5)) # likelihood bonus
if self.log_end != 1.0: # e.g. 0.75
log_end = self.log_end
y_log_end = torch.log(torch.ones_like(probs) - log_end)
bonus_after_log_end = 1/(log_end - torch.ones_like(probs)) * (probs-log_end) + y_log_end
bonus = torch.where(probs > log_end, bonus_after_log_end, bonus)
el_loss =(bonus-logprobs).gather(dim=-1, index=target.unsqueeze(1))
el_loss = el_loss.squeeze(1)
smooth_loss = (bonus-logprobs).mean(dim=-1)
loss = self.confidence * el_loss + self.smoothing * smooth_loss
return loss.mean()