Skip to content

Commit

Permalink
Use main_func_info rather than bespoke logic in AOT
Browse files Browse the repository at this point in the history
This moves from using the bespoke AOT UpdateMainWorkspaceSize to the
LoweredModule main_func_info property to unify with Graph executor
codegen.
  • Loading branch information
Mousius committed Aug 11, 2021
1 parent c9c965d commit b39e8cc
Showing 1 changed file with 1 addition and 36 deletions.
37 changes: 1 addition & 36 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,40 +363,6 @@ class AOTExecutorCodegen : public ExprVisitor {
return ss.str();
}

/*!
* \brief Update the "main" control function's metadata
*
* \param func The main function that contains calls to operator tir primitive functions
*/
void UpdateMainWorkspaceSize(const tir::PrimFunc& primfunc, const relay::Function& func) {
auto workspace_byte_alignment = target_host_->GetAttr<Integer>("workspace-byte-alignment")
.value_or(tvm::runtime::kDefaultWorkspaceAlignment);
Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment);
// Populate FunctionInfo
auto fi_node = make_object<FunctionInfoNode>();
// Initialize all target workspaces to zero
for (const auto& kv : targets_) {
auto tgt = kv.second;
fi_node->workspace_sizes.Set(tgt, 0);
}
fi_node->workspace_sizes.Set(target_host_, workspace_size);
fi_node->relay_primfuncs.Set(target_host_, func);

int64_t io_size = 0;
for (const auto& input : input_vars_) {
io_size += CalculateRelayExprSizeBytes(input->checked_type());
}
io_size += CalculateRelayExprSizeBytes(func->body->checked_type());
fi_node->io_sizes.Set(target_host_, io_size);

int64_t const_size = 0;
for (const auto& kv : params_by_expr_) {
const_size += CalculateRelayExprSizeBytes(kv.first->checked_type());
}
fi_node->constant_sizes.Set(target_host_, const_size);
function_metadata_.Set(String(runtime::symbol::tvm_module_main), FunctionInfo(fi_node));
}

void VisitExpr_(const CallNode* op) override {
// Descend the call tree
for (auto arg : op->args) {
Expand Down Expand Up @@ -635,6 +601,7 @@ class AOTExecutorCodegen : public ExprVisitor {
tec::UpdateFunctionMetadata(func, this->function_metadata_);
});

function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info);
auto lowered_main = lowered_module.main_module->Lookup("main");
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());

Expand Down Expand Up @@ -670,8 +637,6 @@ class AOTExecutorCodegen : public ExprVisitor {
// because the packed calls arguments are not wrapped in TVMValues. To make this happen we need
// to run the LegalizePackedCalls pass.
auto prim_func = CreateMainFunc(mod_name, lowered_main_func->params.size());
UpdateMainWorkspaceSize(prim_func, lowered_main_func);

LoweredOutput ret;

ret.params = std::unordered_map<std::string, std::pair<int, const tvm::runtime::NDArray>>();
Expand Down

0 comments on commit b39e8cc

Please sign in to comment.