Skip to content

Commit

Permalink
[LLVM] Use Var annotation in LetStmt for pointer type (#14570)
Browse files Browse the repository at this point in the history
* [LLVM] Use Var annotation in LetStmt for pointer type

TIR has type-annotations on variables, but not on each PrimExpr.
Because some versions of LLVM (before LLVM 15) default to using typed
pointers, passing LLVM validation may require correct pointer types.
This is typically encountered in unpacking of PackedFunc arguments,
where the LLVM type of a data array must be cast from `i8*` to a
pointer to the buffer type.  These pointer-to-pointer casts are not
expressible in TIR, because the `tir::CastNode` represents all
pointers as `DataType::Handle()`, and so must be handled during
codegen instead of earlier in the lowering flow.

* Resolve arg-type incompatibility with re-used context functions

* Handle LLVM codegen for custom data types

* Applied pointer-type fix to external callee as well

* Use correct callee from GetContextPtr
  • Loading branch information
Lunderberg authored May 15, 2023
1 parent 5566c3e commit b6d7ce6
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 24 deletions.
42 changes: 21 additions & 21 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -465,31 +465,31 @@ llvm::Value* CodeGenCPU::CreateCallExtern(Type ret_type, String global_symbol,
}
llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_types, false);
// Check if it is available in global function table as injected function.
auto it = gv_func_map_.find(global_symbol);
if (it != gv_func_map_.end()) {
if (it->second == nullptr) {
gv_func_map_[global_symbol] = InitContextPtr(ftype->getPointerTo(), "__" + global_symbol);
it = gv_func_map_.find(global_symbol);
}
#if TVM_LLVM_VERSION >= 90
auto ext_callee = llvm::FunctionCallee(ftype, GetContextPtr(it->second));
#else
auto ext_callee = GetContextPtr(it->second);
#endif
return builder_->CreateCall(ext_callee, arg_values);
} else {
llvm::Function* f = module_->getFunction(MakeStringRef(global_symbol));
if (f == nullptr) {
f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
MakeStringRef(global_symbol), module_.get());

auto callee = [&]() -> llvm::Value* {
if (auto it = gv_func_map_.find(global_symbol); it != gv_func_map_.end()) {
if (it->second == nullptr) {
it->second = InitContextPtr(ftype->getPointerTo(), "__" + global_symbol);
}
return GetContextPtr(it->second);
} else if (llvm::Function* f = module_->getFunction(MakeStringRef(global_symbol))) {
return f;
} else {
return llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
MakeStringRef(global_symbol), module_.get());
}
}();

if (callee->getType() != ftype->getPointerTo()) {
callee = builder_->CreatePointerCast(callee, ftype->getPointerTo());
}

#if TVM_LLVM_VERSION >= 90
auto ext_callee = llvm::FunctionCallee(f);
auto ext_callee = llvm::FunctionCallee(ftype, callee);
#else
auto ext_callee = f;
auto ext_callee = f;
#endif
return builder_->CreateCall(ext_callee, arg_values);
}
return builder_->CreateCall(ext_callee, arg_values);
}

llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string name) {
Expand Down
23 changes: 20 additions & 3 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -548,10 +548,11 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const {
if (auto* ptr = type.as<PrimTypeNode>()) {
return DTypeToLLVMType(ptr->dtype);
} else if (auto* ptr = type.as<PointerTypeNode>()) {
// LLVM IR doesn't allow void*, so we need to recognize this
// pattern explicitly.
// LLVM IR doesn't allow void*, nor do we require custom datatypes
// to have LLVM equivalents, so we need to recognize these
// patterns explicitly.
if (auto* primtype = ptr->element_type.as<PrimTypeNode>()) {
if (primtype->dtype.is_void()) {
if (primtype->dtype.is_void() || primtype->dtype.code() >= DataType::kCustomBegin) {
return t_void_p_;
}
}
Expand Down Expand Up @@ -1975,6 +1976,22 @@ void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) {
}
}
llvm::Value* value = MakeValue(op->value);

// TIR has type-annotations on variables, but not on each PrimExpr.
// Therefore, to have the correct LLVM type for pointers, we may
// need to introduce a pointer-cast, even though pointer-to-pointer
// casts are not expressible with the `tir::CastNode`.
if (v->dtype.is_handle() && v->type_annotation.defined()) {
CHECK(op->value->dtype.is_handle())
<< "Variable " << op->var << " is a pointer with type " << op->value
<< ", but is being bound to expression with type " << op->value->dtype;
auto* llvm_type = GetLLVMType(v->type_annotation);
if (llvm_type != value->getType()) {
value->setName((v->name_hint + "_void_ptr").c_str());
value = builder_->CreatePointerCast(value, llvm_type);
}
}

value->setName(v->name_hint.c_str());
var_map_[v] = value;
analyzer_->Bind(op->var, op->value);
Expand Down

0 comments on commit b6d7ce6

Please sign in to comment.