Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADDON] Allow piggy back nvcc compiler and code #35

Merged
merged 1 commit into from
Feb 7, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ typedef enum {
kArrayHandle = 5U,
kTVMType = 6U,
kNodeHandle = 7U,
kStr = 8U,
kFuncHandle = 9U
kFuncHandle = 8U,
kStr = 9U,
kBytes = 10U
} TVMTypeCode;

/*!
Expand Down Expand Up @@ -86,6 +87,15 @@ typedef union {
TVMType v_type;
} TVMValue;

/*!
* \brief Byte array type used to pass in byte array
* When kBytes is used as data type.
*/
typedef struct {
const char* data;
size_t size;
} TVMByteArray;

/*!
* \brief The device type
*/
Expand Down
17 changes: 14 additions & 3 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ class PackedFunc {
* \return reference to the registered function.
*/
static const PackedFunc& GetGlobal(const std::string& name);
/*!
* \brief Whether the global function exist
* \param name The name of the function.
* \return Whetehr the global function exist.
*/
static bool GlobalExist(const std::string& name);
/*!
* \brief Get the names of currently registered global function.
*/
Expand Down Expand Up @@ -267,9 +273,13 @@ class TVMArgValue : public TVMPODValue_ {
operator std::string() const {
if (type_code_ == kTVMType) {
return TVMType2String(operator TVMType());
} else if (type_code_ == kBytes) {
TVMByteArray* arr = static_cast<TVMByteArray*>(value_.v_handle);
return std::string(arr->data, arr->size);
} else {
TVM_CHECK_TYPE_CODE(type_code_, kStr);
return std::string(value_.v_str);
}
TVM_CHECK_TYPE_CODE(type_code_, kStr);
return std::string(value_.v_str);
}
operator TVMType() const {
if (type_code_ == kStr) {
Expand Down Expand Up @@ -452,7 +462,8 @@ class TVMRetValue : public TVMPODValue_ {
template<typename T>
void Assign(const T& other) {
switch (other.type_code()) {
case kStr: {
case kStr:
case kBytes: {
SwitchToClass<std::string>(kStr, other);
break;
}
Expand Down
13 changes: 11 additions & 2 deletions python/tvm/_ctypes/_function.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding: utf-8
# pylint: disable=invalid-name, protected-access
# pylint: disable=invalid-name, protected-access, too-many-branches
"""Symbolic configuration API."""
from __future__ import absolute_import

Expand All @@ -9,7 +9,7 @@

from .._base import _LIB, check_call
from .._base import c_str, py_str, string_types
from ._types import TVMValue, TypeCode, TVMType
from ._types import TVMValue, TypeCode, TVMType, TVMByteArray
from ._types import TVMPackedCFunc, TVMCFuncFinalizer
from ._types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH
from ._node import NodeBase, SliceBase, convert_to_node
Expand Down Expand Up @@ -92,6 +92,15 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, TVMType):
values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR
elif isinstance(arg, bytearray):
arr = TVMByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg),
ctypes.POINTER(ctypes.c_byte))
arr.size = len(arg)
values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr))
temp_args.append(arr)
type_codes[i] = TypeCode.BYTES
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
Expand Down
30 changes: 25 additions & 5 deletions python/tvm/_ctypes/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ class TypeCode(object):
ARRAY_HANDLE = 5
TVM_TYPE = 6
NODE_HANDLE = 7
STR = 8
FUNC_HANDLE = 9
FUNC_HANDLE = 8
STR = 9
BYTES = 10

def _api_type(code):
"""create a type accepted by API"""
Expand Down Expand Up @@ -88,6 +89,11 @@ class TVMValue(ctypes.Union):
("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p)]

class TVMByteArray(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]


TVMPackedCFunc = ctypes.CFUNCTYPE(
None,
Expand All @@ -110,20 +116,34 @@ def _return_handle(x):
handle = ctypes.c_void_p(handle)
return handle

def _return_bytes(x):
"""return handle"""
handle = x.v_handle
if not isinstance(handle, ctypes.c_void_p):
handle = ctypes.c_void_p(handle)
arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
size = arr.size
res = bytearray(size)
rptr = (ctypes.c_byte * size).from_buffer(res)
if not ctypes.memmove(rptr, arr.data, size):
raise RuntimeError('memmove failed')
return res


RETURN_SWITCH = {
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str)
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes
}


C_TO_PY_ARG_SWITCH = {
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str)
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes
}
1 change: 1 addition & 0 deletions python/tvm/addon/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Addon utilities to python"""
55 changes: 55 additions & 0 deletions python/tvm/addon/nvcc_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Util to compile with NVCC"""
import os
import sys
import tempfile
import subprocess

def compile_source(code, target="cubin"):
"""Compile cuda code with NVCC from env.

Parameters
----------
code : str
The cuda code.

target: str
The target format

Return
------
cubin : bytearray
The bytearray of the cubin
"""
temp_dir = tempfile.mkdtemp()
if target not in ["cubin", "ptx", "fatbin"]:
raise ValueError("target must be in cubin, ptx, fatbin")
path_code = os.path.join(temp_dir, "my_kernel.cu")
path_target = os.path.join(temp_dir, "my_kernel.%s" % target)

with open(path_code, "w") as out_file:
out_file.write(code)

cmd = ["nvcc"]
cmd += ["--%s" % target, "-O3"]
cmd += ["-o", path_target]
cmd += [path_code]
args = ' '.join(cmd)

proc = subprocess.Popen(
args, shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()

if proc.returncode != 0:
sys.stderr.write("Compilation error:\n")
sys.stderr.write(out)
sys.stderr.flush()
cubin = None
else:
cubin = bytearray(open(path_target, "rb").read())
os.remove(path_code)
if os.path.exists(path_target):
os.remove(path_target)
os.rmdir(temp_dir)
return cubin
13 changes: 9 additions & 4 deletions src/arithmetic/canonical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ class Canonical::Internal : public IRMutator {
}
// functions
Stmt Mutate(Stmt stmt) final {
return IRMutator::Mutate(stmt);
stmt = IRMutator::Mutate(stmt);
return stmt;
}
Expr MutateExpr_(Expr expr) {
static const FMutateExpr& f = Internal::vtable_expr();
Expand All @@ -176,6 +177,7 @@ class Canonical::Internal : public IRMutator {
ret_entry_.has_side_effect = stack_.back().has_side_effect;
ret_entry_.max_level = stack_.back().max_level;
stack_.pop_back();
CHECK(expr.defined());
return expr;
}
// call produce to get a cache entry.
Expand Down Expand Up @@ -399,6 +401,7 @@ class Canonical::Internal : public IRMutator {
// subroutine to do produce
Expr SumAdd(CacheEntry a, CacheEntry b, int bscale) {
ret_entry_.sum = SumAdd_(a.AsSum(), b.AsSum(), bscale);
CHECK_NE(stack_.size(), 0U);
ret_entry_.max_level = stack_.back().max_level;
ret_entry_.has_side_effect = stack_.back().has_side_effect;
auto it = cache_sum_.find(ret_entry_.sum);
Expand All @@ -408,8 +411,6 @@ class Canonical::Internal : public IRMutator {
ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type());
cache_sum_[ret_entry_.sum] = ret_entry_;
}
ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type());
cache_sum_[ret_entry_.sum] = ret_entry_;
return ret_entry_.value;
}
// convert sum to expr
Expand Down Expand Up @@ -444,7 +445,11 @@ class Canonical::Internal : public IRMutator {
}
}
}
return vsum;
if (vsum.defined()) {
return vsum;
} else {
return make_zero(t);
}
}
};

Expand Down
14 changes: 13 additions & 1 deletion src/codegen/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,19 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
os << CodeGenCUDA().Compile(f, output_ssa);
os << '\n';
}
std::string ptx = runtime::NVRTCCompile(os.str());
std::string code = os.str();

if (PackedFunc::GlobalExist("tvm_callback_cuda_postproc")) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc");
code = f(code).operator std::string();
}
std::string ptx;
if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile");
ptx = f(code).operator std::string();
} else {
ptx = runtime::NVRTCCompile(os.str());
}
std::unordered_map<LoweredFunc, PackedFunc> ret;

runtime::CUDAModule m = runtime::CUDAModule::Create(ptx);
Expand Down
6 changes: 6 additions & 0 deletions src/runtime/packed_func_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ const PackedFunc& PackedFunc::GetGlobal(const std::string& name) {
return *(it->second);
}

bool PackedFunc::GlobalExist(const std::string& name) {
PackedFuncRegistry* r = PackedFuncRegistry::Global();
auto it = r->fmap.find(name);
return it != r->fmap.end();
}

std::vector<std::string> PackedFunc::ListGlobalNames() {
PackedFuncRegistry* r = PackedFuncRegistry::Global();
std::vector<std::string> keys;
Expand Down
22 changes: 16 additions & 6 deletions tests/python/integration/test_gemm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import tvm
from tvm.addon import nvcc_compiler
import numpy as np

@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc_compiler.compile_source(code, target="ptx")
print(ptx.decode("utf-8"))
return ptx

@tvm.register_func
def tvm_callback_cuda_postproc(code):
print(code)
return code

def test_gemm():
# graph
nn = 1024
Expand All @@ -23,7 +35,6 @@ def test_gemm():
s = tvm.Schedule(C.op)
xtile, ytile = 32, 32
s[AA].set_scope("shared")
#s[CC].set_scope("global")
s[BB].set_scope("shared")

scale = 8
Expand Down Expand Up @@ -60,8 +71,6 @@ def check_device(target):
codes = []
f = tvm.build(s, [A, B, C], target, record_codes=codes,
max_auto_unroll_step=max_auto_unroll_step)
for c in codes[1:]:
print(c)
if target == "cuda":
ctx = tvm.gpu(0)
else:
Expand All @@ -77,13 +86,14 @@ def check_device(target):
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
f(a, b, c)
for i in range(4):
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)

tvm.init_opencl()
check_device("cuda")
check_device("opencl")
#tvm.init_opencl()
#check_device("opencl")

if __name__ == "__main__":
test_gemm()
10 changes: 9 additions & 1 deletion tests/python/unittest/test_runtime_packed_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,17 @@ def myfunc(*args):
assert isinstance(f, tvm.nd.Function)
f(*targs)

def test_byte_array():
s = "hello"
a = bytearray(s, encoding="ascii")

def myfunc(ss):
assert ss == a
f = tvm.convert(myfunc)
f(a)

if __name__ == "__main__":
test_function()
test_convert()
test_get_global()
test_return_func()
test_byte_array()