diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 32c98efa6963..cdbdb4b5424f 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -445,7 +445,14 @@ def call_tir(global_var: tvm.ir.GlobalVar, *args): The call expression. """ assert isinstance(global_var, tvm.ir.GlobalVar) - return Call(dtype="void", op=global_var, args=args) + + dtype = "void" + if global_var.checked_type is not None: + ret_type = global_var.checked_type.ret_type + if hasattr(ret_type, "dtype"): + dtype = ret_type.dtype + + return Call(dtype=dtype, op=global_var, args=args) def start_profile_intrinsic(id): diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index 0c34f85246c9..02fb899f0dd9 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -19,6 +19,8 @@ #include #include #include +#include +#include #include "./utils.h" @@ -38,7 +40,17 @@ GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature) IRModuleFrame frame = FindModuleFrame("I.DeclFunction"); CHECK(!frame->global_var_map.count(func_name)) << "ValueError: function " << func_name << " already exists"; - GlobalVar gv = GlobalVar(func_name); + + auto gvar_type = [&]() -> Type { + if (auto prim_func = func_signature.as()) { + Array arg_types = prim_func->params.Map([](const auto& var) { return GetType(var); }); + return FuncType(arg_types, prim_func->ret_type, {}, {}); + } + + return {}; + }(); + + GlobalVar gv = GlobalVar(func_name, gvar_type); CHECK(frame->functions.find(gv) == frame->functions.end()) << "ValueError: function " << func_name << " has already been defined."; frame->global_var_map.Set(func_name, gv); diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 2646b5baea7c..0cb072701cb5 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -64,18 +64,21 @@ class SubroutineCallRewriter : public StmtExprMutator { if (auto gvar = node->op.as()) { if (external_methods_.count(gvar)) { - Array args = node->args.Map([this](const PrimExpr& arg) -> PrimExpr { + Array args = node->args.Map([](const PrimExpr& arg) -> PrimExpr { if (auto* as_call = arg.as()) { if (as_call->op.same_as(builtin::tvm_stack_make_array())) { PrimExpr data_ptr = as_call->args[0]; - made_change_ = true; return data_ptr; } } return arg; }); - if (!args.same_as(node->args)) { - node.CopyOnWrite()->args = args; + + if (!args.same_as(node->args) || node->dtype != DataType::Int(32)) { + auto write_ptr = node.CopyOnWrite(); + write_ptr->dtype = DataType::Int(32); + write_ptr->args = args; + made_change_ = true; } } } diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index d36641dfc28f..90d2599b58bd 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3817,6 +3817,22 @@ def subroutine(A_data: T.handle("float32"), n: T.int32): return mod +def subroutine_call_returning_int(): + """An internal function call may return non-void""" + + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(2, "float32")): + mod.subroutine(A[0]) + mod.subroutine(A[1]) + + @T.prim_func + def subroutine(x: T.float32) -> T.float32: + T.ret(x * x) + + return mod + + def undefined_data_ptr_in_decl_buffer(): """The T.decl_buffer syntax should not introduce an Allocate @@ -4009,6 +4025,7 @@ def func(): ir_module_with_attrs, nested_seqstmt, subroutine_call, + subroutine_call_returning_int, undefined_data_ptr_in_decl_buffer, undefined_shape_in_decl_buffer, undefined_stride_in_decl_buffer,