Skip to content

Commit

Permalink
[Driver] Single-module lowering flow in driver_api.cc
Browse files Browse the repository at this point in the history
Prior to this commit, a build that used multiple targets needed to
provide `tvm::build` with a `Map<Target, IRModule>` specifying which
target should be used to compile each `IRModule`.  As a result,
lowering passes could not introduce new targets based on a PrimFunc's
content (e.g. a `with T.target()` frame to delegate out to another
device), nor simplify based on cross-device subroutines (e.g. simplify
a host-side conditional based on the known output of a device-side
internal subroutine).

This commit makes the `tvm::attr::kTarget` attribute (`"target"`) be
the single source of truth for where a `PrimFunc` will be executed.
Other existing methods for specifying the target (the `target`
parameter for `tvm.build`, the keys in a `Map<Target,IRModule>`, the
parameter to the pass `tir::transform::BindTarget`) are still accepted
as inputs, and may provide a default value for `tvm::attr::kTarget` if
the attribute is missing, but may not overwrite the target attribute.

This is part of a series of commits to simplify the handling of
multi-target builds.
  • Loading branch information
Lunderberg committed Mar 25, 2024
1 parent b2204ae commit 485bd3e
Show file tree
Hide file tree
Showing 21 changed files with 757 additions and 118 deletions.
2 changes: 1 addition & 1 deletion apps/extension/tests/test_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_ext_dev():
def check_llvm():
if not tvm.testing.device_enabled("llvm"):
return
f = tvm.build(s, [A, B], "ext_dev", "llvm")
f = tvm.build(s, [A, B], "ext_dev", "ext_dev")
dev = tvm.ext_dev(0)
# launch the kernel.
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ using tvm::transform::Pass;
* \param target The device Target.
* \return The composite Pass for the fused module.
// */
TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target);
TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod,
Optional<Target> target = NullOpt);

/*!
* \brief Configures and returns the composite Pass for the device Target after device/host from
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
primfunc = tir_mod["main"]
primfunc = primfunc.with_attr("global_symbol", func.attrs["global_symbol"])
primfunc = primfunc.with_attr("ethos-u.constants", const_dict)
primfunc = primfunc.with_attr("target", tvm.target.Target(compiler_name))
primfunc = primfunc.with_attr(
"target", tvm.target.Target(compiler_name, host=compiler_name)
)
return primfunc

def __call__(self, *args, **kwargs):
Expand Down
228 changes: 143 additions & 85 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,17 +279,6 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
return pass_list;
}

IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
auto optimize = tvm::transform::Sequential(pass_list);
mod = optimize(std::move(mod));
return mod;
}

IRModule ApplyPasses(IRModule mod, transform::Sequential seq) {
mod = seq(std::move(mod));
return mod;
}

// Convert te schedule to IRModule
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
Expand Down Expand Up @@ -342,7 +331,8 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")

IRModule LowerModule(IRModule mod, bool simple_mode) {
Array<transform::Pass> pass_list = CreatePassList(simple_mode);
return LowerWithPassList(std::move(mod), pass_list);
tvm::transform::Sequential optimize(pass_list, "tvm.lower");
return optimize(std::move(mod));
}

TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, bool simple_mode) {
Expand All @@ -359,10 +349,7 @@ IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_
f = WithAttr(std::move(f), "tir.noalias", Bool(true));
}
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));

// Get the pass list
Array<transform::Pass> pass_list = CreatePassList(simple_mode);
return LowerWithPassList(std::move(mod), pass_list);
return LowerModule(mod, simple_mode);
}

TVM_REGISTER_GLOBAL("driver.lower_primfunc")
Expand All @@ -384,9 +371,7 @@ IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args, const std
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
GlobalVarSupply global_var_supply, bool simple_mode) {
IRModule mod = ScheduleToModule(std::move(sch), args, name, binds, global_var_supply);
// Get the legacy TE pass list
Array<transform::Pass> pass_list = CreatePassList(simple_mode);
return LowerWithPassList(mod, pass_list);
return LowerModule(mod, simple_mode);
}

TVM_REGISTER_GLOBAL("driver.lower_schedule")
Expand All @@ -403,35 +388,42 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
simple_mode);
});

/**
* This function takes the input module that contains both the device and host opts.
* Then, it applies transformation on the original module before splitting into separate modules for
* device and host. Then it also applies transformations on the new splitted modules.
*/
std::pair<IRModule, IRModule> SplitMixedModule(IRModule mod_mixed, const Target& target_arg,
const Target& target_host_arg) {
Target target = target_arg, target_host = target_host_arg;
CheckAndUpdateHostConsistency(&target, &target_host);

ICHECK(mod_mixed.defined()) << "This module must be defined";
IRModule MergeModules(const Map<Target, IRModule>& inputs) {
if (inputs.size() == 1) {
auto [target, mod] = *inputs.begin();
return tir::transform::BindTarget(target)(mod);
}

mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target));
// Take the attrs from the first module so the eventual modules have them.
IRModule first_module = (*inputs.begin()).second;
IRModule merged = IRModule(Map<GlobalVar, BaseFunc>(), {}, {}, {}, first_module->attrs);

IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host));
for (auto [target, mod] : inputs) {
mod = tir::transform::BindTarget(target)(mod);
merged->Update(mod);
}

IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target));
return merged;
}

auto keys = target->GetKeys();
Map<Target, IRModule> SplitModule(const IRModule& module) {
Map<String, IRModule> split;

CheckAndUpdateHostConsistency(&target, &target_host);
for (auto [gvar, base_func] : module->functions) {
auto target_str = base_func->GetAttr<Target>(tvm::attr::kTarget).value()->str();
if (auto it = split.find(target_str); it != split.end()) {
(*it).second->Add(gvar, base_func);
} else {
split.Set(target_str, IRModule({{gvar, base_func}}, {}, {}, {}, module->attrs));
}
}

bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
if (target_is_gpu && device_mod->functions.size() == 0) {
DLOG(WARNING) << "Specified target " << target->str()
<< " but cannot find device code. Did you forget to bind?";
Map<Target, IRModule> out;
for (auto [str, mod] : split) {
out.Set(Target(str), mod);
}

return {host_mod, device_mod};
return out;
}

/*!
Expand Down Expand Up @@ -478,52 +470,86 @@ runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
// Update target host for all targets
CheckAndUpdateHostConsistency(&inputs, &target_host);

// Take the attrs from the first module so the eventual modules have them.
// Ideally this would just be one unified module all the way through;
IRModule first_module = (*inputs.begin()).second;
IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>(), {}, {}, {}, first_module->attrs);

ICHECK(mhost_all.defined()) << "The host module must be defined";

for (const auto& it : inputs) {
if (it.second.defined()) {
const Target& target = it.first;
const IRModule& ir_module = it.second;
auto pair = SplitMixedModule(ir_module, target, target_host);
auto& host_mod = pair.first;
auto& device_mod = pair.second;

ICHECK(host_mod.defined()) << "The split host module must be defined";

ICHECK(mhost_all.defined()) << "The host module must be defined";

// We don't want library modules going back into host codegen
// unless they're supposed to. Here if we overrode the target host
// to allow lowering previously we check that it's meant to be placed
// back into the host Module.
bool overrides_host_target =
target->GetTargetDeviceType() == target_host->GetTargetDeviceType();
bool non_host_target_kind = target->kind != target_host->kind;
if (overrides_host_target && non_host_target_kind) {
device_modules.push_back(codegen::Build(host_mod, it.first));
} else {
mhost_all->Update(host_mod);
auto has_gpu_function = [](const IRModule& mod) -> bool {
for (const auto& [gvar, func] : mod->functions) {
if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
if (target.value()->HasKey("gpu")) {
return true;
}
}
}
return false;
};

IRModule merged = MergeModules(inputs);

bool contains_gpu_function_pre = has_gpu_function(merged);
merged = MixedModulePassManager(merged)(merged);
bool contains_gpu_function_post = has_gpu_function(merged);
if (contains_gpu_function_pre && !contains_gpu_function_post) {
DLOG(WARNING) << "Specified GPU targets, "
<< "but cannot find device code. Did you forget to bind?";
}

Map<Target, IRModule> split = SplitModule(merged);

if (device_mod->functions.size() != 0) {
device_modules.push_back(codegen::Build(device_mod, it.first));
Map<Target, runtime::Module> built;
for (const auto& [target, mod] : split) {
built.Set(target, codegen::Build(mod, target));
}

auto host_target = [&]() -> Target {
// All targets that contain a kIsEntryFunc=True function
Array<Target> targets_with_entry_func;

// All targets that can run on the CPU and contain at least one
// function without kIsEntryFunc=False.
Array<Target> cpu_targets;
for (const auto& [target, mod] : split) {
bool contains_entry_func = false;
bool may_contain_entry_func = false;
for (const auto& [gvar, func] : mod->functions) {
Optional<Bool> is_entry_func = func->attrs.GetAttr<Bool>(tvm::tir::attr::kIsEntryFunc);
if (is_entry_func.defined() && is_entry_func.value()->value) {
contains_entry_func = true;
} else if (!is_entry_func.defined()) {
may_contain_entry_func = true;
}
}

if (contains_entry_func) {
targets_with_entry_func.push_back(target);
}

if (may_contain_entry_func && target->HasKey("cpu")) {
cpu_targets.push_back(target);
}
}
}

runtime::Module mhost = codegen::Build(mhost_all, target_host);
for (const auto& it : device_modules) {
if (it.operator->()) {
mhost.Import(it);
if (targets_with_entry_func.size()) {
ICHECK_EQ(targets_with_entry_func.size(), 1)
<< "Expected at most one function "
<< "annotated with tvm::tir::attr::kIsEntryFunc "
<< "(\"" << tvm::tir::attr::kIsEntryFunc << "\"), "
<< "but found: " << targets_with_entry_func;
return targets_with_entry_func[0];
} else if (cpu_targets.size() == 1) {
return cpu_targets[0];
} else {
LOG(FATAL) << "Could not determine which target is the host. "
<< "No function was annotated with tvm::tir::attr::kIsEntryFunc (\""
<< tvm::tir::attr::kIsEntryFunc << "\"), "
<< "and " << cpu_targets.size() << " targets have the 'cpu' key";
}
}();

auto runtime_module = built[host_target];
for (const auto& [target, mod] : built) {
if (!mod.same_as(runtime_module)) {
runtime_module.Import(mod);
}
}

return mhost;
return runtime_module;
}

TVM_REGISTER_GLOBAL("driver.tir_to_runtime")
Expand Down Expand Up @@ -564,13 +590,16 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg,
return TIRToRuntime(inputs, target_host);
}

transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) {
transform::Sequential MixedModulePassManager(IRModule mixed_mod, Optional<Target> target) {
transform::PassContext pass_ctx = transform::PassContext::Current();

Array<Pass> mixed_pass_list;

// FPComputeLegalize uses the target attrs added by BindTarget, so it must come first
mixed_pass_list.push_back(tir::transform::BindTarget(target));
// FPComputeLegalize uses the target attrs added by BindTarget, so
// BindTarget must come first.
if (target) {
mixed_pass_list.push_back(tir::transform::BindTarget(target.value()));
}
mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize());

// VerifyVTCMLimit must occur before LowerVtcmAlloc
Expand Down Expand Up @@ -625,7 +654,28 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)

mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());

return transform::Sequential(mixed_pass_list);
// Only applies to the device functions, identified by inspection of
// each function's tvm::attr::kTarget attribute.
mixed_pass_list.push_back(tir::transform::LowerWarpMemory());

// Only applies to the host functions, identified by inspection of
// each function's tvm::attr::kTarget attribute.
mixed_pass_list.push_back(tir::transform::LowerTVMBuiltin());

// Apply to both host and device functions
mixed_pass_list.push_back(tir::transform::Simplify());
mixed_pass_list.push_back(tir::transform::LowerCustomDatatypes());
mixed_pass_list.push_back(tir::transform::LowerIntrin());
mixed_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());

// Only applies to the host functions, identified by inspection of
// each function's tvm::attr::kTarget attribute.
mixed_pass_list.push_back(tir::transform::CombineContextCall());
if (pass_ctx->GetConfig<Bool>("tir.enable_debug", Bool(false)).value()) {
mixed_pass_list.push_back(tir::transform::InstallDebugSpans());
}

return transform::Sequential(mixed_pass_list, "tvm.build");
}

TVM_REGISTER_GLOBAL("driver.mixed_mod_passes")
Expand All @@ -634,6 +684,10 @@ TVM_REGISTER_GLOBAL("driver.mixed_mod_passes")
});

transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) {
LOG(WARNING) << "Use of driver.host_mod_passes is deprecated. "
<< "All lowering passes are now included "
<< "as part of driver.mixed_mod_passes.";

transform::PassContext pass_ctx = transform::PassContext::Current();
bool enable_debug = pass_ctx->GetConfig<Bool>("tir.enable_debug", Bool(false)).value();

Expand All @@ -659,7 +713,7 @@ transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_ho
host_pass_list.push_back(tir::transform::InstallDebugSpans());
}

return transform::Sequential(host_pass_list);
return transform::Sequential(host_pass_list, "tir.host_mod_passes");
}

TVM_REGISTER_GLOBAL("driver.host_mod_passes")
Expand All @@ -668,6 +722,10 @@ TVM_REGISTER_GLOBAL("driver.host_mod_passes")
});

transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) {
LOG(WARNING) << "Use of driver.device_mod_passes is deprecated. "
<< "All lowering passes are now included "
<< "as part of driver.mixed_mod_passes.";

Array<Pass> device_pass_list;
runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const tir::PrimFunc& f) {
return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
Expand All @@ -683,7 +741,7 @@ transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target)
device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
device_pass_list.push_back(tir::transform::LowerIntrin());

return transform::Sequential(device_pass_list);
return transform::Sequential(device_pass_list, "tir.device_mod_passes");
}

TVM_REGISTER_GLOBAL("driver.device_mod_passes")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ConvertAddToSubtract : public MixedModeMutator {
explicit ConvertAddToSubtract(IRModule ir_module, Target host_target)
: ir_module_(ir_module),
host_target_(host_target),
custom_target_(Target("example_target_hook")) {}
custom_target_(Target(Target("example_target_hook"), Target("example_target_hook"))) {}

IRModule Mutate() {
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
Expand Down
16 changes: 10 additions & 6 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -610,12 +610,16 @@ void* LLVMModuleNode::GetFunctionAddr(const std::string& name,
return nullptr;
}

TVM_REGISTER_GLOBAL("target.build.llvm")
.set_body_typed([](IRModule mod, Target target) -> runtime::Module {
auto n = make_object<LLVMModuleNode>();
n->Init(mod, target);
return runtime::Module(n);
});
namespace {
runtime::Module BuildLLVM(IRModule mod, Target target) {
auto n = make_object<LLVMModuleNode>();
n->Init(mod, target);
return runtime::Module(n);
}
} // namespace

TVM_REGISTER_GLOBAL("target.build.llvm").set_body_typed(BuildLLVM);
TVM_REGISTER_GLOBAL("target.build.ext_dev").set_body_typed(BuildLLVM);

TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")
.set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module {
Expand Down
Loading

0 comments on commit 485bd3e

Please sign in to comment.