From 84a6c4ba5b08f14760ad1ddf2949d70c171f334c Mon Sep 17 00:00:00 2001 From: optima2005 <56945758+optima2005@users.noreply.github.com> Date: Sun, 29 Dec 2019 04:05:14 +0800 Subject: [PATCH] [FRONTEND][TF] conv2d_transpose 'SAME' support kernel more than 1x1 (#4484) * [FRONTEND][TF] conv3d_transpose 'SAME' support kernel more than 1x1 * revised per as review comments * add more fallback wolkaround to make all tests pass --- python/tvm/relay/frontend/tensorflow.py | 47 ++++++++++++------- src/relay/op/nn/convolution.cc | 16 ++++++- .../frontend/tensorflow/test_forward.py | 24 ++++++++++ .../python/topi/cuda/conv2d_transpose_nchw.py | 2 + topi/python/topi/nn/util.py | 9 +++- 5 files changed, 77 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index a07f0dd29828..f748fe828bfd 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -269,6 +269,12 @@ def _impl(inputs, attr, params): attr['strides'][1], attr['strides'][2], attr['strides'][3] = \ attr['strides'][3], attr['strides'][1], attr['strides'][2] attr['data_format'] = 'NCHW' + + if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0: + tmp_shape = attr['_output_shapes'][0] + tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] + attr['_output_shapes'][0] = tmp_shape + flip_layout = True inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2] @@ -345,12 +351,17 @@ def _impl(inputs, attr, params): elif attr['padding'] == 'SAME': stride_h, stride_w = attr['strides'] kernel_h, kernel_w = attr['kernel_shape'] + + pdata_shape = input_shape + if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0: + pdata_shape = attr['_output_shapes'][0] + if attr['data_format'] == 'NHWC': - in_h = input_shape[1] - in_w = input_shape[2] + in_h = pdata_shape[1] + in_w = pdata_shape[2] else: - in_h = input_shape[2] - in_w = input_shape[3] + in_h = pdata_shape[2] + in_w = pdata_shape[3] dilation_h = attr['dilations'][0] dilation_w = attr['dilations'][1] @@ -359,21 +370,23 @@ def _impl(inputs, attr, params): pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) + if opname != 'conv_transpose': + if attr['data_format'] == 'NHWC': + inputs_data = _op.nn.pad(data=inputs_data, + pad_width=((0, 0), + (pad_v[0], pad_v[1]), + (pad_h[0], pad_h[1]), + (0, 0))) + else: + inputs_data = _op.nn.pad(data=inputs_data, + pad_width=((0, 0), + (0, 0), + (pad_v[0], pad_v[1]), + (pad_h[0], pad_h[1]))) - if attr['data_format'] == 'NHWC': - inputs_data = _op.nn.pad(data=inputs_data, - pad_width=((0, 0), - (pad_v[0], pad_v[1]), - (pad_h[0], pad_h[1]), - (0, 0))) + attr['padding'] = [0, 0] else: - inputs_data = _op.nn.pad(data=inputs_data, - pad_width=((0, 0), - (0, 0), - (pad_v[0], pad_v[1]), - (pad_h[0], pad_h[1]))) - - attr['padding'] = [0, 0] + attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] else: msg = 'Value {} in attribute "padding" of operator Conv is not ' \ diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 5f1b194a2b3a..4a1fd466108d 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -249,10 +249,22 @@ bool Conv2DTransposeRel(const Array& types, } // dilation Array oshape({dshape_nchw[0], channels, 0, 0}); + auto pad_h = param->padding[0]; + auto pad_w = param->padding[1]; + if (param->padding.size() == 2) { + pad_h *= 2; + pad_w *= 2; + } else if (param->padding.size() == 4) { + pad_h += param->padding[2]; + pad_w += param->padding[3]; + } else { + CHECK_EQ(param->padding.size(), 4) << " Padding should be 2 or 4, but got " + << param->padding.size(); + } oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - - 2 * param->padding[0] + param->output_padding[0])); + pad_h + param->output_padding[0])); oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - - 2 * param->padding[1] + param->output_padding[1])); + pad_w + param->output_padding[1])); DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 7163eead8435..9b7fe62306fd 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -403,10 +403,22 @@ def test_forward_convolution(): _test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NCHW') _test_convolution('conv_transpose', [4, 32, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NCHW', [4, 176, 8, 8]) + _test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 176, 32], [1, 1], [1, 1], 'SAME', + 'NCHW', [4, 176, 8, 8]) + _test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 176, 32], [1, 1], [2, 2], 'SAME', + 'NCHW', [4, 176, 15, 15]) + _test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [1, 1], 'SAME', + 'NCHW', [4, 176, 8, 8]) + _test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME', + 'NCHW', [4, 176, 15, 15]) + _test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME', + 'NCHW', [4, 176, 16, 16]) _test_convolution('conv_transpose', [4, 19, 8, 8], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NCHW', [4, 19, 17, 17]) _test_convolution('conv_transpose', [4, 19, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NCHW', [4, 124, 17, 17]) + _test_convolution('conv_transpose', [4, 19, 17, 17], [3, 3, 124, 19], [1, 1], [1, 1], 'SAME', + 'NCHW', [4, 124, 17, 17]) _test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NCHW', [4, 12, 17, 17]) # kernel 2x2, strides (2,2) @@ -429,10 +441,22 @@ def test_forward_convolution(): _test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC') _test_convolution('conv_transpose', [4, 8, 8, 32], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC', [4, 8, 8, 176]) + _test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 176, 32], [1, 1], [1, 1], 'SAME', + 'NHWC', [4, 8, 8, 176]) + _test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 176, 32], [1, 1], [2, 2], 'SAME', + 'NHWC', [4, 15, 15, 176]) + _test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [1, 1], 'SAME', + 'NHWC', [4, 8, 8, 176]) + _test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME', + 'NHWC', [4, 15, 15, 176]) + _test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME', + 'NHWC', [4, 16, 16, 176]) _test_convolution('conv_transpose', [4, 8, 8, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC', [4, 17, 17, 19]) _test_convolution('conv_transpose', [4, 17, 17, 19], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC', [4, 17, 17, 124]) + _test_convolution('conv_transpose', [4, 17, 17, 19], [3, 3, 124, 19], [1, 1], [1, 1], 'SAME', + 'NHWC', [4, 17, 17, 124]) _test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC', [4, 17, 17, 12]) # kernel 2x2, strides (2,2) diff --git a/topi/python/topi/cuda/conv2d_transpose_nchw.py b/topi/python/topi/cuda/conv2d_transpose_nchw.py index a3a4cfe6c87f..274dfb03e794 100644 --- a/topi/python/topi/cuda/conv2d_transpose_nchw.py +++ b/topi/python/topi/cuda/conv2d_transpose_nchw.py @@ -197,6 +197,8 @@ def _callback(op): do_fallback = False elif (kh, kw) == (1, 1): do_fallback = True + elif (stride_h, stride_w) == (2, 2): + do_fallback = False elif (kh, kw) == (stride_h, stride_w): do_fallback = False diff --git a/topi/python/topi/nn/util.py b/topi/python/topi/nn/util.py index 1ee086288f60..847a5c84daaa 100644 --- a/topi/python/topi/nn/util.py +++ b/topi/python/topi/nn/util.py @@ -103,8 +103,13 @@ def get_pad_tuple(padding, kernel): """ # compute the padding size if isinstance(padding, (tuple, list)): - pad_h = padding[0] * 2 - pad_w = padding[1] * 2 + if len(padding) == 2: + pad_h = padding[0] * 2 + pad_w = padding[1] * 2 + elif len(padding) == 4: + return padding[0], padding[1], padding[2], padding[3] + else: + raise ValueError("Size of padding can only be 2 or 4") elif isinstance(padding, int): pad_h = pad_w = padding * 2 elif padding == "VALID":