Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workaround to make conv2d_transpose compilation for CUDA work #4472

Merged
merged 1 commit into from
Dec 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,14 @@ def test_forward_convolution():
'NCHW', [4, 124, 17, 17])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 12, 17, 17])
# kernel 2x2, strides (2,2)
_test_convolution('conv_transpose', [4, 19, 8, 8], [2, 2, 19, 19], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 19, 16, 16])
_test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 12, 32], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 12, 16, 16])
# output channel is 1
_test_convolution('conv_transpose', [1, 19, 8, 8], [1, 1, 1, 19], [1, 1], [1, 1], 'VALID',
'NCHW', [1, 1, 8, 8])

_test_convolution('conv', [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
_test_convolution('conv', [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
Expand All @@ -386,6 +394,14 @@ def test_forward_convolution():
'NHWC', [4, 17, 17, 124])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 17, 17, 12])
# kernel 2x2, strides (2,2)
_test_convolution('conv_transpose', [4, 8, 8, 19], [2, 2, 19, 19], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 16, 16, 19])
_test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 12, 32], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 16, 16, 12])
# output channel is 1
_test_convolution('conv_transpose', [1, 8, 8, 19], [1, 1, 1, 19], [1, 1], [1, 1], 'VALID',
'NHWC', [1, 8, 8, 1])


#######################################################################
Expand Down
20 changes: 18 additions & 2 deletions topi/python/topi/cuda/conv2d_transpose_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,24 @@ def _callback(op):
cfg.define_knob("unroll_explicit", [0, 1])

if cfg.is_fallback:
N, F, Y, X = get_const_tuple(conv.shape)
_fallback_schedule(N, F, Y, X)
ko = int(kernel.shape[1])
kh = int(kernel.shape[2])
kw = int(kernel.shape[3])
stride_h, stride_w = cfg.stride
# Workaround to make CUDA compilation work. Issue #4470
Copy link
Member

@vinx13 vinx13 Dec 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we still use the fallback for the other cases by checking the input params here?

Copy link
Contributor Author

@apivovarov apivovarov Dec 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked more kernel and strides combinations and found that the error happens when kernel is equal to strides, e.g.

# kernel and strides when compilation for CUDA fails
2x2 and (2,2)
3x3 and (3,3)
4x4 and (4,4)
5x5 and (5,5)
2x3 and (2,3)
3x2 and (3,2)
1x2 and (1x2)
etc

I also found that the compilation fails if output channel is 1

Copy link
Contributor Author

@apivovarov apivovarov Dec 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added kernel / strides check and skip _fallback_schedule when output channel is 1.
In other case It will run _fallback_schedule for kernel 1x1 or when kernel != strides

# TODO make _fallback_schedule work for all kernel/strides combinations
# after issue #4470 is resolved
do_fallback = True
if ko == 1:
do_fallback = False
elif (kh, kw) == (1, 1):
do_fallback = True
elif (kh, kw) == (stride_h, stride_w):
do_fallback = False

if do_fallback:
N, F, Y, X = get_const_tuple(conv.shape)
_fallback_schedule(N, F, Y, X)

##### space definition end #####

Expand Down