From 2e02cf7cbeb4129e0cb83c5c8574c44dce506679 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Wed, 3 Aug 2022 18:16:48 -0500 Subject: [PATCH] [LLVM] Create LLVM scope object for use with LLVM libraries (#12140) This implements RFC 80. See https://github.com/apache/tvm-rfcs/pull/83. Summary of changes: - Created an `LLVMInstance` class. Uses of LLVM functions and data struc- tures should be contained within the lifetime of an object of this class. LLVMInstance object contains LLVMContext, and implements member functions to deserialize an llvm::Module. - Created an `LLVMTarget` class. Once an LLVMInstance object has been created, an object of LLVMTarget class can be created from TVM target string, or Target object for "llvm" target. Once LLVM command line flags are added to the "llvm" target, one of the goals of this object will be to save/restore relevant LLVM global state. Another objective for the LLVMTarget object is to be a single location for all LLVM-related compilation structures and options (such as TargetMachine, FastMathFlags, etc.) --- src/target/llvm/codegen_amdgpu.cc | 30 +- src/target/llvm/codegen_arm.cc | 5 +- src/target/llvm/codegen_blob.cc | 24 +- src/target/llvm/codegen_blob.h | 15 +- src/target/llvm/codegen_cpu.cc | 94 ++-- src/target/llvm/codegen_cpu.h | 10 +- src/target/llvm/codegen_hexagon.cc | 51 ++- src/target/llvm/codegen_llvm.cc | 142 +++--- src/target/llvm/codegen_llvm.h | 29 +- src/target/llvm/codegen_nvptx.cc | 33 +- src/target/llvm/codegen_x86_64.cc | 8 +- src/target/llvm/llvm_common.cc | 211 --------- src/target/llvm/llvm_common.h | 89 ---- src/target/llvm/llvm_instance.cc | 365 ++++++++++++++++ src/target/llvm/llvm_instance.h | 266 ++++++++++++ src/target/llvm/llvm_module.cc | 672 +++++++++++++---------------- src/target/llvm/llvm_module.h | 1 - 17 files changed, 1157 insertions(+), 888 deletions(-) delete mode 100644 src/target/llvm/llvm_common.cc delete mode 100644 src/target/llvm/llvm_common.h create mode 100644 src/target/llvm/llvm_instance.cc create mode 100644 src/target/llvm/llvm_instance.h diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 2e5a4bc23bd5..c08081405648 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -51,7 +51,7 @@ #include "../../runtime/rocm/rocm_module.h" #include "../build_common.h" #include "codegen_llvm.h" -#include "llvm_common.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -238,27 +238,25 @@ class CodeGenAMDGPU : public CodeGenLLVM { } protected: - void InitTarget(llvm::TargetMachine* tm) final { + void InitTarget() final { // Maximum vector lane = float4 native_vector_bits_ = 4 * 32; - CodeGenLLVM::InitTarget(tm); + CodeGenLLVM::InitTarget(); } }; runtime::Module BuildAMDGPU(IRModule mod, Target target) { + LLVMInstance llvm_instance; + + With llvm_target(llvm_instance, target); #if TVM_LLVM_VERSION < 90 LOG(FATAL) << "AMDGPU backend requires at least LLVM 9"; // Lower versions will crash when loading the bitcode, see // issue #4087 for a discussion #endif - InitializeLLVM(); - std::unique_ptr tm = GetLLVMTargetMachine(target); - std::unique_ptr ctx(new llvm::LLVMContext()); - // careful: cg will hold a naked pointer reference to ctx, so it should - // have a shorter lifetime than the ctx. std::unique_ptr cg(new CodeGenAMDGPU()); - cg->Init("TVMAMDGPUModule", tm.get(), ctx.get(), false, false, false); + cg->Init("TVMAMDGPUModule", llvm_target.get(), false, false, false); cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) { ICHECK(kv.second->template IsInstance()) @@ -266,20 +264,15 @@ runtime::Module BuildAMDGPU(IRModule mod, Target target) { return Downcast(kv.second); }); + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); const auto* find_rocm_bitcodes = tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path"); Array bitcode_files = (*find_rocm_bitcodes)(); for (auto& bitcode_path : bitcode_files) { - std::string path = bitcode_path; - llvm::SMDiagnostic err; - std::unique_ptr mlib = llvm::parseIRFile(path, err, *ctx); - if (mlib.get() == nullptr) { - std::string msg(err.getMessage()); - LOG(FATAL) << "Fail to load bitcode file " << path << "\n" - << "line " << err.getLineNo() << ":" << msg; - } - mlib->setTargetTriple(tm->getTargetTriple().str()); + std::unique_ptr mlib = llvm_instance.LoadIR(bitcode_path); + mlib->setTargetTriple(llvm_target->GetTargetTriple()); mlib->setDataLayout(tm->createDataLayout()); + for (llvm::Function& f : mlib->functions()) { f.addFnAttr(llvm::Attribute::AlwaysInline); } @@ -351,4 +344,5 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_rocm") } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index f5ce0d550b1f..15d1699b3b59 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -42,10 +42,10 @@ class CodeGenARM final : public CodeGenCPU { CodeGenARM() = default; virtual ~CodeGenARM() = default; - void InitTarget(llvm::TargetMachine* tm) final { + void InitTarget() final { // set native vector bits. native_vector_bits_ = 16 * 8; - CodeGenCPU::InitTarget(tm); + CodeGenCPU::InitTarget(); } llvm::Value* CreateIntrinsic(const CallNode* op) override; @@ -139,4 +139,5 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_blob.cc b/src/target/llvm/codegen_blob.cc index 8e6041b4c970..b67aac480654 100644 --- a/src/target/llvm/codegen_blob.cc +++ b/src/target/llvm/codegen_blob.cc @@ -52,25 +52,20 @@ #include #include -#include "llvm_common.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { -std::pair, std::shared_ptr> CodeGenBlob( - const std::string& data, bool system_lib, const std::string& llvm_target_string) { - InitializeLLVM(); - Target target(llvm_target_string); - auto tm = GetLLVMTargetMachine(target); - auto triple = tm->getTargetTriple(); - auto ctx = std::make_shared(); +std::unique_ptr CodeGenBlob(const std::string& data, bool system_lib, + LLVMTarget* llvm_target) { + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); + const llvm::Triple& triple = tm->getTargetTriple(); + llvm::LLVMContext* ctx = llvm_target->GetContext(); std::string module_name = "devc"; - std::unique_ptr module(new llvm::Module(module_name, *ctx)); + auto module = std::make_unique(module_name, *ctx); module->setTargetTriple(triple.str()); - // Store full target string in metadata, because flags such as -mfloat-abi must be preserved for - // ModulePackImportsToLLVM. - module->addModuleFlag(llvm::Module::ModFlagBehavior::Override, "tvm_target", - llvm::MDString::get(*ctx, LLVMTargetToString(target))); + llvm_target->SetTargetMetadata(module.get()); module->setDataLayout(tm->createDataLayout()); auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false); auto* tvm_dev_mblob = new llvm::GlobalVariable( @@ -188,9 +183,10 @@ std::pair, std::shared_ptr> Cod ir_builder.CreateRetVoid(); } - return std::make_pair(std::move(module), ctx); + return module; } } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_blob.h b/src/target/llvm/codegen_blob.h index 46c037a30af2..a06c043c07b1 100644 --- a/src/target/llvm/codegen_blob.h +++ b/src/target/llvm/codegen_blob.h @@ -26,15 +26,18 @@ #ifdef TVM_LLVM_VERSION -#include -#include - #include #include -#include + +namespace llvm { +class Module; +} namespace tvm { namespace codegen { + +class LLVMTarget; + /** * \brief Code Generation of blob data * @@ -44,8 +47,8 @@ namespace codegen { * * \return LLVM module and LLVM context */ -std::pair, std::shared_ptr> CodeGenBlob( - const std::string& data, bool system_lib, const std::string& llvm_target_string); +std::unique_ptr CodeGenBlob(const std::string& data, bool system_lib, + LLVMTarget* llvm_target); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index f2ce6fb848b4..c4aed1a237dd 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -60,6 +60,7 @@ #include "../func_registry_generator.h" #include "../metadata_utils.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -69,10 +70,9 @@ namespace codegen { CodeGenCPU::CodeGenCPU() = default; CodeGenCPU::~CodeGenCPU() = default; -void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup, - bool target_c_runtime) { - CodeGenLLVM::Init(module_name, tm, ctx, system_lib, dynamic_lookup, target_c_runtime); +void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, + bool dynamic_lookup, bool target_c_runtime) { + CodeGenLLVM::Init(module_name, llvm_target, system_lib, dynamic_lookup, target_c_runtime); dbg_info_ = CreateDebugInfo(module_.get()); static_assert(sizeof(TVMValue) == sizeof(double), "invariant"); func_handle_map_.clear(); @@ -80,7 +80,8 @@ void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, // Runtime types. - t_tvm_shape_index_ = llvm::Type::getIntNTy(*ctx, DataType::ShapeIndex().bits()); + t_tvm_shape_index_ = + llvm::Type::getIntNTy(*llvm_target_->GetContext(), DataType::ShapeIndex().bits()); // Defined in 3rdparty/dlpack/include/dlpack/dlpack.h: // typedef struct { DLDeviceType device_type; int device_id; } DLDevice; t_tvm_device_ = llvm::StructType::create({t_int_, t_int_}); @@ -177,7 +178,7 @@ void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::Function::Create(ftype_tvm_parallel_barrier_, llvm::Function::ExternalLinkage, "TVMBackendParallelBarrier", module_.get()); } - this->InitGlobalContext(dynamic_lookup); + InitGlobalContext(dynamic_lookup); target_c_runtime_ = target_c_runtime; is_system_lib_ = system_lib; } @@ -240,6 +241,7 @@ void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) { } llvm::DebugLoc DL; builder.SetCurrentDebugLocation(DL); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); for (size_t i = 0; i < f_llvm->arg_size(); ++i) { auto* paramAlloca = builder.CreateAlloca(f_llvm->getFunctionType()->getParamType(i)); std::string paramName = "arg" + std::to_string(i + 1); @@ -248,7 +250,7 @@ void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) { GetDebugType(GetType(f_tir->params[i]), f_llvm->getFunctionType()->getParamType(i)), /*alwaysPreserve=*/true); auto* store = builder.CreateStore(f_llvm->arg_begin() + i, paramAlloca); - auto* di_loc = llvm::DILocation::get(*ctx_, 0, 0, DIFunction); + auto* di_loc = llvm::DILocation::get(*ctx, 0, 0, DIFunction); dbg_info_->di_builder_->insertDeclare(paramAlloca, param, dbg_info_->di_builder_->createExpression(), llvm::DebugLoc(di_loc), store); @@ -263,7 +265,7 @@ void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) { if (I.getDebugLoc()) { continue; } - auto* di_loc = llvm::DILocation::get(*ctx_, 0, 0, scope); + auto* di_loc = llvm::DILocation::get(*ctx, 0, 0, scope); I.setDebugLoc(llvm::DebugLoc(di_loc)); } } @@ -273,7 +275,7 @@ void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) { llvm::DIType* CodeGenCPU::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm) { if (ty_llvm == t_void_) { return nullptr; - } else if (ty_llvm == llvm::Type::getFloatTy(*ctx_)) { + } else if (ty_llvm == llvm::Type::getFloatTy(*llvm_target_->GetContext())) { return dbg_info_->di_builder_->createBasicType("float", 32, llvm::dwarf::DW_ATE_float); } else if (ty_llvm == t_int8_) { return dbg_info_->di_builder_->createBasicType("int8", 8, llvm::dwarf::DW_ATE_signed); @@ -311,13 +313,14 @@ void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) { #endif // comdat is needed for windows select any linking to work // set comdat to Any(weak linking) - if (target_machine_->getTargetTriple().isOSWindows()) { + if (llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) { llvm::Comdat* comdat = module_->getOrInsertComdat(runtime::symbol::tvm_module_main); comdat->setSelectionKind(llvm::Comdat::Any); global->setComdat(comdat); } - global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, entry_func_name)); + global->setInitializer( + llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), entry_func_name)); global->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass); } @@ -475,7 +478,7 @@ llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string gv->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); // comdat is needed for windows select any linking to work // set comdat to Any(weak linking) - if (target_machine_->getTargetTriple().isOSWindows()) { + if (llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) { llvm::Comdat* comdat = module_->getOrInsertComdat(name); comdat->setSelectionKind(llvm::Comdat::Any); gv->setComdat(comdat); @@ -525,8 +528,9 @@ void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) { llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) { // create emit codes that checks and load the function. - auto* fail_block = llvm::BasicBlock::Create(*ctx_, "call_fail", function_); - auto* end_block = llvm::BasicBlock::Create(*ctx_, "call_end", function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto* fail_block = llvm::BasicBlock::Create(*ctx, "call_fail", function_); + auto* end_block = llvm::BasicBlock::Create(*ctx, "call_end", function_); auto* succ = builder_->CreateICmpEQ(retcode, llvm::ConstantInt::get(t_int_, 0)); builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_); builder_->SetInsertPoint(fail_block); @@ -584,6 +588,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { SetTargetAttributes(fcompute); llvm::BasicBlock* compute_call_end = CheckCallSuccess(builder_->CreateCall(fcompute, arg_values)); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); // enter compute scope and setup compute function. With scope_states_guard(this); size_t idx = 0; @@ -607,7 +612,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { if (f != alloc_storage_info_.end()) { unsigned align = f->second.alignment; if (align > 1) { - auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align); + auto attr = llvm::Attribute::get(*ctx, llvm::Attribute::Alignment, align); fcompute->addParamAttr(idx, attr); } } @@ -615,7 +620,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { } function_ = fcompute; - auto* compute_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + auto* compute_entry = llvm::BasicBlock::Create(*ctx, "entry", function_); builder_->SetInsertPoint(compute_entry); this->VisitStmt(op->body); builder_->CreateRet(ConstInt32(0)); @@ -679,7 +684,8 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task, std::strin launch_callee, {f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(num_task)})); // Setup the closure function. - auto* lambda_entry = llvm::BasicBlock::Create(*ctx_, "parallel_closure_entry", f); + auto* lambda_entry = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "parallel_closure_entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); llvm::Value* task_id = &(*it++); @@ -747,7 +753,7 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod llvm::BasicBlock* init_end = CheckCallSuccess(builder_->CreateCall( finit, {gv, f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(nbytes)})); // Setup the closure function. - auto* lambda_entry = llvm::BasicBlock::Create(*ctx_, "entry", f); + auto* lambda_entry = llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType()); @@ -793,9 +799,10 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { hptr = it->second; } // create emit codes that checks and load the function. + llvm::LLVMContext* ctx = llvm_target_->GetContext(); llvm::BasicBlock* pre_block = builder_->GetInsertBlock(); - auto* init_block = llvm::BasicBlock::Create(*ctx_, "handle_init", function_); - auto* end_block = llvm::BasicBlock::Create(*ctx_, "handle_init_end", function_); + auto* init_block = llvm::BasicBlock::Create(*ctx, "handle_init", function_); + auto* end_block = llvm::BasicBlock::Create(*ctx, "handle_init_end", function_); #if TVM_LLVM_VERSION >= 110 llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, llvm::Align(align)); #elif TVM_LLVM_VERSION >= 80 @@ -811,22 +818,22 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { llvm::Value* out = WithFunctionEntry([&]() { return builder_->CreateAlloca(t_tvm_func_handle_); }); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, - llvm::Align(gv_mod_ctx_->getAlignment())); + llvm::LoadInst* ctx_load = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, + llvm::Align(gv_mod_ctx_->getAlignment())); #elif TVM_LLVM_VERSION >= 80 - llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, - gv_mod_ctx_->getAlignment()); + llvm::LoadInst* ctx_load = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, + gv_mod_ctx_->getAlignment()); #else - llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment()); + llvm::LoadInst* ctx_load = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment()); #endif - ctx->setMetadata("tbaa", - md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); + ctx_load->setMetadata( + "tbaa", md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); #if TVM_LLVM_VERSION >= 90 auto env_callee = llvm::FunctionCallee(ftype_tvm_get_func_from_env_, RuntimeTVMGetFuncFromEnv()); #else auto env_callee = RuntimeTVMGetFuncFromEnv(); #endif - llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx, GetConstString(fname), out}); + llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx_load, GetConstString(fname), out}); init_block = CheckCallSuccess(retcode); #if TVM_LLVM_VERSION >= 110 llvm::Value* loaded_handle = @@ -946,13 +953,14 @@ llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { ICHECK_EQ(op->args.size(), 6U); PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, op->args[4].as()->value, true); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); // Get traced value. llvm::Value* traced_value = MakeValue(op->args[5]); // The update_block handles case when we need to update the return value. - llvm::BasicBlock* update_block = llvm::BasicBlock::Create(*ctx_, "update_block", function_); + llvm::BasicBlock* update_block = llvm::BasicBlock::Create(*ctx, "update_block", function_); // The continue_block handles case when we need to return original // traced value. - llvm::BasicBlock* continue_block = llvm::BasicBlock::Create(*ctx_, "continue_block", function_); + llvm::BasicBlock* continue_block = llvm::BasicBlock::Create(*ctx, "continue_block", function_); // Check the ret_type_code and create cmp instruction. llvm::Value* cmp = @@ -1254,14 +1262,15 @@ class MetadataSerializerLLVM : public AttrVisitor { }; void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) { + llvm::LLVMContext* ctx = llvm_target_->GetContext(); MetadataLlvmTypes llvm_types{ t_float64_ /* t_float64 */, - llvm::Type::getInt8Ty(*ctx_) /* t_uint8 */, + llvm::Type::getInt8Ty(*ctx) /* t_uint8 */, t_int64_ /* t_int64 */, - llvm::Type::getInt8Ty(*ctx_) /* t_bool */, + llvm::Type::getInt8Ty(*ctx) /* t_bool */, t_char_->getPointerTo() /* t_cstring */, t_void_p_ /* t_void_p */, - llvm::StructType::create(*ctx_, {t_int8_, t_int8_, t_int8_}, "DLDataType") /* t_data_type */, + llvm::StructType::create(*ctx, {t_int8_, t_int8_, t_int8_}, "DLDataType") /* t_data_type */, }; // create sample ConstantInfoMetadata instance for MetadataTypeDefiner @@ -1278,7 +1287,7 @@ void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) { metadata::DiscoverComplexTypesVisitor discover_complex{&queue}; discover_complex.Discover(metadata); - MetadataTypeDefiner definer{ctx_, &llvm_types}; + MetadataTypeDefiner definer{ctx, &llvm_types}; for (auto md : queue) { if (md.defined()) { definer.DefineType(md); @@ -1295,7 +1304,7 @@ void CodeGenCPU::DefineMetadata(runtime::metadata::Metadata metadata) { function_->setCallingConv(llvm::CallingConv::C); function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); - llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx, "entry", function_); builder_->SetInsertPoint(entry_point_entry); auto ret_values_p = builder_->CreateBitCast(GetArg(function_, 3), t_void_p_->getPointerTo()); @@ -1350,7 +1359,8 @@ void CodeGenCPU::DefineFunctionRegistry(Array func_names) { function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, "TVMSystemLibEntryPoint", module_.get()); SetTargetAttributes(function_); - llvm::BasicBlock* entry_point_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + llvm::BasicBlock* entry_point_entry = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", function_); builder_->SetInsertPoint(entry_point_entry); builder_->CreateRet(builder_->CreateBitCast(module, t_void_p_)); } @@ -1361,7 +1371,8 @@ void CodeGenCPU::AddStartupFunction() { function_ = llvm::Function::Create(ftype, llvm::Function::InternalLinkage, "__tvm_module_startup", module_.get()); SetTargetAttributes(function_); - llvm::BasicBlock* startup_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + llvm::BasicBlock* startup_entry = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "entry", function_); builder_->SetInsertPoint(startup_entry); for (const auto& kv : export_system_symbols_) { llvm::Value* name = GetConstString(kv.first); @@ -1385,7 +1396,8 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::tvm_throw_last_error())) { builder_->CreateRet(ConstInt32(-1)); auto next_block = std::next(builder_->GetInsertBlock()->getIterator()); - llvm::BasicBlock* new_bb = llvm::BasicBlock::Create(*ctx_, "cont", function_, &*next_block); + llvm::BasicBlock* new_bb = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "cont", function_, &*next_block); builder_->SetInsertPoint(new_bb); return ConstInt32(-1); } else if (op->op.same_as(builtin::tvm_struct_get())) { @@ -1443,8 +1455,9 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { os << ", " << op->message.as()->value; } llvm::Value* msg = GetConstString(os.str()); - auto* fail_block = llvm::BasicBlock::Create(*ctx_, "assert_fail", function_); - auto* end_block = llvm::BasicBlock::Create(*ctx_, "assert_end", function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto* fail_block = llvm::BasicBlock::Create(*ctx, "assert_fail", function_); + auto* end_block = llvm::BasicBlock::Create(*ctx, "assert_end", function_); builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_); // fail condition. builder_->SetInsertPoint(fail_block); @@ -1549,4 +1562,5 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_cpu") } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index eec38b122a0b..e0716ac8be2d 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -24,6 +24,8 @@ #ifndef TVM_TARGET_LLVM_CODEGEN_CPU_H_ #define TVM_TARGET_LLVM_CODEGEN_CPU_H_ +#ifdef TVM_LLVM_VERSION + #include #include #include @@ -54,14 +56,16 @@ class Module; namespace tvm { namespace codegen { +class LLVMTarget; + // CPU host code generation class CodeGenCPU : public CodeGenLLVM { public: CodeGenCPU(); virtual ~CodeGenCPU(); - void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, - bool system_lib, bool dynamic_lookup, bool target_c_runtime) override; + void Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, + bool dynamic_lookup, bool target_c_runtime) override; void AddFunction(const PrimFunc& f) override; void AddMainFunction(const std::string& entry_func_name) override; std::unique_ptr Finish() override; @@ -197,4 +201,6 @@ class CodeGenCPU : public CodeGenLLVM { } // namespace codegen } // namespace tvm + +#endif // TVM_LLVM_VERSION #endif // TVM_TARGET_LLVM_CODEGEN_CPU_H_ diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index cab77697164d..1b9233d2ad2f 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -62,7 +62,7 @@ #include "../../runtime/hexagon/hexagon_module.h" #include "../build_common.h" #include "codegen_cpu.h" -#include "llvm_common.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -70,9 +70,9 @@ namespace codegen { // Hexagon code generation class CodeGenHexagon final : public CodeGenCPU { public: - void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, - bool system_lib, bool dynamic_lookup, bool target_c_runtime) override; - void InitTarget(llvm::TargetMachine* tm) final; + void Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, + bool dynamic_lookup, bool target_c_runtime) override; + void InitTarget() final; using CodeGenCPU::VisitStmt_; llvm::Value* VisitExpr_(const BufferLoadNode* op) override; @@ -117,29 +117,30 @@ class CodeGenHexagon final : public CodeGenCPU { llvm::Value* Intrinsic(llvm::Intrinsic::ID, llvm::ArrayRef args); }; -void CodeGenHexagon::Init(const std::string& module_name, llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup, - bool target_c_runtime) { - CodeGenCPU::Init(module_name, tm, ctx, system_lib, dynamic_lookup, target_c_runtime); +void CodeGenHexagon::Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, + bool dynamic_lookup, bool target_c_runtime) { + CodeGenCPU::Init(module_name, llvm_target, system_lib, dynamic_lookup, target_c_runtime); } -void CodeGenHexagon::InitTarget(llvm::TargetMachine* tm) { - native_vector_bits_ = 64; // Assume "scalar" vectors at first. - llvm::StringRef fs = tm->getTargetFeatureString(); - size_t npos = llvm::StringRef::npos; +void CodeGenHexagon::InitTarget() { + native_vector_bits_ = 64; // Assume "scalar" vectors at first. const auto hvx_length_feature = "+hvx-length"; // +hvx-length{64|128}b - size_t len_begin = fs.find(hvx_length_feature); - size_t len_end = len_begin != npos ? fs.find('b', len_begin) : npos; - if (len_end != npos) { + for (const std::string& f : llvm_target_->GetTargetFeatures()) { + llvm::StringRef fs(f); + if (!fs.startswith(hvx_length_feature)) continue; + + ICHECK(fs.endswith("b")) << "malformed target feature: " << f; int hvx_bytes = 0; - len_begin += std::strlen(hvx_length_feature); - ICHECK(!fs.substr(len_begin, len_end - len_begin).getAsInteger(10, hvx_bytes)) - << "invalid HVX length in feature string: " << fs.str(); + size_t len_begin = std::strlen(hvx_length_feature); + ICHECK(!fs.substr(len_begin, fs.size() - len_begin - 1).getAsInteger(10, hvx_bytes)) + << "invalid HVX length in feature string: " << f; ICHECK(hvx_bytes == 64 || hvx_bytes == 128) << "invalid HVX vector length: " << hvx_bytes << ", should be 64 or 128"; native_vector_bits_ = hvx_bytes * 8; + // There should only be one hvx-length... + break; } - CodeGenLLVM::InitTarget(tm); + CodeGenLLVM::InitTarget(); } llvm::Value* CodeGenHexagon::CreateCallExternQHL(Type ret_type, String global_symbol, @@ -510,9 +511,8 @@ void ProcessLLVMOptions(const std::vector& llvm_vec) { } // namespace runtime::Module BuildHexagon(IRModule mod, Target target) { - // Make sure all targets are registered. InitializeLLVM can be called - // multiple times, after the first call all subsequent calls are no-ops. - InitializeLLVM(); + LLVMInstance llvm_instance; + With llvm_target(llvm_instance, target); auto split = [](const std::string& str, char delim = ' ') { std::vector vec; @@ -552,8 +552,6 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { static bool CallOnce = (ProcessLLVMOptions(llvm_options_vec), true); (void)CallOnce; - std::unique_ptr tm = GetLLVMTargetMachine(target); - std::unique_ptr ctx(new llvm::LLVMContext()); std::unique_ptr cg(new CodeGenHexagon()); std::vector funcs; @@ -574,7 +572,7 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { funcs.emplace_back(f); } - cg->Init("TVMHexagonModule", tm.get(), ctx.get(), false, false, false); + cg->Init("TVMHexagonModule", llvm_target.get(), false, false, false); cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); if (entry_func.length() != 0) { cg->AddMainFunction(entry_func); @@ -586,7 +584,7 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { enum CodeGenFileType { Asm, Obj, IR, BC }; - auto EmitToString = [&tm](const llvm::Module& m, CodeGenFileType cgft) { + auto EmitToString = [&llvm_target](const llvm::Module& m, CodeGenFileType cgft) { std::string out; if (cgft == IR || cgft == BC) { @@ -607,6 +605,7 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { llvm::raw_svector_ostream os(ss); std::unique_ptr cm = llvm::CloneModule(m); llvm::legacy::PassManager pass; + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); ICHECK(tm->addPassesToEmitFile(pass, os, nullptr, ft) == 0) << "Cannot emit target code"; pass.run(*cm.get()); out.assign(ss.c_str(), ss.size()); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index f1d891e2c3bd..305358d079d0 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -89,7 +89,7 @@ #include "../build_common.h" #include "../func_registry_generator.h" #include "codegen_params.h" -#include "llvm_common.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -102,8 +102,8 @@ CodeGenLLVM::CodeGenLLVM() = default; CodeGenLLVM::~CodeGenLLVM() = default; CodeGenLLVM::DebugInfo::~DebugInfo() = default; -std::unique_ptr CodeGenLLVM::Create(llvm::TargetMachine* tm) { - std::string target = tm->getTarget().getName(); +std::unique_ptr CodeGenLLVM::Create(LLVMTarget* llvm_target) { + std::string target = llvm_target->GetOrCreateTargetMachine()->getTarget().getName(); std::string factory_template = "tvm.codegen.llvm.target_"; void* handle = nullptr; if (const PackedFunc* f = runtime::Registry::Get(factory_template + target)) { @@ -121,38 +121,37 @@ std::unique_ptr CodeGenLLVM::Create(llvm::TargetMachine* tm) { } } -void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup, - bool target_c_runtime) { - InitializeLLVM(); - ctx_ = ctx; - builder_.reset(new IRBuilder(*ctx_)); - module_.reset(new llvm::Module(module_name, *ctx_)); - md_builder_.reset(new llvm::MDBuilder(*ctx_)); +void CodeGenLLVM::Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, + bool dynamic_lookup, bool target_c_runtime) { + llvm_target_ = llvm_target; + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + builder_.reset(new IRBuilder(*ctx)); + module_.reset(new llvm::Module(module_name, *ctx)); + md_builder_.reset(new llvm::MDBuilder(*ctx)); // types - t_void_ = llvm::Type::getVoidTy(*ctx_); - t_void_p_ = llvm::Type::getInt8Ty(*ctx_)->getPointerTo(GetGlobalAddressSpace()); - t_int_ = llvm::Type::getInt32Ty(*ctx_); - t_char_ = llvm::Type::getInt8Ty(*ctx_); - t_int8_ = llvm::Type::getInt8Ty(*ctx_); - t_int16_ = llvm::Type::getInt16Ty(*ctx_); - t_int32_ = llvm::Type::getInt32Ty(*ctx_); - t_int64_ = llvm::Type::getInt64Ty(*ctx_); - t_float64_ = llvm::Type::getDoubleTy(*ctx_); + t_void_ = llvm::Type::getVoidTy(*ctx); + t_void_p_ = llvm::Type::getInt8Ty(*ctx)->getPointerTo(GetGlobalAddressSpace()); + t_int_ = llvm::Type::getInt32Ty(*ctx); + t_char_ = llvm::Type::getInt8Ty(*ctx); + t_int8_ = llvm::Type::getInt8Ty(*ctx); + t_int16_ = llvm::Type::getInt16Ty(*ctx); + t_int32_ = llvm::Type::getInt32Ty(*ctx); + t_int64_ = llvm::Type::getInt64Ty(*ctx); + t_float64_ = llvm::Type::getDoubleTy(*ctx); // meta data md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 20, 1); md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa"); md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_); - this->InitTarget(tm); + InitTarget(); } -void CodeGenLLVM::SetFastMathFlag(llvm::FastMathFlags fmf) { builder_->setFastMathFlags(fmf); } +void CodeGenLLVM::SetFastMathFlags(llvm::FastMathFlags fmf) { builder_->setFastMathFlags(fmf); } -void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { +void CodeGenLLVM::InitTarget() { + llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine(); module_->setTargetTriple(tm->getTargetTriple().str()); module_->setDataLayout(tm->createDataLayout()); data_layout_.reset(new llvm::DataLayout(module_.get())); - target_machine_ = tm; if (native_vector_bits_ == 0) { const auto& arch = tm->getTargetTriple().getArch(); if (arch == llvm::Triple::x86_64) { @@ -230,7 +229,8 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { } } } - llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx, "entry", function_); builder_->SetInsertPoint(entry); this->VisitStmt(f->body); @@ -242,7 +242,7 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { if (f != alloc_storage_info_.end()) { unsigned align = f->second.alignment; if (align > 1) { - auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align); + auto attr = llvm::Attribute::get(*ctx, llvm::Attribute::Alignment, align); function_->addParamAttr(i, attr); } } @@ -269,28 +269,16 @@ std::unique_ptr CodeGenLLVM::Finish() { } void CodeGenLLVM::HandleImport(const std::string& code) { + llvm::StringRef code_str(code); std::unique_ptr mlib; - llvm::SMDiagnostic err; - if (code.length() >= 3 && - (code.substr(code.length() - 3) == ".ll" || code.substr(code.length() - 3) == ".bc")) { - mlib = llvm::parseIRFile(code, err, *ctx_); - if (mlib.get() == nullptr) { - std::string msg = std::string(err.getMessage()); - LOG(FATAL) << "Fail to load bitcode file " << code << "\n" - << "line " << err.getLineNo() << ":" << msg; - } + if (code_str.endswith(".ll") || code_str.endswith(".bc")) { + mlib = llvm_target_->GetInstance().LoadIR(code); } else { - std::unique_ptr buf = llvm::MemoryBuffer::getMemBuffer(code); - mlib = llvm::parseIR(*buf, err, *ctx_); - if (mlib.get() == nullptr) { - std::string msg = std::string(err.getMessage()); - LOG(FATAL) << "Fail to load llvm ir " - << "line " << err.getLineNo() << ":" << msg << "\ncontent:\n" - << code; - } + mlib = llvm_target_->GetInstance().ParseIR(code); } - mlib->setTargetTriple(target_machine_->getTargetTriple().str()); - mlib->setDataLayout(target_machine_->createDataLayout()); + + mlib->setTargetTriple(llvm_target_->GetTargetTriple()); + mlib->setDataLayout(llvm_target_->GetOrCreateTargetMachine()->createDataLayout()); // mark all the functions as force inline for (llvm::Function& f : mlib->functions()) { f.removeFnAttr(llvm::Attribute::NoInline); @@ -338,16 +326,15 @@ void CodeGenLLVM::Optimize() { // pass manager FPassManager fpass(module_.get()); MPassManager mpass; - mpass.add(llvm::createTargetTransformInfoWrapperPass( - target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis())); - fpass.add(llvm::createTargetTransformInfoWrapperPass( - target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis())); + llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine(); + mpass.add(llvm::createTargetTransformInfoWrapperPass(tm->getTargetIRAnalysis())); + fpass.add(llvm::createTargetTransformInfoWrapperPass(tm->getTargetIRAnalysis())); // place optimization pass llvm::PassManagerBuilder builder; // Use the same opt-level as specified in TargetMachine for running passes - llvm::CodeGenOpt::Level opt_level = target_machine_->getOptLevel(); + llvm::CodeGenOpt::Level opt_level = llvm_target_->GetOptLevel(); switch (opt_level) { case llvm::CodeGenOpt::Level::None: @@ -376,7 +363,7 @@ void CodeGenLLVM::Optimize() { this->InitPassManagerBuilder(&builder); #if TVM_LLVM_VERSION >= 50 - target_machine_->adjustPassManager(builder); + tm->adjustPassManager(builder); #endif builder.populateFunctionPassManager(fpass); @@ -405,18 +392,19 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { return t_void_; } llvm::Type* etype = nullptr; + llvm::LLVMContext* ctx = llvm_target_->GetContext(); if (dtype.is_int() || dtype.is_uint()) { - etype = llvm::Type::getIntNTy(*ctx_, dtype.bits()); + etype = llvm::Type::getIntNTy(*ctx, dtype.bits()); } else if (dtype.is_float()) { switch (dtype.bits()) { case 16: - etype = llvm::Type::getHalfTy(*ctx_); + etype = llvm::Type::getHalfTy(*ctx); break; case 32: - etype = llvm::Type::getFloatTy(*ctx_); + etype = llvm::Type::getFloatTy(*ctx); break; case 64: - etype = llvm::Type::getDoubleTy(*ctx_); + etype = llvm::Type::getDoubleTy(*ctx); break; default: LOG(FATAL) << "do not support " << dtype; @@ -702,9 +690,10 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va const Var& loop_var, const Stmt& body) { llvm::BasicBlock* pre_block = builder_->GetInsertBlock(); std::string loop_var_name = loop_var->name_hint; - auto* for_begin = llvm::BasicBlock::Create(*ctx_, "for_begin_" + loop_var_name, function_); - auto* for_body = llvm::BasicBlock::Create(*ctx_, "for_body_" + loop_var_name, function_); - auto* for_end = llvm::BasicBlock::Create(*ctx_, "for_end_" + loop_var_name, function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto* for_begin = llvm::BasicBlock::Create(*ctx, "for_begin_" + loop_var_name, function_); + auto* for_body = llvm::BasicBlock::Create(*ctx, "for_body_" + loop_var_name, function_); + auto* for_end = llvm::BasicBlock::Create(*ctx, "for_end_" + loop_var_name, function_); builder_->CreateBr(for_begin); builder_->SetInsertPoint(for_begin); llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2); @@ -777,7 +766,7 @@ llvm::Constant* CodeGenLLVM::GetGlobalConstant(llvm::Constant* const_data, const llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { auto it = str_map_.find(str); if (it != str_map_.end()) return it->second; - auto llvm_str = llvm::ConstantDataArray::getString(*ctx_, str); + auto llvm_str = llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), str); auto ptr = GetGlobalConstant(llvm_str, ".str", llvm::GlobalValue::PrivateLinkage); str_map_[str] = ptr; return ptr; @@ -950,11 +939,11 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type } void CodeGenLLVM::SetTargetAttributes(llvm::Function* func) { - llvm::StringRef cpu = target_machine_->getTargetCPU(); + const std::string& cpu = llvm_target_->GetCPU(); if (!cpu.empty()) { func->addFnAttr("target-cpu", cpu); } - llvm::StringRef features = target_machine_->getTargetFeatureString(); + const std::string& features = llvm_target_->GetTargetFeatureString(); if (!features.empty()) { func->addFnAttr("target-features", features); } @@ -980,8 +969,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { // mismatch will have to be treated specially here. // TODO(kparzysz-quic): fix this once TVM prefetch uses the same // type as LLVM. - llvm::Type* return_type = (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef(op)) - : llvm::Type::getVoidTy(*ctx_); + llvm::Type* return_type = + (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef(op)) : t_void_; llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); ICHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " #if TVM_LLVM_VERSION >= 130 @@ -1039,9 +1028,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val); } else if (op->op.same_as(builtin::if_then_else())) { ICHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition"; - auto* then_block = llvm::BasicBlock::Create(*ctx_, "if_then", function_); - auto* else_block = llvm::BasicBlock::Create(*ctx_, "if_else", function_); - auto* end_block = llvm::BasicBlock::Create(*ctx_, "if_end", function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto* then_block = llvm::BasicBlock::Create(*ctx, "if_then", function_); + auto* else_block = llvm::BasicBlock::Create(*ctx, "if_else", function_); + auto* end_block = llvm::BasicBlock::Create(*ctx, "if_end", function_); builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block); builder_->SetInsertPoint(then_block); llvm::Value* then_value = MakeValue(op->args[1]); @@ -1065,7 +1055,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { builder_->CreateRet(ConstInt32(0)); // LLVM allows exactly one terminator in a single basic block // append a new dummy basic block to avoid error. - llvm::BasicBlock* ret_dummy = llvm::BasicBlock::Create(*ctx_, "ret_dummy", function_); + llvm::BasicBlock* ret_dummy = + llvm::BasicBlock::Create(*llvm_target_->GetContext(), "ret_dummy", function_); builder_->SetInsertPoint(ret_dummy); return ret_dummy; } else if (op->op.same_as(builtin::reinterpret())) { @@ -1519,9 +1510,10 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { } void CodeGenLLVM::VisitStmt_(const WhileNode* op) { - auto* while_cond = llvm::BasicBlock::Create(*ctx_, "while_cond", function_); - auto* while_body = llvm::BasicBlock::Create(*ctx_, "while_body", function_); - auto* while_merge = llvm::BasicBlock::Create(*ctx_, "while_merge", function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto* while_cond = llvm::BasicBlock::Create(*ctx, "while_cond", function_); + auto* while_body = llvm::BasicBlock::Create(*ctx, "while_body", function_); + auto* while_merge = llvm::BasicBlock::Create(*ctx, "while_merge", function_); builder_->CreateBr(while_cond); builder_->SetInsertPoint(while_cond); builder_->CreateCondBr(MakeValue(op->condition), while_body, while_merge); @@ -1533,10 +1525,11 @@ void CodeGenLLVM::VisitStmt_(const WhileNode* op) { void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { llvm::Value* cond = MakeValue(op->condition); - auto* then_block = llvm::BasicBlock::Create(*ctx_, "if_then", function_); - auto* end_block = llvm::BasicBlock::Create(*ctx_, "if_end", function_); + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto* then_block = llvm::BasicBlock::Create(*ctx, "if_then", function_); + auto* end_block = llvm::BasicBlock::Create(*ctx, "if_end", function_); if (op->else_case.defined()) { - auto* else_block = llvm::BasicBlock::Create(*ctx_, "if_else", function_); + auto* else_block = llvm::BasicBlock::Create(*ctx, "if_else", function_); builder_->CreateCondBr(cond, then_block, else_block); builder_->SetInsertPoint(then_block); this->VisitStmt(op->then_case); @@ -1555,7 +1548,7 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) { auto data = op->data.value(); - auto array = codegen::NDArrayToLLVMArray(ctx_, data); + auto array = NDArrayToLLVMArray(llvm_target_->GetContext(), data); std::string symbol_name = op->buffer_var->name_hint; llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name); @@ -1673,4 +1666,5 @@ void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index c6129c238c7f..e6321be647aa 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -23,6 +23,7 @@ */ #ifndef TVM_TARGET_LLVM_CODEGEN_LLVM_H_ #define TVM_TARGET_LLVM_CODEGEN_LLVM_H_ + #ifdef TVM_LLVM_VERSION #include @@ -40,7 +41,6 @@ #include #include #include -#include #include #if TVM_LLVM_VERSION >= 140 #include @@ -78,7 +78,6 @@ class Function; class GlobalVariable; class Instruction; class PassManagerBuilder; -class TargetMachine; class DIFile; class DICompileUnit; class MDNode; @@ -93,6 +92,8 @@ class MDBuilder; namespace tvm { namespace codegen { +class LLVMTarget; + using namespace tir; /*! @@ -109,7 +110,7 @@ class CodeGenLLVM : public ExprFunctor, * \param tm The target machine * \return The created llvm generator. */ - static std::unique_ptr Create(llvm::TargetMachine* tm); + static std::unique_ptr Create(LLVMTarget* llvm_target); /*! * \brief Initialize the code generator with given context * \param module_name The name of the module. @@ -121,14 +122,14 @@ class CodeGenLLVM : public ExprFunctor, * \param target_c_runtime If true, generate a module to be executed by the C runtime. In practice * this option influences whether global ctors are used. */ - virtual void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, - bool system_lib, bool dynamic_lookup, bool target_c_runtime); + virtual void Init(const std::string& module_name, LLVMTarget* llvm_target, bool system_lib, + bool dynamic_lookup, bool target_c_runtime); /*! * \brief Turn on fast math flags for floating point operations. * \param fmf FastMathFlags to use for code generation. */ - void SetFastMathFlag(llvm::FastMathFlags fmf); + void SetFastMathFlags(llvm::FastMathFlags fmf); /*! * \brief Compile and add function f to the current module. @@ -229,9 +230,6 @@ class CodeGenLLVM : public ExprFunctor, llvm::Constant* GetGlobalConstant( llvm::Constant* const_data, const std::string& name = "", llvm::GlobalValue::LinkageTypes linkage_type = llvm::GlobalValue::InternalLinkage); - inline llvm::ConstantArray* NDArrayToLLVMArray(::tvm::runtime::NDArray arr) { - return codegen::NDArrayToLLVMArray(ctx_, arr); - } protected: /*! @@ -340,7 +338,7 @@ class CodeGenLLVM : public ExprFunctor, bool is_volatile)> make_instruction); // Initialize target - virtual void InitTarget(llvm::TargetMachine* tm); + virtual void InitTarget(); // Add module startup function if needed. virtual void AddStartupFunction() {} // apply optimization on the module. @@ -476,10 +474,8 @@ class CodeGenLLVM : public ExprFunctor, std::unique_ptr data_layout_; // Internal metabuilder std::unique_ptr md_builder_; - // llvm target machine - llvm::TargetMachine* target_machine_{nullptr}; - // llvm context - llvm::LLVMContext* ctx_{nullptr}; + // llvm target info + LLVMTarget* llvm_target_{nullptr}; // helpful data types llvm::Type* t_void_{nullptr}; llvm::PointerType* t_void_p_{nullptr}; @@ -495,7 +491,7 @@ class CodeGenLLVM : public ExprFunctor, llvm::MDNode* md_tbaa_root_{nullptr}; llvm::MDNode* md_tbaa_alias_set_{nullptr}; // modules to be linked. - std::vector > link_modules_; + std::vector> link_modules_; /*! \brief native vector bits of current targetx*/ int native_vector_bits_{0}; /*! \brief the storage scope of allocation */ @@ -567,5 +563,6 @@ void CodeGenLLVM::AddFunctionsOrdered(IterType begin, IterType end, ConvType pfu } // namespace codegen } // namespace tvm -#endif // LLVM_VERSION + +#endif // TVM_LLVM_VERSION #endif // TVM_TARGET_LLVM_CODEGEN_LLVM_H_ diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index a74274009cf4..c758ca383621 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -56,7 +56,7 @@ #include "../../runtime/cuda/cuda_module.h" #include "../build_common.h" #include "codegen_llvm.h" -#include "llvm_common.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -68,10 +68,11 @@ class CodeGenNVPTX : public CodeGenLLVM { // add function as void return value CodeGenLLVM::AddFunctionInternal(f, true); // annotate as kernel function + llvm::LLVMContext* ctx = llvm_target_->GetContext(); module_->getOrInsertNamedMetadata("nvvm.annotations") ->addOperand(llvm::MDNode::get( - *ctx_, {llvm::ValueAsMetadata::get(function_), llvm::MDString::get(*ctx_, "kernel"), - llvm::ValueAsMetadata::get(ConstInt32(1))})); + *ctx, {llvm::ValueAsMetadata::get(function_), llvm::MDString::get(*ctx, "kernel"), + llvm::ValueAsMetadata::get(ConstInt32(1))})); } void VisitStmt_(const AllocateNode* op) final { @@ -203,10 +204,10 @@ class CodeGenNVPTX : public CodeGenLLVM { llvm::Value* CreateIntrinsic(const CallNode* op) override; protected: - void InitTarget(llvm::TargetMachine* tm) final { + void InitTarget() final { // Maximum vector lane = float4 native_vector_bits_ = 4 * 32; - CodeGenLLVM::InitTarget(tm); + CodeGenLLVM::InitTarget(); } }; @@ -298,15 +299,13 @@ int GetCUDAComputeVersion(const Target& target) { } runtime::Module BuildNVPTX(IRModule mod, Target target) { - InitializeLLVM(); + LLVMInstance llvm_instance; + With llvm_target(llvm_instance, target); + int compute_ver = GetCUDAComputeVersion(target); - std::unique_ptr tm = GetLLVMTargetMachine(target); - std::unique_ptr ctx(new llvm::LLVMContext()); - // careful: cg will hold a naked pointer reference to ctx, so it should - // have a shorter lifetime than the ctx. std::unique_ptr cg(new CodeGenNVPTX()); - cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false, false); + cg->Init("TVMPTXModule", llvm_target.get(), false, false, false); cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end(), [](auto& kv) { ICHECK(kv.second->template IsInstance()) @@ -314,18 +313,13 @@ runtime::Module BuildNVPTX(IRModule mod, Target target) { return Downcast(kv.second); }); + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); const auto* flibdevice_path = tvm::runtime::Registry::Get("tvm_callback_libdevice_path"); if (flibdevice_path != nullptr) { std::string path = (*flibdevice_path)(compute_ver); if (path.length() != 0) { - llvm::SMDiagnostic err; - std::unique_ptr mlib = llvm::parseIRFile(path, err, *ctx); - if (mlib.get() == nullptr) { - std::string msg(err.getMessage()); - LOG(FATAL) << "Fail to load bitcode file " << path << "\n" - << "line " << err.getLineNo() << ":" << msg; - } - mlib->setTargetTriple(tm->getTargetTriple().str()); + std::unique_ptr mlib = llvm_instance.LoadIR(path); + mlib->setTargetTriple(llvm_target->GetTargetTriple()); mlib->setDataLayout(tm->createDataLayout()); cg->AddLinkModule(std::move(mlib)); } @@ -365,4 +359,5 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_nvptx") } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index 2d36e0b022e1..efe15c5c4aac 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -38,6 +38,7 @@ #include #include "codegen_cpu.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -91,9 +92,9 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { const auto to = op->dtype; if (from.is_float() && to.is_float() && from.bits() == 16 && to.bits() == 32) { ICHECK_EQ(from.lanes(), to.lanes()); - CHECK_NOTNULL(target_machine_); + llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine(); - const auto has_avx512 = TargetHasFeature(*target_machine_, "avx512f"); + const auto has_avx512 = TargetHasFeature(*tm, "avx512f"); if (from.lanes() >= 16 && has_avx512) { return CallVectorIntrin( @@ -110,7 +111,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { #if TVM_LLVM_VERSION <= 100 // The intrinsic x86_vcvtph2ps_256 was removed in LLVM 11. - const auto has_f16c = TargetHasFeature(*target_machine_, "f16c"); + const auto has_f16c = TargetHasFeature(*tm, "f16c"); if (from.lanes() >= 8 && has_f16c) { return CallVectorIntrin(llvm::Intrinsic::x86_vcvtph2ps_256, 8, @@ -168,4 +169,5 @@ TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64") } // namespace codegen } // namespace tvm + #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/llvm_common.cc b/src/target/llvm/llvm_common.cc deleted file mode 100644 index 83de839a926e..000000000000 --- a/src/target/llvm/llvm_common.cc +++ /dev/null @@ -1,211 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file llvm_common.cc - */ -#ifdef TVM_LLVM_VERSION - -#include "llvm_common.h" - -#if TVM_LLVM_VERSION >= 140 -#include -#else -#include -#endif -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace tvm { -namespace codegen { - -struct LLVMEnv { - std::mutex mu; - std::atomic all_initialized{false}; - - static LLVMEnv* Global() { - static LLVMEnv inst; - return &inst; - } -}; - -void InitializeLLVM() { - LLVMEnv* e = LLVMEnv::Global(); - if (!e->all_initialized.load(std::memory_order::memory_order_acquire)) { - std::lock_guard lock(e->mu); - if (!e->all_initialized.load(std::memory_order::memory_order_acquire)) { - llvm::InitializeAllTargetInfos(); - llvm::InitializeAllTargets(); - llvm::InitializeAllTargetMCs(); - llvm::InitializeAllAsmParsers(); - llvm::InitializeAllAsmPrinters(); - e->all_initialized.store(true, std::memory_order::memory_order_release); - } - } -} - -void ParseLLVMTargetOptions(const Target& target, std::string* triple, std::string* mcpu, - std::string* mattr, llvm::TargetOptions* options) { - // simple parser - triple->resize(0); - mcpu->resize(0); - mattr->resize(0); - bool soft_float_abi = false; - if (const Optional& v = target->GetAttr("mtriple")) { - *triple = v.value(); - } - if (const Optional& v = target->GetAttr("mcpu")) { - *mcpu = v.value(); - } - if (const Optional>& v = target->GetAttr>("mattr")) { - std::ostringstream os; - bool is_first = true; - for (const String& s : v.value()) { - if (!is_first) { - os << ','; - } - is_first = false; - os << s; - } - *mattr = os.str(); - } - if (const Optional& v = target->GetAttr("mfloat-abi")) { - String value = v.value(); - if (value == "hard") { -#if TVM_LLVM_VERSION < 60 - LOG(FATAL) << "-mfloat-abi hard is only supported for LLVM > 6.0"; -#endif - soft_float_abi = false; - } else if (value == "soft") { - soft_float_abi = true; - } else { - LOG(FATAL) << "invalid -mfloat-abi option " << value; - } - } - if (triple->length() == 0 || *triple == "default") { - *triple = llvm::sys::getDefaultTargetTriple(); - } - // set target option - llvm::TargetOptions& opt = *options; - opt = llvm::TargetOptions(); -#if TVM_LLVM_VERSION < 50 - opt.LessPreciseFPMADOption = true; -#endif - // In clang, these are fed from LangOpts which describe language specific features - // TODO(AndrewZhaoLuo): figure out how these relate to fast math flags - opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; - opt.UnsafeFPMath = false; - opt.NoInfsFPMath = false; - opt.NoNaNsFPMath = true; - if (soft_float_abi) { - opt.FloatABIType = llvm::FloatABI::Soft; - } else { - opt.FloatABIType = llvm::FloatABI::Hard; - } - if (const Optional& v = target->GetAttr("mabi")) { - opt.MCOptions.ABIName = v.value(); - } -} - -std::unique_ptr GetLLVMTargetMachine(const Target& target, bool allow_null) { - std::string target_triple, mcpu, mattr; - llvm::TargetOptions opt; - - ParseLLVMTargetOptions(target, &target_triple, &mcpu, &mattr, &opt); - - if (target_triple.length() == 0 || target_triple == "default") { - target_triple = llvm::sys::getDefaultTargetTriple(); - } - if (mcpu.length() == 0) { - mcpu = "generic"; - } - - std::string err; - const llvm::Target* llvm_target = llvm::TargetRegistry::lookupTarget(target_triple, err); - if (llvm_target == nullptr) { - ICHECK(allow_null) << err << " target_triple=" << target_triple; - return nullptr; - } - - int llvm_opt_level = target->GetAttr("opt-level").value_or(Integer(3)).IntValue(); - llvm::CodeGenOpt::Level llvm_opt; - if (llvm_opt_level <= 0) { - llvm_opt = llvm::CodeGenOpt::None; - } else if (llvm_opt_level == 1) { - llvm_opt = llvm::CodeGenOpt::Less; - } else if (llvm_opt_level == 2) { - llvm_opt = llvm::CodeGenOpt::Default; - } else { - // llvm_opt_level >= 3 - llvm_opt = llvm::CodeGenOpt::Aggressive; - } - - llvm::TargetMachine* tm = llvm_target->createTargetMachine( - target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_, llvm::CodeModel::Small, llvm_opt); - return std::unique_ptr(tm); -} - -std::string LLVMTargetToString(const Target& target) { - std::ostringstream os; - os << "llvm"; - if (Optional mtriple = target->GetAttr("mtriple")) { - os << " -mtriple=" << mtriple.value(); - } - if (Optional mcpu = target->GetAttr("mcpu")) { - os << " -mcpu=" << mcpu.value(); - } - if (Optional> mattr = target->GetAttr>("mattr")) { - bool is_first = true; - os << " -mattr="; - for (const String& attr : mattr.value()) { - if (!is_first) { - os << ","; - } - is_first = false; - os << attr; - } - } - if (Optional mfloat_abo = target->GetAttr("mfloat-abi")) { - os << " -mfloat-abi=" << mfloat_abo.value(); - } - if (Optional mabi = target->GetAttr("mabi")) { - os << " -mabi=" << mabi.value(); - } - return os.str(); -} - -} // namespace codegen -} // namespace tvm -#endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/llvm_common.h b/src/target/llvm/llvm_common.h deleted file mode 100644 index c127b77c03ac..000000000000 --- a/src/target/llvm/llvm_common.h +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file llvm_common.h - * \brief Common utilities for llvm initialization. - */ -#ifndef TVM_TARGET_LLVM_LLVM_COMMON_H_ -#define TVM_TARGET_LLVM_LLVM_COMMON_H_ - -#ifdef _MSC_VER -#pragma warning(disable : 4141 4291 4146 4624) -#endif -#ifdef TVM_LLVM_VERSION - -#include - -#include -#include -#include - -namespace llvm { -class Module; -class Target; -class TargetMachine; -class TargetOptions; -} // namespace llvm - -namespace tvm { - -// The TVM target -class Target; - -namespace codegen { - -/*! - * \brief Initialize LLVM on this process, - * can be called multiple times. - */ -void InitializeLLVM(); - -/*! - * \brief Parse target options - * \param target The TVM target - * \param triple Target triple - * \param mcpu cpu info - * \param options the options - * \param mattr The attributes - */ -void ParseLLVMTargetOptions(const Target& target, std::string* triple, std::string* mcpu, - std::string* mattr, llvm::TargetOptions* options); - -/*! - * \brief Get target machine from TVM target. - * \param target The TVM target - * \param allow_null Whether allow null to be returned. - * \return target machine - */ -std::unique_ptr GetLLVMTargetMachine(const Target& target, - bool allow_null = false); - -/*! - * \brief Convert the TVM's LLVM target to string by extracting only relevant fields - * \param target The TVM target to be extracted - * \return The raw string format for the TVM LLVM target - */ -std::string LLVMTargetToString(const Target& target); - -} // namespace codegen -} // namespace tvm - -#endif // TVM_LLVM_VERSION -#endif // TVM_TARGET_LLVM_LLVM_COMMON_H_ diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc new file mode 100644 index 000000000000..772e71b28724 --- /dev/null +++ b/src/target/llvm/llvm_instance.cc @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifdef TVM_LLVM_VERSION + +#include "llvm_instance.h" + +#include +#include +#include +#if TVM_LLVM_VERSION >= 150 +#include +#else +#include +#endif +#include +#include +#include +#include +#if TVM_LLVM_VERSION >= 140 +#include +#else +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace codegen { + +namespace { +namespace defaults { +static const char* cpu = "generic"; +static const llvm::CodeGenOpt::Level opt_level = llvm::CodeGenOpt::Aggressive; +} // namespace defaults +} // namespace + +namespace { +bool InitializeLLVM() { + static std::atomic_flag initialized = ATOMIC_FLAG_INIT; + if (!initialized.test_and_set()) { + llvm::InitializeAllTargetInfos(); + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + llvm::InitializeAllAsmParsers(); + llvm::InitializeAllAsmPrinters(); + } + return true; +} + +std::string Join(std::string sep, llvm::ArrayRef strings) { + std::string result; + bool is_first = true; + for (const std::string& s : strings) { + if (!is_first) { + result += sep; + } + result += s; + is_first = false; + } + return result; +} + +} // namespace + +// LLVMInstance + +LLVMInstance::LLVMInstance() { + // Call InitializeLLVM before anything else. + static const bool DMLC_ATTRIBUTE_UNUSED init_llvm = InitializeLLVM(); + ctx_ = std::make_shared(); +} + +LLVMInstance::~LLVMInstance() = default; + +std::unique_ptr LLVMInstance::ParseIR(const std::string& llvm_ir) const { + auto buffer = llvm::MemoryBuffer::getMemBuffer(llvm_ir, /*BufferName=*/"", + /*RequiresNullTerminator=*/false); + return ParseBuffer(*buffer); +} + +std::unique_ptr LLVMInstance::LoadIR(const std::string& file_name) const { + llvm::ErrorOr> maybe_buffer = + llvm::MemoryBuffer::getFileAsStream(file_name); + if (std::error_code ec = maybe_buffer.getError()) { + LOG(FATAL) << ec.message(); + } + return ParseBuffer(**maybe_buffer); +} + +std::unique_ptr LLVMInstance::ParseBuffer(const llvm::MemoryBuffer& buffer) const { + llvm::SMDiagnostic error; + std::unique_ptr module = llvm::parseIR(buffer.getMemBufferRef(), error, *ctx_); + if (module == nullptr) { + std::string message; + llvm::raw_string_ostream ostream(message); + error.print(/*ProgName=*/nullptr, ostream, /*ShowColors=*/false, /*ShowKindLabel=*/true); + LOG(FATAL) << ostream.str(); + } + + return module; +} + +// LLVMTarget + +LLVMTarget::LLVMTarget(LLVMInstance& instance, const Target& target) + : instance_(instance), ctx_(instance.GetContext()) { + triple_ = target->GetAttr("mtriple").value_or("default"); + + if (triple_.empty() || triple_ == "default") { + triple_ = llvm::sys::getDefaultTargetTriple(); + } + cpu_ = target->GetAttr("mcpu").value_or(defaults::cpu); + + if (const Optional>& v = target->GetAttr>("mattr")) { + for (const String& s : v.value()) { + attrs_.push_back(s); + } + } + + llvm::FloatABI::ABIType float_abi = llvm::FloatABI::Default; + if (const Optional& v = target->GetAttr("mfloat-abi")) { + String value = v.value(); + if (value == "hard") { + float_abi = llvm::FloatABI::Hard; + } else if (value == "soft") { + float_abi = llvm::FloatABI::Soft; + } else { + LOG(FATAL) << "invalid -mfloat-abi option " << value; + } + } + + // Target options + +#if TVM_LLVM_VERSION < 50 + target_options_.LessPreciseFPMADOption = true; +#endif + // In clang, these are fed from LangOpts which describe language specific features + // TODO(AndrewZhaoLuo): figure out how these relate to fast math flags + target_options_.AllowFPOpFusion = llvm::FPOpFusion::Fast; + target_options_.UnsafeFPMath = false; + target_options_.NoInfsFPMath = false; + target_options_.NoNaNsFPMath = true; + target_options_.FloatABIType = float_abi; + if (const Optional& v = target->GetAttr("mabi")) { + target_options_.MCOptions.ABIName = v.value(); + } + + auto maybe_level = target->GetAttr("opt-level"); + + if (maybe_level.defined()) { + int level = maybe_level.value()->value; + if (level <= 0) { + opt_level_ = llvm::CodeGenOpt::None; + } else if (level == 1) { + opt_level_ = llvm::CodeGenOpt::Less; + } else if (level == 2) { + opt_level_ = llvm::CodeGenOpt::Default; + } else { + // level >= 3 + opt_level_ = llvm::CodeGenOpt::Aggressive; + } + } else { + opt_level_ = defaults::opt_level; + } + + // Fast math options + + auto GetBoolFlag = [&target](llvm::StringRef flag) -> bool { + return target->GetAttr(flag.str()).value_or(Bool(false)); + }; + if (GetBoolFlag("fast-math")) { +#if TVM_LLVM_VERSION >= 60 + fast_math_flags_.setFast(); +#else + fast_math_flags_.setUnsafeAlgebra(); +#endif + } else { +#if TVM_LLVM_VERSION >= 50 + // This option was added in 5.x, and has a boolean argument, + // unlike the rest of options at the time. + fast_math_flags_.setAllowContract(GetBoolFlag("fast-math-contract")); +#endif +#if TVM_LLVM_VERSION >= 70 + fast_math_flags_.setNoNaNs(GetBoolFlag("fast-math-nnan")); + fast_math_flags_.setNoInfs(GetBoolFlag("fast-math-ninf")); + fast_math_flags_.setNoSignedZeros(GetBoolFlag("fast-math-nsz")); + fast_math_flags_.setAllowReciprocal(GetBoolFlag("fast-math-arcp")); + fast_math_flags_.setAllowContract(GetBoolFlag("fast-math-contract")); + fast_math_flags_.setAllowReassoc(GetBoolFlag("fast-math-reassoc")); + fast_math_flags_.setApproxFunc(GetBoolFlag("fast-math-afn")); +#else + // LLVM 4.x, 5.x, and 6.x + if (GetBoolFlag("fast-math-nnan")) fast_math_flags_.setNoNaNs(); + if (GetBoolFlag("fast-math-ninf")) fast_math_flags_.setNoInfs(); + if (GetBoolFlag("fast-math-nsz")) fast_math_flags_.setNoSignedZeros(); + if (GetBoolFlag("fast-math-arcp")) fast_math_flags_.setAllowReciprocal(); +#if TVM_LLVM_VERSION >= 60 + if (GetBoolFlag("fast-math-reassoc")) fast_math_flags_.setAllowReassoc(); + if (GetBoolFlag("fast-math-afn")) fast_math_flags_.setApproxFunc(); +#endif +#endif + } +} + +LLVMTarget::LLVMTarget(LLVMInstance& scope, const std::string& target_str) + : LLVMTarget(scope, Target(target_str)) {} + +LLVMTarget::~LLVMTarget() = default; + +llvm::LLVMContext* LLVMTarget::GetContext() const { + ICHECK(!ctx_.expired()) << "LLVM scope has been deleted"; + return ctx_.lock().get(); +} + +llvm::TargetMachine* LLVMTarget::GetOrCreateTargetMachine(bool allow_missing) { + if (target_machine_) return target_machine_.get(); + + std::string error; + if (const llvm::Target* llvm_instance = llvm::TargetRegistry::lookupTarget(triple_, error)) { + llvm::TargetMachine* tm = + llvm_instance->createTargetMachine(triple_, cpu_, GetTargetFeatureString(), target_options_, + reloc_model_, code_model_, opt_level_); + target_machine_ = std::unique_ptr(tm); + if (!allow_missing) { + ICHECK(target_machine_ != nullptr) << error; + } + } + return target_machine_.get(); +} + +std::string LLVMTarget::GetTargetFeatureString() const { // + return Join(",", attrs_); +} + +std::string LLVMTarget::str() const { + std::ostringstream os; + os << "llvm"; + if (!triple_.empty()) { + os << " -mtriple=" << triple_; + } + if (!cpu_.empty() && cpu_ != defaults::cpu) { + os << " -mcpu=" << cpu_; + } + if (!attrs_.empty()) { + os << " -mattr=" << GetTargetFeatureString(); + } + + switch (target_options_.FloatABIType) { + case llvm::FloatABI::Soft: + os << " -mfloat-abi=soft"; + break; + case llvm::FloatABI::Hard: + os << " -mfloat-abi=hard"; + break; + case llvm::FloatABI::Default: + break; + } + if (!target_options_.MCOptions.ABIName.empty()) { + os << " -mabi=" << target_options_.MCOptions.ABIName; + } + + bool do_individual = true; +#if TVM_LLVM_VERSION >= 60 + if (fast_math_flags_.isFast()) { + os << " -fast-math"; + do_individual = false; + } +#else + if (fast_math_flags_.unsafeAlgebra()) { + os << " -fast-math"; + do_individual = false; + } +#endif + + if (do_individual) { + if (fast_math_flags_.noNaNs()) os << " -fast-math-nnan"; + if (fast_math_flags_.noInfs()) os << " -fast-math-ninf"; + if (fast_math_flags_.noSignedZeros()) os << " -fast-math-nsz"; + if (fast_math_flags_.allowReciprocal()) os << " -fast-math-arcp"; +#if TVM_LLVM_VERSION >= 50 + if (fast_math_flags_.allowContract()) os << " -fast-math-contract"; +#endif +#if TVM_LLVM_VERSION >= 60 + if (fast_math_flags_.allowReassoc()) os << " -fast-math-reassoc"; + if (fast_math_flags_.approxFunc()) os << " -fast-math-afn"; +#endif + } + + if (opt_level_ != defaults::opt_level) { + os << " -opt-level="; + switch (opt_level_) { + case llvm::CodeGenOpt::None: + os << "0"; + break; + case llvm::CodeGenOpt::Less: + os << "1"; + break; + case llvm::CodeGenOpt::Default: + os << "2"; + break; + case llvm::CodeGenOpt::Aggressive: + os << "3"; + break; + } + } + + return os.str(); +} + +std::string LLVMTarget::GetTargetMetadata(const llvm::Module& module) { + if (llvm::Metadata* tvm_target = module.getModuleFlag("tvm_target")) { + auto* mdstr = llvm::cast(tvm_target); + llvm::StringRef meta = mdstr->getString(); + if (meta.startswith("llvm")) { + return meta.str(); + } + } + return "llvm -mtriple " + module.getTargetTriple(); +} + +void LLVMTarget::SetTargetMetadata(llvm::Module* module) const { + module->addModuleFlag(llvm::Module::Warning, "tvm_target", + llvm::MDString::get(*GetContext(), str())); +} + +} // namespace codegen +} // namespace tvm + +#endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h new file mode 100644 index 000000000000..afb6e58deb1f --- /dev/null +++ b/src/target/llvm/llvm_instance.h @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! \file llvm_instance.h + */ +#ifndef TVM_TARGET_LLVM_LLVM_INSTANCE_H_ +#define TVM_TARGET_LLVM_LLVM_INSTANCE_H_ + +#ifdef TVM_LLVM_VERSION + +#include +#if TVM_LLVM_VERSION >= 150 +#include +#else +#include +#endif +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace llvm { +class LLVMContext; +class MemoryBuffer; +class Module; +class TargetMachine; +} // namespace llvm + +namespace tvm { +namespace codegen { + +class LLVMTarget; + +/*! + * \class LLVMInstance + * \brief LLVMInstance is a class that (conceptually) starts and stops LLVM. All + * uses of LLVM should take place within a lifetime of an object of this class. + * + * E.g. + * ```{.cpp} + * { + * LLVMInstance llvm_instance; + * ... + * someFunctionFromLLVM(...); + * ... + * } + * // no more calls to LLVM here + * ``` + * In addition to that, LLVMInstance provides an LLVM context (llvm::LLVMContext). + * The context is a structure in LLVM where common IR constructs are maintained, + * (such as types, constants, etc.) so that they can be identified by their + * address (i.e. pointer comparison). Because of that, it's important to use + * the same context throughout compilation. + * + * At the moment the "starting" of LLVM performs initialization of LLVM, but + * "stopping" doesn't do anything. In the future, if such a need arises, this + * functionality may be extended to perform dlopen/dlclose of the LLVM-based + * code in TVM. + * + * This class provides means to deserialize an LLVM module, either from text + * (in a string), or from a file. In either case, the serialized module can + * be LLVM IR assembly, or binary bitcode enconding. + */ +class LLVMInstance { + public: + /*! + * \brief Constructs LLVMInstance + */ + LLVMInstance(); + /*! + * \brief Destroys LLVMInstance object + */ + ~LLVMInstance(); // Must not be "= default" here in the header file. + + /*! + * \brief Get the LLVM context for this scope. + */ + std::shared_ptr GetContext() const { return ctx_; } + + /*! + * \brief Create `llvm::Module` from a string. + * + * Parse the string in \param llvm_ir, and return the `llvm::Module`. + * At the moment this function will abort if the parsing fails. + * \param llvm_ir string with the LLVM IR assembly or bitcode + * \return created `llvm::Module` + */ + std::unique_ptr ParseIR(const std::string& llvm_ir) const; + /*! + * \brief Load `llvm::Module` from a given file + * + * Read the file \param file_name, and return the `llvm::Module`. + * At the moment this function will abort if reading of the file or creation + * of the module fails. + * \param file_name file with the LLVM IR assembly or bitcode + * \return created `llvm::Module` + */ + std::unique_ptr LoadIR(const std::string& file_name) const; + + private: + std::unique_ptr ParseBuffer(const llvm::MemoryBuffer& buffer) const; + + std::shared_ptr ctx_; +}; + +/*! + * \class LLVMTarget + * \brief Information used by LLVM for code generation for particular target + * + * This class contains all information that LLVM needs for code generation for + * a particular target. Since Target in TVM will soon contain command line + * flags for LLVM, objects of this class will handle saving and restoring + * global LLVM state that may be affected by these flags. This way, code + * generation for each LLVM-based target in TVM will start with the same LLVM + * global state. + * + * Note that objects of this class must be created within the lifetime of an + * LLVMInstance object. + */ +class LLVMTarget { + public: + /*! + * \brief Constructs LLVMTarget from `Target` + * \param scope LLVMInstance object + * \param target TVM Target object for target "llvm" + */ + LLVMTarget(LLVMInstance& scope, const Target& target); // NOLINT(runtime/references) + /*! + * \brief Constructs LLVMTarget from target string + * \param scope LLVMInstance object + * \param target TVM target string for target "llvm" + */ + LLVMTarget(LLVMInstance& scope, const std::string& target_str); // NOLINT(runtime/references) + /*! + * \brief Destroys LLVMTarget object + */ + ~LLVMTarget(); + + /*! + * \brief Returns string representation (as TVM target) of the LLVMTarget + * \return Target string + * + * Note: If the LLVMTarget object was created from a string `s`, the string + * returned here may not be exactly equal to `s`. For example, if the CPU + * was "default", the returned string will have CPU set to the detected host + * CPU. + */ + std::string str() const; + + /*! + * \brief Get the LLVMInstance object from which the LLVMTarget object was + * created + * \return The enclosing LLVMInstance object + */ + const LLVMInstance& GetInstance() const { return instance_; } + /*! + * \brief Get the current LLVM context + * \return the current LLVM context + */ + llvm::LLVMContext* GetContext() const; + /*! + * \brief Return LLVM's `TargetMachine`, or nullptr + * \param allow_missing do not abort if the target machine cannot be created, + * return nullptr instead + * \return Pointer to the `TargetMachine` object (or nullptr if it cannot be + * created, \see allow_missing) + */ + llvm::TargetMachine* GetOrCreateTargetMachine(bool allow_missing = false); + + /*! + * \brief Get the target triple + * \return the target triple + */ + const std::string& GetTargetTriple() const { return triple_; } + /*! + * \brief Get the CPU name + * \return the CPU name: the detected host CPU if the original TVM target + * specified it as "default" + */ + const std::string& GetCPU() const { return cpu_; } + /*! + * \brief Get the list of LLVM target features + * \return array of individual feature strings + */ + llvm::ArrayRef GetTargetFeatures() const { return attrs_; } + /*! + * \brief Get the LLVM target feature string + * \return comma-separated list of LLVM target features + */ + std::string GetTargetFeatureString() const; + /*! + * \brief Get the LLVM target options + * \return `llvm::TargetOptions` object for this target + */ + const llvm::TargetOptions& GetTargetOptions() const { return target_options_; } + /*! + * \brief Get fast math flags + * \return `llvm::FastMathFlags` for this target + */ + llvm::FastMathFlags GetFastMathFlags() const { return fast_math_flags_; } + /*! + * \brief Get the LLVM optimization level + * \return optimization level for this target + */ + llvm::CodeGenOpt::Level GetOptLevel() const { return opt_level_; } + + /*! + * \brief Extract the target string from given `llvm::Module` + * \param module LLVM module with the TVM target string embedded as metadata + * \return the target string from module's metadata + */ + static std::string GetTargetMetadata(const llvm::Module& module); + /*! + * \brief Embed target string as metadata in given `llvm::Module` + * \param module the module to insert the target string into + */ + void SetTargetMetadata(llvm::Module* module) const; + + // Stubs to enable use with `With`. + void EnterWithScope() {} + void ExitWithScope() {} + + private: + const LLVMInstance& instance_; + std::weak_ptr ctx_; + + std::string triple_; + std::string cpu_; + std::vector attrs_; + llvm::TargetOptions target_options_; + llvm::FastMathFlags fast_math_flags_; + llvm::CodeGenOpt::Level opt_level_; + llvm::Reloc::Model reloc_model_ = llvm::Reloc::PIC_; + llvm::CodeModel::Model code_model_ = llvm::CodeModel::Small; + std::shared_ptr target_machine_; +}; + +} // namespace codegen +} // namespace tvm + +#endif // TVM_LLVM_VERSION +#endif // TVM_TARGET_LLVM_LLVM_INSTANCE_H_ diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 69c7632d65ea..9aed66fffc5c 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -51,11 +51,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include @@ -74,7 +76,7 @@ #include "codegen_blob.h" #include "codegen_cpu.h" #include "codegen_llvm.h" -#include "llvm_common.h" +#include "llvm_instance.h" namespace tvm { namespace codegen { @@ -85,398 +87,338 @@ using runtime::TVMRetValue; class LLVMModuleNode final : public runtime::ModuleNode { public: - ~LLVMModuleNode() { - module_owning_ptr_.reset(); - if (ee_ != nullptr) { - ee_->runStaticConstructorsDestructors(true); - delete ee_; - } - } + ~LLVMModuleNode(); const char* type_key() const final { return "llvm"; } - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - if (name == "__tvm_is_system_module") { - bool flag = (module_->getFunction("__tvm_module_startup") != nullptr); - return PackedFunc([flag](TVMArgs args, TVMRetValue* rv) { *rv = flag; }); - } else if (name == "get_func_names") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->function_names_; }); - } else if (name == "get_symbol") { - return PackedFunc(nullptr); - } else if (name == "get_const_vars") { - return PackedFunc(nullptr); - } else if (name == "_get_target_string") { - std::string target_string = LLVMTargetToString(target_); - return PackedFunc([target_string](TVMArgs args, TVMRetValue* rv) { *rv = target_string; }); - } - if (ee_ == nullptr) LazyInitJIT(); - - std::lock_guard lock(mutex_); - - TVMBackendPackedCFunc faddr; - if (name == runtime::symbol::tvm_module_main) { - const char* entry_name = - reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_main)); - ICHECK(entry_name != nullptr) - << "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; - faddr = reinterpret_cast(GetFunctionAddr(entry_name)); - } else { - faddr = reinterpret_cast(GetFunctionAddr(name)); - } - if (faddr == nullptr) return PackedFunc(); - return WrapPackedFunc(faddr, sptr_to_self); + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + + void SaveToFile(const std::string& file_name, const std::string& format) final; + void SaveToBinary(dmlc::Stream* stream) final; + std::string GetSource(const std::string& format) final; + + void Init(const IRModule& mod, const Target& target); + void Init(std::unique_ptr module, std::unique_ptr llvm_instance); + void LoadIR(const std::string& file_name); + bool IsDSOExportable() const final { return true; } + + bool ImplementsFunction(const String& name, bool query_imports) final; + + private: + void LazyInitJIT(); + bool IsCompatibleWithHost(const llvm::TargetMachine* tm) const; + void* GetGlobalAddr(const std::string& name, const LLVMTarget& llvm_target) const; + void* GetFunctionAddr(const std::string& name, const LLVMTarget& llvm_target) const; + + // The LLVM scope object. + std::unique_ptr llvm_instance_; + // JIT lock + std::mutex mutex_; + // execution engine + llvm::ExecutionEngine* ee_{nullptr}; + // The raw pointer to the module. + llvm::Module* module_{nullptr}; + // The unique_ptr owning the module. This becomes empty once JIT has been initialized + // (EngineBuilder takes ownership of the module). + std::unique_ptr module_owning_ptr_; + /* \brief names of the functions declared in this module */ + Array function_names_; +}; + +LLVMModuleNode::~LLVMModuleNode() { + if (ee_ != nullptr) { + ee_->runStaticConstructorsDestructors(true); + delete ee_; } + module_owning_ptr_.reset(); +} - void SaveToFile(const std::string& file_name, const std::string& format) final { - std::string fmt = runtime::GetFileFormat(file_name, format); - std::error_code ecode; +PackedFunc LLVMModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + if (name == "__tvm_is_system_module") { + bool flag = (module_->getFunction("__tvm_module_startup") != nullptr); + return PackedFunc([flag](TVMArgs args, TVMRetValue* rv) { *rv = flag; }); + } else if (name == "get_func_names") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->function_names_; }); + } else if (name == "get_symbol") { + return PackedFunc(nullptr); + } else if (name == "get_const_vars") { + return PackedFunc(nullptr); + } else if (name == "_get_target_string") { + std::string target_string = LLVMTarget::GetTargetMetadata(*module_); + return PackedFunc([target_string](TVMArgs args, TVMRetValue* rv) { *rv = target_string; }); + } + if (ee_ == nullptr) LazyInitJIT(); + + std::lock_guard lock(mutex_); + + TVMBackendPackedCFunc faddr; + With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); + if (name == runtime::symbol::tvm_module_main) { + const char* entry_name = reinterpret_cast( + GetGlobalAddr(runtime::symbol::tvm_module_main, *llvm_target)); + ICHECK(entry_name != nullptr) << "Symbol " << runtime::symbol::tvm_module_main + << " is not presented"; + faddr = reinterpret_cast(GetFunctionAddr(entry_name, *llvm_target)); + } else { + faddr = reinterpret_cast(GetFunctionAddr(name, *llvm_target)); + } + if (faddr == nullptr) return PackedFunc(); + return WrapPackedFunc(faddr, sptr_to_self); +} + +void LLVMModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { + std::string fmt = runtime::GetFileFormat(file_name, format); + std::error_code ecode; #if TVM_LLVM_VERSION <= 70 - llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::F_None); + llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::F_None); #else - llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::OF_None); + llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::OF_None); #endif - ICHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name << " " << ecode.message(); - if (fmt == "o" || fmt == "obj") { + ICHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name << " " << ecode.message(); + if (fmt == "o" || fmt == "obj") { + With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); #if TVM_LLVM_VERSION <= 60 - std::unique_ptr m = llvm::CloneModule(module_); + std::unique_ptr m = llvm::CloneModule(module_); #else - std::unique_ptr m = llvm::CloneModule(*module_); + std::unique_ptr m = llvm::CloneModule(*module_); #endif - llvm::legacy::PassManager pass; - ICHECK(tm_); + llvm::legacy::PassManager pass; + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); #if TVM_LLVM_VERSION <= 60 - ICHECK(tm_->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0) - << "Cannot emit target CGFT_ObjectFile"; + ICHECK(tm->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 90 - ICHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == - 0) - << "Cannot emit target CGFT_ObjectFile"; + ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; #else - ICHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_ObjectFile) == 0) - << "Cannot emit target CGFT_ObjectFile"; + ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; #endif - pass.run(*m); - } else if (fmt == "s" || fmt == "asm") { + pass.run(*m); + } else if (fmt == "s" || fmt == "asm") { + With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); #if TVM_LLVM_VERSION <= 60 - std::unique_ptr m = llvm::CloneModule(module_); + std::unique_ptr m = llvm::CloneModule(module_); #else - std::unique_ptr m = llvm::CloneModule(*module_); + std::unique_ptr m = llvm::CloneModule(*module_); #endif - llvm::legacy::PassManager pass; - ICHECK(tm_); + llvm::legacy::PassManager pass; + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); #if TVM_LLVM_VERSION <= 60 - ICHECK(tm_->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; + ICHECK(tm->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 90 - ICHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, - llvm::TargetMachine::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; + ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == + 0) + << "Cannot emit target CGFT_AssemblyFile"; #else - ICHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; + ICHECK(tm->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; #endif - pass.run(*m); - } else if (fmt == "ll") { - module_->print(dest, nullptr); - } else if (fmt == "bc") { + pass.run(*m); + } else if (fmt == "ll") { + module_->print(dest, nullptr); + } else if (fmt == "bc") { #if TVM_LLVM_VERSION <= 60 - llvm::WriteBitcodeToFile(module_, dest); + llvm::WriteBitcodeToFile(module_, dest); #else - llvm::WriteBitcodeToFile(*module_, dest); + llvm::WriteBitcodeToFile(*module_, dest); #endif - } else { - LOG(FATAL) << "Do not know how to save file " << file_name << " with format=\'" << format - << "\'"; - } - dest.close(); + } else { + LOG(FATAL) << "Do not know how to save file " << file_name << " with format=\'" << format + << "\'"; } + dest.close(); +} - void SaveToBinary(dmlc::Stream* stream) final { - LOG(FATAL) << "LLVMModule: SaveToBinary not supported"; - } +void LLVMModuleNode::SaveToBinary(dmlc::Stream* stream) { + LOG(FATAL) << "LLVMModule: SaveToBinary not supported"; +} - std::string GetSource(const std::string& format) final { - std::string fmt = runtime::GetFileFormat("", format); - std::string type_str; - llvm::SmallString<256> str; - llvm::raw_svector_ostream rso(str); +std::string LLVMModuleNode::GetSource(const std::string& format) { + std::string fmt = runtime::GetFileFormat("", format); + std::string type_str; + llvm::SmallString<256> str; + llvm::raw_svector_ostream rso(str); - if (fmt == "s" || fmt == "asm") { + if (fmt == "s" || fmt == "asm") { + With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); #if TVM_LLVM_VERSION <= 60 - std::unique_ptr m = llvm::CloneModule(module_); + std::unique_ptr m = llvm::CloneModule(module_); #else - std::unique_ptr m = llvm::CloneModule(*module_); + std::unique_ptr m = llvm::CloneModule(*module_); #endif - llvm::legacy::PassManager pass; - ICHECK(tm_); + llvm::legacy::PassManager pass; + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); #if TVM_LLVM_VERSION <= 60 - ICHECK(tm_->addPassesToEmitFile(pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; + ICHECK(tm->addPassesToEmitFile(pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 90 - ICHECK(tm_->addPassesToEmitFile(pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == - 0) - << "Cannot emit target CGFT_AssemblyFile"; + ICHECK(tm->addPassesToEmitFile(pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; #else - ICHECK(tm_->addPassesToEmitFile(pass, rso, nullptr, llvm::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; + ICHECK(tm->addPassesToEmitFile(pass, rso, nullptr, llvm::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; #endif - pass.run(*m); - return rso.str().str(); - } else if (fmt == "" || fmt == "ll") { - std::string type_str; - llvm::raw_string_ostream rso(type_str); - ICHECK(module_ != nullptr); - module_->print(rso, nullptr); - return rso.str(); - } else { - LOG(FATAL) << "Do not know how to get source code with format: " << format << "\'"; - } - return ""; + pass.run(*m); + return rso.str().str(); + } else if (fmt == "" || fmt == "ll") { + std::string type_str; + llvm::raw_string_ostream rso(type_str); + ICHECK(module_ != nullptr); + module_->print(rso, nullptr); + return rso.str(); + } else { + LOG(FATAL) << "Do not know how to get source code with format: " << format << "\'"; } + return ""; +} - void Init(const IRModule& mod, const Target& target) { - InitializeLLVM(); - tm_ = GetLLVMTargetMachine(target); - ctx_ = std::make_shared(); - std::unique_ptr cg = CodeGenLLVM::Create(tm_.get()); - - std::vector funcs; - std::string entry_func; - relay::Runtime runtime = - mod->GetAttr(tvm::attr::kRuntime).value_or(relay::Runtime::Create("cpp")); - bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); - bool target_c_runtime = runtime->name == "crt"; - - for (auto kv : mod->functions) { - if (!kv.second->IsInstance()) { - // (@jroesch): we relax constraints here, Relay functions will just be ignored. - DLOG(INFO) << "Can only lower IR Module with PrimFuncs, but got " - << kv.second->GetTypeKey(); - continue; - } - auto f = Downcast(kv.second); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()); - function_names_.push_back(global_symbol.value()); - if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - entry_func = global_symbol.value(); - } - funcs.push_back(f); - } - // TODO(@jroesch): follow up on this condition. - // ICHECK(funcs.size() > 0); - // TODO(tqchen): remove the entry function behavior as it does not - // makes sense when we start to use multiple modules. - cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime); - - // See https://llvm.org/docs/LangRef.html#fast-math-flags for details - Bool fast_math_all = target->GetAttr("fast-math").value_or(Bool(false)); - Bool fast_math_nnan = target->GetAttr("fast-math-nnan").value_or(Bool(false)); - Bool fast_math_ninf = target->GetAttr("fast-math-ninf").value_or(Bool(false)); - Bool fast_math_nsz = target->GetAttr("fast-math-nsz").value_or(Bool(false)); - Bool fast_math_arcp = target->GetAttr("fast-math-arcp").value_or(Bool(false)); - - llvm::FastMathFlags fmf; - if (fast_math_all) { -#if TVM_LLVM_VERSION >= 60 - fmf.setFast(); -#else - fmf.setUnsafeAlgebra(); -#endif - } +void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { + llvm_instance_ = std::make_unique(); + With llvm_target(*llvm_instance_, target); + llvm::TargetMachine* tm = llvm_target->GetOrCreateTargetMachine(); + std::unique_ptr cg = CodeGenLLVM::Create(llvm_target.get()); - if (fast_math_nnan) { - fmf.setNoNaNs(); - } - if (fast_math_ninf) { - fmf.setNoInfs(); - } - if (fast_math_nsz) { - fmf.setNoSignedZeros(); - } - if (fast_math_arcp) { - fmf.setAllowReciprocal(); - } + std::vector funcs; + std::string entry_func; + relay::Runtime runtime = + mod->GetAttr(tvm::attr::kRuntime).value_or(relay::Runtime::Create("cpp")); + bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); + bool target_c_runtime = runtime->name == "crt"; -#if TVM_LLVM_VERSION >= 60 - Bool fast_math_contract = target->GetAttr("fast-math-contract").value_or(Bool(false)); - Bool fast_math_afn = target->GetAttr("fast-math-afn").value_or(Bool(false)); - Bool fast_math_reassoc = target->GetAttr("fast-math-reassoc").value_or(Bool(false)); - if (fast_math_contract) { - fmf.setAllowContract(true); - } - if (fast_math_afn) { - fmf.setApproxFunc(); + for (auto kv : mod->functions) { + if (!kv.second->IsInstance()) { + // (@jroesch): we relax constraints here, Relay functions will just be ignored. + DLOG(INFO) << "Can only lower IR Module with PrimFuncs, but got " << kv.second->GetTypeKey(); + continue; } - if (fast_math_reassoc) { - fmf.setAllowReassoc(); + auto f = Downcast(kv.second); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()); + function_names_.push_back(global_symbol.value()); + if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + entry_func = global_symbol.value(); } -#endif + funcs.push_back(f); + } + // TODO(@jroesch): follow up on this condition. + // ICHECK(funcs.size() > 0); + // TODO(tqchen): remove the entry function behavior as it does not + // makes sense when we start to use multiple modules. + cg->Init("TVMMod", llvm_target.get(), system_lib, system_lib, target_c_runtime); + cg->SetFastMathFlags(llvm_target->GetFastMathFlags()); + + cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); + if (entry_func.length() != 0) { + cg->AddMainFunction(entry_func); + } - cg->SetFastMathFlag(fmf); + module_owning_ptr_ = cg->Finish(); + module_ = module_owning_ptr_.get(); + llvm_target->SetTargetMetadata(module_); + module_->addModuleFlag(llvm::Module::Override, "Debug Info Version", + llvm::DEBUG_METADATA_VERSION); - cg->AddFunctionsOrdered(funcs.begin(), funcs.end()); - if (entry_func.length() != 0) { - cg->AddMainFunction(entry_func); - } + if (tm->getTargetTriple().isOSDarwin()) { + module_->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); + } - module_owning_ptr_ = cg->Finish(); - module_ = module_owning_ptr_.get(); + std::string verify_errors_storage; + llvm::raw_string_ostream verify_errors(verify_errors_storage); + LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors)) + << "LLVM module verification failed with the following errors: \n" + << verify_errors.str(); +} - module_->addModuleFlag(llvm::Module::Warning, "tvm_target", - llvm::MDString::get(*ctx_, LLVMTargetToString(target))); - module_->addModuleFlag(llvm::Module::Override, "Debug Info Version", - llvm::DEBUG_METADATA_VERSION); +void LLVMModuleNode::Init(std::unique_ptr module, + std::unique_ptr llvm_instance) { + module_owning_ptr_ = std::move(module); + module_ = module_owning_ptr_.get(); + llvm_instance_ = std::move(llvm_instance); +} - if (tm_->getTargetTriple().isOSDarwin()) { - module_->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); - } +void LLVMModuleNode::LoadIR(const std::string& file_name) { + auto llvm_instance = std::make_unique(); + std::unique_ptr module = llvm_instance->LoadIR(file_name); + Init(std::move(module), std::move(llvm_instance)); +} - std::string verify_errors_storage; - llvm::raw_string_ostream verify_errors(verify_errors_storage); - LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors)) - << "LLVM module verification failed with the following errors: \n" - << verify_errors.str(); - target_ = target; - } +bool LLVMModuleNode::ImplementsFunction(const String& name, bool query_imports) { + return std::find(function_names_.begin(), function_names_.end(), name) != function_names_.end(); +} - void Init(std::unique_ptr module, std::shared_ptr ctx) { - InitializeLLVM(); - ctx_ = ctx; - llvm::SMDiagnostic err; - module_owning_ptr_ = std::move(module); - module_ = module_owning_ptr_.get(); - if (module_ == nullptr) { - std::string msg = std::string(err.getMessage()); - LOG(FATAL) << "Fail to load module: " << msg; - } - std::string target_metadata; - llvm::Metadata* tvm_target = module_->getModuleFlag("tvm_target"); - if (tvm_target != nullptr) { - llvm::MDString* pstr = llvm::dyn_cast(tvm_target); - ICHECK(pstr != nullptr); - target_metadata = pstr->getString().str(); - if (!(target_metadata.length() >= 4 && target_metadata.substr(0, 4) == "llvm")) { - target_metadata = "llvm " + target_metadata; - } - } else { - std::ostringstream os; - os << "llvm -mtriple " << module_->getTargetTriple(); - target_metadata = os.str(); - } - target_ = Target(target_metadata); - tm_ = GetLLVMTargetMachine(target_); +void LLVMModuleNode::LazyInitJIT() { + std::lock_guard lock(mutex_); + if (ee_) { + return; } - - void LoadIR(const std::string& file_name) { - auto ctx = std::make_shared(); - llvm::SMDiagnostic err; - auto module = llvm::parseIRFile(file_name, err, *ctx); - if (module == nullptr) { - std::string msg = std::string(err.getMessage()); - LOG(FATAL) << "Fail to load ir file " << file_name << "\n" - << "line " << err.getLineNo() << ":" << msg; - } - Init(std::move(module), ctx); + With llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); + llvm::EngineBuilder builder(std::move(module_owning_ptr_)); + builder.setEngineKind(llvm::EngineKind::JIT); + builder.setOptLevel(llvm::CodeGenOpt::Aggressive); + builder.setMCPU(llvm_target->GetCPU()); + builder.setMAttrs(llvm_target->GetTargetFeatures()); + builder.setTargetOptions(llvm_target->GetTargetOptions()); + auto tm = std::unique_ptr(builder.selectTarget()); + if (!IsCompatibleWithHost(tm.get())) { + LOG(FATAL) << "Cannot run module, architecture mismatch"; } - - bool IsDSOExportable() const final { return true; } - - bool ImplementsFunction(const String& name, bool query_imports) final { - return std::find(function_names_.begin(), function_names_.end(), name) != function_names_.end(); + llvm::DataLayout layout(tm->createDataLayout()); + ICHECK(layout == module_->getDataLayout()) + << "Data layout mismatch between module(" + << module_->getDataLayout().getStringRepresentation() << ")" + << " and ExecutionEngine (" << layout.getStringRepresentation() << ")"; + ee_ = builder.create(tm.release()); + ICHECK(ee_ != nullptr) << "Failed to initialize jit engine for " << module_->getTargetTriple(); + ee_->runStaticConstructorsDestructors(false); + + if (void** ctx_addr = + reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_ctx, *llvm_target))) { + *ctx_addr = this; } + runtime::InitContextFunctions( + [this, &llvm_target](const char* name) { return GetGlobalAddr(name, *llvm_target); }); + // There is a problem when a JITed function contains a call to a runtime function. + // The runtime function (e.g. __truncsfhf2) may not be resolved, and calling it will + // lead to a runtime crash. + // Do name lookup on a symbol that doesn't exist. This will force MCJIT to finalize + // all loaded objects, which will resolve symbols in JITed code. + ee_->getFunctionAddress("__some_name_that_hopefully_doesnt_exist__b49f8aaade5877eaba7583b91"); +} - private: - void LazyInitJIT() { - std::lock_guard lock(mutex_); - if (ee_) { - return; - } - if (!target_.defined()) { - target_ = Target("llvm"); - } - llvm::EngineBuilder builder(std::move(module_owning_ptr_)); - std::string triple, mcpu, mattr; - llvm::TargetOptions opt; - ParseLLVMTargetOptions(target_, &triple, &mcpu, &mattr, &opt); - builder.setEngineKind(llvm::EngineKind::JIT); - builder.setOptLevel(llvm::CodeGenOpt::Aggressive); - if (mcpu.length() != 0) { - builder.setMCPU(mcpu); - } - if (mattr.length() != 0) { - std::vector mattrs{mattr}; - builder.setMAttrs(mattrs); - } - builder.setTargetOptions(opt); - auto tm = std::unique_ptr(builder.selectTarget()); - std::unique_ptr tm_sys = GetLLVMTargetMachine(Target("llvm")); - if (tm_sys->getTargetTriple().getArch() != tm->getTargetTriple().getArch()) { - LOG(FATAL) << "Cannot run module, architecture mismatch " - << " module=" << tm->getTargetTriple().str() - << " system=" << tm_sys->getTargetTriple().str(); - } - llvm::DataLayout layout(tm->createDataLayout()); - ICHECK(layout == module_->getDataLayout()) - << "Data layout mismatch between module(" - << module_->getDataLayout().getStringRepresentation() << ")" - << " and ExecutionEngine (" << layout.getStringRepresentation() << ")"; - ee_ = builder.create(tm.release()); - ICHECK(ee_ != nullptr) << "Failed to initialize jit engine for " << module_->getTargetTriple(); - ee_->runStaticConstructorsDestructors(false); - - if (void** ctx_addr = - reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_ctx))) { - *ctx_addr = this; - } - runtime::InitContextFunctions( - [this](const char* name) { return reinterpret_cast(GetGlobalAddr(name)); }); - // There is a problem when a JITed function contains a call to a runtime function. - // The runtime function (e.g. __truncsfhf2) may not be resolved, and calling it will - // lead to a runtime crash. - // Do name lookup on a symbol that doesn't exist. This will force MCJIT to finalize - // all loaded objects, which will resolve symbols in JITed code. - ee_->getFunctionAddress("__some_name_that_hopefully_doesnt_exist__b49f8aaade5877eaba7583b91"); +bool LLVMModuleNode::IsCompatibleWithHost(const llvm::TargetMachine* tm) const { + With host_target(*llvm_instance_, "llvm"); // FIXME(kparzysz-quic): nesting + auto tm_host = host_target->GetOrCreateTargetMachine(); + if (tm_host->getTargetTriple().getArch() != tm->getTargetTriple().getArch()) { + LOG(INFO) << "Architecture mismatch: module=" << tm->getTargetTriple().str() + << " host=" << tm_host->getTargetTriple().str(); + return false; } + return true; +} - // Get global address from execution engine. - uint64_t GetGlobalAddr(const std::string& name) const { - // first verifies if GV exists. - if (module_->getGlobalVariable(name) != nullptr) { - return ee_->getGlobalValueAddress(name); - } else { - return 0; - } +// Get global address from execution engine. +void* LLVMModuleNode::GetGlobalAddr(const std::string& name, const LLVMTarget& llvm_target) const { + // first verifies if GV exists. + if (module_->getGlobalVariable(name) != nullptr) { + return reinterpret_cast(ee_->getGlobalValueAddress(name)); + } else { + return nullptr; } +} - uint64_t GetFunctionAddr(const std::string& name) const { - // first verifies if GV exists. - if (module_->getFunction(name) != nullptr) { - return ee_->getFunctionAddress(name); - } else { - return 0; - } +void* LLVMModuleNode::GetFunctionAddr(const std::string& name, + const LLVMTarget& llvm_target) const { + // first verifies if GV exists. + if (module_->getFunction(name) != nullptr) { + return reinterpret_cast(ee_->getFunctionAddress(name)); + } else { + return nullptr; } - - // The target configuration string - Target target_; - // JIT lock - std::mutex mutex_; - // execution engine - llvm::ExecutionEngine* ee_{nullptr}; - // The target machine - std::unique_ptr tm_{nullptr}; - // The raw pointer to the module. - llvm::Module* module_{nullptr}; - // The unique_ptr owning the module. This becomes empty once JIT has been initialized - // (EngineBuilder takes ownership of the module). - std::unique_ptr module_owning_ptr_; - // the context. - std::shared_ptr ctx_; - /* \brief names of the functions declared in this module */ - Array function_names_; -}; +} TVM_REGISTER_GLOBAL("target.build.llvm") .set_body_typed([](IRModule mod, Target target) -> runtime::Module { @@ -487,18 +429,15 @@ TVM_REGISTER_GLOBAL("target.build.llvm") TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") .set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module { - Target target = Target(target_str); + auto llvm_instance = std::make_unique(); + With llvm_target(*llvm_instance, target_str); auto n = make_object(); // Generate a LLVM module from an input target string - InitializeLLVM(); - auto tm = GetLLVMTargetMachine(target); - auto ctx = std::make_shared(); - std::unique_ptr module(new llvm::Module(module_name, *ctx)); - // Use a default data layout and target triple - auto triple = tm->getTargetTriple(); - module->setTargetTriple(triple.str()); - module->setDataLayout(tm->createDataLayout()); - n->Init(std::move(module), ctx); + auto module = std::make_unique(module_name, *llvm_target->GetContext()); + llvm_target->SetTargetMetadata(module.get()); + module->setTargetTriple(llvm_target->GetTargetTriple()); + module->setDataLayout(llvm_target->GetOrCreateTargetMachine()->createDataLayout()); + n->Init(std::move(module), std::move(llvm_instance)); return runtime::Module(n); }); @@ -535,38 +474,39 @@ TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll") TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled") .set_body_typed([](std::string target_str) -> bool { - InitializeLLVM(); - Target target = Target(target_str); - return (GetLLVMTargetMachine(target, true) != nullptr); + LLVMInstance llvm_instance; + auto* tm = With(llvm_instance, target_str) + ->GetOrCreateTargetMachine(/*allow_missing=*/true); + return tm != nullptr; }); TVM_REGISTER_GLOBAL("codegen.codegen_blob") .set_body_typed([](std::string data, bool system_lib, std::string llvm_target_string) -> runtime::Module { auto n = make_object(); - auto p = CodeGenBlob(data, system_lib, llvm_target_string); - n->Init(std::move(p.first), p.second); + auto llvm_instance = std::make_unique(); + With llvm_target(*llvm_instance, llvm_target_string); + std::unique_ptr blob = CodeGenBlob(data, system_lib, llvm_target.get()); + n->Init(std::move(blob), std::move(llvm_instance)); return runtime::Module(n); }); runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata, Target target, tvm::relay::Runtime runtime) { - InitializeLLVM(); - auto tm = GetLLVMTargetMachine(target); + auto llvm_instance = std::make_unique(); + With llvm_target(*llvm_instance, target); bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); - auto ctx = std::make_shared(); std::unique_ptr cg{new CodeGenCPU()}; - cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib, - false /* target_c_runtime */); + cg->Init("TVMMetadataMod", llvm_target.get(), system_lib, system_lib, + /*target_c_runtime=*/false); cg->DefineMetadata(metadata); auto mod = cg->Finish(); - mod->addModuleFlag(llvm::Module::Warning, "tvm_target", - llvm::MDString::get(*ctx, LLVMTargetToString(target))); + llvm_target->SetTargetMetadata(mod.get()); mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); - if (tm->getTargetTriple().isOSDarwin()) { + if (llvm_target->GetOrCreateTargetMachine()->getTargetTriple().isOSDarwin()) { mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); } @@ -577,7 +517,7 @@ runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata metadata << verify_errors.str(); auto n = make_object(); - n->Init(std::move(mod), ctx); + n->Init(std::move(mod), std::move(llvm_instance)); auto meta_mod = MetadataModuleCreate(metadata); meta_mod->Import(runtime::Module(n)); @@ -597,24 +537,22 @@ runtime::Module CreateLLVMCrtMetadataModule(const Array& module } } - InitializeLLVM(); - auto tm = GetLLVMTargetMachine(target); + auto llvm_instance = std::make_unique(); + With llvm_target(*llvm_instance, target); bool system_lib = runtime->GetAttr("system-lib").value_or(Bool(false)); bool target_c_runtime = runtime->name == "crt"; ICHECK(system_lib && target_c_runtime) << "For LLVM C-runtime metadata module, must include --system-lib and --runtime=c; " << "got target: " << target->str(); - auto ctx = std::make_shared(); std::unique_ptr cg{new CodeGenCPU()}; - cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib, target_c_runtime); + cg->Init("TVMMetadataMod", llvm_target.operator->(), system_lib, system_lib, target_c_runtime); cg->DefineFunctionRegistry(func_names); auto mod = cg->Finish(); - mod->addModuleFlag(llvm::Module::Warning, "tvm_target", - llvm::MDString::get(*ctx, LLVMTargetToString(target))); + llvm_target->SetTargetMetadata(mod.get()); mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); - if (tm->getTargetTriple().isOSDarwin()) { + if (llvm_target->GetOrCreateTargetMachine()->getTargetTriple().isOSDarwin()) { mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); } @@ -625,7 +563,7 @@ runtime::Module CreateLLVMCrtMetadataModule(const Array& module << verify_errors.str(); auto n = make_object(); - n->Init(std::move(mod), ctx); + n->Init(std::move(mod), std::move(llvm_instance)); for (auto m : modules) { n->Import(m); } diff --git a/src/target/llvm/llvm_module.h b/src/target/llvm/llvm_module.h index 3a50c2c4244f..66492f8152e5 100644 --- a/src/target/llvm/llvm_module.h +++ b/src/target/llvm/llvm_module.h @@ -46,5 +46,4 @@ runtime::Module CreateLLVMCrtMetadataModule(const Array& module } // namespace tvm #endif // TVM_LLVM_VERSION - #endif // TVM_TARGET_LLVM_LLVM_MODULE_H_