Skip to content

Commit

Permalink
Allow "ext_dev" to act as host.
Browse files Browse the repository at this point in the history
Currently, the Target does two independent tasks: (1) defining which
device owns the buffers that are passed as input to a PrimFunc,
and (2) defining which codegen will be used for a PrimFunc.

Prior to this commit, the "ext_dev" target was required to define the
device ownership, but did not provide the `"target.build.ext_dev"`
function that is required for codegen.  This worked, because
`SplitHostDevice` would remove the `"ext_dev"` target without making a
device-side function.  With the single-module lowering flow, the
separate device-side function is required to support UMA codegen.

To resolve this issue, `"ext_dev"` now provides a codegen function,
which is identical to the LLVM codegen.  This may be improved in the
future by allowing the buffer device and the codegen to be specified
independently.
  • Loading branch information
Lunderberg committed Jun 16, 2023
1 parent 888130d commit d3ddd40
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion apps/extension/tests/test_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_ext_dev():
def check_llvm():
if not tvm.testing.device_enabled("llvm"):
return
f = tvm.build(s, [A, B], "ext_dev", "llvm")
f = tvm.build(s, [A, B], "ext_dev", "ext_dev")
dev = tvm.ext_dev(0)
# launch the kernel.
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
Expand Down
16 changes: 10 additions & 6 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,12 +442,16 @@ void* LLVMModuleNode::GetFunctionAddr(const std::string& name,
}
}

TVM_REGISTER_GLOBAL("target.build.llvm")
.set_body_typed([](IRModule mod, Target target) -> runtime::Module {
auto n = make_object<LLVMModuleNode>();
n->Init(mod, target);
return runtime::Module(n);
});
namespace {
runtime::Module BuildLLVM(IRModule mod, Target target) {
auto n = make_object<LLVMModuleNode>();
n->Init(mod, target);
return runtime::Module(n);
}
} // namespace

TVM_REGISTER_GLOBAL("target.build.llvm").set_body_typed(BuildLLVM);
TVM_REGISTER_GLOBAL("target.build.ext_dev").set_body_typed(BuildLLVM);

TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")
.set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module {
Expand Down
2 changes: 1 addition & 1 deletion src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon)
TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU) // line break
.set_default_keys({"cpu"});

TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev);
TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev).set_default_keys({"cpu"});

TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU);

Expand Down

0 comments on commit d3ddd40

Please sign in to comment.