diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 40ec2cc9e0b8..9bc56dc91458 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -487,7 +487,7 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { } llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) { - int num_elems = llvm::cast(vec->getType())->getNumElements(); + int num_elems = GetVectorNumElements(vec); if (extent == num_elems && begin == 0) return vec; CHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n"; std::vector indices; @@ -503,7 +503,7 @@ llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent } llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) { - int num_elems = llvm::cast(vec->getType())->getNumElements(); + int num_elems = GetVectorNumElements(vec); #if TVM_LLVM_VERSION >= 110 std::vector indices; #else @@ -517,7 +517,7 @@ llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) { llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) { llvm::Value* mask = llvm::UndefValue::get(DTypeToLLVMType(DataType::Int(32, target_lanes))); - int num_elems = llvm::cast(vec->getType())->getNumElements(); + int num_elems = GetVectorNumElements(vec); if (num_elems == target_lanes) return vec; CHECK_LT(num_elems, target_lanes); for (int i = 0; i < num_elems; ++i) { @@ -531,15 +531,15 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector vecs) { int total_lanes = 0; for (llvm::Value* v : vecs) { - total_lanes += llvm::cast(v->getType())->getNumElements(); + total_lanes += GetVectorNumElements(v); } while (vecs.size() > 1) { std::vector new_vecs; for (size_t i = 0; i < vecs.size() - 1; i += 2) { llvm::Value* lhs = vecs[i]; llvm::Value* rhs = vecs[i + 1]; - const size_t lhs_lanes = llvm::cast(lhs->getType())->getNumElements(); - const size_t rhs_lanes = llvm::cast(rhs->getType())->getNumElements(); + const size_t lhs_lanes = GetVectorNumElements(lhs); + const size_t rhs_lanes = GetVectorNumElements(rhs); if (lhs_lanes < rhs_lanes) { lhs = CreateVecPad(lhs, rhs_lanes); } else if (rhs_lanes < lhs_lanes) { @@ -843,16 +843,16 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return builder_->CreateFCmpUNO(a, a); } else if (op->op.same_as(builtin::vectorlow())) { llvm::Value* v = MakeValue(op->args[0]); - int l = llvm::cast(v->getType())->getNumElements(); + int l = GetVectorNumElements(v); return CreateVecSlice(v, 0, l / 2); } else if (op->op.same_as(builtin::vectorhigh())) { llvm::Value* v = MakeValue(op->args[0]); - int l = llvm::cast(v->getType())->getNumElements(); + int l = GetVectorNumElements(v); return CreateVecSlice(v, l / 2, l / 2); } else if (op->op.same_as(builtin::vectorcombine())) { llvm::Value* v0 = MakeValue(op->args[0]); llvm::Value* v1 = MakeValue(op->args[1]); - int num_elems = llvm::cast(v0->getType())->getNumElements() * 2; + int num_elems = GetVectorNumElements(v0) * 2; #if TVM_LLVM_VERSION >= 110 std::vector indices; #else diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 3b0ce10534fd..78eb5e2dcac7 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -242,6 +242,11 @@ class CodeGenLLVM : public ExprFunctor, */ llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type, llvm::ArrayRef arg_types); + /*! + * \brief Get the number of elements in the given vector value. + * \param vec The value, must be of a vector type. + */ + inline int GetVectorNumElements(llvm::Value* vec); // initialize the function state. void InitFuncState(); // Get alignment given index. @@ -348,6 +353,15 @@ class CodeGenLLVM : public ExprFunctor, */ static std::unique_ptr CreateDebugInfo(llvm::Module* module); }; + +inline int CodeGenLLVM::GetVectorNumElements(llvm::Value* vec) { +#if TVM_LLVM_VERSION >= 120 + return llvm::cast(vec->getType())->getNumElements(); +#else + return llvm::cast(vec->getType())->getNumElements(); +#endif +} + } // namespace codegen } // namespace tvm #endif // LLVM_VERSION diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index f3362fb0f1eb..a71a0226c958 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -117,7 +117,11 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr llvm::Type* result_ty, const std::vector& args) { llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id, {}); +#if TVM_LLVM_VERSION >= 120 + size_t num_elems = llvm::cast(result_ty)->getNumElements(); +#else size_t num_elems = llvm::cast(result_ty)->getNumElements(); +#endif if (intrin_lanes == num_elems) { return builder_->CreateCall(f, args); } @@ -130,7 +134,7 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr std::vector split_args; for (const auto& v : args) { if (v->getType()->isVectorTy()) { - CHECK_EQ(llvm::cast(v->getType())->getNumElements(), num_elems); + CHECK_EQ(GetVectorNumElements(v), num_elems); split_args.push_back(CreateVecSlice(v, i, intrin_lanes)); } else { split_args.push_back(v);