diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 7c1d34f3fd2c..b5bdbb9e8929 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -205,6 +205,12 @@ def _impl(inputs, attr, params): attr['strides'][1], attr['strides'][2], attr['strides'][3] = \ attr['strides'][3], attr['strides'][1], attr['strides'][2] attr['data_format'] = 'NCHW' + + if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0: + tmp_shape = attr['_output_shapes'][0] + tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] + attr['_output_shapes'][0] = tmp_shape + flip_layout = True inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2] @@ -281,12 +287,17 @@ def _impl(inputs, attr, params): elif attr['padding'] == 'SAME': stride_h, stride_w = attr['strides'] kernel_h, kernel_w = attr['kernel_shape'] + + pdata_shape = input_shape + if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0: + pdata_shape = attr['_output_shapes'][0] + if attr['data_format'] == 'NHWC': - in_h = input_shape[1] - in_w = input_shape[2] + in_h = pdata_shape[1] + in_w = pdata_shape[2] else: - in_h = input_shape[2] - in_w = input_shape[3] + in_h = pdata_shape[2] + in_w = pdata_shape[3] dilation_h = attr['dilations'][0] dilation_w = attr['dilations'][1] @@ -295,21 +306,23 @@ def _impl(inputs, attr, params): pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) + if opname != 'conv_transpose': + if attr['data_format'] == 'NHWC': + inputs_data = _op.nn.pad(data=inputs_data, + pad_width=((0, 0), + (pad_v[0], pad_v[1]), + (pad_h[0], pad_h[1]), + (0, 0))) + else: + inputs_data = _op.nn.pad(data=inputs_data, + pad_width=((0, 0), + (0, 0), + (pad_v[0], pad_v[1]), + (pad_h[0], pad_h[1]))) - if attr['data_format'] == 'NHWC': - inputs_data = _op.nn.pad(data=inputs_data, - pad_width=((0, 0), - (pad_v[0], pad_v[1]), - (pad_h[0], pad_h[1]), - (0, 0))) + attr['padding'] = [0, 0] else: - inputs_data = _op.nn.pad(data=inputs_data, - pad_width=((0, 0), - (0, 0), - (pad_v[0], pad_v[1]), - (pad_h[0], pad_h[1]))) - - attr['padding'] = [0, 0] + attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] else: msg = 'Value {} in attribute "padding" of operator Conv is not ' \ diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 3c9bebc1b0d0..234d575184e3 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -248,10 +248,22 @@ bool Conv2DTransposeRel(const Array& types, } // dilation Array oshape({dshape_nchw[0], channels, 0, 0}); + auto pad_h = param->padding[0]; + auto pad_w = param->padding[1]; + if (param->padding.size() == 2) { + pad_h *= 2; + pad_w *= 2; + } else if (param->padding.size() == 4) { + pad_h += param->padding[2]; + pad_w += param->padding[3]; + } else { + CHECK_EQ(param->padding.size(), 4) << " Padding should be 2 or 4, but got " + << param->padding.size(); + } oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - - 2 * param->padding[0] + param->output_padding[0])); + pad_h + param->output_padding[0])); oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - - 2 * param->padding[1] + param->output_padding[1])); + pad_w + param->output_padding[1])); DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index cb4a9263b8f1..bcdc617f3f75 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -362,10 +362,22 @@ def test_forward_convolution(): _test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NCHW') _test_convolution('conv_transpose', [4, 32, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NCHW', [4, 176, 8, 8]) + _test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 176, 32], [1, 1], [1, 1], 'SAME', + 'NCHW', [4, 176, 8, 8]) + _test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 176, 32], [1, 1], [2, 2], 'SAME', + 'NCHW', [4, 176, 15, 15]) + _test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [1, 1], 'SAME', + 'NCHW', [4, 176, 8, 8]) + _test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME', + 'NCHW', [4, 176, 15, 15]) + _test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME', + 'NCHW', [4, 176, 16, 16]) _test_convolution('conv_transpose', [4, 19, 8, 8], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NCHW', [4, 19, 17, 17]) _test_convolution('conv_transpose', [4, 19, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NCHW', [4, 124, 17, 17]) + _test_convolution('conv_transpose', [4, 19, 17, 17], [3, 3, 124, 19], [1, 1], [1, 1], 'SAME', + 'NCHW', [4, 124, 17, 17]) _test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NCHW', [4, 12, 17, 17]) # kernel 2x2, strides (2,2) @@ -388,10 +400,22 @@ def test_forward_convolution(): _test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC') _test_convolution('conv_transpose', [4, 8, 8, 32], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC', [4, 8, 8, 176]) + _test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 176, 32], [1, 1], [1, 1], 'SAME', + 'NHWC', [4, 8, 8, 176]) + _test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 176, 32], [1, 1], [2, 2], 'SAME', + 'NHWC', [4, 15, 15, 176]) + _test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [1, 1], 'SAME', + 'NHWC', [4, 8, 8, 176]) + _test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME', + 'NHWC', [4, 15, 15, 176]) + _test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME', + 'NHWC', [4, 16, 16, 176]) _test_convolution('conv_transpose', [4, 8, 8, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC', [4, 17, 17, 19]) _test_convolution('conv_transpose', [4, 17, 17, 19], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC', [4, 17, 17, 124]) + _test_convolution('conv_transpose', [4, 17, 17, 19], [3, 3, 124, 19], [1, 1], [1, 1], 'SAME', + 'NHWC', [4, 17, 17, 124]) _test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC', [4, 17, 17, 12]) # kernel 2x2, strides (2,2) diff --git a/topi/python/topi/cuda/conv2d_transpose_nchw.py b/topi/python/topi/cuda/conv2d_transpose_nchw.py index a3a4cfe6c87f..274dfb03e794 100644 --- a/topi/python/topi/cuda/conv2d_transpose_nchw.py +++ b/topi/python/topi/cuda/conv2d_transpose_nchw.py @@ -197,6 +197,8 @@ def _callback(op): do_fallback = False elif (kh, kw) == (1, 1): do_fallback = True + elif (stride_h, stride_w) == (2, 2): + do_fallback = False elif (kh, kw) == (stride_h, stride_w): do_fallback = False diff --git a/topi/python/topi/nn/util.py b/topi/python/topi/nn/util.py index 463edaa463dc..2281f49e4037 100644 --- a/topi/python/topi/nn/util.py +++ b/topi/python/topi/nn/util.py @@ -103,8 +103,13 @@ def get_pad_tuple(padding, kernel): """ # compute the padding size if isinstance(padding, (tuple, list)): - pad_h = padding[0] * 2 - pad_w = padding[1] * 2 + if len(padding) == 2: + pad_h = padding[0] * 2 + pad_w = padding[1] * 2 + elif len(padding) == 4: + return padding[0], padding[1], padding[2], padding[3] + else: + raise ValueError("Size of padding can only be 2 or 4") elif isinstance(padding, int): pad_h = pad_w = padding * 2 elif padding == "VALID":