Skip to content

Commit

Permalink
Add LowerTEPass, and convert calls to LowerTE to application of Lower…
Browse files Browse the repository at this point in the history
…TEPass (apache#8802)

* Initial commit

Initial stab at IRModule -> LoweredModule conversion func, notes

Add external_mods and main_func_info to conversion funcs

MTest lowered module to ir module

fix problem with conversion funcs + print stmts

Add LowerTE pass

Add pLowerTEPass

AAdd LowerTEPass to graph_executor_codegen.cc

Use LowerTEPass instead of LowerTe in graph_executor_codegen.cc

Code cleanup

Add docs, more cleanup

Formatting

* Fix bad rebase

* Address 1st round of comments

* Use tir kTarget instead of relay one

* Change target string to Target obj

* removing target string causing issues

* Fix typos

* Revert target str -> target obj changes

* Don't use Update : IRModule because it is broken

* Fix check

* flaky test?

* lint
  • Loading branch information
electriclilies authored and ylc committed Jan 13, 2022
1 parent 6cdc2ff commit 52f06e9
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 15 deletions.
1 change: 0 additions & 1 deletion include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ constexpr const char* kComposite = "Composite";
constexpr const char* kInline = "Inline";
/*! \brief Indicate the function was created by the Pattern Partitioning Pass. */
constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";

/*! \brief Mark the function as only composed of reshape operations. */
constexpr const char* kReshapeOnly = "relay.reshape_only";
} // namespace attr
Expand Down
8 changes: 5 additions & 3 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,9 @@ class AOTExecutorCodegen : public ExprVisitor {
// to instead explicitly lowering the incoming IRModule, and then
// performing the preexisting AOT executor code generation phase.
IRModule mod = IRModule::FromExpr(func);
auto lowered_module = tec::LowerTE(
mod, targets_, device_context_map, memory_plan, mod_name, [this](Function func) {

IRModule new_mod =
LowerTEPass(targets_, device_context_map, memory_plan, mod_name, [this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand All @@ -599,8 +600,9 @@ class AOTExecutorCodegen : public ExprVisitor {
// execute as a further pass, instead writing data to the
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_);
});
})(mod);

tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod);
function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info);
auto lowered_main = lowered_module.main_module->Lookup("main");
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
Expand Down
7 changes: 4 additions & 3 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
device_context_map.insert({expr, dev});
}

auto lowered_module = tec::LowerTE(
mod, targets_, device_context_map, memory_plan_, mod_name_, [this](Function func) {
IRModule new_mod =
LowerTEPass(targets_, device_context_map, memory_plan_, mod_name_, [this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand All @@ -234,8 +234,9 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
// execute as a further pass, instead writing data to the
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_);
});
})(mod);

tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod);
function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info);
auto main_module = lowered_module.main_module;
main_module = relay::transform::InferType()(main_module);
Expand Down
117 changes: 112 additions & 5 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "te_compiler.h"

#include <tvm/driver/driver_api.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/function.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
Expand Down Expand Up @@ -749,8 +750,6 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap tar
relay_primfuncs);
}

// TODO(@electriclilies): Is the function passed in here relay_func??
// Also should this be inlined?
/*!
* \brief A function to create the function metadata for an input function (ie calculate buffer
* input/output sizes)
Expand Down Expand Up @@ -830,9 +829,6 @@ void UpdateFunctionMetadata(Function relay_func,
function_metadata.Set(prim_fn_var.value()->name_hint, fi);
}

// TODO(mbs): Make this an IRModule->IRModule pass by folding LoweredModule back into IRModule.
// Currently we rely on accumulating bindings inside the local TECompiler which we then
// host into the LoweredModule result.
LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn) {
Expand Down Expand Up @@ -875,6 +871,117 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic
return lowered_module;
}

IRModule LoweredModuleToIRModule(LoweredModule mod) {
IRModule unified_module;

// Copy the main module and its typedefs
for (const auto& kv : mod.main_module->functions) {
unified_module->Add(kv.first, kv.second);
}
for (const auto& kv : mod.main_module->type_definitions) {
unified_module->AddTypeDef(kv.first, kv.second);
}

// Annotate the per-target functions with their target and add them to the unified module
for (const auto& kv : mod.per_target_module) {
const String target = kv.first;
const IRModule target_module = kv.second;

// Right now, per-target functions are TIR functions, which don't have type definitions, so
// there should be no type defs in the per_target_modules
size_t ty_def_size = target_module->type_definitions.size();
ICHECK(ty_def_size == 0)
<< "Expected there to be no type definitions in the per_target_modules, but found "
<< ty_def_size;

for (const auto& kv : target_module->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
if (func->IsInstance<tir::PrimFuncNode>()) {
tir::PrimFunc primFunc =
WithAttr(Downcast<tir::PrimFunc>(std::move(func)), tvm::attr::kTarget, target);
unified_module->Add(var, primFunc);
} else if (func->IsInstance<relay::FunctionNode>()) {
relay::Function relayFunc =
WithAttr(Downcast<relay::Function>(std::move(func)), tvm::attr::kTarget, target);
unified_module->Add(var, relayFunc);
} else {
LOG(FATAL)
<< "We expected to only have PrimFuncs or RelayFuncs in the target modules, but found "
<< func->GetTypeKey();
}
}
}

IRModule ret_mod = WithAttr(unified_module, "external_mods", mod.external_mods);
ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info);
return ret_mod;
}

LoweredModule IRModuleToLoweredModule(IRModule mod) {
IRModule main_mod;
// Copy just the TypeDefs from the IRModule to the LoweredModule's main module
// This is the only time we need to do this since there are no TypeDefs in TIR
for (const auto& kv : mod->type_definitions) {
main_mod->AddTypeDef(kv.first, kv.second);
}

Map<String, IRModule> per_target_modules;
for (const auto& kv : mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
if (func->IsInstance<relay::FunctionNode>()) {
main_mod->Add(var, func);
} else if (func->IsInstance<tir::PrimFuncNode>()) {
// Extract target
Optional<String> target = func->GetAttr<String>(tvm::attr::kTarget);
ICHECK(target) << "Target should be set at this point";

// Put the function in per_target_modules
if (!per_target_modules.count(target.value())) {
// Initialize the IRModule for this target and add the function
IRModule target_module;
target_module->Add(var, func);
per_target_modules.Set(target.value(), target_module);
} else {
// The IRModule for this target is initialized, so just add the function.
IRModule target_module = per_target_modules.at(target.value());
target_module->Add(var, func);
}
} else {
LOG(FATAL)
<< "The function types in the IRModule should be RelayFunction or PrimFunc, but got "
<< func->GetTypeKey();
}
}

// Put the LoweredModule together
LoweredModule lowered_module;
lowered_module.main_module = main_mod;
lowered_module.per_target_module = per_target_modules;

// Extract external modules and main func info, add to lowered module if they exist
auto external_mods = mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
if (external_mods) {
lowered_module.external_mods = external_mods.value();
}
auto main_func_info = mod->GetAttr<backend::FunctionInfo>("main_func_info");
if (main_func_info) {
lowered_module.main_func_info = main_func_info.value();
}
return lowered_module;
}

Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module,
PassContext ctx) {
return LoweredModuleToIRModule(
LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn));
};
return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {});
}
} // namespace tec
} // namespace relay
} // namespace tvm
49 changes: 46 additions & 3 deletions src/relay/backend/te_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,30 +166,73 @@ void UpdateFunctionMetadata(Function relay_func,
/*!
* \brief Obtain the Target from the device type.
* If homogenous compilation, this will return the only target.
* If heteregenous compilation, this will select associated using the targets_ Map.
* If heterogeneous compilation, this will select the associated target using the
* targets_ Map.
*
* \param dev_type
* \return Target
*/
Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets);

/*! \brief Utility to convert a LoweredModule to an IRModule.
*
* This function takes all the target specific modules in LoweredModule and
* annotates their functions with the correct target, and puts all those functions
* in one IRModule.
* The purpose of this utility is to allow us to slowly remove LoweredModule from the codebase.
*
* \param mod The LoweredModule to convert.
* \return The IRModule form of the input LoweredModule.
*/
IRModule LoweredModuleToIRModule(LoweredModule mod);

/*! \brief Utility to convert an IRModule to a LoweredModule.
*
* This function takes all the functions in the IRModule and moves them into target-specific
* IRModules stored inside a LoweredModule.
* The purpose of this utility is to allow us to slowly remove LoweredModule from the codebase.
* \param mod The IRModule to convert.
* \return The LoweredModule form of the input IRModule.
*/
LoweredModule IRModuleToLoweredModule(IRModule mod);

/*! \brief Lower an IRModule's primitive functions to TIR.
*
* This is the "back half" of the Relay compiler which lowers "primitive functions"
* to TE expressions, schedules them, and then to TIR.
*
* \param compiler The TE-to-TIR compliler (which caches lowered functions)
* \param module The IRModule.
* \param targets The mapping for devices to targets.
* \param device_map An analysis result mapping each sub-expression to a device.
* \param memory_plan The memory plan used during lowering
* \param module_name The name of this module
* \param process_fn Callback allowing one-level up code generators to process
* each function that we lower
* \return The lowered module, see above.
*/
// TODO(@electriclilies): Not sure if this default initialization is correct...
LoweredModule LowerTE(
const IRModule& module, TargetMap targets, DeviceMap device_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
ProcessFn process_fn = [](Function f) {});

/*! \brief Pass to lower an IRModule's primitive functions to TIR.
*
* This is the "back half" of the Relay compiler which lowers "primitive functions"
* to TE expressions, schedules them, and then to TIR. This Pass calls LowerTE, and
* uses LoweredModuleToIRModule utility to convert the output LowerTE's output
* LoweredModule into an IRModule before returning it.
*
* \param targets The mapping for devices to targets.
* \param device_context_map An analysis result mapping each sub-expression to a device.
* \param memory_plan The memory plan used during lowering
* \param module_name The name of this module
* \param process_fn Callback allowing one-level up code generators to process
* each function that we lower
* \returns The pass which lowers primative functions to TIR
*/
transform::Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn);
} // namespace tec
} // namespace relay
} // namespace tvm
Expand Down

0 comments on commit 52f06e9

Please sign in to comment.