From f0909e3a991c7b6515cafe2ce0a069a856fe675f Mon Sep 17 00:00:00 2001 From: Sterling-Augustine <56981066+Sterling-Augustine@users.noreply.github.com> Date: Fri, 11 Oct 2024 11:19:59 -0700 Subject: [PATCH] [SandboxIR] Add utility function to find the base Value for Mem instructions (#112030) --- llvm/include/llvm/SandboxIR/Context.h | 1 + llvm/include/llvm/SandboxIR/Utils.h | 10 ++++++++ llvm/unittests/SandboxIR/UtilsTest.cpp | 32 ++++++++++++++++++++++++++ 3 files changed, 43 insertions(+) diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h index 77924fbcd5acea..1285598a1c0282 100644 --- a/llvm/include/llvm/SandboxIR/Context.h +++ b/llvm/include/llvm/SandboxIR/Context.h @@ -68,6 +68,7 @@ class Context { } /// Get or create a sandboxir::Constant from an existing LLVM IR \p LLVMC. Constant *getOrCreateConstant(llvm::Constant *LLVMC); + friend class Utils; // For getMemoryBase // Friends for getOrCreateConstant(). #define DEF_CONST(ID, CLASS) friend class CLASS; diff --git a/llvm/include/llvm/SandboxIR/Utils.h b/llvm/include/llvm/SandboxIR/Utils.h index 4ff4509b7086c5..a73498adea1d59 100644 --- a/llvm/include/llvm/SandboxIR/Utils.h +++ b/llvm/include/llvm/SandboxIR/Utils.h @@ -50,6 +50,16 @@ class Utils { return const_cast(I); } + /// \Returns the base Value for load or store instruction \p LSI. + template + static Value *getMemInstructionBase(const LoadOrStoreT *LSI) { + static_assert(std::is_same_v || + std::is_same_v, + "Expected sandboxir::Load or sandboxir::Store!"); + return LSI->Ctx.getOrCreateValue( + getUnderlyingObject(LSI->getPointerOperand()->Val)); + } + /// \Returns the number of bits required to represent the operands or return /// value of \p V in \p DL. static unsigned getNumBits(Value *V, const DataLayout &DL) { diff --git a/llvm/unittests/SandboxIR/UtilsTest.cpp b/llvm/unittests/SandboxIR/UtilsTest.cpp index 90396eaa53ab38..a30fc253a1a742 100644 --- a/llvm/unittests/SandboxIR/UtilsTest.cpp +++ b/llvm/unittests/SandboxIR/UtilsTest.cpp @@ -215,3 +215,35 @@ define void @foo(float %arg0, double %arg1, i8 %arg2, i64 %arg3, ptr %arg4) { EXPECT_EQ(sandboxir::Utils::getNumBits(L2), 8u); EXPECT_EQ(sandboxir::Utils::getNumBits(L3), 64u); } + +TEST_F(UtilsTest, GetMemBase) { + parseIR(C, R"IR( +define void @foo(ptr %ptrA, float %val, ptr %ptrB) { +bb: + %gepA0 = getelementptr float, ptr %ptrA, i32 0 + %gepA1 = getelementptr float, ptr %ptrA, i32 1 + %gepB0 = getelementptr float, ptr %ptrB, i32 0 + %gepB1 = getelementptr float, ptr %ptrB, i32 1 + store float %val, ptr %gepA0 + store float %val, ptr %gepA1 + store float %val, ptr %gepB0 + store float %val, ptr %gepB1 + ret void +} +)IR"); + llvm::Function &Foo = *M->getFunction("foo"); + sandboxir::Context Ctx(C); + sandboxir::Function *F = Ctx.createFunction(&Foo); + + auto It = std::next(F->begin()->begin(), 4); + auto *St0 = cast(&*It++); + auto *St1 = cast(&*It++); + auto *St2 = cast(&*It++); + auto *St3 = cast(&*It++); + EXPECT_EQ(sandboxir::Utils::getMemInstructionBase(St0), + sandboxir::Utils::getMemInstructionBase(St1)); + EXPECT_EQ(sandboxir::Utils::getMemInstructionBase(St2), + sandboxir::Utils::getMemInstructionBase(St3)); + EXPECT_NE(sandboxir::Utils::getMemInstructionBase(St0), + sandboxir::Utils::getMemInstructionBase(St3)); +}