From b39e8cc931295892906a839fb4c58f9ee574ae72 Mon Sep 17 00:00:00 2001 From: Chris Sidebottom Date: Tue, 10 Aug 2021 20:07:17 +0100 Subject: [PATCH] Use main_func_info rather than bespoke logic in AOT This moves from using the bespoke AOT UpdateMainWorkspaceSize to the LoweredModule main_func_info property to unify with Graph executor codegen. --- src/relay/backend/aot_executor_codegen.cc | 37 +---------------------- 1 file changed, 1 insertion(+), 36 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index afa73b946694..54a10add2f07 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -363,40 +363,6 @@ class AOTExecutorCodegen : public ExprVisitor { return ss.str(); } - /*! - * \brief Update the "main" control function's metadata - * - * \param func The main function that contains calls to operator tir primitive functions - */ - void UpdateMainWorkspaceSize(const tir::PrimFunc& primfunc, const relay::Function& func) { - auto workspace_byte_alignment = target_host_->GetAttr("workspace-byte-alignment") - .value_or(tvm::runtime::kDefaultWorkspaceAlignment); - Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment); - // Populate FunctionInfo - auto fi_node = make_object(); - // Initialize all target workspaces to zero - for (const auto& kv : targets_) { - auto tgt = kv.second; - fi_node->workspace_sizes.Set(tgt, 0); - } - fi_node->workspace_sizes.Set(target_host_, workspace_size); - fi_node->relay_primfuncs.Set(target_host_, func); - - int64_t io_size = 0; - for (const auto& input : input_vars_) { - io_size += CalculateRelayExprSizeBytes(input->checked_type()); - } - io_size += CalculateRelayExprSizeBytes(func->body->checked_type()); - fi_node->io_sizes.Set(target_host_, io_size); - - int64_t const_size = 0; - for (const auto& kv : params_by_expr_) { - const_size += CalculateRelayExprSizeBytes(kv.first->checked_type()); - } - fi_node->constant_sizes.Set(target_host_, const_size); - function_metadata_.Set(String(runtime::symbol::tvm_module_main), FunctionInfo(fi_node)); - } - void VisitExpr_(const CallNode* op) override { // Descend the call tree for (auto arg : op->args) { @@ -635,6 +601,7 @@ class AOTExecutorCodegen : public ExprVisitor { tec::UpdateFunctionMetadata(func, this->function_metadata_); }); + function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info); auto lowered_main = lowered_module.main_module->Lookup("main"); auto lowered_main_func = GetRef(lowered_main.as()); @@ -670,8 +637,6 @@ class AOTExecutorCodegen : public ExprVisitor { // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need // to run the LegalizePackedCalls pass. auto prim_func = CreateMainFunc(mod_name, lowered_main_func->params.size()); - UpdateMainWorkspaceSize(prim_func, lowered_main_func); - LoweredOutput ret; ret.params = std::unordered_map>();