diff --git a/lib/SPIRV/SPIRVUtil.cpp b/lib/SPIRV/SPIRVUtil.cpp index ab244ed719..f1c35aa5e5 100644 --- a/lib/SPIRV/SPIRVUtil.cpp +++ b/lib/SPIRV/SPIRVUtil.cpp @@ -1877,19 +1877,20 @@ bool isSPIRVBuiltinVariable(GlobalVariable *GV, /// are accumulated in the AccumulatedOffset parameter, which will eventually be /// used to figure out which index of a variable is being used. static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset, - Function *ReplacementFunc) { + Function *ReplacementFunc, + GlobalVariable *GV) { const DataLayout &DL = ReplacementFunc->getParent()->getDataLayout(); SmallVector InstsToRemove; for (User *U : V->users()) { if (auto *Cast = dyn_cast(U)) { - replaceUsesOfBuiltinVar(Cast, AccumulatedOffset, ReplacementFunc); + replaceUsesOfBuiltinVar(Cast, AccumulatedOffset, ReplacementFunc, GV); InstsToRemove.push_back(Cast); } else if (auto *GEP = dyn_cast(U)) { APInt NewOffset = AccumulatedOffset.sextOrTrunc( DL.getIndexSizeInBits(GEP->getPointerAddressSpace())); if (!GEP->accumulateConstantOffset(DL, NewOffset)) llvm_unreachable("Illegal GEP of a SPIR-V builtin variable"); - replaceUsesOfBuiltinVar(GEP, NewOffset, ReplacementFunc); + replaceUsesOfBuiltinVar(GEP, NewOffset, ReplacementFunc, GV); InstsToRemove.push_back(GEP); } else if (auto *Load = dyn_cast(U)) { // Figure out which index the accumulated offset corresponds to. If we @@ -1912,7 +1913,12 @@ static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset, } else { // The function has an index parameter. if (auto *VecTy = dyn_cast(Load->getType())) { - if (!!Index) + // Reconstruct the original global variable vector because + // the load type may not match. + // global <3 x i64>, load <6 x i32> + VecTy = cast(GV->getValueType()); + if (!!Index || DL.getTypeSizeInBits(VecTy) != + DL.getTypeSizeInBits(Load->getType())) llvm_unreachable("Illegal use of a SPIR-V builtin variable"); Replacement = UndefValue::get(VecTy); for (unsigned I = 0; I < VecTy->getNumElements(); I++) { @@ -1922,6 +1928,19 @@ static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset, Builder.CreateCall(ReplacementFunc, {Builder.getInt32(I)})), Builder.getInt32(I)); } + // Insert a bitcast from the reconstructed vector to the load vector + // type in case they are different. + // Input: + // %1 = load <6 x i32>, ptr addrspace(1) %0, align 32 + // %2 = extractelement <6 x i32> %1, i32 0 + // %3 = add i32 5, %2 + // Modified: + // < reconstruct global vector elements 0 and 1 > + // %2 = insertelement <3 x i64> %0, i64 %1, i32 2 + // %3 = bitcast <3 x i64> %2 to <6 x i32> + // %4 = extractelement <6 x i32> %3, i32 0 + // %5 = add i32 5, %4 + Replacement = Builder.CreateBitCast(Replacement, Load->getType()); } else if (Load->getType() == ScalarTy) { Replacement = setAttrByCalledFunc(Builder.CreateCall( ReplacementFunc, {Builder.getInt32(Index.getZExtValue())})); @@ -1975,7 +1994,7 @@ bool lowerBuiltinVariableToCall(GlobalVariable *GV, Func->setDoesNotAccessMemory(); } - replaceUsesOfBuiltinVar(GV, APInt(64, 0), Func); + replaceUsesOfBuiltinVar(GV, APInt(64, 0), Func, GV); return true; } diff --git a/test/transcoding/builtin_vars_different_type.ll b/test/transcoding/builtin_vars_different_type.ll new file mode 100644 index 0000000000..4dc838ac03 --- /dev/null +++ b/test/transcoding/builtin_vars_different_type.ll @@ -0,0 +1,27 @@ +; RUN: llvm-as %s -o %t.bc +; RUN: llvm-spirv %t.bc -o %t.spv -spirv-ext=+SPV_INTEL_vector_compute +; RUN: llvm-spirv -r %t.spv --spirv-target-env=SPV-IR -o %t.out.bc +; RUN: llvm-dis %t.out.bc -o - | FileCheck %s --check-prefix=CHECK-SPV-IR + +target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" +target triple = "spir-unknown-unknown" + +@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32 + +; Function Attrs: nounwind readnone +define spir_kernel void @f() { +entry: + %0 = load <6 x i32>, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32 + %1 = extractelement <6 x i32> %0, i64 0 + %2 = add i32 5, %1 + ret void +; CHECK-SPV-IR: %[[#ID0:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 0) #1 +; CHECK-SPV-IR: %[[#ID1:]] = insertelement <3 x i64> undef, i64 %[[#ID0]], i32 0 +; CHECK-SPV-IR: %[[#ID2:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 1) #1 +; CHECK-SPV-IR: %[[#ID3:]] = insertelement <3 x i64> %[[#ID1]], i64 %[[#ID2]], i32 1 +; CHECK-SPV-IR: %[[#ID4:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 2) #1 +; CHECK-SPV-IR: %[[#ID5:]] = insertelement <3 x i64> %[[#ID3]], i64 %[[#ID4]], i32 2 +; CHECK-SPV-IR: %[[#ID6:]] = bitcast <3 x i64> %[[#ID5]] to <6 x i32> +; CHECK-SPV-IR: %[[#ID7:]] = extractelement <6 x i32> %[[#ID6]], i32 0 +; CHECK-SPV-IR: = add i32 5, %[[#ID7]] +}