-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmetrics.py
36 lines (25 loc) · 910 Bytes
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
import numpy as np
def iou_score(y_pred, y_true):
y_pred = torch.sigmoid(y_pred)
y_pred = y_pred.data.cpu().numpy()
y_true = y_true.data.cpu().numpy()
y_pred = y_pred > 0.5
y_true = y_true > 0.5
intersection = (y_pred & y_true).sum()
union = (y_pred | y_true).sum()
return intersection / union
def dice_score(y_pred, y_true, smooth=0.):
y_pred = torch.sigmoid(y_pred)
y_pred = y_pred.view(-1)
y_true = y_true.view(-1)
intersection = (y_pred * y_true).sum()
return (2. * intersection + smooth) / (y_pred.sum() + y_true.sum() + smooth)
class Loss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, y_pred, y_true):
bce = F.binary_cross_entropy_with_logits(y_pred, y_true)
dice = dice_score(y_pred, y_true, smooth=1e-3)
return 0.5*(1 - dice) + 0.5*bce