Skip to content

Commit

Permalink
[PYTHON] Enable constructors in Node
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Aug 23, 2018
1 parent 56ab0ad commit a41c247
Show file tree
Hide file tree
Showing 11 changed files with 1,079 additions and 122 deletions.
19 changes: 19 additions & 0 deletions python/tvm/_ffi/_ctypes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from .node import NodeBase
from . import node as _node

FunctionHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p
Expand Down Expand Up @@ -186,6 +187,23 @@ def __call__(self, *args):
_ = args
return RETURN_SWITCH[ret_tcode.value](ret_val)


def __init_handle_by_constructor__(fconstructor, args):
"""Initialize handle by constructor"""
temp_args = []
values, tcodes, num_args = _make_tvm_args(args, temp_args)
ret_val = TVMValue()
ret_tcode = ctypes.c_int()
check_call(_LIB.TVMFuncCall(
fconstructor.handle, values, tcodes, ctypes.c_int(num_args),
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
_ = temp_args
_ = args
assert ret_tcode.value == TypeCode.NODE_HANDLE
handle = ret_val.v_handle
return handle


def _return_module(x):
"""Return function"""
handle = x.v_handle
Expand All @@ -202,6 +220,7 @@ def _handle_return_func(x):


# setup return handle for function type
_node.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False)
Expand Down
25 changes: 24 additions & 1 deletion python/tvm/_ffi/_ctypes/node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pylint: disable=invalid-name, protected-access
# pylint: disable=no-member, missing-docstring
# pylint: disable=no-member, missing-docstring, not-callable
from __future__ import absolute_import

import ctypes
Expand All @@ -9,6 +9,7 @@
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func

NodeHandle = ctypes.c_void_p
__init_by_constructor__ = None

"""Maps node type to its constructor"""
NODE_TYPE = {}
Expand Down Expand Up @@ -58,4 +59,26 @@ def __getattr__(self, name):
"'%s' object has no attribute '%s'" % (str(type(self)), name))
return RETURN_SWITCH[ret_type_code.value](ret_val)

def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
handle = __init_by_constructor__(fconstructor, args)
if not isinstance(handle, NodeHandle):
handle = NodeHandle(handle)
self.handle = handle

_set_class_node_base(NodeBase)
40 changes: 28 additions & 12 deletions python/tvm/_ffi/_cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -196,37 +196,50 @@ cdef inline object make_ret(TVMValue value, int tcode):
raise ValueError("Unhandled type code %d" % tcode)


cdef inline object FuncCall3(void* chandle, tuple args, int nargs):
cdef inline void FuncCall3(void* chandle,
tuple args,
int nargs,
TVMValue* ret_val,
int* ret_tcode):
cdef TVMValue[3] values
cdef int[3] tcodes
cdef TVMValue ret_val
cdef int ret_code
nargs = len(args)
temp_args = []
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
nargs, &ret_val, &ret_code))
return make_ret(ret_val, ret_code)
nargs, ret_val, ret_tcode))

cdef inline object FuncCall(void* chandle, tuple args):
cdef inline void FuncCall(void* chandle,
tuple args,
TVMValue* ret_val,
int* ret_tcode):
cdef int nargs
nargs = len(args)
if nargs <= 3:
return FuncCall3(chandle, args, nargs)
FuncCall3(chandle, args, nargs, ret_val, ret_tcode)
return

cdef vector[TVMValue] values
cdef vector[int] tcodes
cdef TVMValue ret_val
cdef int ret_code
values.resize(max(nargs, 1))
tcodes.resize(max(nargs, 1))
temp_args = []
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(TVMFuncCall(chandle, &values[0], &tcodes[0],
nargs, &ret_val, &ret_code))
return make_ret(ret_val, ret_code)
nargs, ret_val, ret_tcode))


cdef inline void* ConstructorCall(void* constructor_handle,
int type_code,
tuple args):
"""Call contructor of a handle function"""
cdef TVMValue ret_val
cdef int ret_tcode
FuncCall(constructor_handle, args, &ret_val, &ret_tcode)
assert ret_tcode == type_code
return ret_val.v_handle


cdef class FunctionBase:
Expand Down Expand Up @@ -264,7 +277,10 @@ cdef class FunctionBase:
CALL(TVMFuncFree(self.chandle))

def __call__(self, *args):
return FuncCall(self.chandle, args)
cdef TVMValue ret_val
cdef int ret_tcode
FuncCall(self.chandle, args, &ret_val, &ret_tcode)
return make_ret(ret_val, ret_tcode)

_CLASS_FUNCTION = None
_CLASS_MODULE = None
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/_ffi/_cython/node.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,26 @@ cdef class NodeBase:
"'%s' object has no attribute '%s'" % (type(self), name))
return make_ret(ret_val, ret_type_code)

def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function.
Parameters
----------
fconstructor : Function
Constructor function.
args: list of objects
The arguments to the constructor
Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
self.chandle = ConstructorCall(
(<FunctionBase>fconstructor).chandle,
kNodeHandle, args)


_set_class_node_base(NodeBase)
18 changes: 1 addition & 17 deletions python/tvm/_ffi/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,23 +262,7 @@ def _list(name, func):
def _get_api(f):
flocal = f
flocal.is_global = True
def my_api_func(*args):
"""
This is a type erased API that calls into Global PackedFunc.
These APIs corresponds to functions registered from C++ backend
and can be used as developer functions.
args : list
The positional arguments to the function call.
Returns
-------
value : int, float, None, Node or Function
The result of the API function call.
"""
return flocal(*args)
return my_api_func
return flocal

def _init_api(namespace, target_module_name=None):
"""Initialize api for a given module name
Expand Down
16 changes: 8 additions & 8 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ def any(*args):
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = _make.Or(args[0], args[1])
ret = _expr.Or(args[0], args[1])
for i in range(2, len(args)):
ret = _make.Or(ret, args[i])
ret = _expr.Or(ret, args[i])
return ret


Expand All @@ -158,9 +158,9 @@ def all(*args):
raise ValueError("Any must take at least 1 argument")
if len(args) == 1:
return args[0]
ret = _make.And(args[0], args[1])
ret = _expr.And(args[0], args[1])
for i in range(2, len(args)):
ret = _make.And(ret, args[i])
ret = _expr.And(ret, args[i])
return ret


Expand Down Expand Up @@ -616,7 +616,7 @@ def select(cond, t, f):
node : Node
The tvm.expr.Select node
"""
return _make.Select(convert(cond), convert(t), convert(f))
return _expr.Select(convert(cond), convert(t), convert(f))


def comm_reducer(fcombine, fidentity, name="reduce"):
Expand Down Expand Up @@ -699,7 +699,7 @@ def _make_reduce(expr, axis, where=None):
axis = convert(axis if isinstance(axis, (list, tuple)) else [axis])
if where is None:
where = convert(True)
outputs = tuple(_make.Reduce(combiner, expr, axis, where, i)
outputs = tuple(_expr.Reduce(combiner, expr, axis, where, i)
for i in range(size))
return outputs[0] if size == 1 else outputs

Expand Down Expand Up @@ -751,5 +751,5 @@ def reducer(expr, axis, where=None, *args):
_init_api("tvm.api")
#pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _make.Min(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: _make.Max(x, y), min_value, name='max')
min = comm_reducer(lambda x, y: _expr.Min(x, y), max_value, name='min')
max = comm_reducer(lambda x, y: _expr.Max(x, y), min_value, name='max')
Loading

0 comments on commit a41c247

Please sign in to comment.