From e9ec06cd894aaecc4a8aa9387e12dfc1a2f3876d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 18 Jul 2023 13:10:36 -0500 Subject: [PATCH] [TIR] Return error code from kernels in SplitHostDevice (#15241) * [TVMScript] Handle parsing of PrimFunc calls with non-void return Prior to this commit, the return type of all internal function calls was hard-coded as `"void"`. After this commit, the `GlobalVar` representing the internal function has type annotation based on the callee's signature, which is then used as the return type of the internal call. * Update CallNode return type in MakeUnpackedAPI * [TIR] Return error code from kernels in SplitHostDevice Some codegen types delegate to `CodeGenCPU` for their compute kernels, as they may delegate work to packed functions. Because `CodeGenCPU` assumes that it can return an error code at any point (e.g. when launching a parallel for loop), the compute kernel should return an error code. * [TIR] Remove builtin::ret(0) from device-side kernel * Restrict the int32 return type to targets that need to propagate errors * Updated unit tests for CPU-specific checks --- .../transforms/lower_device_kernel_launch.cc | 41 ++++++++++++++++++- src/tir/transforms/split_host_device.cc | 33 +++++++++++++-- .../test_tir_transform_split_host_device.py | 38 +++++++++++++++++ 3 files changed, 108 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc index 52f06ea45c7c7..932116485fa1d 100644 --- a/src/tir/transforms/lower_device_kernel_launch.cc +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -145,6 +145,36 @@ class DeviceInfoCollector : public StmtVisitor { // The amount of dynamic shared memory used Optional dyn_shmem_size{NullOpt}; }; + +class ReturnRemover : public StmtExprMutator { + public: + static Stmt Apply(const Stmt& stmt) { + ReturnRemover mutator; + return mutator(stmt); + } + + private: + using Parent = StmtExprMutator; + Stmt VisitStmt_(const EvaluateNode* op) override { + if (auto* call = op->value.as()) { + if (call->op.same_as(builtin::ret())) { + ICHECK_EQ(call->args.size(), 1); + auto as_int = call->args[0].as(); + ICHECK(as_int && as_int->value == 0) + << "Device kernel may only contain successful return, T.ret(0)"; + return Evaluate(0); + } + } + return Parent::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const CallNode* op) override { + if (op->op.same_as(builtin::ret())) { + LOG(FATAL) << "Call to builtin::ret() should only appear within an Evaluate node"; + } + return Parent::VisitExpr_(op); + } +}; } // namespace class DeviceKernelMutator : public StmtExprMutator { @@ -185,10 +215,19 @@ class DeviceKernelMutator : public StmtExprMutator { if (is_kernel_launch) { const auto& info = device_info_map_.at(gvar.get()); + // Kernel launches provide an int32 error code to the caller, + // but do not accept any return type from the callee. + { + auto write_ptr = func.CopyOnWrite(); + write_ptr->ret_type = VoidType(); + write_ptr->body = ReturnRemover::Apply(write_ptr->body); + } + func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDeviceKernelLaunch)}, {tvm::tir::attr::kKernelLaunchParams, info.launch_params}, {tvm::attr::kGlobalSymbol, info.global_symbol}}); + } else if (is_call_extern && !func->GetAttr(tvm::attr::kGlobalSymbol)) { func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint); } @@ -197,7 +236,7 @@ class DeviceKernelMutator : public StmtExprMutator { } private: - PrimExpr VisitExpr_(const CallNode* op) { + PrimExpr VisitExpr_(const CallNode* op) override { auto node = Downcast(Parent::VisitExpr_(op)); auto* gvar = op->op.as(); diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index ac5dc7131d33f..9b1dbf1a66188 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -60,7 +60,7 @@ class HostDeviceSplitter : public StmtMutator { VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/false); use_def(body); - // Sort first by variable typ, then by variable name + // Sort first by variable type, then by variable name std::vector params{use_def.undefined_.begin(), use_def.undefined_.end()}; std::sort(params.begin(), params.end(), [](const Var& a, const Var& b) { auto sort_key = [](const Var& var) { @@ -74,8 +74,25 @@ class HostDeviceSplitter : public StmtMutator { return params; }(); + // CodeGenCPU is used for some device-side targets, such as + // "ext_dev", and expects to be able to return a int32_t status + // code. + + bool can_propagate_errors = [&]() { + auto kind = device_target->GetTargetDeviceType(); + return kind == kDLCPU || kind == kDLExtDev || kind == kDLHexagon; + }(); + IntImm success(DataType::Int(32), 0); + Type kernel_ret_type; + if (can_propagate_errors) { + kernel_ret_type = PrimType(DataType::Int(32)); + body = SeqStmt::Flatten(body, Evaluate(ret(success))); + } else { + kernel_ret_type = VoidType(); + } + GlobalVar kernel_symbol_global = var_supply_(); - PrimFunc device_func(params, body); + PrimFunc device_func(params, body, kernel_ret_type); device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target}, {tir::attr::kNoAlias, Bool(true)}, {tir::attr::kIsGlobalFunc, Bool(true)}}); @@ -83,7 +100,17 @@ class HostDeviceSplitter : public StmtMutator { (*device_mod_)->Add(kernel_symbol_global, device_func); Array args = params.Map([](const Var& var) -> PrimExpr { return var; }); - return Evaluate(Call(DataType::Void(), kernel_symbol_global, args)); + if (can_propagate_errors) { + Var kernel_error_code("kernel_error_code", success->dtype); + Call kernel_call(success->dtype, kernel_symbol_global, args); + AssertStmt assert_success(kernel_error_code == success, + StringImm("Error executing compute kernel"), Evaluate(0)); + LetStmt let_check(kernel_error_code, kernel_call, assert_success); + + return std::move(let_check); + } else { + return Evaluate(Call(DataType::Void(), kernel_symbol_global, args)); + } } // target ir module diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py b/tests/python/unittest/test_tir_transform_split_host_device.py index ca16fe908ff36..a4dbb6b6b9a3c 100644 --- a/tests/python/unittest/test_tir_transform_split_host_device.py +++ b/tests/python/unittest/test_tir_transform_split_host_device.py @@ -129,6 +129,44 @@ def main_kernel(n: T.int32): return mod +class TestSplitHostDeviceOnCPU(BaseCompare): + """A kernel running on the CPU may return an error code""" + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(n: T.int32): + T.func_attr({"target": T.target("cuda", host="llvm -opt-level=0")}) + T.attr(T.target("llvm"), "target", 0) + T.evaluate(n) + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(n: T.int32): + T.func_attr({"target": T.target("cuda", host="llvm -opt-level=0")}) + err = mod.main_kernel(n) + assert err == 0, "Error executing compute kernel" + + @T.prim_func + def main_kernel(n: T.int32) -> T.int32: + T.func_attr( + { + "target": T.target("llvm"), + "tir.noalias": T.bool(True), + "tir.is_global_func": True, + } + ) + T.evaluate(n) + T.ret(0) + + return mod + + class TestSplitHostDeviceWithoutFuncHostAttribute(BaseCompare): """Like TestSplitHostDevice, but no host specified in the host's target