From 068c6f38cf6893866db446619eb3c66fd9d4cb05 Mon Sep 17 00:00:00 2001 From: "Plyakhin, Yury" Date: Tue, 17 Dec 2024 16:19:04 -0800 Subject: [PATCH] update --- .../SYCLLowerIR/SYCLJointMatrixTransform.cpp | 47 ++++++++++-- .../JointMatrixTransform/access_chain.ll | 75 +++++++++++++------ 2 files changed, 93 insertions(+), 29 deletions(-) diff --git a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp index 586175120d79e..0e95f21d555ae 100644 --- a/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp +++ b/llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp @@ -39,6 +39,42 @@ Type *replaceInnermostType(Type *Ty, Type *NewInnermostTy) { return NewInnermostTy; } +// This function is a copy of llvm::stripPointerCastsAndOffsets with +// modification to strip non-zero GEP indices as well. +Value *stripPointerCastsAndOffsets(Value *V) { + if (!V->getType()->isPointerTy()) + return V; + + // Even though we don't look through PHI nodes, we could be called on an + // instruction in an unreachable block, which may be on a cycle. + SmallPtrSet Visited; + + Visited.insert(V); + do { + if (auto *GEP = dyn_cast(V)) { + V = GEP->getPointerOperand(); + } else if (Operator::getOpcode(V) == Instruction::BitCast) { + Value *NewV = cast(V)->getOperand(0); + if (!NewV->getType()->isPointerTy()) + return V; + V = NewV; + } else if (Operator::getOpcode(V) == Instruction::AddrSpaceCast) { + V = cast(V)->getOperand(0); + } else { + if (auto *Call = dyn_cast(V)) { + if (Value *RV = Call->getReturnedArgOperand()) { + V = RV; + continue; + } + } + return V; + } + assert(V->getType()->isPointerTy() && "Unexpected operand type!"); + } while (Visited.insert(V).second); + + return V; +} + // This function finds all calls to __spirv_AccessChain function and transforms // its users and operands to make LLVM IR more SPIR-V friendly. bool transformAccessChain(Function *F) { @@ -77,8 +113,8 @@ bool transformAccessChain(Function *F) { // from sycl::joint_matrix class object if it's used in __spirv_AccessChain // function call. It's necessary because otherwise OpAccessChain indices // would be wrong. - Instruction *Ptr = - dyn_cast(CI->getArgOperand(0)->stripPointerCasts()); + Instruction *Ptr = dyn_cast( + stripPointerCastsAndOffsets(CI->getArgOperand(0))); if (!Ptr || !isa(Ptr)) continue; @@ -118,7 +154,7 @@ bool transformAccessChain(Function *F) { if (!GEP) continue; Value *LastIndex = GEP->getOperand(GEP->getNumOperands() - 1); - if ((GEP->getNumIndices() == NestingLevel + 1) && + if ((GEP->getNumIndices() == NestingLevel + 2) && (!isa(LastIndex) || !cast(LastIndex)->isZero())) { assert(false && "Unexpected GEP pattern"); @@ -130,8 +166,9 @@ bool transformAccessChain(Function *F) { for (auto *GEP : GEPsToReplace) { SmallVector Indices(GEP->idx_begin(), GEP->idx_end()); - // Remove the last index, as nesting level is decreased - if (GEP->getNumIndices() == NestingLevel + 1) + // Remove the last index, if it is addressing element of struct + // that was removed + if (GEP->getNumIndices() == NestingLevel + 2) Indices.pop_back(); IRBuilder Builder(GEP); Value *NewGEP = Builder.CreateInBoundsGEP( diff --git a/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll b/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll index e7f802b80e346..6584a5c7cdb2d 100644 --- a/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll +++ b/llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll @@ -3,16 +3,24 @@ ; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s -; CHECK: %[[#Alloca:]] = alloca target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) -; CHECK-NEXT: [[TC_I:%.*]] = alloca [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]] -; CHECK-NEXT: [[ARRAY_BEGIN_I:%.*]] = getelementptr inbounds [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]], ptr [[TC_I]], i64 0, i64 0 -; CHECK-NEXT: [[ARRAYCTOR_END_I:%.*]] = getelementptr inbounds [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]], ptr [[TC_I]], i64 5 -; CHECK-NEXT: %[[#Cast:]] = addrspacecast ptr %[[#Alloca]] to ptr addrspace(4) -; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef %[[#Cast]], i64 noundef 0) -; CHECK: [[TMP4:%.*]] = addrspacecast ptr [[ARRAY_BEGIN_I]] to ptr addrspace(4) -; CHECK-NEXT: [[TMP5:%.*]] = addrspacecast ptr [[ARRAY_BEGIN_I]] to ptr addrspace(4) -; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[TMP4]], i64 noundef 0) -; CHECK: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[TMP5]], i64 noundef 0) +; CHECK: [[ALLOC:%.*]] = alloca target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) +; CHECK-NEXT: [[TC:%.*]] = alloca [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]] +; CHECK-NEXT: [[TC_1:%.*]] = alloca [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]] +; CHECK-NEXT: [[TC_2:%.*]] = alloca [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]] +; CHECK-NEXT: [[CAST:%.*]] = addrspacecast ptr [[ALLOC]] to ptr addrspace(4) +; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[CAST]], i64 noundef 0) +; CHECK: [[ARRAY_BEGIN_I:%.*]] = getelementptr inbounds [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]], ptr [[TC]], i64 0, i64 0, i64 0 +; CHECK-NEXT: [[ARRAYCTOR_END_I:%.*]] = getelementptr inbounds [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]], ptr [[TC]], i64 1 +; CHECK-NEXT: [[CAST_1:%.*]] = addrspacecast ptr [[ARRAY_BEGIN_I]] to ptr addrspace(4) +; CHECK-NEXT: [[CAST_2:%.*]] = addrspacecast ptr [[ARRAY_BEGIN_I]] to ptr addrspace(4) +; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[CAST_1]], i64 noundef 0) +; CHECK: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[CAST_2]], i64 noundef 0) +; CHECK: [[TARGET_PTR:%.*]] = getelementptr inbounds [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]], ptr [[TC_1]], i64 0, i64 0, i64 0 +; CHECK-NEXT: [[CAST_3:%.*]] = addrspacecast ptr [[TARGET_PTR]] to ptr addrspace(4) +; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[CAST_3]], i64 noundef 0) +; CHECK: [[NON_CONST_OFFSET:%.*]] = getelementptr inbounds [5 x [7 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]], ptr [[TC_2]], i64 0, i64 %ind, i64 %ind +; CHECK-NEXT: [[CAST_4:%.*]] = addrspacecast ptr [[NON_CONST_OFFSET]] to ptr addrspace(4) +; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[CAST_4]], i64 noundef 0) ; ModuleID = 'test.bc' source_filename = "test.cpp" @@ -21,21 +29,40 @@ target triple = "spir64-unknown-unknown" %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) } -define weak_odr dso_local spir_kernel void @test() { +define weak_odr dso_local spir_kernel void @test(i64 %ind) { entry: - %0 = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8 - %tC.i = alloca [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], align 8 - %array.begin.i = getelementptr inbounds [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], ptr %tC.i, i64 0, i64 0, i64 0 - %arrayctor.end.i = getelementptr inbounds [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], ptr %tC.i, i64 5 - %1 = addrspacecast ptr %0 to ptr addrspace(4) - %2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %1, i64 noundef 0) - %3 = load i8, ptr addrspace(4) %2 - %4 = addrspacecast ptr %array.begin.i to ptr addrspace(4) - %5 = addrspacecast ptr %array.begin.i to ptr addrspace(4) - %6 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %4, i64 noundef 0) - %7 = load i8, ptr addrspace(4) %6 - %8 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %5, i64 noundef 0) - %9 = load i8, ptr addrspace(4) %8 + ; allocas + %alloc = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8 + %tC = alloca [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], align 8 + %tC.1 = alloca [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], align 8 + %tC.2 = alloca [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], align 8 + + ; simple case + %cast = addrspacecast ptr %alloc to ptr addrspace(4) + %1 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %cast, i64 noundef 0) + %2 = load i8, ptr addrspace(4) %1 + + ; array case + %array.begin.i = getelementptr inbounds [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], ptr %tC, i64 0, i64 0, i64 0 + %arrayctor.end.i = getelementptr inbounds [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], ptr %tC, i64 1 + %cast.1 = addrspacecast ptr %array.begin.i to ptr addrspace(4) + %cast.2 = addrspacecast ptr %array.begin.i to ptr addrspace(4) + %3 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %cast.1, i64 noundef 0) + %4 = load i8, ptr addrspace(4) %3 + %5 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %cast.2, i64 noundef 0) + %6 = load i8, ptr addrspace(4) %5 + + ; array case with indexing into the struct + %target.ptr = getelementptr inbounds [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], ptr %tC.1, i64 0, i64 0, i64 0, i32 0 + %cast.3 = addrspacecast ptr %target.ptr to ptr addrspace(4) + %7 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %cast.3, i64 noundef 0) + %8 = load i8, ptr addrspace(4) %7 + + ; array case with variable index + %non.const.offset = getelementptr inbounds [5 x [7 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"]], ptr %tC.2, i64 0, i64 %ind, i64 %ind + %cast.4 = addrspacecast ptr %non.const.offset to ptr addrspace(4) + %9 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %cast.4, i64 noundef 0) + %10 = load i8, ptr addrspace(4) %9 ret void }