Skip to content

Commit

Permalink
[TVMScript] Handle parsing of PrimFunc calls with non-void return (#1…
Browse files Browse the repository at this point in the history
…5239)

* [TVMScript] Handle parsing of PrimFunc calls with non-void return

Prior to this commit, the return type of all internal function calls
was hard-coded as `"void"`.  After this commit, the `GlobalVar`
representing the internal function has type annotation based on the
callee's signature, which is then used as the return type of the
internal call.

* Update CallNode return type in MakeUnpackedAPI
  • Loading branch information
Lunderberg authored Jul 7, 2023
1 parent 81463d7 commit 3a33771
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 6 deletions.
9 changes: 8 additions & 1 deletion python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,14 @@ def call_tir(global_var: tvm.ir.GlobalVar, *args):
The call expression.
"""
assert isinstance(global_var, tvm.ir.GlobalVar)
return Call(dtype="void", op=global_var, args=args)

dtype = "void"
if global_var.checked_type is not None:
ret_type = global_var.checked_type.ret_type
if hasattr(ret_type, "dtype"):
dtype = ret_type.dtype

return Call(dtype=dtype, op=global_var, args=args)


def start_profile_intrinsic(id):
Expand Down
14 changes: 13 additions & 1 deletion src/script/ir_builder/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <tvm/ir/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/script/ir_builder/ir/ir.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>

#include "./utils.h"

Expand All @@ -38,7 +40,17 @@ GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature)
IRModuleFrame frame = FindModuleFrame("I.DeclFunction");
CHECK(!frame->global_var_map.count(func_name))
<< "ValueError: function " << func_name << " already exists";
GlobalVar gv = GlobalVar(func_name);

auto gvar_type = [&]() -> Type {
if (auto prim_func = func_signature.as<tir::PrimFuncNode>()) {
Array<Type> arg_types = prim_func->params.Map([](const auto& var) { return GetType(var); });
return FuncType(arg_types, prim_func->ret_type, {}, {});
}

return {};
}();

GlobalVar gv = GlobalVar(func_name, gvar_type);
CHECK(frame->functions.find(gv) == frame->functions.end())
<< "ValueError: function " << func_name << " has already been defined.";
frame->global_var_map.Set(func_name, gv);
Expand Down
11 changes: 7 additions & 4 deletions src/tir/transforms/make_unpacked_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,21 @@ class SubroutineCallRewriter : public StmtExprMutator {

if (auto gvar = node->op.as<GlobalVarNode>()) {
if (external_methods_.count(gvar)) {
Array<PrimExpr> args = node->args.Map([this](const PrimExpr& arg) -> PrimExpr {
Array<PrimExpr> args = node->args.Map([](const PrimExpr& arg) -> PrimExpr {
if (auto* as_call = arg.as<CallNode>()) {
if (as_call->op.same_as(builtin::tvm_stack_make_array())) {
PrimExpr data_ptr = as_call->args[0];
made_change_ = true;
return data_ptr;
}
}
return arg;
});
if (!args.same_as(node->args)) {
node.CopyOnWrite()->args = args;

if (!args.same_as(node->args) || node->dtype != DataType::Int(32)) {
auto write_ptr = node.CopyOnWrite();
write_ptr->dtype = DataType::Int(32);
write_ptr->args = args;
made_change_ = true;
}
}
}
Expand Down
17 changes: 17 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3817,6 +3817,22 @@ def subroutine(A_data: T.handle("float32"), n: T.int32):
return mod


def subroutine_call_returning_int():
"""An internal function call may return non-void"""

@I.ir_module
class mod:
@T.prim_func
def main(A: T.Buffer(2, "float32")):
mod.subroutine(A[0]) + mod.subroutine(A[1])

@T.prim_func
def subroutine(x: T.float32) -> T.float32:
T.ret(x * x)

return mod


def undefined_data_ptr_in_decl_buffer():
"""The T.decl_buffer syntax should not introduce an Allocate
Expand Down Expand Up @@ -4009,6 +4025,7 @@ def func():
ir_module_with_attrs,
nested_seqstmt,
subroutine_call,
subroutine_call_returning_int,
undefined_data_ptr_in_decl_buffer,
undefined_shape_in_decl_buffer,
undefined_stride_in_decl_buffer,
Expand Down

0 comments on commit 3a33771

Please sign in to comment.