From 6f8d2781f604cfcf9ea6facecc0bea8e4d682e1e Mon Sep 17 00:00:00 2001 From: Sterling-Augustine <56981066+Sterling-Augustine@users.noreply.github.com> Date: Mon, 9 Sep 2024 20:49:49 +0000 Subject: [PATCH] [SandboxIR] Add missing VectorType functions (#107650) Fills in many missing functions from VectorType --- llvm/include/llvm/SandboxIR/Type.h | 27 ++++++++++++-- llvm/lib/SandboxIR/Type.cpp | 38 ++++++++++++++++++- llvm/unittests/SandboxIR/TypesTest.cpp | 51 ++++++++++++++++++++++++-- 3 files changed, 108 insertions(+), 8 deletions(-) diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h index 69ca156e82101c..44aee4e4a5b46e 100644 --- a/llvm/include/llvm/SandboxIR/Type.h +++ b/llvm/include/llvm/SandboxIR/Type.h @@ -50,8 +50,8 @@ class Type { friend class ConstantArray; // For LLVMTy. friend class ConstantStruct; // For LLVMTy. friend class ConstantVector; // For LLVMTy. - friend class CmpInst; // For LLVMTy. TODO: Cleanup after sandboxir::VectorType - // is more complete. + friend class CmpInst; // For LLVMTy. TODO: Cleanup after + // sandboxir::VectorType is more complete. // Friend all instruction classes because `create()` functions use LLVMTy. #define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS; @@ -317,7 +317,28 @@ class StructType : public Type { class VectorType : public Type { public: static VectorType *get(Type *ElementType, ElementCount EC); - // TODO: add missing functions + static VectorType *get(Type *ElementType, unsigned NumElements, + bool Scalable) { + return VectorType::get(ElementType, + ElementCount::get(NumElements, Scalable)); + } + Type *getElementType() const; + + static VectorType *get(Type *ElementType, const VectorType *Other) { + return VectorType::get(ElementType, Other->getElementCount()); + } + + inline ElementCount getElementCount() const { + return cast(LLVMTy)->getElementCount(); + } + static VectorType *getInteger(VectorType *VTy); + static VectorType *getExtendedElementVectorType(VectorType *VTy); + static VectorType *getTruncatedElementVectorType(VectorType *VTy); + static VectorType *getSubdividedVectorType(VectorType *VTy, int NumSubdivs); + static VectorType *getHalfElementsVectorType(VectorType *VTy); + static VectorType *getDoubleElementsVectorType(VectorType *VTy); + static bool isValidElementType(Type *ElemTy); + static bool classof(const Type *From) { return isa(From->LLVMTy); } diff --git a/llvm/lib/SandboxIR/Type.cpp b/llvm/lib/SandboxIR/Type.cpp index 11a16e865213fb..bf9f02e2ba3111 100644 --- a/llvm/lib/SandboxIR/Type.cpp +++ b/llvm/lib/SandboxIR/Type.cpp @@ -36,7 +36,6 @@ Type *Type::getDoubleTy(Context &Ctx) { Type *Type::getFloatTy(Context &Ctx) { return Ctx.getType(llvm::Type::getFloatTy(Ctx.LLVMCtx)); } - PointerType *PointerType::get(Type *ElementType, unsigned AddressSpace) { return cast(ElementType->getContext().getType( llvm::PointerType::get(ElementType->LLVMTy, AddressSpace))); @@ -67,6 +66,43 @@ VectorType *VectorType::get(Type *ElementType, ElementCount EC) { llvm::VectorType::get(ElementType->LLVMTy, EC))); } +Type *VectorType::getElementType() const { + return Ctx.getType(cast(LLVMTy)->getElementType()); +} +VectorType *VectorType::getInteger(VectorType *VTy) { + return cast(VTy->getContext().getType( + llvm::VectorType::getInteger(cast(VTy->LLVMTy)))); +} +VectorType *VectorType::getExtendedElementVectorType(VectorType *VTy) { + return cast( + VTy->getContext().getType(llvm::VectorType::getExtendedElementVectorType( + cast(VTy->LLVMTy)))); +} +VectorType *VectorType::getTruncatedElementVectorType(VectorType *VTy) { + return cast( + VTy->getContext().getType(llvm::VectorType::getTruncatedElementVectorType( + cast(VTy->LLVMTy)))); +} +VectorType *VectorType::getSubdividedVectorType(VectorType *VTy, + int NumSubdivs) { + return cast( + VTy->getContext().getType(llvm::VectorType::getSubdividedVectorType( + cast(VTy->LLVMTy), NumSubdivs))); +} +VectorType *VectorType::getHalfElementsVectorType(VectorType *VTy) { + return cast( + VTy->getContext().getType(llvm::VectorType::getHalfElementsVectorType( + cast(VTy->LLVMTy)))); +} +VectorType *VectorType::getDoubleElementsVectorType(VectorType *VTy) { + return cast( + VTy->getContext().getType(llvm::VectorType::getDoubleElementsVectorType( + cast(VTy->LLVMTy)))); +} +bool VectorType::isValidElementType(Type *ElemTy) { + return llvm::VectorType::isValidElementType(ElemTy->LLVMTy); +} + IntegerType *IntegerType::get(Context &Ctx, unsigned NumBits) { return cast( Ctx.getType(llvm::IntegerType::get(Ctx.LLVMCtx, NumBits))); diff --git a/llvm/unittests/SandboxIR/TypesTest.cpp b/llvm/unittests/SandboxIR/TypesTest.cpp index 36ef0cf8e52911..e4f9235c1ef3ca 100644 --- a/llvm/unittests/SandboxIR/TypesTest.cpp +++ b/llvm/unittests/SandboxIR/TypesTest.cpp @@ -268,16 +268,59 @@ define void @foo({i32, i8} %v0) { TEST_F(SandboxTypeTest, VectorType) { parseIR(C, R"IR( -define void @foo(<2 x i8> %v0) { +define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) { ret void } )IR"); llvm::Function *LLVMF = &*M->getFunction("foo"); sandboxir::Context Ctx(C); auto *F = Ctx.createFunction(LLVMF); - // Check classof(), creation. - [[maybe_unused]] auto *VecTy = - cast(F->getArg(0)->getType()); + // Check classof(), creation, accessors + auto *VecTy = cast(F->getArg(0)->getType()); + EXPECT_TRUE(VecTy->getElementType()->isIntegerTy(16)); + EXPECT_EQ(VecTy->getElementCount(), ElementCount::getFixed(4)); + + // get(ElementType, NumElements, Scalable) + EXPECT_EQ(sandboxir::VectorType::get(sandboxir::Type::getInt16Ty(Ctx), 4, + /*Scalable=*/false), + F->getArg(0)->getType()); + // get(ElementType, Other) + EXPECT_EQ(sandboxir::VectorType::get( + sandboxir::Type::getInt16Ty(Ctx), + cast(F->getArg(0)->getType())), + F->getArg(0)->getType()); + auto *FVecTy = cast(F->getArg(1)->getType()); + EXPECT_TRUE(FVecTy->getElementType()->isFloatTy()); + // getInteger + auto *IVecTy = sandboxir::VectorType::getInteger(FVecTy); + EXPECT_TRUE(IVecTy->getElementType()->isIntegerTy(32)); + EXPECT_EQ(IVecTy->getElementCount(), FVecTy->getElementCount()); + // getExtendedElementCountVectorType + auto *ExtVecTy = sandboxir::VectorType::getExtendedElementVectorType(IVecTy); + EXPECT_TRUE(ExtVecTy->getElementType()->isIntegerTy(64)); + EXPECT_EQ(ExtVecTy->getElementCount(), VecTy->getElementCount()); + // getTruncatedElementVectorType + auto *TruncVecTy = + sandboxir::VectorType::getTruncatedElementVectorType(IVecTy); + EXPECT_TRUE(TruncVecTy->getElementType()->isIntegerTy(16)); + EXPECT_EQ(TruncVecTy->getElementCount(), VecTy->getElementCount()); + // getSubdividedVectorType + auto *SubVecTy = sandboxir::VectorType::getSubdividedVectorType(VecTy, 1); + EXPECT_TRUE(SubVecTy->getElementType()->isIntegerTy(8)); + EXPECT_EQ(SubVecTy->getElementCount(), ElementCount::getFixed(8)); + // getHalfElementsVectorType + auto *HalfVecTy = sandboxir::VectorType::getHalfElementsVectorType(VecTy); + EXPECT_TRUE(HalfVecTy->getElementType()->isIntegerTy(16)); + EXPECT_EQ(HalfVecTy->getElementCount(), ElementCount::getFixed(2)); + // getDoubleElementsVectorType + auto *DoubleVecTy = sandboxir::VectorType::getDoubleElementsVectorType(VecTy); + EXPECT_TRUE(DoubleVecTy->getElementType()->isIntegerTy(16)); + EXPECT_EQ(DoubleVecTy->getElementCount(), ElementCount::getFixed(8)); + // isValidElementType + auto *I8Type = F->getArg(2)->getType(); + EXPECT_TRUE(I8Type->isIntegerTy()); + EXPECT_TRUE(sandboxir::VectorType::isValidElementType(I8Type)); + EXPECT_FALSE(sandboxir::VectorType::isValidElementType(FVecTy)); } TEST_F(SandboxTypeTest, FunctionType) {