-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathddflow_model.py
411 lines (342 loc) · 23.5 KB
/
ddflow_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
# -*- coding: utf-8 -*-
from __future__ import division, print_function, absolute_import
import tensorflow as tf
import numpy as np
import os
import sys
import time
import cv2
from six.moves import xrange
from scipy import misc, io
from tensorflow.contrib import slim
import matplotlib.pyplot as plt
from network import pyramid_processing, pyramid_processing_bidirection, get_shape
from datasets import BasicDataset
from utils import average_gradients, lrelu, occlusion, rgb_bgr
from data_augmentation import flow_resize
from flowlib import flow_to_color, write_flo
from warp import tf_warp
class DDFlowModel(object):
def __init__(self, batch_size=8, iter_steps=1000000, initial_learning_rate=1e-4, decay_steps=2e5,
decay_rate=0.5, is_scale=True, num_input_threads=4, buffer_size=5000,
beta1=0.9, num_gpus=1, save_checkpoint_interval=5000, write_summary_interval=200,
display_log_interval=50, allow_soft_placement=True, log_device_placement=False,
regularizer_scale=1e-4, cpu_device='/cpu:0', save_dir='KITTI', checkpoint_dir='checkpoints',
model_name='model', sample_dir='sample', summary_dir='summary', training_mode="no_distillation",
is_restore_model=False, restore_model='./models/KITTI/no_census_no_occlusion',
dataset_config={}, distillation_config={}):
self.batch_size = batch_size
self.iter_steps = iter_steps
self.initial_learning_rate = initial_learning_rate
self.decay_steps = decay_steps
self.decay_rate = decay_rate
self.is_scale = is_scale
self.num_input_threads = num_input_threads
self.buffer_size = buffer_size
self.beta1 = beta1
self.num_gpus = num_gpus
self.save_checkpoint_interval = save_checkpoint_interval
self.write_summary_interval = write_summary_interval
self.display_log_interval = display_log_interval
self.allow_soft_placement = allow_soft_placement
self.log_device_placement = log_device_placement
self.regularizer_scale = regularizer_scale
self.training_mode = training_mode
self.is_restore_model = is_restore_model
self.restore_model = restore_model
self.dataset_config = dataset_config
self.distillation_config = distillation_config
self.shared_device = '/gpu:0' if self.num_gpus == 1 else cpu_device
assert(np.mod(batch_size, num_gpus) == 0)
self.batch_size_per_gpu = int(batch_size / np.maximum(num_gpus, 1))
self.save_dir = save_dir
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.checkpoint_dir = '/'.join([self.save_dir, checkpoint_dir])
if not os.path.exists(self.checkpoint_dir):
os.makedirs(self.checkpoint_dir)
self.model_name = model_name
if not os.path.exists('/'.join([self.checkpoint_dir, model_name])):
os.makedirs(('/'.join([self.checkpoint_dir, self.model_name])))
self.sample_dir = '/'.join([self.save_dir, sample_dir])
if not os.path.exists(self.sample_dir):
os.makedirs(self.sample_dir)
if not os.path.exists('/'.join([self.sample_dir, self.model_name])):
os.makedirs(('/'.join([self.sample_dir, self.model_name])))
self.summary_dir = '/'.join([self.save_dir, summary_dir])
if not os.path.exists(self.summary_dir):
os.makedirs(self.summary_dir)
if not os.path.exists('/'.join([self.summary_dir, 'train'])):
os.makedirs(('/'.join([self.summary_dir, 'train'])))
if not os.path.exists('/'.join([self.summary_dir, 'test'])):
os.makedirs(('/'.join([self.summary_dir, 'test'])))
def create_dataset_and_iterator(self, training_mode='no_distillation'):
if training_mode=='no_distillation':
dataset = BasicDataset(crop_h=self.dataset_config['crop_h'],
crop_w=self.dataset_config['crop_w'],
batch_size=self.batch_size_per_gpu,
data_list_file=self.dataset_config['data_list_file'],
img_dir=self.dataset_config['img_dir'])
iterator = dataset.create_batch_iterator(data_list=dataset.data_list, batch_size=dataset.batch_size,
shuffle=True, buffer_size=self.buffer_size, num_parallel_calls=self.num_input_threads)
elif training_mode == 'distillation':
dataset = BasicDataset(crop_h=self.dataset_config['crop_h'],
crop_w=self.dataset_config['crop_w'],
batch_size=self.batch_size_per_gpu,
data_list_file=self.dataset_config['data_list_file'],
img_dir=self.dataset_config['img_dir'],
fake_flow_occ_dir=self.distillation_config['fake_flow_occ_dir'])
iterator = dataset.create_batch_distillation_iterator(data_list=dataset.data_list, batch_size=dataset.batch_size,
shuffle=True, buffer_size=self.buffer_size, num_parallel_calls=self.num_input_threads)
else:
raise ValueError('Invalid training_mode. Training_mode should be one of {no_distillation, distillation}')
return dataset, iterator
def epe_loss(self, diff, mask):
diff_norm = tf.norm(diff, axis=-1, keepdims=True)
diff_norm = tf.multiply(diff_norm, mask)
diff_norm_sum = tf.reduce_sum(diff_norm)
loss_mean = diff_norm_sum / (tf.reduce_sum(mask) + 1e-6)
return loss_mean
def abs_robust_loss(self, diff, mask, q=0.4):
diff = tf.pow((tf.abs(diff)+0.01), q)
diff = tf.multiply(diff, mask)
diff_sum = tf.reduce_sum(diff)
loss_mean = diff_sum / (tf.reduce_sum(mask) * 2 + 1e-6)
return loss_mean
def create_mask(self, tensor, paddings):
with tf.variable_scope('create_mask'):
shape = tf.shape(tensor)
inner_width = shape[1] - (paddings[0][0] + paddings[0][1])
inner_height = shape[2] - (paddings[1][0] + paddings[1][1])
inner = tf.ones([inner_width, inner_height])
mask2d = tf.pad(inner, paddings)
mask3d = tf.tile(tf.expand_dims(mask2d, 0), [shape[0], 1, 1])
mask4d = tf.expand_dims(mask3d, 3)
return tf.stop_gradient(mask4d)
def census_loss(self, img1, img2_warped, mask, max_distance=3):
patch_size = 2 * max_distance + 1
with tf.variable_scope('census_loss'):
def _ternary_transform(image):
intensities = tf.image.rgb_to_grayscale(image) * 255
#patches = tf.extract_image_patches( # fix rows_in is None
# intensities,
# ksizes=[1, patch_size, patch_size, 1],
# strides=[1, 1, 1, 1],
# rates=[1, 1, 1, 1],
# padding='SAME')
out_channels = patch_size * patch_size
w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels))
weights = tf.constant(w, dtype=tf.float32)
patches = tf.nn.conv2d(intensities, weights, strides=[1, 1, 1, 1], padding='SAME')
transf = patches - intensities
transf_norm = transf / tf.sqrt(0.81 + tf.square(transf))
return transf_norm
def _hamming_distance(t1, t2):
dist = tf.square(t1 - t2)
dist_norm = dist / (0.1 + dist)
dist_sum = tf.reduce_sum(dist_norm, 3, keepdims=True)
return dist_sum
t1 = _ternary_transform(img1)
t2 = _ternary_transform(img2_warped)
dist = _hamming_distance(t1, t2)
transform_mask = self.create_mask(mask, [[max_distance, max_distance],
[max_distance, max_distance]])
return self.abs_robust_loss(dist, mask * transform_mask)
def compute_losses(self, batch_img1, batch_img2, flow_fw, flow_bw, mask_fw, mask_bw, train=True, is_scale=True):
img_size = get_shape(batch_img1, train=train)
img1_warp = tf_warp(batch_img1, flow_bw['full_res'], img_size[1], img_size[2])
img2_warp = tf_warp(batch_img2, flow_fw['full_res'], img_size[1], img_size[2])
losses = {}
abs_robust_mean = {}
abs_robust_mean['no_occlusion'] = self.abs_robust_loss(batch_img1-img2_warp, tf.ones_like(mask_fw)) + self.abs_robust_loss(batch_img2-img1_warp, tf.ones_like(mask_bw))
abs_robust_mean['occlusion'] = self.abs_robust_loss(batch_img1-img2_warp, mask_fw) + self.abs_robust_loss(batch_img2-img1_warp, mask_bw)
losses['abs_robust_mean'] = abs_robust_mean
census_loss = {}
census_loss['no_occlusion'] = self.census_loss(batch_img1, img2_warp, tf.ones_like(mask_fw), max_distance=3) + \
self.census_loss(batch_img2, img1_warp, tf.ones_like(mask_bw), max_distance=3)
census_loss['occlusion'] = self.census_loss(batch_img1, img2_warp, mask_fw, max_distance=3) + \
self.census_loss(batch_img2, img1_warp, mask_bw, max_distance=3)
losses['census'] = census_loss
return losses
def add_loss_summary(self, losses, keys=['abs_robust_mean'], prefix=None):
for key in keys:
for loss_key, loss_value in losses[key].items():
if prefix:
loss_name = '%s/%s/%s' % (prefix, key, loss_key)
else:
loss_name = '%s/%s' % (key, loss_key)
tf.summary.scalar(loss_name, loss_value)
def build_no_data_distillation(self, iterator, regularizer_scale=1e-4, train=True, trainable=True, is_scale=True):
batch_img1, batch_img2 = iterator.get_next()
regularizer = slim.l2_regularizer(scale=regularizer_scale)
flow_fw, flow_bw = pyramid_processing_bidirection(batch_img1, batch_img2,
train=train, trainable=trainable, reuse=None, regularizer=regularizer, is_scale=is_scale)
occ_fw, occ_bw = occlusion(flow_fw['full_res'], flow_bw['full_res'])
mask_fw = 1. - occ_fw
mask_bw = 1. - occ_bw
losses = self.compute_losses(batch_img1, batch_img2, flow_fw, flow_bw, mask_fw, mask_bw, train=train, is_scale=is_scale)
l2_regularizer = tf.losses.get_regularization_losses()
regularizer_loss = tf.add_n(l2_regularizer)
return losses, regularizer_loss
def build_data_distillation(self, iterator, regularizer_scale=1e-4, train=True, trainable=True, is_scale=True):
batch_img1, batch_img2, flow_fw, flow_bw, occ_fw, occ_bw = iterator.get_next()
regularizer = slim.l2_regularizer(scale=regularizer_scale)
h = self.dataset_config['crop_h']
w = self.dataset_config['crop_w']
target_h = self.distillation_config['target_h']
target_w = self.distillation_config['target_w']
offect_h = tf.random_uniform([], minval=0, maxval=h-target_h, dtype=tf.int32)
offect_w = tf.random_uniform([], minval=0, maxval=w-target_w, dtype=tf.int32)
batch_img1_cropped_patch = tf.image.crop_to_bounding_box(batch_img1, offect_h, offect_w, target_h, target_w)
batch_img2_cropped_patch = tf.image.crop_to_bounding_box(batch_img2, offect_h, offect_w, target_h, target_w)
flow_fw_cropped_patch = tf.image.crop_to_bounding_box(flow_fw, offect_h, offect_w, target_h, target_w)
flow_bw_cropped_patch = tf.image.crop_to_bounding_box(flow_bw, offect_h, offect_w, target_h, target_w)
occ_fw_cropped_patch = tf.image.crop_to_bounding_box(occ_fw, offect_h, offect_w, target_h, target_w)
occ_bw_cropped_patch = tf.image.crop_to_bounding_box(occ_bw, offect_h, offect_w, target_h, target_w)
flow_fw_patch, flow_bw_patch = pyramid_processing_bidirection(batch_img1_cropped_patch, batch_img2_cropped_patch,
train=train, trainable=trainable, reuse=None, regularizer=regularizer, is_scale=is_scale)
occ_fw_patch, occ_bw_patch = occlusion(flow_fw_patch['full_res'], flow_bw_patch['full_res'])
mask_fw_patch = 1. - occ_fw_patch
mask_bw_patch = 1. - occ_bw_patch
losses = self.compute_losses(batch_img1_cropped_patch, batch_img2_cropped_patch, flow_fw_patch, flow_bw_patch, mask_fw_patch, mask_bw_patch, train=train, is_scale=is_scale)
valid_mask_fw = tf.clip_by_value(occ_fw_patch - occ_fw_cropped_patch, 0., 1.)
valid_mask_bw = tf.clip_by_value(occ_bw_patch - occ_bw_cropped_patch, 0., 1.)
data_distillation_loss = {}
data_distillation_loss['distillation'] = (self.abs_robust_loss(flow_fw_cropped_patch-flow_fw_patch['full_res'], valid_mask_fw) + \
self.abs_robust_loss(flow_bw_cropped_patch-flow_bw_patch['full_res'], valid_mask_bw)) / 2
losses['data_distillation'] = data_distillation_loss
l2_regularizer = tf.losses.get_regularization_losses()
regularizer_loss = tf.add_n(l2_regularizer)
return losses, regularizer_loss
def build(self, iterator, regularizer_scale=1e-4, train=True, trainable=True, is_scale=True, training_mode='no_distillation'):
if training_mode == 'no_distillation':
losses, regularizer_loss = self.build_no_data_distillation(iterator=iterator, regularizer_scale=regularizer_scale, train=train, trainable=trainable, is_scale=is_scale)
elif training_mode == 'distillation':
losses, regularizer_loss = self.build_data_distillation(iterator=iterator, regularizer_scale=regularizer_scale, train=train, trainable=trainable, is_scale=is_scale)
else:
raise ValueError('Invalid training_mode. Training_mode should be one of {no_distillation, distillation}')
return losses, regularizer_loss
def create_train_op(self, optim, iterator, global_step, regularizer_scale=1e-4, train=True, trainable=True, is_scale=True, training_mode='no_distillation'):
if self.num_gpus == 1:
losses, regularizer_loss = self.build(iterator, regularizer_scale=regularizer_scale, train=train, trainable=trainable, is_scale=is_scale, training_mode=training_mode)
optim_loss = losses['abs_robust_mean']['no_occlusion']
train_op = optim.minimize(optim_loss, var_list=tf.trainable_variables(), global_step=global_step)
else:
tower_grads = []
tower_losses = []
tower_regularizer_losses = []
with tf.variable_scope(tf.get_variable_scope()):
for i in range(self.num_gpus):
with tf.device('/gpu:%d' % i):
with tf.name_scope('tower_{}'.format(i)) as scope:
losses_, regularizer_loss_ = self.build(iterator, regularizer_scale=regularizer_scale, train=train, trainable=trainable, is_scale=is_scale, training_mode=training_mode)
optim_loss = losses_['abs_robust_mean']['no_occlusion']
# Reuse variables for the next tower.
tf.get_variable_scope().reuse_variables()
grads = self.optim.compute_gradients(optim_loss, var_list=tf.trainable_variables())
tower_grads.append(grads)
tower_losses.append(losses_)
tower_regularizer_losses.append(regularizer_loss_)
#self.add_loss_summary(losses_, keys=['abs_robust_mean', 'census'], prefix='tower_%d' % i)
grads = average_gradients(tower_grads)
train_op = optim.apply_gradients(grads, global_step=global_step)
losses = tower_losses[0].copy()
for key in losses.keys():
for loss_key, loss_value in losses[key].items():
for i in range(1, self.num_gpus):
losses[key][loss_key] += tower_losses[i][key][loss_key]
losses[key][loss_key] /= self.num_gpus
regularizer_loss = 0.
for i in range(self.num_gpus):
regularizer_loss += tower_regularizer_losses[i]
regularizer_loss /= self.num_gpus
self.add_loss_summary(losses, keys=losses.keys())
tf.summary.scalar('regularizer_loss', regularizer_loss)
return train_op, losses, regularizer_loss
def train(self):
with tf.Graph().as_default(), tf.device(self.shared_device):
self.global_step = tf.Variable(0, trainable=False)
self.dataset, self.iterator = self.create_dataset_and_iterator(training_mode=self.training_mode)
self.lr_decay = tf.train.exponential_decay(self.initial_learning_rate, self.global_step, decay_steps=self.decay_steps, decay_rate=self.decay_rate, staircase=True)
tf.summary.scalar('learning_rate', self.lr_decay)
self.optim = tf.train.AdamOptimizer(self.lr_decay, self.beta1)
self.train_op, self.losses, self.regularizer_loss = self.create_train_op(optim=self.optim, iterator=self.iterator,
global_step=self.global_step, regularizer_scale=self.regularizer_scale, train=True, trainable=True, is_scale=self.is_scale, training_mode=self.training_mode)
merge_summary = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(logdir='/'.join([self.summary_dir, 'train', self.model_name]))
self.trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
self.saver = tf.train.Saver(var_list=self.trainable_vars + [self.global_step], max_to_keep=500)
self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=self.allow_soft_placement, log_device_placement=self.log_device_placement))
self.sess.run(tf.global_variables_initializer())
self.sess.run(tf.local_variables_initializer())
if self.is_restore_model:
self.saver.restore(self.sess, self.restore_model)
self.sess.run(tf.assign(self.global_step, 0))
start_step = self.sess.run(self.global_step)
self.sess.run(self.iterator.initializer)
start_time = time.time()
for step in range(start_step+1, self.iter_steps+1):
_, abs_robust_mean_no_occlusion, census_occlusion = self.sess.run([self.train_op,
self.losses['abs_robust_mean']['no_occlusion'], self.losses['census']['occlusion']])
if np.mod(step, self.display_log_interval) == 0:
print('step: %d time: %.6fs, abs_robust_mean_no_occlusion: %.6f, census_occlusion: %.6f' %
(step, time.time() - start_time, abs_robust_mean_no_occlusion, census_occlusion))
if np.mod(step, self.write_summary_interval) == 0:
summary_str = self.sess.run(merge_summary)
summary_writer.add_summary(summary_str, global_step=step)
if np.mod(step, self.save_checkpoint_interval) == 0:
self.saver.save(self.sess, '/'.join([self.checkpoint_dir, self.model_name, 'model']), global_step=step,
write_meta_graph=False, write_state=False)
def test(self, restore_model, save_dir):
dataset = BasicDataset(data_list_file=self.dataset_config['data_list_file'], img_dir=self.dataset_config['img_dir'])
save_name_list = dataset.data_list[:, 2]
iterator = dataset.create_one_shot_iterator(dataset.data_list, num_parallel_calls=self.num_input_threads)
batch_img1, batch_img2 = iterator.get_next()
flow_est = pyramid_processing(batch_img1, batch_img2, train=False, trainable=False, regularizer=None, is_scale=True)
flow_est_color = flow_to_color(flow_est['full_res'], mask=None, max_flow=256)
restore_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
saver = tf.train.Saver(var_list=restore_vars)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(iterator.initializer)
saver.restore(sess, restore_model)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for i in range(dataset.data_num):
np_flow_est, np_flow_est_color = sess.run([flow_est['full_res'], flow_est_color])
misc.imsave('%s/flow_est_color_%s.png' % (save_dir, save_name_list[i]), np_flow_est_color[0])
write_flo('%s/flow_est_%s.flo' % (save_dir, save_name_list[i]), np_flow_est[0])
print('Finish %d/%d' % (i, dataset.data_num))
def generate_fake_flow_occlusion(self, restore_model, save_dir):
dataset = BasicDataset(data_list_file=self.dataset_config['data_list_file'], img_dir=self.dataset_config['img_dir'])
save_name_list = dataset.data_list[:, 2]
iterator = dataset.create_one_shot_iterator(dataset.data_list, num_parallel_calls=self.num_input_threads)
batch_img1, batch_img2 = iterator.get_next()
flow_fw, flow_bw = pyramid_processing_bidirection(batch_img1, batch_img2,
train=False, trainable=False, reuse=None, regularizer=None, is_scale=True)
occ_fw, occ_bw = occlusion(flow_fw['full_res'], flow_bw['full_res'])
flow_fw_full_res = flow_fw['full_res'] * 64. + 32768
flow_occ_fw = tf.concat([flow_fw_full_res, occ_fw], -1)
flow_occ_fw = tf.cast(flow_occ_fw, tf.uint16)
flow_bw_full_res = flow_bw['full_res'] * 64. + 32768
flow_occ_bw = tf.concat([flow_bw_full_res, occ_bw], -1)
flow_occ_bw = tf.cast(flow_occ_bw, tf.uint16)
restore_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
saver = tf.train.Saver(var_list=restore_vars)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(iterator.initializer)
saver.restore(sess, restore_model)
#save_dir = '/'.join([self.save_dir, 'sample', self.model_name])
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for i in range(dataset.data_num):
np_flow_occ_fw, np_flow_occ_bw, np_occ_fw = sess.run([flow_occ_fw, flow_occ_bw, occ_fw])
# opencv read and save image as bgr format, here we change rgb to bgr
np_flow_occ_fw = rgb_bgr(np_flow_occ_fw[0])
np_flow_occ_bw = rgb_bgr(np_flow_occ_bw[0])
np_flow_occ_fw = np_flow_occ_fw.astype(np.uint16)
np_flow_occ_bw = np_flow_occ_bw.astype(np.uint16)
cv2.imwrite('%s/flow_occ_fw_%s.png' % (save_dir, save_name_list[i]), np_flow_occ_fw)
cv2.imwrite('%s/flow_occ_bw_%s.png' % (save_dir, save_name_list[i]), np_flow_occ_bw)
print('Finish %d/%d' % (i, dataset.data_num))