Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LowerTEPass, and convert calls to LowerTE to application of LowerTEPass #8802

Merged
merged 12 commits into from
Aug 24, 2021
1 change: 0 additions & 1 deletion include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ constexpr const char* kComposite = "Composite";
constexpr const char* kInline = "Inline";
/*! \brief Indicate the function was created by the Pattern Partitioning Pass. */
constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";

/*! \brief Mark the function as only composed of reshape operations. */
constexpr const char* kReshapeOnly = "relay.reshape_only";
} // namespace attr
Expand Down
17 changes: 9 additions & 8 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,9 @@ class AOTExecutorCodegen : public ExprVisitor {
// to instead explicitly lowering the incoming IRModule, and then
// performing the preexisting AOT executor code generation phase.
IRModule mod = IRModule::FromExpr(func);
auto lowered_module = tec::LowerTE(
mod, targets_, device_context_map, memory_plan, mod_name, [this](Function func) {

IRModule new_mod =
LowerTEPass(targets_, device_context_map, memory_plan, mod_name, [this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand All @@ -599,8 +600,9 @@ class AOTExecutorCodegen : public ExprVisitor {
// execute as a further pass, instead writing data to the
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_);
});
})(mod);

tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod);
function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info);
auto lowered_main = lowered_module.main_module->Lookup("main");
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
Expand Down Expand Up @@ -667,11 +669,10 @@ class AOTExecutorCodegen : public ExprVisitor {
ret.lowered_funcs = lowered_module.per_target_module;
ret.external_mods = lowered_module.external_mods;

auto target_host_str = target_host_->str();
if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) {
ret.lowered_funcs[target_host_str]->Update(mod_run);
if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) {
ret.lowered_funcs[target_host_]->Update(mod_run);
} else {
ret.lowered_funcs.Set(target_host_str, mod_run);
ret.lowered_funcs.Set(target_host_, mod_run);
}

std::vector<String> input_var_names(input_vars_.size());
Expand Down Expand Up @@ -776,7 +777,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
return (*it).second.first;
}

Map<String, IRModule> get_irmodule() { return this->output_.lowered_funcs; }
Map<Target, IRModule> get_irmodule() { return this->output_.lowered_funcs; }

std::shared_ptr<AOTExecutorCodegen> codegen_;
LoweredOutput output_;
Expand Down
7 changes: 4 additions & 3 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
device_context_map.insert({expr, dev});
}

auto lowered_module = tec::LowerTE(
mod, targets_, device_context_map, memory_plan_, mod_name_, [this](Function func) {
IRModule new_mod =
LowerTEPass(targets_, device_context_map, memory_plan_, mod_name_, [this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand All @@ -234,8 +234,9 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
// execute as a further pass, instead writing data to the
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_);
});
})(mod);

tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod);
function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info);
auto main_module = lowered_module.main_module;
main_module = relay::transform::InferType()(main_module);
Expand Down
18 changes: 9 additions & 9 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
public:
// TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule.
Interpreter(IRModule mod, Map<String, IRModule> per_target_module, Device device, Target target)
Interpreter(IRModule mod, Map<Target, IRModule> per_target_module, Device device, Target target)
: mod_(mod),
per_target_module_(per_target_module),
device_(device),
Expand Down Expand Up @@ -373,7 +373,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
*/
PackedFunc TIRToPackedFunc(const GlobalVar& tir_fn_var, const Array<GlobalVar>& all_tir_fn_vars,
Target target) {
std::pair<std::string, std::string> packed_func_key(target->str(), tir_fn_var->name_hint);
std::pair<Target, std::string> packed_func_key(target, tir_fn_var->name_hint);
auto packed_itr = compiled_packed_funcs_.find(packed_func_key);
if (packed_itr != compiled_packed_funcs_.end()) {
// Already compiled.
Expand All @@ -382,7 +382,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,

// Project out just the function(s) we need.
IRModule lowered_projected_mod;
auto mod_itr = per_target_module_.find(target->str());
auto mod_itr = per_target_module_.find(target);
ICHECK(mod_itr != per_target_module_.end())
<< "No target module for target '" << target->str() << "'";
const IRModule& target_module = (*mod_itr).second;
Expand All @@ -407,7 +407,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
PackedFunc packed_func = runtime_module.GetFunction(var->name_hint);
ICHECK(packed_func != nullptr) << "No packed function for global var '" << var->name_hint
<< "' in compiled module for target '" << target->str() << "'";
compiled_packed_funcs_.emplace(std::make_pair(target->str(), var->name_hint), packed_func);
compiled_packed_funcs_.emplace(std::make_pair(target, var->name_hint), packed_func);
}

// Return just what we need for this call.
Expand Down Expand Up @@ -874,10 +874,10 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
// Map from target key to lowered TIR functions derived from mod_.
// Note that primitives are implicitly executed on target_, while shape functions are implicitly
// executed on the default 'cpu' host. Thus this map has at most two entries.
Map<String, IRModule> per_target_module_;
Map<Target, IRModule> per_target_module_;
// Cached packed functions for the primitives and shape functions, keyed by target and
// global var name.
std::unordered_map<std::pair<std::string, std::string>, PackedFunc, PairHash>
std::unordered_map<std::pair<Target, std::string>, PackedFunc, PairHash>
compiled_packed_funcs_;
// Unique device on which primitives (but not shape functions) will be executed.
// (For simplicity we only run the interpreter on a single device.)
Expand All @@ -895,7 +895,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
* rewritten \p mod and target-specific modules containing bindings for all TIR primitive
* functions needed by the rewritten module.
*/
std::pair<IRModule, Map<String, IRModule>> Prepare(IRModule mod, Device device, Target target) {
std::pair<IRModule, Map<Target, IRModule>> Prepare(IRModule mod, Device device, Target target) {
// Run minimal transforms on module to establish invariants needed by interpreter.
transform::Sequential seq({transform::SimplifyInference(),
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
Expand Down Expand Up @@ -1014,7 +1014,7 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, De
// and can just eval it directly.
expr_to_eval = expr;
}
std::pair<IRModule, Map<String, IRModule>> main_and_lowered =
std::pair<IRModule, Map<Target, IRModule>> main_and_lowered =
Prepare(mod_with_expr, device, target);
std::shared_ptr<Interpreter> intrp = std::make_shared<Interpreter>(
/*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device,
Expand Down Expand Up @@ -1057,7 +1057,7 @@ ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<String> import_set, Device device, Target target) {
std::pair<IRModule, GlobalVar> mod_and_global =
IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set);
std::pair<IRModule, Map<String, IRModule>> main_and_lowered =
std::pair<IRModule, Map<Target, IRModule>> main_and_lowered =
Prepare(mod_and_global.first, device, target);
Interpreter intrp(
/*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device,
Expand Down
125 changes: 112 additions & 13 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "te_compiler.h"

#include <tvm/driver/driver_api.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/function.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
Expand Down Expand Up @@ -84,30 +85,30 @@ class TECompilerImpl : public TECompilerNode {
return LowerShapeFuncInternal(key)->cached_func;
}

Map<String, IRModule> GetLoweredFunctions() {
Map<String, IRModule> lowered_functions;
Map<Target, IRModule> GetLoweredFunctions() {
Map<Target, IRModule> lowered_functions;
for (const auto& it : cache_) {
auto source_func = it.first;
auto lowered_func = it.second;
auto target = source_func->target;

if (!lowered_functions.count(target->str())) {
lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
if (!lowered_functions.count(target)) {
lowered_functions.Set(target, IRModule(Map<GlobalVar, BaseFunc>({})));
}

lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
lowered_functions[target]->Update(lowered_func->cached_func->funcs);
}

for (const auto& it : shape_func_cache_) {
auto source_func = it.first;
auto lowered_func = it.second;
auto target = source_func->target;

if (!lowered_functions.count(target->str())) {
lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
if (!lowered_functions.count(target)) {
lowered_functions.Set(target, IRModule(Map<GlobalVar, BaseFunc>({})));
}

lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
lowered_functions[target]->Update(lowered_func->cached_func->funcs);
}
return lowered_functions;
}
Expand Down Expand Up @@ -749,8 +750,6 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap tar
relay_primfuncs);
}

// TODO(@electriclilies): Is the function passed in here relay_func??
// Also should this be inlined?
/*!
* \brief A function to create the function metadata for an input function (ie calculate buffer
* input/output sizes)
Expand Down Expand Up @@ -830,9 +829,6 @@ void UpdateFunctionMetadata(Function relay_func,
function_metadata.Set(prim_fn_var.value()->name_hint, fi);
}

// TODO(mbs): Make this an IRModule->IRModule pass by folding LoweredModule back into IRModule.
// Currently we rely on accumulating bindings inside the local TECompiler which we then
// host into the LoweredModule result.
LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn) {
Expand Down Expand Up @@ -875,6 +871,109 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic
return lowered_module;
}

IRModule LoweredModuleToIRModule(LoweredModule mod) {
IRModule unified_module;

// Copy the main module and its typedefs
unified_module->Update(mod.main_module);
for (const auto& kv : mod.main_module->type_definitions) {
unified_module->AddTypeDef(kv.first, kv.second);
}

// Annotate the per-target functions with thier target and add them to the unified module
electriclilies marked this conversation as resolved.
Show resolved Hide resolved
for (const auto& kv : mod.per_target_module) {
const Target target = kv.first;
const IRModule target_module = kv.second;

// Right now, per-target functions are TIR functions, which don't have type definitions, so
// there should be no type defs in the per_target_modules
size_t ty_def_size = target_module->type_definitions.size();
ICHECK(ty_def_size == 0)
<< "Expected there to be no type definitions in the per_target_modules, but found "
<< ty_def_size;

for (const auto& kv : target_module->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
ICHECK(func->IsInstance<tir::PrimFuncNode>())
<< "We expect the target_module to contain only PrimFuncs at this point, but got "
<< func->GetTypeKey();
// TODO(@electriclilies): Change to Target object if possible
tir::PrimFunc primFunc =
WithAttr(Downcast<tir::PrimFunc>(std::move(func)), tvm::attr::kTarget, target);
unified_module->Add(var, primFunc);
}
}

IRModule ret_mod = WithAttr(unified_module, "external_mods", mod.external_mods);
ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info);
return ret_mod;
}

LoweredModule IRModuleToLoweredModule(IRModule mod) {
IRModule main_mod;
// Copy just the TypeDefs from the IRModule to the LoweredModule's main module
// This is the only time we need to do this since there are no TypeDefs in TIR
for (const auto& kv : mod->type_definitions) {
main_mod->AddTypeDef(kv.first, kv.second);
}

Map<Target, IRModule> per_target_modules;
for (const auto& kv : mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
if (func->IsInstance<relay::FunctionNode>()) {
main_mod->Add(var, func);
} else if (func->IsInstance<tir::PrimFuncNode>()) {
// Extract target
Optional<Target> target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target) << "Target should be set at this point";

// Put the function in per_target_modules
if (!per_target_modules.count(target.value())) {
// Initialize the IRModule for this target and add the function
IRModule target_module;
target_module->Add(var, func);
per_target_modules.Set(target.value(), target_module);
} else {
// The IRModule for this target is initialized, so just add the function.
IRModule target_module = per_target_modules.at(target.value());
target_module->Add(var, func);
}
} else {
LOG(FATAL)
<< "The function types in the IRModule should be RelayFunction or PrimFunc, but got "
<< func->GetTypeKey();
}
}

// Put the LoweredModule together
LoweredModule lowered_module;
lowered_module.main_module = main_mod;
lowered_module.per_target_module = per_target_modules;

// Extract external modules and main func info, add to lowered module if they exist
auto external_mods = mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
if (external_mods) {
lowered_module.external_mods = external_mods.value();
}
auto main_func_info = mod->GetAttr<backend::FunctionInfo>("main_func_info");
if (main_func_info) {
lowered_module.main_func_info = main_func_info.value();
}
return lowered_module;
}

Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module,
PassContext ctx) {
return LoweredModuleToIRModule(
LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn));
};
return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {});
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[A question to the community and not specific to your PR Lily!]

This is a good example of code which could be easily unit tested in C++ in the, er, 'conventional' sense. That is, as a reader I could expect to go to tests/cpp/relay/backend/te_compiler_test.cc and look for TEST(IRModuleToLoweredModule, ...). Currently this new code is tested indirectly via it's use by LowerTEPass and consumers of such, which in turn are tested indirectly by virtue of everything passing into TIR via this choke point. Just wanted to test the water on whether folks on this PR have opinions here so I don't go off tilting at windmills.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. Since these two functions are supposed to be inverses of each other, it would be pretty easy to write a unit test for it in theory. When I was developing, I actually inserted the conversions in some places and ran existing unit tests to make sure that the functions worked, but it would be great to have a way to directly write unit tests in C++. That way I wouldn't have to remove my tests before merging!

Copy link
Member

@tqchen tqchen Aug 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently when we unit-test passes by exposing them via a python API, construct the expected input and output and run the tests in python:

https://github.com/apache/tvm/blob/main/tests/python/unittest/test_tir_transform_loop_partition.py#L30

There are certainly pros and cons of doing so. The original rationale is that we require most of the compiler passed to be accessible from python and it is relatively easier to construct and expand test cases.

We could revisit this pt on the need of the related testcases

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@electriclilies is there any reason you can't create a file similar to https://github.com/apache/tvm/blob/main/tests/cpp/build_module_test.cc and test the functions there?

Ideally I'd definitely like to see a C++ test setup as @mbs-octoml describes rather than the single folder but this would work here? It's not an absolute rule that we must expose via Python for testing is it @tqchen?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for most of the passes that can be modularized, we encourage the python first principle and expose via python. This one is a intermediate state so it is not an absolute rule to do so

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments. Perhaps a rule of thumb here is if it's part of the public api it should be tested on the py side, but otherwise should stay on the c++ side. I'm struggling to see how to write targeted unit tests on the py side without both risking making something internal part of the defacto api and without paying for all the unit test boundaries be ffi-able.

} // namespace tec
} // namespace relay
} // namespace tvm
Loading