Skip to content

Commit

Permalink
Fix SPIR-V global to function replacement for differing load types
Browse files Browse the repository at this point in the history
In some cases, we will see IR with the following

```
@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32

...

%0 = load <6 x i32>, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
%1 = extractelement <6 x i32> %0, i64 0
```

Note the global type and load type are different. Change the handling of vector loads
from vector globals to reconstruct the global vector type and then bitcast to the load type.

Thanks to @jcranmer-intel for helping me find the simpliest solution.

Signed-off-by: Sarnie, Nick <nick.sarnie@intel.com>
  • Loading branch information
sarnex committed Sep 19, 2023
1 parent aab0dac commit af8eb11
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
17 changes: 13 additions & 4 deletions lib/SPIRV/SPIRVUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1989,19 +1989,20 @@ bool isSPIRVBuiltinVariable(GlobalVariable *GV,
/// are accumulated in the AccumulatedOffset parameter, which will eventually be
/// used to figure out which index of a variable is being used.
static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
Function *ReplacementFunc) {
Function *ReplacementFunc,
GlobalVariable *GV) {
const DataLayout &DL = ReplacementFunc->getParent()->getDataLayout();
SmallVector<Instruction *, 4> InstsToRemove;
for (User *U : V->users()) {
if (auto *Cast = dyn_cast<CastInst>(U)) {
replaceUsesOfBuiltinVar(Cast, AccumulatedOffset, ReplacementFunc);
replaceUsesOfBuiltinVar(Cast, AccumulatedOffset, ReplacementFunc, GV);
InstsToRemove.push_back(Cast);
} else if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
APInt NewOffset = AccumulatedOffset.sextOrTrunc(
DL.getIndexSizeInBits(GEP->getPointerAddressSpace()));
if (!GEP->accumulateConstantOffset(DL, NewOffset))
llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
replaceUsesOfBuiltinVar(GEP, NewOffset, ReplacementFunc);
replaceUsesOfBuiltinVar(GEP, NewOffset, ReplacementFunc, GV);
InstsToRemove.push_back(GEP);
} else if (auto *Load = dyn_cast<LoadInst>(U)) {
// Figure out which index the accumulated offset corresponds to. If we
Expand All @@ -2024,6 +2025,10 @@ static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
} else {
// The function has an index parameter.
if (auto *VecTy = dyn_cast<FixedVectorType>(Load->getType())) {
// Reconstruct the original global variable vector because
// the load type may not match.
// global <3 x i64>, load <6 x i32>
VecTy = cast<FixedVectorType>(GV->getValueType());
if (!Index.isZero())
llvm_unreachable("Illegal use of a SPIR-V builtin variable");
Replacement = UndefValue::get(VecTy);
Expand All @@ -2034,6 +2039,10 @@ static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
Builder.CreateCall(ReplacementFunc, {Builder.getInt32(I)})),
Builder.getInt32(I));
}
// Insert a bitcast from the reconstructed vector to the load vector
// type in case they are different. %2 = insertelement <3 x i64> %0,
// i64 %1, i32 2 bitcast <3 x i64> %2 to <6 x i32>
Replacement = Builder.CreateBitCast(Replacement, Load->getType());
} else if (Load->getType() == ScalarTy) {
Replacement = setAttrByCalledFunc(Builder.CreateCall(
ReplacementFunc, {Builder.getInt32(Index.getZExtValue())}));
Expand Down Expand Up @@ -2087,7 +2096,7 @@ bool lowerBuiltinVariableToCall(GlobalVariable *GV,
Func->setDoesNotAccessMemory();
}

replaceUsesOfBuiltinVar(GV, APInt(64, 0), Func);
replaceUsesOfBuiltinVar(GV, APInt(64, 0), Func, GV);
return true;
}

Expand Down
27 changes: 27 additions & 0 deletions test/transcoding/builtin_vars_different_type.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
; RUN: llvm-as %s -o %t.bc
; RUN: llvm-spirv %t.bc -o %t.spv -spirv-ext=+SPV_INTEL_vector_compute
; RUN: llvm-spirv -r %t.spv --spirv-target-env=SPV-IR -o %t.out.bc
; RUN: llvm-dis %t.out.bc -o - | FileCheck %s --check-prefix=CHECK-SPV-IR

target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir-unknown-unknown"

@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32

; Function Attrs: nounwind readnone
define spir_kernel void @f() {
entry:
%0 = load <6 x i32>, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
%1 = extractelement <6 x i32> %0, i64 0
%2 = add i32 5, %1
ret void
; CHECK-SPV-IR: %[[#ID0:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 0) #1
; CHECK-SPV-IR: %[[#ID1:]] = insertelement <3 x i64> undef, i64 %[[#ID0]], i32 0
; CHECK-SPV-IR: %[[#ID2:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 1) #1
; CHECK-SPV-IR: %[[#ID3:]] = insertelement <3 x i64> %[[#ID1]], i64 %[[#ID2]], i32 1
; CHECK-SPV-IR: %[[#ID4:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 2) #1
; CHECK-SPV-IR: %[[#ID5:]] = insertelement <3 x i64> %[[#ID3]], i64 %[[#ID4]], i32 2
; CHECK-SPV-IR: %[[#ID6:]] = bitcast <3 x i64> %[[#ID5]] to <6 x i32>
; CHECK-SPV-IR: %[[#ID7:]] = extractelement <6 x i32> %[[#ID6]], i32 0
; CHECK-SPV-IR: = add i32 5, %[[#ID7]]
}

0 comments on commit af8eb11

Please sign in to comment.