Skip to content

Commit

Permalink
[Relay][Frontend][Tensorflow]Add conv2d_transpose (apache#4300)
Browse files Browse the repository at this point in the history
* [Relay][Frontend][Tensorflow]Add conv2d_transpose

* add transformation from NHWC to NCHW to compatible with TVM conv2d_transpose implementation

* remove 'dilations' paramater to compitable with TF1.3
  • Loading branch information
optima2005 authored and yongwww committed Nov 26, 2019
1 parent 133b68b commit a1b5829
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 20 deletions.
60 changes: 41 additions & 19 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,24 +195,38 @@ def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False

if opname == 'conv_transpose' and attr['data_format'] == 'NHWC':
# transform to NCHW for TVM backend compatible and set 'flip_layout'
# to have output flip back to NHWC
tmp_shape = attr['_input_shapes'][inputs[2]]
tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
inputs[2] = _op.transpose(inputs[2], axes=(0, 3, 1, 2))
attr['_input_shapes'][inputs[2]] = tmp_shape
attr['strides'][1], attr['strides'][2], attr['strides'][3] = \
attr['strides'][3], attr['strides'][1], attr['strides'][2]
attr['data_format'] = 'NCHW'
flip_layout = True

inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2]

# NCHW Layout require weights transpose
if attr['data_format'] == 'NCHW':
tmp_shape = attr['_input_shapes'][inputs[1]]
if opname == 'conv':
if opname in ['conv', 'conv_transpose']:
tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1))
else:
tmp_shape = [tmp_shape[ii] for ii in (2, 3, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1))
attr['_input_shapes'][inputs[1]] = tmp_shape

input_shape = attr['_input_shapes'][inputs[0]]
input_shape = attr['_input_shapes'][inputs_data]
weights_shape = attr['_input_shapes'][inputs[1]]

if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2))
if opname == 'conv':
inputs_data = _op.transpose(inputs_data, axes=(0, 3, 1, 2))
if opname in ['conv', 'conv_transpose']:
weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1))
else:
Expand All @@ -228,6 +242,8 @@ def _impl(inputs, attr, params):
attr['kernel_shape'] = (weights_shape[0], weights_shape[1])
if opname == 'conv':
attr['channels'] = weights_shape[3]
elif opname == 'conv_transpose':
attr['channels'] = weights_shape[2]
else:
attr['channels'] = input_shape[3] * depth_mult

Expand All @@ -239,6 +255,8 @@ def _impl(inputs, attr, params):
attr['kernel_shape'] = (weights_shape[2], weights_shape[3])
if opname == 'conv':
attr['channels'] = weights_shape[0]
elif opname == 'conv_transpose':
attr['channels'] = weights_shape[1]
else:
attr['channels'] = input_shape[1] * depth_mult
if attr['channels'] < 0:
Expand Down Expand Up @@ -279,17 +297,17 @@ def _impl(inputs, attr, params):


if attr['data_format'] == 'NHWC':
inputs[0] = _op.nn.pad(data=inputs[0],
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
else:
inputs[0] = _op.nn.pad(data=inputs[0],
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))

attr['padding'] = [0, 0]

Expand All @@ -299,27 +317,30 @@ def _impl(inputs, attr, params):
raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))

if 'kernel_layout' not in attr:
if opname == 'conv':
if opname in ['conv', 'conv_transpose']:
attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW'
else:
attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'

use_bias = len(inputs) == 3
use_bias = len(inputs) == (3 if opname != 'conv_transpose' else 4)
channel_axis = 1 if attr['data_format'] == "NCHW" else 3

# Ignore the new attributes from TF2.0, for now.
out = AttrCvt(
op_name=_dimension_picker('conv'),
op_name=_dimension_picker('conv', \
surfix="_transpose" if opname == 'conv_transpose' else ""),
ignores=['explicit_paddings'],
transforms={
'kernel_shape': 'kernel_size',
'data_format': 'data_layout',
'dilations': ('dilation', (0, 0)),
'group': ('groups', 1)},
custom_check=_dimension_constraint())([inputs[0], inputs[1]], attr)
custom_check=_dimension_constraint())([inputs_data, inputs[1]], attr)

if use_bias:
out = _op.nn.bias_add(out, inputs[2], axis=channel_axis)
out = _op.nn.bias_add(out,
inputs[2] if opname != 'conv_transpose' else inputs[3],
axis=channel_axis)

if flip_layout:
out = _op.transpose(out, axes=(0, 2, 3, 1))
Expand Down Expand Up @@ -1403,6 +1424,7 @@ def _impl(inputs, attr, params):
'Concat' : _concat(),
'ConcatV2' : _concatV2(),
'Conv2D' : _conv('conv'),
'Conv2DBackpropInput' : _conv('conv_transpose'),
'CropAndResize' : _crop_and_resize(),
'DecodeJpeg' : _decode_image(),
'DepthwiseConv2dNative' : _conv('depthwise'),
Expand Down
30 changes: 29 additions & 1 deletion tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ def test_forward_pooling():


def _test_convolution(opname, tensor_in_sizes, filter_in_sizes,
dilations, strides, padding, data_format):
dilations, strides, padding, data_format,
deconv_output_shape=[]):
""" One iteration of convolution with given shapes and attributes """

total_size_1 = np.prod(tensor_in_sizes)
Expand Down Expand Up @@ -326,6 +327,16 @@ def _test_convolution(opname, tensor_in_sizes, filter_in_sizes,

compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
'Placeholder:0', 'Conv2D:0')
elif opname == 'conv_transpose':
nn_ops.conv2d_transpose(in_data,
in_filter,
output_shape=deconv_output_shape,
strides=strides,
padding=padding,
data_format=data_format)

compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
'Placeholder:0', 'conv2d_transpose:0')
else:
nn_ops.depthwise_conv2d_native(in_data,
in_filter,
Expand All @@ -349,6 +360,14 @@ def test_forward_convolution():
_test_convolution('depthwise', [4, 124, 17, 17], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NCHW')
_test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution('conv_transpose', [4, 32, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 176, 8, 8])
_test_convolution('conv_transpose', [4, 19, 8, 8], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 19, 17, 17])
_test_convolution('conv_transpose', [4, 19, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 124, 17, 17])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 12, 17, 17])

_test_convolution('conv', [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution('conv', [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
Expand All @@ -359,6 +378,15 @@ def test_forward_convolution():
_test_convolution('depthwise', [4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution('conv_transpose', [4, 8, 8, 32], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 8, 8, 176])
_test_convolution('conv_transpose', [4, 8, 8, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 17, 17, 19])
_test_convolution('conv_transpose', [4, 17, 17, 19], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 17, 17, 124])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 17, 17, 12])


#######################################################################
# BiasAdd
Expand Down

0 comments on commit a1b5829

Please sign in to comment.