diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 5151fb5392ee..92eda1f9021c 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -51,8 +51,9 @@ typedef enum { kArrayHandle = 5U, kTVMType = 6U, kNodeHandle = 7U, - kStr = 8U, - kFuncHandle = 9U + kFuncHandle = 8U, + kStr = 9U, + kBytes = 10U } TVMTypeCode; /*! @@ -86,6 +87,15 @@ typedef union { TVMType v_type; } TVMValue; +/*! + * \brief Byte array type used to pass in byte array + * When kBytes is used as data type. + */ +typedef struct { + const char* data; + size_t size; +} TVMByteArray; + /*! * \brief The device type */ diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 3b1921ee8868..bcc5fffc43b2 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -111,6 +111,12 @@ class PackedFunc { * \return reference to the registered function. */ static const PackedFunc& GetGlobal(const std::string& name); + /*! + * \brief Whether the global function exist + * \param name The name of the function. + * \return Whetehr the global function exist. + */ + static bool GlobalExist(const std::string& name); /*! * \brief Get the names of currently registered global function. */ @@ -267,9 +273,13 @@ class TVMArgValue : public TVMPODValue_ { operator std::string() const { if (type_code_ == kTVMType) { return TVMType2String(operator TVMType()); + } else if (type_code_ == kBytes) { + TVMByteArray* arr = static_cast(value_.v_handle); + return std::string(arr->data, arr->size); + } else { + TVM_CHECK_TYPE_CODE(type_code_, kStr); + return std::string(value_.v_str); } - TVM_CHECK_TYPE_CODE(type_code_, kStr); - return std::string(value_.v_str); } operator TVMType() const { if (type_code_ == kStr) { @@ -452,7 +462,8 @@ class TVMRetValue : public TVMPODValue_ { template void Assign(const T& other) { switch (other.type_code()) { - case kStr: { + case kStr: + case kBytes: { SwitchToClass(kStr, other); break; } diff --git a/python/tvm/_ctypes/_function.py b/python/tvm/_ctypes/_function.py index e8377fe4bdfe..6393d1fed58d 100644 --- a/python/tvm/_ctypes/_function.py +++ b/python/tvm/_ctypes/_function.py @@ -1,5 +1,5 @@ # coding: utf-8 -# pylint: disable=invalid-name, protected-access +# pylint: disable=invalid-name, protected-access, too-many-branches """Symbolic configuration API.""" from __future__ import absolute_import @@ -9,7 +9,7 @@ from .._base import _LIB, check_call from .._base import c_str, py_str, string_types -from ._types import TVMValue, TypeCode, TVMType +from ._types import TVMValue, TypeCode, TVMType, TVMByteArray from ._types import TVMPackedCFunc, TVMCFuncFinalizer from ._types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH from ._node import NodeBase, SliceBase, convert_to_node @@ -92,6 +92,15 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, TVMType): values[i].v_str = c_str(str(arg)) type_codes[i] = TypeCode.STR + elif isinstance(arg, bytearray): + arr = TVMByteArray() + arr.data = ctypes.cast( + (ctypes.c_byte * len(arg)).from_buffer(arg), + ctypes.POINTER(ctypes.c_byte)) + arr.size = len(arg) + values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr)) + temp_args.append(arr) + type_codes[i] = TypeCode.BYTES elif isinstance(arg, string_types): values[i].v_str = c_str(arg) type_codes[i] = TypeCode.STR diff --git a/python/tvm/_ctypes/_types.py b/python/tvm/_ctypes/_types.py index 58bdf111e54c..33d0e8fe0737 100644 --- a/python/tvm/_ctypes/_types.py +++ b/python/tvm/_ctypes/_types.py @@ -18,8 +18,9 @@ class TypeCode(object): ARRAY_HANDLE = 5 TVM_TYPE = 6 NODE_HANDLE = 7 - STR = 8 - FUNC_HANDLE = 9 + FUNC_HANDLE = 8 + STR = 9 + BYTES = 10 def _api_type(code): """create a type accepted by API""" @@ -88,6 +89,11 @@ class TVMValue(ctypes.Union): ("v_handle", ctypes.c_void_p), ("v_str", ctypes.c_char_p)] +class TVMByteArray(ctypes.Structure): + """TVM datatype structure""" + _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)), + ("size", ctypes.c_size_t)] + TVMPackedCFunc = ctypes.CFUNCTYPE( None, @@ -110,20 +116,34 @@ def _return_handle(x): handle = ctypes.c_void_p(handle) return handle +def _return_bytes(x): + """return handle""" + handle = x.v_handle + if not isinstance(handle, ctypes.c_void_p): + handle = ctypes.c_void_p(handle) + arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0] + size = arr.size + res = bytearray(size) + rptr = (ctypes.c_byte * size).from_buffer(res) + if not ctypes.memmove(rptr, arr.data, size): + raise RuntimeError('memmove failed') + return res + RETURN_SWITCH = { TypeCode.INT: lambda x: x.v_int64, TypeCode.FLOAT: lambda x: x.v_float64, TypeCode.HANDLE: _return_handle, TypeCode.NULL: lambda x: None, - TypeCode.STR: lambda x: py_str(x.v_str) + TypeCode.STR: lambda x: py_str(x.v_str), + TypeCode.BYTES: _return_bytes } - C_TO_PY_ARG_SWITCH = { TypeCode.INT: lambda x: x.v_int64, TypeCode.FLOAT: lambda x: x.v_float64, TypeCode.HANDLE: _return_handle, TypeCode.NULL: lambda x: None, - TypeCode.STR: lambda x: py_str(x.v_str) + TypeCode.STR: lambda x: py_str(x.v_str), + TypeCode.BYTES: _return_bytes } diff --git a/python/tvm/addon/__init__.py b/python/tvm/addon/__init__.py new file mode 100644 index 000000000000..e0759cf544e2 --- /dev/null +++ b/python/tvm/addon/__init__.py @@ -0,0 +1 @@ +"""Addon utilities to python""" diff --git a/python/tvm/addon/nvcc_compiler.py b/python/tvm/addon/nvcc_compiler.py new file mode 100644 index 000000000000..a1c2b938d58c --- /dev/null +++ b/python/tvm/addon/nvcc_compiler.py @@ -0,0 +1,55 @@ +"""Util to compile with NVCC""" +import os +import sys +import tempfile +import subprocess + +def compile_source(code, target="cubin"): + """Compile cuda code with NVCC from env. + + Parameters + ---------- + code : str + The cuda code. + + target: str + The target format + + Return + ------ + cubin : bytearray + The bytearray of the cubin + """ + temp_dir = tempfile.mkdtemp() + if target not in ["cubin", "ptx", "fatbin"]: + raise ValueError("target must be in cubin, ptx, fatbin") + path_code = os.path.join(temp_dir, "my_kernel.cu") + path_target = os.path.join(temp_dir, "my_kernel.%s" % target) + + with open(path_code, "w") as out_file: + out_file.write(code) + + cmd = ["nvcc"] + cmd += ["--%s" % target, "-O3"] + cmd += ["-o", path_target] + cmd += [path_code] + args = ' '.join(cmd) + + proc = subprocess.Popen( + args, shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + (out, _) = proc.communicate() + + if proc.returncode != 0: + sys.stderr.write("Compilation error:\n") + sys.stderr.write(out) + sys.stderr.flush() + cubin = None + else: + cubin = bytearray(open(path_target, "rb").read()) + os.remove(path_code) + if os.path.exists(path_target): + os.remove(path_target) + os.rmdir(temp_dir) + return cubin diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc index 2c99094551c5..1bc28374418f 100644 --- a/src/arithmetic/canonical.cc +++ b/src/arithmetic/canonical.cc @@ -158,7 +158,8 @@ class Canonical::Internal : public IRMutator { } // functions Stmt Mutate(Stmt stmt) final { - return IRMutator::Mutate(stmt); + stmt = IRMutator::Mutate(stmt); + return stmt; } Expr MutateExpr_(Expr expr) { static const FMutateExpr& f = Internal::vtable_expr(); @@ -176,6 +177,7 @@ class Canonical::Internal : public IRMutator { ret_entry_.has_side_effect = stack_.back().has_side_effect; ret_entry_.max_level = stack_.back().max_level; stack_.pop_back(); + CHECK(expr.defined()); return expr; } // call produce to get a cache entry. @@ -399,6 +401,7 @@ class Canonical::Internal : public IRMutator { // subroutine to do produce Expr SumAdd(CacheEntry a, CacheEntry b, int bscale) { ret_entry_.sum = SumAdd_(a.AsSum(), b.AsSum(), bscale); + CHECK_NE(stack_.size(), 0U); ret_entry_.max_level = stack_.back().max_level; ret_entry_.has_side_effect = stack_.back().has_side_effect; auto it = cache_sum_.find(ret_entry_.sum); @@ -408,8 +411,6 @@ class Canonical::Internal : public IRMutator { ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type()); cache_sum_[ret_entry_.sum] = ret_entry_; } - ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type()); - cache_sum_[ret_entry_.sum] = ret_entry_; return ret_entry_.value; } // convert sum to expr @@ -444,7 +445,11 @@ class Canonical::Internal : public IRMutator { } } } - return vsum; + if (vsum.defined()) { + return vsum; + } else { + return make_zero(t); + } } }; diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 9098200bdf27..06a8762a8244 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -50,7 +50,19 @@ MakeNVRTC(Array funcs) { os << CodeGenCUDA().Compile(f, output_ssa); os << '\n'; } - std::string ptx = runtime::NVRTCCompile(os.str()); + std::string code = os.str(); + + if (PackedFunc::GlobalExist("tvm_callback_cuda_postproc")) { + const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc"); + code = f(code).operator std::string(); + } + std::string ptx; + if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) { + const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile"); + ptx = f(code).operator std::string(); + } else { + ptx = runtime::NVRTCCompile(os.str()); + } std::unordered_map ret; runtime::CUDAModule m = runtime::CUDAModule::Create(ptx); diff --git a/src/runtime/packed_func_registry.cc b/src/runtime/packed_func_registry.cc index 8a9c60944d97..b3f2ff371a5a 100644 --- a/src/runtime/packed_func_registry.cc +++ b/src/runtime/packed_func_registry.cc @@ -46,6 +46,12 @@ const PackedFunc& PackedFunc::GetGlobal(const std::string& name) { return *(it->second); } +bool PackedFunc::GlobalExist(const std::string& name) { + PackedFuncRegistry* r = PackedFuncRegistry::Global(); + auto it = r->fmap.find(name); + return it != r->fmap.end(); +} + std::vector PackedFunc::ListGlobalNames() { PackedFuncRegistry* r = PackedFuncRegistry::Global(); std::vector keys; diff --git a/tests/python/integration/test_gemm.py b/tests/python/integration/test_gemm.py index 8b63d8c08e4c..feab008f08c4 100644 --- a/tests/python/integration/test_gemm.py +++ b/tests/python/integration/test_gemm.py @@ -1,6 +1,18 @@ import tvm +from tvm.addon import nvcc_compiler import numpy as np +@tvm.register_func +def tvm_callback_cuda_compile(code): + ptx = nvcc_compiler.compile_source(code, target="ptx") + print(ptx.decode("utf-8")) + return ptx + +@tvm.register_func +def tvm_callback_cuda_postproc(code): + print(code) + return code + def test_gemm(): # graph nn = 1024 @@ -23,7 +35,6 @@ def test_gemm(): s = tvm.Schedule(C.op) xtile, ytile = 32, 32 s[AA].set_scope("shared") - #s[CC].set_scope("global") s[BB].set_scope("shared") scale = 8 @@ -60,8 +71,6 @@ def check_device(target): codes = [] f = tvm.build(s, [A, B, C], target, record_codes=codes, max_auto_unroll_step=max_auto_unroll_step) - for c in codes[1:]: - print(c) if target == "cuda": ctx = tvm.gpu(0) else: @@ -77,13 +86,14 @@ def check_device(target): a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) - f(a, b, c) + for i in range(4): + f(a, b, c) np.testing.assert_allclose( c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5) - tvm.init_opencl() check_device("cuda") - check_device("opencl") + #tvm.init_opencl() + #check_device("opencl") if __name__ == "__main__": test_gemm() diff --git a/tests/python/unittest/test_runtime_packed_func.py b/tests/python/unittest/test_runtime_packed_func.py index bd9eada79b6d..6aa8dd8959c7 100644 --- a/tests/python/unittest/test_runtime_packed_func.py +++ b/tests/python/unittest/test_runtime_packed_func.py @@ -35,9 +35,17 @@ def myfunc(*args): assert isinstance(f, tvm.nd.Function) f(*targs) +def test_byte_array(): + s = "hello" + a = bytearray(s, encoding="ascii") + + def myfunc(ss): + assert ss == a + f = tvm.convert(myfunc) + f(a) if __name__ == "__main__": - test_function() test_convert() test_get_global() test_return_func() + test_byte_array()