Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Driver] Single-module lowering flow in driver_api.cc #14985

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/extension/tests/test_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_ext_dev():
def check_llvm():
if not tvm.testing.device_enabled("llvm"):
return
f = tvm.build(s, [A, B], "ext_dev", "llvm")
f = tvm.build(s, [A, B], "ext_dev", "ext_dev")
dev = tvm.ext_dev(0)
# launch the kernel.
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ using tvm::transform::Pass;
* \param target The device Target.
* \return The composite Pass for the fused module.
// */
TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target);
TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod,
Optional<Target> target = NullOpt);

/*!
* \brief Configures and returns the composite Pass for the device Target after device/host from
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
primfunc = tir_mod["main"]
primfunc = primfunc.with_attr("global_symbol", func.attrs["global_symbol"])
primfunc = primfunc.with_attr("ethos-u.constants", const_dict)
primfunc = primfunc.with_attr("target", tvm.target.Target(compiler_name))
primfunc = primfunc.with_attr(
"target", tvm.target.Target(compiler_name, host=compiler_name)
)
return primfunc

def __call__(self, *args, **kwargs):
Expand Down
232 changes: 147 additions & 85 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,17 +281,6 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
return pass_list;
}

IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
auto optimize = tvm::transform::Sequential(pass_list);
mod = optimize(std::move(mod));
return mod;
}

IRModule ApplyPasses(IRModule mod, transform::Sequential seq) {
mod = seq(std::move(mod));
return mod;
}

// Convert te schedule to IRModule
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
Expand Down Expand Up @@ -343,7 +332,8 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")

IRModule LowerModule(IRModule mod, bool simple_mode) {
Array<transform::Pass> pass_list = CreatePassList(simple_mode);
return LowerWithPassList(std::move(mod), pass_list);
tvm::transform::Sequential optimize(pass_list, "tvm.lower");
return optimize(std::move(mod));
}

TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, bool simple_mode) {
Expand All @@ -360,10 +350,7 @@ IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_
f = WithAttr(std::move(f), "tir.noalias", Bool(true));
}
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));

// Get the pass list
Array<transform::Pass> pass_list = CreatePassList(simple_mode);
return LowerWithPassList(std::move(mod), pass_list);
return LowerModule(mod, simple_mode);
}

TVM_REGISTER_GLOBAL("driver.lower_primfunc")
Expand All @@ -385,9 +372,7 @@ IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args, const std
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
GlobalVarSupply global_var_supply, bool simple_mode) {
IRModule mod = ScheduleToModule(std::move(sch), args, name, binds, global_var_supply);
// Get the legacy TE pass list
Array<transform::Pass> pass_list = CreatePassList(simple_mode);
return LowerWithPassList(mod, pass_list);
return LowerModule(mod, simple_mode);
}

TVM_REGISTER_GLOBAL("driver.lower_schedule")
Expand All @@ -403,35 +388,42 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
return LowerSchedule(std::move(sch), args, name, c_binds, GlobalVarSupply(), simple_mode);
});

/**
* This function takes the input module that contains both the device and host opts.
* Then, it applies transformation on the original module before splitting into separate modules for
* device and host. Then it also applies transformations on the new splitted modules.
*/
std::pair<IRModule, IRModule> SplitMixedModule(IRModule mod_mixed, const Target& target_arg,
const Target& target_host_arg) {
Target target = target_arg, target_host = target_host_arg;
CheckAndUpdateHostConsistency(&target, &target_host);

ICHECK(mod_mixed.defined()) << "This module must be defined";
IRModule MergeModules(const Map<Target, IRModule>& inputs) {
if (inputs.size() == 1) {
auto [target, mod] = *inputs.begin();
return tir::transform::BindTarget(target)(mod);
}

mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target));
// Take the attrs from the first module so the eventual modules have them.
IRModule first_module = (*inputs.begin()).second;
IRModule merged = IRModule(Map<GlobalVar, BaseFunc>(), {}, {}, {}, first_module->attrs);

IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host));
for (auto [target, mod] : inputs) {
mod = tir::transform::BindTarget(target)(mod);
merged->Update(mod);
}

IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target));
return merged;
}

auto keys = target->GetKeys();
Map<Target, IRModule> SplitModule(const IRModule& module) {
Map<String, IRModule> split;

CheckAndUpdateHostConsistency(&target, &target_host);
for (auto [gvar, base_func] : module->functions) {
auto target_str = base_func->GetAttr<Target>(tvm::attr::kTarget).value()->str();
if (auto it = split.find(target_str); it != split.end()) {
(*it).second->Add(gvar, base_func);
} else {
split.Set(target_str, IRModule({{gvar, base_func}}, {}, {}, {}, module->attrs));
}
}

bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
if (target_is_gpu && device_mod->functions.size() == 0) {
DLOG(WARNING) << "Specified target " << target->str()
<< " but cannot find device code. Did you forget to bind?";
Map<Target, IRModule> out;
for (auto [str, mod] : split) {
out.Set(Target(str), mod);
}

return {host_mod, device_mod};
return out;
}

/*!
Expand Down Expand Up @@ -478,52 +470,86 @@ runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
// Update target host for all targets
CheckAndUpdateHostConsistency(&inputs, &target_host);

// Take the attrs from the first module so the eventual modules have them.
// Ideally this would just be one unified module all the way through;
IRModule first_module = (*inputs.begin()).second;
IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>(), {}, {}, {}, first_module->attrs);

ICHECK(mhost_all.defined()) << "The host module must be defined";

for (const auto& it : inputs) {
if (it.second.defined()) {
const Target& target = it.first;
const IRModule& ir_module = it.second;
auto pair = SplitMixedModule(ir_module, target, target_host);
auto& host_mod = pair.first;
auto& device_mod = pair.second;

ICHECK(host_mod.defined()) << "The split host module must be defined";

ICHECK(mhost_all.defined()) << "The host module must be defined";

// We don't want library modules going back into host codegen
// unless they're supposed to. Here if we overrode the target host
// to allow lowering previously we check that it's meant to be placed
// back into the host Module.
bool overrides_host_target =
target->GetTargetDeviceType() == target_host->GetTargetDeviceType();
bool non_host_target_kind = target->kind != target_host->kind;
if (overrides_host_target && non_host_target_kind) {
device_modules.push_back(codegen::Build(host_mod, it.first));
} else {
mhost_all->Update(host_mod);
auto has_gpu_function = [](const IRModule& mod) -> bool {
for (const auto& [gvar, func] : mod->functions) {
if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
if (target.value()->HasKey("gpu")) {
return true;
}
}
}
return false;
};

IRModule merged = MergeModules(inputs);

bool contains_gpu_function_pre = has_gpu_function(merged);
merged = MixedModulePassManager(merged)(merged);
bool contains_gpu_function_post = has_gpu_function(merged);
if (contains_gpu_function_pre && !contains_gpu_function_post) {
DLOG(WARNING) << "Specified GPU targets, "
<< "but cannot find device code. Did you forget to bind?";
}

Map<Target, IRModule> split = SplitModule(merged);

if (device_mod->functions.size() != 0) {
device_modules.push_back(codegen::Build(device_mod, it.first));
Map<Target, runtime::Module> built;
for (const auto& [target, mod] : split) {
built.Set(target, codegen::Build(mod, target));
}

auto host_target = [&]() -> Target {
// All targets that contain a kIsEntryFunc=True function
Array<Target> targets_with_entry_func;

// All targets that can run on the CPU and contain at least one
// function without kIsEntryFunc=False.
Array<Target> cpu_targets;
for (const auto& [target, mod] : split) {
bool contains_entry_func = false;
bool may_contain_entry_func = false;
for (const auto& [gvar, func] : mod->functions) {
Optional<Bool> is_entry_func = func->attrs.GetAttr<Bool>(tvm::tir::attr::kIsEntryFunc);
if (is_entry_func.defined() && is_entry_func.value()->value) {
contains_entry_func = true;
} else if (!is_entry_func.defined()) {
may_contain_entry_func = true;
}
}

if (contains_entry_func) {
targets_with_entry_func.push_back(target);
}

if (may_contain_entry_func && target->HasKey("cpu")) {
cpu_targets.push_back(target);
}
}
}

runtime::Module mhost = codegen::Build(mhost_all, target_host);
for (const auto& it : device_modules) {
if (it.operator->()) {
mhost.Import(it);
if (targets_with_entry_func.size()) {
ICHECK_EQ(targets_with_entry_func.size(), 1)
<< "Expected at most one function "
<< "annotated with tvm::tir::attr::kIsEntryFunc "
<< "(\"" << tvm::tir::attr::kIsEntryFunc << "\"), "
<< "but found: " << targets_with_entry_func;
return targets_with_entry_func[0];
} else if (cpu_targets.size() == 1) {
return cpu_targets[0];
} else {
LOG(FATAL) << "Could not determine which target is the host. "
<< "No function was annotated with tvm::tir::attr::kIsEntryFunc (\""
<< tvm::tir::attr::kIsEntryFunc << "\"), "
<< "and " << cpu_targets.size() << " targets have the 'cpu' key";
}
}();

auto runtime_module = built[host_target];
for (const auto& [target, mod] : built) {
if (!mod.same_as(runtime_module)) {
runtime_module.Import(mod);
}
}

return mhost;
return runtime_module;
}

TVM_REGISTER_GLOBAL("driver.tir_to_runtime")
Expand Down Expand Up @@ -564,13 +590,16 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg,
return TIRToRuntime(inputs, target_host);
}

transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) {
transform::Sequential MixedModulePassManager(IRModule mixed_mod, Optional<Target> target) {
transform::PassContext pass_ctx = transform::PassContext::Current();

Array<Pass> mixed_pass_list;

// FPComputeLegalize uses the target attrs added by BindTarget, so it must come first
mixed_pass_list.push_back(tir::transform::BindTarget(target));
// FPComputeLegalize uses the target attrs added by BindTarget, so
// BindTarget must come first.
if (target) {
mixed_pass_list.push_back(tir::transform::BindTarget(target.value()));
}
mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize());

// VerifyVTCMLimit must occur before LowerVtcmAlloc
Expand Down Expand Up @@ -625,7 +654,32 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)

mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());

return transform::Sequential(mixed_pass_list);
// After the device kernels have been split into host/device
// sections, the host section can be inlined.
mixed_pass_list.push_back(tir::transform::InlinePrivateFunctions());

// Only applies to the device functions, identified by inspection of
// each function's tvm::attr::kTarget attribute.
mixed_pass_list.push_back(tir::transform::LowerWarpMemory());

// Only applies to the host functions, identified by inspection of
// each function's tvm::attr::kTarget attribute.
mixed_pass_list.push_back(tir::transform::LowerTVMBuiltin());

// Apply to both host and device functions
mixed_pass_list.push_back(tir::transform::Simplify());
mixed_pass_list.push_back(tir::transform::LowerCustomDatatypes());
mixed_pass_list.push_back(tir::transform::LowerIntrin());
mixed_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());

// Only applies to the host functions, identified by inspection of
// each function's tvm::attr::kTarget attribute.
mixed_pass_list.push_back(tir::transform::CombineContextCall());
if (pass_ctx->GetConfig<Bool>("tir.enable_debug", Bool(false)).value()) {
mixed_pass_list.push_back(tir::transform::InstallDebugSpans());
}

return transform::Sequential(mixed_pass_list, "tvm.build");
}

TVM_REGISTER_GLOBAL("driver.mixed_mod_passes")
Expand All @@ -634,6 +688,10 @@ TVM_REGISTER_GLOBAL("driver.mixed_mod_passes")
});

transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) {
LOG(WARNING) << "Use of driver.host_mod_passes is deprecated. "
<< "All lowering passes are now included "
<< "as part of driver.mixed_mod_passes.";

transform::PassContext pass_ctx = transform::PassContext::Current();
bool enable_debug = pass_ctx->GetConfig<Bool>("tir.enable_debug", Bool(false)).value();

Expand All @@ -659,7 +717,7 @@ transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_ho
host_pass_list.push_back(tir::transform::InstallDebugSpans());
}

return transform::Sequential(host_pass_list);
return transform::Sequential(host_pass_list, "tir.host_mod_passes");
}

TVM_REGISTER_GLOBAL("driver.host_mod_passes")
Expand All @@ -668,6 +726,10 @@ TVM_REGISTER_GLOBAL("driver.host_mod_passes")
});

transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) {
LOG(WARNING) << "Use of driver.device_mod_passes is deprecated. "
<< "All lowering passes are now included "
<< "as part of driver.mixed_mod_passes.";

Array<Pass> device_pass_list;
runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const tir::PrimFunc& f) {
return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
Expand All @@ -683,7 +745,7 @@ transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target)
device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
device_pass_list.push_back(tir::transform::LowerIntrin());

return transform::Sequential(device_pass_list);
return transform::Sequential(device_pass_list, "tir.device_mod_passes");
}

TVM_REGISTER_GLOBAL("driver.device_mod_passes")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ConvertAddToSubtract : public MixedModeMutator {
explicit ConvertAddToSubtract(IRModule ir_module, Target host_target)
: ir_module_(ir_module),
host_target_(host_target),
custom_target_(Target("example_target_hook")) {}
custom_target_(Target(Target("example_target_hook"), Target("example_target_hook"))) {}

IRModule Mutate() {
GlobalVar main_global_var = ir_module_->GetGlobalVar("main");
Expand Down
16 changes: 10 additions & 6 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -618,12 +618,16 @@ void* LLVMModuleNode::GetFunctionAddr(const std::string& name,
return nullptr;
}

TVM_REGISTER_GLOBAL("target.build.llvm")
.set_body_typed([](IRModule mod, Target target) -> runtime::Module {
auto n = make_object<LLVMModuleNode>();
n->Init(mod, target);
return runtime::Module(n);
});
namespace {
runtime::Module BuildLLVM(IRModule mod, Target target) {
auto n = make_object<LLVMModuleNode>();
n->Init(mod, target);
return runtime::Module(n);
}
} // namespace

TVM_REGISTER_GLOBAL("target.build.llvm").set_body_typed(BuildLLVM);
TVM_REGISTER_GLOBAL("target.build.ext_dev").set_body_typed(BuildLLVM);

TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")
.set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module {
Expand Down
Loading
Loading