diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 6a40d86b8984..d8248d4e1a87 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -334,11 +334,14 @@ TVM_DLL const Op& tvm_stack_make_array(); /*! * \brief See pesudo code * - * int tvm_call_packed(name, TVMValue* args) { + * return_type tvm_call_packed(name, TVMValue* args) { + * TVMValue ret_value; + * int ret_code; * ModuleNode* env = GetCurrentEnv(); * const PackedFunc* f = env->GetFuncFromEnv(name); - * (*f)(args, type_code_of(args), len(args)); - * return 0; + * (*f)(args, type_code_of(args), len(args), &ret_value, &ret_code); + * // return type can be int, float, handle. + * return cast(return_type, ret_value.v_return_type); * } */ TVM_DLL const Op& tvm_call_packed(); @@ -346,11 +349,12 @@ TVM_DLL const Op& tvm_call_packed(); /*! * \brief See pesudo code * - * int tvm_call_trace_packed(name, TVMValue* args) { + * return_type tvm_call_trace_packed(name, TVMValue* args) { * ModuleNode* env = GetCurrentEnv(); * const PackedFunc* f = env->GetFuncFromEnv(name); * (*f)(args, type_code_of(args), len(args)); - * return 0; + * // return type can be int, float, handle. + * return cast(return_type, ret_value.v_return_type); * } */ TVM_DLL const Op& tvm_call_trace_packed(); @@ -372,16 +376,18 @@ TVM_DLL const Op& tvm_thread_context(); * \brief Lowered version of call packed, the space of value and * type codes are explicitly allocated. * - * int tvm_call_packed_lowered(name, - * TVMValue* value_stack, - * int* tcode_stack, - * int begin, - * int end) { + * return_type tvm_call_packed_lowered(name, + * TVMValue* value_stack, + * int* tcode_stack, + * int begin, + * int end) { * ModuleNode* env = GetCurrentEnv(); * const PackedFunc* f = env->GetFuncFromEnv(name); * f->CallPacked(TVMArgs(value_stack[begin:end], * tcode_stack[begin:end]), * TVMRetValue(value_stack + end, tcode_stack + end)); + * // return type can be int, float, handle. + * return cast(return_type, load_return_from(tcode_stack + end)) * } */ TVM_DLL const Op& tvm_call_packed_lowered(); @@ -391,16 +397,18 @@ TVM_DLL const Op& tvm_call_packed_lowered(); * type codes are explicitly allocated. The return value is the * (end - 1) value on the stack. * - * int tvm_call_trace_packed_lowered(name, - * TVMValue* value_stack, - * int* tcode_stack, - * int begin, - * int end) { + * return_type tvm_call_trace_packed_lowered(name, + * TVMValue* value_stack, + * int* tcode_stack, + * int begin, + * int end) { * ModuleNode* env = GetCurrentEnv(); * const PackedFunc* f = env->GetFuncFromEnv(name); * f->CallPacked(TVMArgs(value_stack[begin:end], * tcode_stack[begin:end]), * TVMRetValue(value_stack + end, tcode_stack + end)); + * // return type can be int, float, handle. + * return cast(return_type, load_return_from(tcode_stack + end)) * } */ TVM_DLL const Op& tvm_call_trace_packed_lowered(); diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 2ecbdeda8371..4934bf04727f 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -374,6 +374,26 @@ def _exit_cb(): return WithScope(None, _exit_cb) + def let(self, var_name, value): + """Create a new let stmt binding. + + Parameters + ---------- + var_name : str + The name of the variable + + value : PrimExpr + The value to be bound + + Returns + ------- + var : tvm.tir.Var + The var that can be in for future emits. + """ + var = _expr.Var(var_name, dtype=value.dtype) + self.emit(lambda x: _stmt.LetStmt(var, value, x)) + return var + def allocate(self, dtype, shape, name="buf", scope=None): """Create a allocate statement. diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index c40fd7edfdc2..8d2857ef7a40 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -89,14 +89,19 @@ class BuiltinLower : public StmtExprMutator { } Stmt VisitStmt(const Stmt& s) final { + // allocate space to hold prepare stmts before s + prep_seq_stack_.emplace_back(std::vector()); + auto stmt = StmtExprMutator::VisitStmt(s); auto& scope = alloca_scope_.back(); ICHECK_EQ(scope.run_shape_stack, -1); ICHECK_EQ(scope.run_array_stack, 0); - if (prep_seq_.size() != 0) { - Stmt ret = SeqStmt::Flatten(prep_seq_, stmt); - prep_seq_.clear(); + auto prep_seq = std::move(prep_seq_stack_.back()); + prep_seq_stack_.pop_back(); + + if (prep_seq.size() != 0) { + Stmt ret = SeqStmt::Flatten(prep_seq, stmt); return ret; } else { return stmt; @@ -192,6 +197,7 @@ class BuiltinLower : public StmtExprMutator { // if args.size() == 0, it represents a scalar shape () ICHECK(!alloca_scope_.empty()); auto& scope = alloca_scope_.back(); + auto& prep_seq = prep_seq_stack_.back(); if (scope.run_shape_stack == -1) { scope.run_shape_stack = 0; } @@ -201,8 +207,8 @@ class BuiltinLower : public StmtExprMutator { op = expr.as(); // no need to perform any store for a scalar shape for (size_t i = 0; i < op->args.size(); ++i) { - prep_seq_.emplace_back(Store(scope.stack_shape, cast(DataType::Int(64), op->args[i]), - ConstInt32(stack_begin + i), const_true(1))); + prep_seq.emplace_back(Store(scope.stack_shape, cast(DataType::Int(64), op->args[i]), + ConstInt32(stack_begin + i), const_true(1))); } return AddressOffset(scope.stack_shape, DataType::Int(64), stack_begin); } @@ -210,48 +216,54 @@ class BuiltinLower : public StmtExprMutator { PrimExpr MakeArray(const CallNode* op) { ICHECK(!alloca_scope_.empty()); auto& scope = alloca_scope_.back(); + auto& prep_seq = prep_seq_stack_.back(); + size_t idx = scope.run_array_stack; scope.run_array_stack += 1; PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrData, op->args[0])); - prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrShape, op->args[1])); + + prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrData, op->args[0])); + prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrShape, op->args[1])); PrimExpr strides = op->args[2]; if (!strides.defined() || is_zero(strides)) { strides = make_zero(DataType::Handle()); } - prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrStrides, strides)); - prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrNDim, op->args[3])); + prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrStrides, strides)); + prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrNDim, op->args[3])); DataType dtype = op->args[4].dtype(); - prep_seq_.emplace_back( + prep_seq.emplace_back( TVMStructSet(scope.stack_array, idx, builtin::kArrTypeCode, make_const(DataType::UInt(8), static_cast(dtype.code())))); - prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeBits, - make_const(DataType::UInt(8), dtype.bits()))); - prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeLanes, - make_const(DataType::UInt(16), dtype.lanes()))); + prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeBits, + make_const(DataType::UInt(8), dtype.bits()))); + prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrTypeLanes, + make_const(DataType::UInt(16), dtype.lanes()))); // set byte offset int data_bytes = GetVectorBytes(dtype); PrimExpr byte_offset = op->args[5]; if (!is_zero(byte_offset)) { byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes); } - prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrByteOffset, - cast(DataType::UInt(64), byte_offset))); + prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrByteOffset, + cast(DataType::UInt(64), byte_offset))); ICHECK(device_type_.defined()) << "Unknown device type in current IR"; ICHECK(device_id_.defined()) << "Unknown device id in current IR"; - prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceId, - cast(DataType::Int(32), device_id_))); - prep_seq_.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceType, - cast(DataType::Int(32), device_type_))); + prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceId, + cast(DataType::Int(32), device_id_))); + prep_seq.emplace_back(TVMStructSet(scope.stack_array, idx, builtin::kArrDeviceType, + cast(DataType::Int(32), device_type_))); return TVMStructGet(DataType::Handle(), scope.stack_array, idx, builtin::kArrAddr); } // call packed. PrimExpr MakeCallPacked(const CallNode* op) { auto& scope = alloca_scope_.back(); + auto& prep_seq = prep_seq_stack_.back(); + int64_t restore_shape_stack = scope.run_shape_stack; size_t restore_array_stack = scope.run_array_stack; size_t arg_stack_begin = scope.run_arg_stack; + scope.run_arg_stack += op->args.size(); // Specially handle the buffer packed intrinsic PrimExpr expr = StmtExprMutator::VisitExpr_(op); @@ -264,15 +276,15 @@ class BuiltinLower : public StmtExprMutator { if (t != api_type) { arg = Cast(api_type, arg); } - prep_seq_.emplace_back(TVMStructSet(scope.stack_value, - static_cast(arg_stack_begin + i - 1), - builtin::kTVMValueContent, arg)); + prep_seq.emplace_back(TVMStructSet(scope.stack_value, + static_cast(arg_stack_begin + i - 1), + builtin::kTVMValueContent, arg)); int arg_tcode = api_type.code(); if (api_type.is_handle() && arg.as()) { arg_tcode = kTVMStr; } if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle; - prep_seq_.emplace_back( + prep_seq.emplace_back( Store(scope.stack_tcode, ConstInt32(arg_tcode), stack_index, const_true(1))); } // UPDATE stack value @@ -285,12 +297,15 @@ class BuiltinLower : public StmtExprMutator { Array packed_args = {op->args[0], scope.stack_value, scope.stack_tcode, ConstInt32(arg_stack_begin), ConstInt32(arg_stack_begin + op->args.size() - 1)}; - return Call(DataType::Int(32), builtin::tvm_call_packed_lowered(), packed_args); + // call_packed_lowered needs to do the type casting properly + return Call(op->dtype, builtin::tvm_call_packed_lowered(), packed_args); } PrimExpr MakeCallTracePacked(const CallNode* op) { ICHECK(!alloca_scope_.empty()); auto& scope = alloca_scope_.back(); + auto& prep_seq = prep_seq_stack_.back(); + int64_t restore_shape_stack = scope.run_shape_stack; size_t restore_array_stack = scope.run_array_stack; size_t arg_stack_begin = scope.run_arg_stack; @@ -307,12 +322,12 @@ class BuiltinLower : public StmtExprMutator { if (t != api_type) { arg = Cast(api_type, arg); } - prep_seq_.emplace_back(TVMStructSet(scope.stack_value, - static_cast(arg_stack_begin + i - 1), - builtin::kTVMValueContent, arg)); + prep_seq.emplace_back(TVMStructSet(scope.stack_value, + static_cast(arg_stack_begin + i - 1), + builtin::kTVMValueContent, arg)); int arg_tcode = api_type.code(); ICHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers"; - prep_seq_.emplace_back( + prep_seq.emplace_back( Store(scope.stack_tcode, ConstInt32(arg_tcode), stack_index, const_true(1))); } // UPDATE stack value @@ -344,8 +359,8 @@ class BuiltinLower : public StmtExprMutator { return false; } - // The prepration sequence to be emitted. - std::vector prep_seq_; + // The prepration sequence to be emitted before the current statement. + std::vector> prep_seq_stack_; PrimExpr device_type_; PrimExpr device_id_; diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py index 8b2b26a42ceb..d6b427a50fae 100644 --- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py @@ -133,11 +133,45 @@ def check_packed_func(target="llvm"): tvm.ir.assert_structural_equal(alloca_shape, expected_stmt, map_free_vars=True) -def test_packed_func(): +def test_lower_packed_func(): check_packed_func("llvm") check_packed_func("stackvm") +@tvm.testing.requires_llvm +def test_call_packed_return_non_i32(): + # This call packed that return non i32 types + expected_value = np.array([1.2, 1.4], dtype="float32") + + def packed_echo(value): + return tvm.tir.call_intrin( + value.dtype, tvm.ir.Op.get("tir.tvm_call_packed"), "testing.echo", value + ) + + def build_tir(): + Ab = tvm.tir.decl_buffer((2,), "float32") + ib = tvm.tir.ir_builder.create() + Aptr = ib.buffer_ptr(Ab) + # return f32 + # Aptr[0] = testing.echo(expected_value[0]) + Aptr[0] = packed_echo(tvm.tir.const(expected_value[0], "float32")) + # return handle + # let Aptr_var = testing.echo(Aptr) in Aptr_var[1] = expected_value[1] + Aptr_var = ib.let("Aptr_dup", packed_echo(Aptr.asobject())) + ib.emit(tvm.tir.Store(Aptr, tvm.tir.const(expected_value[1], "float32"), 1)) + + stmt = ib.get() + return tvm.IRModule.from_expr( + tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "packed_test") + ) + + mod = build_tir() + f = tvm.build(mod, None, "llvm") + a = tvm.nd.array(np.zeros(2, dtype="float32")) + f(a) + tvm.testing.assert_allclose(a.asnumpy(), expected_value) + + if __name__ == "__main__": - # Test cases for issue: https://github.com/apache/tvm/issues/7246 - test_packed_func() + test_call_packed_return_non_i32() + test_lower_packed_func()