Skip to content

Commit

Permalink
Add collect offset c api function (#1465)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Oct 6, 2023
1 parent 08288a0 commit 1ada437
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 1 deletion.
18 changes: 18 additions & 0 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,24 @@ LLVMValueRef EnzymeCloneFunctionWithoutReturnOrArgs(LLVMValueRef FC,
LLVMTypeRef EnzymeAllocaType(LLVMValueRef V) {
return wrap(cast<AllocaInst>(unwrap(V))->getAllocatedType());
}
LLVMValueRef EnzymeComputeByteOffsetOfGEP(LLVMBuilderRef B_r, LLVMValueRef V_r,
LLVMTypeRef T_r) {
IRBuilder<> &B = *unwrap(B_r);
auto T = cast<IntegerType>(unwrap(T_r));
auto width = T->getBitWidth();
auto gep = cast<GetElementPtrInst>(unwrap(V_r));
auto &DL = B.GetInsertBlock()->getParent()->getParent()->getDataLayout();

MapVector<Value *, APInt> VariableOffsets;
APInt Offset(width, 0);
bool success = collectOffset(gep, DL, width, VariableOffsets, Offset);
assert(success);
Value *start = ConstantInt::get(T, Offset);
for (auto &pair : VariableOffsets)
start = B.CreateAdd(
start, B.CreateMul(pair.first, ConstantInt::get(T, pair.second)));
return wrap(start);
}
}

static size_t num_rooting(llvm::Type *T, llvm::Function *F) {
Expand Down
65 changes: 65 additions & 0 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GetElementPtrTypeIterator.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Module.h"
Expand Down Expand Up @@ -2648,3 +2649,67 @@ CountTrackedPointers::CountTrackedPointers(Type *T) {
if (count == 0)
all = false;
}

bool collectOffset(GetElementPtrInst *gep, const DataLayout &DL,
unsigned BitWidth,
MapVector<Value *, APInt> &VariableOffsets,
APInt &ConstantOffset) {
#if LLVM_VERSION_MAJOR >= 13
return cast<GEPOperator>(gep)->collectOffset(DL, BitWidth, VariableOffsets,
ConstantOffset);
#else
assert(BitWidth == DL.getIndexSizeInBits(gep->getPointerAddressSpace()) &&
"The offset bit width does not match DL specification.");

auto CollectConstantOffset = [&](APInt Index, uint64_t Size) {
Index = Index.sextOrTrunc(BitWidth);
APInt IndexedSize = APInt(BitWidth, Size);
ConstantOffset += Index * IndexedSize;
};

for (gep_type_iterator GTI = gep_type_begin(gep), GTE = gep_type_end(gep);
GTI != GTE; ++GTI) {
// Scalable vectors are multiplied by a runtime constant.
bool ScalableType = isa<ScalableVectorType>(GTI.getIndexedType());

Value *V = GTI.getOperand();
StructType *STy = GTI.getStructTypeOrNull();
// Handle ConstantInt if possible.
if (auto ConstOffset = dyn_cast<ConstantInt>(V)) {
if (ConstOffset->isZero())
continue;
// If the type is scalable and the constant is not zero (vscale * n * 0 =
// 0) bailout.
// TODO: If the runtime value is accessible at any point before DWARF
// emission, then we could potentially keep a forward reference to it
// in the debug value to be filled in later.
if (ScalableType)
return false;
// Handle a struct index, which adds its field offset to the pointer.
if (STy) {
unsigned ElementIdx = ConstOffset->getZExtValue();
const StructLayout *SL = DL.getStructLayout(STy);
// Element offset is in bytes.
CollectConstantOffset(APInt(BitWidth, SL->getElementOffset(ElementIdx)),
1);
continue;
}
CollectConstantOffset(ConstOffset->getValue(),
DL.getTypeAllocSize(GTI.getIndexedType()));
continue;
}

if (STy || ScalableType)
return false;
APInt IndexedSize =
APInt(BitWidth, DL.getTypeAllocSize(GTI.getIndexedType()));
// Insert an initial offset of 0 for V iff none exists already, then
// increment the offset by IndexedSize.
if (IndexedSize != 0) {
VariableOffsets.insert({V, APInt(BitWidth, 0)});
VariableOffsets[V] += IndexedSize;
}
}
return true;
#endif
}
5 changes: 5 additions & 0 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#ifndef ENZYME_UTILS_H
#define ENZYME_UTILS_H

#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"

Expand Down Expand Up @@ -1754,4 +1755,8 @@ static inline bool isSpecialPtr(llvm::Type *Ty) {
return AddressSpace::FirstSpecial <= AS && AS <= AddressSpace::LastSpecial;
}

bool collectOffset(llvm::GetElementPtrInst *gep, const llvm::DataLayout &DL,
unsigned BitWidth,
llvm::MapVector<llvm::Value *, llvm::APInt> &VariableOffsets,
llvm::APInt &ConstantOffset);
#endif
2 changes: 1 addition & 1 deletion enzyme/test/test_find_package/main.c
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <stdio.h>
int printf(const char*, ...);

extern double __enzyme_autodiff(void*, double);

Expand Down

0 comments on commit 1ada437

Please sign in to comment.