Skip to content

Commit

Permalink
[SYCL][Matrix] Extend W/A for more corner cases of AccessChain usage
Browse files Browse the repository at this point in the history
The new corner case is:
AccessChain is used on arrays of Joint Matrices
  • Loading branch information
YuriPlyakhin committed Dec 14, 2024
1 parent 3524739 commit c5ef0b8
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 5 deletions.
59 changes: 57 additions & 2 deletions llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ namespace {
static constexpr char ACCESS_CHAIN[] = "_Z19__spirv_AccessChain";
static constexpr char MATRIX_TYPE[] = "spirv.CooperativeMatrixKHR";

Type *getInnermostType(Type *Ty, unsigned &NestingLevel) {
NestingLevel = 0;
while (auto *ArrayTy = dyn_cast<ArrayType>(Ty)) {
Ty = ArrayTy->getElementType();
++NestingLevel;
}
return Ty;
}

Type *replaceInnermostType(Type *Ty, Type *NewInnermostTy) {
if (auto *ArrayTy = dyn_cast<ArrayType>(Ty))
return ArrayType::get(
replaceInnermostType(ArrayTy->getElementType(), NewInnermostTy),
ArrayTy->getNumElements());
return NewInnermostTy;
}

// This function finds all calls to __spirv_AccessChain function and transforms
// its users and operands to make LLVM IR more SPIR-V friendly.
bool transformAccessChain(Function *F) {
Expand Down Expand Up @@ -64,8 +81,13 @@ bool transformAccessChain(Function *F) {
dyn_cast<Instruction>(CI->getArgOperand(0)->stripPointerCasts());
if (!Ptr || !isa<AllocaInst>(Ptr))
continue;

Type *AllocaTy = cast<AllocaInst>(Ptr)->getAllocatedType();
// It may happen that sycl::joint_matrix class object is wrapped into
// nested arrays. We need to find the innermost type to extract
unsigned NestingLevel = 0;
StructType *WrapperMatrixTy =
dyn_cast<StructType>(cast<AllocaInst>(Ptr)->getAllocatedType());
dyn_cast<StructType>(getInnermostType(AllocaTy, NestingLevel));
if (!WrapperMatrixTy)
continue;
TargetExtType *MatrixTy =
Expand All @@ -81,13 +103,46 @@ bool transformAccessChain(Function *F) {
IRBuilder Builder(CI);
IRBuilderBase::InsertPointGuard IG(Builder);
Builder.SetInsertPointPastAllocas(CI->getFunction());
Alloca = Builder.CreateAlloca(MatrixTy);
Alloca = Builder.CreateAlloca(replaceInnermostType(AllocaTy, MatrixTy));
Alloca->takeName(Ptr);
}
Ptr->replaceAllUsesWith(Alloca);
Ptr->dropAllReferences();
Ptr->eraseFromParent();
ModuleChanged = true;

// Update also all getelementptr instructions which use the new alloca
SmallVector<GetElementPtrInst *, 4> GEPsToReplace;
for (auto *User : Alloca->users()) {
auto *GEP = dyn_cast<GetElementPtrInst>(User);
if (!GEP)
continue;
Value *LastIndex = GEP->getOperand(GEP->getNumOperands() - 1);
if ((GEP->getNumIndices() == NestingLevel + 1) &&
(!isa<ConstantInt>(LastIndex) ||
!cast<ConstantInt>(LastIndex)->isZero())) {
assert(false && "Unexpected GEP pattern");
continue;
}

GEPsToReplace.push_back(GEP);
}

for (auto *GEP : GEPsToReplace) {
SmallVector<Value *, 3> Indices(GEP->idx_begin(), GEP->idx_end());
// Remove the last index, as nesting level is decreased
if (GEP->getNumIndices() == NestingLevel + 1)
Indices.pop_back();
IRBuilder Builder(GEP);
Value *NewGEP = Builder.CreateInBoundsGEP(
Alloca->getAllocatedType(), GEP->getPointerOperand(), Indices);
NewGEP->takeName(GEP);
GEP->replaceAllUsesWith(NewGEP);
GEP->dropAllReferences();
GEP->eraseFromParent();
}
}

return ModuleChanged;
}
} // namespace
Expand Down
22 changes: 19 additions & 3 deletions llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@

; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s

; CHECK: %[[#Alloca:]] = alloca target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)
; CHECK: %[[#Cast:]] = addrspacecast ptr %[[#Alloca]] to ptr addrspace(4)
; CHECK: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef %[[#Cast]], i64 noundef 0)
; CHECK: %[[#Alloca:]] = alloca target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)
; CHECK-NEXT: [[TC_I:%.*]] = alloca [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]]
; CHECK-NEXT: [[ARRAY_BEGIN_I:%.*]] = getelementptr inbounds [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]], ptr [[TC_I]], i64 0, i64 0
; CHECK-NEXT: [[ARRAYCTOR_END_I:%.*]] = getelementptr inbounds [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]], ptr [[TC_I]], i64 5
; CHECK-NEXT: %[[#Cast:]] = addrspacecast ptr %[[#Alloca]] to ptr addrspace(4)
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef %[[#Cast]], i64 noundef 0)
; CHECK: [[TMP4:%.*]] = addrspacecast ptr [[ARRAY_BEGIN_I]] to ptr addrspace(4)
; CHECK-NEXT: [[TMP5:%.*]] = addrspacecast ptr [[ARRAY_BEGIN_I]] to ptr addrspace(4)
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[TMP4]], i64 noundef 0)
; CHECK: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[TMP5]], i64 noundef 0)

; ModuleID = 'test.bc'
source_filename = "test.cpp"
Expand All @@ -17,9 +24,18 @@ target triple = "spir64-unknown-unknown"
define weak_odr dso_local spir_kernel void @test() {
entry:
%0 = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8
%tC.i = alloca [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], align 8
%array.begin.i = getelementptr inbounds [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], ptr %tC.i, i64 0, i64 0, i64 0
%arrayctor.end.i = getelementptr inbounds [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], ptr %tC.i, i64 5
%1 = addrspacecast ptr %0 to ptr addrspace(4)
%2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %1, i64 noundef 0)
%3 = load i8, ptr addrspace(4) %2
%4 = addrspacecast ptr %array.begin.i to ptr addrspace(4)
%5 = addrspacecast ptr %array.begin.i to ptr addrspace(4)
%6 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %4, i64 noundef 0)
%7 = load i8, ptr addrspace(4) %6
%8 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %5, i64 noundef 0)
%9 = load i8, ptr addrspace(4) %8
ret void
}

Expand Down

0 comments on commit c5ef0b8

Please sign in to comment.