Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport to 9] Fix SPIR-V global to function replacement for differing load types #2241

Merged
merged 1 commit into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions lib/SPIRV/SPIRVUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1877,19 +1877,20 @@ bool isSPIRVBuiltinVariable(GlobalVariable *GV,
/// are accumulated in the AccumulatedOffset parameter, which will eventually be
/// used to figure out which index of a variable is being used.
static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
Function *ReplacementFunc) {
Function *ReplacementFunc,
GlobalVariable *GV) {
const DataLayout &DL = ReplacementFunc->getParent()->getDataLayout();
SmallVector<Instruction *, 4> InstsToRemove;
for (User *U : V->users()) {
if (auto *Cast = dyn_cast<CastInst>(U)) {
replaceUsesOfBuiltinVar(Cast, AccumulatedOffset, ReplacementFunc);
replaceUsesOfBuiltinVar(Cast, AccumulatedOffset, ReplacementFunc, GV);
InstsToRemove.push_back(Cast);
} else if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
APInt NewOffset = AccumulatedOffset.sextOrTrunc(
DL.getIndexSizeInBits(GEP->getPointerAddressSpace()));
if (!GEP->accumulateConstantOffset(DL, NewOffset))
llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
replaceUsesOfBuiltinVar(GEP, NewOffset, ReplacementFunc);
replaceUsesOfBuiltinVar(GEP, NewOffset, ReplacementFunc, GV);
InstsToRemove.push_back(GEP);
} else if (auto *Load = dyn_cast<LoadInst>(U)) {
// Figure out which index the accumulated offset corresponds to. If we
Expand All @@ -1912,7 +1913,12 @@ static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
} else {
// The function has an index parameter.
if (auto *VecTy = dyn_cast<llvm::VectorType>(Load->getType())) {
if (!!Index)
// Reconstruct the original global variable vector because
// the load type may not match.
// global <3 x i64>, load <6 x i32>
VecTy = cast<llvm::VectorType>(GV->getValueType());
if (!!Index || DL.getTypeSizeInBits(VecTy) !=
DL.getTypeSizeInBits(Load->getType()))
llvm_unreachable("Illegal use of a SPIR-V builtin variable");
Replacement = UndefValue::get(VecTy);
for (unsigned I = 0; I < VecTy->getNumElements(); I++) {
Expand All @@ -1922,6 +1928,19 @@ static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
Builder.CreateCall(ReplacementFunc, {Builder.getInt32(I)})),
Builder.getInt32(I));
}
// Insert a bitcast from the reconstructed vector to the load vector
// type in case they are different.
// Input:
// %1 = load <6 x i32>, ptr addrspace(1) %0, align 32
// %2 = extractelement <6 x i32> %1, i32 0
// %3 = add i32 5, %2
// Modified:
// < reconstruct global vector elements 0 and 1 >
// %2 = insertelement <3 x i64> %0, i64 %1, i32 2
// %3 = bitcast <3 x i64> %2 to <6 x i32>
// %4 = extractelement <6 x i32> %3, i32 0
// %5 = add i32 5, %4
Replacement = Builder.CreateBitCast(Replacement, Load->getType());
} else if (Load->getType() == ScalarTy) {
Replacement = setAttrByCalledFunc(Builder.CreateCall(
ReplacementFunc, {Builder.getInt32(Index.getZExtValue())}));
Expand Down Expand Up @@ -1975,7 +1994,7 @@ bool lowerBuiltinVariableToCall(GlobalVariable *GV,
Func->setDoesNotAccessMemory();
}

replaceUsesOfBuiltinVar(GV, APInt(64, 0), Func);
replaceUsesOfBuiltinVar(GV, APInt(64, 0), Func, GV);
return true;
}

Expand Down
28 changes: 28 additions & 0 deletions test/transcoding/builtin_vars_different_type.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
; RUN: llvm-as %s -o %t.bc
; RUN: llvm-spirv %t.bc -o %t.spv -spirv-ext=+SPV_INTEL_vector_compute
; RUN: llvm-spirv -r %t.spv --spirv-target-env=SPV-IR -o %t.out.bc
; RUN: llvm-dis %t.out.bc -o - | FileCheck %s --check-prefix=CHECK-SPV-IR

target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir-unknown-unknown"

@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32

; Function Attrs: nounwind readnone
define spir_kernel void @f() {
entry:
%0 = bitcast <3 x i64> addrspace(1)* @__spirv_BuiltInGlobalInvocationId to <6 x i32> addrspace(1)*
%1 = load <6 x i32>, <6 x i32> addrspace(1)* %0, align 32
%2 = extractelement <6 x i32> %1, i64 0
%3 = add i32 5, %2
ret void
; CHECK-SPV-IR: %[[#ID0:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 0) #1
; CHECK-SPV-IR: %[[#ID1:]] = insertelement <3 x i64> undef, i64 %[[#ID0]], i32 0
; CHECK-SPV-IR: %[[#ID2:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 1) #1
; CHECK-SPV-IR: %[[#ID3:]] = insertelement <3 x i64> %[[#ID1]], i64 %[[#ID2]], i32 1
; CHECK-SPV-IR: %[[#ID4:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 2) #1
; CHECK-SPV-IR: %[[#ID5:]] = insertelement <3 x i64> %[[#ID3]], i64 %[[#ID4]], i32 2
; CHECK-SPV-IR: %[[#ID6:]] = bitcast <3 x i64> %[[#ID5]] to <6 x i32>
; CHECK-SPV-IR: %[[#ID7:]] = extractelement <6 x i32> %[[#ID6]], i32 0
; CHECK-SPV-IR: = add i32 5, %[[#ID7]]
}