diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py index a094ab7a64d4..ffe7b232d309 100644 --- a/python/tvm/relax/vm.py +++ b/python/tvm/relax/vm.py @@ -17,8 +17,11 @@ from typing import List, Optional, Union, Dict, Tuple import tvm +from tvm import relax +from tvm.ir.module import IRModule from tvm.runtime import Object, Device, Module, PackedFunc from tvm._ffi.base import _LIB, check_call +from tvm.tir.function import PrimFunc from . import _ffi_api from . import transform from ..rpc.base import RPC_SESS_MASK @@ -169,5 +172,22 @@ def build(mod: tvm.IRModule, new_mod = transform.call_dps_rewrite(new_mod) new_mod = transform.vm_memory_lower(new_mod) new_mod = transform.vm_shape_lower(new_mod) - ex, lib = _ffi_api.VMBuild(new_mod, target, target_host) + + # split primfunc and relax function + rx_mod, tir_mod = _split_tir_relax(new_mod) + + lib = tvm.build(tir_mod, target, target_host) + ex = _ffi_api.VMCodeGen(rx_mod) return ex, lib + +def _split_tir_relax(mod: tvm.IRModule) -> Tuple[tvm.IRModule, tvm.IRModule]: + rx_mod = IRModule({}) + tir_mod = IRModule({}) + for gv in mod.get_global_vars(): + if isinstance(mod[gv], PrimFunc): + tir_mod[gv] = mod[gv] + elif isinstance(mod[gv], relax.Function): + rx_mod[gv] = mod[gv] + else: + raise ValueError("An IRModule should contain contain relax function and TIR primfunc.") + return rx_mod, tir_mod \ No newline at end of file diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index a40c749a4072..1af98d4e8a0e 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -19,7 +19,7 @@ /*! * \file src/relax/backend/vm/codegen_vm.cc - * \brief A compiler to compile an IRModule to VM executable. + * \brief A codegen to generate VM executable from an IRModule with relax functions. */ #include "codegen_vm.h" @@ -64,7 +64,7 @@ class CodeGenVM : public ExprFunctor { // TODO(@yuchen): handle local functions that capture local vars outside the func // TODO(@yuchen): a renaming pass to resolve name conflicts, e.g. the input module has a // function named "local_funcN" - // lift the local func to a global func and compile it normally + // lift the local func to a global func and process it normally builder_->EmitFunction("local_func" + std::to_string(local_func_counter_++), func_node->params.size()); } @@ -287,49 +287,27 @@ class CodeGenVM : public ExprFunctor { const Op& load_shape_op_ = Op::Get("relax.vm.builtin.load_shape"); }; -void VMCompiler::Compile(IRModule mod, Target target, Target target_host) { +void VMCodeGen::CodeGen(IRModule rx_mod) { builder_ = relax::ExecBuilderNode::Create(); - - IRModule tir_mod; - IRModule rx_mod; - for (auto& p : mod->functions) { - auto gvar = p.first; - - BaseFunc func = p.second; - if (func.as()) { - tir_mod->Add(gvar, func); - } else if (func.as()) { - rx_mod->Add(gvar, func); - } else { - LOG(FATAL) << "Cannot handle such function node now:\n" << func; - } - } - lib_ = tvm::build(tir_mod, target, target_host); - - CodeGenVM compiler(builder_.operator->()); + CodeGenVM codegen(builder_.operator->()); for (auto& p : rx_mod->functions) { - compiler.VisitExpr(p.second); + codegen.VisitExpr(p.second); } } -Executable VMCompiler::GetExec() { +Executable VMCodeGen::GetExec() { return builder_->Get(); } -runtime::Module VMCompiler::GetLib() { - return lib_; -} - -Array Build(IRModule mod, Target target, Target target_host) { - auto compiler = make_object(); - compiler->Compile(mod, target, target_host); - Executable exec = compiler->GetExec(); - Module lib = compiler->GetLib(); - return Array({exec, lib}); +Executable CodeGen(IRModule mod) { + auto codegen = make_object(); + codegen->CodeGen(mod); + Executable exec = codegen->GetExec(); + return exec; } -TVM_REGISTER_GLOBAL("relax.VMBuild") -.set_body_typed(Build); +TVM_REGISTER_GLOBAL("relax.VMCodeGen") +.set_body_typed(CodeGen); } // namespace relax_vm } // namespace relax diff --git a/src/relax/backend/vm/codegen_vm.h b/src/relax/backend/vm/codegen_vm.h index 36c78f7e3ec6..f0ac9bed0ecd 100644 --- a/src/relax/backend/vm/codegen_vm.h +++ b/src/relax/backend/vm/codegen_vm.h @@ -19,11 +19,11 @@ /*! * \file src/relax/backend/vm/codegen_vm.h - * \brief A compiler to compile an IRModule to VM executable. + * \brief A codegen to generate VM executable from an IRModule with relax functions. */ -#ifndef TVM_RELAX_BACKEND_VM_COMPILER_H_ -#define TVM_RELAX_BACKEND_VM_COMPILER_H_ +#ifndef TVM_RELAX_BACKEND_CODEGEN_VM_H_ +#define TVM_RELAX_BACKEND_CODEGEN_VM_H_ #include #include @@ -40,37 +40,29 @@ using tvm::Target; using namespace tvm::runtime::relax_vm; using namespace tvm::runtime; -class VMCompiler : public Object { +class VMCodeGen : public Object { public: /*! * \brief Compile the functions in a Module. - * \param mod Input IRModule to be compiled. + * \param rx_mod Input IRModule that constains relax functions. */ - void Compile(IRModule mod, Target target, Target target_host); + void CodeGen(IRModule rx_mod); /*! * \brief Get the compiled executable. * \return The compiled executable. */ Executable GetExec(); - /*! - * \brief Get the compiled library. - * \return The compiled lirary. - */ - Module GetLib(); static constexpr const uint32_t _type_index = TypeIndex::kDynamic; - static constexpr const char* _type_key = "relax.VMCompiler"; - TVM_DECLARE_FINAL_OBJECT_INFO(ExecutableNode, Object); + static constexpr const char* _type_key = "relax.VMCodeGen"; protected: /*! \brief Internal executable builder. */ relax::ExecBuilder builder_; - /*! \brief Built library. */ - runtime::Module lib_; }; } // namespace relax_vm } // namespace relax } // namespace tvm -#endif // TVM_RELAX_BACKEND_VM_COMPILER_H_ +#endif // TVM_RELAX_BACKEND_CODEGEN_VM_H_