-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathValidate_knn.py
55 lines (36 loc) · 2.15 KB
/
Validate_knn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# -*- coding: utf-8 -*-
"""
Created on Tue May 26 14:21:42 2020
@author: luist
"""
class Validate_knn:
def validate_knn(n_iter,loader, time, siamese_net,batch_size,t_start,n_val,best_val, evaluate_every, loss_every, n):
best_train = -1
best_val = -1
knn_accs, train_accs = [], []
for i in range(1, n_iter+1):
(inputs,targets) = loader.batch_function(batch_size)
loss = siamese_net.train_on_batch(inputs, targets)
if i % evaluate_every == 0:
print("\n ------------- \n")
print("Time for {0} iterations: {1} mins".format(i, (time.time()-t_start)/60.0))
print("Train Loss: {0}".format(loss))
val_acc = loader.knn_test(n, n_val)
train_acc = loader.oneshot_test(siamese_net, n, n_val, s = 'train')
# siamese_net.save_weights(os.path.join(model_path, 'weights.{}.h5'.format(i)))
if val_acc >= best_val:
print("Current best val: {0}, previous best: {1}".format(val_acc, best_val))
best_val = val_acc
if train_acc >= best_train:
print("Current best train: {0}, previous best: {1}".format(train_acc, best_train))
best_train = train_acc
if i % loss_every == 0:
print("iteration {}, training loss: {:.2f},".format(i,loss))
print("Current best val: {0}".format(best_val)," - N:", n)
print("Current best train: {0}".format(best_train)," - N:", n)
print("Tempo decorrido:", (time.time()-t_start)/60.0)
print("The final best accuracy value (validation): {0}".format(best_val)," - N:", n)
print("The final best accuracy value (training): {0}".format(best_train)," - N:", n)
knn_accs.append(best_val)
train_accs.append(best_train)
return knn_accs, train_accs