From 07f3e71dcc9b45a1f6ea14f8b8eb2a0a2d008f35 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 1 Oct 2023 22:23:21 -0400 Subject: [PATCH 1/4] Simplify type analysis dump (#1461) --- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index a389ad94c292..d7b1132f2f2b 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -35,6 +35,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/ModuleSlotTracker.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" @@ -2168,8 +2169,15 @@ void TypeAnalyzer::visitInsertValueInst(InsertValueInst &I) { void TypeAnalyzer::dump(llvm::raw_ostream &ss) { ss << "\n"; + // We don't care about correct MD node numbering here. + ModuleSlotTracker MST(fntypeinfo.Function->getParent(), + /*ShouldInitializeAllMetadata*/ false); for (auto &pair : analysis) { - ss << *pair.first << ": " << pair.second.str() + if (auto F = dyn_cast(pair.first)) + ss << "@" << F->getName(); + else + pair.first->print(ss, MST); + ss << ": " << pair.second.str() << ", intvals: " << to_string(knownIntegralValues(pair.first)) << "\n"; } ss << "\n"; From 9aa0b8d55bc9c2e8d8fb34eccb491007f41986a4 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 6 Oct 2023 12:41:44 -0500 Subject: [PATCH 2/4] Mark mpi functions of booleans as inactive (#1464) --- enzyme/Enzyme/Enzyme.cpp | 37 +++++++++++++++++++++ enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 13 ++++++-- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 275951c1d638..7c14c5caf2d4 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -239,6 +239,43 @@ void attributeKnownFunctions(llvm::Function &F) { F.addParamAttr(2, Attribute::WriteOnly); F.addParamAttr(2, Attribute::NoCapture); } + // Map of MPI function name to the arg index of its type argument + std::map MPI_TYPE_ARGS = { + {"MPI_Send", 2}, {"MPI_Ssend", 2}, {"MPI_Bsend", 2}, + {"MPI_Recv", 2}, {"MPI_Brecv", 2}, {"PMPI_Send", 2}, + {"PMPI_Ssend", 2}, {"PMPI_Bsend", 2}, {"PMPI_Recv", 2}, + {"PMPI_Brecv", 2}, + + {"MPI_Isend", 2}, {"MPI_Irecv", 2}, {"PMPI_Isend", 2}, + {"PMPI_Irecv", 2}, + + {"MPI_Reduce", 3}, {"PMPI_Reduce", 3}, + + {"MPI_Allreduce", 3}, {"PMPI_Allreduce", 3}}; + { + auto found = MPI_TYPE_ARGS.find(F.getName().str()); + if (found != MPI_TYPE_ARGS.end()) { + for (auto user : F.users()) { + if (auto CI = dyn_cast(user)) + if (CI->getCalledFunction() == &F) { + if (Constant *C = + dyn_cast(CI->getArgOperand(found->second))) { + while (ConstantExpr *CE = dyn_cast(C)) { + C = CE->getOperand(0); + } + if (auto GV = dyn_cast(C)) { + if (GV->getName() == "ompi_mpi_cxx_bool") { + CI->addAttribute( + AttributeList::FunctionIndex, + Attribute::get(CI->getContext(), "enzyme_inactive")); + } + } + } + } + } + } + } + if (F.getName() == "omp_get_max_threads" || F.getName() == "omp_get_thread_num") { #if LLVM_VERSION_MAJOR >= 16 diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index d7b1132f2f2b..1276aa310769 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -4031,6 +4031,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { buf.insert({0}, Type::getDoubleTy(C->getContext())); } else if (GV->getName() == "ompi_mpi_float") { buf.insert({0}, Type::getFloatTy(C->getContext())); + } else if (GV->getName() == "ompi_mpi_cxx_bool") { + buf.insert({0}, BaseType::Integer); } } else if (auto CI = dyn_cast(C)) { // MPICH @@ -4051,7 +4053,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); return; } - if (funcName == "MPI_Isend" || funcName == "MPI_Irecv") { + if (funcName == "MPI_Isend" || funcName == "MPI_Irecv" || + funcName == "PMPI_Isend" || funcName == "PMPI_Irecv") { TypeTree buf = TypeTree(BaseType::Pointer); if (Constant *C = dyn_cast(call.getOperand(2))) { @@ -4063,6 +4066,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { buf.insert({0}, Type::getDoubleTy(C->getContext())); } else if (GV->getName() == "ompi_mpi_float") { buf.insert({0}, Type::getFloatTy(C->getContext())); + } else if (GV->getName() == "ompi_mpi_cxx_bool") { + buf.insert({0}, BaseType::Integer); } } else if (auto CI = dyn_cast(C)) { // MPICH @@ -4137,6 +4142,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { buf.insert({0}, Type::getDoubleTy(C->getContext())); } else if (GV->getName() == "ompi_mpi_float") { buf.insert({0}, Type::getFloatTy(C->getContext())); + } else if (GV->getName() == "ompi_mpi_cxx_bool") { + buf.insert({0}, BaseType::Integer); } } else if (auto CI = dyn_cast(C)) { // MPICH @@ -4164,7 +4171,7 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call); return; } - if (funcName == "MPI_Allreduce") { + if (funcName == "MPI_Allreduce" || funcName == "PMPI_Allreduce") { TypeTree buf = TypeTree(BaseType::Pointer); if (Constant *C = dyn_cast(call.getOperand(3))) { @@ -4176,6 +4183,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { buf.insert({0}, Type::getDoubleTy(C->getContext())); } else if (GV->getName() == "ompi_mpi_float") { buf.insert({0}, Type::getFloatTy(C->getContext())); + } else if (GV->getName() == "ompi_mpi_cxx_bool") { + buf.insert({0}, BaseType::Integer); } } else if (auto CI = dyn_cast(C)) { // MPICH From 08288a01fafc745d9c94fbe9a0b053701733bf4d Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 6 Oct 2023 16:50:59 -0500 Subject: [PATCH 3/4] [Type Analysis] handle extract vector of i1 (#1463) --- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 5 +++-- enzyme/test/TypeAnalysis/veci1.ll | 23 +++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) create mode 100644 enzyme/test/TypeAnalysis/veci1.ll diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 1276aa310769..866480f0e80a 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -1888,10 +1888,11 @@ void TypeAnalyzer::visitExtractElementInst(ExtractElementInst &I) { auto &dl = fntypeinfo.Function->getParent()->getDataLayout(); VectorType *vecType = cast(I.getVectorOperand()->getType()); - size_t size = (dl.getTypeSizeInBits(vecType->getElementType()) + 7) / 8; + size_t bitsize = dl.getTypeSizeInBits(vecType->getElementType()); + size_t size = (bitsize + 7) / 8; if (auto CI = dyn_cast(I.getIndexOperand())) { - size_t off = CI->getZExtValue() * size; + size_t off = (CI->getZExtValue() * bitsize) / 8; if (direction & DOWN) updateAnalysis(&I, diff --git a/enzyme/test/TypeAnalysis/veci1.ll b/enzyme/test/TypeAnalysis/veci1.ll new file mode 100644 index 000000000000..30d3f6eb78d8 --- /dev/null +++ b/enzyme/test/TypeAnalysis/veci1.ll @@ -0,0 +1,23 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=f -o /dev/null | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="print-type-analysis" -type-analysis-func=f -S -o /dev/null | FileCheck %s + +; ModuleID = 'test.c' +source_filename = "test.c" +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +define dso_local i1 @f(<2 x i1> %inp) { +entry: + %e0 = extractelement <2 x i1> %inp, i32 0 + %e1 = extractelement <2 x i1> %inp, i32 1 + %res = and i1 %e0, %e1 + ret i1 %res +} + +; CHECK: f - {[-1]:Integer} |{[-1]:Integer}:{} +; CHECK-NEXT: <2 x i1> %inp: {[-1]:Integer} +; CHECK-NEXT: entry +; CHECK-NEXT: %e0 = extractelement <2 x i1> %inp, i32 0: {[-1]:Integer} +; CHECK-NEXT: %e1 = extractelement <2 x i1> %inp, i32 1: {[-1]:Integer} +; CHECK-NEXT: %res = and i1 %e0, %e1: {[-1]:Integer} +; CHECK-NEXT: ret i1 %res: {} From 1ada437c5b5f7a794f7d96c957d451c4c21cb5a7 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 6 Oct 2023 16:53:16 -0500 Subject: [PATCH 4/4] Add collect offset c api function (#1465) --- enzyme/Enzyme/CApi.cpp | 18 ++++++++ enzyme/Enzyme/Utils.cpp | 65 ++++++++++++++++++++++++++++ enzyme/Enzyme/Utils.h | 5 +++ enzyme/test/test_find_package/main.c | 2 +- 4 files changed, 89 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index 212843852ef7..c5680af9112f 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -1188,6 +1188,24 @@ LLVMValueRef EnzymeCloneFunctionWithoutReturnOrArgs(LLVMValueRef FC, LLVMTypeRef EnzymeAllocaType(LLVMValueRef V) { return wrap(cast(unwrap(V))->getAllocatedType()); } +LLVMValueRef EnzymeComputeByteOffsetOfGEP(LLVMBuilderRef B_r, LLVMValueRef V_r, + LLVMTypeRef T_r) { + IRBuilder<> &B = *unwrap(B_r); + auto T = cast(unwrap(T_r)); + auto width = T->getBitWidth(); + auto gep = cast(unwrap(V_r)); + auto &DL = B.GetInsertBlock()->getParent()->getParent()->getDataLayout(); + + MapVector VariableOffsets; + APInt Offset(width, 0); + bool success = collectOffset(gep, DL, width, VariableOffsets, Offset); + assert(success); + Value *start = ConstantInt::get(T, Offset); + for (auto &pair : VariableOffsets) + start = B.CreateAdd( + start, B.CreateMul(pair.first, ConstantInt::get(T, pair.second))); + return wrap(start); +} } static size_t num_rooting(llvm::Type *T, llvm::Function *F) { diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 3445550a1b13..756c972996d9 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -37,6 +37,7 @@ #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/Module.h" @@ -2648,3 +2649,67 @@ CountTrackedPointers::CountTrackedPointers(Type *T) { if (count == 0) all = false; } + +bool collectOffset(GetElementPtrInst *gep, const DataLayout &DL, + unsigned BitWidth, + MapVector &VariableOffsets, + APInt &ConstantOffset) { +#if LLVM_VERSION_MAJOR >= 13 + return cast(gep)->collectOffset(DL, BitWidth, VariableOffsets, + ConstantOffset); +#else + assert(BitWidth == DL.getIndexSizeInBits(gep->getPointerAddressSpace()) && + "The offset bit width does not match DL specification."); + + auto CollectConstantOffset = [&](APInt Index, uint64_t Size) { + Index = Index.sextOrTrunc(BitWidth); + APInt IndexedSize = APInt(BitWidth, Size); + ConstantOffset += Index * IndexedSize; + }; + + for (gep_type_iterator GTI = gep_type_begin(gep), GTE = gep_type_end(gep); + GTI != GTE; ++GTI) { + // Scalable vectors are multiplied by a runtime constant. + bool ScalableType = isa(GTI.getIndexedType()); + + Value *V = GTI.getOperand(); + StructType *STy = GTI.getStructTypeOrNull(); + // Handle ConstantInt if possible. + if (auto ConstOffset = dyn_cast(V)) { + if (ConstOffset->isZero()) + continue; + // If the type is scalable and the constant is not zero (vscale * n * 0 = + // 0) bailout. + // TODO: If the runtime value is accessible at any point before DWARF + // emission, then we could potentially keep a forward reference to it + // in the debug value to be filled in later. + if (ScalableType) + return false; + // Handle a struct index, which adds its field offset to the pointer. + if (STy) { + unsigned ElementIdx = ConstOffset->getZExtValue(); + const StructLayout *SL = DL.getStructLayout(STy); + // Element offset is in bytes. + CollectConstantOffset(APInt(BitWidth, SL->getElementOffset(ElementIdx)), + 1); + continue; + } + CollectConstantOffset(ConstOffset->getValue(), + DL.getTypeAllocSize(GTI.getIndexedType())); + continue; + } + + if (STy || ScalableType) + return false; + APInt IndexedSize = + APInt(BitWidth, DL.getTypeAllocSize(GTI.getIndexedType())); + // Insert an initial offset of 0 for V iff none exists already, then + // increment the offset by IndexedSize. + if (IndexedSize != 0) { + VariableOffsets.insert({V, APInt(BitWidth, 0)}); + VariableOffsets[V] += IndexedSize; + } + } + return true; +#endif +} diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index fd0daa7da04d..b3208ee9611c 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -25,6 +25,7 @@ #ifndef ENZYME_UTILS_H #define ENZYME_UTILS_H +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" @@ -1754,4 +1755,8 @@ static inline bool isSpecialPtr(llvm::Type *Ty) { return AddressSpace::FirstSpecial <= AS && AS <= AddressSpace::LastSpecial; } +bool collectOffset(llvm::GetElementPtrInst *gep, const llvm::DataLayout &DL, + unsigned BitWidth, + llvm::MapVector &VariableOffsets, + llvm::APInt &ConstantOffset); #endif diff --git a/enzyme/test/test_find_package/main.c b/enzyme/test/test_find_package/main.c index a31cccc6cbc4..2d0608b779c8 100644 --- a/enzyme/test/test_find_package/main.c +++ b/enzyme/test/test_find_package/main.c @@ -1,4 +1,4 @@ -#include +int printf(const char*, ...); extern double __enzyme_autodiff(void*, double);