Skip to content

Commit

Permalink
[Relay][Frontend][Tensorflow]Add conv2d_transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
optima2005 committed Nov 11, 2019
1 parent d2fc025 commit 08e3fef
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 20 deletions.
52 changes: 33 additions & 19 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,24 +188,30 @@ def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False

if opname == 'conv_transpose' and attr['data_format'] == 'NHWC':
raise NotImplementedError( \
"conv2d_transpose with NHWC layout is not implemented.")

inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2]

# NCHW Layout require weights transpose
if attr['data_format'] == 'NCHW':
tmp_shape = attr['_input_shapes'][inputs[1]]
if opname == 'conv':
if opname in ['conv', 'conv_transpose']:
tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1))
else:
tmp_shape = [tmp_shape[ii] for ii in (2, 3, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1))
attr['_input_shapes'][inputs[1]] = tmp_shape

input_shape = attr['_input_shapes'][inputs[0]]
input_shape = attr['_input_shapes'][inputs_data]
weights_shape = attr['_input_shapes'][inputs[1]]

if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2))
if opname == 'conv':
inputs_data = _op.transpose(inputs_data, axes=(0, 3, 1, 2))
if opname in ['conv', 'conv_transpose']:
weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1))
else:
Expand All @@ -221,6 +227,8 @@ def _impl(inputs, attr, params):
attr['kernel_shape'] = (weights_shape[0], weights_shape[1])
if opname == 'conv':
attr['channels'] = weights_shape[3]
elif opname == 'conv_transpose':
attr['channels'] = weights_shape[2]
else:
attr['channels'] = input_shape[3] * depth_mult

Expand All @@ -232,6 +240,8 @@ def _impl(inputs, attr, params):
attr['kernel_shape'] = (weights_shape[2], weights_shape[3])
if opname == 'conv':
attr['channels'] = weights_shape[0]
elif opname == 'conv_transpose':
attr['channels'] = weights_shape[1]
else:
attr['channels'] = input_shape[1] * depth_mult
if attr['channels'] < 0:
Expand Down Expand Up @@ -272,17 +282,17 @@ def _impl(inputs, attr, params):


if attr['data_format'] == 'NHWC':
inputs[0] = _op.nn.pad(data=inputs[0],
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
else:
inputs[0] = _op.nn.pad(data=inputs[0],
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))

attr['padding'] = [0, 0]

Expand All @@ -292,25 +302,28 @@ def _impl(inputs, attr, params):
raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))

if 'kernel_layout' not in attr:
if opname == 'conv':
if opname in ['conv', 'conv_transpose']:
attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW'
else:
attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'

use_bias = len(inputs) == 3
use_bias = len(inputs) == (3 if opname != 'conv_transpose' else 4)
channel_axis = 1 if attr['data_format'] == "NCHW" else 3

out = AttrCvt(
op_name=_dimension_picker('conv'),
op_name=_dimension_picker('conv', \
surfix="_transpose" if opname == 'conv_transpose' else ""),
transforms={
'kernel_shape': 'kernel_size',
'data_format': 'data_layout',
'dilations': ('dilation', (0, 0)),
'group': ('groups', 1)},
custom_check=_dimension_constraint())([inputs[0], inputs[1]], attr)
custom_check=_dimension_constraint())([inputs_data, inputs[1]], attr)

if use_bias:
out = _op.nn.bias_add(out, inputs[2], axis=channel_axis)
out = _op.nn.bias_add(out,
inputs[2] if opname != 'conv_transpose' else inputs[3],
axis=channel_axis)

if flip_layout:
out = _op.transpose(out, axes=(0, 2, 3, 1))
Expand Down Expand Up @@ -1385,6 +1398,7 @@ def _impl(inputs, attr, params):
'Concat' : _concat(),
'ConcatV2' : _concatV2(),
'Conv2D' : _conv('conv'),
'Conv2DBackpropInput' : _conv('conv_transpose'),
'CropAndResize' : _crop_and_resize(),
'DecodeJpeg' : _decode_image(),
'DepthwiseConv2dNative' : _conv('depthwise'),
Expand Down
22 changes: 21 additions & 1 deletion tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ def test_forward_pooling():


def _test_convolution(opname, tensor_in_sizes, filter_in_sizes,
dilations, strides, padding, data_format):
dilations, strides, padding, data_format,
deconv_output_shape=[]):
""" One iteration of convolution with given shapes and attributes """

total_size_1 = np.prod(tensor_in_sizes)
Expand Down Expand Up @@ -326,6 +327,17 @@ def _test_convolution(opname, tensor_in_sizes, filter_in_sizes,

compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
'Placeholder:0', 'Conv2D:0')
elif opname == 'conv_transpose':
nn_ops.conv2d_transpose(in_data,
in_filter,
output_shape=deconv_output_shape,
strides=strides,
dilations=dilations,
padding=padding,
data_format=data_format)

compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
'Placeholder:0', 'conv2d_transpose:0')
else:
nn_ops.depthwise_conv2d_native(in_data,
in_filter,
Expand All @@ -349,6 +361,14 @@ def test_forward_convolution():
_test_convolution('depthwise', [4, 124, 17, 17], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NCHW')
_test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution('conv_transpose', [4, 32, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 176, 8, 8])
_test_convolution('conv_transpose', [4, 19, 8, 8], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 19, 17, 17])
_test_convolution('conv_transpose', [4, 19, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 124, 17, 17])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 12, 17, 17])

_test_convolution('conv', [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution('conv', [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
Expand Down

0 comments on commit 08e3fef

Please sign in to comment.