-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
135 lines (100 loc) · 4.83 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import tensorflow as tf
import time
import numpy as np
import os
from macn.model import MACN, VINConfig
from dataset import get_datasets
FLAGS = tf.flags.FLAGS
# Hyperparameter
tf.flags.DEFINE_integer("epochs", 10, "Number of epochs for training")
tf.flags.DEFINE_integer("ep_per_epoch", 1000, "Number of episodes per epochs")
tf.flags.DEFINE_float( "learning_rate", 10e-5, "The learning rate")
# MACN conf
tf.flags.DEFINE_integer("im_h", 9, "Image height")
tf.flags.DEFINE_integer("im_w", 9, "Image width")
tf.flags.DEFINE_integer("ch_i", 2, "Channels in input layer (~2 in [grid, reward])")
# VIN conf
tf.flags.DEFINE_integer("k", 10, "Number of iteration for planning (VIN)")
tf.flags.DEFINE_integer("ch_q", 4, "Channels in q layer (~actions)")
tf.flags.DEFINE_integer("ch_h", 150, "Channels in initial hidden layer")
# DNC Conf
tf.flags.DEFINE_integer("hidden_size", 256, "Size of LSTM hidden layer.")
tf.flags.DEFINE_integer("memory_size", 32, "The number of memory slots.")
tf.flags.DEFINE_integer("word_size", 8, "The width of each memory slot.")
tf.flags.DEFINE_integer("num_read_heads", 4, "Number of memory read heads.")
tf.flags.DEFINE_integer("num_write_heads", 1, "Number of memory write heads.")
tf.flags.DEFINE_string('dataset', "./data/dataset.pkl", "Path to dataset file")
tf.flags.DEFINE_string('save', "./model/weights.ckpt", "File to save the weights")
tf.flags.DEFINE_string('load', "./model/weights.ckpt", "File to load the weights")
def main(args):
checks()
macn = MACN(
image_shape=[FLAGS.im_h, FLAGS.im_w, FLAGS.ch_i],
vin_config=VINConfig(k=FLAGS.k, ch_h=FLAGS.ch_h, ch_q=FLAGS.ch_q),
access_config={
"memory_size": FLAGS.memory_size,
"word_size": FLAGS.word_size,
"num_reads": FLAGS.num_read_heads,
"num_writes": FLAGS.num_write_heads
},
controller_config={
"hidden_size": FLAGS.hidden_size
}
)
y = tf.placeholder(tf.int64, shape=[None], name='y') # labels : actions {0,1,2,3}
# Training
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=macn.logits, name='cross_entropy')
loss = tf.reduce_sum(cross_entropy, name='cross_entropy_mean')
train_step = tf.train.RMSPropOptimizer(FLAGS.learning_rate, epsilon=1e-6, centered=True).minimize(loss)
# Reporting
y_ = tf.argmax(macn.prob_actions, axis=-1) # predicted action
nb_errors = tf.reduce_sum(tf.to_float(tf.not_equal(y_, y))) # Number of wrongly selected actions
def train_on_episode(images, labels):
_, _loss, _nb_err = sess.run([train_step, loss, nb_errors], feed_dict={macn.X : images, y : labels})
return _loss, _nb_err
def test_on_episode(images, labels):
return sess.run([loss, nb_errors], feed_dict={macn.X : images, y : labels})
trainset, testset = get_datasets(FLAGS.dataset, test_percent=0.1)
# Start training
saver = tf.train.Saver()
with tf.Session() as sess:
if loadfile_exists(FLAGS.load):
saver.restore(sess, FLAGS.load)
print("Weights reloaded")
else:
sess.run(tf.global_variables_initializer())
print("Start training...")
for epoch in range(1, FLAGS.epochs + 1):
start_time = time.time()
mean_loss, mean_accuracy = compute_on_dataset(sess, trainset, train_on_episode)
print('Epoch: {:3d} ({:.1f} s):'.format(epoch, time.time() - start_time))
print('\t Train Loss: {:.5f} \t Train accuracy: {:.2f}%'.format(mean_loss, 100*(mean_accuracy)))
saver.save(sess, FLAGS.save)
print('Training finished.')
print('Testing...')
mean_loss, mean_accuracy = compute_on_dataset(sess, testset, test_on_episode)
print('Test Accuracy: {:.2f}%'.format(100*(mean_accuracy)))
def compute_on_dataset(sess, dataset, compute_episode):
total_loss = 0
total_accuracy = 0
for episode in range(1, FLAGS.ep_per_epoch + 1):
images, labels = dataset.next_episode()
loss, nb_err = compute_episode(images, labels)
accuracy = 1 - (nb_err / labels.shape[0])
total_loss += loss
total_accuracy += accuracy
mean_loss = total_loss / FLAGS.ep_per_epoch
mean_accuracy = total_accuracy / FLAGS.ep_per_epoch
return mean_loss, mean_accuracy
def loadfile_exists(filepath):
filename = os.path.basename(filepath)
for file in os.listdir(os.path.dirname(filepath)):
if file.startswith(filename):
return True
return False
def checks():
if not os.path.exists(os.path.dirname(FLAGS.save)):
print("Error : save file cannot be created (need folders) : " + FLAGS.save)
exit()
if __name__ == "__main__":
tf.app.run()