diff --git a/python/tvm/_ffi/_ctypes/function.py b/python/tvm/_ffi/_ctypes/function.py index 79f3c6033a1f..61679f0018c0 100644 --- a/python/tvm/_ffi/_ctypes/function.py +++ b/python/tvm/_ffi/_ctypes/function.py @@ -17,6 +17,7 @@ from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func from .node import NodeBase +from . import node as _node FunctionHandle = ctypes.c_void_p ModuleHandle = ctypes.c_void_p @@ -186,6 +187,23 @@ def __call__(self, *args): _ = args return RETURN_SWITCH[ret_tcode.value](ret_val) + +def __init_handle_by_constructor__(fconstructor, args): + """Initialize handle by constructor""" + temp_args = [] + values, tcodes, num_args = _make_tvm_args(args, temp_args) + ret_val = TVMValue() + ret_tcode = ctypes.c_int() + check_call(_LIB.TVMFuncCall( + fconstructor.handle, values, tcodes, ctypes.c_int(num_args), + ctypes.byref(ret_val), ctypes.byref(ret_tcode))) + _ = temp_args + _ = args + assert ret_tcode.value == TypeCode.NODE_HANDLE + handle = ret_val.v_handle + return handle + + def _return_module(x): """Return function""" handle = x.v_handle @@ -202,6 +220,7 @@ def _handle_return_func(x): # setup return handle for function type +_node.__init_by_constructor__ = __init_handle_by_constructor__ RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False) diff --git a/python/tvm/_ffi/_ctypes/node.py b/python/tvm/_ffi/_ctypes/node.py index 925aa93f8f96..eb9e930b30eb 100644 --- a/python/tvm/_ffi/_ctypes/node.py +++ b/python/tvm/_ffi/_ctypes/node.py @@ -1,5 +1,5 @@ # pylint: disable=invalid-name, protected-access -# pylint: disable=no-member, missing-docstring +# pylint: disable=no-member, missing-docstring, not-callable from __future__ import absolute_import import ctypes @@ -9,6 +9,7 @@ from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func NodeHandle = ctypes.c_void_p +__init_by_constructor__ = None """Maps node type to its constructor""" NODE_TYPE = {} @@ -58,4 +59,26 @@ def __getattr__(self, name): "'%s' object has no attribute '%s'" % (str(type(self)), name)) return RETURN_SWITCH[ret_type_code.value](ret_val) + def __init_handle_by_constructor__(self, fconstructor, *args): + """Initialize the handle by calling constructor function. + + Parameters + ---------- + fconstructor : Function + Constructor function. + + args: list of objects + The arguments to the constructor + + Note + ---- + We have a special calling convention to call constructor functions. + So the return handle is directly set into the Node object + instead of creating a new Node. + """ + handle = __init_by_constructor__(fconstructor, args) + if not isinstance(handle, NodeHandle): + handle = NodeHandle(handle) + self.handle = handle + _set_class_node_base(NodeBase) diff --git a/python/tvm/_ffi/_cython/function.pxi b/python/tvm/_ffi/_cython/function.pxi index 989f5b8e7b47..dcbf4c665e66 100644 --- a/python/tvm/_ffi/_cython/function.pxi +++ b/python/tvm/_ffi/_cython/function.pxi @@ -196,37 +196,54 @@ cdef inline object make_ret(TVMValue value, int tcode): raise ValueError("Unhandled type code %d" % tcode) -cdef inline object FuncCall3(void* chandle, tuple args, int nargs): +cdef inline int FuncCall3(void* chandle, + tuple args, + int nargs, + TVMValue* ret_val, + int* ret_tcode) except -1: cdef TVMValue[3] values cdef int[3] tcodes - cdef TVMValue ret_val - cdef int ret_code nargs = len(args) temp_args = [] for i in range(nargs): make_arg(args[i], &values[i], &tcodes[i], temp_args) CALL(TVMFuncCall(chandle, &values[0], &tcodes[0], - nargs, &ret_val, &ret_code)) - return make_ret(ret_val, ret_code) + nargs, ret_val, ret_tcode)) + return 0 -cdef inline object FuncCall(void* chandle, tuple args): +cdef inline int FuncCall(void* chandle, + tuple args, + TVMValue* ret_val, + int* ret_tcode) except -1: cdef int nargs nargs = len(args) if nargs <= 3: - return FuncCall3(chandle, args, nargs) + FuncCall3(chandle, args, nargs, ret_val, ret_tcode) + return 0 cdef vector[TVMValue] values cdef vector[int] tcodes - cdef TVMValue ret_val - cdef int ret_code values.resize(max(nargs, 1)) tcodes.resize(max(nargs, 1)) temp_args = [] for i in range(nargs): make_arg(args[i], &values[i], &tcodes[i], temp_args) CALL(TVMFuncCall(chandle, &values[0], &tcodes[0], - nargs, &ret_val, &ret_code)) - return make_ret(ret_val, ret_code) + nargs, ret_val, ret_tcode)) + return 0 + + +cdef inline int ConstructorCall(void* constructor_handle, + int type_code, + tuple args, + void** handle) except -1: + """Call contructor of a handle function""" + cdef TVMValue ret_val + cdef int ret_tcode + FuncCall(constructor_handle, args, &ret_val, &ret_tcode) + assert ret_tcode == type_code + handle[0] = ret_val.v_handle + return 0 cdef class FunctionBase: @@ -264,7 +281,10 @@ cdef class FunctionBase: CALL(TVMFuncFree(self.chandle)) def __call__(self, *args): - return FuncCall(self.chandle, args) + cdef TVMValue ret_val + cdef int ret_tcode + FuncCall(self.chandle, args, &ret_val, &ret_tcode) + return make_ret(ret_val, ret_tcode) _CLASS_FUNCTION = None _CLASS_MODULE = None diff --git a/python/tvm/_ffi/_cython/node.pxi b/python/tvm/_ffi/_cython/node.pxi index 1ced48878803..c62e4ab44cef 100644 --- a/python/tvm/_ffi/_cython/node.pxi +++ b/python/tvm/_ffi/_cython/node.pxi @@ -65,4 +65,27 @@ cdef class NodeBase: "'%s' object has no attribute '%s'" % (type(self), name)) return make_ret(ret_val, ret_type_code) + def __init_handle_by_constructor__(self, fconstructor, *args): + """Initialize the handle by calling constructor function. + + Parameters + ---------- + fconstructor : Function + Constructor function. + + args: list of objects + The arguments to the constructor + + Note + ---- + We have a special calling convention to call constructor functions. + So the return handle is directly set into the Node object + instead of creating a new Node. + """ + cdef void* chandle + ConstructorCall( + (fconstructor).chandle, + kNodeHandle, args, &chandle) + self.chandle = chandle + _set_class_node_base(NodeBase) diff --git a/python/tvm/_ffi/function.py b/python/tvm/_ffi/function.py index cfda2a35f9b9..ca1812d4109a 100644 --- a/python/tvm/_ffi/function.py +++ b/python/tvm/_ffi/function.py @@ -262,23 +262,7 @@ def _list(name, func): def _get_api(f): flocal = f flocal.is_global = True - def my_api_func(*args): - """ - - This is a type erased API that calls into Global PackedFunc. - These APIs corresponds to functions registered from C++ backend - and can be used as developer functions. - - args : list - The positional arguments to the function call. - - Returns - ------- - value : int, float, None, Node or Function - The result of the API function call. - """ - return flocal(*args) - return my_api_func + return flocal def _init_api(namespace, target_module_name=None): """Initialize api for a given module name diff --git a/python/tvm/api.py b/python/tvm/api.py index 75debc33db66..2bcb003ee7e5 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -134,9 +134,9 @@ def any(*args): raise ValueError("Any must take at least 1 argument") if len(args) == 1: return args[0] - ret = _make.Or(args[0], args[1]) + ret = _expr.Or(args[0], args[1]) for i in range(2, len(args)): - ret = _make.Or(ret, args[i]) + ret = _expr.Or(ret, args[i]) return ret @@ -158,9 +158,9 @@ def all(*args): raise ValueError("Any must take at least 1 argument") if len(args) == 1: return args[0] - ret = _make.And(args[0], args[1]) + ret = _expr.And(args[0], args[1]) for i in range(2, len(args)): - ret = _make.And(ret, args[i]) + ret = _expr.And(ret, args[i]) return ret @@ -616,7 +616,7 @@ def select(cond, t, f): node : Node The tvm.expr.Select node """ - return _make.Select(convert(cond), convert(t), convert(f)) + return _expr.Select(convert(cond), convert(t), convert(f)) def comm_reducer(fcombine, fidentity, name="reduce"): @@ -699,7 +699,7 @@ def _make_reduce(expr, axis, where=None): axis = convert(axis if isinstance(axis, (list, tuple)) else [axis]) if where is None: where = convert(True) - outputs = tuple(_make.Reduce(combiner, expr, axis, where, i) + outputs = tuple(_expr.Reduce(combiner, expr, axis, where, i) for i in range(size)) return outputs[0] if size == 1 else outputs @@ -751,5 +751,5 @@ def reducer(expr, axis, where=None, *args): _init_api("tvm.api") #pylint: disable=unnecessary-lambda sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum") -min = comm_reducer(lambda x, y: _make.Min(x, y), max_value, name='min') -max = comm_reducer(lambda x, y: _make.Max(x, y), min_value, name='max') +min = comm_reducer(lambda x, y: _expr.Min(x, y), max_value, name='min') +max = comm_reducer(lambda x, y: _expr.Max(x, y), min_value, name='max') diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 8bf46b7eee62..1c1c9f82cb97 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -225,127 +225,545 @@ class LogicalExpr(Expr): @register_node("Variable") class Var(Expr): - """Symbolic variable.""" - pass + """Symbolic variable. + + Parameters + ---------- + name : str + The name + + dtype : int + The data type + """ + def __init__(self, name, dtype): + self.__init_handle_by_constructor__( + _api_internal._Var, name, dtype) + @register_node class Reduce(Expr): - pass + """Reduce node. + + Parameters + ---------- + combiner : CommReducer + The combiner. + + src : list of Expr + The source expression. + + rdom : list of IterVar + The iteration domain + + condition : Expr + The reduce condition. + + value_index : int + The value index. + """ + def __init__(self, combiner, src, rdom, condition, value_index): + self.__init_handle_by_constructor__( + _make.Reduce, combiner, src, rdom, + condition, value_index) + @register_node class FloatImm(ConstExpr): - pass + """Float constant. + + Parameters + ---------- + dtype : str + The data type + + value : float + The constant value. + """ + def __init__(self, dtype, value): + self.__init_handle_by_constructor__( + _make.FloatImm, dtype, value) @register_node class IntImm(ConstExpr): - pass + """Int constant. + + Parameters + ---------- + dtype : str + The data type + + value : int + The constant value. + """ + def __init__(self, dtype, value): + self.__init_handle_by_constructor__( + _make.IntImm, dtype, value) + @register_node class UIntImm(ConstExpr): - pass + """UInt constant. + + Parameters + ---------- + dtype : str + The data type + + value : int + The constant value. + """ + def __init__(self, dtype, value): + self.__init_handle_by_constructor__( + _make.UIntImm, dtype, value) + @register_node class StringImm(ConstExpr): - pass + """String constant. + + Parameters + ---------- + value : str + The value of the function. + """ + def __init__(self, value): + self.__init_handle_by_constructor__( + _make.StringImm, value) + @register_node class Cast(Expr): - pass + """Cast expression. + + Parameters + ---------- + dtype : str + The data type + + value : Expr + The value of the function. + """ + def __init__(self, dtype, value): + self.__init_handle_by_constructor__( + _make.Cast, dtype, value) + @register_node class Add(BinaryOpExpr): - pass + """Add node. + + Parameters + ---------- + a : Expr + The left hand operand. + + b : Expr + The right hand operand. + """ + def __init__(self, a, b): + self.__init_handle_by_constructor__( + _make.Add, a, b) + @register_node class Sub(BinaryOpExpr): - pass + """Sub node. + + Parameters + ---------- + a : Expr + The left hand operand. + + b : Expr + The right hand operand. + """ + def __init__(self, a, b): + self.__init_handle_by_constructor__( + _make.Sub, a, b) + @register_node class Mul(BinaryOpExpr): - pass + """Mul node. + + Parameters + ---------- + a : Expr + The left hand operand. + + b : Expr + The right hand operand. + """ + def __init__(self, a, b): + self.__init_handle_by_constructor__( + _make.Mul, a, b) + @register_node class Div(BinaryOpExpr): - pass + """Div node. + + Parameters + ---------- + a : Expr + The left hand operand. + + b : Expr + The right hand operand. + """ + def __init__(self, a, b): + self.__init_handle_by_constructor__( + _make.Div, a, b) + @register_node class Mod(BinaryOpExpr): - pass + """Mod node. + + Parameters + ---------- + a : Expr + The left hand operand. + + b : Expr + The right hand operand. + """ + def __init__(self, a, b): + self.__init_handle_by_constructor__( + _make.Mod, a, b) + @register_node class Min(BinaryOpExpr): - pass + """Min node. + + Parameters + ---------- + a : Expr + The left hand operand. + + b : Expr + The right hand operand. + """ + def __init__(self, a, b): + self.__init_handle_by_constructor__( + _make.Min, a, b) + @register_node class Max(BinaryOpExpr): - pass + """Max node. + + Parameters + ---------- + a : Expr + The left hand operand. + + b : Expr + The right hand operand. + """ + def __init__(self, a, b): + self.__init_handle_by_constructor__( + _make.Max, a, b) + @register_node class EQ(CmpExpr): - pass + """EQ node. + + Parameters + ---------- + a : Expr + The left hand operand. + + b : Expr + The right hand operand. + """ + def __init__(self, a, b): + self.__init_handle_by_constructor__( + _make.EQ, a, b) + @register_node class NE(CmpExpr): - pass + """NE node. + + Parameters + ---------- + a : Expr + The left hand operand. + + b : Expr + The right hand operand. + """ + def __init__(self, a, b): + self.__init_handle_by_constructor__( + _make.NE, a, b) + @register_node class LT(CmpExpr): - pass + """LT node. + + Parameters + ---------- + a : Expr + The left hand operand. + + b : Expr + The right hand operand. + """ + def __init__(self, a, b): + self.__init_handle_by_constructor__( + _make.LT, a, b) + @register_node class LE(CmpExpr): - pass + """LE node. + + Parameters + ---------- + a : Expr + The left hand operand. + + b : Expr + The right hand operand. + """ + def __init__(self, a, b): + self.__init_handle_by_constructor__( + _make.LE, a, b) + @register_node class GT(CmpExpr): - pass + """GT node. + + Parameters + ---------- + a : Expr + The left hand operand. + + b : Expr + The right hand operand. + """ + def __init__(self, a, b): + self.__init_handle_by_constructor__( + _make.GT, a, b) + @register_node class GE(CmpExpr): - pass + """GE node. + + Parameters + ---------- + a : Expr + The left hand operand. + + b : Expr + The right hand operand. + """ + def __init__(self, a, b): + self.__init_handle_by_constructor__( + _make.GE, a, b) + @register_node class And(LogicalExpr): - pass + """And node. + + Parameters + ---------- + a : Expr + The left hand operand. + + b : Expr + The right hand operand. + """ + def __init__(self, a, b): + self.__init_handle_by_constructor__( + _make.And, a, b) + @register_node class Or(LogicalExpr): - pass + """Or node. + + Parameters + ---------- + a : Expr + The left hand operand. + + b : Expr + The right hand operand. + """ + def __init__(self, a, b): + self.__init_handle_by_constructor__( + _make.Or, a, b) + @register_node class Not(LogicalExpr): - pass + """Not node. + + Parameters + ---------- + a : Expr + The input value + """ + def __init__(self, a): + self.__init_handle_by_constructor__( + _make.Not, a) + @register_node class Select(Expr): - pass + """Select node. + + Parameters + ---------- + condition : Expr + The condition expression. + + true_value : Expr + The value to take when condition is true. + + false_value : Expr + The value to take when condition is false. + """ + def __init__(self, condition, true_value, false_value): + self.__init_handle_by_constructor__( + _make.Select, condition, true_value, false_value) + @register_node class Load(Expr): - pass + """Load node. + + Parameters + ---------- + dtype : str + The data type. + + buffer_var : Var + The buffer variable in the load expression. + + index : Expr + The index in the load. + + predicate : Expr + The load predicate. + """ + def __init__(self, dtype, buffer_var, index, predicate): + self.__init_handle_by_constructor__( + _make.Load, dtype, buffer_var, index, predicate) + @register_node class Ramp(Expr): - pass + """Ramp node. + + Parameters + ---------- + base : Expr + The base expression. + + stride : ramp stride + The stride of the ramp. + + lanes : int + The lanes of the expression. + """ + def __init__(self, base, stride, lanes): + self.__init_handle_by_constructor__( + _make.Ramp, base, stride, lanes) + @register_node class Broadcast(Expr): - pass + """Broadcast node. + + Parameters + ---------- + value : Expr + The value of the expression. + + lanes : int + The lanes of the expression. + """ + def __init__(self, value, lanes): + self.__init_handle_by_constructor__( + _make.Broadcast, value, lanes) + @register_node class Shuffle(Expr): - pass + """Shuffle node. + + Parameters + ---------- + vectors : Array of Expr + The vectors + + indices : Array of indices + The indices + """ + def __init__(self, vectors, indices): + self.__init_handle_by_constructor__( + _make.Shuffle, vectors, indices) + @register_node class Call(Expr): + """Call node. + + Parameters + ---------- + dtype : str + The return data type + + name : str + The name of the function + + args : list of Expr + The input arguments to the call + + call_type : int + The type of the call + + func : Operation, optional + Operation if call_type is Halide + + value_index : int + The output value index + """ Extern = 0 ExternCPlusPlus = 1 PureExtern = 2 Halide = 3 Intrinsic = 4 PureIntrinsic = 5 + def __init__(self, dtype, name, args, call_type, func, value_index): + self.__init_handle_by_constructor__( + _make.Call, dtype, name, args, call_type, func, value_index) @register_node class Let(Expr): - pass + """Let node. + + Parameters + ---------- + var : Var + The variable in the binding. + + value : Expr + The value in to be binded. + + body : Expr + The body expression. + """ + def __init__(self, var, value, body): + self.__init_handle_by_constructor__( + _make.Let, var, value, body) diff --git a/python/tvm/make.py b/python/tvm/make.py index 19949509778b..6238fd7f1789 100644 --- a/python/tvm/make.py +++ b/python/tvm/make.py @@ -6,9 +6,10 @@ Each api is a PackedFunc that can be called in a positional argument manner. You can use make function to build the IR node. """ +from __future__ import absolute_import as _abs from ._ffi.function import _init_api from ._ffi.runtime_ctypes import TVMType -from . import stmt as _stmt + def range_by_min_extent(min_value, extent): """Construct a Range by min and extent. @@ -98,44 +99,4 @@ def node(type_key, **kwargs): return _Node(*args) -def stmt_seq(*args): - """Make sequence of statements - - Parameters - ---------- - args : list of Expr or Var - List of statements to be combined as sequence. - - Returns - ------- - stmt : Stmt - The combined statement. - """ - ret = None - for value in args: - if not isinstance(value, _stmt.Stmt): - value = Evaluate(value) - ret = value if ret is None else Block(ret, value) - return ret if ret else Evaluate(0) - - -def stmt_list(stmt): - """Make list of stmt from blocks. - - Parameters - ---------- - stmt : A block statement - - Returns - ------- - stmt_list : list of Stmt - The unpacked list of statements - """ - if isinstance(stmt, _stmt.Block): - return stmt_list(stmt.first) + stmt_list(stmt.rest) - elif isinstance(stmt, _stmt.ProducerConsumer): - return stmt_list(stmt.body) - return [stmt] - - _init_api("tvm.make") diff --git a/python/tvm/stmt.py b/python/tvm/stmt.py index 1f5fea11a472..48d91dfa8044 100644 --- a/python/tvm/stmt.py +++ b/python/tvm/stmt.py @@ -15,65 +15,376 @@ """ from __future__ import absolute_import as _abs from ._ffi.node import NodeBase, register_node +from . import make as _make + class Stmt(NodeBase): pass @register_node class LetStmt(Stmt): - pass + """LetStmt node. + + Parameters + ---------- + var : Var + The variable in the binding. + + value : Expr + The value in to be binded. + + body : Stmt + The body statement. + """ + def __init__(self, var, value, body): + self.__init_handle_by_constructor__( + _make.LetStmt, var, value, body) + @register_node class AssertStmt(Stmt): - pass + """AssertStmt node. + + Parameters + ---------- + condition : Expr + The assert condition. + + message : Expr + The error message. + + body : Stmt + The body statement. + """ + def __init__(self, condition, message, body): + self.__init_handle_by_constructor__( + _make.AssertStmt, condition, message, body) + @register_node class ProducerConsumer(Stmt): - pass + """ProducerConsumer node. + + Parameters + ---------- + func : Operation + The Operation. + + is_producer : bool + Whether if the node is producer. + + body : Stmt + The body statement. + """ + def __init__(self, func, is_producer, body): + self.__init_handle_by_constructor__( + _make.ProducerConsumer, func, is_producer, body) + @register_node class For(Stmt): + """For node. + + Parameters + ---------- + loop_var : Var + The loop variable. + + min_val : Expr + The begining value. + + extent : Expr + The length of the loop. + + for_type : int + The for type. + + device_api : int + The device api type. + + body : Stmt + The body statement. + """ Serial = 0 Parallel = 1 Vectorized = 2 Unrolled = 3 + def __init__(self, + loop_var, + min_val, + extent, + for_type, + device_api, + body): + self.__init_handle_by_constructor__( + _make.For, loop_var, min_val, extent, + for_type, device_api, body) + @register_node class Store(Stmt): - pass + """Store node. + + Parameters + ---------- + buffer_var : Var + The buffer Variable. + + value : Expr + The value we want to store. + + index : Expr + The index in the store expression. + + predicate : Expr + The store predicate. + """ + def __init__(self, buffer_var, value, index, predicate): + self.__init_handle_by_constructor__( + _make.Store, buffer_var, value, index, predicate) + @register_node class Provide(Stmt): - pass + """Provide node. + + Parameters + ---------- + func : Operation + The operation to create the function. + + value_index : int + The output value index + + value : Expr + The value to be stored. + + args : list of Expr + The index arguments of the Provide. + """ + def __init__(self, func, value_index, value, args): + self.__init_handle_by_constructor__( + _make.Provide, func, value_index, value, args) + @register_node class Allocate(Stmt): - pass + """Allocate node. + + Parameters + ---------- + buffer_var : Var + The buffer variable. + + dtype : str + The data type of the buffer. + + extents : list of Expr + The extents of the allocate + + condition : Expr + The condition. + + body : Stmt + The body statement. + """ + def __init__(self, + buffer_var, + dtype, + extents, + condition, + body): + self.__init_handle_by_constructor__( + _make.Allocate, buffer_var, dtype, + extents, condition, body) + @register_node class AttrStmt(Stmt): - pass + """AttrStmt node. + + Parameters + ---------- + node : Node + The node to annotate the attribute + + attr_key : str + Attribute type key. + + value : Expr + The value of the attribute + + body : Stmt + The body statement. + """ + def __init__(self, node, attr_key, value, body): + self.__init_handle_by_constructor__( + _make.AttrStmt, node, attr_key, value, body) + @register_node class Free(Stmt): - pass + """Free node. + + Parameters + ---------- + buffer_var : Var + The buffer variable. + """ + def __init__(self, buffer_var): + self.__init_handle_by_constructor__( + _make.Free, buffer_var) + @register_node class Realize(Stmt): - pass + """Realize node. + + Parameters + ---------- + func : Operation + The operation to create the function. + + value_index : int + The output value index + + dtype : str + The data type of the operation. + + bounds : list of range + The bound of realize + + condition : Expr + The realize condition. + + body : Stmt + The realize body + """ + def __init__(self, + func, + value_index, + dtype, + bounds, + condition, + body): + self.__init_handle_by_constructor__( + _make.Realize, func, value_index, dtype, + bounds, condition, body) + @register_node class Block(Stmt): - pass + """Block node. + + Parameters + ---------- + first : Stmt + The first statement. + + rest : Stmt + The following statement. + """ + def __init__(self, first, rest): + self.__init_handle_by_constructor__( + _make.Block, first, rest) + @register_node class IfThenElse(Stmt): - pass + """IfThenElse node. + + Parameters + ---------- + condition : Expr + The expression + + then_case : Stmt + The statement to execute if condition is true. + + else_case : Stmt + The statement to execute if condition is false. + """ + def __init__(self, condition, then_case, else_case): + self.__init_handle_by_constructor__( + _make.IfThenElse, condition, then_case, else_case) + @register_node class Evaluate(Stmt): - pass + """Evaluate node. + + Parameters + ---------- + value : Expr + The expression to be evalued. + """ + def __init__(self, value): + self.__init_handle_by_constructor__( + _make.Evaluate, value) + @register_node class Prefetch(Stmt): - pass + """Prefetch node. + + Parameters + ---------- + func : Operation + The operation to create the function. + + value_index : int + The output value index + + dtype : str + The data type to be prefetched. + + bounds : list of Range + The bounds to be prefetched. + """ + def __init__(self, func, value_index, dtype, bounds): + self.__init_handle_by_constructor__( + _make.Prefetch, func, value_index, dtype, bounds) + + +def stmt_seq(*args): + """Make sequence of statements + + Parameters + ---------- + args : list of Expr or Var + List of statements to be combined as sequence. + + Returns + ------- + stmt : Stmt + The combined statement. + """ + ret = None + for value in args: + if not isinstance(value, Stmt): + value = Evaluate(value) + ret = value if ret is None else Block(ret, value) + return ret if ret else Evaluate(0) + + +def stmt_list(stmt): + """Make list of stmt from blocks. + + Parameters + ---------- + stmt : A block statement + + Returns + ------- + stmt_list : list of Stmt + The unpacked list of statements + """ + if isinstance(stmt, Block): + return stmt_list(stmt.first) + stmt_list(stmt.rest) + elif isinstance(stmt, ProducerConsumer): + return stmt_list(stmt.body) + return [stmt] + + +_make.stmt_list = stmt_list +_make.stmt_seq = stmt_seq diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index bc9293c20b7a..8a65260a0f58 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -170,6 +170,7 @@ REGISTER_MAKE3(Select); REGISTER_MAKE3(Ramp); REGISTER_MAKE2(Cast); REGISTER_MAKE2(Broadcast); +REGISTER_MAKE2(Shuffle); REGISTER_MAKE3(Let); REGISTER_MAKE3(LetStmt); REGISTER_MAKE3(AssertStmt); diff --git a/tests/python/unittest/test_lang_constructor.py b/tests/python/unittest/test_lang_constructor.py new file mode 100644 index 000000000000..caca08afa804 --- /dev/null +++ b/tests/python/unittest/test_lang_constructor.py @@ -0,0 +1,202 @@ +import tvm + +def test_expr_constructor(): + x = tvm.expr.Var("xx", "float32") + assert isinstance(x, tvm.expr.Var) + assert x.name == "xx" + + x = tvm.expr.Reduce(None, [1], + [tvm.api._IterVar((0, 1), "x", 2)], + None, 0) + assert isinstance(x, tvm.expr.Reduce) + assert x.combiner == None + assert x.value_index == 0 + + x = tvm.expr.FloatImm("float32", 1.0) + assert isinstance(x, tvm.expr.FloatImm) + assert x.value == 1.0 + assert x.dtype == "float32" + + x = tvm.expr.IntImm("int64", 2) + assert isinstance(x, tvm.expr.IntImm) + assert x.value == 2 + assert x.dtype == "int64" + + x = tvm.expr.UIntImm("uint16", 2) + assert isinstance(x, tvm.expr.UIntImm) + assert x.value == 2 + assert x.dtype == "uint16" + + x = tvm.expr.StringImm("xyza") + assert isinstance(x, tvm.expr.StringImm) + assert x.value == "xyza" + + x = tvm.expr.Cast("float32", tvm.expr.IntImm("int32", 1)) + assert isinstance(x, tvm.expr.Cast) + assert x.dtype == "float32" + assert x.value.value == 1 + + a = tvm.const(1.0, dtype="float32") + b = tvm.var("x", dtype="float32") + + for cls in [tvm.expr.Add, + tvm.expr.Sub, + tvm.expr.Mul, + tvm.expr.Div, + tvm.expr.Mod, + tvm.expr.Min, + tvm.expr.Max, + tvm.expr.LT, + tvm.expr.LE, + tvm.expr.GT, + tvm.expr.GE]: + x = cls(a, b) + assert isinstance(x, cls) + assert x.a == a + assert x.b.same_as(b) + + + a = tvm.convert(tvm.var("x") > 1) + b = tvm.convert(tvm.var("x") == 1) + + for cls in [tvm.expr.And, + tvm.expr.Or]: + x = cls(a, b) + assert isinstance(x, cls) + assert x.a == a + assert x.b.same_as(b) + + x = tvm.expr.Not(a) + assert isinstance(x, tvm.expr.Not) + assert x.a == a + + x = tvm.expr.Select(a, a, b) + assert isinstance(x, tvm.expr.Select) + assert x.true_value == a + assert x.false_value == b + assert x.condition == a + + buffer_var = tvm.var("x", dtype="handle") + x = tvm.expr.Load("float32", buffer_var, 1, a) + assert isinstance(x, tvm.expr.Load) + assert x.dtype == "float32" + assert x.buffer_var == buffer_var + assert x.index.value == 1 + assert x.predicate == a + + x = tvm.expr.Ramp(1, 2, 10) + assert isinstance(x, tvm.expr.Ramp) + assert x.base.value == 1 + assert x.stride.value == 2 + assert x.lanes == 10 + + x = tvm.expr.Broadcast(a, 10) + assert isinstance(x, tvm.expr.Broadcast) + assert x.value == a + assert x.lanes == 10 + + x = tvm.expr.Shuffle([a], [0]) + assert isinstance(x, tvm.expr.Shuffle) + assert x.vectors[0] == a + assert x.indices[0].value == 0 + + x = tvm.expr.Call("float32", "xyz", [a], tvm.expr.Call.Extern, None, 0) + assert isinstance(x, tvm.expr.Call) + assert x.dtype == "float32" + assert x.name == "xyz" + assert x.args[0] == a + assert x.call_type == tvm.expr.Call.Extern + assert x.func == None + assert x.value_index == 0 + + v = tvm.var("aa") + x = tvm.expr.Let(v, 1, v) + assert x.var == v + assert x.value.value == 1 + assert x.body == v + + +def test_stmt_constructor(): + v = tvm.var("aa") + buffer_var = tvm.var("buf", dtype="handle") + nop = tvm.stmt.Evaluate(1) + x = tvm.stmt.LetStmt(v, 1, tvm.stmt.Evaluate(1)) + assert isinstance(x, tvm.stmt.LetStmt) + assert x.var == v + assert x.value.value == 1 + assert isinstance(x.body, tvm.stmt.Evaluate) + + x = tvm.stmt.AttrStmt(v == 1, "xx", 1, tvm.stmt.Evaluate(1)) + assert isinstance(x, tvm.stmt.AttrStmt) + assert x.value.value == 1 + + x = tvm.stmt.Block(tvm.stmt.Evaluate(11), + nop) + assert isinstance(x, tvm.stmt.Block) + assert x.first.value.value == 11 + assert x.rest == nop + + x = tvm.stmt.AssertStmt(tvm.const(1, "uint1"), + tvm.convert("hellow"), + nop) + assert isinstance(x, tvm.stmt.AssertStmt) + assert x.body == nop + + x = tvm.stmt.ProducerConsumer(None, True, nop) + assert isinstance(x, tvm.stmt.ProducerConsumer) + assert x.body == nop + + x = tvm.stmt.For(tvm.var("x"), 0, 10, 0, 0, nop) + assert isinstance(x, tvm.stmt.For) + assert x.min.value == 0 + assert x.extent.value == 10 + assert x.body == nop + + x = tvm.stmt.Store(buffer_var, 1, 10, tvm.const(1, "uint1")) + assert isinstance(x, tvm.stmt.Store) + assert x.buffer_var == buffer_var + assert x.index.value == 10 + assert x.value.value == 1 + + tensor = tvm.placeholder((), dtype="float32") + x = tvm.stmt.Provide(tensor.op, 0, 10, []) + assert isinstance(x, tvm.stmt.Provide) + assert x.value_index == 0 + assert x.value.value == 10 + + x = tvm.stmt.Allocate(buffer_var, "float32", [10], + tvm.const(1, "uint1"), nop) + assert isinstance(x, tvm.stmt.Allocate) + assert x.dtype == "float32" + assert x.buffer_var == buffer_var + assert x.body == nop + + x = tvm.stmt.AttrStmt(buffer_var, "xyz", 1, nop) + assert isinstance(x, tvm.stmt.AttrStmt) + assert x.node == buffer_var + assert x.attr_key == "xyz" + assert x.body == nop + + x = tvm.stmt.Free(buffer_var) + assert isinstance(x, tvm.stmt.Free) + assert x.buffer_var == buffer_var + + x = tvm.stmt.Realize(None, 0, "float", [], tvm.const(1, "uint1"), nop) + assert isinstance(x, tvm.stmt.Realize) + assert x.body == nop + + x = tvm.stmt.IfThenElse(tvm.const(1, "uint1"), + tvm.stmt.Evaluate(11), + nop) + assert isinstance(x, tvm.stmt.IfThenElse) + assert x.then_case.value.value == 11 + assert x.else_case == nop + + x = tvm.stmt.Prefetch(None, 1, "float32", []) + assert isinstance(x, tvm.stmt.Prefetch) + assert x.value_index == 1 + + +if __name__ == "__main__": + test_expr_constructor() + test_stmt_constructor()