Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
YuriPlyakhin committed Dec 18, 2024
1 parent c5ef0b8 commit 068c6f3
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 29 deletions.
47 changes: 42 additions & 5 deletions llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,42 @@ Type *replaceInnermostType(Type *Ty, Type *NewInnermostTy) {
return NewInnermostTy;
}

// This function is a copy of llvm::stripPointerCastsAndOffsets with
// modification to strip non-zero GEP indices as well.
Value *stripPointerCastsAndOffsets(Value *V) {
if (!V->getType()->isPointerTy())
return V;

// Even though we don't look through PHI nodes, we could be called on an
// instruction in an unreachable block, which may be on a cycle.
SmallPtrSet<Value *, 4> Visited;

Visited.insert(V);
do {
if (auto *GEP = dyn_cast<GEPOperator>(V)) {
V = GEP->getPointerOperand();
} else if (Operator::getOpcode(V) == Instruction::BitCast) {
Value *NewV = cast<Operator>(V)->getOperand(0);
if (!NewV->getType()->isPointerTy())
return V;
V = NewV;
} else if (Operator::getOpcode(V) == Instruction::AddrSpaceCast) {
V = cast<Operator>(V)->getOperand(0);
} else {
if (auto *Call = dyn_cast<CallBase>(V)) {
if (Value *RV = Call->getReturnedArgOperand()) {
V = RV;
continue;
}
}
return V;
}
assert(V->getType()->isPointerTy() && "Unexpected operand type!");
} while (Visited.insert(V).second);

return V;
}

// This function finds all calls to __spirv_AccessChain function and transforms
// its users and operands to make LLVM IR more SPIR-V friendly.
bool transformAccessChain(Function *F) {
Expand Down Expand Up @@ -77,8 +113,8 @@ bool transformAccessChain(Function *F) {
// from sycl::joint_matrix class object if it's used in __spirv_AccessChain
// function call. It's necessary because otherwise OpAccessChain indices
// would be wrong.
Instruction *Ptr =
dyn_cast<Instruction>(CI->getArgOperand(0)->stripPointerCasts());
Instruction *Ptr = dyn_cast<Instruction>(
stripPointerCastsAndOffsets(CI->getArgOperand(0)));
if (!Ptr || !isa<AllocaInst>(Ptr))
continue;

Expand Down Expand Up @@ -118,7 +154,7 @@ bool transformAccessChain(Function *F) {
if (!GEP)
continue;
Value *LastIndex = GEP->getOperand(GEP->getNumOperands() - 1);
if ((GEP->getNumIndices() == NestingLevel + 1) &&
if ((GEP->getNumIndices() == NestingLevel + 2) &&
(!isa<ConstantInt>(LastIndex) ||
!cast<ConstantInt>(LastIndex)->isZero())) {
assert(false && "Unexpected GEP pattern");
Expand All @@ -130,8 +166,9 @@ bool transformAccessChain(Function *F) {

for (auto *GEP : GEPsToReplace) {
SmallVector<Value *, 3> Indices(GEP->idx_begin(), GEP->idx_end());
// Remove the last index, as nesting level is decreased
if (GEP->getNumIndices() == NestingLevel + 1)
// Remove the last index, if it is addressing element of struct
// that was removed
if (GEP->getNumIndices() == NestingLevel + 2)
Indices.pop_back();
IRBuilder Builder(GEP);
Value *NewGEP = Builder.CreateInBoundsGEP(
Expand Down
75 changes: 51 additions & 24 deletions llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,24 @@

; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s

; CHECK: %[[#Alloca:]] = alloca target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)
; CHECK-NEXT: [[TC_I:%.*]] = alloca [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]]
; CHECK-NEXT: [[ARRAY_BEGIN_I:%.*]] = getelementptr inbounds [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]], ptr [[TC_I]], i64 0, i64 0
; CHECK-NEXT: [[ARRAYCTOR_END_I:%.*]] = getelementptr inbounds [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]], ptr [[TC_I]], i64 5
; CHECK-NEXT: %[[#Cast:]] = addrspacecast ptr %[[#Alloca]] to ptr addrspace(4)
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef %[[#Cast]], i64 noundef 0)
; CHECK: [[TMP4:%.*]] = addrspacecast ptr [[ARRAY_BEGIN_I]] to ptr addrspace(4)
; CHECK-NEXT: [[TMP5:%.*]] = addrspacecast ptr [[ARRAY_BEGIN_I]] to ptr addrspace(4)
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[TMP4]], i64 noundef 0)
; CHECK: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[TMP5]], i64 noundef 0)
; CHECK: [[ALLOC:%.*]] = alloca target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)
; CHECK-NEXT: [[TC:%.*]] = alloca [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]]
; CHECK-NEXT: [[TC_1:%.*]] = alloca [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]]
; CHECK-NEXT: [[TC_2:%.*]] = alloca [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]]
; CHECK-NEXT: [[CAST:%.*]] = addrspacecast ptr [[ALLOC]] to ptr addrspace(4)
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[CAST]], i64 noundef 0)
; CHECK: [[ARRAY_BEGIN_I:%.*]] = getelementptr inbounds [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]], ptr [[TC]], i64 0, i64 0, i64 0
; CHECK-NEXT: [[ARRAYCTOR_END_I:%.*]] = getelementptr inbounds [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]], ptr [[TC]], i64 1
; CHECK-NEXT: [[CAST_1:%.*]] = addrspacecast ptr [[ARRAY_BEGIN_I]] to ptr addrspace(4)
; CHECK-NEXT: [[CAST_2:%.*]] = addrspacecast ptr [[ARRAY_BEGIN_I]] to ptr addrspace(4)
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[CAST_1]], i64 noundef 0)
; CHECK: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[CAST_2]], i64 noundef 0)
; CHECK: [[TARGET_PTR:%.*]] = getelementptr inbounds [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]], ptr [[TC_1]], i64 0, i64 0, i64 0
; CHECK-NEXT: [[CAST_3:%.*]] = addrspacecast ptr [[TARGET_PTR]] to ptr addrspace(4)
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[CAST_3]], i64 noundef 0)
; CHECK: [[NON_CONST_OFFSET:%.*]] = getelementptr inbounds [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]], ptr [[TC_2]], i64 0, i64 %ind, i64 %ind
; CHECK-NEXT: [[CAST_4:%.*]] = addrspacecast ptr [[NON_CONST_OFFSET]] to ptr addrspace(4)
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[CAST_4]], i64 noundef 0)

; ModuleID = 'test.bc'
source_filename = "test.cpp"
Expand All @@ -21,21 +29,40 @@ target triple = "spir64-unknown-unknown"

%"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) }

define weak_odr dso_local spir_kernel void @test() {
define weak_odr dso_local spir_kernel void @test(i64 %ind) {
entry:
%0 = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8
%tC.i = alloca [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], align 8
%array.begin.i = getelementptr inbounds [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], ptr %tC.i, i64 0, i64 0, i64 0
%arrayctor.end.i = getelementptr inbounds [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], ptr %tC.i, i64 5
%1 = addrspacecast ptr %0 to ptr addrspace(4)
%2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %1, i64 noundef 0)
%3 = load i8, ptr addrspace(4) %2
%4 = addrspacecast ptr %array.begin.i to ptr addrspace(4)
%5 = addrspacecast ptr %array.begin.i to ptr addrspace(4)
%6 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %4, i64 noundef 0)
%7 = load i8, ptr addrspace(4) %6
%8 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %5, i64 noundef 0)
%9 = load i8, ptr addrspace(4) %8
; allocas
%alloc = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8
%tC = alloca [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], align 8
%tC.1 = alloca [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], align 8
%tC.2 = alloca [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], align 8

; simple case
%cast = addrspacecast ptr %alloc to ptr addrspace(4)
%1 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %cast, i64 noundef 0)
%2 = load i8, ptr addrspace(4) %1

; array case
%array.begin.i = getelementptr inbounds [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], ptr %tC, i64 0, i64 0, i64 0
%arrayctor.end.i = getelementptr inbounds [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], ptr %tC, i64 1
%cast.1 = addrspacecast ptr %array.begin.i to ptr addrspace(4)
%cast.2 = addrspacecast ptr %array.begin.i to ptr addrspace(4)
%3 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %cast.1, i64 noundef 0)
%4 = load i8, ptr addrspace(4) %3
%5 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %cast.2, i64 noundef 0)
%6 = load i8, ptr addrspace(4) %5

; array case with indexing into the struct
%target.ptr = getelementptr inbounds [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], ptr %tC.1, i64 0, i64 0, i64 0, i32 0
%cast.3 = addrspacecast ptr %target.ptr to ptr addrspace(4)
%7 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %cast.3, i64 noundef 0)
%8 = load i8, ptr addrspace(4) %7

; array case with variable index
%non.const.offset = getelementptr inbounds [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], ptr %tC.2, i64 0, i64 %ind, i64 %ind
%cast.4 = addrspacecast ptr %non.const.offset to ptr addrspace(4)
%9 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %cast.4, i64 noundef 0)
%10 = load i8, ptr addrspace(4) %9
ret void
}

Expand Down

0 comments on commit 068c6f3

Please sign in to comment.