-
Notifications
You must be signed in to change notification settings - Fork 59
/
utils.py
68 lines (57 loc) · 1.93 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import numpy as np
from sklearn.metrics import precision_recall_curve, auc
from torch.utils.data import Dataset
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
print(device)
def train(data, model, optim, criterion, lbd, max_clip_norm=5):
model.train()
input = data[:, :-1].to(device)
label = data[:, -1].float().to(device)
model.train()
optim.zero_grad()
logits, kld = model(input)
logits = logits.squeeze(-1)
kld = kld.sum()
bce = criterion(logits, label)
loss = bce + lbd * kld
torch.nn.utils.clip_grad_norm_(model.parameters(), max_clip_norm)
loss.backward()
optim.step()
return loss.item(), kld.item(), bce.item()
def evaluate(model, data_iter, length):
model.eval()
y_pred = np.zeros(length)
y_true = np.zeros(length)
y_prob = np.zeros(length)
pointer = 0
for data in data_iter:
input = data[:, :-1].to(device)
label = data[:, -1]
batch_size = len(label)
probability, _ = model(input)
probability = torch.sigmoid(probability.squeeze(-1).detach())
predicted = probability > 0.5
y_true[pointer: pointer + batch_size] = label.numpy()
y_pred[pointer: pointer + batch_size] = predicted.cpu().numpy()
y_prob[pointer: pointer + batch_size] = probability.cpu().numpy()
pointer += batch_size
precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
return auc(recall, precision), (y_pred, y_prob, y_true)
class EHRData(Dataset):
def __init__(self, data, cla):
self.data = data
self.cla = cla
def __len__(self):
return len(self.cla)
def __getitem__(self, idx):
return self.data[idx], self.cla[idx]
def collate_fn(data):
# padding
data_list = []
for datum in data:
data_list.append(np.hstack((datum[0].toarray().ravel(), datum[1])))
return torch.from_numpy(np.array(data_list)).long()