Skip to content

Commit

Permalink
Remove LoweredModule (#8886)
Browse files Browse the repository at this point in the history
* Remove LoweredModule

* Clean up some comments

* QEMU flaky tests

* Don't add external functions to the LoweredFunctions module

* QEMU flaky test

* Respond to feedback

* flaky test
  • Loading branch information
electriclilies authored Sep 3, 2021
1 parent a890bb9 commit 0744641
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 174 deletions.
2 changes: 1 addition & 1 deletion include/tvm/runtime/container/map.h
Original file line number Diff line number Diff line change
Expand Up @@ -1353,7 +1353,7 @@ class Map : public ObjectRef {
* Otherwise make a new copy of the array to ensure the current handle
* hold a unique copy.
*
* \return Handle to the internal node container(which ganrantees to be unique)
* \return Handle to the internal node container(which guarantees to be unique)
*/
MapNode* CopyOnWrite() {
if (data_.get() == nullptr) {
Expand Down
1 change: 0 additions & 1 deletion include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include <tvm/target/target_kind.h>

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

Expand Down
20 changes: 14 additions & 6 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// performing the preexisting AOT executor code generation phase.
IRModule mod = IRModule::FromExpr(func);

IRModule new_mod =
IRModule lowered_mod =
LowerTEPass(targets_, device_context_map, memory_plan, mod_name, [this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
Expand All @@ -598,9 +598,12 @@ class AOTExecutorCodegen : public MixedModeVisitor {
tec::UpdateFunctionMetadata(func, this->function_metadata_);
})(mod);

tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod);
function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info);
auto lowered_main = lowered_module.main_module->Lookup("main");
Optional<backend::FunctionInfo> main_func_info =
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());
auto lowered_main = lowered_mod->Lookup("main");

auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());

// Post-lowering storage map for writing main func - this should be the same map as previously
Expand Down Expand Up @@ -662,8 +665,13 @@ class AOTExecutorCodegen : public MixedModeVisitor {

ret.function_metadata = std::move(function_metadata_);

ret.lowered_funcs = lowered_module.per_target_module;
ret.external_mods = lowered_module.external_mods;
Optional<Array<tvm::runtime::Module>> external_modules =
lowered_mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
ICHECK(external_modules) << "Attribute \"external_modules\" should be set at this point.";

// This is the point where we separate the functions in the module by target
ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod);
ret.external_mods = external_modules.value();

if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) {
ret.lowered_funcs[target_host_]->Update(mod_run);
Expand Down
22 changes: 16 additions & 6 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
device_context_map.insert({expr, dev});
}

IRModule new_mod =
IRModule lowered_mod =
LowerTEPass(targets_, device_context_map, memory_plan_, mod_name_, [this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
Expand All @@ -236,9 +236,13 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
tec::UpdateFunctionMetadata(func, this->function_metadata_);
})(mod);

tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod);
function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info);
auto main_module = lowered_module.main_module;
Optional<backend::FunctionInfo> main_func_info =
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());

// Get only the Relay functions out of the lowered module so we can run type inference on them
IRModule main_module = tec::GetMainModule(lowered_mod);
main_module = relay::transform::InferType()(main_module);
relay::Function main_func = Downcast<relay::Function>(main_module->Lookup("main"));

Expand Down Expand Up @@ -270,8 +274,14 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
std::make_pair(static_cast<int>(param_storage_ids_[param.first]), param.second)));
}
ret.function_metadata = std::move(function_metadata_);
ret.lowered_funcs = lowered_module.per_target_module;
ret.external_mods = lowered_module.external_mods;

Optional<Array<tvm::runtime::Module>> external_modules =
lowered_mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
ICHECK(external_modules) << "Attribute \"external_modules\" should be set at this point.";

// This is the point where we separate the functions in the module by target
ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod);
ret.external_mods = external_modules.value();
return ret;
}

Expand Down
42 changes: 21 additions & 21 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack st
class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
public:
// TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule.
// TODO(mbs, electriclilies): Collapse mod and per_target_module once IRModule subsumes
// LoweredModule.
Interpreter(IRModule mod, Map<Target, IRModule> per_target_module, Device device, Target target)
: mod_(mod),
per_target_module_(per_target_module),
Expand Down Expand Up @@ -902,20 +903,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
* functions needed by the rewritten module.
*/
std::pair<IRModule, Map<Target, IRModule>> Prepare(IRModule mod, Device device, Target target) {
// Run minimal transforms on module to establish invariants needed by interpreter.
transform::Sequential seq({transform::SimplifyInference(),
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
// attribute.
transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(),
// eta expand to support constructors in argument position
transform::EtaExpand(
/*expand_constructor=*/true, /*expand_global_var=*/false),
transform::InferType()});

transform::PassContext pass_ctx = transform::PassContext::Current();
With<transform::PassContext> ctx(pass_ctx);
mod = seq(mod);

// Things to initialize to pass into tec::LowerTEPass
// We only have one device-specific target.
tec::TargetMap targets = {{device.device_type, target}};

Expand All @@ -925,13 +913,25 @@ std::pair<IRModule, Map<Target, IRModule>> Prepare(IRModule mod, Device device,
// No need for a memory plan.
backend::StaticMemoryPlan memory_plan; /*=nullptr*/

// Run minimal transforms on module to establish invariants needed by interpreter.
transform::Sequential seq(
{transform::SimplifyInference(),
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
// attribute.
transform::FuseOps(/*fuse_opt_level=*/0), transform::ToANormalForm(),
// eta expand to support constructors in argument position
transform::EtaExpand(
/*expand_constructor=*/true, /*expand_global_var=*/false),
transform::InferType(),
tec::LowerTEPass(targets, device_map, memory_plan, /*module_name=*/"intrp",
[](Function func) { /* no-op */ })});

transform::PassContext pass_ctx = transform::PassContext::Current();
With<transform::PassContext> ctx(pass_ctx);
mod = seq(mod);

// Lower all primitive functions reachable from expr.
// TODO(mbs): This should be just another pass in seq above, which requires LoweredModule to
// be merged into IRModule.
LoweredModule lowered_module =
tec::LowerTE(mod, targets, device_map, memory_plan, /*module_name=*/"intrp",
[](Function func) { /* no-op */ });
return {lowered_module.main_module, lowered_module.per_target_module};
return {tec::GetMainModule(mod), tec::GetPerTargetModules(mod)};
}

/*! \brief Check if an expression could be changed by \p Prepare.
Expand Down
159 changes: 61 additions & 98 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,33 +85,46 @@ class TECompilerImpl : public TECompilerNode {
return LowerShapeFuncInternal(key)->cached_func;
}

Map<Target, IRModule> GetLoweredFunctions() {
std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
lowered_functions;
IRModule GetLoweredFunctions() {
IRModule mod;
// Extract lowered functions from the cache
for (const auto& it : cache_) {
auto source_func = it.first;
auto lowered_func = it.second;
auto target = source_func->target;

if (!lowered_functions.count(target)) {
lowered_functions[target] = IRModule(Map<GlobalVar, BaseFunc>({}));
}
IRModule lowered_mod = lowered_func->cached_func->funcs;

lowered_functions[target]->Update(lowered_func->cached_func->funcs);
}
// Annotate functions with their target and put them in the return module
for (auto kv : lowered_mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;

// Only add functions that are not external functions
if (!func->GetAttr<String>(attr::kCompiler).defined()) {
ICHECK(func->IsInstance<tir::PrimFuncNode>())
<< "Expected all functions that are not external to be PrimFuncs, but found "
<< func->GetTypeKey();
const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target));
}
}
}
// Extract lowered dynamic shape functions from the shape cache
for (const auto& it : shape_func_cache_) {
auto source_func = it.first;
auto lowered_func = it.second;
auto target = source_func->target;

if (!lowered_functions.count(target)) {
lowered_functions[target] = IRModule(Map<GlobalVar, BaseFunc>({}));
IRModule lowered_mod = lowered_func->cached_func->funcs;

// Annotate functions with their target and put them in the return module
for (auto kv : lowered_mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
const tir::PrimFunc& prim_func = Downcast<tir::PrimFunc>(func);
mod->Update(var, WithAttr(prim_func, tvm::attr::kTarget, source_func->target));
}

lowered_functions[target]->Update(lowered_func->cached_func->funcs);
}
return backend::TargetStrModuleMapToTargetModuleMap(lowered_functions);
return mod;
}

Array<tvm::runtime::Module> LowerExternalFunctions() {
Expand Down Expand Up @@ -830,9 +843,9 @@ void UpdateFunctionMetadata(Function relay_func,
function_metadata.Set(prim_fn_var.value()->name_hint, fi);
}

LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn) {
IRModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn) {
DLOG(INFO) << "lowering module:\n" << PrettyPrint(module);

TECompiler compiler;
Expand Down Expand Up @@ -864,76 +877,23 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic
(*te_compiler_update_weights)(weight_map);
}

LoweredModule lowered_module;
lowered_module.main_module = updated_module;
lowered_module.per_target_module = compiler->GetLoweredFunctions();
lowered_module.external_mods = compiler->LowerExternalFunctions();
lowered_module.main_func_info = func_info;
return lowered_module;
}
// Copy the lowered functions into the return module
updated_module->Update(compiler->GetLoweredFunctions());

IRModule LoweredModuleToIRModule(LoweredModule mod) {
IRModule unified_module;
// Annotate the module with the external modules and function info
updated_module = WithAttr(updated_module, "external_mods", compiler->LowerExternalFunctions());
updated_module = WithAttr(updated_module, "main_func_info", func_info);

// Copy the main module and its typedefs
for (const auto& kv : mod.main_module->functions) {
unified_module->Add(kv.first, kv.second);
}
for (const auto& kv : mod.main_module->type_definitions) {
unified_module->AddTypeDef(kv.first, kv.second);
}

// Annotate the per-target functions with their target and add them to the unified module
for (const auto& kv : mod.per_target_module) {
const Target target = kv.first;
const IRModule target_module = kv.second;

// Right now, per-target functions are TIR functions, which don't have type definitions, so
// there should be no type defs in the per_target_modules
size_t ty_def_size = target_module->type_definitions.size();
ICHECK(ty_def_size == 0)
<< "Expected there to be no type definitions in the per_target_modules, but found "
<< ty_def_size;

for (const auto& kv : target_module->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
if (func->IsInstance<tir::PrimFuncNode>()) {
tir::PrimFunc primFunc =
WithAttr(Downcast<tir::PrimFunc>(std::move(func)), tvm::attr::kTarget, target);
unified_module->Add(var, primFunc);
} else if (func->IsInstance<relay::FunctionNode>()) {
relay::Function relayFunc =
WithAttr(Downcast<relay::Function>(std::move(func)), tvm::attr::kTarget, target);
unified_module->Add(var, relayFunc);
} else {
LOG(FATAL)
<< "We expected to only have PrimFuncs or RelayFuncs in the target modules, but found "
<< func->GetTypeKey();
}
}
}

IRModule ret_mod = WithAttr(unified_module, "external_mods", mod.external_mods);
ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info);
return ret_mod;
return updated_module;
}

LoweredModule IRModuleToLoweredModule(IRModule mod) {
IRModule main_mod;
// Copy just the TypeDefs from the IRModule to the LoweredModule's main module
// This is the only time we need to do this since there are no TypeDefs in TIR
for (const auto& kv : mod->type_definitions) {
main_mod->AddTypeDef(kv.first, kv.second);
}

Map<Target, IRModule> per_target_modules;
Map<Target, IRModule> GetPerTargetModules(IRModule mod) {
std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
per_target_modules;
for (const auto& kv : mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
if (func->IsInstance<relay::FunctionNode>()) {
main_mod->Add(var, func);
} else if (func->IsInstance<tir::PrimFuncNode>()) {
if (func->IsInstance<tir::PrimFuncNode>()) {
// Extract target
Optional<Target> target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target) << "Target should be set at this point";
Expand All @@ -943,44 +903,47 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) {
// Initialize the IRModule for this target and add the function
IRModule target_module;
target_module->Add(var, func);
per_target_modules.Set(target.value(), target_module);
per_target_modules[target.value()] = target_module;
} else {
// The IRModule for this target is initialized, so just add the function.
IRModule target_module = per_target_modules.at(target.value());
target_module->Add(var, func);
}
} else {
} else if (!func->IsInstance<relay::FunctionNode>()) {
LOG(FATAL)
<< "The function types in the IRModule should be RelayFunction or PrimFunc, but got "
<< func->GetTypeKey();
}
}
return per_target_modules;
}

// Put the LoweredModule together
LoweredModule lowered_module;
lowered_module.main_module = main_mod;
lowered_module.per_target_module = per_target_modules;

// Extract external modules and main func info, add to lowered module if they exist
auto external_mods = mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
if (external_mods) {
lowered_module.external_mods = external_mods.value();
IRModule GetMainModule(IRModule mod) {
IRModule main_module;
// Copy the type defs
for (const auto& kv : mod->type_definitions) {
main_module->AddTypeDef(kv.first, kv.second);
}
auto main_func_info = mod->GetAttr<backend::FunctionInfo>("main_func_info");
if (main_func_info) {
lowered_module.main_func_info = main_func_info.value();
// Copy all Relay functions (we don't include PrimFuncs)
for (auto kv : mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
if (func->IsInstance<tvm::relay::FunctionNode>()) {
main_module->Add(var, func);
}
}
return lowered_module;
return main_module;
}

Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module,
PassContext ctx) {
return LoweredModuleToIRModule(
LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn));
return LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn);
};
// TODO(@electriclilies, mbs): Fold InferType() pass into LowerTEPass since it will always need to
// be called afterwards
return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {});
}
} // namespace tec
Expand Down
Loading

0 comments on commit 0744641

Please sign in to comment.