forked from harpArk614/3d-pose-warping
-
Notifications
You must be signed in to change notification settings - Fork 1
/
adapted_resnet_utils.py
100 lines (79 loc) · 3.91 KB
/
adapted_resnet_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import tensorflow as tf
from tensorflow.contrib import slim as contrib_slim
slim = contrib_slim
class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])):
pass
def subsample(inputs, factor, scope=None):
if factor == 1:
return inputs
else:
return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope)
def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None):
if stride == 1:
return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate,
padding='SAME', scope=scope)
else:
kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
pad_total = kernel_size_effective - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
inputs = tf.pad(inputs,
[[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride,
rate=rate, padding='VALID', scope=scope)
@slim.add_arg_scope
def stack_blocks_dense(net, blocks, output_stride=None,
store_non_strided_activations=False,
outputs_collections=None):
current_stride = 1
rate = 1
for block in blocks:
with tf.variable_scope(block.scope, 'block', [net]) as sc:
block_stride = 1
for i, unit in enumerate(block.args):
if store_non_strided_activations and i == len(block.args) - 1:
block_stride = unit.get('stride', 1)
unit = dict(unit, stride=1)
with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
if output_stride is not None and current_stride == output_stride:
net = block.unit_fn(net, rate=rate, **dict(unit, stride=1))
rate *= unit.get('stride', 1)
else:
net = block.unit_fn(net, rate=1, **unit)
current_stride *= unit.get('stride', 1)
if output_stride is not None and current_stride > output_stride:
raise ValueError('The target output_stride cannot be reached.')
net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)
if output_stride is not None and current_stride == output_stride:
rate *= block_stride
else:
net = subsample(net, block_stride)
current_stride *= block_stride
if output_stride is not None and current_stride > output_stride:
raise ValueError('The target output_stride cannot be reached.')
if output_stride is not None and current_stride != output_stride:
raise ValueError('The target output_stride cannot be reached.')
return net
def resnet_arg_scope(weight_decay=0.0001,
batch_norm_epsilon=1e-5,
batch_norm_scale=True,
activation_fn=tf.nn.relu,
use_batch_norm=True):
batch_norm_params = {
'epsilon': batch_norm_epsilon,
'scale': batch_norm_scale
}
with slim.arg_scope(
[slim.conv2d],
weights_regularizer=slim.l2_regularizer(weight_decay),
weights_initializer=slim.variance_scaling_initializer(),
activation_fn=activation_fn,
normalizer_fn=slim.group_norm if use_batch_norm else None,
normalizer_params=batch_norm_params):
with slim.arg_scope([slim.group_norm], **batch_norm_params):
with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
return arg_sc