Skip to content

Commit

Permalink
Merge branch 'main' into controlflowfixpart1try2
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Oct 6, 2023
2 parents dfe725a + 1ada437 commit b0f4384
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 6 deletions.
18 changes: 18 additions & 0 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,24 @@ LLVMValueRef EnzymeCloneFunctionWithoutReturnOrArgs(LLVMValueRef FC,
LLVMTypeRef EnzymeAllocaType(LLVMValueRef V) {
return wrap(cast<AllocaInst>(unwrap(V))->getAllocatedType());
}
LLVMValueRef EnzymeComputeByteOffsetOfGEP(LLVMBuilderRef B_r, LLVMValueRef V_r,
LLVMTypeRef T_r) {
IRBuilder<> &B = *unwrap(B_r);
auto T = cast<IntegerType>(unwrap(T_r));
auto width = T->getBitWidth();
auto gep = cast<GetElementPtrInst>(unwrap(V_r));
auto &DL = B.GetInsertBlock()->getParent()->getParent()->getDataLayout();

MapVector<Value *, APInt> VariableOffsets;
APInt Offset(width, 0);
bool success = collectOffset(gep, DL, width, VariableOffsets, Offset);
assert(success);
Value *start = ConstantInt::get(T, Offset);
for (auto &pair : VariableOffsets)
start = B.CreateAdd(
start, B.CreateMul(pair.first, ConstantInt::get(T, pair.second)));
return wrap(start);
}
}

static size_t num_rooting(llvm::Type *T, llvm::Function *F) {
Expand Down
37 changes: 37 additions & 0 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,43 @@ void attributeKnownFunctions(llvm::Function &F) {
F.addParamAttr(2, Attribute::WriteOnly);
F.addParamAttr(2, Attribute::NoCapture);
}
// Map of MPI function name to the arg index of its type argument
std::map<std::string, int> MPI_TYPE_ARGS = {
{"MPI_Send", 2}, {"MPI_Ssend", 2}, {"MPI_Bsend", 2},
{"MPI_Recv", 2}, {"MPI_Brecv", 2}, {"PMPI_Send", 2},
{"PMPI_Ssend", 2}, {"PMPI_Bsend", 2}, {"PMPI_Recv", 2},
{"PMPI_Brecv", 2},

{"MPI_Isend", 2}, {"MPI_Irecv", 2}, {"PMPI_Isend", 2},
{"PMPI_Irecv", 2},

{"MPI_Reduce", 3}, {"PMPI_Reduce", 3},

{"MPI_Allreduce", 3}, {"PMPI_Allreduce", 3}};
{
auto found = MPI_TYPE_ARGS.find(F.getName().str());
if (found != MPI_TYPE_ARGS.end()) {
for (auto user : F.users()) {
if (auto CI = dyn_cast<CallBase>(user))
if (CI->getCalledFunction() == &F) {
if (Constant *C =
dyn_cast<Constant>(CI->getArgOperand(found->second))) {
while (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
C = CE->getOperand(0);
}
if (auto GV = dyn_cast<GlobalVariable>(C)) {
if (GV->getName() == "ompi_mpi_cxx_bool") {
CI->addAttribute(
AttributeList::FunctionIndex,
Attribute::get(CI->getContext(), "enzyme_inactive"));
}
}
}
}
}
}
}

if (F.getName() == "omp_get_max_threads" ||
F.getName() == "omp_get_thread_num") {
#if LLVM_VERSION_MAJOR >= 16
Expand Down
28 changes: 23 additions & 5 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/ModuleSlotTracker.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"

Expand Down Expand Up @@ -1887,10 +1888,11 @@ void TypeAnalyzer::visitExtractElementInst(ExtractElementInst &I) {
auto &dl = fntypeinfo.Function->getParent()->getDataLayout();
VectorType *vecType = cast<VectorType>(I.getVectorOperand()->getType());

size_t size = (dl.getTypeSizeInBits(vecType->getElementType()) + 7) / 8;
size_t bitsize = dl.getTypeSizeInBits(vecType->getElementType());
size_t size = (bitsize + 7) / 8;

if (auto CI = dyn_cast<ConstantInt>(I.getIndexOperand())) {
size_t off = CI->getZExtValue() * size;
size_t off = (CI->getZExtValue() * bitsize) / 8;

if (direction & DOWN)
updateAnalysis(&I,
Expand Down Expand Up @@ -2168,8 +2170,15 @@ void TypeAnalyzer::visitInsertValueInst(InsertValueInst &I) {

void TypeAnalyzer::dump(llvm::raw_ostream &ss) {
ss << "<analysis>\n";
// We don't care about correct MD node numbering here.
ModuleSlotTracker MST(fntypeinfo.Function->getParent(),
/*ShouldInitializeAllMetadata*/ false);
for (auto &pair : analysis) {
ss << *pair.first << ": " << pair.second.str()
if (auto F = dyn_cast<Function>(pair.first))
ss << "@" << F->getName();
else
pair.first->print(ss, MST);
ss << ": " << pair.second.str()
<< ", intvals: " << to_string(knownIntegralValues(pair.first)) << "\n";
}
ss << "</analysis>\n";
Expand Down Expand Up @@ -4023,6 +4032,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) {
buf.insert({0}, Type::getDoubleTy(C->getContext()));
} else if (GV->getName() == "ompi_mpi_float") {
buf.insert({0}, Type::getFloatTy(C->getContext()));
} else if (GV->getName() == "ompi_mpi_cxx_bool") {
buf.insert({0}, BaseType::Integer);
}
} else if (auto CI = dyn_cast<ConstantInt>(C)) {
// MPICH
Expand All @@ -4043,7 +4054,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) {
updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
return;
}
if (funcName == "MPI_Isend" || funcName == "MPI_Irecv") {
if (funcName == "MPI_Isend" || funcName == "MPI_Irecv" ||
funcName == "PMPI_Isend" || funcName == "PMPI_Irecv") {
TypeTree buf = TypeTree(BaseType::Pointer);

if (Constant *C = dyn_cast<Constant>(call.getOperand(2))) {
Expand All @@ -4055,6 +4067,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) {
buf.insert({0}, Type::getDoubleTy(C->getContext()));
} else if (GV->getName() == "ompi_mpi_float") {
buf.insert({0}, Type::getFloatTy(C->getContext()));
} else if (GV->getName() == "ompi_mpi_cxx_bool") {
buf.insert({0}, BaseType::Integer);
}
} else if (auto CI = dyn_cast<ConstantInt>(C)) {
// MPICH
Expand Down Expand Up @@ -4129,6 +4143,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) {
buf.insert({0}, Type::getDoubleTy(C->getContext()));
} else if (GV->getName() == "ompi_mpi_float") {
buf.insert({0}, Type::getFloatTy(C->getContext()));
} else if (GV->getName() == "ompi_mpi_cxx_bool") {
buf.insert({0}, BaseType::Integer);
}
} else if (auto CI = dyn_cast<ConstantInt>(C)) {
// MPICH
Expand Down Expand Up @@ -4156,7 +4172,7 @@ void TypeAnalyzer::visitCallBase(CallBase &call) {
updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
return;
}
if (funcName == "MPI_Allreduce") {
if (funcName == "MPI_Allreduce" || funcName == "PMPI_Allreduce") {
TypeTree buf = TypeTree(BaseType::Pointer);

if (Constant *C = dyn_cast<Constant>(call.getOperand(3))) {
Expand All @@ -4168,6 +4184,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) {
buf.insert({0}, Type::getDoubleTy(C->getContext()));
} else if (GV->getName() == "ompi_mpi_float") {
buf.insert({0}, Type::getFloatTy(C->getContext()));
} else if (GV->getName() == "ompi_mpi_cxx_bool") {
buf.insert({0}, BaseType::Integer);
}
} else if (auto CI = dyn_cast<ConstantInt>(C)) {
// MPICH
Expand Down
65 changes: 65 additions & 0 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GetElementPtrTypeIterator.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Module.h"
Expand Down Expand Up @@ -2648,3 +2649,67 @@ CountTrackedPointers::CountTrackedPointers(Type *T) {
if (count == 0)
all = false;
}

bool collectOffset(GetElementPtrInst *gep, const DataLayout &DL,
unsigned BitWidth,
MapVector<Value *, APInt> &VariableOffsets,
APInt &ConstantOffset) {
#if LLVM_VERSION_MAJOR >= 13
return cast<GEPOperator>(gep)->collectOffset(DL, BitWidth, VariableOffsets,
ConstantOffset);
#else
assert(BitWidth == DL.getIndexSizeInBits(gep->getPointerAddressSpace()) &&
"The offset bit width does not match DL specification.");

auto CollectConstantOffset = [&](APInt Index, uint64_t Size) {
Index = Index.sextOrTrunc(BitWidth);
APInt IndexedSize = APInt(BitWidth, Size);
ConstantOffset += Index * IndexedSize;
};

for (gep_type_iterator GTI = gep_type_begin(gep), GTE = gep_type_end(gep);
GTI != GTE; ++GTI) {
// Scalable vectors are multiplied by a runtime constant.
bool ScalableType = isa<ScalableVectorType>(GTI.getIndexedType());

Value *V = GTI.getOperand();
StructType *STy = GTI.getStructTypeOrNull();
// Handle ConstantInt if possible.
if (auto ConstOffset = dyn_cast<ConstantInt>(V)) {
if (ConstOffset->isZero())
continue;
// If the type is scalable and the constant is not zero (vscale * n * 0 =
// 0) bailout.
// TODO: If the runtime value is accessible at any point before DWARF
// emission, then we could potentially keep a forward reference to it
// in the debug value to be filled in later.
if (ScalableType)
return false;
// Handle a struct index, which adds its field offset to the pointer.
if (STy) {
unsigned ElementIdx = ConstOffset->getZExtValue();
const StructLayout *SL = DL.getStructLayout(STy);
// Element offset is in bytes.
CollectConstantOffset(APInt(BitWidth, SL->getElementOffset(ElementIdx)),
1);
continue;
}
CollectConstantOffset(ConstOffset->getValue(),
DL.getTypeAllocSize(GTI.getIndexedType()));
continue;
}

if (STy || ScalableType)
return false;
APInt IndexedSize =
APInt(BitWidth, DL.getTypeAllocSize(GTI.getIndexedType()));
// Insert an initial offset of 0 for V iff none exists already, then
// increment the offset by IndexedSize.
if (IndexedSize != 0) {
VariableOffsets.insert({V, APInt(BitWidth, 0)});
VariableOffsets[V] += IndexedSize;
}
}
return true;
#endif
}
5 changes: 5 additions & 0 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#ifndef ENZYME_UTILS_H
#define ENZYME_UTILS_H

#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"

Expand Down Expand Up @@ -1754,4 +1755,8 @@ static inline bool isSpecialPtr(llvm::Type *Ty) {
return AddressSpace::FirstSpecial <= AS && AS <= AddressSpace::LastSpecial;
}

bool collectOffset(llvm::GetElementPtrInst *gep, const llvm::DataLayout &DL,
unsigned BitWidth,
llvm::MapVector<llvm::Value *, llvm::APInt> &VariableOffsets,
llvm::APInt &ConstantOffset);
#endif
23 changes: 23 additions & 0 deletions enzyme/test/TypeAnalysis/veci1.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=f -o /dev/null | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -passes="print-type-analysis" -type-analysis-func=f -S -o /dev/null | FileCheck %s

; ModuleID = 'test.c'
source_filename = "test.c"
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

define dso_local i1 @f(<2 x i1> %inp) {
entry:
%e0 = extractelement <2 x i1> %inp, i32 0
%e1 = extractelement <2 x i1> %inp, i32 1
%res = and i1 %e0, %e1
ret i1 %res
}

; CHECK: f - {[-1]:Integer} |{[-1]:Integer}:{}
; CHECK-NEXT: <2 x i1> %inp: {[-1]:Integer}
; CHECK-NEXT: entry
; CHECK-NEXT: %e0 = extractelement <2 x i1> %inp, i32 0: {[-1]:Integer}
; CHECK-NEXT: %e1 = extractelement <2 x i1> %inp, i32 1: {[-1]:Integer}
; CHECK-NEXT: %res = and i1 %e0, %e1: {[-1]:Integer}
; CHECK-NEXT: ret i1 %res: {}
2 changes: 1 addition & 1 deletion enzyme/test/test_find_package/main.c
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <stdio.h>
int printf(const char*, ...);

extern double __enzyme_autodiff(void*, double);

Expand Down

0 comments on commit b0f4384

Please sign in to comment.