Skip to content

Commit

Permalink
[BYOC] Handle constants in IRModule-at-a-time external codegen
Browse files Browse the repository at this point in the history
I tried to do to the TensorRT integration what #11631 did to the CUTLASS integration, viz:
 - Make sure all compilation options are passed in Target instances. This helps Collage.
 - Use a custom pass invoked via RelayToTIRTargetHooks instead of the relay.ext.$toolchain mechanism.
   This helps use decouple external codegen from lowering.

This PR collects the prep for that change:
 - TensorRT uses the JSONSerializer visitor to encode each partition function. Previously, when the
   visitor encountered a Constant it simply generated and recorded a name for the constant. Then,
   completely separately, and via a callback in TECompiler, the function is visited again in the
   same order and with the same name generation convention by a ConstantUpdater to actually collect the
   bindings, which are then encoded into a ConstLoaderModule to be made available at runtime.

   However if all TensorRT compilation is to be done by a stand-alone pass there's no TECompiler callback
   hackery available. So I've added a "const_name_to_ndarray" attribute to the IRModule of type
   Map<String, runtime::NDArray> so that named constants can be accumulated throughout compilation by
   any pass which needs to do so. Then the Graph, AOT and VM executors are all updated to merge those
   constants into the final runtime artifact

   (Compare with "Constants", the equivalent attribute for extracting TIR AllocateConsts.)

 - The TensorRT tests use the create_executor interface but it wasn't quite ready for the
   new more general form of passing list-of-targets.

 - I want TensorRT compilation to work out of the box without the need for any special targets if
   all the default options should apply. Go back and make the CUTLASS integration I did follow the
   same convention.

 - To test this I also switched the 'demo' "ccompiler" external codegen target to IRModule-at-a-time
   style. This means we can test most of external codegen machinery in one place without depending on
   any target which may not be enabled in CI (eg TensorRT):
     - Target instances are plumbed correctly so compile-time options are available.
     - External modules are conveyed to the final export library.
     - Constant bindings are conveyed to the metadata module.
  • Loading branch information
mbs-octoml committed Jun 29, 2022
1 parent c9d0d25 commit ea69f71
Show file tree
Hide file tree
Showing 34 changed files with 571 additions and 262 deletions.
2 changes: 1 addition & 1 deletion cmake/modules/contrib/CODEGENC.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.

tvm_file_glob(GLOB CSOURCE_RELAY_CONTRIB_SRC src/relay/backend/contrib/codegen_c/codegen.cc)
tvm_file_glob(GLOB CSOURCE_RELAY_CONTRIB_SRC src/relay/backend/contrib/codegen_c/*.cc)
list(APPEND COMPILER_SRCS ${CSOURCE_RELAY_CONTRIB_SRC})

30 changes: 27 additions & 3 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,10 @@ TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true,

namespace attr {

// Following are attributes for IRModule only.

/*!
* \brief Executor targetted by the module
* \brief Executor targeted by the module
*
* Type: Executor
*
Expand Down Expand Up @@ -516,9 +518,31 @@ constexpr const char* kWorkspaceMemoryPools = "workspace_memory_pools";
constexpr const char* kConstantMemoryPools = "constant_memory_pools";

/*
* \brief Module attribute for tir constants
* \brief All the runtime::NDArrays extracted from PrimFunc tir::AllocateConst nodes. The
* node will record the index into this array. See also kConstNameToConstant below, which is
* the analog for Realy Functions.
*
* Type: Array<runtime::NDArray>
*/
constexpr const char* kConstants = "constants";

/*!
* \brief All the runtime::Modules accumulated during compilation by external codegen. These
* modules must be either directly linked or captured in the final compilation artifact.
*
* Type: Array<runtime::Module>
*/
constexpr const char* kExternalMods = "external_mods";

/*!
* \brief All the named runtime::NDArrays accumulated during compilation by external codegen.
* Generally the associated runtime::Module will indicate it requires bindings for these names,
* and during module initialization these bindings will be recovered from a ConstLoaderModule.
* See also kConstantsArray above, which is the analog for PrimFuncs.
*
* Type: Map<String, runtime::NDArray>
*/
constexpr const char* kConstantsArray = "Constants";
constexpr const char* kConstNameToConstant = "const_name_to_constant";

} // namespace attr
} // namespace tvm
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,9 +599,9 @@ class AllocateConstNode : public StmtNode {
/*! \brief The optional data associated to the constant.
*/
Optional<runtime::NDArray> data;
/*! \brief If the PrimFunc containing the Stmt is added to IRModule,
this is an optional index to indicate the index within
"Constants" attribute, that is a Array<NDArray> of IRModule.
/*!
* \brief If the PrimFunc containing the Stmt is added to IRModule, this is an optional index
* to indicate the index within "constants" attribute, that is a Array<NDArray> of IRModule.
*/
Optional<Integer> irmod_storage_idx;
/*! \brief The type of the buffer. */
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class Interpreter(Executor):
The runtime device to run the code on.
target : tvm.Target
The target option to build the function.
The target option to build the function. Only homogeneous execution is supported.
CAUTION: Despite the API the module is prepared upon each call to evaluate
rather than once in create_executor.
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,9 @@ class VMExecutor(Executor):
device : :py:class:`~tvm.runtime.Device`
The runtime device to run the code on.
target : :py:class:`Target`
The target option to build the function.
target : any multi-target like object, see Target.canon_multi_target
For homogeneous compilation, the unique build target.
For heterogeneous compilation, a dictionary or list of possible build targets.
"""

def __init__(self, mod, device, target):
Expand Down
41 changes: 26 additions & 15 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,9 @@ class GraphExecutor(_interpreter.Executor):
device : :py:class:`Device`
The runtime device to run the code on.
target : :py:class:`Target`
The target option to build the function.
target : any multi-target like object, see Target.canon_multi_target
For homogeneous compilation, the unique build target.
For heterogeneous compilation, a dictionary or list of possible build targets.
"""

def __init__(self, mod, device, target):
Expand Down Expand Up @@ -630,16 +631,16 @@ class AotExecutor(_interpreter.Executor):
device : :py:class:`Device`
The runtime device to run the code on.
target : :py:class:`Target`
The target option to build the function.
target : any multi-target like object, see Target.canon_multi_target
For homogeneous compilation, the unique build target.
For heterogeneous compilation, a dictionary or list of possible build targets.
"""

def __init__(self, mod, device, target):
assert mod is not None
self.mod = mod
self.device = device
self.target = target
assert target.attrs.get("executor", "graph") == "aot"

def _make_executor(self, expr=None):
if expr:
Expand Down Expand Up @@ -719,8 +720,11 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N
device : :py:class:`Device`
The device to execute the code.
target : :py:class:`tvm.Target`
The corresponding context
target : any multi-target like object, see Target.canon_multi_target
For homogeneous compilation, the unique build target.
For heterogeneous compilation, a dictionary or list of possible build targets.
CAUTION: Though this API allows multiple targets, it does not allow multiple devices, so
heterogenous compilation is not yet supported.
params : dict of str to NDArray
Input parameters to the graph that do not change
Expand All @@ -730,24 +734,31 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N
-------
executor : :py:class:`~tvm.relay.backend.interpreter.Executor`
"""
raw_targets = Target.canon_multi_target(target)
if mod is None:
mod = IRModule()
if device is not None:
assert device.device_type == _nd.device(str(target), 0).device_type
assert device.device_type == raw_targets[0].kind.device_type
else:
device = _nd.device(str(target), 0)
# Derive the default device from the first target.
device = _nd.device(raw_targets[0].kind.device_type, 0)

if params is not None:
mod = IRModule.from_expr(bind_params_by_name(mod["main"], params))

if isinstance(target, str):
target = Target(target)
assert "executor" not in raw_targets[0].attrs or raw_targets[0].attrs["executor"] == kind

if kind == "debug":
return _interpreter.Interpreter(mod, device, target)
assert len(raw_targets) == 1, "The interpreter currently only supports a single target"
return _interpreter.Interpreter(mod, device, raw_targets[0])
if kind == "graph":
return GraphExecutor(mod, device, target)
return GraphExecutor(mod, device, raw_targets)
if kind == "vm":
return VMExecutor(mod, device, target)
return VMExecutor(mod, device, raw_targets)
if kind == "aot":
return AotExecutor(mod, device, target)
# The AOT requires the executor as a target attribute.
# (The compilation paths for the other executors currently do not always provide this
# attribute, hence the above generic assert is more forgiving).
assert "executor" in raw_targets[0].attrs
return AotExecutor(mod, device, raw_targets)
raise RuntimeError("unknown execution strategy: {0}".format(kind))
4 changes: 2 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,7 +1386,7 @@ def OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter=""):
Parameters
----------
compiler_filter : String
If non-empty, the 'compiler' attribute to filter on.
If non-empty, the "Compiler" attribute to filter on.
Returns
-------
Expand All @@ -1412,7 +1412,7 @@ def MarkCompilerFunctionsAsExtern(compiler_filter=""):
Parameters
----------
compiler_filter : String
If non-empty, the 'compiler' attribute to filter on.
If non-empty, the "Compiler" attribute to filter on.
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ class AllocateConst(Stmt):
data_or_idx : Union[NDArray, int]
If an NDArray, this is the const data associated with the
constant. If an integer, this is the index into the
"Constants" attribute of the `IRModule` that contains the
"constants" attribute of the `IRModule` that contains the
`AllocateConst`.
body : Stmt
Expand Down
39 changes: 17 additions & 22 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1167,11 +1167,19 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// because the packed calls arguments are not wrapped in TVMValues. To make this happen we need
// to run the LegalizePackedCalls pass.
LoweredOutput ret;
ret.params = std::unordered_map<std::string, std::pair<int, const tvm::runtime::NDArray>>();
for (auto param : params_) {
ret.params.emplace(std::make_pair(
param.first,
std::make_pair(static_cast<int>(param_storage_ids_[param.first]), param.second)));

// Collect any constants extracted by external codegen.
ret.params = std::unordered_map<std::string, tvm::runtime::NDArray>();
Map<String, runtime::NDArray> const_name_to_constant =
lowered_mod->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToConstant)
.value_or({});
for (const auto& kv : const_name_to_constant) {
ICHECK(ret.params.emplace(kv.first, kv.second).second);
}

// Collect any constants extracted during lowering.
for (const auto& kv : params_) {
ICHECK(ret.params.emplace(kv.first, kv.second).second);
}

// AoT Executor codegen works completely on TIR beyond this point, hence removing relay main
Expand Down Expand Up @@ -1212,9 +1220,9 @@ class AOTExecutorCodegen : public MixedModeVisitor {
lowered_mod = pack_calls(lowered_mod);
}

Optional<Array<tvm::runtime::Module>> external_modules =
lowered_mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point.";
// Collect any runtime modules generated by external codegen.
ret.external_mods =
lowered_mod->GetAttr<Array<tvm::runtime::Module>>(tvm::attr::kExternalMods).value_or({});

// This is the point where we separate the functions in the module by target
VLOG(1) << "lowered module:" << std::endl << PrettyPrint(lowered_mod);
Expand All @@ -1227,8 +1235,6 @@ class AOTExecutorCodegen : public MixedModeVisitor {
<< PrettyPrint(kv.second);
}

ret.external_mods = external_modules.value();

// Extract USMP metadata to pass onto metadata sources
Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_var_info;
std::vector<tir::Var> pool_vars;
Expand Down Expand Up @@ -1316,11 +1322,6 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
String key = args[0];
*rv = get_param_by_name(key);
});
} else if (name == "get_param_id") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
String key = args[0];
*rv = get_param_id(key);
});
} else if (name == "get_irmodule") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_irmodule(); });
Expand Down Expand Up @@ -1362,17 +1363,11 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
runtime::NDArray get_param_by_name(String key) {
auto it = this->output_.params.find(key);
CHECK(it != this->output_.params.end()) << "no such parameter " << key;
return (*it).second.second;
return (*it).second;
}

Array<tvm::runtime::Module> get_external_modules() { return output_.external_mods; }

int get_param_id(String key) {
auto it = this->output_.params.find(key);
CHECK(it != this->output_.params.end()) << "no such parameter " << key;
return (*it).second.first;
}

Map<Target, IRModule> get_irmodule() { return this->output_.lowered_funcs; }

std::shared_ptr<AOTExecutorCodegen> codegen_;
Expand Down
12 changes: 1 addition & 11 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,6 @@ struct ExecutorCodegen {
return ret;
}

std::unordered_map<std::string, int64_t> GetParamIds() {
std::unordered_map<std::string, int64_t> ret;
auto names = CallFunc<Array<runtime::String>>("list_params_name", nullptr);
for (const auto& expr : names) {
// Implicit cast from runtime::String to std::string
std::string key = expr;
ret[key] = CallFunc<int64_t>("get_param_id", key);
}
return ret;
}

Array<tvm::runtime::Module> GetExternalModules() {
return CallFunc<Array<tvm::runtime::Module>>("get_external_modules", nullptr);
}
Expand Down Expand Up @@ -478,6 +467,7 @@ class RelayBuildModule : public runtime::ModuleNode {
for (size_t i = 0; i < variables.size(); i++) {
auto it = ret_.params.find(variables[i].operator std::string());
if (it != ret_.params.end()) {
VLOG(1) << "constant '" << variables[i] << "' has been captured in external module";
ret_.params.erase(it);
}
}
Expand Down
9 changes: 7 additions & 2 deletions src/relay/backend/contrib/arm_compute_lib/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,15 @@ runtime::Module ACLCompiler(const ObjectRef& ref) {
ACLJSONSerializer serializer(func_name, func);
serializer.serialize();
std::string graph_json = serializer.GetJSON();
auto param_names = serializer.GetParams();

// Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes
// a callback which calls backend::UpdateConstants to capture the map before the function
// 'disappears' into lowered form, on the assumption the visit order and thus constant
// names match those generated by the JSONSerializer.

const auto* pf = runtime::Registry::Get("runtime.arm_compute_lib_runtime_create");
ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create";
runtime::Module lib = (*pf)(func_name, graph_json, param_names);
runtime::Module lib = (*pf)(func_name, graph_json, serializer.const_names());
return lib;
}

Expand Down
8 changes: 6 additions & 2 deletions src/relay/backend/contrib/bnns/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,15 @@ runtime::Module BNNSCompiler(const ObjectRef& ref) {
BNNSJSONSerializer serializer(func_name, func);
serializer.serialize();
std::string graph_json = serializer.GetJSON();
auto params = serializer.GetParams();

// Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes
// a callback which calls backend::UpdateConstants to capture the map before the function
// 'disappears' into lowered form, on the assumption the visit order and thus constant
// names match those generated by the JSONSerializer.

const auto* pf = runtime::Registry::Get("runtime.BNNSJSONRuntimeCreate");
ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create";
auto mod = (*pf)(func_name, graph_json, params);
auto mod = (*pf)(func_name, graph_json, serializer.const_names());
return mod;
}

Expand Down
Loading

0 comments on commit ea69f71

Please sign in to comment.