-
Notifications
You must be signed in to change notification settings - Fork 14
/
loss.py
33 lines (26 loc) · 883 Bytes
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from torch import nn
import torch
class IoU_loss(torch.nn.Module):
def __init__(self):
super(IoU_loss, self).__init__()
def forward(self, pred, target):
b = pred.shape[0]
IoU = 0.0
for i in range(0,b):
#compute the IoU of the foreground
Iand1 = torch.sum(target[i,:,:,:]*pred[i,:,:,:])
Ior1 = torch.sum(target[i,:,:,:]) + torch.sum(pred[i,:,:,:])-Iand1
IoU1 = Iand1/Ior1
#IoU loss is (1-IoU1)
IoU = IoU + (1-IoU1)
#return IoU/b
return IoU
class DSLoss_IoU_noCAM(nn.Module):
def __init__(self):
super(DSLoss_IoU_noCAM, self).__init__()
self.iou = IoU_loss()
def forward(self, scaled_preds, gt):
loss = 0
for pred_lvl in scaled_preds[1:]:
loss += self.iou(pred_lvl, gt)
return loss