diff --git a/include/tvm/codegen.h b/include/tvm/codegen.h index e58fab8cc21a..ddafe500482f 100644 --- a/include/tvm/codegen.h +++ b/include/tvm/codegen.h @@ -10,7 +10,7 @@ #include "./base.h" #include "./expr.h" #include "./module.h" -#include "./runtime/runtime.h" +#include "./runtime/packed_func.h" namespace tvm { diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 24be3f7881fe..6f61aced7aea 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -81,11 +81,13 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null"; /*! * \brief See pesudo code * - * bool tvm_print(VType value) { - * LOG(INFO) << value; + * int tvm_call_global(name, TVMValue* args) { + * PackedFunc f = PackedFunc::GetGlobal(name); + * f (args, type_code_of(args), len(args)); + * return 0; * } */ -constexpr const char* tvm_print = "tvm_print"; +constexpr const char* tvm_call_global = "tvm_call_global"; /*! \brief The field id of each field in array */ enum TVMArrayFieldKind { diff --git a/include/tvm/module.h b/include/tvm/module.h index 263fdc2f28f1..34ce7438b1a5 100644 --- a/include/tvm/module.h +++ b/include/tvm/module.h @@ -20,9 +20,6 @@ namespace tvm { // Internal node container of lowered function. class LoweredFuncNode; -// Internal node container of module. -class ModuleNode; - /*! * \brief LoweredFunc represents function after lowering. * This is the final IR representation before codegen. diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 34db255398ce..f3e9eee88acd 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -161,7 +161,7 @@ TVM_DLL const char *TVMGetLastError(void); * \param option_vals Additional option values to pass * \param num_options Number of options to be passed into it. * \param out_code 1: success, 0: already initialized - * \return Whether the function is successful. + * \return 0 when success, -1 when failure happens */ TVM_DLL int TVMDeviceInit(int dev_mask, const char** option_keys, @@ -188,7 +188,7 @@ TVM_DLL int TVMContextEnabled(TVMContext ctx, * \param dtype The array data type. * \param ctx The ctx this array sits on. * \param out The output handle. - * \return Whether the function is successful. + * \return 0 when success, -1 when failure happens */ TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, tvm_index_t ndim, @@ -198,6 +198,7 @@ TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, /*! * \brief Free the TVM Array. * \param handle The array handle to be freed. + * \return 0 when success, -1 when failure happens */ TVM_DLL int TVMArrayFree(TVMArrayHandle handle); @@ -206,6 +207,7 @@ TVM_DLL int TVMArrayFree(TVMArrayHandle handle); * \param from The array to be copied from. * \param to The target space. * \param stream The stream where the copy happens, can be NULL. + * \return 0 when success, -1 when failure happens */ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, TVMArrayHandle to, @@ -214,13 +216,14 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, * \brief Wait until all computations on stream completes. * \param ctx The ctx to be synchronized. * \param stream The stream to be synchronized. + * \return 0 when success, -1 when failure happens */ TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream); /*! * \brief Free the function when it is no longer needed. * \param func The function handle - * \return whether + * \return 0 when success, -1 when failure happens */ TVM_DLL int TVMFuncFree(TVMFunctionHandle func); @@ -239,6 +242,57 @@ TVM_DLL int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* type_codes, int num_args); + +/*! + * \brief C type of packed function. + * + * \param args The arguments + * \param type_codes The type codes of the arguments + * \param num_args Number of arguments. + * \param resource_handle The handle additional resouce handle from fron-end. + */ +typedef void (*TVMPackedCFunc)( + TVMValue* args, int* type_codes, int num_args, void* resource_handle); + +/*! + * \brief C callback to free the resource handle in C packed function. + * \param resource_handle The handle additional resouce handle from fron-end. + */ +typedef void (*TVMPackedCFuncFinalizer)(void* resource_handle); + +/*! + * \brief Wrap a TVMPackedCFunc to become a FunctionHandle. + * + * The resource_handle will be managed by TVM API, until the function is no longer used. + * + * \param func The packed C function. + * \param resource_handle The resource handle from front-end, can be NULL. + * \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL + * \param out the result function handle. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func, + void* resource_handle, + TVMPackedCFuncFinalizer fin, + TVMFunctionHandle *out); + +/*! + * \brief Register the function to runtime's global table. + * + * The registered function then can be pulled by the backend by the name. + * + * \param name The name of the function. + * \param f The function to be registered. + */ +TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f); + +/*! + * \brief Get a global function. + * + * \param name The name of the function. + * \param out the result function pointer. + */ +TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); } // TVM_EXTERN_C #endif // TVM_RUNTIME_C_RUNTIME_API_H_ diff --git a/include/tvm/runtime/runtime.h b/include/tvm/runtime/packed_func.h similarity index 69% rename from include/tvm/runtime/runtime.h rename to include/tvm/runtime/packed_func.h index ef53d6c5f5e8..b4868f757f1e 100644 --- a/include/tvm/runtime/runtime.h +++ b/include/tvm/runtime/packed_func.h @@ -1,35 +1,43 @@ /*! * Copyright (c) 2016 by Contributors - * \file runtime.h + * \file packed_func.h * \brief Runtime related c++ class. */ -#ifndef TVM_RUNTIME_RUNTIME_H_ -#define TVM_RUNTIME_RUNTIME_H_ +#ifndef TVM_RUNTIME_PACKED_FUNC_H_ +#define TVM_RUNTIME_PACKED_FUNC_H_ #include #include +#include +#include #include "./c_runtime_api.h" namespace tvm { namespace runtime { /*! - * \brief Packed function is a runtime function - * whose argument type_codes are erased by packed format. + * \brief Packed function is a type-erased function. + * The arguments are passed by packed format. * - * This is an useful unified interface to call generated functions. + * This is an useful unified interface to call generated functions, + * It is the unified function function type of TVM. + * It corresponds to TVMFunctionHandle in C runtime API. */ class PackedFunc { public: /*! \brief The internal std::function */ using FType = std::function; + /*! \brief default constructor */ PackedFunc() {} + /*! + * \brief constructing a packed function from a std::function. + * \param body the internal container of packed function. + */ explicit PackedFunc(FType body) : body_(body) {} /*! - * \brief invoke the packed function by directly passing in arguments. + * \brief Call packed function by directly passing in unpacked format. * \param args Arguments to be passed. * \tparam Args arguments to be passed. - * \return The first return value. */ template inline void operator()(Args&& ...args) const; @@ -41,9 +49,25 @@ class PackedFunc { */ inline void CallPacked(const TVMValue* args, const int* type_codes, int num_args) const; /*! \return the internal body function */ - inline FType body() const { - return body_; - } + inline FType body() const; + /*! + * \brief Register f as into global function table + * \param name The name of the function. + * \param f The function to be registered. + * \return Reference to the registered function. + * \note The returned reference is valid until the end of the program + */ + static const PackedFunc& RegisterGlobal(const std::string& name, PackedFunc f); + /*! + * \brief Get the global function by name. + * \param name The name of the function. + * \return reference to the registered function. + */ + static const PackedFunc& GetGlobal(const std::string& name); + /*! + * \brief Get the names of currently registered global function. + */ + static std::vector ListGlobalNames(); private: /*! \brief internal container of packed function */ @@ -56,6 +80,10 @@ inline void PackedFunc::CallPacked( body_(args, type_codes, num_args); } +inline PackedFunc::FType PackedFunc::body() const { + return body_; +} + template struct for_each_dispatcher_ { static inline void run(const std::tuple& args, F f) { @@ -124,4 +152,4 @@ inline void PackedFunc::operator()(Args&& ...args) const { } } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_RUNTIME_H_ +#endif // TVM_RUNTIME_PACKED_FUNC_H_ diff --git a/python/tvm/_ctypes/_api.py b/python/tvm/_ctypes/_api.py index 19d638d1af86..f6eb9ac05134 100644 --- a/python/tvm/_ctypes/_api.py +++ b/python/tvm/_ctypes/_api.py @@ -1,6 +1,6 @@ # coding: utf-8 # pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines -# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring +# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring, too-many-return-statements """Symbolic configuration API.""" from __future__ import absolute_import as _abs @@ -13,7 +13,7 @@ from .._base import check_call, ctypes2docstring from .. import _api_internal from . import _runtime_api -from ._types import TVMValue, TypeCode +from ._types import TVMValue, TypeCode, TVMPackedCFunc, TVMCFuncFinalizer # type definitions APIFuncHandle = ctypes.c_void_p @@ -57,6 +57,13 @@ def _return_func(x): return _runtime_api._function_cls(handle) +def _return_handle(x): + handle = x.v_handle + if not isinstance(handle, ctypes.c_void_p): + handle = ctypes.c_void_p(handle) + return handle + + RET_SWITCH = { TypeCode.NULL: lambda x: None, TypeCode.INT: lambda x: x.v_int64, @@ -66,6 +73,15 @@ def _return_func(x): TypeCode.FUNC_HANDLE: _return_func } +PACK_ARG_SWITCH = { + TypeCode.NULL: lambda x: None, + TypeCode.INT: lambda x: x.v_int64, + TypeCode.FLOAT: lambda x: x.v_float64, + TypeCode.STR: lambda x: py_str(x.v_str), + TypeCode.HANDLE: lambda x: _return_handle, +} + + class SliceBase(object): """base class of slice object""" pass @@ -159,10 +175,53 @@ def const(value, dtype=None): return _api_internal._const(value, dtype) +def _ctypes_free_resource(rhandle): + """callback to free resources when it it not needed.""" + pyobj = ctypes.cast(rhandle, ctypes.py_object) + ctypes.pythonapi.Py_DecRef(pyobj) + +# Global callback that is always alive +TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource) +ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ)) + +def convert_to_tvm_func(pyfunc): + """Convert a python function to TVM function + + Parameters + ---------- + pyfunc : python function + The python function to be converted. + + Returns + ------- + tvmfunc: tvm.nd.Function + The converted tvm function. + """ + local_pyfunc = pyfunc + def cfun(args, type_codes, num_args, _): + """ ctypes function """ + num_args = num_args.value if isinstance(num_args, ctypes.c_int) else num_args + pyargs = [PACK_ARG_SWITCH[type_codes[i]](args[i]) for i in range(num_args)] + local_pyfunc(*pyargs) + handle = FunctionHandle() + f = TVMPackedCFunc(cfun) + # NOTE: We will need to use python-api to increase ref count of the f + # TVM_FREE_PYOBJ will be called after it is no longer needed. + pyobj = ctypes.py_object(f) + ctypes.pythonapi.Py_IncRef(pyobj) + check_call(_LIB.TVMFuncCreateFromCFunc( + f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle))) + return _runtime_api._function_cls(handle) + + def convert(value): """Convert a value to expression.""" - if isinstance(value, Number): + if isinstance(value, (NodeBase, _runtime_api.FunctionBase)): + return value + elif isinstance(value, Number): return const(value) + elif isinstance(value, string_types): + return _api_internal._str(value) elif isinstance(value, (list, tuple)): value = [convert(x) for x in value] return _api_internal._Array(*value) @@ -176,10 +235,11 @@ def convert(value): return _api_internal._Map(*vlist) elif isinstance(value, SliceBase): return value.tensor(*value.indices) + elif callable(value): + return convert_to_tvm_func(value) else: - if not isinstance(value, NodeBase): - raise ValueError("don't know how to handle type %s" % type(value)) - return value + raise ValueError("don't know how to handle type %s" % type(value)) + return value def _push_arg(arg): @@ -270,6 +330,59 @@ def register(cls): NODE_TYPE[cls.__name__] = cls return cls + +def register_func(func_name, f=None): + """Register global function + + Parameters + ---------- + func_name : str or function + The function name + + f : function + The function to be registered. + + Returns + ------- + fregister : function + Register function if f is not specified. + """ + if callable(func_name): + f = func_name + func_name = f.__name__ + + if not isinstance(func_name, str): + raise ValueError("expect string function name") + def register(myf): + """internal register function""" + if not isinstance(myf, _runtime_api.FunctionBase): + myf = convert_to_tvm_func(myf) + check_call(_LIB.TVMFuncRegisterGlobal( + c_str(func_name), myf.handle)) + if f: + register(f) + else: + return register + + +def get_global_func(name): + """Get a global function by name + + Parameters + ---------- + name : str + The name of the global function + + Returns + ------- + func : tvm.nd.Function + The function to be returned. + """ + handle = FunctionHandle() + check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle))) + return _runtime_api._function_cls(handle) + + def _init_api_module(root_namespace): """List and add all the functions to current module.""" plist = ctypes.POINTER(ctypes.c_char_p)() diff --git a/python/tvm/_ctypes/_types.py b/python/tvm/_ctypes/_types.py index 29d78981a1db..b7d8e3537de4 100644 --- a/python/tvm/_ctypes/_types.py +++ b/python/tvm/_ctypes/_types.py @@ -70,3 +70,16 @@ def __repr__(self): if self.lanes != 1: x += "x%d" % self.lanes return x + + +TVMPackedCFunc = ctypes.CFUNCTYPE( + None, + ctypes.POINTER(TVMValue), + ctypes.POINTER(ctypes.c_int), + ctypes.c_int, + ctypes.c_void_p) + + +TVMCFuncFinalizer = ctypes.CFUNCTYPE( + None, + ctypes.c_void_p) diff --git a/python/tvm/api.py b/python/tvm/api.py index e537c9321fd3..e181db9f19a2 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -1,9 +1,9 @@ # pylint: disable=protected-access, no-member, invalid-name -# pylint: disable=redefined-builtin, undefined-variable +# pylint: disable=redefined-builtin, undefined-variable, unused-import """Functions defined in TVM.""" from __future__ import absolute_import as _abs from numbers import Integral as _Integral -from ._ctypes._api import _init_api_module, convert +from ._ctypes._api import _init_api_module, convert, register_func, get_global_func from . import _api_internal from . import make as _make from . import expr as _expr diff --git a/src/c_api/c_api_codegen.cc b/src/c_api/c_api_codegen.cc index e38a90777b0c..a198365f1a6a 100644 --- a/src/c_api/c_api_codegen.cc +++ b/src/c_api/c_api_codegen.cc @@ -52,5 +52,10 @@ TVM_REGISTER_API(_codegen_DummyHelloFunction) *ret = runtime::PackedFunc(DummyHelloFunction); }); +TVM_REGISTER_API(_codegen_BuildStackVM) +.set_body([](const ArgStack& args, RetValue *ret) { + *ret = BuildStackVM(args.at(0)); + }); + } // namespace codegen } // namespace tvm diff --git a/src/c_api/c_api_ir.cc b/src/c_api/c_api_ir.cc index c2110c76d9f6..41739032ef91 100644 --- a/src/c_api/c_api_ir.cc +++ b/src/c_api/c_api_ir.cc @@ -46,7 +46,8 @@ TVM_REGISTER_API(_make_Call) args.at(1), args.at(2), static_cast(args.at(3).operator int()), - args.at(4)); + args.at(4), + args.at(5)); }); TVM_REGISTER_API(_make_Allocate) diff --git a/src/c_api/c_api_lang.cc b/src/c_api/c_api_lang.cc index 4119d2fb3832..2c65d81b306e 100644 --- a/src/c_api/c_api_lang.cc +++ b/src/c_api/c_api_lang.cc @@ -4,6 +4,7 @@ * \file c_api_lang.cc */ #include +#include #include #include #include @@ -27,6 +28,13 @@ TVM_REGISTER_API(_const) .add_argument("src", "Number", "source number") .add_argument("dtype", "str", "data type"); + +TVM_REGISTER_API(_str) +.set_body([](const ArgStack& args, RetValue *ret) { + *ret = ir::StringImm::make(args.at(0)); +}); + + TVM_REGISTER_API(_Array) .set_body([](const ArgStack& args, RetValue *ret) { std::vector > data; diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h index 648c0c84ec2b..8852b937743e 100644 --- a/src/c_api/c_api_registry.h +++ b/src/c_api/c_api_registry.h @@ -9,7 +9,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/src/codegen/codegen_stack_vm.cc b/src/codegen/codegen_stack_vm.cc new file mode 100644 index 000000000000..d9ca3d388ee4 --- /dev/null +++ b/src/codegen/codegen_stack_vm.cc @@ -0,0 +1,497 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file codegen_stack_vm.cc + */ +#include +#include "./codegen_stack_vm.h" + +namespace tvm { +namespace codegen { + +using namespace ir; + +runtime::PackedFunc BuildStackVM(LoweredFunc func) { + StackVM vm = codegen::CodeGenStackVM().Compile(func); + auto f = [vm](const TVMValue* args, const int* type_codes, int num_args) { + LOG(INFO) << "Run stack VM"; + StackVM::State* s = StackVM::ThreadLocalState(); + s->sp = 0; + s->pc = 0; + if (s->heap.size() < vm.heap_size) { + s->heap.resize(vm.heap_size); + } + s->heap[0].v_handle = (void*)args; // NOLINT(*) + s->heap[1].v_handle = (void*)type_codes; // NOLINT(*) + s->heap[2].v_int64 = num_args; + vm.Run(s); + }; + return runtime::PackedFunc(f); +} + +TVMValue TVMPrint(const TVMValue* args, int num_args) { + CHECK_EQ(num_args, 2); + int tcode = static_cast(args[1].v_int64); + int code = (tcode >> (8 * 3)) & 255; + int bits = (tcode >> (8 * 2)) & 255; + int lanes = tcode & ((1 << 16) - 1); + Type t((halide_type_code_t)code, bits, lanes); + if (t.is_handle()) { + LOG(INFO) << t << ": " << args[0].v_handle; + } else if (t.is_float()) { + LOG(INFO) << t << ": " << args[0].v_float64; + } else { + LOG(INFO) << t << ": " << args[0].v_int64; + } + TVMValue r; r.v_int64 = 0; + return r; +} + +CodeGenStackVM::FType& CodeGenStackVM::vtable() { // NOLINT(*) + static FType inst; return inst; +} + +StackVM CodeGenStackVM::Compile(LoweredFunc f) { + for (size_t i = 0; i < f->args.size(); ++i) { + Var v = f->args[i]; + int vid = AllocVarID(v.get()); + CHECK_EQ(static_cast(vid), i); + } + this->Push(f->body); + return std::move(vm_); +} + +void CodeGenStackVM::Push(const Stmt& n) { + static const FType& f = vtable(); + f(n, this); + if (debug_) { + this->PushOp(StackVM::ASSERT_SP, 0); + } +} + +void CodeGenStackVM::Push(const Expr& n) { + static const FType& f = vtable(); + f(n, this); +} + +void CodeGenStackVM::PushOp(StackVM::OpCode opcode) { + StackVM::Code code; + code.op_code = opcode; + vm_.code.push_back(code); +} + +void CodeGenStackVM::SetOperand(int64_t operand_index, int64_t operand) { + CHECK(operand >= std::numeric_limits::min() && + operand <= std::numeric_limits::max()); + vm_.code.at(operand_index).v_int = static_cast(operand); +} + +int64_t CodeGenStackVM::PushOp(StackVM::OpCode opcode, int operand) { + int64_t pc = static_cast(vm_.code.size()); + StackVM::Code code; + code.op_code = opcode; + vm_.code.push_back(code); + code.v_int = operand; + vm_.code.push_back(code); + return pc + 1; +} + +int CodeGenStackVM::GetStrID(const std::string& key) { + auto it = str_idmap_.find(key); + if (it != str_idmap_.end()) return it->second; + int sid = static_cast(vm_.str_data.size()); + vm_.str_data.push_back(key); + str_idmap_[key] = sid; + return sid; +} + +int CodeGenStackVM::AllocVarID(const Variable* v) { + CHECK(!var_idmap_.count(v)); + int vid = static_cast(vm_.heap_size); + CHECK_EQ(vm_.heap_size, var_idmap_.size()); + vm_.heap_id_name.push_back(v->name_hint); + ++vm_.heap_size; + var_idmap_[v] = vid; + return vid; +} + +int CodeGenStackVM::GetGlobalFuncID(std::string name) { + auto it = fun_idmap_.find(name); + if (it != fun_idmap_.end()) return it->second; + using runtime::PackedFunc; + PackedFunc f = PackedFunc::GetGlobal(name); + auto extern_f = [f](const TVMValue* args, int num_args) { + CHECK_EQ(num_args % 2, 0); + num_args = num_args / 2; + std::vector type_codes(std::max(num_args, 1)); + for (int i = 0; i < num_args; ++i) { + int tcode = static_cast(args[num_args + i].v_int64); + int code = (tcode >> (8 * 3)) & 255; + type_codes[i] = code; + } + f.CallPacked(args, &type_codes[0], num_args); + TVMValue r; r.v_int64 = 0; + return r; + }; + int fid = static_cast(vm_.extern_func.size()); + vm_.extern_func.push_back(extern_f); + fun_idmap_[name] = fid; + + return fid; +} + +int CodeGenStackVM::GetVarID(const Variable* v) const { + auto it = var_idmap_.find(v); + CHECK(it != var_idmap_.end()) + << "Find undefined Variable " << v->name_hint; + return it->second; +} + +void CodeGenStackVM::Push_(const ir::Load* op) { + this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get())); + if (op->type == UInt(32) && op->index.as()) { + this->PushOp(StackVM::ARRAY_LOAD_UINT32, op->index.as()->value); + } else { + this->Push(op->index); + this->PushOp(StackVM::PUSH_I64, op->type.element_of().bytes()); + this->PushOp(StackVM::MUL_I64); + this->PushOp(StackVM::ADDR_ADD); + this->PushOp(StackVM::GetLoad(op->type)); + } +} +void CodeGenStackVM::Push_(const ir::Store* op) { + this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get())); + this->Push(op->index); + this->PushOp(StackVM::PUSH_I64, op->value.type().element_of().bytes()); + this->PushOp(StackVM::MUL_I64); + this->PushOp(StackVM::ADDR_ADD); + this->Push(op->value); + this->PushOp(StackVM::GetStore(op->value.type())); +} + +void CodeGenStackVM::Push_(const ir::Allocate* op) { + CHECK(!is_zero(op->condition)); + int vid = AllocVarID(op->buffer_var.get()); + if (op->new_expr.defined()) { + // Prefer global static allocation for the program + CHECK_EQ(op->free_function, "nop"); + this->Push(op->new_expr); + this->PushOp(StackVM::STORE_HEAP, vid); + } else { + LOG(FATAL) << "Dynamic allocation not supported"; + } +} + +void CodeGenStackVM::Push_(const ir::Call* op) { + if (op->is_intrinsic(Call::address_of)) { + const Load *l = op->args[0].as(); + CHECK(op->args.size() == 1 && l); + this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get())); + this->Push(l->index); + this->PushOp(StackVM::PUSH_I64, l->type.element_of().bytes()); + this->PushOp(StackVM::MUL_I64); + this->PushOp(StackVM::ADDR_ADD); + } else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) { + CHECK_EQ(op->args.size(), 3U); + this->Push(op->args[0]); + this->Push(op->args[1]); + this->Push(op->args[2]); + if (op->type.is_handle()) { + this->PushOp(StackVM::TVM_LOAD_ARG_HANDLE); + } else if (op->type.is_float()) { + this->PushOp(StackVM::TVM_LOAD_ARG_FP64); + } else if (op->type.is_int() || op->type.is_uint()) { + this->PushOp(StackVM::TVM_LOAD_ARG_INT64); + } else { + LOG(FATAL) << "donot know how to handle type" << op->type; + } + } else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) { + CHECK_EQ(op->args.size(), 2U); + this->Push(op->args[0]); + switch (op->args[1].as()->value) { + case intrinsic::kData: PushOp(StackVM::TVM_ARRAY_GET_DATA); break; + case intrinsic::kShape: PushOp(StackVM::TVM_ARRAY_GET_SHAPE); break; + case intrinsic::kStrides: PushOp(StackVM::TVM_ARRAY_GET_STRIDES); break; + case intrinsic::kNDim: PushOp(StackVM::TVM_ARRAY_GET_NDIM); break; + case intrinsic::kTypeCode: PushOp(StackVM::TVM_ARRAY_GET_TYPE_CODE); break; + case intrinsic::kTypeBits: PushOp(StackVM::TVM_ARRAY_GET_TYPE_BITS); break; + case intrinsic::kTypeLanes: PushOp(StackVM::TVM_ARRAY_GET_TYPE_LANES); break; + default: LOG(FATAL) << "unknown field code"; + } + } else if (op->is_intrinsic(intrinsic::tvm_call_global)) { + CHECK_GE(op->args.size(), 1U); + const StringImm* s = op->args[0].as(); + CHECK(s != nullptr) << "tvm_call_global expect first argument as function name"; + for (size_t i = 1; i < op->args.size(); ++i) { + this->Push(op->args[i]); + } + for (size_t i = 1; i < op->args.size(); ++i) { + Type t = op->args[i].type(); + int code = t.code(); + int bits = t.bits(); + int lanes = t.lanes(); + int tcode = (code << (8 * 3)) | (bits << 16) | lanes; + this->PushOp(StackVM::PUSH_I64, tcode); + } + int num_args = static_cast((op->args.size() - 1) * 2); + this->PushOp(StackVM::PUSH_I64, num_args); + this->PushOp(StackVM::CALL_EXTERN, GetGlobalFuncID(s->value)); + } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { + CHECK_EQ(op->args.size(), 1U); + this->Push(op->args[0]); + this->PushOp(StackVM::PUSH_I64, 0); + this->PushOp(StackVM::EQ_I64); + } else { + this->HandleUnknownCall(op); + } +} + +void CodeGenStackVM::HandleUnknownCall(const ir::Call* op) { + LOG(FATAL) << "donot know how to handle call " << op->name; +} + +inline void PushBinary(StackVM::OpCode op_int64, + const Expr& a, + const Expr& b, + CodeGenStackVM* p) { + p->Push(a); + p->Push(b); + Type t = a.type(); + if (t.is_int()) { + p->PushOp(op_int64); + } else if (t.is_uint()) { + if (t.bits() <= 32) { + p->PushOp(op_int64); + } else { + LOG(FATAL) << "Cannot handle uint64_t in StackVM"; + } + } else { + p->PushOp(StackVM::CodeI64ToF64(op_int64)); + } +} + + +inline void PushCast(Type dst, + Type src, + CodeGenStackVM* p) { + if (dst.is_int()) { + if (src.is_int()) return; + if (src.is_uint() && src.bits() <= 32) return; + } else if (dst.is_uint() && dst.bits() <= 32) { + if (src.is_int()) return; + if (src.is_uint() && src.bits() <= 32) return; + } else if (dst.is_float()) { + if (src.is_float()) return; + } + LOG(FATAL) << "Cannot handle cast " << src << " to " << dst; +} + +TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable) +.set_dispatch([](const StringImm *op, CodeGenStackVM *p) { + int sid = p->GetStrID(op->value); + p->PushOp(StackVM::PUSH_I64, sid); + }) +.set_dispatch([](const IntImm *op, CodeGenStackVM *p) { + CHECK(op->value >= std::numeric_limits::min() && + op->value <= std::numeric_limits::max()) + << "Int constant exceed bound"; + p->PushOp(StackVM::PUSH_I64, static_cast(op->value)); + }) +.set_dispatch([](const UIntImm *op, CodeGenStackVM *p) { + CHECK(op->value <= std::numeric_limits::max()) + << "Int constant exceed bound"; + p->PushOp(StackVM::PUSH_I64, static_cast(op->value)); + }) +.set_dispatch([](const FloatImm *op, CodeGenStackVM *p) { + LOG(FATAL) << "Float Imm is not supported"; + }); + +TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable) +.set_dispatch([](const Variable *op, CodeGenStackVM* p) { + int vid = p->GetVarID(op); + p->PushOp(StackVM::LOAD_HEAP, vid); + }) +.set_dispatch([](const Cast *op, CodeGenStackVM* p) { + p->Push(op->value); + PushCast(op->type, op->value.type(), p); + }) +.set_dispatch([](const Add *op, CodeGenStackVM* p) { + PushBinary(StackVM::ADD_I64, op->a, op->b, p); + }) +.set_dispatch([](const Sub *op, CodeGenStackVM* p) { + PushBinary(StackVM::SUB_I64, op->a, op->b, p); + }) +.set_dispatch([](const Mul *op, CodeGenStackVM* p) { + PushBinary(StackVM::MUL_I64, op->a, op->b, p); + }) +.set_dispatch
([](const Div *op, CodeGenStackVM* p) { + PushBinary(StackVM::DIV_I64, op->a, op->b, p); + }) +.set_dispatch([](const Mod *op, CodeGenStackVM* p) { + PushBinary(StackVM::MOD_I64, op->a, op->b, p); + }) +.set_dispatch([](const Min *op, CodeGenStackVM* p) { + p->Push(op->a); + p->Push(op->b); + p->PushOp(StackVM::PUSH_VALUE, -1); + p->PushOp(StackVM::PUSH_VALUE, -1); + p->PushOp(StackVM::LT_I64); + p->PushOp(StackVM::SELECT); + }) +.set_dispatch([](const Max *op, CodeGenStackVM* p) { + p->Push(op->a); + p->Push(op->b); + p->PushOp(StackVM::PUSH_VALUE, 0); + p->PushOp(StackVM::PUSH_VALUE, -2); + p->PushOp(StackVM::LT_I64); + p->PushOp(StackVM::SELECT); + }) +.set_dispatch([](const EQ *op, CodeGenStackVM* p) { + PushBinary(StackVM::EQ_I64, op->a, op->b, p); + }) +.set_dispatch([](const LE *op, CodeGenStackVM* p) { + PushBinary(StackVM::LE_I64, op->a, op->b, p); + }) +.set_dispatch([](const NE *op, CodeGenStackVM* p) { + PushBinary(StackVM::EQ_I64, op->a, op->b, p); + p->PushOp(StackVM::NOT); + }) +.set_dispatch([](const LT *op, CodeGenStackVM* p) { + PushBinary(StackVM::LT_I64, op->a, op->b, p); + }) +.set_dispatch([](const GE *op, CodeGenStackVM* p) { + PushBinary(StackVM::LT_I64, op->a, op->b, p); + p->PushOp(StackVM::NOT); + }) +.set_dispatch([](const GT *op, CodeGenStackVM* p) { + PushBinary(StackVM::LE_I64, op->a, op->b, p); + p->PushOp(StackVM::NOT); + }) +.set_dispatch([](const And *op, CodeGenStackVM* p) { + p->Push(op->a); + int64_t pc_jump = p->GetPC(); + int64_t opr_index = p->PushOp(StackVM::RJUMP_IF_FALSE, 0); + p->PushOp(StackVM::POP); + p->Push(op->b); + int64_t diff = p->GetPC() - pc_jump; + p->SetOperand(opr_index, diff); +}) +.set_dispatch([](const Or *op, CodeGenStackVM* p) { + p->Push(op->a); + int64_t pc_jump = p->GetPC(); + int64_t opr_index = p->PushOp(StackVM::RJUMP_IF_TRUE, 0); + p->Push(op->b); + int64_t diff = p->GetPC() - pc_jump; + p->SetOperand(opr_index, diff); +}) +.set_dispatch([](const Not* op, CodeGenStackVM* p) { + p->PushOp(StackVM::NOT); + }); + + +TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable) +.set_dispatch([](const ProducerConsumer *op, CodeGenStackVM* p) { + p->Push(op->body); + }) +.set_dispatch([](const For *op, CodeGenStackVM* p) { + CHECK(is_zero(op->min)); + int vid = p->AllocVarID(op->loop_var.get()); + p->PushOp(StackVM::PUSH_I64, 0); + int64_t loop_head = p->GetPC(); + p->PushOp(StackVM::STORE_HEAP, vid); + p->PushOp(StackVM::LOAD_HEAP, vid); + p->Push(op->extent); + p->PushOp(StackVM::LT_I64); + int64_t label_fjump = p->GetPC(); + int64_t foward_jump = p->PushOp(StackVM::RJUMP_IF_FALSE, 0); + p->PushOp(StackVM::POP); + p->Push(op->body); + p->PushOp(StackVM::LOAD_HEAP, vid); + p->PushOp(StackVM::PUSH_I64, 1); + p->PushOp(StackVM::ADD_I64); + int64_t label_bjump = p->GetPC(); + int64_t backward_jump = p->PushOp(StackVM::RJUMP, 0); + int64_t loop_end = p->GetPC(); + p->PushOp(StackVM::POP); + p->SetOperand(foward_jump, loop_end - label_fjump); + p->SetOperand(backward_jump, loop_head - label_bjump); + }) +.set_dispatch([](const Block *op, CodeGenStackVM* p) { + p->Push(op->first); + if (op->rest.defined()) p->Push(op->rest); + }) +.set_dispatch([](const Evaluate *op, CodeGenStackVM* p) { + if (is_const(op->value)) return; + p->Push(op->value); + p->PushOp(StackVM::POP); + }) +.set_dispatch([](const IfThenElse *op, CodeGenStackVM* p) { + p->Push(op->condition); + int64_t label_ejump = p->GetPC(); + int64_t else_jump = p->PushOp(StackVM::RJUMP_IF_FALSE, 0); + p->PushOp(StackVM::POP); + p->Push(op->then_case); + if (op->else_case.defined()) { + int64_t label_then_jump = p->GetPC(); + int64_t then_jump = p->PushOp(StackVM::RJUMP, 0); + int64_t else_begin = p->GetPC(); + p->SetOperand(else_jump, else_begin - label_ejump); + p->PushOp(StackVM::POP); + p->Push(op->else_case); + int64_t if_end = p->GetPC(); + p->SetOperand(then_jump, if_end - label_then_jump); + } else { + int64_t if_end = p->GetPC(); + p->SetOperand(else_jump, if_end - label_ejump); + p->PushOp(StackVM::POP); + } + }) +.set_dispatch([](const LetStmt *op, CodeGenStackVM* p) { + p->Push(op->value); + int64_t vid = p->AllocVarID(op->var.get()); + p->PushOp(StackVM::STORE_HEAP, vid); + p->Push(op->body); + }) +.set_dispatch([](const Ramp *op, CodeGenStackVM* p) { + LOG(FATAL) << "Ramp is not supported"; + }) +.set_dispatch([](const Broadcast *op, CodeGenStackVM* p) { + LOG(FATAL) << "Broadcast is not supported"; + }) +.set_dispatch