Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Jul 9, 2020
1 parent f1f02e2 commit 8d8f5ed
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .tensor import *
from .transform import *
from .algorithm import *
from .vm import *
from . import vm
from . import nn
from . import annotation
from . import memory
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/vm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
# pylint: disable=wildcard-import
"""Dialect operators for Relay VM."""
from __future__ import absolute_import as _abs
from . import vm
from .vm import *
8 changes: 6 additions & 2 deletions python/tvm/relay/op/vm/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def invoke_tvm_op(func, inputs, outputs):
return _ffi_api.invoke_tvm_op(func, inputs, outputs)


def shape_func(func, inputs, outputs, dependent=False):
def shape_func(func, inputs, outputs, is_inputs):
"""Invoke the shape function of the passed function.
Parameters
Expand All @@ -71,9 +71,13 @@ def shape_func(func, inputs, outputs, dependent=False):
outputs : tvm.relay.Tuple
The tupled outputs.
is_inputs : List[bool]
A boolean list indicating whether the shape function should expect
shape or input at each position.
Returns
-------
result : tvm.relay.Expr
The shape function expression.
"""
return _ffi_api.shape_func(func, inputs, outputs, dependent)
return _ffi_api.shape_func(func, inputs, outputs, is_inputs)
4 changes: 2 additions & 2 deletions src/relay/op/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ RELAY_REGISTER_OP("vm.shape_func")
return {topi::identity(inputs[0])};
});

bool InvokeTVMOPRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
bool InvokeTVMOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4u);
auto func_type = types[0].as<FuncTypeNode>();
Expand Down Expand Up @@ -169,7 +169,7 @@ RELAY_REGISTER_OP("vm.invoke_tvm_op")
.add_argument("op", "Function", "The operation to call")
.add_argument("ins", "Tuple", "The input tensors.")
.add_argument("outs", "Tuple", "The output tensors.")
.add_type_rel("InvokeTVMOP", InvokeTVMOPRel)
.add_type_rel("InvokeTVMOp", InvokeTVMOpRel)
.set_support_level(10)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
Expand Down

0 comments on commit 8d8f5ed

Please sign in to comment.