Skip to content

Commit

Permalink
Better detection of host target
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Jun 13, 2023
1 parent c35a8b5 commit 03c39ef
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
20 changes: 16 additions & 4 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,17 +473,29 @@ runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
}

auto host_target = [&]() -> Target {
// All targets that contain a kIsEntryFunc=True function
Array<Target> targets_with_entry_func;

// All targets that can run on the CPU and contain at least one
// function without kIsEntryFunc=False.
Array<Target> cpu_targets;
for (const auto& [target, mod] : split) {
bool contains_entry_func = std::any_of(
mod->functions.begin(), mod->functions.end(),
[](const auto& kv) { return kv.second->HasNonzeroAttr(tvm::tir::attr::kIsEntryFunc); });
bool contains_entry_func = false;
bool may_contain_entry_func = false;
for (const auto& [gvar, func] : mod->functions) {
Optional<Bool> is_entry_func = func->attrs.GetAttr<Bool>(tvm::tir::attr::kIsEntryFunc);
if (is_entry_func.defined() && is_entry_func.value()->value) {
contains_entry_func = true;
} else if (!is_entry_func.defined()) {
may_contain_entry_func = true;
}
}

if (contains_entry_func) {
targets_with_entry_func.push_back(target);
}

if (target->HasKey("cpu")) {
if (may_contain_entry_func && target->HasKey("cpu")) {
cpu_targets.push_back(target);
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/tir/transforms/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ class HostDeviceSplitter : public StmtMutator {
PrimFunc device_func(params, body);
device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target},
{tir::attr::kNoAlias, Bool(true)},
{tir::attr::kIsGlobalFunc, Bool(true)}});
{tir::attr::kIsGlobalFunc, Bool(true)},
{tir::attr::kIsEntryFunc, Bool(false)}});

(*device_mod_)->Add(kernel_symbol_global, device_func);
Array<PrimExpr> args = params.Map([](const Var& var) -> PrimExpr { return var; });
Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_tir_transform_split_host_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def main_kernel(n: T.int32):
"target": T.target("cuda"),
"tir.noalias": T.bool(True),
"tir.is_global_func": True,
"tir.is_entry_func": False,
}
)
T.evaluate(n)
Expand Down Expand Up @@ -162,6 +163,7 @@ def main_kernel(n: T.int32):
"target": T.target("cuda"),
"tir.noalias": T.bool(True),
"tir.is_global_func": True,
"tir.is_entry_func": False,
}
)
T.evaluate(n)
Expand Down

0 comments on commit 03c39ef

Please sign in to comment.