diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 221df958a8cb..54a10add2f07 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -38,7 +38,7 @@ #include #include -#include "compile_engine.h" +#include "te_compiler.h" #include "utils.h" namespace tvm { @@ -46,7 +46,6 @@ namespace relay { namespace backend { using IntegerArray = Array; -using TargetsMap = std::unordered_map; using StorageMap = std::unordered_map; @@ -287,7 +286,6 @@ class AOTExecutorCodegen : public ExprVisitor { void CreateFuncCall(Call call, std::string func_name) { tvm::Array args{tvm::tir::StringImm(func_name)}; std::vector create_func_call_stmts; - // Pack the inputs for (Expr arg : call->args) { if (params_by_expr_.find(arg) != params_by_expr_.end()) { @@ -365,155 +363,21 @@ class AOTExecutorCodegen : public ExprVisitor { return ss.str(); } - /*! - * \brief Update the "main" control function's metadata - * - * \param func The main function that contains calls to operator tir primitive functions - */ - void UpdateMainWorkspaceSize(const tir::PrimFunc& primfunc, const relay::Function& func) { - auto workspace_byte_alignment = target_host_->GetAttr("workspace-byte-alignment") - .value_or(tvm::runtime::kDefaultWorkspaceAlignment); - Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment); - // Populate FunctionInfo - auto fi_node = make_object(); - // Initialize all target workspaces to zero - for (const auto& kv : targets_) { - auto tgt = kv.second; - fi_node->workspace_sizes.Set(tgt, 0); - } - fi_node->workspace_sizes.Set(target_host_, workspace_size); - fi_node->relay_primfuncs.Set(target_host_, func); - - int64_t io_size = 0; - for (const auto& input : input_vars_) { - io_size += CalculateRelayExprSizeBytes(input->checked_type()); - } - io_size += CalculateRelayExprSizeBytes(func->body->checked_type()); - fi_node->io_sizes.Set(target_host_, io_size); - - int64_t const_size = 0; - for (const auto& kv : params_by_expr_) { - const_size += CalculateRelayExprSizeBytes(kv.first->checked_type()); - } - fi_node->constant_sizes.Set(target_host_, const_size); - function_metadata_.Set(String(runtime::symbol::tvm_module_main), FunctionInfo(fi_node)); - } - - /*! - * \brief Update the function metadata for a given cached function and its relay - * primitive function. - * - * \param cfunc The cached function as provided the by the compile engine - * \param relay_func The source relay primitive function - * \param relay_target The target associated with relay primitive function - */ - void UpdateFunctionMetadata(const CachedFunc& cfunc, const Function& relay_func, - const Target& relay_target) { - auto fi_node = make_object(); - for (const auto& kv : cfunc->funcs->functions) { - auto primfunc = Downcast(kv.second); - auto workspace_byte_alignment = - target_host_->GetAttr("workspace-byte-alignment").value_or(16); - Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment); - Target primfunc_target = relay_target; - if (primfunc->attrs->dict.count("target")) { - primfunc_target = Downcast(primfunc->attrs->dict["target"]); - } - fi_node->workspace_sizes.Set(primfunc_target, workspace_size); - // Calculating size for I/O - for (auto const& param : primfunc->params) { - auto p_shape = primfunc->buffer_map[param]->shape; - int num_of_elements = 1; - for (const auto& dim_index_expr : p_shape) { - if (dim_index_expr->IsInstance()) { - num_of_elements *= dim_index_expr.as()->value; - } else { - // If shape is dynamic, we cannot calculate workspace in compile time. - num_of_elements = 0; - } - } - int element_size = primfunc->buffer_map[param]->dtype.bytes(); - fi_node->io_sizes.Set(primfunc_target, element_size * num_of_elements); - } - fi_node->constant_sizes.Set(primfunc_target, 0); - fi_node->tir_primfuncs.Set(primfunc_target, primfunc); - fi_node->relay_primfuncs.Set(primfunc_target, relay_func); - } - function_metadata_.Set(cfunc->prim_fn_var->name_hint, FunctionInfo(fi_node)); - } - void VisitExpr_(const CallNode* op) override { // Descend the call tree for (auto arg : op->args) { VisitExpr(arg); } - Expr expr = GetRef(op); - Function func; if (op->op.as()) { LOG(FATAL) << "Operators should be transformed away; try applying" << "the fuse_ops transformation to the expression."; } else if (op->op.as()) { - LOG(FATAL) << "Not implemented"; - } else if (op->op.as()) { - func = GetRef(op->op.as()); + GlobalVar node = GetRef(op->op.as()); + CreateFuncCall(GetRef(op), node->name_hint); } else { LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey(); } - if (!func->HasNonzeroAttr(attr::kPrimitive)) { - LOG(FATAL) << "TVM only support calls to primitive functions " - << "(i.e functions composed of fusable operator invocations)"; - } - - Target target; - - // Handle external function - if (func->GetAttr(attr::kCompiler).defined()) { - target = Target("ext_dev"); - CCacheKey key = CCacheKey(func, target); - CachedFunc ext_func = compile_engine_->Lower(key, mod_name_); - ICHECK(ext_func.defined()) << "External function is not defined."; - UpdateConstants(func, ¶ms_); - - // Generate the TIR function call - CreateFuncCall(GetRef(op), ext_func->prim_fn_var->name_hint); - return; - } - - ICHECK_GE(storage_device_map_.count(expr), 0); - StorageInfo& sinfo = storage_device_map_[expr]; - auto call_dev_type = sinfo->device_types[0]; - // Normal Relay Function - if (targets_.size() == 1) { - // homogeneous execution. - const auto& it = targets_.begin(); - target = (*it).second; - } else { - // heterogeneous execution. - std::string call_dev_name; - if (call_dev_type == 0) { - call_dev_name = "llvm"; - } else { - call_dev_name = runtime::DeviceName(call_dev_type); - } - if (targets_.count(call_dev_type) == 0) { - LOG(FATAL) << "No target is provided for device " << call_dev_name; - } - target = targets_[call_dev_type]; - } - - CCacheKey key = CCacheKey(func, target); - CachedFunc lowered_func = compile_engine_->Lower(key, mod_name_); - - if (!lowered_funcs_.count(target->str())) { - lowered_funcs_[target->str()] = IRModule(Map({})); - } - lowered_funcs_[target->str()]->Update(lowered_func->funcs); - // Update function metadata via looking at all primfuncs - UpdateFunctionMetadata(lowered_func, func, target); - - // Generate the TIR function call - CreateFuncCall(GetRef(op), lowered_func->prim_fn_var->name_hint); } void VisitExpr_(const VarNode* op) override { @@ -598,7 +462,7 @@ class AOTExecutorCodegen : public ExprVisitor { // Create the main PrimFunc to execute the graph. Please note that // the packed function calls don't pack their arguments. The AOT // runner function needs to be legalized by the LegalizePackedCalls pass. - tir::PrimFunc CreateMainFunc(unsigned int relay_params) { + tir::PrimFunc CreateMainFunc(String mod_name, unsigned int relay_params) { tir::Stmt body = tir::SeqStmt(stmts_); // Allocate the sids @@ -637,7 +501,7 @@ class AOTExecutorCodegen : public ExprVisitor { // Define the PrimFunc attributes Map dict_attrs; String run_func_name = - runtime::get_name_mangled(mod_name_, runtime::symbol::tvm_run_func_suffix); + runtime::get_name_mangled(mod_name, runtime::symbol::tvm_run_func_suffix); dict_attrs.Set("global_symbol", run_func_name); dict_attrs.Set("runner_function", Bool(true)); @@ -654,7 +518,7 @@ class AOTExecutorCodegen : public ExprVisitor { /*! \brief input and output variables belonging to the main function signature */ Array main_signature_; /*! \brief target device */ - TargetsMap targets_; + tec::TargetMap targets_; /*! \brief target host */ Target target_host_; /*! @@ -684,35 +548,70 @@ class AOTExecutorCodegen : public ExprVisitor { /*! \brief mapping sid -> tir::Var */ std::unordered_map sids_table_; /*! \brief lowered funcs */ - std::unordered_map lowered_funcs_; - /*! \brief lowered funcs */ Map function_metadata_; - /*! \brief compile engine */ - CompileEngine compile_engine_; /*! \brief the set of statements that make the program */ std::vector stmts_; /*! \brief the list of return sids (note that the function might return more then one output */ std::vector return_sid_; - /*! \brief the module name we use to mangle the function names */ - String mod_name_; public: - AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host) + AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets, Target target_host) : mod_(mod), targets_(targets), target_host_(target_host), - use_unpacked_api_(target_host->GetAttr("unpacked-api").value_or(Bool(false))), - compile_engine_(CompileEngine::Global()) {} + use_unpacked_api_(target_host->GetAttr("unpacked-api").value_or(Bool(false))) {} LoweredOutput Codegen(relay::Function func, String mod_name) { auto aot_allocator = AOTOnDemandAllocator(); aot_allocator.Run(func); - // Retrieve the storage map - storage_device_map_ = aot_allocator.GetStorageMap(); - mod_name_ = mod_name; + // Pre-lowering storage map and memory plan + StorageMap initial_storage_map = aot_allocator.GetStorageMap(); + StaticMemoryPlan memory_plan(initial_storage_map); + + // Build a map from each operation to device. + tec::DeviceMap device_context_map; + for (const auto& it : memory_plan->expr_to_storage_info) { + auto expr = it.first; + auto storage_info = it.second; + auto device_types = storage_info->device_types; + // CHECK_EQ(device_types.size(), 1); + tvm::Device dev; + dev.device_id = 0; + dev.device_type = device_types[0]; + device_context_map.insert({expr, dev}); + } + + // This first phase moves from implicit use of compile engine, + // to instead explicitly lowering the incoming IRModule, and then + // performing the preexisting AOT executor code generation phase. + IRModule mod = IRModule::FromExpr(func); + auto lowered_module = tec::LowerTE( + mod, targets_, device_context_map, memory_plan, mod_name, [this](Function func) { + // We need to maintain the constant map for external + // functions so we pass this processing function which + // allows us to process each function as we lower it. + if (func->GetAttr(attr::kCompiler).defined()) { + UpdateConstants(func, ¶ms_); + } + + // TODO(@areusch, @jroesch): We should refactor this to + // execute as a further pass, instead writing data to the + // lowering process directly. + tec::UpdateFunctionMetadata(func, this->function_metadata_); + }); - for (auto input : func->params) { + function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info); + auto lowered_main = lowered_module.main_module->Lookup("main"); + auto lowered_main_func = GetRef(lowered_main.as()); + + // Post-lowering storage map for writing main func - this should be the same map as previously + // created, just referencing the new expressions created from lowering + auto new_allocator = AOTOnDemandAllocator(); + new_allocator.Run(lowered_main_func); + storage_device_map_ = new_allocator.GetStorageMap(); + + for (auto input : lowered_main_func->params) { input_vars_.push_back(input); main_signature_.push_back(tir::Var("input", DataType::Handle())); } @@ -732,13 +631,12 @@ class AOTExecutorCodegen : public ExprVisitor { main_signature_.push_back(tir::Var("output", DataType::Handle())); } - VisitExpr(func->body); + VisitExpr(lowered_main_func->body); // Create the runner function. Please note that the function is not legal yet // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need // to run the LegalizePackedCalls pass. - auto prim_func = CreateMainFunc(func->params.size()); - UpdateMainWorkspaceSize(prim_func, func); + auto prim_func = CreateMainFunc(mod_name, lowered_main_func->params.size()); LoweredOutput ret; ret.params = std::unordered_map>(); @@ -748,17 +646,7 @@ class AOTExecutorCodegen : public ExprVisitor { std::make_pair(static_cast(param_storage_ids_[param.first]), param.second))); } - for (auto& kv : lowered_funcs_) { - if (ret.lowered_funcs.count(kv.first) == 0) { - ret.lowered_funcs.Set(kv.first, IRModule(Map({}))); - } - auto& mod = ret.lowered_funcs[kv.first]; - mod->Update(kv.second); - ret.lowered_funcs.Set(kv.first, mod); - } - ret.external_mods = compile_engine_->LowerExternalFunctions(); - - // Build the TIR IRModule + // Build the TIR IRModule for the AOT function Map symbol_map; symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func); IRModule mod_run(symbol_map); @@ -774,14 +662,17 @@ class AOTExecutorCodegen : public ExprVisitor { mod_run = pack_calls(mod_run); } - // Update the lowered functions + ret.function_metadata = std::move(function_metadata_); + + ret.lowered_funcs = lowered_module.per_target_module; + ret.external_mods = lowered_module.external_mods; + auto target_host_str = target_host_->str(); if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) { ret.lowered_funcs[target_host_str]->Update(mod_run); } else { ret.lowered_funcs.Set(target_host_str, mod_run); } - ret.function_metadata = std::move(function_metadata_); std::vector input_var_names(input_vars_.size()); std::transform(input_vars_.begin(), input_vars_.end(), input_var_names.begin(), @@ -845,7 +736,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { private: void init(void* mod, Map tmp) { - TargetsMap targets; + tec::TargetMap targets; Target target_host; for (const auto& it : tmp) { auto dev_type = it.first.as(); @@ -853,7 +744,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { target_host = it.second; } ICHECK(dev_type); - targets[dev_type->value] = it.second; + targets[static_cast(dev_type->value)] = it.second; } codegen_ = std::make_shared(reinterpret_cast(mod), targets, target_host);