diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 4ec8abdfb336..cb4a9263b8f1 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -368,6 +368,14 @@ def test_forward_convolution(): 'NCHW', [4, 124, 17, 17]) _test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NCHW', [4, 12, 17, 17]) + # kernel 2x2, strides (2,2) + _test_convolution('conv_transpose', [4, 19, 8, 8], [2, 2, 19, 19], [1, 1], [2, 2], 'VALID', + 'NCHW', [4, 19, 16, 16]) + _test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 12, 32], [1, 1], [2, 2], 'VALID', + 'NCHW', [4, 12, 16, 16]) + # output channel is 1 + _test_convolution('conv_transpose', [1, 19, 8, 8], [1, 1, 1, 19], [1, 1], [1, 1], 'VALID', + 'NCHW', [1, 1, 8, 8]) _test_convolution('conv', [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC') _test_convolution('conv', [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC') @@ -386,6 +394,14 @@ def test_forward_convolution(): 'NHWC', [4, 17, 17, 124]) _test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC', [4, 17, 17, 12]) + # kernel 2x2, strides (2,2) + _test_convolution('conv_transpose', [4, 8, 8, 19], [2, 2, 19, 19], [1, 1], [2, 2], 'VALID', + 'NHWC', [4, 16, 16, 19]) + _test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 12, 32], [1, 1], [2, 2], 'VALID', + 'NHWC', [4, 16, 16, 12]) + # output channel is 1 + _test_convolution('conv_transpose', [1, 8, 8, 19], [1, 1, 1, 19], [1, 1], [1, 1], 'VALID', + 'NHWC', [1, 8, 8, 1]) ####################################################################### diff --git a/topi/python/topi/cuda/conv2d_transpose_nchw.py b/topi/python/topi/cuda/conv2d_transpose_nchw.py index 320187886322..a3a4cfe6c87f 100644 --- a/topi/python/topi/cuda/conv2d_transpose_nchw.py +++ b/topi/python/topi/cuda/conv2d_transpose_nchw.py @@ -185,8 +185,24 @@ def _callback(op): cfg.define_knob("unroll_explicit", [0, 1]) if cfg.is_fallback: - N, F, Y, X = get_const_tuple(conv.shape) - _fallback_schedule(N, F, Y, X) + ko = int(kernel.shape[1]) + kh = int(kernel.shape[2]) + kw = int(kernel.shape[3]) + stride_h, stride_w = cfg.stride + # Workaround to make CUDA compilation work. Issue #4470 + # TODO make _fallback_schedule work for all kernel/strides combinations + # after issue #4470 is resolved + do_fallback = True + if ko == 1: + do_fallback = False + elif (kh, kw) == (1, 1): + do_fallback = True + elif (kh, kw) == (stride_h, stride_w): + do_fallback = False + + if do_fallback: + N, F, Y, X = get_const_tuple(conv.shape) + _fallback_schedule(N, F, Y, X) ##### space definition end #####