diff --git a/python/vta/top/vta_conv2d_transpose.py b/python/vta/top/vta_conv2d_transpose.py index a2750dc9081d..ff10ff015348 100644 --- a/python/vta/top/vta_conv2d_transpose.py +++ b/python/vta/top/vta_conv2d_transpose.py @@ -27,24 +27,28 @@ from ..environment import get_env @autotvm.register_topi_compute(topi.nn.conv2d_transpose_nchw, 'vta', 'direct') -def _declatation_conv2d_transpose(cfg, +def _declaration_conv2d_transpose(cfg, data, kernel, strides, padding, - out_dtype): + out_dtype, + output_padding=(0, 0)): ishape = get_const_tuple(data.shape) kshape = get_const_tuple(kernel.shape) b, c_i, i_h, i_w, t_b, t_ci = ishape c_o, _, k_h, k_w, t_co, t_ci = kshape stride_h, stride_w = strides + opad_h, opad_w = output_padding + # FIXME(tmoreau89): currently IR pass breaks when output padding != (0,0) + assert opad_h == 0 and opad_w == 0, "VTA does not support output padding for now" # derive padding parameters fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (k_h, k_w)) bpad_top = k_h - 1 - fpad_top - bpad_bottom = k_h - 1 - fpad_bottom + bpad_bottom = k_h - 1 - fpad_bottom + opad_h bpad_left = k_w - 1 - fpad_left - bpad_right = k_w - 1 - fpad_right + bpad_right = k_w - 1 - fpad_right + opad_w # padding stage dilated_input = topi.nn.dilate(data, [1, 1, stride_h, stride_w, 1, 1]) @@ -53,8 +57,8 @@ def _declatation_conv2d_transpose(cfg, [0, 0, bpad_bottom, bpad_right, 0, 0]) # convolution transpose stage - out_h = (i_h - 1) * stride_h - fpad_top - fpad_bottom + k_h - out_w = (i_w - 1) * stride_w - fpad_left - fpad_right + k_w + out_h = (i_h - 1) * stride_h - fpad_top - fpad_bottom + k_h + opad_h + out_w = (i_w - 1) * stride_w - fpad_left - fpad_right + k_w + opad_w oshape = (b, c_o, out_h, out_w, t_b, t_co) d_c = tvm.reduce_axis((0, c_i), name='d_c') d_h = tvm.reduce_axis((0, k_h), name='d_h') diff --git a/scripts/tune_conv2d_transpose.py b/scripts/tune_conv2d_transpose.py index 3e51d410638b..fa9900a121c4 100644 --- a/scripts/tune_conv2d_transpose.py +++ b/scripts/tune_conv2d_transpose.py @@ -33,13 +33,15 @@ Workload = namedtuple("Conv2DTransposeWorkload", ['batch', 'height', 'width', 'in_filter', 'out_filter', - 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) + 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride', + 'o_hpad', 'o_wpad']) +# DCGAN workloads dcgan_wkls = [ # dcgan - ('DCGAN.CT1', Workload(env.BATCH, 4, 4, 1024, 512, 4, 4, 1, 1, 2, 2)), - ('DCGAN.CT2', Workload(env.BATCH, 8, 8, 512, 256, 4, 4, 1, 1, 2, 2)), - ('DCGAN.CT3', Workload(env.BATCH, 16, 16, 256, 128, 4, 4, 1, 1, 2, 2)), + ('DCGAN.CT1', Workload(env.BATCH, 4, 4, 1024, 512, 4, 4, 1, 1, 2, 2, 0, 0)), + ('DCGAN.CT2', Workload(env.BATCH, 8, 8, 512, 256, 4, 4, 1, 1, 2, 2, 0, 0)), + ('DCGAN.CT3', Workload(env.BATCH, 16, 16, 256, 128, 4, 4, 1, 1, 2, 2, 0, 0)), ] @tvm.tag_scope(tag=topi.tag.ELEMWISE) @@ -51,7 +53,7 @@ def my_clip(x, a_min, a_max): x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") return x -def conv2d_transpose(N, CI, H, W, CO, KH, KW, strides, padding): +def conv2d_transpose(N, CI, H, W, CO, KH, KW, strides, padding, opadding): data_shape = (N//env.BATCH, CI//env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN) kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN) @@ -64,7 +66,9 @@ def conv2d_transpose(N, CI, H, W, CO, KH, KW, strides, padding): Filter=kernel, strides=strides, padding=padding, - out_dtype=env.acc_dtype) + out_dtype=env.acc_dtype, + output_padding=opadding + ) res = topi.right_shift(res, env.WGT_WIDTH) res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) res = topi.cast(res, env.out_dtype) @@ -109,11 +113,12 @@ def conv2d_transpose(N, CI, H, W, CO, KH, KW, strides, padding): KW = wl.wkernel strides = (wl.hstride, wl.wstride) padding = (wl.hpad, wl.wpad) + opadding = (wl.o_hpad, wl.o_wpad) # Create task task = autotvm.task.create( conv2d_transpose, - args=(N, CI, H, W, CO, KH, KW, strides, padding), + args=(N, CI, H, W, CO, KH, KW, strides, padding, opadding), target=tvm.target.vta(), target_host=env.target_host, template_key='direct') diff --git a/tests/python/integration/test_benchmark_topi_conv2d_transpose.py b/tests/python/integration/test_benchmark_topi_conv2d_transpose.py index e2601d1a424f..235076c79528 100644 --- a/tests/python/integration/test_benchmark_topi_conv2d_transpose.py +++ b/tests/python/integration/test_benchmark_topi_conv2d_transpose.py @@ -37,7 +37,8 @@ Workload = namedtuple("Conv2DTransposeWorkload", ['batch', 'height', 'width', 'in_filter', 'out_filter', - 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) + 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride', + 'o_hpad', 'o_wpad']) # Get batch info from env env = vta.get_env() @@ -45,9 +46,9 @@ # DCGAN workloads dcgan_wklds = [ # dcgan - ('DCGAN.CT1', Workload(env.BATCH, 4, 4, 1024, 512, 4, 4, 1, 1, 2, 2)), - ('DCGAN.CT2', Workload(env.BATCH, 8, 8, 512, 256, 4, 4, 1, 1, 2, 2)), - ('DCGAN.CT3', Workload(env.BATCH, 16, 16, 256, 128, 4, 4, 1, 1, 2, 2)), + ('DCGAN.CT1', Workload(env.BATCH, 4, 4, 1024, 512, 4, 4, 1, 1, 2, 2, 0, 0)), + ('DCGAN.CT2', Workload(env.BATCH, 8, 8, 512, 256, 4, 4, 1, 1, 2, 2, 0, 0)), + ('DCGAN.CT3', Workload(env.BATCH, 16, 16, 256, 128, 4, 4, 1, 1, 2, 2, 0, 0)), ] # FIXME: we need a custom clip operator to circumvent a pattern detection limitation @@ -102,7 +103,8 @@ def run_conv2d_transpose(env, remote, wl, target, # Define base computation schedule with target: res = topi.nn.conv2d_transpose_nchw( - data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), env.acc_dtype) + data, kernel, (wl.hstride, wl.wstride), + (wl.hpad, wl.wpad), env.acc_dtype, (wl.o_hpad, wl.o_wpad)) res = topi.right_shift(res, env.WGT_WIDTH) res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1) res = topi.cast(res, env.out_dtype) @@ -112,8 +114,8 @@ def run_conv2d_transpose(env, remote, wl, target, print(vta.lower(s, [data, kernel, res], simple_mode=True)) # Derive number of ops - fout_height = (wl.height - 1) * wl.hstride - 2 * wl.hpad + wl.hkernel - fout_width = (wl.width - 1) * wl.wstride - 2 * wl.wpad + wl.wkernel + fout_height = (wl.height - 1) * wl.hstride - 2 * wl.hpad + wl.hkernel + wl.o_hpad + fout_width = (wl.width - 1) * wl.wstride - 2 * wl.wpad + wl.wkernel + wl.o_wpad num_ops = 2 * wl.batch * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter # @memoize("vta.tests.test_benchmark_topi.conv2d.verify_nhwc")