Skip to content

Commit

Permalink
Correct Graph Executor Python API
Browse files Browse the repository at this point in the history
  • Loading branch information
Mousius committed Nov 18, 2021
1 parent 229484e commit cfb297c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _lower(mod, target, params):
with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
mod, _ = relay.optimize(mod, target, params)
grc = graph_executor_codegen.GraphExecutorCodegen(None, target)
grc.codegen(mod["main"])
grc.codegen(mod, mod["main"])
return

compiler = relay.vm.VMCompiler()
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/backend/graph_executor_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,13 @@ def _setup(self, mod, target):
tgts[_expr.IntImm("int32", 0)] = Target(target)
self._init(mod, tgts)

def codegen(self, func):
def codegen(self, ir_module, func):
"""Compile a single function into a graph.
Parameters
----------
ir_module: tvm.ir.Module
The module to compile
func: tvm.relay.Expr
The function to compile.
Expand All @@ -82,7 +84,7 @@ def codegen(self, func):
Additional constant parameters.
"""
default_mod_name = mangle_module_name("default")
self._codegen(func, default_mod_name)
self._codegen(ir_module, func, default_mod_name)
graph_json = self._get_graph_json()
lowered_func = self._get_irmodule()
param_names = self._list_params_name()
Expand Down

0 comments on commit cfb297c

Please sign in to comment.