-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathdeformable_conv_layer.py
246 lines (220 loc) · 10.2 KB
/
deformable_conv_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
############
# AUTHOR: An Jiaoyang
# DATE: 2018-10-11
############
"""Deformable Convolutional Layer
"""
import tensorflow as tf
from tensorflow.keras.layers import Conv2D
class DeformableConvLayer(Conv2D):
"""Only support "channel last" data format"""
def __init__(self,
filters,
kernel_size,
strides=(1, 1),
padding='valid',
data_format=None,
dilation_rate=(1, 1),
num_deformable_group=None,
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
"""`kernel_size`, `strides` and `dilation_rate` must have the same value in both axis.
:param num_deformable_group: split output channels into groups, offset shared in each group. If
this parameter is None, then set num_deformable_group=filters.
"""
super().__init__(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs)
self.kernel = None
self.bias = None
self.offset_layer_kernel = None
self.offset_layer_bias = None
if num_deformable_group is None:
num_deformable_group = filters
if filters % num_deformable_group != 0:
raise ValueError('"filters" mod "num_deformable_group" must be zero')
self.num_deformable_group = num_deformable_group
def build(self, input_shape):
input_dim = int(input_shape[-1])
# kernel_shape = self.kernel_size + (input_dim, self.filters)
# we want to use depth-wise conv
kernel_shape = self.kernel_size + (self.filters * input_dim, 1)
self.kernel = self.add_weight(
name='kernel',
shape=kernel_shape,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
trainable=True,
dtype=self.dtype)
if self.use_bias:
self.bias = self.add_weight(
name='bias',
shape=(self.filters,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
trainable=True,
dtype=self.dtype)
# create offset conv layer
offset_num = self.kernel_size[0] * self.kernel_size[1] * self.num_deformable_group
self.offset_layer_kernel = self.add_weight(
name='offset_layer_kernel',
shape=self.kernel_size + (input_dim, offset_num * 2), # 2 means x and y axis
initializer=tf.zeros_initializer(),
regularizer=self.kernel_regularizer,
trainable=True,
dtype=self.dtype)
self.offset_layer_bias = self.add_weight(
name='offset_layer_bias',
shape=(offset_num * 2,),
initializer=tf.zeros_initializer(),
# initializer=tf.random_uniform_initializer(-5, 5),
regularizer=self.bias_regularizer,
trainable=True,
dtype=self.dtype)
self.built = True
def call(self, inputs, training=None, **kwargs):
# get offset, shape [batch_size, out_h, out_w, filter_h, * filter_w * channel_out * 2]
offset = tf.nn.conv2d(inputs,
filter=self.offset_layer_kernel,
strides=[1, *self.strides, 1],
padding=self.padding.upper(),
dilations=[1, *self.dilation_rate, 1])
offset += self.offset_layer_bias
# add padding if needed
inputs = self._pad_input(inputs)
# some length
batch_size = int(inputs.get_shape()[0])
channel_in = int(inputs.get_shape()[-1])
in_h, in_w = [int(i) for i in inputs.get_shape()[1: 3]] # input feature map size
out_h, out_w = [int(i) for i in offset.get_shape()[1: 3]] # output feature map size
filter_h, filter_w = self.kernel_size
# get x, y axis offset
offset = tf.reshape(offset, [batch_size, out_h, out_w, -1, 2])
y_off, x_off = offset[:, :, :, :, 0], offset[:, :, :, :, 1]
# input feature map gird coordinates
y, x = self._get_conv_indices([in_h, in_w])
y, x = [tf.expand_dims(i, axis=-1) for i in [y, x]]
y, x = [tf.tile(i, [batch_size, 1, 1, 1, self.num_deformable_group]) for i in [y, x]]
y, x = [tf.reshape(i, [*i.shape[0: 3], -1]) for i in [y, x]]
y, x = [tf.to_float(i) for i in [y, x]]
# add offset
y, x = y + y_off, x + x_off
y = tf.clip_by_value(y, 0, in_h - 1)
x = tf.clip_by_value(x, 0, in_w - 1)
# get four coordinates of points around (x, y)
y0, x0 = [tf.to_int32(tf.floor(i)) for i in [y, x]]
y1, x1 = y0 + 1, x0 + 1
# clip
y0, y1 = [tf.clip_by_value(i, 0, in_h - 1) for i in [y0, y1]]
x0, x1 = [tf.clip_by_value(i, 0, in_w - 1) for i in [x0, x1]]
# get pixel values
indices = [[y0, x0], [y0, x1], [y1, x0], [y1, x1]]
p0, p1, p2, p3 = [DeformableConvLayer._get_pixel_values_at_point(inputs, i) for i in indices]
# cast to float
x0, x1, y0, y1 = [tf.to_float(i) for i in [x0, x1, y0, y1]]
# weights
w0 = (y1 - y) * (x1 - x)
w1 = (y1 - y) * (x - x0)
w2 = (y - y0) * (x1 - x)
w3 = (y - y0) * (x - x0)
# expand dim for broadcast
w0, w1, w2, w3 = [tf.expand_dims(i, axis=-1) for i in [w0, w1, w2, w3]]
# bilinear interpolation
pixels = tf.add_n([w0 * p0, w1 * p1, w2 * p2, w3 * p3])
# reshape the "big" feature map
pixels = tf.reshape(pixels, [batch_size, out_h, out_w, filter_h, filter_w, self.num_deformable_group, channel_in])
pixels = tf.transpose(pixels, [0, 1, 3, 2, 4, 5, 6])
pixels = tf.reshape(pixels, [batch_size, out_h * filter_h, out_w * filter_w, self.num_deformable_group, channel_in])
# copy channels to same group
feat_in_group = self.filters // self.num_deformable_group
pixels = tf.tile(pixels, [1, 1, 1, 1, feat_in_group])
pixels = tf.reshape(pixels, [batch_size, out_h * filter_h, out_w * filter_w, -1])
# depth-wise conv
out = tf.nn.depthwise_conv2d(pixels, self.kernel, [1, filter_h, filter_w, 1], 'VALID')
# add the output feature maps in the same group
out = tf.reshape(out, [batch_size, out_h, out_w, self.filters, channel_in])
out = tf.reduce_sum(out, axis=-1)
if self.use_bias:
out += self.bias
return self.activation(out)
def _pad_input(self, inputs):
"""Check if input feature map needs padding, because we don't use the standard Conv() function.
:param inputs:
:return: padded input feature map
"""
# When padding is 'same', we should pad the feature map.
# if padding == 'same', output size should be `ceil(input / stride)`
if self.padding == 'same':
in_shape = inputs.get_shape().as_list()[1: 3]
padding_list = []
for i in range(2):
filter_size = self.kernel_size[i]
dilation = self.dilation_rate[i]
dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
same_output = (in_shape[i] + self.strides[i] - 1) // self.strides[i]
valid_output = (in_shape[i] - dilated_filter_size + self.strides[i]) // self.strides[i]
if same_output == valid_output:
padding_list += [0, 0]
else:
p = dilated_filter_size - 1
p_0 = p // 2
padding_list += [p_0, p - p_0]
if sum(padding_list) != 0:
padding = [[0, 0],
[padding_list[0], padding_list[1]], # top, bottom padding
[padding_list[2], padding_list[3]], # left, right padding
[0, 0]]
inputs = tf.pad(inputs, padding)
return inputs
def _get_conv_indices(self, feature_map_size):
"""the x, y coordinates in the window when a filter sliding on the feature map
:param feature_map_size:
:return: y, x with shape [1, out_h, out_w, filter_h * filter_w]
"""
feat_h, feat_w = [int(i) for i in feature_map_size[0: 2]]
x, y = tf.meshgrid(tf.range(feat_w), tf.range(feat_h))
x, y = [tf.reshape(i, [1, *i.get_shape(), 1]) for i in [x, y]] # shape [1, h, w, 1]
x, y = [tf.image.extract_image_patches(i,
[1, *self.kernel_size, 1],
[1, *self.strides, 1],
[1, *self.dilation_rate, 1],
'VALID')
for i in [x, y]] # shape [1, out_h, out_w, filter_h * filter_w]
return y, x
@staticmethod
def _get_pixel_values_at_point(inputs, indices):
"""get pixel values
:param inputs:
:param indices: shape [batch_size, H, W, I], I = filter_h * filter_w * channel_out
:return:
"""
y, x = indices
batch, h, w, n = y.get_shape().as_list()[0: 4]
batch_idx = tf.reshape(tf.range(0, batch), (batch, 1, 1, 1))
b = tf.tile(batch_idx, (1, h, w, n))
pixel_idx = tf.stack([b, y, x], axis=-1)
return tf.gather_nd(inputs, pixel_idx)