Skip to content

Commit

Permalink
[TIR][TRANSFORM] Return value support in tir.tvm_call_packed (apache#…
Browse files Browse the repository at this point in the history
…7932)

This PR fixes the return value support in tir.tvm_call_packed

- Clarified the semantics of the intrinsics
- Fix a problem when lowering call packed with nested scopes(let bindings)
- Added regression tests to cover the changes
  • Loading branch information
tqchen authored and trevor-m committed May 11, 2021
1 parent 7c177f3 commit f71ba44
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 49 deletions.
38 changes: 23 additions & 15 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,23 +334,27 @@ TVM_DLL const Op& tvm_stack_make_array();
/*!
* \brief See pesudo code
*
* int tvm_call_packed(name, TVMValue* args) {
* return_type tvm_call_packed(name, TVMValue* args) {
* TVMValue ret_value;
* int ret_code;
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
* (*f)(args, type_code_of(args), len(args));
* return 0;
* (*f)(args, type_code_of(args), len(args), &ret_value, &ret_code);
* // return type can be int, float, handle.
* return cast(return_type, ret_value.v_return_type);
* }
*/
TVM_DLL const Op& tvm_call_packed();

/*!
* \brief See pesudo code
*
* int tvm_call_trace_packed(name, TVMValue* args) {
* return_type tvm_call_trace_packed(name, TVMValue* args) {
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
* (*f)(args, type_code_of(args), len(args));
* return 0;
* // return type can be int, float, handle.
* return cast(return_type, ret_value.v_return_type);
* }
*/
TVM_DLL const Op& tvm_call_trace_packed();
Expand All @@ -372,16 +376,18 @@ TVM_DLL const Op& tvm_thread_context();
* \brief Lowered version of call packed, the space of value and
* type codes are explicitly allocated.
*
* int tvm_call_packed_lowered(name,
* TVMValue* value_stack,
* int* tcode_stack,
* int begin,
* int end) {
* return_type tvm_call_packed_lowered(name,
* TVMValue* value_stack,
* int* tcode_stack,
* int begin,
* int end) {
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
* f->CallPacked(TVMArgs(value_stack[begin:end],
* tcode_stack[begin:end]),
* TVMRetValue(value_stack + end, tcode_stack + end));
* // return type can be int, float, handle.
* return cast(return_type, load_return_from(tcode_stack + end))
* }
*/
TVM_DLL const Op& tvm_call_packed_lowered();
Expand All @@ -391,16 +397,18 @@ TVM_DLL const Op& tvm_call_packed_lowered();
* type codes are explicitly allocated. The return value is the
* (end - 1) value on the stack.
*
* int tvm_call_trace_packed_lowered(name,
* TVMValue* value_stack,
* int* tcode_stack,
* int begin,
* int end) {
* return_type tvm_call_trace_packed_lowered(name,
* TVMValue* value_stack,
* int* tcode_stack,
* int begin,
* int end) {
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
* f->CallPacked(TVMArgs(value_stack[begin:end],
* tcode_stack[begin:end]),
* TVMRetValue(value_stack + end, tcode_stack + end));
* // return type can be int, float, handle.
* return cast(return_type, load_return_from(tcode_stack + end))
* }
*/
TVM_DLL const Op& tvm_call_trace_packed_lowered();
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,26 @@ def _exit_cb():

return WithScope(None, _exit_cb)

def let(self, var_name, value):
"""Create a new let stmt binding.
Parameters
----------
var_name : str
The name of the variable
value : PrimExpr
The value to be bound
Returns
-------
var : tvm.tir.Var
The var that can be in for future emits.
"""
var = _expr.Var(var_name, dtype=value.dtype)
self.emit(lambda x: _stmt.LetStmt(var, value, x))
return var

def allocate(self, dtype, shape, name="buf", scope=None):
"""Create a allocate statement.
Expand Down
77 changes: 46 additions & 31 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,19 @@ class BuiltinLower : public StmtExprMutator {
}

Stmt VisitStmt(const Stmt& s) final {
// allocate space to hold prepare stmts before s
prep_seq_stack_.emplace_back(std::vector<Stmt>());

auto stmt = StmtExprMutator::VisitStmt(s);
auto& scope = alloca_scope_.back();
ICHECK_EQ(scope.run_shape_stack, -1);
ICHECK_EQ(scope.run_array_stack, 0);

if (prep_seq_.size() != 0) {
Stmt ret = SeqStmt::Flatten(prep_seq_, stmt);
prep_seq_.clear();
auto prep_seq = std::move(prep_seq_stack_.back());
prep_seq_stack_.pop_back();

if (prep_seq.size() != 0) {
Stmt ret = SeqStmt::Flatten(prep_seq, stmt);
return ret;
} else {
return stmt;
Expand Down Expand Up @@ -192,6 +197,7 @@ class BuiltinLower : public StmtExprMutator {
// if args.size() == 0, it represents a scalar shape ()
ICHECK(!alloca_scope_.empty());
auto& scope = alloca_scope_.back();
auto& prep_seq = prep_seq_stack_.back();
if (scope.run_shape_stack == -1) {
scope.run_shape_stack = 0;
}
Expand All @@ -201,57 +207,63 @@ class BuiltinLower : public StmtExprMutator {
op = expr.as<CallNode>();
// no need to perform any store for a scalar shape
for (size_t i = 0; i < op->args.size(); ++i) {
prep_seq_.emplace_back(Store(scope.stack_shape, cast(DataType::Int(64), op->args[i]),
ConstInt32(stack_begin + i), const_true(1)));
prep_seq.emplace_back(Store(scope.stack_shape, cast(DataType::Int(64), op->args[i]),
ConstInt32(stack_begin + i), const_true(1)));
}
return AddressOffset(scope.stack_shape, DataType::Int(64), stack_begin);
}
// make array
PrimExpr MakeArray(const CallNode* op) {
ICHECK(!alloca_scope_.empty());
auto& scope = alloca_scope_.back();
auto& prep_seq = prep_seq_stack_.back();

size_t idx = scope.run_array_stack;
scope.run_array_stack += 1;
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrData, op->args[0]));
prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrShape, op->args[1]));

prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrData, op->args[0]));
prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrShape, op->args[1]));
PrimExpr strides = op->args[2];
if (!strides.defined() || is_zero(strides)) {
strides = make_zero(DataType::Handle());
}
prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrStrides, strides));
prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrNDim, op->args[3]));
prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrStrides, strides));
prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrNDim, op->args[3]));
DataType dtype = op->args[4].dtype();
prep_seq_.emplace_back(
prep_seq.emplace_back(
TVMStructSet(scope.stack_array, idx, builtin::kArrTypeCode,
make_const(DataType::UInt(8), static_cast<int>(dtype.code()))));
prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeBits,
make_const(DataType::UInt(8), dtype.bits())));
prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeLanes,
make_const(DataType::UInt(16), dtype.lanes())));
prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeBits,
make_const(DataType::UInt(8), dtype.bits())));
prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeLanes,
make_const(DataType::UInt(16), dtype.lanes())));
// set byte offset
int data_bytes = GetVectorBytes(dtype);
PrimExpr byte_offset = op->args[5];
if (!is_zero(byte_offset)) {
byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes);
}
prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrByteOffset,
cast(DataType::UInt(64), byte_offset)));
prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrByteOffset,
cast(DataType::UInt(64), byte_offset)));
ICHECK(device_type_.defined()) << "Unknown device type in current IR";
ICHECK(device_id_.defined()) << "Unknown device id in current IR";
prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceId,
cast(DataType::Int(32), device_id_)));
prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceType,
cast(DataType::Int(32), device_type_)));
prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceId,
cast(DataType::Int(32), device_id_)));
prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceType,
cast(DataType::Int(32), device_type_)));
return TVMStructGet(DataType::Handle(), scope.stack_array, idx, builtin::kArrAddr);
}
// call packed.
PrimExpr MakeCallPacked(const CallNode* op) {
auto& scope = alloca_scope_.back();
auto& prep_seq = prep_seq_stack_.back();

int64_t restore_shape_stack = scope.run_shape_stack;
size_t restore_array_stack = scope.run_array_stack;
size_t arg_stack_begin = scope.run_arg_stack;

scope.run_arg_stack += op->args.size();
// Specially handle the buffer packed intrinsic
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
Expand All @@ -264,15 +276,15 @@ class BuiltinLower : public StmtExprMutator {
if (t != api_type) {
arg = Cast(api_type, arg);
}
prep_seq_.emplace_back(TVMStructSet(scope.stack_value,
static_cast<int>(arg_stack_begin + i - 1),
builtin::kTVMValueContent, arg));
prep_seq.emplace_back(TVMStructSet(scope.stack_value,
static_cast<int>(arg_stack_begin + i - 1),
builtin::kTVMValueContent, arg));
int arg_tcode = api_type.code();
if (api_type.is_handle() && arg.as<StringImmNode>()) {
arg_tcode = kTVMStr;
}
if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle;
prep_seq_.emplace_back(
prep_seq.emplace_back(
Store(scope.stack_tcode, ConstInt32(arg_tcode), stack_index, const_true(1)));
}
// UPDATE stack value
Expand All @@ -285,12 +297,15 @@ class BuiltinLower : public StmtExprMutator {
Array<PrimExpr> packed_args = {op->args[0], scope.stack_value, scope.stack_tcode,
ConstInt32(arg_stack_begin),
ConstInt32(arg_stack_begin + op->args.size() - 1)};
return Call(DataType::Int(32), builtin::tvm_call_packed_lowered(), packed_args);
// call_packed_lowered needs to do the type casting properly
return Call(op->dtype, builtin::tvm_call_packed_lowered(), packed_args);
}

PrimExpr MakeCallTracePacked(const CallNode* op) {
ICHECK(!alloca_scope_.empty());
auto& scope = alloca_scope_.back();
auto& prep_seq = prep_seq_stack_.back();

int64_t restore_shape_stack = scope.run_shape_stack;
size_t restore_array_stack = scope.run_array_stack;
size_t arg_stack_begin = scope.run_arg_stack;
Expand All @@ -307,12 +322,12 @@ class BuiltinLower : public StmtExprMutator {
if (t != api_type) {
arg = Cast(api_type, arg);
}
prep_seq_.emplace_back(TVMStructSet(scope.stack_value,
static_cast<int>(arg_stack_begin + i - 1),
builtin::kTVMValueContent, arg));
prep_seq.emplace_back(TVMStructSet(scope.stack_value,
static_cast<int>(arg_stack_begin + i - 1),
builtin::kTVMValueContent, arg));
int arg_tcode = api_type.code();
ICHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
prep_seq_.emplace_back(
prep_seq.emplace_back(
Store(scope.stack_tcode, ConstInt32(arg_tcode), stack_index, const_true(1)));
}
// UPDATE stack value
Expand Down Expand Up @@ -344,8 +359,8 @@ class BuiltinLower : public StmtExprMutator {
return false;
}

// The prepration sequence to be emitted.
std::vector<Stmt> prep_seq_;
// The prepration sequence to be emitted before the current statement.
std::vector<std::vector<Stmt>> prep_seq_stack_;
PrimExpr device_type_;
PrimExpr device_id_;

Expand Down
40 changes: 37 additions & 3 deletions tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,45 @@ def check_packed_func(target="llvm"):
tvm.ir.assert_structural_equal(alloca_shape, expected_stmt, map_free_vars=True)


def test_packed_func():
def test_lower_packed_func():
check_packed_func("llvm")
check_packed_func("stackvm")


@tvm.testing.requires_llvm
def test_call_packed_return_non_i32():
# This call packed that return non i32 types
expected_value = np.array([1.2, 1.4], dtype="float32")

def packed_echo(value):
return tvm.tir.call_intrin(
value.dtype, tvm.ir.Op.get("tir.tvm_call_packed"), "testing.echo", value
)

def build_tir():
Ab = tvm.tir.decl_buffer((2,), "float32")
ib = tvm.tir.ir_builder.create()
Aptr = ib.buffer_ptr(Ab)
# return f32
# Aptr[0] = testing.echo(expected_value[0])
Aptr[0] = packed_echo(tvm.tir.const(expected_value[0], "float32"))
# return handle
# let Aptr_var = testing.echo(Aptr) in Aptr_var[1] = expected_value[1]
Aptr_var = ib.let("Aptr_dup", packed_echo(Aptr.asobject()))
ib.emit(tvm.tir.Store(Aptr, tvm.tir.const(expected_value[1], "float32"), 1))

stmt = ib.get()
return tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "packed_test")
)

mod = build_tir()
f = tvm.build(mod, None, "llvm")
a = tvm.nd.array(np.zeros(2, dtype="float32"))
f(a)
tvm.testing.assert_allclose(a.asnumpy(), expected_value)


if __name__ == "__main__":
# Test cases for issue: https://github.com/apache/tvm/issues/7246
test_packed_func()
test_call_packed_return_non_i32()
test_lower_packed_func()

0 comments on commit f71ba44

Please sign in to comment.