From 52f06e99fc716cd0533e9697ff3ff505a5c04d24 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Mon, 23 Aug 2021 20:12:07 -0700 Subject: [PATCH] Add LowerTEPass, and convert calls to LowerTE to application of LowerTEPass (#8802) * Initial commit Initial stab at IRModule -> LoweredModule conversion func, notes Add external_mods and main_func_info to conversion funcs MTest lowered module to ir module fix problem with conversion funcs + print stmts Add LowerTE pass Add pLowerTEPass AAdd LowerTEPass to graph_executor_codegen.cc Use LowerTEPass instead of LowerTe in graph_executor_codegen.cc Code cleanup Add docs, more cleanup Formatting * Fix bad rebase * Address 1st round of comments * Use tir kTarget instead of relay one * Change target string to Target obj * removing target string causing issues * Fix typos * Revert target str -> target obj changes * Don't use Update : IRModule because it is broken * Fix check * flaky test? * lint --- include/tvm/relay/function.h | 1 - src/relay/backend/aot_executor_codegen.cc | 8 +- src/relay/backend/graph_executor_codegen.cc | 7 +- src/relay/backend/te_compiler.cc | 117 +++++++++++++++++++- src/relay/backend/te_compiler.h | 49 +++++++- 5 files changed, 167 insertions(+), 15 deletions(-) diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index fccd1f937a06..9170bc53ea02 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -144,7 +144,6 @@ constexpr const char* kComposite = "Composite"; constexpr const char* kInline = "Inline"; /*! \brief Indicate the function was created by the Pattern Partitioning Pass. */ constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern"; - /*! \brief Mark the function as only composed of reshape operations. */ constexpr const char* kReshapeOnly = "relay.reshape_only"; } // namespace attr diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 54a10add2f07..942bc0d1d44a 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -586,8 +586,9 @@ class AOTExecutorCodegen : public ExprVisitor { // to instead explicitly lowering the incoming IRModule, and then // performing the preexisting AOT executor code generation phase. IRModule mod = IRModule::FromExpr(func); - auto lowered_module = tec::LowerTE( - mod, targets_, device_context_map, memory_plan, mod_name, [this](Function func) { + + IRModule new_mod = + LowerTEPass(targets_, device_context_map, memory_plan, mod_name, [this](Function func) { // We need to maintain the constant map for external // functions so we pass this processing function which // allows us to process each function as we lower it. @@ -599,8 +600,9 @@ class AOTExecutorCodegen : public ExprVisitor { // execute as a further pass, instead writing data to the // lowering process directly. tec::UpdateFunctionMetadata(func, this->function_metadata_); - }); + })(mod); + tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod); function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info); auto lowered_main = lowered_module.main_module->Lookup("main"); auto lowered_main_func = GetRef(lowered_main.as()); diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index cc54a52be200..486a6dcd7d87 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -221,8 +221,8 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorfunction_metadata_); - }); + })(mod); + tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod); function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info); auto main_module = lowered_module.main_module; main_module = relay::transform::InferType()(main_module); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 93fcf73b17a2..71ac752ec680 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -20,6 +20,7 @@ #include "te_compiler.h" #include +#include #include #include #include @@ -749,8 +750,6 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, TargetMap tar relay_primfuncs); } -// TODO(@electriclilies): Is the function passed in here relay_func?? -// Also should this be inlined? /*! * \brief A function to create the function metadata for an input function (ie calculate buffer * input/output sizes) @@ -830,9 +829,6 @@ void UpdateFunctionMetadata(Function relay_func, function_metadata.Set(prim_fn_var.value()->name_hint, fi); } -// TODO(mbs): Make this an IRModule->IRModule pass by folding LoweredModule back into IRModule. -// Currently we rely on accumulating bindings inside the local TECompiler which we then -// host into the LoweredModule result. LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap device_context_map, backend::StaticMemoryPlan memory_plan, const String& module_name, std::function process_fn) { @@ -875,6 +871,117 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic return lowered_module; } +IRModule LoweredModuleToIRModule(LoweredModule mod) { + IRModule unified_module; + + // Copy the main module and its typedefs + for (const auto& kv : mod.main_module->functions) { + unified_module->Add(kv.first, kv.second); + } + for (const auto& kv : mod.main_module->type_definitions) { + unified_module->AddTypeDef(kv.first, kv.second); + } + + // Annotate the per-target functions with their target and add them to the unified module + for (const auto& kv : mod.per_target_module) { + const String target = kv.first; + const IRModule target_module = kv.second; + + // Right now, per-target functions are TIR functions, which don't have type definitions, so + // there should be no type defs in the per_target_modules + size_t ty_def_size = target_module->type_definitions.size(); + ICHECK(ty_def_size == 0) + << "Expected there to be no type definitions in the per_target_modules, but found " + << ty_def_size; + + for (const auto& kv : target_module->functions) { + const GlobalVar& var = kv.first; + const BaseFunc& func = kv.second; + if (func->IsInstance()) { + tir::PrimFunc primFunc = + WithAttr(Downcast(std::move(func)), tvm::attr::kTarget, target); + unified_module->Add(var, primFunc); + } else if (func->IsInstance()) { + relay::Function relayFunc = + WithAttr(Downcast(std::move(func)), tvm::attr::kTarget, target); + unified_module->Add(var, relayFunc); + } else { + LOG(FATAL) + << "We expected to only have PrimFuncs or RelayFuncs in the target modules, but found " + << func->GetTypeKey(); + } + } + } + + IRModule ret_mod = WithAttr(unified_module, "external_mods", mod.external_mods); + ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info); + return ret_mod; +} + +LoweredModule IRModuleToLoweredModule(IRModule mod) { + IRModule main_mod; + // Copy just the TypeDefs from the IRModule to the LoweredModule's main module + // This is the only time we need to do this since there are no TypeDefs in TIR + for (const auto& kv : mod->type_definitions) { + main_mod->AddTypeDef(kv.first, kv.second); + } + + Map per_target_modules; + for (const auto& kv : mod->functions) { + const GlobalVar& var = kv.first; + const BaseFunc& func = kv.second; + if (func->IsInstance()) { + main_mod->Add(var, func); + } else if (func->IsInstance()) { + // Extract target + Optional target = func->GetAttr(tvm::attr::kTarget); + ICHECK(target) << "Target should be set at this point"; + + // Put the function in per_target_modules + if (!per_target_modules.count(target.value())) { + // Initialize the IRModule for this target and add the function + IRModule target_module; + target_module->Add(var, func); + per_target_modules.Set(target.value(), target_module); + } else { + // The IRModule for this target is initialized, so just add the function. + IRModule target_module = per_target_modules.at(target.value()); + target_module->Add(var, func); + } + } else { + LOG(FATAL) + << "The function types in the IRModule should be RelayFunction or PrimFunc, but got " + << func->GetTypeKey(); + } + } + + // Put the LoweredModule together + LoweredModule lowered_module; + lowered_module.main_module = main_mod; + lowered_module.per_target_module = per_target_modules; + + // Extract external modules and main func info, add to lowered module if they exist + auto external_mods = mod->GetAttr>("external_mods"); + if (external_mods) { + lowered_module.external_mods = external_mods.value(); + } + auto main_func_info = mod->GetAttr("main_func_info"); + if (main_func_info) { + lowered_module.main_func_info = main_func_info.value(); + } + return lowered_module; +} + +Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, + backend::StaticMemoryPlan memory_plan, const String& module_name, + std::function process_fn) { + runtime::TypedPackedFunc pass_func = [=](IRModule module, + PassContext ctx) { + return LoweredModuleToIRModule( + LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn)); + }; + return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {}); +} } // namespace tec } // namespace relay } // namespace tvm diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 8376b99d79cd..e9cfb0d62e66 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -166,30 +166,73 @@ void UpdateFunctionMetadata(Function relay_func, /*! * \brief Obtain the Target from the device type. * If homogenous compilation, this will return the only target. - * If heteregenous compilation, this will select associated using the targets_ Map. + * If heterogeneous compilation, this will select the associated target using the + * targets_ Map. * * \param dev_type * \return Target */ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets); +/*! \brief Utility to convert a LoweredModule to an IRModule. + * + * This function takes all the target specific modules in LoweredModule and + * annotates their functions with the correct target, and puts all those functions + * in one IRModule. + * The purpose of this utility is to allow us to slowly remove LoweredModule from the codebase. + * + * \param mod The LoweredModule to convert. + * \return The IRModule form of the input LoweredModule. + */ +IRModule LoweredModuleToIRModule(LoweredModule mod); + +/*! \brief Utility to convert an IRModule to a LoweredModule. + * + * This function takes all the functions in the IRModule and moves them into target-specific + * IRModules stored inside a LoweredModule. + * The purpose of this utility is to allow us to slowly remove LoweredModule from the codebase. + * \param mod The IRModule to convert. + * \return The LoweredModule form of the input IRModule. + */ +LoweredModule IRModuleToLoweredModule(IRModule mod); + /*! \brief Lower an IRModule's primitive functions to TIR. * * This is the "back half" of the Relay compiler which lowers "primitive functions" * to TE expressions, schedules them, and then to TIR. * - * \param compiler The TE-to-TIR compliler (which caches lowered functions) * \param module The IRModule. * \param targets The mapping for devices to targets. * \param device_map An analysis result mapping each sub-expression to a device. + * \param memory_plan The memory plan used during lowering + * \param module_name The name of this module + * \param process_fn Callback allowing one-level up code generators to process + * each function that we lower * \return The lowered module, see above. */ -// TODO(@electriclilies): Not sure if this default initialization is correct... LoweredModule LowerTE( const IRModule& module, TargetMap targets, DeviceMap device_map, backend::StaticMemoryPlan memory_plan, const String& module_name, ProcessFn process_fn = [](Function f) {}); +/*! \brief Pass to lower an IRModule's primitive functions to TIR. + * + * This is the "back half" of the Relay compiler which lowers "primitive functions" + * to TE expressions, schedules them, and then to TIR. This Pass calls LowerTE, and + * uses LoweredModuleToIRModule utility to convert the output LowerTE's output + * LoweredModule into an IRModule before returning it. + * + * \param targets The mapping for devices to targets. + * \param device_context_map An analysis result mapping each sub-expression to a device. + * \param memory_plan The memory plan used during lowering + * \param module_name The name of this module + * \param process_fn Callback allowing one-level up code generators to process + * each function that we lower + * \returns The pass which lowers primative functions to TIR + */ +transform::Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, + backend::StaticMemoryPlan memory_plan, const String& module_name, + std::function process_fn); } // namespace tec } // namespace relay } // namespace tvm