Skip to content

Commit

Permalink
[NNVM, TOPI] Bug fixes (apache#24)
Browse files Browse the repository at this point in the history
* bug fix

* passing the parameters when building the nnvm graph before extracting the tasks in autotvm

* bug fix for operator fusion

* fixing integration and tuning scripts
  • Loading branch information
tmoreau89 committed Jan 2, 2019
1 parent 09c94e3 commit 9274d25
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 22 deletions.
10 changes: 5 additions & 5 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,25 +98,25 @@ def compute_conv2d(attrs, inputs, _):
if groups == 1 and layout == 'NCHW4c' and inputs[0].dtype == 'int8':
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding,
dilation, layout, out_dtype=out_dtype)
dilation, layout, out_dtype)
# pylint: enable=assignment-from-no-return
elif groups == 1:
out = topi.nn.conv2d(
inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype=out_dtype)
inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype)
elif layout == "NCHW" and \
groups == get_const_int(inputs[0].shape[1]) and \
groups == channels:
out = topi.nn.depthwise_conv2d_nchw(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
inputs[0], inputs[1], strides, padding, dilation, out_dtype)
elif layout in ["NCHW", "NCHW4c"]:
out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups,
out_dtype=out_dtype)
out_dtype)
elif layout == "NHWC" and \
kernel_layout == "HWOI" and \
groups == get_const_int(inputs[0].shape[3]) and \
groups == channels:
out = topi.nn.depthwise_conv2d_nhwc(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
inputs[0], inputs[1], strides, padding, dilation, out_dtype)
else:
raise ValueError("not support arbitrary group number for now")

Expand Down
9 changes: 6 additions & 3 deletions python/tvm/autotvm/task/nnvm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
logger = logging.getLogger('autotvm')


def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
def extract_from_graph(graph, shape, dtype, target, symbols, params, target_host=None):
""" Extract tuning tasks from a nnvm graph.
This function collects tuning tasks by building the graph and trace all the calls to topi.
Expand All @@ -33,6 +33,8 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
The compilation target
symbols : Array of nnvm.symbol
Array of nnvm symbols want to be tuned
params : dict of str to NDArray
The parameter dictionary.
target_host: tvm.target.Target
The host compilation target
Expand Down Expand Up @@ -66,7 +68,8 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
# run compiler to collect all TOPI calls during compilation
nnvm.compiler.engine.clear_cache()
with ApplyHistoryBest([]):
nnvm.compiler.build(graph, target=target, shape=shape, dtype=dtype)
nnvm.compiler.build(graph, target=target, shape=shape, dtype=dtype,
target_host=target_host, params=params)
nnvm.compiler.engine.clear_cache()

logger.disabled = old_state
Expand All @@ -80,7 +83,7 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
template_key='direct')
tasks.append(tsk)
except topi.InvalidShapeError:
print("shape error")
print("[Warning] invalid shape")

return tasks

Expand Down
2 changes: 1 addition & 1 deletion vta/python/vta/top/vta_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def schedule_packed_conv2d(cfg, outs,
ewise_inputs = []
ewise_ops = []
conv2d_res = []
assert output.op.input_tensors[0].dtype == "int32"
assert "int" in output.op.input_tensors[0].dtype

def _traverse(op):
if topi.tag.is_broadcast(op.tag):
Expand Down
14 changes: 7 additions & 7 deletions vta/scripts/tune_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def my_clip(x, a_min, a_max):
x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
return x

def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype):
def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation, in_dtype, out_dtype):
data_shape = (N//env.BATCH, CI//env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN)
kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN)
bias_shape = (N//env.BATCH, CO//env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT)
Expand All @@ -33,7 +33,7 @@ def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype):
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)

with tvm.target.vta():
res = topi.nn.conv2d(data, kernel, padding=padding, strides=strides,
res = topi.nn.conv2d(data, kernel, padding=padding, strides=strides, dilation=dilation,
layout='NCHW%dn%dc' % (env.BATCH, env.BLOCK_IN), out_dtype='int32')
res = topi.add(res, bias)
res = topi.right_shift(res, 8)
Expand All @@ -46,13 +46,13 @@ def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype):
s = tvm.create_schedule([res.op])


return s, [data, kernel, bias, res]
return s, [data, kernel, bias, res]

if __name__ == '__main__':
N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype = \
1, 64, 56, 56, 64, 3, 3, (1, 1), (1, 1), 'int8', 'int32'
N, CI, H, W, CO, KH, KW, strides, padding, dilation, in_dtype, out_dtype = \
1, 64, 56, 56, 64, 3, 3, (1, 1), (1, 1), (1, 1), 'int8', 'int32'

task = autotvm.task.create(conv2d, args=(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype),
task = autotvm.task.create(conv2d, args=(N, CI, H, W, CO, KH, KW, strides, padding, dilation, in_dtype, out_dtype),
target=tvm.target.vta(env.MODEL), target_host=env.target_host, template_key='direct')
print(task.config_space)

Expand All @@ -62,7 +62,7 @@ def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype):

measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
runner=autotvm.RPCRunner(env.TARGET, 'fleet', 9190, number=4, repeat=3, timeout=30,
runner=autotvm.RPCRunner(env.TARGET, '10.77.1.109', 9190, number=4, repeat=3, timeout=30,
check_correctness=True))

tuner = autotvm.tuner.RandomTuner(task)
Expand Down
9 changes: 5 additions & 4 deletions vta/scripts/tune_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ def extract_tasks(sym, params, target, target_host):
sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT)

with vta.build_config():
tasks = autotvm.task.extract_from_graph(sym, target=target, target_host=target_host,
shape=shape_dict, dtype=dtype_dict, symbols=(nnvm.sym.conv2d,))
tasks = autotvm.task.extract_from_graph(sym, shape=shape_dict, dtype=dtype_dict, target=target,
params=params, symbols=(nnvm.sym.conv2d,), target_host=target_host,
)
return tasks


Expand Down Expand Up @@ -169,7 +170,7 @@ def tune_tasks(tasks,

'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
runner=autotvm.RPCRunner(env.TARGET, 'fleet', 9190,
runner=autotvm.RPCRunner(env.TARGET, '10.77.1.109', 9190,
number=4, repeat=3, timeout=60,
check_correctness=True))
}
Expand Down Expand Up @@ -202,7 +203,7 @@ def tune_tasks(tasks,

# upload module to device
print("Upload...")
remote = autotvm.measure.request_remote(env.TARGET, 'fleet', 9190, timeout=10000)
remote = autotvm.measure.request_remote(env.TARGET, '10.77.1.109', 9190, timeout=10000)
remote.upload(tmp.relpath(filename))
rlib = remote.load_module(filename)

Expand Down
4 changes: 2 additions & 2 deletions vta/tests/python/integration/test_benchmark_topi_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def run_cpu_conv2d(env, remote, wl, target):

with target:
res_conv = topi.nn.conv2d(
data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), "NCHW", "int32")
data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), (1, 1), "NCHW", "int32")
res = topi.right_shift(res_conv, 8)
res = my_clip(res, 0, 127)
res = topi.cast(res, "int8")
Expand Down Expand Up @@ -202,7 +202,7 @@ def run_vta_conv2d(env, remote, wl, target, check_correctness=True, print_ir=Fal

with target:
res_conv = topi.nn.conv2d(
data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad),
data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), (1, 1),
"NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN), 'int32')
res = topi.right_shift(res_conv, 8)
res = topi.add(res, bias)
Expand Down

0 comments on commit 9274d25

Please sign in to comment.