Skip to content

Commit

Permalink
More flexible conv2d_NCHWc_int8 generic operator. (apache#6714)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalint13 authored and trevor-m committed Dec 4, 2020
1 parent 2225bb3 commit de6a58d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
12 changes: 6 additions & 6 deletions python/tvm/topi/generic/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements):
num_int8_elements,
)

oc_bn = int32_lanes
oc_bn = int32_lanes if int32_lanes >= num_int8_elements else num_int8_elements
ic_bn = 1
for bn in range(oc_bn, 0, -4):
if wkl.in_filter % bn == 0:
Expand Down Expand Up @@ -99,7 +99,7 @@ def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements):
num_int8_elements,
)

oc_bn = int32_lanes
oc_bn = int32_lanes if int32_lanes >= num_int8_elements else num_int8_elements
ic_bn = 1
for bn in range(oc_bn, 0, -4):
if wkl.in_filter % bn == 0:
Expand All @@ -119,7 +119,7 @@ def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements):


def schedule_conv_NCHWc_cpu_common_int8(
s, cfg, data_vec, kernel_vec, conv_out, last, int32_lanes=16, intrin=None
s, cfg, data_vec, kernel_vec, conv_out, last, int32_lanes=16, int8_elems=4, intrin=None
):
"""
Defines the schedule for INT8 for Intel and ARM machines
Expand Down Expand Up @@ -180,7 +180,7 @@ def schedule_conv_NCHWc_cpu_common_int8(
ow_chunk, ow_block = s[CC].split(ow, factor=reg_n)

assert oc_bn % int32_lanes == 0
assert ic_bn % 4 == 0 # 4 (u)int8 elements in (u)int32
assert ic_bn % int8_elems == 0 # (u)int8 elements in (u)int32

oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)

Expand Down Expand Up @@ -245,7 +245,7 @@ def schedule_conv_NCHWc_cpu_common_int8(


def schedule_conv_NCHWc_cpu_1x1_int8(
s, cfg, data_vec, kernel_vec, conv_out, last, int32_lanes=16, intrin=None
s, cfg, data_vec, kernel_vec, conv_out, last, int32_lanes=16, int8_elems=4, intrin=None
):
"""
Defines the 1x1 conv schedule for INT8 for Intel and ARM machines
Expand Down Expand Up @@ -305,7 +305,7 @@ def schedule_conv_NCHWc_cpu_1x1_int8(
kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis

assert oc_bn % int32_lanes == 0
assert ic_bn % 4 == 0 # 4 (u)int8 elements in (u)int32
assert ic_bn % int8_elems == 0 # (u)int8 elements in (u)int32

oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes)

Expand Down
7 changes: 4 additions & 3 deletions python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou


def conv2d_NCHWc_int8(
data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32"
data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32", n_elems=4
):
"""Conv2D operator for nChw[x]c layout.
Expand Down Expand Up @@ -539,6 +539,9 @@ def conv2d_NCHWc_int8(
out_dtype : str
output data type
n_elems : int
numer of int8 elements accumulated
Returns
-------
output : tvm.te.Tensor
Expand Down Expand Up @@ -588,7 +591,6 @@ def conv2d_NCHWc_int8(
kw = te.reduce_axis((0, kernel_width), name="kw")

if groups == 1:
n_elems = 4
ic_outer = te.reduce_axis((0, in_channel // ic_bn), name="ic_outer")
ic_f_inner = te.reduce_axis((0, ic_bn // n_elems), name="ic_f_inner")
ic_s_inner = te.reduce_axis((0, n_elems), name="ic_s_inner")
Expand All @@ -611,7 +613,6 @@ def conv2d_NCHWc_int8(
tag="conv2d_NCHWc_int8",
)
# for int8 group conv support
n_elems = 4
ic_chunk = in_channel // ic_bn
ic_outer = te.reduce_axis((0, ic_chunk // groups), name="ic_outer")
ic_f_inner = te.reduce_axis((0, ic_bn // n_elems), name="ic_f_inner")
Expand Down

0 comments on commit de6a58d

Please sign in to comment.