From f0f45246c5898d6864632ab5649db1c6713fb9d8 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 4 Aug 2018 10:37:50 -0700 Subject: [PATCH] Add schedule and test for group convolution (#5) * group conv pass all * pass mobilenet --- python/tvm/contrib/util.py | 1 + topi/python/topi/testing/__init__.py | 1 + topi/python/topi/testing/group_conv2d.py | 74 ++++++ vta/config/vta_config.json | 8 +- vta/python/vta/top/__init__.py | 4 +- vta/python/vta/top/arm_conv2d.py | 82 +++++++ vta/python/vta/top/vta_conv2d.py | 44 ++-- vta/python/vta/top/vta_group_conv2d.py | 224 ++++++++++++++++++ .../test_benchmark_topi_group_conv.py | 161 +++++++++++++ 9 files changed, 577 insertions(+), 22 deletions(-) create mode 100644 topi/python/topi/testing/group_conv2d.py create mode 100644 vta/python/vta/top/vta_group_conv2d.py create mode 100644 vta/tests/python/integration/test_benchmark_topi_group_conv.py diff --git a/python/tvm/contrib/util.py b/python/tvm/contrib/util.py index 0d94a8da5058b..d3a727f9389ff 100644 --- a/python/tvm/contrib/util.py +++ b/python/tvm/contrib/util.py @@ -143,6 +143,7 @@ def which(exec_name): return full_path return None + def get_lower_ir(s): """Get lower ir code of a schedule. This is useful for debug, since you don't have to find all inputs/outputs diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index c496e08c1835a..a29ad0fb05a97 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -8,6 +8,7 @@ from .conv2d_nchw_python import conv2d_nchw_python from .conv2d_nhwc_python import conv2d_nhwc_python from .conv2d_transpose_nchw_python import conv2d_transpose_nchw_python +from .group_conv2d import group_conv2d_nchw_python from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .dilate_python import dilate_python from .softmax_python import softmax_python, log_softmax_python diff --git a/topi/python/topi/testing/group_conv2d.py b/topi/python/topi/testing/group_conv2d.py new file mode 100644 index 0000000000000..c9332ffa3ce6c --- /dev/null +++ b/topi/python/topi/testing/group_conv2d.py @@ -0,0 +1,74 @@ +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-branches +"""Convolution in python""" +import numpy as np +import scipy.signal + + +def group_conv2d_nchw_python(a_np, w_np, stride, padding, groups): + """Convolution operator in HWCN layout. + + Parameters + ---------- + a_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] + + w_np : numpy.ndarray + 4-D with shape [num_filter, in_channel, filter_height, filter_width] + + stride : int or a list/tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : int or str or a list/tuple of two ints + Padding size, or ['VALID', 'SAME'], or [pad_height, pad_width] + + groups: int + + Returns + ------- + b_np : np.ndarray + 4-D with shape [batch, out_channel, out_height, out_width] + """ + batch, in_channel, in_height, in_width = a_np.shape + num_filter, ci_g, kernel_h, kernel_w = w_np.shape + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + if isinstance(padding, int): + pad_h = pad_w = padding * 2 + elif isinstance(padding, (list, tuple)): + pad_h, pad_w = padding[0] * 2, padding[1] * 2 + else: + pad_h = 0 if padding == 'VALID' else kernel_h - 1 + pad_w = 0 if padding == 'VALID' else kernel_w - 1 + pad_top = int(np.ceil(float(pad_h) / 2)) + pad_bottom = pad_h - pad_top + pad_left = int(np.ceil(float(pad_w) / 2)) + pad_right = pad_w - pad_left + # compute the output shape + out_channel = num_filter + out_height = (in_height - kernel_h + pad_h) // stride_h + 1 + out_width = (in_width - kernel_w + pad_w) // stride_w + 1 + b_np = np.zeros((batch, out_channel, out_height, out_width)) + + assert ci_g * groups == in_channel + + # group computation + for n in range(batch): + for f in range(out_channel): + for c in range(ci_g): + base = f // (out_channel // groups) * ci_g + if pad_h > 0 or pad_w > 0: + apad = np.zeros((in_height + pad_h, in_width + pad_w)) + if pad_h == 0: + apad[:, pad_left:-pad_right] = a_np[n, base + c] + elif pad_w == 0: + apad[pad_top:-pad_bottom, :] = a_np[n, base + c] + else: + apad[pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, base + c] + else: + apad = a_np[n, base + c] + out = scipy.signal.convolve2d( + apad, np.rot90(np.rot90(w_np[f, c])), mode='valid') + b_np[n, f] += out[::stride_h, ::stride_w] + return b_np diff --git a/vta/config/vta_config.json b/vta/config/vta_config.json index 27a2289b2b8b7..85f21b8270f4f 100644 --- a/vta/config/vta_config.json +++ b/vta/config/vta_config.json @@ -8,14 +8,14 @@ "GEMM_II" : 1, "TALU_II" : 2, "LOG_INP_WIDTH" : 3, - "LOG_WGT_WIDTH" : 1, + "LOG_WGT_WIDTH" : 3, "LOG_ACC_WIDTH" : 5, "LOG_OUT_WIDTH" : 3, "LOG_BATCH" : 0, - "LOG_BLOCK_IN" : 5, - "LOG_BLOCK_OUT" : 5, + "LOG_BLOCK_IN" : 4, + "LOG_BLOCK_OUT" : 4, "LOG_UOP_BUFF_SIZE" : 15, - "LOG_INP_BUFF_SIZE" : 16, + "LOG_INP_BUFF_SIZE" : 15, "LOG_WGT_BUFF_SIZE" : 18, "LOG_ACC_BUFF_SIZE" : 17 } diff --git a/vta/python/vta/top/__init__.py b/vta/python/vta/top/__init__.py index 46454ebf789f8..6c07c64f27d74 100644 --- a/vta/python/vta/top/__init__.py +++ b/vta/python/vta/top/__init__.py @@ -1,6 +1,8 @@ """TVM TOPI connector, eventually most of these should go to TVM repo""" -from .vta_conv2d import packed_conv2d, schedule_packed_conv2d from . import vta_conv2d from . import arm_conv2d + from .bitpack import bitpack +from .vta_conv2d import packed_conv2d, schedule_packed_conv2d +from .vta_group_conv2d import packed_group_conv2d, schedule_packed_group_conv2d diff --git a/vta/python/vta/top/arm_conv2d.py b/vta/python/vta/top/arm_conv2d.py index 634348a87cfe8..e3acb7a202df5 100644 --- a/vta/python/vta/top/arm_conv2d.py +++ b/vta/python/vta/top/arm_conv2d.py @@ -5,6 +5,88 @@ from topi.nn import conv2d, conv2d_alter_layout from topi import generic +_WORKLOADS = [ + # resnet 18 + Workload('float32', 'float32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2), + Workload('int8', 'int32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2), + Workload('int8', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1), + Workload('int8', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1), + Workload('int8', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2), + Workload('int8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2), + Workload('int8', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1), + Workload('int8', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2), + Workload('int8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2), + Workload('int8', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1), + Workload('int8', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2), + Workload('int8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2), + Workload('int8', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1), + + # mobilenet float32 + Workload('float32', 'float32', 224, 224, 3, 32, 3, 3, 1, 1, 2, 2), + Workload('float32', 'float32', 112, 112, 32, 64, 1, 1, 0, 0, 1, 1), + Workload('float32', 'float32', 56, 56, 64, 128, 1, 1, 0, 0, 1, 1), + Workload('float32', 'float32', 56, 56, 128, 128, 1, 1, 0, 0, 1, 1), + Workload('float32', 'float32', 28, 28, 128, 256, 1, 1, 0, 0, 1, 1), + Workload('float32', 'float32', 28, 28, 256, 256, 1, 1, 0, 0, 1, 1), + Workload('float32', 'float32', 14, 14, 256, 512, 1, 1, 0, 0, 1, 1), + Workload('float32', 'float32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1), + Workload('float32', 'float32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1), + Workload('float32', 'float32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1), + + # mobilenet int8 + Workload('float32', 'float32', 224, 224, 3, 32, 3, 3, 1, 1, 2, 2), + Workload('int8', 'int32', 112, 112, 32, 64, 1, 1, 0, 0, 1, 1), + Workload('int8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 1, 1), + Workload('int8', 'int32', 56, 56, 128, 128, 1, 1, 0, 0, 1, 1), + Workload('int8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 1, 1), + Workload('int8', 'int32', 28, 28, 256, 256, 1, 1, 0, 0, 1, 1), + Workload('int8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 1, 1), + Workload('int8', 'int32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1), + Workload('int8', 'int32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1), + Workload('int8', 'int32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1), +] + +_SCHEDULES = [ + # float32 imagenet + SpatialPack(1, 8, 4, 1, 4, True), + SpatialPack(1, 8, 4, 1, 4, True), + SpatialPack(1, 7, 4, 2, 4, True), + SpatialPack(1, 4, 8, 4, 1, True), + SpatialPack(1, 4, 4, 1, 16, False), + SpatialPack(1, 4, 8, 4, 8, False), + SpatialPack(1, 7, 4, 3, 8, True), + SpatialPack(1, 2, 8, 1, 8, True), + SpatialPack(2, 1, 16, 1, 4, True), + SpatialPack(1, 7, 4, 1, 1, True), + Im2ColPack(7, 4, 1, 16, True), + Im2ColPack(7, 4, 1, 8, False), + Im2ColPack(7, 4, 1, 16, False), + + # float32 mobilenet + SpatialPack(2, 2, 4, 28, 1, True), + SpatialPack(1, 4, 8, 14, 1, False), + SpatialPack(1, 2, 16, 8, 1, True), + SpatialPack(1, 4, 8, 8, 8, True), + SpatialPack(2, 2, 8, 1, 1, False), + SpatialPack(1, 4, 8, 4, 8, False), + SpatialPack(2, 2, 8, 1, 4, False), + SpatialPack(2, 2, 8, 1, 8, False), + Im2ColPack(7, 4, 1, 16, False), + Im2ColPack(7, 4, 1, 4, True), + + # int8 mobilenet + SpatialPack(2, 2, 4, 28, 1, True), + SpatialPack(1, 4, 8, 14, 1, False), + SpatialPack(1, 2, 16, 8, 1, True), + SpatialPack(1, 4, 8, 8, 8, True), + SpatialPack(2, 2, 8, 1, 1, False), + SpatialPack(1, 4, 8, 4, 8, False), + SpatialPack(2, 2, 8, 1, 4, False), + SpatialPack(2, 2, 8, 1, 8, False), + Im2ColPack(7, 4, 1, 16, False), + Im2ColPack(7, 4, 1, 4, True), +] + @conv2d.register(["vtacpu", "vta"]) def compute(*args, **kwargs): with tvm.target.arm_cpu("vtacpu"): diff --git a/vta/python/vta/top/vta_conv2d.py b/vta/python/vta/top/vta_conv2d.py index 7a73b58278052..b0029565f5066 100644 --- a/vta/python/vta/top/vta_conv2d.py +++ b/vta/python/vta/top/vta_conv2d.py @@ -11,6 +11,8 @@ from nnvm.top import nn as _nn from ..environment import get_env from ..ptr_alias import reinterpret +from .vta_group_conv2d import packed_group_conv2d, schedule_packed_group_conv2d + Workload = namedtuple("Conv2DWorkload", ['batch', 'height', 'width', 'in_filter', 'out_filter', @@ -262,22 +264,26 @@ def compute_conv2d(attrs, inputs, out): assert dilation == (1, 1), "not support dilate now" if is_packed_layout(layout): - assert groups == 1 - env = get_env() - assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now" - assert env.LOG_OUT_WIDTH == 3, "only support 8bit inp for now" - inputs = list(inputs) - w_pack_factor = 1 << (3 - env.LOG_WGT_WIDTH) - assert inputs[1].dtype == "int8" - - # Apply bit packing if necessary - if w_pack_factor != 1: - kshape = list(topi.util.get_const_tuple(inputs[1].shape)) - kshape[-1] *= w_pack_factor - inputs[1] = reinterpret(inputs[1], kshape, dtype=env.wgt_dtype) - - return packed_conv2d(inputs[0], inputs[1], - padding, strides, out_dtype=out_dtype) + if groups == 1: + assert groups == 1 + env = get_env() + assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now" + assert env.LOG_OUT_WIDTH == 3, "only support 8bit inp for now" + inputs = list(inputs) + w_pack_factor = 1 << (3 - env.LOG_WGT_WIDTH) + assert inputs[1].dtype == "int8" + + # Apply bit packing if necessary + if w_pack_factor != 1: + kshape = list(topi.util.get_const_tuple(inputs[1].shape)) + kshape[-1] *= w_pack_factor + inputs[1] = reinterpret(inputs[1], kshape, dtype=env.wgt_dtype) + + return packed_conv2d(inputs[0], inputs[1], + padding, strides, out_dtype=out_dtype) + else: + return packed_group_conv2d(inputs[0], inputs[1], + padding, strides, groups, out_dtype=out_dtype) return _nn.compute_conv2d(attrs, inputs, out) @@ -286,11 +292,15 @@ def schedule_conv2d(attrs, outs, target): """ 2D convolution schedule. """ layout = attrs["layout"] + groups = attrs.get_int('groups') if is_packed_layout(layout): target = tvm.target.create(target) if target.device_name == "vta": - return schedule_packed_conv2d(outs) + if groups == 1: + return schedule_packed_conv2d(outs) + else: + return schedule_packed_group_conv2d(outs) elif str(target).startswith("llvm"): return tvm.create_schedule([x.op for x in outs]) else: diff --git a/vta/python/vta/top/vta_group_conv2d.py b/vta/python/vta/top/vta_group_conv2d.py new file mode 100644 index 0000000000000..e6891233a18d4 --- /dev/null +++ b/vta/python/vta/top/vta_group_conv2d.py @@ -0,0 +1,224 @@ +import logging +from collections import namedtuple + +import tvm +import topi + + +from topi.util import get_const_int, get_const_tuple +from tvm.contrib.util import get_lower_ir + +from ..environment import get_env + +Workload = namedtuple("GroupConv2DWorkload", + ('batch', 'height', 'width', 'in_filter', 'out_filter', 'groups', + 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride')) + +Schedule = namedtuple("GroupConv2DSchedule", + ('b_factor', 'oc_factor', 'ic_factor', 'h_factor', 'w_factor', + 'oc_nthread', 'h_nthread', 'debug_sync')) + + +def find_schedules(layer, vt_only=False, best_only=False): + return [Schedule(0, 0, 1, 0, 0, 0, 0, False)] + + +def _get_workload(data, pad_data, kernel, output): + """ Get the workload structure. + """ + o_shape = get_const_tuple(output.shape) + d_shape = get_const_tuple(data.shape) + k_shape = get_const_tuple(kernel.shape) + o_b, o_c, o_h, o_w, ob_blk, o_blk = o_shape + i_b, i_c, i_h, i_w, ib_blk, i_blk = d_shape + k_o, k_i, k_h, k_w, ko_blk, ki_blk = k_shape + # For now we need to assume that input channel blocking is the same + # as the output channel blocking + assert o_blk == i_blk + assert ob_blk == ib_blk + # Make sure that dimensions match + assert o_b == i_b + assert o_blk == ko_blk + assert i_blk == ki_blk + assert k_o == o_c + groups = i_c // k_i + assert i_c % groups == 0 + assert o_c % groups == 0 + + # Scale the channel size + i_c *= i_blk + o_c *= o_blk + if pad_data is not None: + p_shape = topi.util.get_const_tuple(pad_data.shape) + h_pad = (p_shape[2] - d_shape[2]) // 2 + w_pad = (p_shape[3] - d_shape[3]) // 2 + else: + h_pad, w_pad = 0, 0 + h_str = (i_h + h_pad*2 - k_h) // (o_h - 1) + w_str = (i_w + w_pad*2 - k_w) // (o_w - 1) + return Workload(i_b, i_h, i_w, i_c, o_c, groups, k_h, k_w, h_pad, w_pad, h_str, w_str) + + +def packed_group_conv2d(data, + kernel, + padding, + strides, + group, + out_dtype="int32"): + """ Packed conv2d function.""" + + if padding[0]: + pad_data = topi.nn.pad(data, [0, 0, padding[0], padding[1], 0, 0], name="pad_data") + else: + pad_data = data + + assert len(data.shape) == 6 + assert len(kernel.shape) == 6 + assert data.dtype == "int8", data.dtype + assert kernel.dtype == "int8", kernel.dtype + + N, CI, IH, IW, B_BATCH, B_CI = get_const_tuple(data.shape) + CO, CI_G, KH, KW, B_CO, B_CI = get_const_tuple(kernel.shape) + PAD_H, PAD_W = padding + STR_H, STR_W = strides + + OH = (IH + 2 * PAD_H - KH) // strides[0] + 1 + OW = (IW + 2 * PAD_W - KW) // strides[1] + 1 + + assert group * CI_G == CI + assert CO % group == 0 + + oshape = (N, CO, OH, OW, B_BATCH, B_CO) + + kh = tvm.reduce_axis((0, KH), name='d_i') + kw = tvm.reduce_axis((0, KW), name='d_j') + ci_o = tvm.reduce_axis((0, CI_G), name='k_o') + ci_i = tvm.reduce_axis((0, B_CI), name='k_ten') + + out = tvm.compute( + oshape, + lambda n, co, h, w, b_n, b_co: tvm.sum( + pad_data[n, co // (CO // group) * CI_G + ci_o, h * STR_H + kh, + w * STR_W + kw, b_n, ci_i].astype(out_dtype) * + kernel[co, ci_o, kh, kw, b_co, ci_i].astype(out_dtype), + axis=[ci_o, kh, kw, ci_i]), + name="res", tag="packed_group_conv2d") + return out + + +def schedule_packed_group_conv2d(outs): + """ Schedule the packed conv2d. + """ + assert len(outs) == 1 + output = outs[0] + ewise_inputs = [] + ewise_ops = [] + conv2d_res = [] + assert output.dtype == "int8" + assert output.op.input_tensors[0].dtype == "int32" + + def _traverse(op): + if topi.tag.is_broadcast(op.tag): + if not op.same_as(output.op): + ewise_ops.append(op) + for tensor in op.input_tensors: + if isinstance(tensor.op, tvm.tensor.PlaceholderOp): + ewise_inputs.append((op, tensor)) + else: + _traverse(tensor.op) + else: + assert op.tag == "packed_group_conv2d" + conv2d_res.append(op) + + _traverse(output.op) + assert len(conv2d_res) == 1 + conv2d_stage = conv2d_res[0].output(0) + + data, kernel = conv2d_stage.op.input_tensors + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + temp = data.op.input_tensors[0] + pad_data = data + data = temp + else: + pad_data = None + wrkld = _get_workload(data, pad_data, kernel, output) + plan = find_schedules(wrkld, vt_only=True, best_only=True)[0] + logging.info("Trying to find plan for %s", wrkld) + env = get_env() + + load_inp = load_wgt = load_out = store_out = env.dma_copy + alu = env.alu + gemm = env.gemm + + # schedule1 + oshape = topi.util.get_const_tuple(output.shape) + s = tvm.create_schedule(output.op) + + # setup pad + if pad_data is not None: + cdata = pad_data + s[pad_data].set_scope(env.inp_scope) + else: + cdata = s.cache_read(data, env.inp_scope, [conv2d_stage]) + ckernel = s.cache_read(kernel, env.wgt_scope, [conv2d_stage]) + s[conv2d_stage].set_scope(env.acc_scope) + # cache read input + cache_read_ewise = [] + + for consumer, tensor in ewise_inputs: + cache_read_ewise.append( + s.cache_read(tensor, env.acc_scope, [consumer])) + # set ewise scope + for op in ewise_ops: + s[op].set_scope(env.acc_scope) + s[op].pragma(s[op].op.axis[0], alu) + + # tile + oc_factor = (plan.oc_factor if plan.oc_factor else 1) + h_factor = (plan.h_factor if plan.h_factor else 1) + w_factor = (plan.w_factor if plan.w_factor else 1) + + x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis + x_co0, x_co1 = s[output].split(x_co, factor=oc_factor) + x_i0, x_i1 = s[output].split(x_i, factor=h_factor) + x_j0, x_j1 = s[output].split(x_j, factor=w_factor) + s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci) + store_pt = x_j0 + + # set all compute scopes + s[conv2d_stage].compute_at(s[output], store_pt) + for op in ewise_ops: + s[op].compute_at(s[output], store_pt) + + for tensor in cache_read_ewise: + s[tensor].compute_at(s[output], store_pt) + s[tensor].pragma(s[tensor].op.axis[0], load_out) + + # virtual threading along output channel axes + if plan.oc_nthread > 1: + _, v_t = s[output].split(x_co0, factor=plan.oc_nthread) + s[output].reorder(v_t, x_bo) + s[output].bind(v_t, tvm.thread_axis("cthread")) + + # virtual threading along spatial rows + if plan.h_nthread > 1: + _, v_t = s[output].split(x_i0, factor=plan.h_nthread) + s[output].reorder(v_t, x_bo) + s[output].bind(v_t, tvm.thread_axis("cthread")) + + x_bo, x_co, x_i, x_j, x_bi, x_ci = s[conv2d_stage].op.axis + k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis + s[conv2d_stage].reorder(x_bo, k_o, x_j, d_j, d_i, x_co, x_i, x_bi, x_ci, k_i) + + if plan.ic_factor: + k_o, _ = s[conv2d_stage].split(k_o, factor=plan.ic_factor) + s[cdata].compute_at(s[conv2d_stage], k_o) + s[ckernel].compute_at(s[conv2d_stage], k_o) + + # Use VTA instructions + s[cdata].pragma(s[cdata].op.axis[0], load_inp) + s[ckernel].pragma(s[ckernel].op.axis[0], load_wgt) + s[conv2d_stage].tensorize(x_bi, gemm) + s[output].pragma(x_co1, store_out) + + return s diff --git a/vta/tests/python/integration/test_benchmark_topi_group_conv.py b/vta/tests/python/integration/test_benchmark_topi_group_conv.py new file mode 100644 index 0000000000000..0b16c41350c07 --- /dev/null +++ b/vta/tests/python/integration/test_benchmark_topi_group_conv.py @@ -0,0 +1,161 @@ +"""Testing if we can generate code in topi style""" + +import tvm +from tvm import autotvm +from tvm.contrib import util +from tvm.contrib.pickle_memoize import memoize +import topi +import topi.testing +import vta +import vta.testing +import numpy as np + +Workload = vta.top.vta_group_conv2d.Workload + + +@tvm.tag_scope(tag=topi.tag.ELEMWISE) +def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + + +def test_vta_group_conv2d(): + def run_vta_group_conv2d(env, remote, name, wl, profile=True): + assert wl.in_filter % wl.groups == 0 + assert wl.out_filter % wl.groups == 0 + assert wl.in_filter % (wl.groups * env.BLOCK_IN) == 0 + assert wl.batch % env.BATCH == 0 + assert wl.in_filter % env.BLOCK_IN == 0 + assert wl.out_filter % env.BLOCK_OUT == 0 + + batch_size = wl.batch + CI_G = wl.in_filter // wl.groups + + data_shape = (batch_size//env.BATCH, wl.in_filter//env.BLOCK_IN, + wl.height, wl.width, env.BATCH, env.BLOCK_IN) + kernel_shape = (wl.out_filter//env.BLOCK_OUT, CI_G//env.BLOCK_IN, + wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN) + bias_shape = (batch_size//env.BATCH, wl.out_filter//env.BLOCK_OUT, + 1, 1, env.BATCH, env.BLOCK_OUT) + + fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 + fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 + data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) + kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) + bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype) + + res_conv = vta.top.packed_group_conv2d( + data, kernel, (wl.hpad, wl.wpad), (wl.hstride, wl.wstride), wl.groups) + res = topi.right_shift(res_conv, 8) + res = topi.add(res, bias) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + # To compute number of ops, use a x2 factor for FMA + num_ops = 2 * batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * \ + wl.out_filter * wl.in_filter // wl.groups + + a_shape = (batch_size, wl.in_filter, wl.height, wl.width) + w_shape = (wl.out_filter, CI_G, wl.hkernel, wl.wkernel) + data_dtype = data.dtype + kernel_dtype = kernel.dtype + acc_dtype = env.acc_dtype + stride = (wl.hstride, wl.wstride) + padding = (wl.hpad, wl.wpad) + groups = wl.groups + + @memoize("vta.tests.test_benchmark_topi.conv2d.verify_nhwc") + def get_ref_data(): + a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype) + w_np = (np.random.uniform(size=w_shape) * 4).astype(kernel_dtype) + a_np = np.abs(a_np) + w_np = np.abs(w_np) + b_np = topi.testing.group_conv2d_nchw_python( + a_np.astype(acc_dtype), w_np.astype(acc_dtype), stride, padding, groups).astype(acc_dtype) + return a_np, w_np, b_np + + def verify(s, check_correctness): + mod = vta.build(s, [data, kernel, bias, res], "ext_dev", + env.target_host, name="group_conv2d") + temp = util.tempdir() + + mod.save(temp.relpath("group_conv2d.o")) + remote.upload(temp.relpath("group_conv2d.o")) + f = remote.load_module("group_conv2d.o") + # verify + ctx = remote.ext_dev(0) + # Data in original format + data_orig, kernel_orig, res_ref = get_ref_data() + bias_orig = (np.random.uniform(size=(wl.out_filter,)) * 4).astype("int32") + bias_orig = np.abs(bias_orig) + + data_packed = data_orig.reshape( + batch_size//env.BATCH, env.BATCH, + wl.in_filter//env.BLOCK_IN, env.BLOCK_IN, + wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3)) + kernel_packed = kernel_orig.reshape( + wl.out_filter//env.BLOCK_OUT, env.BLOCK_OUT, + wl.in_filter//wl.groups//env.BLOCK_IN, env.BLOCK_IN, + wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3)) + bias_packed = bias_orig.reshape( + 1, wl.out_filter // env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT) + res_shape = topi.util.get_const_tuple(res.shape) + + res_np = np.zeros(res_shape).astype(res.dtype) + data_arr = tvm.nd.array(data_packed, ctx) + kernel_arr = tvm.nd.array(kernel_packed, ctx) + bias_arr = tvm.nd.array(bias_packed, ctx) + res_arr = tvm.nd.array(res_np, ctx) + time_f = f.time_evaluator("group_conv2d", ctx, number=5) + cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) + res_unpack = res_arr.asnumpy().transpose( + (0, 4, 1, 5, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width) + if check_correctness: + assert wl.hpad == wl.wpad + stride = (wl.hstride, wl.wstride) + padding = (wl.hpad, wl.wpad) + res_ref = res_ref >> 8 + res_ref += bias_orig.reshape(wl.out_filter, 1, 1) + res_ref = np.clip(res_ref, 0, 127).astype("int8") + np.testing.assert_allclose(res_unpack, res_ref) + return cost + + def group_conv_normal(print_ir): + print("----- Group conv2d End-to-End Test-------") + with vta.build_config(): + s = vta.top.schedule_packed_group_conv2d([res]) + if print_ir: + print(vta.lower(s, [data, kernel, bias, res], simple_mode=True)) + cost = verify(s, True) + gops = (num_ops / cost.mean) / float(10 ** 9) + print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops)) + + group_conv_normal(False) + + def _run(env, remote): + tasks = [ + # mobilenet + ('mobilenet.D1', Workload(1, 112, 112, 32, 32, 2, 3, 3, 1, 1, 1, 1)), + ('mobilenet.D2', Workload(1, 112, 112, 64, 64, 4, 3, 3, 1, 1, 2, 2)), + ('mobilenet.D3', Workload(1, 56, 56, 64, 64, 4, 3, 3, 1, 1, 1, 1)), + ('mobilenet.D4', Workload(1, 56, 56, 128, 128, 8, 3, 3, 1, 1, 2, 2)), + ('mobilenet.D5', Workload(1, 28, 28, 256, 256, 8, 3, 3, 1, 1, 1, 1)), + ('mobilenet.D6', Workload(1, 28, 28, 256, 256, 16, 3, 3, 1, 1, 2, 2)), + ('mobilenet.D7', Workload(1, 14, 14, 256, 256, 16, 3, 3, 1, 1, 1, 1)), + ('mobilenet.D8', Workload(1, 14, 14, 256, 256, 16, 3, 3, 1, 1, 2, 2)), + ('mobilenet.D9', Workload(1, 7, 7, 1024, 1024, 64, 3, 3, 1, 1, 1, 1)), + ] + + for tsk in tasks: + print(tsk) + name, wkl = tsk + run_vta_group_conv2d(env, remote, name, wkl) + + vta.testing.run(_run) + +if __name__ == "__main__": + test_vta_group_conv2d()