From 6eb077944295b96d70531db9b7048f2e87af1cfc Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 26 May 2023 15:22:12 -0500 Subject: [PATCH] [TIR] SplitHostDevice, handle subroutines (#14918) This PR refactors SplitHostDevice into three separate transformations. Previously, SplitHostDevice would replace device regions with a builtin::tvm_call_packed() node to replace the extracted region. After this PR, this process is performed in three separate steps. AnnotateDeviceRegion: Annotate the regions that should be executed on another target. SplitHostDevice: Extract the annotated region into an independent PrimFunc, with a GlobalVar to represent the call from into the new subroutine. LowerDeviceKernelLaunch: For any subroutine call where the caller and callee are on different devices, replace with a device kernel launch. * PR#14915 [TVMScript] Allow T.target("device", host="host") in TVMScript Prior to this commit, the `TargetNode::host` could be specified in TVMScript as part of the config dictionary, under the key `"host"`. However, this required all other device parameters to be explicitly specified, rather than using any of the short-hand string representations. This commit forwards the `host` argument from TVMScript's `T.target` method to `tvm.target.Target`, allowing both the device and host to be specified using the shorthand string representation. ```python @T.prim_func def before_this_commit(): T.func_attr( { "target": T.target( { "arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32, } ) } ) T.evaluate(0) @T.prim_func def after_this_commit(): T.func_attr({"target": T.target("cuda", host="llvm")}) T.evaluate(0) ``` * [Target] Added WithoutHost method * [TIR] SplitHostDevice, handle missing kGlobalSymbol Previously, the symbol name of the extracted compute kernel was defined based on the `kGlobalSymbol` attribute, which was required to be present. This commit updates `SplitHostDevice` to generate the symbol name using `kGlobalSymbol` if present, and to fall back to the name of the `tvm::GlobalVar` for internal functions. * [TIR] Refactor SplitHostDevice into three separate passes First pass, `AnnotateDeviceRegions`. This pass decides which portions of a PrimFunc should be run on the device, and annotates them with `kTarget` attribute, indicating which target should be used for later lowering steps. Second pass, `SplitHostDevice`. This pass extracts the annotated region into an independent PrimFunc. The `kTarget` attribute of the extracted kernel is defined by the `kTarget` annotation inserted by `AnnotateDeviceRegions`. The host function is marked by the `tvm::tir::attr::kIsHostFunc` attribute, allowing it to be recognized by later host-only lowering passes. Third pass, `LowerDeviceKernelLaunch`. This pass identifies subroutine calls that call into device kernels, and rewrites them into `T.tvm_call_packed`. * Add unit tests specifically for SplitHostDevice behavior * Added unit test specifically for AnnotateDeviceRegions * Added unit tests for LowerDeviceKernelLaunch * Minor cleanup, moved all kernel launch collection into one spot Previously, the SplitHostDevice pass added the `tir::attr::kKernelLaunchParams` attribute, and the LowerDeviceKernelLaunch pass filled in the values for it. This cleanup makes the kernel launch params be the sole responsibility of LowerDeviceKernelLaunch. * Updated unit tests for LowerWarpMemory * Updated unit tests for ThreadSync * Updated unit test for inject ptx async copy * [Bugfix] Avoid symbol conflicts in MakePackedAPI/MakeUnpackedAPI PRs https://github.com/apache/tvm/pull/14913 and https://github.com/apache/tvm/pull/14914 made analogous changes to `MakePackedAPI` and `MakeUnpackedAPI` to handle subroutine calls. Both PRs introduced the same symbol, `tvm::tir::SubroutineCallRewriter`, a local utility to update internal calls to a modified function. While each PR passed CI individually, and was therefore able to merge, having both changes caused a duplicate symbol. This commit updates `MakePackedAPI` and `MakeUnpackedAPI` to place their local utilities into anonymous namespaces, avoiding the conflict. * Maintain "tir.is_global_func" attr in device-side entry point * SplitHostDevice, update the host-side target to be the host * [TIR] Update LowerDeviceKernelLaunch to avoid kIsHostFunc Update to use the `tvm::tir::IsHostFunc` utility function, rather than the `kIsHostFunc` attribute. Per discussion on https://github.com/apache/tvm/pull/14020, the `kIsHostFunct` attribute should only be used in `BindTarget`, and should not be re-introduced in `SplitHostDevice`. * Remove is_host_func from SplitHostDevice tests --- include/tvm/tir/transform.h | 38 +++ python/tvm/tir/op.py | 2 +- python/tvm/tir/transform/transform.py | 38 +++ src/driver/driver_api.cc | 3 + src/tir/transforms/annotate_device_regions.cc | 81 +++++ .../transforms/lower_device_kernel_launch.cc | 305 ++++++++++++++++++ src/tir/transforms/split_host_device.cc | 272 ++++------------ ...t_tir_transform_annotate_device_regions.py | 58 ++++ ...test_tir_transform_device_kernel_launch.py | 193 +++++++++++ ...est_tir_transform_inject_ptx_async_copy.py | 2 +- .../test_tir_transform_lower_warp_memory.py | 37 +-- .../test_tir_transform_split_host_device.py | 113 ++++++- .../test_tir_transform_thread_sync.py | 5 +- 13 files changed, 908 insertions(+), 239 deletions(-) create mode 100644 src/tir/transforms/annotate_device_regions.cc create mode 100644 src/tir/transforms/lower_device_kernel_launch.cc create mode 100644 tests/python/unittest/test_tir_transform_annotate_device_regions.py create mode 100644 tests/python/unittest/test_tir_transform_device_kernel_launch.py diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 8dee176277d7..d9d68e0a8b6a 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -263,13 +263,51 @@ TVM_DLL Pass LowerCustomDatatypes(); */ TVM_DLL Pass DecorateDeviceScope(); +/*! + * \brief Annotate locations that should be run on the device + * + * Insert `AttrStmt` nodes specifying a target on which regions within + * the PrimFunc should be executed. Only modifies functions that have + * a `tvm::attr::kTarget` attribute, and where that target defines a + * host. + * + * \return The pass. + */ +TVM_DLL Pass AnnotateDeviceRegions(); + /*! * \brief Split the function into a host function and device functions. * + * The resulting host-side function will keep the same + * `tvm::attr::kTarget` attribute (e.g. `T.target("cuda", + * host=T.target("llvm"))`). This ensures that `MakePackedAPI` knows + * which device type should be used for the input buffers. + * + * The resulting device-side function will + * have the host stripped from its target attribute + * (e.g. `T.target("cuda")`). + * * \return The pass. */ TVM_DLL Pass SplitHostDevice(); +/*! + * \brief Lower cross-device function calls. + * + * Prior to this pass, host to device calls are represented as + * subroutine calls, with environment parameters (e.g. env_thread) + * specified internally. The device function is an internal function, + * without a `tvm::attr::kGlobalSymbol` attribute. + * + * After this pass, host to device calls are represented as + * tvm_call_packed built-in. The device function is an + * externally-exposed function, with a non-empty + * `tvm::attr::kGlobalSymbol` attribute. + * + * \return The pass. + */ +TVM_DLL Pass LowerDeviceKernelLaunch(); + /*! * \brief skip assert stmt. * diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 90e3db4cb96b..098c13f04e9d 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -445,7 +445,7 @@ def call_tir(global_var: tvm.ir.GlobalVar, *args): The call expression. """ assert isinstance(global_var, tvm.ir.GlobalVar) - return Call(dtype="handle", op=global_var, args=args) + return Call(dtype="void", op=global_var, args=args) def start_profile_intrinsic(id): diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index f2ce4378141e..9e038f618bc3 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -435,6 +435,22 @@ def MakeUnpackedAPI(): return _ffi_api.MakeUnpackedAPI() # type: ignore +def AnnotateDeviceRegions(): + """Annotate locations that should be run on the device + + Insert `AttrStmt` nodes specifying a target on which regions + within the PrimFunc should be executed. Only modifies functions + that have a `tvm::attr::kTarget` attribute, and where that target + defines a host. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.AnnotateDeviceRegions() # type: ignore + + def SplitHostDevice(): """Split the function into a host function and device functions. @@ -446,6 +462,28 @@ def SplitHostDevice(): return _ffi_api.SplitHostDevice() # type: ignore +def LowerDeviceKernelLaunch(): + """Lower cross-device function calls. + + Prior to this pass, host to device calls are represented as + subroutine calls, with environment parameters (e.g. env_thread) + specified internally. The device function is an internal + function, without a `tvm::attr::kGlobalSymbol` attribute. + + After this pass, host to device calls are represented as + tvm_call_packed built-in. The device function is an + externally-exposed function, with a non-empty + `tvm::attr::kGlobalSymbol` attribute. + + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerDeviceKernelLaunch() # type: ignore + + def DecorateDeviceScope(): """Decorate all the function's body as device function. diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 91bc57ccbeb2..e5f71c38320d 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -587,7 +587,10 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::MakePackedAPI()); } mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); + + mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions()); mixed_pass_list.push_back(tir::transform::SplitHostDevice()); + mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch()); return transform::Sequential(mixed_pass_list); } diff --git a/src/tir/transforms/annotate_device_regions.cc b/src/tir/transforms/annotate_device_regions.cc new file mode 100644 index 000000000000..a81af7d7805b --- /dev/null +++ b/src/tir/transforms/annotate_device_regions.cc @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file annotate_device_regions.cc + * \brief Split device function from host. + */ +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +class DeviceRegionAnnotater : public StmtMutator { + public: + explicit DeviceRegionAnnotater(Target device_target) : device_target_(device_target) {} + + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == tvm::attr::kTarget) { + // If a target attribute already exists, use it as-is. + return GetRef(op); + } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || + op->attr_key == attr::device_scope) { + // These attributes are only allowed in device-side code, so + // they should be annotated with the function's default target. + Stmt body = GetRef(op); + return AttrStmt(device_target_, tvm::attr::kTarget, 0, body); + } else { + // All other annotations are ignored + return StmtMutator::VisitStmt_(op); + } + } + + private: + Target device_target_; +}; + +namespace transform { + +Pass AnnotateDeviceRegions() { + auto pass_func = [](PrimFunc func, IRModule mod, PassContext ctx) -> PrimFunc { + auto opt_target = func->GetAttr(tvm::attr::kTarget); + ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute"; + Target target = opt_target.value(); + + if (target->GetHost()) { + DeviceRegionAnnotater mutator(target.WithoutHost()); + func.CopyOnWrite()->body = mutator(func->body); + } + return func; + }; + + return CreatePrimFuncPass(pass_func, 0, "tir.AnnotateDeviceRegions", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.AnnotateDeviceRegions").set_body_typed(AnnotateDeviceRegions); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_device_kernel_launch.cc b/src/tir/transforms/lower_device_kernel_launch.cc new file mode 100644 index 000000000000..5ffbf0d7a7fd --- /dev/null +++ b/src/tir/transforms/lower_device_kernel_launch.cc @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file lower_device_kernel_launch.cc + * \brief Split device function from host. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +namespace { +struct KernelInfo { + // The device on which the PrimFunc runs + Target target; + + // The externally visible symbol which may refer to the PrimFunc + // when launching a device kernel. + String global_symbol; + + // The parameters accepted by the PrimFunc. Used to rewrite + // `launch_args` to be in terms of the calling scope. + Array params; + + // The launch parameters that should annotate the PrimFunc, if the + // kernel is ever called from the host. + Array launch_params; + + // Additional arguments which must be provided to the host-side + // PackedFunc. These may be in terms of the function's parameters + // (e.g. a function that computes the average of `N` elements, and + // which must be launched with `N` CUDA threads). + Array launch_args; +}; + +/*! + * \brief Visitor class to collect device-side program information. + */ +class DeviceInfoCollector : public StmtVisitor { + public: + static KernelInfo Collect(const GlobalVar& gvar, const PrimFunc& func) { + DeviceInfoCollector collector; + collector.info_.target = func->GetAttr(tvm::attr::kTarget).value().WithoutHost(); + collector.info_.params = func->params; + + collector(func->body); + + // The dynamic shared memory is required to be the last of the + // kernel launch parameters + if (collector.dyn_shmem_size) { + collector.info_.launch_params.push_back( + tvm::runtime::launch_param::kUseDynamicSharedMemoryTag); + } + + collector.info_.global_symbol = + func->GetAttr(tvm::attr::kGlobalSymbol).value_or(gvar->name_hint); + + collector.info_.launch_args = collector.info_.launch_params.Map( + [&](const auto& param) { return collector.GetArgument(param); }); + + return collector.info_; + } + + private: + PrimExpr GetArgument(const String& launch_param) const { + if (launch_param == tvm::runtime::launch_param::kUseDynamicSharedMemoryTag) { + CHECK(dyn_shmem_size.defined()) + << "Compute kernel requires launch parameter \"" << launch_param + << "\", but PrimFunc did not contain Allocate node with shared dynamic scope."; + return dyn_shmem_size.value(); + } + + auto extent = thread_extent.Get(launch_param); + CHECK(extent) << "Compute kernel requires launch parameter \"" << launch_param + << "\", but PrimFunc does not contain AttrStmt \"" << attr::thread_extent + << "\" defining this thread extent"; + return extent.value(); + } + + void VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::thread_extent) { + IterVar iv = Downcast(op->node); + ICHECK_NE(iv->thread_tag.length(), 0U); + // thread_extent can appear multiple times + // use the first appearance as def. + if (!defined_thread.count(iv.get())) { + defined_thread.insert(iv.get()); + info_.launch_params.push_back(iv->thread_tag); + thread_extent.Set(iv->thread_tag, op->value); + } + } + + StmtVisitor::VisitStmt_(op); + } + + void VisitStmt_(const AllocateNode* op) final { + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { + ICHECK(!dyn_shmem_size.defined()) << "Only one dynamic shared memory allocation is allowed."; + ICHECK_GT(op->extents.size(), 0); + + PrimExpr dyn_size = Integer(1); + for (const auto& extent : op->extents) { + dyn_size *= extent; + } + dyn_size *= op->dtype.bytes(); + + dyn_shmem_size = dyn_size; + } + StmtVisitor::VisitStmt_(op); + } + + // The collected results + KernelInfo info_; + // recording what thread axis have been visited. + std::unordered_set defined_thread; + // The extent of each thread + Map thread_extent; + // The amount of dynamic shared memory used + Optional dyn_shmem_size{NullOpt}; +}; +} // namespace + +class DeviceKernelMutator : public StmtExprMutator { + public: + using Parent = StmtExprMutator; + + explicit DeviceKernelMutator(std::unordered_map device_info_map) + : device_info_map_(std::move(device_info_map)) {} + + PrimFunc RewriteKernelLaunchSite(const GlobalVar& gvar, PrimFunc func) { + ICHECK(!current_target_.defined()); + auto it = device_info_map_.find(gvar.get()); + ICHECK(it != device_info_map_.end()); + current_target_ = it->second.target; + + auto body = VisitStmt(func->body); + if (!body.same_as(func->body)) { + func.CopyOnWrite()->body = body; + } + + current_target_ = NullOpt; + return func; + } + + PrimFunc UpdateKernelAttributes(const GlobalVar& gvar, PrimFunc func) const { + if (device_kernel_launch_.count(gvar.get())) { + const auto& info = device_info_map_.at(gvar.get()); + + func = WithAttrs(std::move(func), + {{tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDeviceKernelLaunch)}, + {tvm::tir::attr::kKernelLaunchParams, info.launch_params}, + {tvm::attr::kGlobalSymbol, info.global_symbol}, + {tvm::tir::attr::kIsGlobalFunc, Bool(true)}}); + } + + return func; + } + + private: + PrimExpr VisitExpr_(const CallNode* op) { + auto node = Downcast(Parent::VisitExpr_(op)); + + auto* gvar = op->op.as(); + if (!gvar) return std::move(node); + + auto it = device_info_map_.find(gvar); + ICHECK(it != device_info_map_.end()) + << "CallNode attempted subroutine call to " << gvar->name_hint << ", but " + << gvar->name_hint << " did not appear within the IRModule"; + const KernelInfo& dev_info = it->second; + + auto caller_device_type = current_target_.value()->GetTargetDeviceType(); + auto callee_device_type = dev_info.target->GetTargetDeviceType(); + if (caller_device_type == callee_device_type) { + return std::move(node); + } + + ICHECK(dev_info.launch_params.defined()) + << "CallNode attempted kernel launch to " << gvar->name_hint << " on target " + << dev_info.target << ", but subroutine " << gvar->name_hint + << " did not have the tir::attr::kKernelLaunchParams attribute " + << "required for cross-target kernel launch"; + + // Collected kernel information may be in terms of the callee's + // arguments, but we need expressions for them in terms of the + // caller's parameters. The param_map allows substitution of + // parameter values into the thread extents, to generate + // expressions that are valid within the caller. + Map param_map = [&]() { + Map param_map; + CHECK_EQ(node->args.size(), dev_info.params.size()) + << "Function " << gvar->name_hint << " accepts " << dev_info.params.size() + << " arguments as input, but is called using " << node->args.size() << " arguments"; + for (size_t i = 0; i < node->args.size(); i++) { + param_map.Set(dev_info.params[i], node->args[i]); + } + return param_map; + }(); + + device_kernel_launch_.insert(gvar); + + Array call_args; + call_args.push_back(StringImm(dev_info.global_symbol)); + for (PrimExpr arg : node->args) { + call_args.push_back(arg); + } + for (const auto& launch_arg : dev_info.launch_args) { + call_args.push_back(Substitute(launch_arg, param_map)); + } + + auto dtype = node->dtype.is_void() ? DataType::Int(32) : node->dtype; + + return Call(dtype, builtin::tvm_call_packed(), call_args); + } + + Optional current_target_; + std::unordered_map device_info_map_; + std::unordered_set device_kernel_launch_; +}; + +namespace transform { + +Pass LowerDeviceKernelLaunch() { + auto pass_func = [](IRModule mod, PassContext ctx) -> IRModule { + auto mutator = [&mod]() { + std::unordered_map device_info_map; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto prim_func = base_func.as()) { + device_info_map[gvar.get()] = DeviceInfoCollector::Collect(gvar, prim_func.value()); + } + } + return DeviceKernelMutator(std::move(device_info_map)); + }(); + + { + IRModule updates; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto* ptr = base_func.as()) { + auto prim_func = mutator.RewriteKernelLaunchSite(gvar, GetRef(ptr)); + if (!prim_func.same_as(base_func)) { + updates->Add(gvar, prim_func); + } + } + } + + if (updates->functions.size()) { + mod.CopyOnWrite()->Update(updates); + } + } + + { + IRModule updates; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto* ptr = base_func.as()) { + auto prim_func = mutator.UpdateKernelAttributes(gvar, GetRef(ptr)); + if (!prim_func.same_as(base_func)) { + updates->Add(gvar, prim_func); + } + } + } + + if (updates->functions.size()) { + mod.CopyOnWrite()->Update(updates); + } + } + + return mod; + }; + + return tvm::transform::CreateModulePass(pass_func, 0, "tir.LowerDeviceKernelLaunch", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceKernelLaunch") + .set_body_typed(LowerDeviceKernelLaunch); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 4f47b8ce2bf9..9270b356ba22 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -41,246 +41,102 @@ namespace tvm { namespace tir { -/*! - * \brief Visitor class to collect device-side program information. - */ -class DeviceInfoCollector : public StmtVisitor { - public: - Array thread_axis_; - Array thread_extent_; - PrimExpr dyn_shmem_size_{0}; - bool use_dyn_shmem_{false}; - - Array GetLaunchParams() const { - Array output; - for (const auto& axis : thread_axis_) { - output.push_back(axis->thread_tag); - } - if (use_dyn_shmem_) { - output.push_back(runtime::launch_param::kUseDynamicSharedMemoryTag); - } - return output; - } - - private: - void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent) { - IterVar iv = Downcast(op->node); - ICHECK_NE(iv->thread_tag.length(), 0U); - // thread_extent can appear multiple times - // use the first appearance as def. - if (!defined_thread.count(iv.get())) { - defined_thread.insert(iv.get()); - thread_axis_.push_back(iv); - thread_extent_.push_back(op->value); - } - - this->VisitExpr(op->value); - this->VisitStmt(op->body); - } else { - StmtVisitor::VisitStmt_(op); - } - } - - void VisitStmt_(const AllocateNode* op) final { - auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); - if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { - ICHECK_EQ(use_dyn_shmem_, false) << "Only one dynamic shared memory allocation is allowed."; - ICHECK_GT(op->extents.size(), 0); - dyn_shmem_size_ = op->extents[0]; - for (size_t i = 1; i < op->extents.size(); ++i) { - dyn_shmem_size_ *= op->extents[i]; - } - dyn_shmem_size_ = dyn_shmem_size_ * (op->dtype.bytes()); - use_dyn_shmem_ = true; - } - StmtVisitor::VisitStmt_(op); - } - - // recording what thread axis have been visited. - std::unordered_set defined_thread; -}; - -/*! - * \brief Mutator class to remove unrefenced let stmt/expressions. - * \param use_count The pre-computed variable to use count map. - */ -class UnreferencedLetRemover : public StmtExprMutator { - public: - explicit UnreferencedLetRemover(const std::unordered_map& use_count) - : use_count_(use_count) {} - - private: - Stmt VisitStmt_(const LetStmtNode* op) final { - Stmt body = this->VisitStmt(op->body); - // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState) { - return body; - } else { - PrimExpr value = this->VisitExpr(op->value); - if (body.same_as(op->body) && value.same_as(op->value)) { - return GetRef(op); - } else { - return LetStmt(op->var, value, body); - } - } - } - - PrimExpr VisitExpr_(const LetNode* op) final { - PrimExpr body = this->VisitExpr(op->body); - PrimExpr value = this->VisitExpr(op->value); - if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState) { - return body; - } else { - if (body.same_as(op->body) && value.same_as(op->value)) { - return GetRef(op); - } else { - return Let(op->var, value, body); - } - } - } - - // pre-computed variable to use count map. - const std::unordered_map& use_count_; -}; - class HostDeviceSplitter : public StmtMutator { public: - explicit HostDeviceSplitter(IRModule* device_mod, Target device_target, std::string name_prefix) - : device_mod_(device_mod), device_target_(device_target), name_prefix_(name_prefix) {} - - Stmt VisitStmt_(const AllocateNode* op) final { - handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0); - return StmtMutator::VisitStmt_(op); - } + explicit HostDeviceSplitter(IRModule* device_mod, std::string name_prefix) + : device_mod_(device_mod), name_prefix_(name_prefix) {} Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || - op->attr_key == attr::device_scope) { - return SplitDeviceFunc(GetRef(op)); + if (op->attr_key == tvm::attr::kTarget) { + auto device_target = op->node.as().value().WithoutHost(); + return SplitDeviceFunc(op->body, device_target); } return StmtMutator::VisitStmt_(op); } private: - Stmt SplitDeviceFunc(Stmt body) { - std::ostringstream os; - os << name_prefix_ << "_kernel" << device_func_counter_++; - std::string kernel_symbol = os.str(); - // isolate the device function. - VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/false); - use_def(body); - DeviceInfoCollector dev_info; - dev_info(body); - UnreferencedLetRemover let_remover(use_def.use_count_); - body = let_remover(std::move(body)); - - Array params; - Array arguments; - Map remap_vars; - - // Strictly order the arguments: Var pointers, positional arguments. - for (Var var : use_def.undefined_) { - if (var.dtype().is_handle()) { - // Create a new version of v. - auto it = handle_data_type_.find(var.get()); - if (it != handle_data_type_.end()) { - String storage_scope; - if (auto* ptr_type = var->type_annotation.as()) { - storage_scope = ptr_type->storage_scope; - } - tir::Var new_var(var->name_hint, - PointerType(PrimType((*it).second->dtype), storage_scope)); - params.push_back(new_var); - remap_vars.Set(var, new_var); - } else { - params.push_back(var); - } - arguments.push_back(var); - } - } - // positional arguments - for (Var var : use_def.undefined_) { - if (!var.dtype().is_handle()) { - params.push_back(var); - arguments.push_back(var); - } - } - GlobalVarSupply global_var_supply = GlobalVarSupply(*device_mod_); - GlobalVar kernel_symbol_global = global_var_supply->FreshGlobal(kernel_symbol, false); - - PrimFunc device_func(params, Substitute(body, remap_vars)); - device_func = WithAttr(std::move(device_func), tir::attr::kKernelLaunchParams, - dev_info.GetLaunchParams()); - - device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv, - Integer(CallingConv::kDeviceKernelLaunch)); - device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, - runtime::String(kernel_symbol_global->name_hint)); - device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); - device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); - device_func = WithAttr(std::move(device_func), tir::attr::kIsGlobalFunc, Integer(1)); + Stmt SplitDeviceFunc(Stmt body, Target device_target) { + Array params = [&]() { + VarUseDefAnalyzer use_def(/*defined_vars=*/{}, /*visit_thread_extent=*/false); + use_def(body); + + // Sort first by variable typ, then by variable name + std::vector params{use_def.undefined_.begin(), use_def.undefined_.end()}; + std::sort(params.begin(), params.end(), [](const Var& a, const Var& b) { + auto sort_key = [](const Var& var) { + return std::tuple{ + !var->dtype.is_handle(), + var->name_hint, + }; + }; + return sort_key(a) < sort_key(b); + }); + return params; + }(); + + GlobalVar kernel_symbol_global = [&]() { + std::stringstream name; + name << name_prefix_ << "_kernel"; + GlobalVarSupply global_var_supply = GlobalVarSupply(*device_mod_); + return global_var_supply->FreshGlobal(name.str(), false); + }(); + + PrimFunc device_func(params, body); + device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target}, + {tir::attr::kNoAlias, Bool(true)}, + {tir::attr::kIsGlobalFunc, Bool(true)}}); (*device_mod_)->Add(kernel_symbol_global, device_func); + Array args = params.Map([](const Var& var) -> PrimExpr { return var; }); - // generate calls to the device function - Array call_args; - call_args.push_back(StringImm(kernel_symbol_global->name_hint)); - for (PrimExpr arg : arguments) { - call_args.push_back(arg); - } - for (PrimExpr ext : dev_info.thread_extent_) { - call_args.push_back(ext); - } - if (dev_info.use_dyn_shmem_) { - call_args.push_back(dev_info.dyn_shmem_size_); - } - return Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), call_args)); + return Evaluate(Call(DataType::Void(), kernel_symbol_global, args)); } // target ir module IRModule* device_mod_; - // Device target - Target device_target_; // function name hint std::string name_prefix_; - // Number of device functions. - int device_func_counter_{0}; - std::unordered_map handle_data_type_; }; -PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { - auto target = func->GetAttr(tvm::attr::kTarget); - ICHECK(target.defined()) << "SplitHostDevice: Require the target attribute"; +PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& gvar) { + auto opt_target = func->GetAttr(tvm::attr::kTarget); + ICHECK(opt_target) << "SplitHostDevice: Require the target attribute"; + Target target = opt_target.value(); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute"; + auto name_prefix = global_symbol.value_or(gvar->name_hint); + + HostDeviceSplitter splitter(device_mod, name_prefix); - HostDeviceSplitter splitter(device_mod, target.value(), - static_cast(global_symbol.value())); + auto body = splitter(func->body); + + if (!body.same_as(func->body)) { + func.CopyOnWrite()->body = body; + auto target_host = target->GetHost().value_or(Target("llvm")); + func = WithAttr(std::move(func), tvm::attr::kTarget, target_host); + } - auto* n = func.CopyOnWrite(); - n->body = splitter(std::move(n->body)); - // set the host target to None. - func = WithAttr(std::move(func), tvm::attr::kTarget, Target(nullptr)); - return std::move(func); + return func; } namespace transform { Pass SplitHostDevice() { auto pass_func = [](IRModule mod, PassContext ctx) { - IRModuleNode* mod_ptr = mod.CopyOnWrite(); - auto* func_dict = mod_ptr->functions.CopyOnWrite(); IRModule device_mod = IRModule(Map({})); - - for (auto& kv : *func_dict) { - if (kv.second->IsInstance()) { - PrimFunc func = Downcast(std::move(kv.second)); - ICHECK(device_mod.defined()) << "The device module must be defined."; - kv.second = SplitHostDevice(std::move(func), &device_mod); + IRModule updates = IRModule(Map({})); + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto opt = base_func.as()) { + PrimFunc func = opt.value(); + func = SplitHostDevice(std::move(func), &device_mod, gvar); + if (!func.same_as(base_func)) { + updates->Add(gvar, func); + } } } + + mod->Update(updates); mod->Update(device_mod); return ConvertSSA()(mod); }; diff --git a/tests/python/unittest/test_tir_transform_annotate_device_regions.py b/tests/python/unittest/test_tir_transform_annotate_device_regions.py new file mode 100644 index 000000000000..efa43027e9c6 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_annotate_device_regions.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.script import tir as T, ir as I + + +class BaseCompare(tvm.testing.CompareBeforeAfter): + transform = tvm.tir.transform.AnnotateDeviceRegions() + + +class TestAnnotateThreadExtent(BaseCompare): + """Annotation inserted at the "thread_extent" attribute""" + + def before(A: T.Buffer(16, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + i = T.launch_thread("threadIdx.x", 16) + A[i] = 0.0 + + def expected(A: T.Buffer(16, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.attr(T.target("cuda"), "target", 0) + i = T.launch_thread("threadIdx.x", 16) + A[i] = 0.0 + + +class TestAnnotateDeviceScope(BaseCompare): + """Annotation inserted at the "device_scope" attribute""" + + def before(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.attr(0, "device_scope", 0) + A[0] = 0.0 + + def expected(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("cuda", host="llvm")}) + T.attr(T.target("cuda"), "target", 0) + T.attr(0, "device_scope", 0) + A[0] = 0.0 + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_device_kernel_launch.py b/tests/python/unittest/test_tir_transform_device_kernel_launch.py new file mode 100644 index 000000000000..a0f77da3766b --- /dev/null +++ b/tests/python/unittest/test_tir_transform_device_kernel_launch.py @@ -0,0 +1,193 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.script import tir as T, ir as I + + +class BaseCompare(tvm.testing.CompareBeforeAfter): + transform = tvm.tir.transform.LowerDeviceKernelLaunch() + + +class TestLowerDeviceKernelLaunch(BaseCompare): + """Kernel launch parameters are added at the call site + + The "tir.kernel_launch_params" determines which parameters belong + to the runtime, and which below to the device-side PrimFunc. + Parameters that are required prior to launching a kernel (e.g. the + number of Cuda threads to use) are stored in the + `"tir.kernel_launch_params"` attribute, and are used by the + runtime prior in order to launch the generated kernel. + """ + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("llvm")}) + mod.kernel(A.data) + + @T.prim_func + def kernel(A_data: T.handle("float32")): + T.func_attr({"target": T.target("cuda")}) + A = T.decl_buffer(1, dtype="float32", data=A_data) + A[0] = 0.0 + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("llvm")}) + T.call_packed("kernel", A.data) + + @T.prim_func + def kernel(A_data: T.handle("float32")): + T.func_attr( + { + "target": T.target("cuda"), + "calling_conv": 2, + "tir.kernel_launch_params": [], + "global_symbol": "kernel", + "tir.is_global_func": True, + } + ) + A = T.decl_buffer(1, dtype="float32", data=A_data) + A[0] = 0.0 + + return mod + + +class TestExternallyVisibleKernelLaunch(BaseCompare): + """Like TestLowerDeviceKernelLaunch, with pre-defined global_symbol + + Because the host and kernel will be handled by different code + generators, the device-side kernel must be externally exposed for + use by the host-side wrapper, even if the host-side wrapper does + not directly expose the kernel. Therefore, a "global_symbol" + attribute must be added for the kernel if not already present. + + If the kernel already has a specific name, that name should be + preserved. + """ + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("llvm")}) + mod.kernel(A.data) + + @T.prim_func + def kernel(A_data: T.handle("float32")): + T.func_attr({"target": T.target("cuda"), "global_symbol": "kernel_by_another_name"}) + A = T.decl_buffer(1, dtype="float32", data=A_data) + A[0] = 0.0 + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, "float32")): + T.func_attr({"target": T.target("llvm")}) + T.call_packed("kernel_by_another_name", A.data) + + @T.prim_func + def kernel(A_data: T.handle("float32")): + T.func_attr( + { + "target": T.target("cuda"), + "calling_conv": 2, + "tir.kernel_launch_params": [], + "global_symbol": "kernel_by_another_name", + "tir.is_global_func": True, + } + ) + A = T.decl_buffer(1, dtype="float32", data=A_data) + A[0] = 0.0 + + return mod + + +class TestCollectLaunchParameter(BaseCompare): + """Kernel launch parameters are added at the call site + + The "tir.kernel_launch_params" determines which parameters belong + to the runtime, and which below to the device-side PrimFunc. + Parameters that are required prior to launching a kernel (e.g. the + number of Cuda threads to use) are stored in the + `"tir.kernel_launch_params"` attribute, and are used by the + runtime prior in order to launch the generated kernel. + """ + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + T.func_attr({"target": T.target("llvm")}) + mod.kernel(A.data) + + @T.prim_func + def kernel(A_data: T.handle("float32")): + T.func_attr( + { + "target": T.target("cuda"), + "global_symbol": "kernel", + } + ) + A = T.decl_buffer(16, dtype="float32", data=A_data) + i = T.launch_thread("threadIdx.x", 16) + A[i] = 0.0 + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(16, "float32")): + T.func_attr({"target": T.target("llvm")}) + T.call_packed("kernel", A.data, 16) + + @T.prim_func + def kernel(A_data: T.handle("float32")): + T.func_attr( + { + "target": T.target("cuda"), + "calling_conv": 2, + "tir.kernel_launch_params": ["threadIdx.x"], + "global_symbol": "kernel", + "tir.is_global_func": True, + } + ) + A = T.decl_buffer(16, dtype="float32", data=A_data) + i = T.launch_thread("threadIdx.x", 16) + A[i] = 0.0 + + return mod + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index 5db33a1e057b..1e1ef410b4e1 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -201,7 +201,7 @@ def test_inject_async_copy_shared_dyn(): #define int64_t long long #define uint64_t unsigned long long #endif -extern "C" __global__ void __launch_bounds__(16) main_kernel0(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) { +extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) { __shared__ float A_shared[64]; __shared__ float B_shared[64]; A_shared[((int)threadIdx.x)] = 0.000000e+00f; diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index d4abc26bb204..c7e90d4e7dc9 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -22,6 +22,16 @@ from tvm.contrib.nvcc import have_fp16 +def _run_passes(mod): + cuda_target = tvm.target.Target("cuda", host="llvm") + assert cuda_target.thread_warp_size == 32 + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) + mod = tvm.tir.transform.AnnotateDeviceRegions()(mod) + mod = tvm.tir.transform.SplitHostDevice()(mod) + mod = tvm.tir.transform.LowerWarpMemory()(mod) + return mod + + @tvm.testing.requires_cuda def test_lower_warp_memory_local_scope(): m = 128 @@ -39,16 +49,12 @@ def test_lower_warp_memory_local_scope(): xo, xi = s[AA].split(s[AA].op.axis[0], 32) s[AA].bind(xi, tx) - cuda_target = tvm.target.Target("cuda") - assert cuda_target.thread_warp_size == 32 # lowering with the CSE pass disabled as otherwise it would do some commoning with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): mod = tvm.lower(s, [A, B], name="f") - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) - fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"] - mod = tvm.IRModule.from_expr(fdevice) - fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] + mod = _run_passes(mod) + fdevice = mod["f_kernel"] allocate = fdevice.body.body assert allocate.buffer_var.type_annotation.storage_scope == "local" assert fdevice.body.body.extents[0].value == 2 @@ -103,7 +109,7 @@ def check_cuda(dtype): A = te.placeholder((m,), name="A", dtype=dtype) B = te.compute((m,), lambda i: A[i // 32 * 32 + (i + 1) % 32], name="B") - cuda_target = tvm.target.Target("cuda") + cuda_target = tvm.target.Target("cuda", host="llvm") assert cuda_target.thread_warp_size == 32 with cuda_target: s = te.create_schedule(B.op) @@ -168,7 +174,7 @@ def check_cuda(dtype): name="B", ) - cuda_target = tvm.target.Target("cuda") + cuda_target = tvm.target.Target("cuda", host="llvm") assert cuda_target.thread_warp_size == 2 * m with cuda_target: s = te.create_schedule(B.op) @@ -214,7 +220,7 @@ def check_cuda(dtype): B = te.placeholder((m,), name="B", dtype=dtype) C = te.compute((m,), lambda i: A[(i + 1) % m] + B[(i + 1) % m], name="C") - cuda_target = tvm.target.Target("cuda") + cuda_target = tvm.target.Target("cuda", host="llvm") assert m <= cuda_target.thread_warp_size with cuda_target: s = te.create_schedule(C.op) @@ -310,15 +316,12 @@ def test_lower_warp_memory_same_thread(): xo, xi = s[BB].split(s[BB].op.axis[0], factor=32) s[BB].bind(xi, tx) - cuda_target = tvm.target.Target("cuda") - assert cuda_target.thread_warp_size == 32 # lowering with the CSE pass disabled as otherwise it would do some commoning with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): mod = tvm.lower(s, [A, B], name="f") - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) - fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"] - mod = tvm.IRModule.from_expr(fdevice) - fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] + + mod = _run_passes(mod) + fdevice = mod["f_kernel"] assert "tvm_warp_shuffle" not in fdevice.script() @@ -338,13 +341,11 @@ def test_lower_warp_memory_divide_by_factor(): stmt = ib.get() func = tvm.tir.PrimFunc([], stmt) func = func.with_attr("from_legacy_te_schedule", True) - cuda_target = tvm.target.Target("cuda") # lowering with the CSE pass disabled as otherwise it would do some commoning with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): mod = tvm.lower(func, name="f") - mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) with pytest.raises(tvm.error.TVMError, match="Divide by zero") as cm: - tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"] + _run_passes(mod) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py b/tests/python/unittest/test_tir_transform_split_host_device.py index 680f23e07a17..cf866ae005c8 100644 --- a/tests/python/unittest/test_tir_transform_split_host_device.py +++ b/tests/python/unittest/test_tir_transform_split_host_device.py @@ -35,17 +35,26 @@ def test_split_host_device_func_attr(): s[A1].compute_at(s[A2], xo) s[A1].set_scope("shared") - mod = tvm.lower(s, [A, A2], name="f") + mod = tvm.lower(s, [A, A2]) - cuda_target = tvm.target.Target("cuda") + cuda_target = tvm.target.Target("cuda", host="llvm") mod = tvm.tir.transform.Apply( lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}) )(mod) - fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"] - assert fdevice.attrs["global_symbol"] == "test_kernel0" + mod = tvm.ir.transform.Sequential( + [ + tvm.tir.transform.AnnotateDeviceRegions(), + tvm.tir.transform.SplitHostDevice(), + tvm.tir.transform.LowerDeviceKernelLaunch(), + ] + )(mod) + + fdevice = mod["test_kernel"] + + assert fdevice.attrs["global_symbol"] == "test_kernel" assert fdevice.attrs["calling_conv"].value == 2 - assert fdevice.attrs["target"] == cuda_target + assert str(fdevice.attrs["target"]) == str(tvm.target.Target("cuda")) assert fdevice.attrs["tir.is_global_func"].value @@ -60,18 +69,104 @@ def test_ssa_across_entire_module(): class before: @T.prim_func def main(): - T.func_attr({"global_symbol": "main", "target": T.target("cuda")}) + T.func_attr({"global_symbol": "main", "target": T.target("cuda", host="llvm")}) for i in range(16): T.attr(0, "device_scope", 0) for j in range(16): T.evaluate(i) - after = tvm.tir.transform.SplitHostDevice()(before) + after = tvm.ir.transform.Sequential( + [ + tvm.tir.transform.AnnotateDeviceRegions(), + tvm.tir.transform.SplitHostDevice(), + tvm.tir.transform.LowerDeviceKernelLaunch(), + ] + )(before) loop_var = after["main"].body.loop_var - param_var = after["main_kernel0"].params[0] + param_var = after["main_kernel"].params[0] assert not loop_var.same_as(param_var) +class BaseCompare(tvm.testing.CompareBeforeAfter): + transform = tvm.tir.transform.SplitHostDevice() + + +class TestSplitHostDevice(BaseCompare): + """SplitHostDevice divides a function at the "target" attribute""" + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(n: T.int32): + T.func_attr({"target": T.target("cuda", host="llvm -opt-level=0")}) + T.attr(T.target("cuda"), "target", 0) + T.evaluate(n) + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(n: T.int32): + T.func_attr({"target": T.target("llvm -opt-level=0")}) + mod.main_kernel(n) + + @T.prim_func + def main_kernel(n: T.int32): + T.func_attr( + { + "target": T.target("cuda"), + "tir.noalias": T.bool(True), + "tir.is_global_func": True, + } + ) + T.evaluate(n) + + return mod + + +class TestSplitHostDeviceWithoutFuncHostAttribute(BaseCompare): + """Like TestSplitHostDevice, but no host specified in the host's target + + The `T.attr` specifying the device still requires splitting out + the kernel. + """ + + def before(self): + @I.ir_module + class mod: + @T.prim_func + def main(n: T.int32): + T.func_attr({"target": T.target("llvm")}) + T.attr(T.target("cuda"), "target", 0) + T.evaluate(n) + + return mod + + def expected(self): + @I.ir_module + class mod: + @T.prim_func + def main(n: T.int32): + T.func_attr({"target": T.target("llvm")}) + mod.main_kernel(n) + + @T.prim_func + def main_kernel(n: T.int32): + T.func_attr( + { + "target": T.target("cuda"), + "tir.noalias": T.bool(True), + "tir.is_global_func": True, + } + ) + T.evaluate(n) + + return mod + + if __name__ == "__main__": - test_split_host_device_func_attr() + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index eb578a8817b5..57ea223cf984 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -24,12 +24,13 @@ def run_passes(func: tvm.tir.PrimFunc): mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.StorageFlatten(64)(mod) - cuda_target = tvm.target.Target("cuda") + cuda_target = tvm.target.Target("cuda", host="llvm") mod = tvm.tir.transform.Apply( lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}) )(mod) + mod = tvm.tir.transform.AnnotateDeviceRegions()(mod) mod = tvm.tir.transform.SplitHostDevice()(mod) return tvm.tir.transform.ThreadSync("shared")(mod) @@ -55,7 +56,7 @@ def test_thread_storage_sync(): func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) mod = run_passes(func) - f = mod["test_kernel0"] + f = mod["test_kernel"] body_list = tvm.tir.stmt_list(f.body.body.body) assert body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))