Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVM] Fix generation of LLVM intrinsics #5282

Merged
merged 3 commits into from
Apr 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 84 additions & 8 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -684,26 +684,102 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const CallNode* op) {
return call;
}

llvm::Function* CodeGenLLVM::GetIntrinsicDecl(
llvm::Intrinsic::ID id, llvm::Type* ret_type,
llvm::ArrayRef<llvm::Type*> arg_types) {
llvm::Module* module = module_.get();

if (!llvm::Intrinsic::isOverloaded(id)) {
return llvm::Intrinsic::getDeclaration(module, id, {});
}

llvm::SmallVector<llvm::Intrinsic::IITDescriptor, 4> infos;
llvm::Intrinsic::getIntrinsicInfoTableEntries(id, infos);
llvm::SmallVector<llvm::Type*, 4> overload_types;

#if TVM_LLVM_VERSION >= 90
auto try_match = [&](llvm::FunctionType* f_ty, bool var_arg) {
overload_types.clear();
llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos);
auto match =
llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types);
if (match == llvm::Intrinsic::MatchIntrinsicTypes_Match) {
bool error = llvm::Intrinsic::matchIntrinsicVarArg(var_arg, ref);
if (error) {
return llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg;
}
}
return match;
};

// First, try matching the signature assuming non-vararg case.
auto* fn_ty = llvm::FunctionType::get(ret_type, arg_types, false);
switch (try_match(fn_ty, false)) {
case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchRet:
// The return type doesn't match, there is nothing else to do.
return nullptr;
case llvm::Intrinsic::MatchIntrinsicTypes_Match:
return llvm::Intrinsic::getDeclaration(module, id, overload_types);
case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg:
break;
}

// Keep adding one type at a time (starting from empty list), and
// try matching the vararg signature.
llvm::SmallVector<llvm::Type*, 4> var_types;
for (int i = 0, e = arg_types.size(); i <= e; ++i) {
if (i > 0) var_types.push_back(arg_types[i - 1]);
auto* ft = llvm::FunctionType::get(ret_type, var_types, true);
if (try_match(ft, true) == llvm::Intrinsic::MatchIntrinsicTypes_Match) {
return llvm::Intrinsic::getDeclaration(module, id, overload_types);
}
}
// Failed to identify the type.
return nullptr;

#else // TVM_LLVM_VERSION
llvm::ArrayRef<llvm::Intrinsic::IITDescriptor> ref(infos);
// matchIntrinsicType returns true on error.
if (llvm::Intrinsic::matchIntrinsicType(ret_type, ref, overload_types)) {
return nullptr;
}
for (llvm::Type* t : arg_types) {
if (llvm::Intrinsic::matchIntrinsicType(t, ref, overload_types)) {
return nullptr;
}
}
return llvm::Intrinsic::getDeclaration(module, id, overload_types);
#endif // TVM_LLVM_VERSION
}

llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
if (op->is_intrinsic("llvm_intrin")) {
CHECK_GE(op->args.size(), 2U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(
Downcast<IntImm>(op->args[0])->value);
int64_t num_signature = Downcast<IntImm>(op->args[1])->value;
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> sig_type;
std::vector<llvm::Type*> arg_type;
for (size_t i = 2; i < op->args.size(); ++i) {
arg_value.push_back(MakeValue(op->args[i]));
if (i - 2 < static_cast<size_t>(num_signature)) {
sig_type.push_back(arg_value.back()->getType());
arg_type.push_back(arg_value.back()->getType());
}
}
llvm::Type *return_type = GetLLVMType(GetRef<PrimExpr>(op));
if (sig_type.size() > 0 && return_type != sig_type[0]) {
sig_type.insert(sig_type.begin(), return_type);
}
llvm::Function* f = llvm::Intrinsic::getDeclaration(
module_.get(), id, sig_type);
// LLVM's prefetch intrinsic returns "void", while TVM's prefetch
// returns int32. This causes problems because prefetch is one of
// those intrinsics that is generated automatically via the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, let us leave a TODO here, hopefully after we introduce the type system to type the intrinsics, we will be able to use the same type as LLVM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

// tvm.intrin.rule mechanism. Any other intrinsic with a type
// mismatch will have to be treated specially here.
// TODO(kparzysz-quic): fix this once TVM prefetch uses the same
// type as LLVM.
llvm::Type *return_type = (id != llvm::Intrinsic::prefetch)
? GetLLVMType(GetRef<PrimExpr>(op))
: llvm::Type::getVoidTy(*ctx_);

llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type);
CHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: "
<< llvm::Intrinsic::getName(id, {});
return builder_->CreateCall(f, arg_value);
} else if (op->is_intrinsic(CallNode::bitwise_and)) {
return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
Expand Down
15 changes: 15 additions & 0 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,21 @@ class CodeGenLLVM :
* \param type The corresponding TVM Type.
*/
llvm::Type* GetLLVMType(const PrimExpr& expr) const;
/*!
* \brief Get the declaration of the LLVM intrinsic based on the intrinsic
* id, and the type of the actual call.
*
* \param id The intrinsic id.
* \param ret_type The call return type.
* \param arg_types The types of the call arguments.
*
* \return Return the llvm::Function pointer, or nullptr if the declaration
* could not be generated (e.g. if the argument/return types do not
* match).
*/
llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id,
llvm::Type* ret_type,
llvm::ArrayRef<llvm::Type*> arg_types);
// initialize the function state.
void InitFuncState();
// Get alignment given index.
Expand Down
6 changes: 3 additions & 3 deletions src/target/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace codegen {
namespace llvm {

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch")
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 0>);
.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>);
Expand All @@ -53,7 +53,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10")
});

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 1>);
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);
Expand Down Expand Up @@ -109,7 +109,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
});

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 1>);
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);
Expand Down
29 changes: 27 additions & 2 deletions tests/python/unittest/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,30 @@ def test_llvm_intrin():
fcode = tvm.build(func, None, "llvm")


def test_llvm_overloaded_intrin():
# Name lookup for overloaded intrinsics in LLVM 4- requires a name
# that includes the overloaded types.
if tvm.target.codegen.llvm_version_major() < 5:
return

def use_llvm_intrinsic(A, C):
ib = tvm.tir.ir_builder.create()
L = A.vload((0,0))
I = tvm.tir.call_llvm_intrin('int32', 'llvm.ctlz',
tvm.tir.const(2, 'uint32'), L, tvm.tir.const(0, 'int1'))
S = C.vstore((0,0), I)
ib.emit(S)
return ib.get()

A = tvm.te.placeholder((1,1), dtype = 'int32', name = 'A')
C = tvm.te.extern((1,1), [A],
lambda ins, outs: use_llvm_intrinsic(ins[0], outs[0]),
name = 'C' , dtype = 'int32')

s = tvm.te.create_schedule(C.op)
f = tvm.build(s, [A, C], target = 'llvm')


def test_llvm_import():
# extern "C" is necessary to get the correct signature
cc_code = """
Expand Down Expand Up @@ -82,9 +106,9 @@ def check_llvm(use_file):

def test_llvm_lookup_intrin():
ib = tvm.tir.ir_builder.create()
m = te.size_var("m")
A = ib.pointer("uint8x8", name="A")
x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.tir.const(1, 'uint32'), A)
z = tvm.tir.const(0, 'int32')
x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z])
ib.emit(x)
body = ib.get()
func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 1, True)
Expand Down Expand Up @@ -680,6 +704,7 @@ def vectorizer(op):
test_llvm_vadd_pipeline()
test_llvm_add_pipeline()
test_llvm_intrin()
test_llvm_overloaded_intrin()
test_llvm_flip_pipeline()
test_llvm_madd_pipeline()
test_llvm_temp_space()
Expand Down
9 changes: 4 additions & 5 deletions topi/python/topi/arm_cpu/bitserial_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def _intrin_func(ins, outs):
ww, xx = ins
zz = outs[0]

args_1 = tvm.tir.const(1, 'uint32')
args_2 = tvm.tir.const(2, 'uint32')

if unipolar:
Expand Down Expand Up @@ -237,10 +236,10 @@ def _instr(index):
cnts8[i] = upper_half + lower_half
for i in range(m//2):
cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
args_1, cnts8[i*2], cnts8[i*2+1])
args_2, cnts8[i*2], cnts8[i*2+1])
for i in range(m//4):
cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
args_1, cnts4[i*2], cnts4[i*2+1])
args_2, cnts4[i*2], cnts4[i*2+1])
cnts = tvm.tir.call_pure_intrin(
full_dtype, 'vectorcombine', cnts2[0], cnts2[1])
shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype)
Expand All @@ -257,10 +256,10 @@ def _instr(index):
cnts8[i] = tvm.tir.popcount(w_ & x_)
for i in range(m//2):
cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
args_1, cnts8[i*2], cnts8[i*2+1])
args_2, cnts8[i*2], cnts8[i*2+1])
for i in range(m//4):
cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd,
args_1, cnts4[i*2], cnts4[i*2+1])
args_2, cnts4[i*2], cnts4[i*2+1])
cnts = tvm.tir.call_pure_intrin(
full_dtype, 'vectorcombine', cnts2[0], cnts2[1])
shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype)
Expand Down