Skip to content

Commit

Permalink
[SandboxIR] Add missing VectorType functions (llvm#107650)
Browse files Browse the repository at this point in the history
Fills in many missing functions from VectorType
  • Loading branch information
Sterling-Augustine authored Sep 9, 2024
1 parent 53a81d4 commit 6f8d278
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 8 deletions.
27 changes: 24 additions & 3 deletions llvm/include/llvm/SandboxIR/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class Type {
friend class ConstantArray; // For LLVMTy.
friend class ConstantStruct; // For LLVMTy.
friend class ConstantVector; // For LLVMTy.
friend class CmpInst; // For LLVMTy. TODO: Cleanup after sandboxir::VectorType
// is more complete.
friend class CmpInst; // For LLVMTy. TODO: Cleanup after
// sandboxir::VectorType is more complete.

// Friend all instruction classes because `create()` functions use LLVMTy.
#define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
Expand Down Expand Up @@ -317,7 +317,28 @@ class StructType : public Type {
class VectorType : public Type {
public:
static VectorType *get(Type *ElementType, ElementCount EC);
// TODO: add missing functions
static VectorType *get(Type *ElementType, unsigned NumElements,
bool Scalable) {
return VectorType::get(ElementType,
ElementCount::get(NumElements, Scalable));
}
Type *getElementType() const;

static VectorType *get(Type *ElementType, const VectorType *Other) {
return VectorType::get(ElementType, Other->getElementCount());
}

inline ElementCount getElementCount() const {
return cast<llvm::VectorType>(LLVMTy)->getElementCount();
}
static VectorType *getInteger(VectorType *VTy);
static VectorType *getExtendedElementVectorType(VectorType *VTy);
static VectorType *getTruncatedElementVectorType(VectorType *VTy);
static VectorType *getSubdividedVectorType(VectorType *VTy, int NumSubdivs);
static VectorType *getHalfElementsVectorType(VectorType *VTy);
static VectorType *getDoubleElementsVectorType(VectorType *VTy);
static bool isValidElementType(Type *ElemTy);

static bool classof(const Type *From) {
return isa<llvm::VectorType>(From->LLVMTy);
}
Expand Down
38 changes: 37 additions & 1 deletion llvm/lib/SandboxIR/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ Type *Type::getDoubleTy(Context &Ctx) {
Type *Type::getFloatTy(Context &Ctx) {
return Ctx.getType(llvm::Type::getFloatTy(Ctx.LLVMCtx));
}

PointerType *PointerType::get(Type *ElementType, unsigned AddressSpace) {
return cast<PointerType>(ElementType->getContext().getType(
llvm::PointerType::get(ElementType->LLVMTy, AddressSpace)));
Expand Down Expand Up @@ -67,6 +66,43 @@ VectorType *VectorType::get(Type *ElementType, ElementCount EC) {
llvm::VectorType::get(ElementType->LLVMTy, EC)));
}

Type *VectorType::getElementType() const {
return Ctx.getType(cast<llvm::VectorType>(LLVMTy)->getElementType());
}
VectorType *VectorType::getInteger(VectorType *VTy) {
return cast<VectorType>(VTy->getContext().getType(
llvm::VectorType::getInteger(cast<llvm::VectorType>(VTy->LLVMTy))));
}
VectorType *VectorType::getExtendedElementVectorType(VectorType *VTy) {
return cast<VectorType>(
VTy->getContext().getType(llvm::VectorType::getExtendedElementVectorType(
cast<llvm::VectorType>(VTy->LLVMTy))));
}
VectorType *VectorType::getTruncatedElementVectorType(VectorType *VTy) {
return cast<VectorType>(
VTy->getContext().getType(llvm::VectorType::getTruncatedElementVectorType(
cast<llvm::VectorType>(VTy->LLVMTy))));
}
VectorType *VectorType::getSubdividedVectorType(VectorType *VTy,
int NumSubdivs) {
return cast<VectorType>(
VTy->getContext().getType(llvm::VectorType::getSubdividedVectorType(
cast<llvm::VectorType>(VTy->LLVMTy), NumSubdivs)));
}
VectorType *VectorType::getHalfElementsVectorType(VectorType *VTy) {
return cast<VectorType>(
VTy->getContext().getType(llvm::VectorType::getHalfElementsVectorType(
cast<llvm::VectorType>(VTy->LLVMTy))));
}
VectorType *VectorType::getDoubleElementsVectorType(VectorType *VTy) {
return cast<VectorType>(
VTy->getContext().getType(llvm::VectorType::getDoubleElementsVectorType(
cast<llvm::VectorType>(VTy->LLVMTy))));
}
bool VectorType::isValidElementType(Type *ElemTy) {
return llvm::VectorType::isValidElementType(ElemTy->LLVMTy);
}

IntegerType *IntegerType::get(Context &Ctx, unsigned NumBits) {
return cast<IntegerType>(
Ctx.getType(llvm::IntegerType::get(Ctx.LLVMCtx, NumBits)));
Expand Down
51 changes: 47 additions & 4 deletions llvm/unittests/SandboxIR/TypesTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,16 +268,59 @@ define void @foo({i32, i8} %v0) {

TEST_F(SandboxTypeTest, VectorType) {
parseIR(C, R"IR(
define void @foo(<2 x i8> %v0) {
define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
// Check classof(), creation.
[[maybe_unused]] auto *VecTy =
cast<sandboxir::VectorType>(F->getArg(0)->getType());
// Check classof(), creation, accessors
auto *VecTy = cast<sandboxir::VectorType>(F->getArg(0)->getType());
EXPECT_TRUE(VecTy->getElementType()->isIntegerTy(16));
EXPECT_EQ(VecTy->getElementCount(), ElementCount::getFixed(4));

// get(ElementType, NumElements, Scalable)
EXPECT_EQ(sandboxir::VectorType::get(sandboxir::Type::getInt16Ty(Ctx), 4,
/*Scalable=*/false),
F->getArg(0)->getType());
// get(ElementType, Other)
EXPECT_EQ(sandboxir::VectorType::get(
sandboxir::Type::getInt16Ty(Ctx),
cast<sandboxir::VectorType>(F->getArg(0)->getType())),
F->getArg(0)->getType());
auto *FVecTy = cast<sandboxir::VectorType>(F->getArg(1)->getType());
EXPECT_TRUE(FVecTy->getElementType()->isFloatTy());
// getInteger
auto *IVecTy = sandboxir::VectorType::getInteger(FVecTy);
EXPECT_TRUE(IVecTy->getElementType()->isIntegerTy(32));
EXPECT_EQ(IVecTy->getElementCount(), FVecTy->getElementCount());
// getExtendedElementCountVectorType
auto *ExtVecTy = sandboxir::VectorType::getExtendedElementVectorType(IVecTy);
EXPECT_TRUE(ExtVecTy->getElementType()->isIntegerTy(64));
EXPECT_EQ(ExtVecTy->getElementCount(), VecTy->getElementCount());
// getTruncatedElementVectorType
auto *TruncVecTy =
sandboxir::VectorType::getTruncatedElementVectorType(IVecTy);
EXPECT_TRUE(TruncVecTy->getElementType()->isIntegerTy(16));
EXPECT_EQ(TruncVecTy->getElementCount(), VecTy->getElementCount());
// getSubdividedVectorType
auto *SubVecTy = sandboxir::VectorType::getSubdividedVectorType(VecTy, 1);
EXPECT_TRUE(SubVecTy->getElementType()->isIntegerTy(8));
EXPECT_EQ(SubVecTy->getElementCount(), ElementCount::getFixed(8));
// getHalfElementsVectorType
auto *HalfVecTy = sandboxir::VectorType::getHalfElementsVectorType(VecTy);
EXPECT_TRUE(HalfVecTy->getElementType()->isIntegerTy(16));
EXPECT_EQ(HalfVecTy->getElementCount(), ElementCount::getFixed(2));
// getDoubleElementsVectorType
auto *DoubleVecTy = sandboxir::VectorType::getDoubleElementsVectorType(VecTy);
EXPECT_TRUE(DoubleVecTy->getElementType()->isIntegerTy(16));
EXPECT_EQ(DoubleVecTy->getElementCount(), ElementCount::getFixed(8));
// isValidElementType
auto *I8Type = F->getArg(2)->getType();
EXPECT_TRUE(I8Type->isIntegerTy());
EXPECT_TRUE(sandboxir::VectorType::isValidElementType(I8Type));
EXPECT_FALSE(sandboxir::VectorType::isValidElementType(FVecTy));
}

TEST_F(SandboxTypeTest, FunctionType) {
Expand Down

0 comments on commit 6f8d278

Please sign in to comment.