Skip to content

Commit

Permalink
[TIR] Return error code from kernels in SplitHostDevice (#15241)
Browse files Browse the repository at this point in the history
* [TVMScript] Handle parsing of PrimFunc calls with non-void return

Prior to this commit, the return type of all internal function calls
was hard-coded as `"void"`.  After this commit, the `GlobalVar`
representing the internal function has type annotation based on the
callee's signature, which is then used as the return type of the
internal call.

* Update CallNode return type in MakeUnpackedAPI

* [TIR] Return error code from kernels in SplitHostDevice

Some codegen types delegate to `CodeGenCPU` for their compute kernels,
as they may delegate work to packed functions.  Because `CodeGenCPU`
assumes that it can return an error code at any point (e.g. when
launching a parallel for loop), the compute kernel should return an
error code.

* [TIR] Remove builtin::ret(0) from device-side kernel

* Restrict the int32 return type to targets that need to propagate errors

* Updated unit tests for CPU-specific checks
  • Loading branch information
Lunderberg authored Jul 18, 2023
1 parent d81e880 commit 2eca9f0
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 4 deletions.
41 changes: 40 additions & 1 deletion src/tir/transforms/lower_device_kernel_launch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,36 @@ class DeviceInfoCollector : public StmtVisitor {
// The amount of dynamic shared memory used
Optional<PrimExpr> dyn_shmem_size{NullOpt};
};

class ReturnRemover : public StmtExprMutator {
public:
static Stmt Apply(const Stmt& stmt) {
ReturnRemover mutator;
return mutator(stmt);
}

private:
using Parent = StmtExprMutator;
Stmt VisitStmt_(const EvaluateNode* op) override {
if (auto* call = op->value.as<CallNode>()) {
if (call->op.same_as(builtin::ret())) {
ICHECK_EQ(call->args.size(), 1);
auto as_int = call->args[0].as<IntImmNode>();
ICHECK(as_int && as_int->value == 0)
<< "Device kernel may only contain successful return, T.ret(0)";
return Evaluate(0);
}
}
return Parent::VisitStmt_(op);
}

PrimExpr VisitExpr_(const CallNode* op) override {
if (op->op.same_as(builtin::ret())) {
LOG(FATAL) << "Call to builtin::ret() should only appear within an Evaluate node";
}
return Parent::VisitExpr_(op);
}
};
} // namespace

class DeviceKernelMutator : public StmtExprMutator {
Expand Down Expand Up @@ -185,10 +215,19 @@ class DeviceKernelMutator : public StmtExprMutator {
if (is_kernel_launch) {
const auto& info = device_info_map_.at(gvar.get());

// Kernel launches provide an int32 error code to the caller,
// but do not accept any return type from the callee.
{
auto write_ptr = func.CopyOnWrite();
write_ptr->ret_type = VoidType();
write_ptr->body = ReturnRemover::Apply(write_ptr->body);
}

func = WithAttrs(std::move(func),
{{tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDeviceKernelLaunch)},
{tvm::tir::attr::kKernelLaunchParams, info.launch_params},
{tvm::attr::kGlobalSymbol, info.global_symbol}});

} else if (is_call_extern && !func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint);
}
Expand All @@ -197,7 +236,7 @@ class DeviceKernelMutator : public StmtExprMutator {
}

private:
PrimExpr VisitExpr_(const CallNode* op) {
PrimExpr VisitExpr_(const CallNode* op) override {
auto node = Downcast<Call>(Parent::VisitExpr_(op));

auto* gvar = op->op.as<GlobalVarNode>();
Expand Down
33 changes: 30 additions & 3 deletions src/tir/transforms/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class HostDeviceSplitter : public StmtMutator {
VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/false);
use_def(body);

// Sort first by variable typ, then by variable name
// Sort first by variable type, then by variable name
std::vector<Var> params{use_def.undefined_.begin(), use_def.undefined_.end()};
std::sort(params.begin(), params.end(), [](const Var& a, const Var& b) {
auto sort_key = [](const Var& var) {
Expand All @@ -74,16 +74,43 @@ class HostDeviceSplitter : public StmtMutator {
return params;
}();

// CodeGenCPU is used for some device-side targets, such as
// "ext_dev", and expects to be able to return a int32_t status
// code.

bool can_propagate_errors = [&]() {
auto kind = device_target->GetTargetDeviceType();
return kind == kDLCPU || kind == kDLExtDev || kind == kDLHexagon;
}();
IntImm success(DataType::Int(32), 0);
Type kernel_ret_type;
if (can_propagate_errors) {
kernel_ret_type = PrimType(DataType::Int(32));
body = SeqStmt::Flatten(body, Evaluate(ret(success)));
} else {
kernel_ret_type = VoidType();
}

GlobalVar kernel_symbol_global = var_supply_();
PrimFunc device_func(params, body);
PrimFunc device_func(params, body, kernel_ret_type);
device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target},
{tir::attr::kNoAlias, Bool(true)},
{tir::attr::kIsGlobalFunc, Bool(true)}});

(*device_mod_)->Add(kernel_symbol_global, device_func);
Array<PrimExpr> args = params.Map([](const Var& var) -> PrimExpr { return var; });

return Evaluate(Call(DataType::Void(), kernel_symbol_global, args));
if (can_propagate_errors) {
Var kernel_error_code("kernel_error_code", success->dtype);
Call kernel_call(success->dtype, kernel_symbol_global, args);
AssertStmt assert_success(kernel_error_code == success,
StringImm("Error executing compute kernel"), Evaluate(0));
LetStmt let_check(kernel_error_code, kernel_call, assert_success);

return std::move(let_check);
} else {
return Evaluate(Call(DataType::Void(), kernel_symbol_global, args));
}
}

// target ir module
Expand Down
38 changes: 38 additions & 0 deletions tests/python/unittest/test_tir_transform_split_host_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,44 @@ def main_kernel(n: T.int32):
return mod


class TestSplitHostDeviceOnCPU(BaseCompare):
"""A kernel running on the CPU may return an error code"""

def before(self):
@I.ir_module
class mod:
@T.prim_func
def main(n: T.int32):
T.func_attr({"target": T.target("cuda", host="llvm -opt-level=0")})
T.attr(T.target("llvm"), "target", 0)
T.evaluate(n)

return mod

def expected(self):
@I.ir_module
class mod:
@T.prim_func
def main(n: T.int32):
T.func_attr({"target": T.target("cuda", host="llvm -opt-level=0")})
err = mod.main_kernel(n)
assert err == 0, "Error executing compute kernel"

@T.prim_func
def main_kernel(n: T.int32) -> T.int32:
T.func_attr(
{
"target": T.target("llvm"),
"tir.noalias": T.bool(True),
"tir.is_global_func": True,
}
)
T.evaluate(n)
T.ret(0)

return mod


class TestSplitHostDeviceWithoutFuncHostAttribute(BaseCompare):
"""Like TestSplitHostDevice, but no host specified in the host's target
Expand Down

0 comments on commit 2eca9f0

Please sign in to comment.