From b6d7ce6a8a3803f917763e9570d10fea0ca1fffa Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 15 May 2023 11:41:56 -0500 Subject: [PATCH] [LLVM] Use Var annotation in LetStmt for pointer type (#14570) * [LLVM] Use Var annotation in LetStmt for pointer type TIR has type-annotations on variables, but not on each PrimExpr. Because some versions of LLVM (before LLVM 15) default to using typed pointers, passing LLVM validation may require correct pointer types. This is typically encountered in unpacking of PackedFunc arguments, where the LLVM type of a data array must be cast from `i8*` to a pointer to the buffer type. These pointer-to-pointer casts are not expressible in TIR, because the `tir::CastNode` represents all pointers as `DataType::Handle()`, and so must be handled during codegen instead of earlier in the lowering flow. * Resolve arg-type incompatibility with re-used context functions * Handle LLVM codegen for custom data types * Applied pointer-type fix to external callee as well * Use correct callee from GetContextPtr --- src/target/llvm/codegen_cpu.cc | 42 ++++++++++++++++----------------- src/target/llvm/codegen_llvm.cc | 23 +++++++++++++++--- 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index dbcdb4a3af87..7677e61ea614 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -465,31 +465,31 @@ llvm::Value* CodeGenCPU::CreateCallExtern(Type ret_type, String global_symbol, } llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_types, false); // Check if it is available in global function table as injected function. - auto it = gv_func_map_.find(global_symbol); - if (it != gv_func_map_.end()) { - if (it->second == nullptr) { - gv_func_map_[global_symbol] = InitContextPtr(ftype->getPointerTo(), "__" + global_symbol); - it = gv_func_map_.find(global_symbol); - } -#if TVM_LLVM_VERSION >= 90 - auto ext_callee = llvm::FunctionCallee(ftype, GetContextPtr(it->second)); -#else - auto ext_callee = GetContextPtr(it->second); -#endif - return builder_->CreateCall(ext_callee, arg_values); - } else { - llvm::Function* f = module_->getFunction(MakeStringRef(global_symbol)); - if (f == nullptr) { - f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, - MakeStringRef(global_symbol), module_.get()); + + auto callee = [&]() -> llvm::Value* { + if (auto it = gv_func_map_.find(global_symbol); it != gv_func_map_.end()) { + if (it->second == nullptr) { + it->second = InitContextPtr(ftype->getPointerTo(), "__" + global_symbol); + } + return GetContextPtr(it->second); + } else if (llvm::Function* f = module_->getFunction(MakeStringRef(global_symbol))) { + return f; + } else { + return llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, + MakeStringRef(global_symbol), module_.get()); } + }(); + + if (callee->getType() != ftype->getPointerTo()) { + callee = builder_->CreatePointerCast(callee, ftype->getPointerTo()); + } + #if TVM_LLVM_VERSION >= 90 - auto ext_callee = llvm::FunctionCallee(f); + auto ext_callee = llvm::FunctionCallee(ftype, callee); #else - auto ext_callee = f; + auto ext_callee = f; #endif - return builder_->CreateCall(ext_callee, arg_values); - } + return builder_->CreateCall(ext_callee, arg_values); } llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string name) { diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 01e25d536118..eb53e9b6dc87 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -548,10 +548,11 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const { if (auto* ptr = type.as()) { return DTypeToLLVMType(ptr->dtype); } else if (auto* ptr = type.as()) { - // LLVM IR doesn't allow void*, so we need to recognize this - // pattern explicitly. + // LLVM IR doesn't allow void*, nor do we require custom datatypes + // to have LLVM equivalents, so we need to recognize these + // patterns explicitly. if (auto* primtype = ptr->element_type.as()) { - if (primtype->dtype.is_void()) { + if (primtype->dtype.is_void() || primtype->dtype.code() >= DataType::kCustomBegin) { return t_void_p_; } } @@ -1975,6 +1976,22 @@ void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) { } } llvm::Value* value = MakeValue(op->value); + + // TIR has type-annotations on variables, but not on each PrimExpr. + // Therefore, to have the correct LLVM type for pointers, we may + // need to introduce a pointer-cast, even though pointer-to-pointer + // casts are not expressible with the `tir::CastNode`. + if (v->dtype.is_handle() && v->type_annotation.defined()) { + CHECK(op->value->dtype.is_handle()) + << "Variable " << op->var << " is a pointer with type " << op->value + << ", but is being bound to expression with type " << op->value->dtype; + auto* llvm_type = GetLLVMType(v->type_annotation); + if (llvm_type != value->getType()) { + value->setName((v->name_hint + "_void_ptr").c_str()); + value = builder_->CreatePointerCast(value, llvm_type); + } + } + value->setName(v->name_hint.c_str()); var_map_[v] = value; analyzer_->Bind(op->var, op->value);