-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtrain.py
94 lines (73 loc) · 2.98 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import tensorflow as tf
import time
import model
import audio_producer
def run_epoch(epoch, ghostvlad_model, producer, sess, save_path, saver):
ops = [ghostvlad_model.cost, ghostvlad_model.learning_rate, ghostvlad_model.global_step, ghostvlad_model.train_op]
for inputs, labels in producer.iterator():
feed_dict = ghostvlad_model.feed_dict(inputs, labels)
cost, lr, step, _ = sess.run(ops, feed_dict)
if(step%1000==0):
saver.save(sess, save_path)
print('Epoch {}, iter {}: Cost= {:.2f}, lr= {:.2e}'.format(epoch, step, cost, lr))
saver.save(sess, save_path)
def main(argv):
restore_path = argv.get('restore_path', None)
save_path = argv['save_path']
producer = audio_producer.AudioProducer(argv['json_path'], argv['batch_size'],
sample_rate=argv['sample_rate'],
min_duration=argv['min_duration'],
max_duration=argv['max_duration'])
graph = tf.Graph()
with graph.as_default():
ghostvlad_model = model.GhostVLADModel(argv)
ghostvlad_model.init_inference(is_training=True)
ghostvlad_model.init_cost()
os.environ['OMP_NUM_THREADS'] = '32'
os.environ['KMP_BLOCKTIME'] = '0'
os.environ["KMP_SETTINGS"] = '0'
os.environ['KMP_AFFINITY'] = 'granularity=fine,compact,1,0'
sess_conf = tf.ConfigProto()
sess_conf.inter_op_parallelism_threads = 1
sess_conf.intra_op_parallelism_threads = 32
with tf.Session(config=sess_conf) as sess:
restore_vars = []
train_vars = []
for var in tf.global_variables():
if(var.name.startswith('arcface/')):
train_vars.append(var)
else:
if(not 'Adam' in var.name):
restore_vars.append(var)
ghostvlad_model.init_train(train_vars)
sess.run(tf.global_variables_initializer())
if restore_path:
saver = tf.train.Saver(restore_vars)
saver.restore(sess, restore_path)
saver = tf.train.Saver()
print("Begin training...")
for e in range(argv['epochs']):
run_epoch(e, ghostvlad_model, producer, sess, save_path, saver)
print("========"*5)
print("Finished epoch", e)
if __name__=="__main__":
args_params = {
"json_path": r"vox.json",
"sample_rate": 16000,
"min_duration": 1000,
"max_duration": 3000,
"save_path": r"saver/model.ckpt",
"restore_path": r"ckpt/model.ckpt",
"batch_size": 256,
"epochs": 1000,
"learning_rate": 0.001,
"max_grad_norm": 50,
"decay_steps": 5000,
"decay_rate": 0.95,
"vlad_clusters": 8,
"ghost_clusters": 2,
"embedding_dim": 512,
"num_class": 5994
}
main(args_params)