Skip to content

Commit

Permalink
[C API] Add accessors for new no-wrap flags on GEP instructions (#97970)
Browse files Browse the repository at this point in the history
Summary:
Previously, only the inbounds flag was accessible via the C API. This
adds support for any no-wrap related flags (currently nuw and nusw).

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60251605
  • Loading branch information
Benjins authored and yuxuanchen1997 committed Jul 25, 2024
1 parent ff98b7e commit 6a53959
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 9 deletions.
8 changes: 8 additions & 0 deletions llvm/docs/ReleaseNotes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,14 @@ They are described in detail in the `debug info migration guide <https://llvm.or
* ``LLVMGetTargetExtTypeNumTypeParams``/``LLVMGetTargetExtTypeTypeParam``
* ``LLVMGetTargetExtTypeNumIntParams``/``LLVMGetTargetExtTypeIntParam``

* Added the following functions for accessing/setting the no-wrap flags for a
GetElementPtr instruction:

* ``LLVMBuildGEPWithNoWrapFlags``
* ``LLVMConstGEPWithNoWrapFlags``
* ``LLVMGEPGetNoWrapFlags``
* ``LLVMGEPSetNoWrapFlags``

Changes to the CodeGen infrastructure
-------------------------------------

Expand Down
50 changes: 50 additions & 0 deletions llvm/include/llvm-c/Core.h
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,20 @@ enum {
*/
typedef unsigned LLVMFastMathFlags;

enum {
LLVMGEPFlagInBounds = (1 << 0),
LLVMGEPFlagNUSW = (1 << 1),
LLVMGEPFlagNUW = (1 << 2),
};

/**
* Flags that constrain the allowed wrap semantics of a getelementptr
* instruction.
*
* See https://llvm.org/docs/LangRef.html#getelementptr-instruction
*/
typedef unsigned LLVMGEPNoWrapFlags;

/**
* @}
*/
Expand Down Expand Up @@ -2395,6 +2409,17 @@ LLVMValueRef LLVMConstGEP2(LLVMTypeRef Ty, LLVMValueRef ConstantVal,
LLVMValueRef LLVMConstInBoundsGEP2(LLVMTypeRef Ty, LLVMValueRef ConstantVal,
LLVMValueRef *ConstantIndices,
unsigned NumIndices);
/**
* Creates a constant GetElementPtr expression. Similar to LLVMConstGEP2, but
* allows specifying the no-wrap flags.
*
* @see llvm::ConstantExpr::getGetElementPtr()
*/
LLVMValueRef LLVMConstGEPWithNoWrapFlags(LLVMTypeRef Ty,
LLVMValueRef ConstantVal,
LLVMValueRef *ConstantIndices,
unsigned NumIndices,
LLVMGEPNoWrapFlags NoWrapFlags);
LLVMValueRef LLVMConstTrunc(LLVMValueRef ConstantVal, LLVMTypeRef ToType);
LLVMValueRef LLVMConstPtrToInt(LLVMValueRef ConstantVal, LLVMTypeRef ToType);
LLVMValueRef LLVMConstIntToPtr(LLVMValueRef ConstantVal, LLVMTypeRef ToType);
Expand Down Expand Up @@ -3904,6 +3929,20 @@ void LLVMSetIsInBounds(LLVMValueRef GEP, LLVMBool InBounds);
*/
LLVMTypeRef LLVMGetGEPSourceElementType(LLVMValueRef GEP);

/**
* Get the no-wrap related flags for the given GEP instruction.
*
* @see llvm::GetElementPtrInst::getNoWrapFlags
*/
LLVMGEPNoWrapFlags LLVMGEPGetNoWrapFlags(LLVMValueRef GEP);

/**
* Set the no-wrap related flags for the given GEP instruction.
*
* @see llvm::GetElementPtrInst::setNoWrapFlags
*/
void LLVMGEPSetNoWrapFlags(LLVMValueRef GEP, LLVMGEPNoWrapFlags NoWrapFlags);

/**
* @}
*/
Expand Down Expand Up @@ -4363,6 +4402,17 @@ LLVMValueRef LLVMBuildGEP2(LLVMBuilderRef B, LLVMTypeRef Ty,
LLVMValueRef LLVMBuildInBoundsGEP2(LLVMBuilderRef B, LLVMTypeRef Ty,
LLVMValueRef Pointer, LLVMValueRef *Indices,
unsigned NumIndices, const char *Name);
/**
* Creates a GetElementPtr instruction. Similar to LLVMBuildGEP2, but allows
* specifying the no-wrap flags.
*
* @see llvm::IRBuilder::CreateGEP()
*/
LLVMValueRef LLVMBuildGEPWithNoWrapFlags(LLVMBuilderRef B, LLVMTypeRef Ty,
LLVMValueRef Pointer,
LLVMValueRef *Indices,
unsigned NumIndices, const char *Name,
LLVMGEPNoWrapFlags NoWrapFlags);
LLVMValueRef LLVMBuildStructGEP2(LLVMBuilderRef B, LLVMTypeRef Ty,
LLVMValueRef Pointer, unsigned Idx,
const char *Name);
Expand Down
58 changes: 58 additions & 0 deletions llvm/lib/IR/Core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1685,6 +1685,32 @@ static int map_from_llvmopcode(LLVMOpcode code)
llvm_unreachable("Unhandled Opcode.");
}

/*-- GEP wrap flag conversions */

static GEPNoWrapFlags mapFromLLVMGEPNoWrapFlags(LLVMGEPNoWrapFlags GEPFlags) {
GEPNoWrapFlags NewGEPFlags;
if ((GEPFlags & LLVMGEPFlagInBounds) != 0)
NewGEPFlags |= GEPNoWrapFlags::inBounds();
if ((GEPFlags & LLVMGEPFlagNUSW) != 0)
NewGEPFlags |= GEPNoWrapFlags::noUnsignedSignedWrap();
if ((GEPFlags & LLVMGEPFlagNUW) != 0)
NewGEPFlags |= GEPNoWrapFlags::noUnsignedWrap();

return NewGEPFlags;
}

static LLVMGEPNoWrapFlags mapToLLVMGEPNoWrapFlags(GEPNoWrapFlags GEPFlags) {
LLVMGEPNoWrapFlags NewGEPFlags = 0;
if (GEPFlags.isInBounds())
NewGEPFlags |= LLVMGEPFlagInBounds;
if (GEPFlags.hasNoUnsignedSignedWrap())
NewGEPFlags |= LLVMGEPFlagNUSW;
if (GEPFlags.hasNoUnsignedWrap())
NewGEPFlags |= LLVMGEPFlagNUW;

return NewGEPFlags;
}

/*--.. Constant expressions ................................................--*/

LLVMOpcode LLVMGetConstOpcode(LLVMValueRef ConstantVal) {
Expand Down Expand Up @@ -1789,6 +1815,18 @@ LLVMValueRef LLVMConstInBoundsGEP2(LLVMTypeRef Ty, LLVMValueRef ConstantVal,
return wrap(ConstantExpr::getInBoundsGetElementPtr(unwrap(Ty), Val, IdxList));
}

LLVMValueRef LLVMConstGEPWithNoWrapFlags(LLVMTypeRef Ty,
LLVMValueRef ConstantVal,
LLVMValueRef *ConstantIndices,
unsigned NumIndices,
LLVMGEPNoWrapFlags NoWrapFlags) {
ArrayRef<Constant *> IdxList(unwrap<Constant>(ConstantIndices, NumIndices),
NumIndices);
Constant *Val = unwrap<Constant>(ConstantVal);
return wrap(ConstantExpr::getGetElementPtr(
unwrap(Ty), Val, IdxList, mapFromLLVMGEPNoWrapFlags(NoWrapFlags)));
}

LLVMValueRef LLVMConstTrunc(LLVMValueRef ConstantVal, LLVMTypeRef ToType) {
return wrap(ConstantExpr::getTrunc(unwrap<Constant>(ConstantVal),
unwrap(ToType)));
Expand Down Expand Up @@ -3102,6 +3140,16 @@ LLVMTypeRef LLVMGetGEPSourceElementType(LLVMValueRef GEP) {
return wrap(unwrap<GEPOperator>(GEP)->getSourceElementType());
}

LLVMGEPNoWrapFlags LLVMGEPGetNoWrapFlags(LLVMValueRef GEP) {
GEPOperator *GEPOp = unwrap<GEPOperator>(GEP);
return mapToLLVMGEPNoWrapFlags(GEPOp->getNoWrapFlags());
}

void LLVMGEPSetNoWrapFlags(LLVMValueRef GEP, LLVMGEPNoWrapFlags NoWrapFlags) {
GetElementPtrInst *GEPInst = unwrap<GetElementPtrInst>(GEP);
GEPInst->setNoWrapFlags(mapFromLLVMGEPNoWrapFlags(NoWrapFlags));
}

/*--.. Operations on phi nodes .............................................--*/

void LLVMAddIncoming(LLVMValueRef PhiNode, LLVMValueRef *IncomingValues,
Expand Down Expand Up @@ -3902,6 +3950,16 @@ LLVMValueRef LLVMBuildInBoundsGEP2(LLVMBuilderRef B, LLVMTypeRef Ty,
unwrap(B)->CreateInBoundsGEP(unwrap(Ty), unwrap(Pointer), IdxList, Name));
}

LLVMValueRef LLVMBuildGEPWithNoWrapFlags(LLVMBuilderRef B, LLVMTypeRef Ty,
LLVMValueRef Pointer,
LLVMValueRef *Indices,
unsigned NumIndices, const char *Name,
LLVMGEPNoWrapFlags NoWrapFlags) {
ArrayRef<Value *> IdxList(unwrap(Indices), NumIndices);
return wrap(unwrap(B)->CreateGEP(unwrap(Ty), unwrap(Pointer), IdxList, Name,
mapFromLLVMGEPNoWrapFlags(NoWrapFlags)));
}

LLVMValueRef LLVMBuildStructGEP2(LLVMBuilderRef B, LLVMTypeRef Ty,
LLVMValueRef Pointer, unsigned Idx,
const char *Name) {
Expand Down
12 changes: 12 additions & 0 deletions llvm/test/Bindings/llvm-c/echo.ll
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ module asm "classical GAS"

@const_gep = global ptr getelementptr (i32, ptr @var, i64 2)
@const_inbounds_gep = global ptr getelementptr inbounds (i32, ptr @var, i64 1)
@const_gep_nuw = global ptr getelementptr nuw (i32, ptr @var, i64 1)
@const_gep_nusw = global ptr getelementptr nusw (i32, ptr @var, i64 1)
@const_gep_nuw_inbounds = global ptr getelementptr nuw inbounds (i32, ptr @var, i64 1)

@aliased1 = alias i32, ptr @var
@aliased2 = internal alias i32, ptr @var
Expand Down Expand Up @@ -391,6 +394,15 @@ bb_03:
ret void
}

define ptr @test_gep_no_wrap_flags(ptr %0) {
%gep.1 = getelementptr i8, ptr %0, i32 4
%gep.inbounds = getelementptr inbounds i8, ptr %0, i32 4
%gep.nuw = getelementptr nuw i8, ptr %0, i32 4
%gep.nuw.inbounds = getelementptr inbounds nuw i8, ptr %0, i32 4
%gep.nusw = getelementptr nusw i8, ptr %0, i32 4
ret ptr %gep.nusw
}

!llvm.dbg.cu = !{!0, !2}
!llvm.module.flags = !{!3}

Expand Down
16 changes: 7 additions & 9 deletions llvm/tools/llvm-c-test/echo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,10 +412,9 @@ static LLVMValueRef clone_constant_impl(LLVMValueRef Cst, LLVMModuleRef M) {
SmallVector<LLVMValueRef, 8> Idx;
for (int i = 1; i <= NumIdx; i++)
Idx.push_back(clone_constant(LLVMGetOperand(Cst, i), M));
if (LLVMIsInBounds(Cst))
return LLVMConstInBoundsGEP2(ElemTy, Ptr, Idx.data(), NumIdx);
else
return LLVMConstGEP2(ElemTy, Ptr, Idx.data(), NumIdx);

return LLVMConstGEPWithNoWrapFlags(ElemTy, Ptr, Idx.data(), NumIdx,
LLVMGEPGetNoWrapFlags(Cst));
}
default:
fprintf(stderr, "%d is not a supported opcode for constant expressions\n",
Expand Down Expand Up @@ -767,11 +766,10 @@ struct FunCloner {
int NumIdx = LLVMGetNumIndices(Src);
for (int i = 1; i <= NumIdx; i++)
Idx.push_back(CloneValue(LLVMGetOperand(Src, i)));
if (LLVMIsInBounds(Src))
Dst = LLVMBuildInBoundsGEP2(Builder, ElemTy, Ptr, Idx.data(), NumIdx,
Name);
else
Dst = LLVMBuildGEP2(Builder, ElemTy, Ptr, Idx.data(), NumIdx, Name);

Dst = LLVMBuildGEPWithNoWrapFlags(Builder, ElemTy, Ptr, Idx.data(),
NumIdx, Name,
LLVMGEPGetNoWrapFlags(Src));
break;
}
case LLVMAtomicRMW: {
Expand Down

0 comments on commit 6a53959

Please sign in to comment.