From de6a58d3fa702634ee70f987bd52c0494ff98e64 Mon Sep 17 00:00:00 2001 From: Balint Cristian Date: Sun, 8 Nov 2020 06:54:45 +0200 Subject: [PATCH] More flexible conv2d_NCHWc_int8 generic operator. (#6714) --- python/tvm/topi/generic/conv2d.py | 12 ++++++------ python/tvm/topi/nn/conv2d.py | 7 ++++--- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/python/tvm/topi/generic/conv2d.py b/python/tvm/topi/generic/conv2d.py index f23cff3bef84..7dd9aed7545d 100644 --- a/python/tvm/topi/generic/conv2d.py +++ b/python/tvm/topi/generic/conv2d.py @@ -51,7 +51,7 @@ def fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements): num_int8_elements, ) - oc_bn = int32_lanes + oc_bn = int32_lanes if int32_lanes >= num_int8_elements else num_int8_elements ic_bn = 1 for bn in range(oc_bn, 0, -4): if wkl.in_filter % bn == 0: @@ -99,7 +99,7 @@ def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements): num_int8_elements, ) - oc_bn = int32_lanes + oc_bn = int32_lanes if int32_lanes >= num_int8_elements else num_int8_elements ic_bn = 1 for bn in range(oc_bn, 0, -4): if wkl.in_filter % bn == 0: @@ -119,7 +119,7 @@ def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements): def schedule_conv_NCHWc_cpu_common_int8( - s, cfg, data_vec, kernel_vec, conv_out, last, int32_lanes=16, intrin=None + s, cfg, data_vec, kernel_vec, conv_out, last, int32_lanes=16, int8_elems=4, intrin=None ): """ Defines the schedule for INT8 for Intel and ARM machines @@ -180,7 +180,7 @@ def schedule_conv_NCHWc_cpu_common_int8( ow_chunk, ow_block = s[CC].split(ow, factor=reg_n) assert oc_bn % int32_lanes == 0 - assert ic_bn % 4 == 0 # 4 (u)int8 elements in (u)int32 + assert ic_bn % int8_elems == 0 # (u)int8 elements in (u)int32 oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) @@ -245,7 +245,7 @@ def schedule_conv_NCHWc_cpu_common_int8( def schedule_conv_NCHWc_cpu_1x1_int8( - s, cfg, data_vec, kernel_vec, conv_out, last, int32_lanes=16, intrin=None + s, cfg, data_vec, kernel_vec, conv_out, last, int32_lanes=16, int8_elems=4, intrin=None ): """ Defines the 1x1 conv schedule for INT8 for Intel and ARM machines @@ -305,7 +305,7 @@ def schedule_conv_NCHWc_cpu_1x1_int8( kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis assert oc_bn % int32_lanes == 0 - assert ic_bn % 4 == 0 # 4 (u)int8 elements in (u)int32 + assert ic_bn % int8_elems == 0 # (u)int8 elements in (u)int32 oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 2e147fc148de..cd10c757e956 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -505,7 +505,7 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou def conv2d_NCHWc_int8( - data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32" + data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32", n_elems=4 ): """Conv2D operator for nChw[x]c layout. @@ -539,6 +539,9 @@ def conv2d_NCHWc_int8( out_dtype : str output data type + n_elems : int + numer of int8 elements accumulated + Returns ------- output : tvm.te.Tensor @@ -588,7 +591,6 @@ def conv2d_NCHWc_int8( kw = te.reduce_axis((0, kernel_width), name="kw") if groups == 1: - n_elems = 4 ic_outer = te.reduce_axis((0, in_channel // ic_bn), name="ic_outer") ic_f_inner = te.reduce_axis((0, ic_bn // n_elems), name="ic_f_inner") ic_s_inner = te.reduce_axis((0, n_elems), name="ic_s_inner") @@ -611,7 +613,6 @@ def conv2d_NCHWc_int8( tag="conv2d_NCHWc_int8", ) # for int8 group conv support - n_elems = 4 ic_chunk = in_channel // ic_bn ic_outer = te.reduce_axis((0, ic_chunk // groups), name="ic_outer") ic_f_inner = te.reduce_axis((0, ic_bn // n_elems), name="ic_f_inner")