-
Notifications
You must be signed in to change notification settings - Fork 13
/
loss.py
29 lines (23 loc) · 804 Bytes
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from torch.autograd import Function
from torch import nn
# combined with cross entropy loss, instance level
class LossVariance(nn.Module):
""" The instances in target should be labeled
"""
def __init__(self):
super(LossVariance, self).__init__()
def forward(self, input, target):
B = input.size(0)
loss = 0
for k in range(B):
unique_vals = target[k].unique()
unique_vals = unique_vals[unique_vals != 0]
sum_var = 0
for val in unique_vals:
instance = input[k][:, target[k] == val]
if instance.size(1) > 1:
sum_var += instance.var(dim=1).sum()
loss += sum_var / (len(unique_vals) + 1e-8)
loss /= B
return loss