-
Notifications
You must be signed in to change notification settings - Fork 2
/
auroc.py
118 lines (103 loc) · 4.56 KB
/
auroc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import json
import tensorflow.keras.backend as kb
import numpy as np
import os
import shutil
from tensorflow.keras.callbacks import Callback
from utils import normalize_hinge_output
from sklearn.metrics import precision_recall_fscore_support, average_precision_score, accuracy_score, hamming_loss, \
roc_auc_score
class MultipleClassAUROC(Callback):
"""
Monitor mean AUROC and update model
"""
def __init__(self, sequence, class_names, weights_path, output_weights_path, loss_function, confidence_thresh=0.5,
stats=None, workers=1):
super(Callback, self).__init__()
self.sequence = sequence
self.workers = workers
self.class_names = class_names
self.weights_path = weights_path
self.loss_function = loss_function
self.confidence_thresh = confidence_thresh
self.best_weights_path = os.path.join(
os.path.split(weights_path)[0],
f"{os.path.split(output_weights_path)[1]}",
)
self.best_auroc_log_path = os.path.join(
os.path.split(weights_path)[0],
"best_auroc.log",
)
self.stats_output_path = os.path.join(
os.path.split(weights_path)[0],
"training_stats.json"
)
# for resuming previous training
if stats:
self.stats = stats
else:
self.stats = {"best_mean_auroc": 0}
# aurocs log
self.aurocs = {}
for c in self.class_names:
self.aurocs[c] = []
def on_epoch_end(self, epoch, logs={}):
"""
Calculate the average AUROC and save the best model weights according
to this metric.
"""
print("\n*********************************")
self.stats["lr"] = float(kb.eval(self.model.optimizer.lr))
print(f"current learning rate: {self.stats['lr']}")
"""
y_hat shape: (#samples, len(class_names))
y: [(#samples, 1), (#samples, 1) ... (#samples, 1)]
"""
y_hat = self.model.predict_generator(self.sequence, workers=self.workers)
if 'Hinge' in self.loss_function:
y_hat = normalize_hinge_output(y_hat)
y = self.sequence.get_y_true()
print(f"*** epoch#{epoch + 1} dev auroc ***")
current_auroc = []
for i in range(len(self.class_names)):
try:
score = roc_auc_score(y[:, i], y_hat[:, i])
except ValueError:
score = 0
self.aurocs[self.class_names[i]].append(score)
current_auroc.append(score)
print(f"{i + 1}. {self.class_names[i]}: {score}")
print("*********************************")
prec, rec, fscore, support = precision_recall_fscore_support(y, y_hat >= self.confidence_thresh,
average='macro')
AP = average_precision_score(y, y_hat)
exact_accuracy = accuracy_score(y, y_hat >= self.confidence_thresh)
ham_loss = hamming_loss(y, y_hat >= self.confidence_thresh)
print(
f"precision:{prec:.2f}, recall: {rec:.2f}, fscore: {fscore:.2f}, AP: {AP:.2f}, exact match accuracy: {exact_accuracy:.2f}, hamming loss: {ham_loss:.2f}")
# customize your multiple class metrics here
mean_auroc = np.mean(current_auroc)
print(f"mean auroc: {mean_auroc}")
if mean_auroc > self.stats["best_mean_auroc"]:
print(f"update best auroc from {self.stats['best_mean_auroc']} to {mean_auroc}")
# 1. copy best model
shutil.copy(self.weights_path, self.best_weights_path)
# 2. update log file
print(f"update log file: {self.best_auroc_log_path}")
with open(self.best_auroc_log_path, "a") as f:
f.write(f"(epoch#{epoch + 1}) auroc: {mean_auroc}, lr: {self.stats['lr']}\n")
# 3. write stats output, this is used for resuming the training
with open(self.stats_output_path, 'w') as f:
json.dump(self.stats, f)
print(f"update model file: {self.weights_path} -> {self.best_weights_path}")
self.stats["best_mean_auroc"] = mean_auroc
self.stats["AP"] = AP
self.stats["precision"] = prec
self.stats["recall"] = rec
self.stats["fscore"] = fscore
self.stats["hamming_loss"] = ham_loss
self.stats["exact_accuracy"] = exact_accuracy
print("*********************************")
else:
print(f"best auroc is still {self.stats['best_mean_auroc']}")
return