Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc] update frontend tutorials to new model based runtime interface #6063

Merged
merged 1 commit into from
Jul 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,10 +363,8 @@ def _make_executor(self, expr=None):
raise ValueError("Graph Runtime only supports static graphs, got output type",
ret_type)
num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
graph_json, mod, params = build(self.mod, target=self.target)
gmodule = _graph_rt.create(graph_json, mod, self.ctx)
if params:
gmodule.set_input(**params)
mod = build(self.mod, target=self.target)
gmodule = _graph_rt.GraphModule(mod['default'](self.ctx))

def _graph_wrapper(*args, **kwargs):
args = self._convert_args(self.mod["main"], args, kwargs)
Expand Down
5 changes: 2 additions & 3 deletions tutorials/frontend/build_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,11 @@ def prepare_params(g, data):
mod["main"] = func
# Build with Relay
with tvm.transform.PassContext(opt_level=0): # Currently only support opt_level=0
graph, lib, params = relay.build(mod, target, params=params)
lib = relay.build(mod, target, params=params)

# Generate graph runtime
ctx = tvm.context(target, 0)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m = graph_runtime.GraphModule(lib['default'](ctx))

######################################################################
# Run the TVM model, test for accuracy and verify with DGL
Expand Down
8 changes: 3 additions & 5 deletions tutorials/frontend/deploy_model_on_android.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ def transform_image(image):
mod, params = relay.frontend.from_keras(keras_mobilenet_v2, shape_dict)

with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(mod, target=target,
target_host=target_host, params=params)
lib = relay.build(mod, target=target,
target_host=target_host, params=params)

# After `relay.build`, you will get three return values: graph,
# library and the new parameter, since we do some optimization that will
Expand Down Expand Up @@ -309,14 +309,12 @@ def transform_image(image):
rlib = remote.load_module('net.so')

# create the remote runtime module
module = runtime.create(graph, rlib, ctx)
module = runtime.GraphModule(rlib['default'](ctx))

######################################################################
# Execute on TVM
# --------------

# set parameter (upload params to the remote device. This may take a while)
module.set_input(**params)
# set input data
module.set_input(input_name, tvm.nd.array(x.astype(dtype)))
# run
Expand Down
6 changes: 2 additions & 4 deletions tutorials/frontend/deploy_model_on_rasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def transform_image(image):
# target = tvm.target.create('llvm -device=arm_cpu -model=bcm2837 -mtriple=armv7l-linux-gnueabihf -mattr=+neon')

with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(func, target, params=params)
lib = relay.build(func, target, params=params)

# After `relay.build`, you will get three return values: graph,
# library and the new parameter, since we do some optimization that will
Expand Down Expand Up @@ -212,9 +212,7 @@ def transform_image(image):

# create the remote runtime module
ctx = remote.cpu(0)
module = runtime.create(graph, rlib, ctx)
# set parameter (upload params to the remote device. This may take a while)
module.set_input(**params)
module = runtime.GraphModule(rlib['default'](ctx))
# set input data
module.set_input('data', tvm.nd.array(x.astype('float32')))
# run
Expand Down
5 changes: 2 additions & 3 deletions tutorials/frontend/deploy_prequantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,9 @@ def get_synset():

def run_tvm_model(mod, params, input_name, inp, target="llvm"):
with tvm.transform.PassContext(opt_level=3):
json, lib, params = relay.build(mod, target=target, params=params)
lib = relay.build(mod, target=target, params=params)

runtime = tvm.contrib.graph_runtime.create(json, lib, tvm.context(target, 0))
runtime.set_input(**params)
runtime = tvm.contrib.graph_runtime.GraphModule(lib['default'](tvm.context(target, 0)))

runtime.set_input(input_name, inp)
runtime.run()
Expand Down
10 changes: 4 additions & 6 deletions tutorials/frontend/deploy_prequantized_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,9 @@ def run_tflite_model(tflite_model_buf, input_data):

###############################################################################
# Lets run TVM compiled pre-quantized model inference and get the TVM prediction.
def run_tvm(graph, lib, params):
def run_tvm(lib):
from tvm.contrib import graph_runtime
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
rt_mod.set_input(**params)
rt_mod = graph_runtime.GraphModule(lib['default'](tvm.cpu(0)))
rt_mod.set_input('input', data)
rt_mod.run()
tvm_res = rt_mod.get_output(0).asnumpy()
Expand Down Expand Up @@ -199,12 +198,11 @@ def run_tvm(graph, lib, params):
# target platform that you are interested in.
target = 'llvm'
with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build_module.build(mod, target=target,
params=params)
lib = relay.build_module.build(mod, target=target, params=params)

###############################################################################
# Finally, lets call inference on the TVM compiled module.
tvm_pred, rt_mod = run_tvm(graph, lib, params)
tvm_pred, rt_mod = run_tvm(lib)

###############################################################################
# Accuracy comparison
Expand Down
5 changes: 2 additions & 3 deletions tutorials/frontend/deploy_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,14 @@ def import_graphdef(
# tensors instead of sparse aware kernels.
def run_relay_graph(mod, params, shape_dict, target, ctx):
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, target=target, params=params)
lib = relay.build(mod, target=target, params=params)
input_shape = shape_dict["input_1"]
dummy_data = np.random.uniform(size=input_shape, low=0, high=input_shape[1]).astype(
"int32"
)

m = graph_runtime.create(graph, lib, ctx)
m = graph_runtime.GraphModule(lib['default'](ctx))
m.set_input(0, dummy_data)
m.set_input(**params)
m.run()
tvm_output = m.get_output(0)

Expand Down
13 changes: 6 additions & 7 deletions tutorials/frontend/deploy_ssd_gluoncv.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,27 +88,26 @@
def build(target):
mod, params = relay.frontend.from_mxnet(block, {"data": dshape})
with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(mod, target, params=params)
return graph, lib, params
lib = relay.build(mod, target, params=params)
return lib

######################################################################
# Create TVM runtime and do inference

def run(graph, lib, params, ctx):
def run(lib, ctx):
# Build TVM runtime
m = graph_runtime.create(graph, lib, ctx)
m = graph_runtime.GraphModule(lib['default'](ctx))
tvm_input = tvm.nd.array(x.asnumpy(), ctx=ctx)
m.set_input('data', tvm_input)
m.set_input(**params)
# execute
m.run()
# get outputs
class_IDs, scores, bounding_boxs = m.get_output(0), m.get_output(1), m.get_output(2)
return class_IDs, scores, bounding_boxs

for target, ctx in target_list:
graph, lib, params = build(target)
class_IDs, scores, bounding_boxs = run(graph, lib, params, ctx)
lib = build(target)
class_IDs, scores, bounding_boxs = run(lib, ctx)

######################################################################
# Display result
Expand Down
6 changes: 2 additions & 4 deletions tutorials/frontend/from_caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def transform_image(image):
# target x86 CPU
target = 'llvm'
with transform.PassContext(opt_level=3):
graph, lib, params = relay.build(mod, target, params=params)
lib = relay.build(mod, target, params=params)

######################################################################
# Execute on TVM
Expand All @@ -101,11 +101,9 @@ def transform_image(image):
# context x86 CPU, use tvm.gpu(0) if you run on GPU
ctx = tvm.cpu(0)
# create a runtime executor module
m = graph_runtime.create(graph, lib, ctx)
m = graph_runtime.GraphModule(lib['default'](ctx))
# set inputs
m.set_input(input_name, tvm.nd.array(data.astype('float32')))
# set related params
m.set_input(**params)
# execute
m.run()
# get outputs
Expand Down
7 changes: 2 additions & 5 deletions tutorials/frontend/from_coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@
mod, params = relay.frontend.from_coreml(mlmodel, shape_dict)

with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(mod,
target,
params=params)
lib = relay.build(mod, target, params=params)

######################################################################
# Execute on TVM
Expand All @@ -86,10 +84,9 @@
from tvm.contrib import graph_runtime
ctx = tvm.cpu(0)
dtype = 'float32'
m = graph_runtime.create(graph, lib, ctx)
m = graph_runtime.GraphModule(lib['default'](ctx))
# set inputs
m.set_input('image', tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
Expand Down
8 changes: 2 additions & 6 deletions tutorials/frontend/from_darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,7 @@
shape = {'data': data.shape}
print("Compiling the model...")
with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(mod,
target=target,
target_host=target_host,
params=params)
lib = relay.build(mod, target=target, target_host=target_host, params=params)

[neth, netw] = shape['data'][2:] # Current image shape is 608x608
######################################################################
Expand All @@ -122,11 +119,10 @@
# The process is no different from other examples.
from tvm.contrib import graph_runtime

m = graph_runtime.create(graph, lib, ctx)
m = graph_runtime.GraphModule(lib['default'](ctx))

# set inputs
m.set_input('data', tvm.nd.array(data.astype(dtype)))
m.set_input(**params)
# execute
print("Running the test image...")

Expand Down
5 changes: 2 additions & 3 deletions tutorials/frontend/from_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def transform_image(image):
# now compile the graph
target = 'cuda'
with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(func, target, params=params)
lib = relay.build(func, target, params=params)

######################################################################
# Execute the portable graph on TVM
Expand All @@ -100,10 +100,9 @@ def transform_image(image):
from tvm.contrib import graph_runtime
ctx = tvm.gpu(0)
dtype = 'float32'
m = graph_runtime.create(graph, lib, ctx)
m = graph_runtime.GraphModule(lib['default'](ctx))
# set inputs
m.set_input('data', tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
Expand Down
8 changes: 2 additions & 6 deletions tutorials/frontend/from_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,17 @@
target_host = 'llvm'
ctx = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(mod,
target=target,
target_host=target_host,
params=params)
lib = relay.build(mod, target=target, target_host=target_host, params=params)

######################################################################
# Execute the portable graph on TVM
# ---------------------------------
# Now we can try deploying the compiled model on target.
from tvm.contrib import graph_runtime
dtype = 'float32'
m = graph_runtime.create(graph, lib, ctx)
m = graph_runtime.GraphModule(lib['default'](ctx))
# Set inputs
m.set_input(input_name, tvm.nd.array(img.astype(dtype)))
m.set_input(**params)
# Execute
m.run()
# Get outputs
Expand Down
8 changes: 2 additions & 6 deletions tutorials/frontend/from_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,7 @@
# lib: target library which can be deployed on target with TVM runtime.

with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(mod,
target=target,
target_host=target_host,
params=params)
lib = relay.build(mod, target=target, target_host=target_host, params=params)

######################################################################
# Execute the portable graph on TVM
Expand All @@ -157,10 +154,9 @@

from tvm.contrib import graph_runtime
dtype = 'uint8'
m = graph_runtime.create(graph, lib, ctx)
m = graph_runtime.GraphModule(lib['default'](ctx))
# set inputs
m.set_input('DecodeJpeg/contents', tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
Expand Down
7 changes: 2 additions & 5 deletions tutorials/frontend/from_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def extract(path):
# Build the module against to x86 CPU
target = "llvm"
with transform.PassContext(opt_level=3):
graph, lib, params = relay.build(mod, target, params=params)
lib = relay.build(mod, target, params=params)

######################################################################
# Execute on TVM
Expand All @@ -146,14 +146,11 @@ def extract(path):
from tvm.contrib import graph_runtime as runtime

# Create a runtime executor module
module = runtime.create(graph, lib, tvm.cpu())
module = runtime.GraphModule(lib['default'](tvm.cpu()))

# Feed input data
module.set_input(input_tensor, tvm.nd.array(image_data))

# Feed related params
module.set_input(**params)

# Run
module.run()

Expand Down
12 changes: 4 additions & 8 deletions tutorials/frontend/using_external_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,11 @@
logging.basicConfig(level=logging.DEBUG) # to dump TVM IR after fusion

target = "cuda"
graph, lib, params = relay.build_module.build(
net, target, params=params)
lib = relay.build_module.build(net, target, params=params)

ctx = tvm.context(target, 0)
data = np.random.uniform(-1, 1, size=data_shape).astype("float32")
module = runtime.create(graph, lib, ctx)
module.set_input(**params)
module = runtime.GraphModule(lib['default'](ctx))
module.set_input("data", data)
module.run()
out_shape = (batch_size, out_channels, 224, 224)
Expand Down Expand Up @@ -494,13 +492,11 @@
# To do that, all we need to do is to append the option " -libs=cudnn" to the target string.
net, params = testing.create_workload(simple_net)
target = "cuda -libs=cudnn" # use cudnn for convolution
graph, lib, params = relay.build_module.build(
net, target, params=params)
lib = relay.build_module.build(net, target, params=params)

ctx = tvm.context(target, 0)
data = np.random.uniform(-1, 1, size=data_shape).astype("float32")
module = runtime.create(graph, lib, ctx)
module.set_input(**params)
module = runtime.GraphModule(lib['default'](ctx))
module.set_input("data", data)
module.run()
out_shape = (batch_size, out_channels, 224, 224)
Expand Down