forked from ikostrikov/TensorFlow-VAE-GAN-DRAW
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdeconv.py
201 lines (178 loc) · 7.35 KB
/
deconv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# Copyright 2015 Google Inc. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A quick hack to try deconv out."""
import collections
import tensorflow as tf
from tensorflow.python.framework import tensor_shape
from prettytensor import layers
from prettytensor import pretty_tensor_class as prettytensor
from prettytensor.pretty_tensor_class import PAD_SAME
from prettytensor.pretty_tensor_class import Phase
from prettytensor.pretty_tensor_class import PROVIDED
# pylint: disable=redefined-outer-name,invalid-name
@prettytensor.Register(
assign_defaults=('activation_fn', 'l2loss', 'stddev', 'batch_normalize'))
class deconv2d(prettytensor.VarStoreMethod):
def __call__(self,
input_layer,
kernel,
depth,
name=PROVIDED,
stride=None,
activation_fn=None,
l2loss=None,
init=None,
stddev=None,
bias=True,
edges=PAD_SAME,
batch_normalize=False):
"""Adds a convolution to the stack of operations.
The current head must be a rank 4 Tensor.
Args:
input_layer: The chainable object, supplied.
kernel: The size of the patch for the pool, either an int or a length 1 or
2 sequence (if length 1 or int, it is expanded).
depth: The depth of the new Tensor.
name: The name for this operation is also used to create/find the
parameter variables.
stride: The strides as a length 1, 2 or 4 sequence or an integer. If an
int, length 1 or 2, the stride in the first and last dimensions are 1.
activation_fn: A tuple of (activation_function, extra_parameters). Any
function that takes a tensor as its first argument can be used. More
common functions will have summaries added (e.g. relu).
l2loss: Set to a value greater than 0 to use L2 regularization to decay
the weights.
init: An optional initialization. If not specified, uses Xavier
initialization.
stddev: A standard deviation to use in parameter initialization.
bias: Set to False to not have a bias.
edges: Either SAME to use 0s for the out of bounds area or VALID to shrink
the output size and only uses valid input pixels.
batch_normalize: Set to True to batch_normalize this layer.
Returns:
Handle to the generated layer.
Raises:
ValueError: If head is not a rank 4 tensor or the depth of the input
(4th dim) is not known.
"""
if len(input_layer.shape) != 4:
raise ValueError(
'Cannot perform conv2d on tensor with shape %s' % input_layer.shape)
if input_layer.shape[3] is None:
raise ValueError('Input depth must be known')
kernel = _kernel(kernel)
stride = _stride(stride)
size = [kernel[0], kernel[1], depth, input_layer.shape[3]]
books = input_layer.bookkeeper
if init is None:
if stddev is None:
patch_size = size[0] * size[1]
init = layers.xavier_init(size[2] * patch_size, size[3] * patch_size)
elif stddev:
init = tf.truncated_normal_initializer(stddev=stddev)
else:
init = tf.zeros_initializer
elif stddev is not None:
raise ValueError('Do not set both init and stddev.')
dtype = input_layer.tensor.dtype
params = self.variable('weights', size, init, dt=dtype)
input_height = input_layer.shape[1]
input_width = input_layer.shape[2]
filter_height = kernel[0]
filter_width = kernel[1]
row_stride = stride[1]
col_stride = stride[2]
out_rows, out_cols = get2d_deconv_output_size(input_height, input_width, filter_height,
filter_width, row_stride, col_stride, edges)
output_shape = [input_layer.shape[0], out_rows, out_cols, depth]
y = tf.nn.conv2d_transpose(input_layer, params, output_shape, stride, edges)
layers.add_l2loss(books, params, l2loss)
if bias:
y += self.variable(
'bias',
[size[-2]],
tf.zeros_initializer,
dt=dtype)
books.add_scalar_summary(
tf.reduce_mean(
layers.spatial_slice_zeros(y)), '%s/zeros_spatial' % y.op.name)
if batch_normalize:
y = input_layer.with_tensor(y).batch_normalize()
if activation_fn is not None:
if not isinstance(activation_fn, collections.Sequence):
activation_fn = (activation_fn,)
y = layers.apply_activation(
books,
y,
activation_fn[0],
activation_args=activation_fn[1:])
return input_layer.with_tensor(y)
# pylint: enable=redefined-outer-name,invalid-name
# Helper methods
def get2d_deconv_output_size(input_height, input_width, filter_height,
filter_width, row_stride, col_stride, padding_type):
"""Returns the number of rows and columns in a convolution/pooling output."""
input_height = tensor_shape.as_dimension(input_height)
input_width = tensor_shape.as_dimension(input_width)
filter_height = tensor_shape.as_dimension(filter_height)
filter_width = tensor_shape.as_dimension(filter_width)
row_stride = int(row_stride)
col_stride = int(col_stride)
# Compute number of rows in the output, based on the padding.
if input_height.value is None or filter_height.value is None:
out_rows = None
elif padding_type == "VALID":
out_rows = (input_height.value - 1) * row_stride + filter_height.value
elif padding_type == "SAME":
out_rows = input_height.value * row_stride
else:
raise ValueError("Invalid value for padding: %r" % padding_type)
# Compute number of columns in the output, based on the padding.
if input_width.value is None or filter_width.value is None:
out_cols = None
elif padding_type == "VALID":
out_cols = (input_width.value - 1) * col_stride + filter_width.value
elif padding_type == "SAME":
out_cols = input_width.value * col_stride
return out_rows, out_cols
def _kernel(kernel_spec):
"""Expands the kernel spec into a length 2 list.
Args:
kernel_spec: An integer or a length 1 or 2 sequence that is expanded to a
list.
Returns:
A length 2 list.
"""
if isinstance(kernel_spec, int):
return [kernel_spec, kernel_spec]
elif len(kernel_spec) == 1:
return [kernel_spec[0], kernel_spec[0]]
else:
assert len(kernel_spec) == 2
return kernel_spec
def _stride(stride_spec):
"""Expands the stride spec into a length 4 list.
Args:
stride_spec: None, an integer or a length 1, 2, or 4 sequence.
Returns:
A length 4 list.
"""
if stride_spec is None:
return [1, 1, 1, 1]
elif isinstance(stride_spec, int):
return [1, stride_spec, stride_spec, 1]
elif len(stride_spec) == 1:
return [1, stride_spec[0], stride_spec[0], 1]
elif len(stride_spec) == 2:
return [1, stride_spec[0], stride_spec[1], 1]
else:
assert len(stride_spec) == 4
return stride_spec