-
Notifications
You must be signed in to change notification settings - Fork 50
/
train.py
138 lines (105 loc) · 4.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
'''
Training part
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six.moves.queue as queue
from six.moves import range
import datetime
import logging
import os
import threading
import time
import numpy as np
import tensorflow as tf
from maze import MazeGenerator
from predictron import Predictron
FLAGS = tf.app.flags.FLAGS
tf.flags.DEFINE_string('train_dir', './ckpts/predictron_train',
'dir to save checkpoints and TB logs')
tf.flags.DEFINE_integer('max_steps', 10000000, 'num of batches')
tf.flags.DEFINE_float('learning_rate', 1e-3, 'learning rate')
tf.flags.DEFINE_integer('batch_size', 128, 'batch size')
tf.flags.DEFINE_integer('maze_size', 20, 'size of maze (square)')
tf.flags.DEFINE_float('maze_density', 0.3, 'Maze density')
tf.flags.DEFINE_integer('max_depth', 16, 'maximum model depth')
tf.flags.DEFINE_float('max_grad_norm', 10., 'clip grad norm into this value')
tf.flags.DEFINE_boolean('log_device_placement', False,
"""Whether to log device placement.""")
tf.flags.DEFINE_integer('num_threads', 10, 'num of threads used to generate mazes.')
logging.basicConfig()
logger = logging.getLogger('training')
logger.setLevel(logging.INFO)
def train():
config = FLAGS
global_step = tf.get_variable(
'global_step', [],
initializer=tf.constant_initializer(0), trainable=False)
maze_ims_ph = tf.placeholder(tf.float32, [None, FLAGS.maze_size, FLAGS.maze_size, 1])
maze_labels_ph = tf.placeholder(tf.float32, [None, FLAGS.maze_size])
model = Predictron(maze_ims_ph, maze_labels_ph, config)
model.build()
loss = model.total_loss
loss_preturns = model.loss_preturns
loss_lambda_preturns = model.loss_lambda_preturns
opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
grad_vars = opt.compute_gradients(loss, tf.trainable_variables())
grads, vars = zip(*grad_vars)
grads_clipped, _ = tf.clip_by_global_norm(grads, FLAGS.max_grad_norm)
grad_vars = zip(grads_clipped, vars)
apply_gradient_op = opt.apply_gradients(grad_vars, global_step=global_step)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
update_op = tf.group(*update_ops)
# Group all updates to into a single train op.
train_op = tf.group(apply_gradient_op, update_op)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
saver = tf.train.Saver(tf.global_variables())
tf.train.start_queue_runners(sess=sess)
train_dir = os.path.join(FLAGS.train_dir, 'max_steps_{}'.format(FLAGS.max_depth))
summary_merged = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(train_dir, sess.graph)
maze_queue = queue.Queue(100)
def maze_generator():
maze_gen = MazeGenerator(
height=FLAGS.maze_size,
width=FLAGS.maze_size,
density=FLAGS.maze_density)
while True:
maze_ims, maze_labels = maze_gen.generate_labelled_mazes(FLAGS.batch_size)
maze_queue.put((maze_ims, maze_labels))
for thread_i in range(FLAGS.num_threads):
t = threading.Thread(target=maze_generator)
t.start()
for step in range(FLAGS.max_steps):
start_time = time.time()
maze_ims_np, maze_labels_np = maze_queue.get()
_, loss_value, loss_preturns_val, loss_lambda_preturns_val, summary_str = sess.run(
[train_op, loss, loss_preturns, loss_lambda_preturns, summary_merged],
feed_dict={
maze_ims_ph: maze_ims_np,
maze_labels_ph: maze_labels_np
})
duration = time.time() - start_time
assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
if step % 10 == 0:
num_examples_per_step = FLAGS.batch_size
examples_per_sec = num_examples_per_step / duration
sec_per_batch = duration
format_str = (
'%s: step %d, loss = %.4f, loss_preturns = %.4f, loss_lambda_preturns = %.4f (%.1f examples/sec; %.3f '
'sec/batch)')
logger.info(format_str % (datetime.datetime.now(), step, loss_value, loss_preturns_val, loss_lambda_preturns_val,
examples_per_sec, sec_per_batch))
if step % 100 == 0:
summary_writer.add_summary(summary_str, step)
# Save the model checkpoint periodically.
if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
checkpoint_path = os.path.join(train_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)
def main(argv=None):
train()
if __name__ == '__main__':
tf.app.run()