Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC] Handle constants in IRModule-at-a-time external codegen #11770

Merged
merged 1 commit into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/modules/contrib/CODEGENC.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.

tvm_file_glob(GLOB CSOURCE_RELAY_CONTRIB_SRC src/relay/backend/contrib/codegen_c/codegen.cc)
tvm_file_glob(GLOB CSOURCE_RELAY_CONTRIB_SRC src/relay/backend/contrib/codegen_c/*.cc)
list(APPEND COMPILER_SRCS ${CSOURCE_RELAY_CONTRIB_SRC})

30 changes: 27 additions & 3 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,10 @@ TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true,

namespace attr {

// Following are attributes for IRModule only.

/*!
* \brief Executor targetted by the module
* \brief Executor targeted by the module
*
* Type: Executor
*
Expand Down Expand Up @@ -516,9 +518,31 @@ constexpr const char* kWorkspaceMemoryPools = "workspace_memory_pools";
constexpr const char* kConstantMemoryPools = "constant_memory_pools";

/*
* \brief Module attribute for tir constants
* \brief All the runtime::NDArrays extracted from PrimFunc tir::AllocateConst nodes. The
* node will record the index into this array. See also kConstNameToConstant below, which is
* the analog for Realy Functions.
*
* Type: Array<runtime::NDArray>
*/
constexpr const char* kConstants = "constants";

/*!
* \brief All the runtime::Modules accumulated during compilation by external codegen. These
* modules must be either directly linked or captured in the final compilation artifact.
*
* Type: Array<runtime::Module>
*/
constexpr const char* kExternalMods = "external_mods";

/*!
* \brief All the named runtime::NDArrays accumulated during compilation by external codegen.
* Generally the associated runtime::Module will indicate it requires bindings for these names,
* and during module initialization these bindings will be recovered from a ConstLoaderModule.
* See also kConstantsArray above, which is the analog for PrimFuncs.
*
* Type: Map<String, runtime::NDArray>
*/
constexpr const char* kConstantsArray = "Constants";
constexpr const char* kConstNameToConstant = "const_name_to_constant";

} // namespace attr
} // namespace tvm
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,9 +599,9 @@ class AllocateConstNode : public StmtNode {
/*! \brief The optional data associated to the constant.
*/
Optional<runtime::NDArray> data;
/*! \brief If the PrimFunc containing the Stmt is added to IRModule,
this is an optional index to indicate the index within
"Constants" attribute, that is a Array<NDArray> of IRModule.
/*!
* \brief If the PrimFunc containing the Stmt is added to IRModule, this is an optional index
* to indicate the index within "constants" attribute, that is a Array<NDArray> of IRModule.
*/
Optional<Integer> irmod_storage_idx;
/*! \brief The type of the buffer. */
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class Interpreter(Executor):
The runtime device to run the code on.

target : tvm.Target
The target option to build the function.
The target option to build the function. Only homogeneous execution is supported.

CAUTION: Despite the API the module is prepared upon each call to evaluate
rather than once in create_executor.
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,9 @@ class VMExecutor(Executor):
device : :py:class:`~tvm.runtime.Device`
The runtime device to run the code on.

target : :py:class:`Target`
The target option to build the function.
target : any multi-target like object, see Target.canon_multi_target
For homogeneous compilation, the unique build target.
For heterogeneous compilation, a dictionary or list of possible build targets.
"""

def __init__(self, mod, device, target):
Expand Down
41 changes: 26 additions & 15 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,9 @@ class GraphExecutor(_interpreter.Executor):
device : :py:class:`Device`
The runtime device to run the code on.

target : :py:class:`Target`
The target option to build the function.
target : any multi-target like object, see Target.canon_multi_target
For homogeneous compilation, the unique build target.
For heterogeneous compilation, a dictionary or list of possible build targets.
"""

def __init__(self, mod, device, target):
Expand Down Expand Up @@ -630,16 +631,16 @@ class AotExecutor(_interpreter.Executor):
device : :py:class:`Device`
The runtime device to run the code on.

target : :py:class:`Target`
The target option to build the function.
target : any multi-target like object, see Target.canon_multi_target
For homogeneous compilation, the unique build target.
For heterogeneous compilation, a dictionary or list of possible build targets.
"""

def __init__(self, mod, device, target):
assert mod is not None
self.mod = mod
self.device = device
self.target = target
assert target.attrs.get("executor", "graph") == "aot"

def _make_executor(self, expr=None):
if expr:
Expand Down Expand Up @@ -719,8 +720,11 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N
device : :py:class:`Device`
The device to execute the code.

target : :py:class:`tvm.Target`
The corresponding context
target : any multi-target like object, see Target.canon_multi_target
For homogeneous compilation, the unique build target.
For heterogeneous compilation, a dictionary or list of possible build targets.
CAUTION: Though this API allows multiple targets, it does not allow multiple devices, so
heterogenous compilation is not yet supported.

params : dict of str to NDArray
Input parameters to the graph that do not change
Expand All @@ -730,24 +734,31 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N
-------
executor : :py:class:`~tvm.relay.backend.interpreter.Executor`
"""
raw_targets = Target.canon_multi_target(target)
mbs-octoml marked this conversation as resolved.
Show resolved Hide resolved
if mod is None:
mod = IRModule()
if device is not None:
assert device.device_type == _nd.device(str(target), 0).device_type
assert device.device_type == raw_targets[0].kind.device_type
else:
device = _nd.device(str(target), 0)
# Derive the default device from the first target.
device = _nd.device(raw_targets[0].kind.device_type, 0)

if params is not None:
mod = IRModule.from_expr(bind_params_by_name(mod["main"], params))

if isinstance(target, str):
target = Target(target)
assert "executor" not in raw_targets[0].attrs or raw_targets[0].attrs["executor"] == kind

if kind == "debug":
return _interpreter.Interpreter(mod, device, target)
assert len(raw_targets) == 1, "The interpreter currently only supports a single target"
return _interpreter.Interpreter(mod, device, raw_targets[0])
if kind == "graph":
return GraphExecutor(mod, device, target)
return GraphExecutor(mod, device, raw_targets)
if kind == "vm":
return VMExecutor(mod, device, target)
return VMExecutor(mod, device, raw_targets)
if kind == "aot":
return AotExecutor(mod, device, target)
# The AOT requires the executor as a target attribute.
# (The compilation paths for the other executors currently do not always provide this
# attribute, hence the above generic assert is more forgiving).
assert "executor" in raw_targets[0].attrs
return AotExecutor(mod, device, raw_targets)
raise RuntimeError("unknown execution strategy: {0}".format(kind))
4 changes: 2 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,7 +1386,7 @@ def OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter=""):
Parameters
----------
compiler_filter : String
If non-empty, the 'compiler' attribute to filter on.
If non-empty, the "Compiler" attribute to filter on.

Returns
-------
Expand All @@ -1412,7 +1412,7 @@ def MarkCompilerFunctionsAsExtern(compiler_filter=""):
Parameters
----------
compiler_filter : String
If non-empty, the 'compiler' attribute to filter on.
If non-empty, the "Compiler" attribute to filter on.

Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ class AllocateConst(Stmt):
data_or_idx : Union[NDArray, int]
If an NDArray, this is the const data associated with the
constant. If an integer, this is the index into the
"Constants" attribute of the `IRModule` that contains the
"constants" attribute of the `IRModule` that contains the
`AllocateConst`.

body : Stmt
Expand Down
39 changes: 17 additions & 22 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1167,11 +1167,19 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// because the packed calls arguments are not wrapped in TVMValues. To make this happen we need
// to run the LegalizePackedCalls pass.
LoweredOutput ret;
ret.params = std::unordered_map<std::string, std::pair<int, const tvm::runtime::NDArray>>();
for (auto param : params_) {
ret.params.emplace(std::make_pair(
param.first,
std::make_pair(static_cast<int>(param_storage_ids_[param.first]), param.second)));

// Collect any constants extracted by external codegen.
ret.params = std::unordered_map<std::string, tvm::runtime::NDArray>();
Map<String, runtime::NDArray> const_name_to_constant =
lowered_mod->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToConstant)
.value_or({});
for (const auto& kv : const_name_to_constant) {
ICHECK(ret.params.emplace(kv.first, kv.second).second);
}

// Collect any constants extracted during lowering.
for (const auto& kv : params_) {
ICHECK(ret.params.emplace(kv.first, kv.second).second);
}

// AoT Executor codegen works completely on TIR beyond this point, hence removing relay main
Expand Down Expand Up @@ -1212,9 +1220,9 @@ class AOTExecutorCodegen : public MixedModeVisitor {
lowered_mod = pack_calls(lowered_mod);
}

Optional<Array<tvm::runtime::Module>> external_modules =
lowered_mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point.";
// Collect any runtime modules generated by external codegen.
ret.external_mods =
lowered_mod->GetAttr<Array<tvm::runtime::Module>>(tvm::attr::kExternalMods).value_or({});

// This is the point where we separate the functions in the module by target
VLOG(1) << "lowered module:" << std::endl << PrettyPrint(lowered_mod);
Expand All @@ -1227,8 +1235,6 @@ class AOTExecutorCodegen : public MixedModeVisitor {
<< PrettyPrint(kv.second);
}

ret.external_mods = external_modules.value();

// Extract USMP metadata to pass onto metadata sources
Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_var_info;
std::vector<tir::Var> pool_vars;
Expand Down Expand Up @@ -1316,11 +1322,6 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
String key = args[0];
*rv = get_param_by_name(key);
});
} else if (name == "get_param_id") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
String key = args[0];
*rv = get_param_id(key);
});
} else if (name == "get_irmodule") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_irmodule(); });
Expand Down Expand Up @@ -1362,17 +1363,11 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
runtime::NDArray get_param_by_name(String key) {
auto it = this->output_.params.find(key);
CHECK(it != this->output_.params.end()) << "no such parameter " << key;
return (*it).second.second;
return (*it).second;
mbs-octoml marked this conversation as resolved.
Show resolved Hide resolved
}

Array<tvm::runtime::Module> get_external_modules() { return output_.external_mods; }

int get_param_id(String key) {
auto it = this->output_.params.find(key);
CHECK(it != this->output_.params.end()) << "no such parameter " << key;
return (*it).second.first;
}

Map<Target, IRModule> get_irmodule() { return this->output_.lowered_funcs; }

std::shared_ptr<AOTExecutorCodegen> codegen_;
Expand Down
12 changes: 1 addition & 11 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,6 @@ struct ExecutorCodegen {
return ret;
}

std::unordered_map<std::string, int64_t> GetParamIds() {
std::unordered_map<std::string, int64_t> ret;
auto names = CallFunc<Array<runtime::String>>("list_params_name", nullptr);
for (const auto& expr : names) {
// Implicit cast from runtime::String to std::string
std::string key = expr;
ret[key] = CallFunc<int64_t>("get_param_id", key);
}
return ret;
}

Array<tvm::runtime::Module> GetExternalModules() {
return CallFunc<Array<tvm::runtime::Module>>("get_external_modules", nullptr);
}
Expand Down Expand Up @@ -478,6 +467,7 @@ class RelayBuildModule : public runtime::ModuleNode {
for (size_t i = 0; i < variables.size(); i++) {
auto it = ret_.params.find(variables[i].operator std::string());
if (it != ret_.params.end()) {
VLOG(1) << "constant '" << variables[i] << "' has been captured in external module";
ret_.params.erase(it);
}
}
Expand Down
9 changes: 7 additions & 2 deletions src/relay/backend/contrib/arm_compute_lib/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,15 @@ runtime::Module ACLCompiler(const ObjectRef& ref) {
ACLJSONSerializer serializer(func_name, func);
serializer.serialize();
std::string graph_json = serializer.GetJSON();
auto param_names = serializer.GetParams();

// Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes
// a callback which calls backend::UpdateConstants to capture the map before the function
// 'disappears' into lowered form, on the assumption the visit order and thus constant
// names match those generated by the JSONSerializer.

const auto* pf = runtime::Registry::Get("runtime.arm_compute_lib_runtime_create");
ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create";
runtime::Module lib = (*pf)(func_name, graph_json, param_names);
runtime::Module lib = (*pf)(func_name, graph_json, serializer.const_names());
return lib;
}

Expand Down
8 changes: 6 additions & 2 deletions src/relay/backend/contrib/bnns/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,15 @@ runtime::Module BNNSCompiler(const ObjectRef& ref) {
BNNSJSONSerializer serializer(func_name, func);
serializer.serialize();
std::string graph_json = serializer.GetJSON();
auto params = serializer.GetParams();

// Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes
// a callback which calls backend::UpdateConstants to capture the map before the function
// 'disappears' into lowered form, on the assumption the visit order and thus constant
// names match those generated by the JSONSerializer.

const auto* pf = runtime::Registry::Get("runtime.BNNSJSONRuntimeCreate");
ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create";
auto mod = (*pf)(func_name, graph_json, params);
auto mod = (*pf)(func_name, graph_json, serializer.const_names());
return mod;
}

Expand Down
Loading