From a1dc0943c6192e6bbdaf80dc8378ed40032e9b72 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 16 Sep 2023 00:59:54 -0500 Subject: [PATCH 01/29] Handle memset of undef memory, even if not type analyzable (#1426) --- enzyme/Enzyme/AdjointGenerator.h | 41 ++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index f6e5bcc529d3..3ce4f6165215 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -2747,6 +2747,47 @@ class AdjointGenerator auto &DL = gutils->newFunc->getParent()->getDataLayout(); auto vd = TR.query(MS.getOperand(0)).Data0().ShiftIndices(DL, 0, size, 0); + if (!vd.isKnownPastPointer()) { + // If unknown type results, and zeroing known undef allocation, consider + // integers + if (auto CI = dyn_cast(MS.getOperand(1))) + if (CI->isZero()) { + auto root = getBaseObject(MS.getOperand(0)); + bool writtenTo = false; + if (isa(root) || isAllocationCall(root, gutils->TLI)) { + Instruction *cur = MS.getPrevNode(); + while (cur) { + if (cur == root) + break; + if (auto MCI = dyn_cast(MS.getOperand(2))) { + if (auto II = dyn_cast(cur)) { + // If the start of the lifetime for more memory than being + // memset, its valid. + if (II->getIntrinsicID() == Intrinsic::lifetime_start) { + if (getBaseObject(II->getOperand(1)) == root) { + if (auto CI2 = dyn_cast(II->getOperand(0))) { + if (MCI->getValue().ult(CI2->getValue())) + break; + } + } + } + } + } + if (cur->mayWriteToMemory()) { + writtenTo = true; + break; + } + cur = cur->getPrevNode(); + } + + if (!writtenTo) { + vd = TypeTree(BaseType::Pointer); + vd.insert({-1}, BaseType::Integer); + } + } + } + } + if (!vd.isKnownPastPointer()) { // If unknown type results, consider the intersection of all incoming. if (isa(MS.getOperand(0)) || isa(MS.getOperand(0))) { From af78947bdc5526591e3e1e1198d33353ce2ad15b Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 16 Sep 2023 20:11:48 -0500 Subject: [PATCH 02/29] Do not perform runtime activity check for known non-aliasing pointers (#1427) --- enzyme/Enzyme/DiffeGradientUtils.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index 6f6c06b49211..9025f5e076c9 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -46,6 +46,8 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" +#include "LibraryFuncs.h" + using namespace llvm; DiffeGradientUtils::DiffeGradientUtils( @@ -1037,7 +1039,13 @@ void DiffeGradientUtils::addToInvertedPtrDiffe( } if (!isConstantValue(origptr)) { - if (EnzymeRuntimeActivityCheck && !merge) { + auto basePtr = getBaseObject(origptr); + assert(!isConstantValue(basePtr)); + // If runtime activity, first see if we can prove that the shadow/primal + // are distinct statically as they are allocas/mallocs, if not compare + // the pointers and conditionally execute. + if ((!isa(basePtr) && !isAllocationCall(basePtr, TLI)) && + EnzymeRuntimeActivityCheck && !merge) { Value *shadow = Builder2.CreateICmpNE( lookupM(getNewFromOriginal(origptr), Builder2), lookupM(invertPointerM(origptr, Builder2), Builder2)); From 6bd0650f559b49343d2fa950147d695212be8976 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 16 Sep 2023 20:27:16 -0500 Subject: [PATCH 03/29] Print debug info on assertion (#1428) --- enzyme/Enzyme/GradientUtils.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 29f29b5cb3c7..5e6ebc407c02 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -867,6 +867,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM, origParent = lookupInst; \ assert(unwrapMode == UnwrapMode::AttemptSingleUnwrap); \ auto found = available.find(v); \ + if (found != available.end() && !found->second) \ + llvm::errs() << *oldFunc << "\n" << *newFunc << "\n" << *v << "\n"; \ assert(found == available.end() || found->second); \ ___res = lookupM(v, Builder, available, v != val, origParent); \ if (___res && ___res->getType() != v->getType()) { \ From 6bc5e0426a393e8e7e493b85cc131b2fb0ada994 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 17 Sep 2023 01:03:41 -0500 Subject: [PATCH 04/29] Handle null unwrap return in available map (#1429) --- enzyme/Enzyme/GradientUtils.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 5e6ebc407c02..e1bdd335e694 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -868,15 +868,16 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM, assert(unwrapMode == UnwrapMode::AttemptSingleUnwrap); \ auto found = available.find(v); \ if (found != available.end() && !found->second) \ - llvm::errs() << *oldFunc << "\n" << *newFunc << "\n" << *v << "\n"; \ - assert(found == available.end() || found->second); \ - ___res = lookupM(v, Builder, available, v != val, origParent); \ - if (___res && ___res->getType() != v->getType()) { \ - llvm::errs() << *newFunc << "\n"; \ - llvm::errs() << " v = " << *v << " res = " << *___res << "\n"; \ + ___res = nullptr; \ + else { \ + ___res = lookupM(v, Builder, available, v != val, origParent); \ + if (___res && ___res->getType() != v->getType()) { \ + llvm::errs() << *newFunc << "\n"; \ + llvm::errs() << " v = " << *v << " res = " << *___res << "\n"; \ + } \ + if (___res) \ + assert(___res->getType() == v->getType() && "lu"); \ } \ - if (___res) \ - assert(___res->getType() == v->getType() && "lu"); \ } \ ___res; \ }) From 9a087e681a05b8c69555b8d1bd34a450d13f5654 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 17 Sep 2023 19:43:18 -0500 Subject: [PATCH 05/29] Fix use of exit block shared by multiple loops (#1430) * Loop multiexit [wip] * Fix use of exit block shared by multiple loops --- enzyme/Enzyme/CacheUtility.cpp | 2 +- enzyme/Enzyme/GradientUtils.cpp | 1332 +++++++++-------- enzyme/Enzyme/GradientUtils.h | 13 +- .../test/Enzyme/ReverseMode/multiloopexit.ll | 212 +++ 4 files changed, 899 insertions(+), 660 deletions(-) create mode 100644 enzyme/test/Enzyme/ReverseMode/multiloopexit.ll diff --git a/enzyme/Enzyme/CacheUtility.cpp b/enzyme/Enzyme/CacheUtility.cpp index 848c8c87a6ef..c189faa5174a 100644 --- a/enzyme/Enzyme/CacheUtility.cpp +++ b/enzyme/Enzyme/CacheUtility.cpp @@ -459,7 +459,7 @@ llvm::AllocaInst *CacheUtility::getDynamicLoopLimit(llvm::Loop *L, auto Limit = B.CreatePHI(found.var->getType(), 1); for (BasicBlock *Pred : predecessors(ExitBlock)) { - if (LI.getLoopFor(Pred) == L) { + if (L->contains(Pred)) { Limit->addIncoming(found.var, Pred); } else { Limit->addIncoming(UndefValue::get(found.var->getType()), Pred); diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index e1bdd335e694..0b7ca4be6368 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -2987,722 +2987,740 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc, assert(false); } -/// Given an edge from BB to branchingBlock get the corresponding block to -/// branch to in the reverse pass -BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB, - BasicBlock *branchingBlock) { - assert(BB); - // BB should be a forward pass block, assert that - if (reverseBlocks.find(BB) == reverseBlocks.end()) { - llvm::errs() << *oldFunc << "\n"; - llvm::errs() << *newFunc << "\n"; - llvm::errs() << "BB: " << *BB << "\n"; - llvm::errs() << "branchingBlock: " << *branchingBlock << "\n"; - } - assert(reverseBlocks.find(BB) != reverseBlocks.end()); - assert(reverseBlocks.find(branchingBlock) != reverseBlocks.end()); - LoopContext lc; - bool inLoop = getContext(BB, lc); +BasicBlock *GradientUtils::prepRematerializedLoopEntry(LoopContext &lc) { + auto header = lc.header; + SmallPtrSet loopRematerializations; + SmallPtrSet loopReallocations; + SmallPtrSet loopShadowReallocations; + SmallPtrSet loopShadowZeroInits; + SmallPtrSet loopShadowRematerializations; + Loop *origLI = nullptr; + for (auto pair : rematerializableAllocations) { + if (pair.second.LI && + getNewFromOriginal(pair.second.LI->getHeader()) == header) { + bool rematerialized = false; + std::map Seen; + for (auto pair : knownRecomputeHeuristic) + if (!pair.second) + Seen[UsageKey(pair.first, ValueType::Primal)] = false; - LoopContext branchingContext; - bool inLoopContext = getContext(branchingBlock, branchingContext); - - if (!inLoop) - return reverseBlocks[BB].front(); - - auto tup = std::make_tuple(BB, branchingBlock); - if (newBlocksForLoop_cache.find(tup) != newBlocksForLoop_cache.end()) - return newBlocksForLoop_cache[tup]; - - if (inLoop) { - // If we're reversing a latch edge. - bool incEntering = inLoopContext && branchingBlock == lc.header && - lc.header == branchingContext.header; - - auto L = LI.getLoopFor(BB); - auto latches = getLatches(L, lc.exitBlocks); - // If we're reverseing a loop exit. - bool exitEntering = - std::find(latches.begin(), latches.end(), BB) != latches.end() && - std::find(lc.exitBlocks.begin(), lc.exitBlocks.end(), branchingBlock) != - lc.exitBlocks.end(); - - // If we're re-entering a loop, prepare a loop-level forward pass to - // rematerialize any loop-scope rematerialization. - if (incEntering || exitEntering) { - SmallPtrSet loopRematerializations; - SmallPtrSet loopReallocations; - SmallPtrSet loopShadowReallocations; - SmallPtrSet loopShadowZeroInits; - SmallPtrSet loopShadowRematerializations; - Loop *origLI = nullptr; - for (auto pair : rematerializableAllocations) { - if (pair.second.LI && - getNewFromOriginal(pair.second.LI->getHeader()) == L->getHeader()) { - bool rematerialized = false; - std::map Seen; - for (auto pair : knownRecomputeHeuristic) - if (!pair.second) - Seen[UsageKey(pair.first, ValueType::Primal)] = false; - - if (DifferentialUseAnalysis::is_value_needed_in_reverse< - ValueType::Primal>(this, pair.first, mode, Seen, - notForAnalysis)) { - rematerialized = true; + if (DifferentialUseAnalysis::is_value_needed_in_reverse< + ValueType::Primal>(this, pair.first, mode, Seen, + notForAnalysis)) { + rematerialized = true; + } + if (rematerialized) { + if (auto inst = dyn_cast(pair.first)) + if (pair.second.LI->contains(inst->getParent())) { + loopReallocations.insert(inst); } - if (rematerialized) { - if (auto inst = dyn_cast(pair.first)) - if (pair.second.LI->contains(inst->getParent())) { - loopReallocations.insert(inst); - } - for (auto I : pair.second.stores) - loopRematerializations.insert(I); - origLI = pair.second.LI; + for (auto I : pair.second.stores) + loopRematerializations.insert(I); + origLI = pair.second.LI; + } + } + } + for (auto pair : backwardsOnlyShadows) { + if (pair.second.LI && + getNewFromOriginal(pair.second.LI->getHeader()) == header) { + if (auto inst = dyn_cast(pair.first)) { + bool restoreStores = false; + if (pair.second.LI->contains(inst->getParent())) { + // TODO later make it so primalInitialize can be restored + // rather than cached from primal + if (!pair.second.primalInitialize) { + loopShadowReallocations.insert(inst); + restoreStores = true; } + } else { + // if (pair.second.primalInitialize) { + // loopShadowZeroInits.insert(inst); + //} + restoreStores = true; } - } - for (auto pair : backwardsOnlyShadows) { - if (pair.second.LI && - getNewFromOriginal(pair.second.LI->getHeader()) == L->getHeader()) { - if (auto inst = dyn_cast(pair.first)) { - bool restoreStores = false; - if (pair.second.LI->contains(inst->getParent())) { - // TODO later make it so primalInitialize can be restored - // rather than cached from primal - if (!pair.second.primalInitialize) { - loopShadowReallocations.insert(inst); - restoreStores = true; - } - } else { - // if (pair.second.primalInitialize) { - // loopShadowZeroInits.insert(inst); - //} - restoreStores = true; - } - if (restoreStores) { - for (auto I : pair.second.stores) { - loopShadowRematerializations.insert(I); - } - } - origLI = pair.second.LI; + if (restoreStores) { + for (auto I : pair.second.stores) { + loopShadowRematerializations.insert(I); } } - } - BasicBlock *resumeblock = reverseBlocks[BB].front(); - if (loopRematerializations.size() != 0 || loopReallocations.size() != 0 || - loopShadowRematerializations.size() != 0 || - loopShadowReallocations.size() != 0 || - loopShadowZeroInits.size() != 0) { - auto found = rematerializedLoops_cache.find(L); - if (found != rematerializedLoops_cache.end()) { - resumeblock = found->second; - } else { - BasicBlock *enterB = BasicBlock::Create( - BB->getContext(), "remat_enter", BB->getParent()); - rematerializedLoops_cache[L] = enterB; - std::map origToNewForward; - for (auto B : origLI->getBlocks()) { - BasicBlock *newB = BasicBlock::Create( - B->getContext(), - "remat_" + lc.header->getName() + "_" + B->getName(), - BB->getParent()); - origToNewForward[B] = newB; - reverseBlockToPrimal[newB] = getNewFromOriginal(B); - if (B == origLI->getHeader()) { - IRBuilder<> NB(newB); - for (auto inst : loopShadowZeroInits) { - auto anti = lookupM(invertPointerM(inst, NB), NB); - StringRef funcName; - SmallVector args; - if (auto orig = dyn_cast(inst)) { + origLI = pair.second.LI; + } + } + } + if (loopRematerializations.size() != 0 || loopReallocations.size() != 0 || + loopShadowRematerializations.size() != 0 || + loopShadowReallocations.size() != 0 || loopShadowZeroInits.size() != 0) { + auto found = rematerializedLoops_cache.find(header); + if (found != rematerializedLoops_cache.end()) { + return found->second; + } + + BasicBlock *enterB = + BasicBlock::Create(header->getContext(), "remat_enter", newFunc); + rematerializedLoops_cache[header] = enterB; + std::map origToNewForward; + for (auto B : origLI->getBlocks()) { + BasicBlock *newB = BasicBlock::Create( + B->getContext(), "remat_" + header->getName() + "_" + B->getName(), + newFunc); + origToNewForward[B] = newB; + reverseBlockToPrimal[newB] = getNewFromOriginal(B); + if (B == origLI->getHeader()) { + IRBuilder<> NB(newB); + for (auto inst : loopShadowZeroInits) { + auto anti = lookupM(invertPointerM(inst, NB), NB); + StringRef funcName; + SmallVector args; + if (auto orig = dyn_cast(inst)) { #if LLVM_VERSION_MAJOR >= 14 - for (auto &arg : orig->args()) + for (auto &arg : orig->args()) #else - for (auto &arg : orig->arg_operands()) + for (auto &arg : orig->arg_operands()) #endif - { - args.push_back(lookupM(getNewFromOriginal(arg), NB)); - } - funcName = getFuncNameFromCall(orig); - } else if (auto AI = dyn_cast(inst)) { - funcName = "malloc"; - Value *sz = - lookupM(getNewFromOriginal(AI->getArraySize()), NB); - - auto ci = ConstantInt::get( - sz->getType(), - B->getParent() - ->getParent() - ->getDataLayout() - .getTypeAllocSizeInBits(AI->getAllocatedType()) / - 8); - sz = NB.CreateMul(sz, ci); - args.push_back(sz); - } - assert(funcName.size()); - - applyChainRule( - NB, - [&](Value *anti) { - zeroKnownAllocation(NB, anti, args, funcName, TLI, - dyn_cast(inst)); - }, - anti); - } + { + args.push_back(lookupM(getNewFromOriginal(arg), NB)); } + funcName = getFuncNameFromCall(orig); + } else if (auto AI = dyn_cast(inst)) { + funcName = "malloc"; + Value *sz = lookupM(getNewFromOriginal(AI->getArraySize()), NB); + + auto ci = ConstantInt::get( + sz->getType(), + B->getParent() + ->getParent() + ->getDataLayout() + .getTypeAllocSizeInBits(AI->getAllocatedType()) / + 8); + sz = NB.CreateMul(sz, ci); + args.push_back(sz); } + assert(funcName.size()); - ValueToValueMapTy available; + applyChainRule( + NB, + [&](Value *anti) { + zeroKnownAllocation(NB, anti, args, funcName, TLI, + dyn_cast(inst)); + }, + anti); + } + } + } - { - IRBuilder<> NB(enterB); - NB.CreateBr(origToNewForward[origLI->getHeader()]); - } + ValueToValueMapTy available; - std::function handleLoop = [&](Loop *OL, - bool subLoop) { - if (subLoop) { - auto Header = OL->getHeader(); - IRBuilder<> NB(origToNewForward[Header]); - LoopContext flc; - getContext(getNewFromOriginal(Header), flc); + { + IRBuilder<> NB(enterB); + NB.CreateBr(origToNewForward[origLI->getHeader()]); + } - auto iv = NB.CreatePHI(flc.var->getType(), 2, "fiv"); - auto inc = NB.CreateAdd(iv, ConstantInt::get(iv->getType(), 1)); + std::function handleLoop = [&](Loop *OL, bool subLoop) { + if (subLoop) { + auto Header = OL->getHeader(); + IRBuilder<> NB(origToNewForward[Header]); + LoopContext flc; + getContext(getNewFromOriginal(Header), flc); - for (auto PH : predecessors(Header)) { - if (notForAnalysis.count(PH)) - continue; + auto iv = NB.CreatePHI(flc.var->getType(), 2, "fiv"); + auto inc = NB.CreateAdd(iv, ConstantInt::get(iv->getType(), 1)); - if (OL->contains(PH)) - iv->addIncoming(inc, origToNewForward[PH]); - else - iv->addIncoming(ConstantInt::get(iv->getType(), 0), - origToNewForward[PH]); - } - available[flc.var] = iv; - available[flc.incvar] = inc; - } - for (auto SL : OL->getSubLoops()) - handleLoop(SL, /*subLoop*/ true); - }; - handleLoop(origLI, /*subLoop*/ false); - - for (auto B : origLI->getBlocks()) { - auto newB = origToNewForward[B]; - IRBuilder<> NB(newB); - - // TODO fill available with relevant IV's surrounding and - // IV's of inner loop phi's - - for (auto &I : *B) { - // Only handle store, memset, and julia.write_barrier - if (loopRematerializations.count(&I)) { - if (auto SI = dyn_cast(&I)) { - auto ts = NB.CreateStore( - lookupM(getNewFromOriginal(SI->getValueOperand()), NB, - available), - lookupM(getNewFromOriginal(SI->getPointerOperand()), NB, - available)); - llvm::SmallVector ToCopy2(MD_ToCopy); - ToCopy2.push_back(LLVMContext::MD_noalias); - ToCopy2.push_back(LLVMContext::MD_alias_scope); - ts->copyMetadata(*SI, ToCopy2); - ts->setAlignment(SI->getAlign()); - ts->setVolatile(SI->isVolatile()); - ts->setOrdering(SI->getOrdering()); - ts->setSyncScopeID(SI->getSyncScopeID()); - ts->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); - } else if (auto CI = dyn_cast(&I)) { - StringRef funcName = getFuncNameFromCall(CI); - if (funcName == "enzyme_zerotype") - continue; - if (funcName == "julia.write_barrier" || - isa(&I) || isa(&I)) { + for (auto PH : predecessors(Header)) { + if (notForAnalysis.count(PH)) + continue; - // TODO - SmallVector args; + if (OL->contains(PH)) + iv->addIncoming(inc, origToNewForward[PH]); + else + iv->addIncoming(ConstantInt::get(iv->getType(), 0), + origToNewForward[PH]); + } + available[flc.var] = iv; + available[flc.incvar] = inc; + } + for (auto SL : OL->getSubLoops()) + handleLoop(SL, /*subLoop*/ true); + }; + handleLoop(origLI, /*subLoop*/ false); + + for (auto B : origLI->getBlocks()) { + auto newB = origToNewForward[B]; + IRBuilder<> NB(newB); + + // TODO fill available with relevant IV's surrounding and + // IV's of inner loop phi's + + for (auto &I : *B) { + // Only handle store, memset, and julia.write_barrier + if (loopRematerializations.count(&I)) { + if (auto SI = dyn_cast(&I)) { + auto ts = NB.CreateStore( + lookupM(getNewFromOriginal(SI->getValueOperand()), NB, + available), + lookupM(getNewFromOriginal(SI->getPointerOperand()), NB, + available)); + llvm::SmallVector ToCopy2(MD_ToCopy); + ToCopy2.push_back(LLVMContext::MD_noalias); + ToCopy2.push_back(LLVMContext::MD_alias_scope); + ts->copyMetadata(*SI, ToCopy2); + ts->setAlignment(SI->getAlign()); + ts->setVolatile(SI->isVolatile()); + ts->setOrdering(SI->getOrdering()); + ts->setSyncScopeID(SI->getSyncScopeID()); + ts->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); + } else if (auto CI = dyn_cast(&I)) { + StringRef funcName = getFuncNameFromCall(CI); + if (funcName == "enzyme_zerotype") + continue; + if (funcName == "julia.write_barrier" || isa(&I) || + isa(&I)) { + + // TODO + SmallVector args; #if LLVM_VERSION_MAJOR >= 14 - for (auto &arg : CI->args()) + for (auto &arg : CI->args()) #else - for (auto &arg : CI->arg_operands()) + for (auto &arg : CI->arg_operands()) #endif - args.push_back( - lookupM(getNewFromOriginal(arg), NB, available)); - - SmallVector BundleTypes(args.size(), - ValueType::Primal); - - auto Defs = getInvertedBundles(CI, BundleTypes, NB, - /*lookup*/ true, available); - auto cal = - NB.CreateCall(CI->getFunctionType(), - CI->getCalledOperand(), args, Defs); - cal->setAttributes(CI->getAttributes()); - cal->setCallingConv(CI->getCallingConv()); - cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); - } else { - assert(isDeallocationFunction(funcName, TLI)); - continue; - } - } else { - assert(0 && "unhandlable loop rematerialization instruction"); - } - } else if (loopReallocations.count(&I)) { - LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, - &newFunc->getEntryBlock()); - - auto inst = getNewFromOriginal((Value *)&I); - - auto found = scopeMap.find(inst); - if (found == scopeMap.end()) { - AllocaInst *cache = - createCacheForScope(lctx, inst->getType(), - inst->getName(), /*shouldFree*/ true); - assert(cache); - found = insert_or_assign( - scopeMap, inst, - std::pair, LimitContext>(cache, - lctx)); - } - auto cache = found->second.first; - if (auto MD = hasMetadata(&I, "enzyme_fromstack")) { - auto replacement = NB.CreateAlloca( - Type::getInt8Ty(I.getContext()), - lookupM(getNewFromOriginal(I.getOperand(0)), NB, - available)); - auto Alignment = cast(cast( - MD->getOperand(0)) - ->getValue()) - ->getLimitedValue(); - replacement->setAlignment(Align(Alignment)); - replacement->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); - storeInstructionInCache(lctx, NB, replacement, cache); - } else if (auto CI = dyn_cast(&I)) { - SmallVector args; + args.push_back(lookupM(getNewFromOriginal(arg), NB, available)); + + SmallVector BundleTypes(args.size(), + ValueType::Primal); + + auto Defs = getInvertedBundles(CI, BundleTypes, NB, + /*lookup*/ true, available); + auto cal = NB.CreateCall(CI->getFunctionType(), + CI->getCalledOperand(), args, Defs); + cal->setAttributes(CI->getAttributes()); + cal->setCallingConv(CI->getCallingConv()); + cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); + } else { + assert(isDeallocationFunction(funcName, TLI)); + continue; + } + } else { + assert(0 && "unhandlable loop rematerialization instruction"); + } + } else if (loopReallocations.count(&I)) { + LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, + &newFunc->getEntryBlock()); + + auto inst = getNewFromOriginal((Value *)&I); + + auto found = scopeMap.find(inst); + if (found == scopeMap.end()) { + AllocaInst *cache = createCacheForScope( + lctx, inst->getType(), inst->getName(), /*shouldFree*/ true); + assert(cache); + found = insert_or_assign( + scopeMap, inst, + std::pair, LimitContext>(cache, lctx)); + } + auto cache = found->second.first; + if (auto MD = hasMetadata(&I, "enzyme_fromstack")) { + auto replacement = NB.CreateAlloca( + Type::getInt8Ty(I.getContext()), + lookupM(getNewFromOriginal(I.getOperand(0)), NB, available)); + auto Alignment = + cast( + cast(MD->getOperand(0))->getValue()) + ->getLimitedValue(); + replacement->setAlignment(Align(Alignment)); + replacement->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); + storeInstructionInCache(lctx, NB, replacement, cache); + } else if (auto CI = dyn_cast(&I)) { + SmallVector args; #if LLVM_VERSION_MAJOR >= 14 - for (auto &arg : CI->args()) + for (auto &arg : CI->args()) #else - for (auto &arg : CI->arg_operands()) + for (auto &arg : CI->arg_operands()) #endif - args.push_back( - lookupM(getNewFromOriginal(arg), NB, available)); - - SmallVector BundleTypes(args.size(), - ValueType::Primal); - - auto Defs = getInvertedBundles(CI, BundleTypes, NB, - /*lookup*/ true, available); - auto cal = NB.CreateCall(CI->getCalledFunction(), args, Defs); - llvm::SmallVector ToCopy2(MD_ToCopy); - ToCopy2.push_back(LLVMContext::MD_noalias); - ToCopy2.push_back(LLVMContext::MD_alias_scope); - cal->copyMetadata(*CI, ToCopy2); - cal->setName("remat_" + CI->getName()); - cal->setAttributes(CI->getAttributes()); - cal->setCallingConv(CI->getCallingConv()); - cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); - storeInstructionInCache(lctx, NB, cal, cache); + args.push_back(lookupM(getNewFromOriginal(arg), NB, available)); + + SmallVector BundleTypes(args.size(), + ValueType::Primal); + + auto Defs = getInvertedBundles(CI, BundleTypes, NB, + /*lookup*/ true, available); + auto cal = NB.CreateCall(CI->getCalledFunction(), args, Defs); + llvm::SmallVector ToCopy2(MD_ToCopy); + ToCopy2.push_back(LLVMContext::MD_noalias); + ToCopy2.push_back(LLVMContext::MD_alias_scope); + cal->copyMetadata(*CI, ToCopy2); + cal->setName("remat_" + CI->getName()); + cal->setAttributes(CI->getAttributes()); + cal->setCallingConv(CI->getCallingConv()); + cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); + storeInstructionInCache(lctx, NB, cal, cache); + } else { + llvm::errs() << " realloc: " << I << "\n"; + llvm_unreachable("Unknown loop reallocation"); + } + } + if (loopShadowRematerializations.count(&I)) { + if (auto SI = dyn_cast(&I)) { + Value *orig_ptr = SI->getPointerOperand(); + Value *orig_val = SI->getValueOperand(); + Type *valType = orig_val->getType(); + assert(!isConstantValue(orig_ptr)); + + auto &DL = newFunc->getParent()->getDataLayout(); + + bool constantval = isConstantValue(orig_val) || + parseTBAA(I, DL).Inner0().isIntegral(); + + // TODO allow recognition of other types that could contain + // pointers [e.g. {void*, void*} or <2 x i64> ] + auto storeSize = DL.getTypeSizeInBits(valType) / 8; + + //! Storing a floating point value + Type *FT = nullptr; + if (valType->isFPOrFPVectorTy()) { + FT = valType->getScalarType(); + } else if (!valType->isPointerTy()) { + if (looseTypeAnalysis) { + auto fp = TR.firstPointer(storeSize, orig_ptr, &I, + /*errifnotfound*/ false, + /*pointerIntSame*/ true); + if (fp.isKnown()) { + FT = fp.isFloat(); + } else if (isa(orig_val) || + valType->isIntOrIntVectorTy()) { + llvm::errs() + << "assuming type as integral for store: " << I << "\n"; + FT = nullptr; } else { - llvm::errs() << " realloc: " << I << "\n"; - llvm_unreachable("Unknown loop reallocation"); + TR.firstPointer(storeSize, orig_ptr, &I, + /*errifnotfound*/ true, + /*pointerIntSame*/ true); + llvm::errs() << "cannot deduce type of store " << I << "\n"; + assert(0 && "cannot deduce"); } + } else { + FT = TR.firstPointer(storeSize, orig_ptr, &I, + /*errifnotfound*/ true, + /*pointerIntSame*/ true) + .isFloat(); } - if (loopShadowRematerializations.count(&I)) { - if (auto SI = dyn_cast(&I)) { - Value *orig_ptr = SI->getPointerOperand(); - Value *orig_val = SI->getValueOperand(); - Type *valType = orig_val->getType(); - assert(!isConstantValue(orig_ptr)); - - auto &DL = newFunc->getParent()->getDataLayout(); - - bool constantval = isConstantValue(orig_val) || - parseTBAA(I, DL).Inner0().isIntegral(); - - // TODO allow recognition of other types that could contain - // pointers [e.g. {void*, void*} or <2 x i64> ] - auto storeSize = DL.getTypeSizeInBits(valType) / 8; - - //! Storing a floating point value - Type *FT = nullptr; - if (valType->isFPOrFPVectorTy()) { - FT = valType->getScalarType(); - } else if (!valType->isPointerTy()) { - if (looseTypeAnalysis) { - auto fp = TR.firstPointer(storeSize, orig_ptr, &I, - /*errifnotfound*/ false, - /*pointerIntSame*/ true); - if (fp.isKnown()) { - FT = fp.isFloat(); - } else if (isa(orig_val) || - valType->isIntOrIntVectorTy()) { - llvm::errs() - << "assuming type as integral for store: " << I - << "\n"; - FT = nullptr; - } else { - TR.firstPointer(storeSize, orig_ptr, &I, - /*errifnotfound*/ true, - /*pointerIntSame*/ true); - llvm::errs() - << "cannot deduce type of store " << I << "\n"; - assert(0 && "cannot deduce"); - } - } else { - FT = TR.firstPointer(storeSize, orig_ptr, &I, - /*errifnotfound*/ true, - /*pointerIntSame*/ true) - .isFloat(); - } - } - if (!FT) { - Value *valueop = nullptr; - if (constantval) { - Value *val = - lookupM(getNewFromOriginal(orig_val), NB, available); - valueop = val; - if (getWidth() > 1) { - Value *array = - UndefValue::get(getShadowType(val->getType())); - for (unsigned i = 0; i < getWidth(); ++i) { - array = NB.CreateInsertValue(array, val, {i}); - } - valueop = array; - } - } else { - valueop = - lookupM(invertPointerM(orig_val, NB), NB, available); - } - SmallVector prevScopes; - if (auto prev = - SI->getMetadata(LLVMContext::MD_alias_scope)) { - for (auto &M : cast(prev)->operands()) { - prevScopes.push_back(M); - } - } - SmallVector prevNoAlias; - if (auto prev = SI->getMetadata(LLVMContext::MD_noalias)) { - for (auto &M : cast(prev)->operands()) { - prevNoAlias.push_back(M); - } - } - auto align = SI->getAlign(); - setPtrDiffe(SI, orig_ptr, valueop, NB, align, - SI->isVolatile(), SI->getOrdering(), - SI->getSyncScopeID(), - /*mask*/ nullptr, prevNoAlias, prevScopes); - } - // TODO shadow memtransfer - } else if (auto MS = dyn_cast(&I)) { - if (!isConstantValue(MS->getArgOperand(0))) { - Value *args[4] = { - lookupM(invertPointerM(MS->getArgOperand(0), NB), NB, - available), - lookupM(getNewFromOriginal(MS->getArgOperand(1)), NB, - available), - lookupM(getNewFromOriginal(MS->getArgOperand(2)), NB, - available), - lookupM(getNewFromOriginal(MS->getArgOperand(3)), NB, - available)}; - - ValueType BundleTypes[4] = { - ValueType::Shadow, ValueType::Primal, ValueType::Primal, - ValueType::Primal}; - auto Defs = getInvertedBundles(MS, BundleTypes, NB, - /*lookup*/ true, available); - auto cal = - NB.CreateCall(MS->getCalledFunction(), args, Defs); - llvm::SmallVector ToCopy2(MD_ToCopy); - ToCopy2.push_back(LLVMContext::MD_noalias); - ToCopy2.push_back(LLVMContext::MD_alias_scope); - cal->copyMetadata(*MS, ToCopy2); - cal->setAttributes(MS->getAttributes()); - cal->setCallingConv(MS->getCallingConv()); - cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); + } + if (!FT) { + Value *valueop = nullptr; + if (constantval) { + Value *val = + lookupM(getNewFromOriginal(orig_val), NB, available); + valueop = val; + if (getWidth() > 1) { + Value *array = UndefValue::get(getShadowType(val->getType())); + for (unsigned i = 0; i < getWidth(); ++i) { + array = NB.CreateInsertValue(array, val, {i}); } - } else if (auto CI = dyn_cast(&I)) { - StringRef funcName = getFuncNameFromCall(CI); - if (funcName == "julia.write_barrier") { + valueop = array; + } + } else { + valueop = lookupM(invertPointerM(orig_val, NB), NB, available); + } + SmallVector prevScopes; + if (auto prev = SI->getMetadata(LLVMContext::MD_alias_scope)) { + for (auto &M : cast(prev)->operands()) { + prevScopes.push_back(M); + } + } + SmallVector prevNoAlias; + if (auto prev = SI->getMetadata(LLVMContext::MD_noalias)) { + for (auto &M : cast(prev)->operands()) { + prevNoAlias.push_back(M); + } + } + auto align = SI->getAlign(); + setPtrDiffe(SI, orig_ptr, valueop, NB, align, SI->isVolatile(), + SI->getOrdering(), SI->getSyncScopeID(), + /*mask*/ nullptr, prevNoAlias, prevScopes); + } + // TODO shadow memtransfer + } else if (auto MS = dyn_cast(&I)) { + if (!isConstantValue(MS->getArgOperand(0))) { + Value *args[4] = { + lookupM(invertPointerM(MS->getArgOperand(0), NB), NB, + available), + lookupM(getNewFromOriginal(MS->getArgOperand(1)), NB, + available), + lookupM(getNewFromOriginal(MS->getArgOperand(2)), NB, + available), + lookupM(getNewFromOriginal(MS->getArgOperand(3)), NB, + available)}; + + ValueType BundleTypes[4] = {ValueType::Shadow, ValueType::Primal, + ValueType::Primal, ValueType::Primal}; + auto Defs = getInvertedBundles(MS, BundleTypes, NB, + /*lookup*/ true, available); + auto cal = NB.CreateCall(MS->getCalledFunction(), args, Defs); + llvm::SmallVector ToCopy2(MD_ToCopy); + ToCopy2.push_back(LLVMContext::MD_noalias); + ToCopy2.push_back(LLVMContext::MD_alias_scope); + cal->copyMetadata(*MS, ToCopy2); + cal->setAttributes(MS->getAttributes()); + cal->setCallingConv(MS->getCallingConv()); + cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); + } + } else if (auto CI = dyn_cast(&I)) { + StringRef funcName = getFuncNameFromCall(CI); + if (funcName == "julia.write_barrier") { - // TODO - SmallVector args; + // TODO + SmallVector args; #if LLVM_VERSION_MAJOR >= 14 - for (auto &arg : CI->args()) + for (auto &arg : CI->args()) #else - for (auto &arg : CI->arg_operands()) + for (auto &arg : CI->arg_operands()) #endif - if (!isConstantValue(arg)) - args.push_back( - lookupM(invertPointerM(arg, NB), NB, available)); - - if (args.size()) { - SmallVector BundleTypes(args.size(), - ValueType::Primal); - - auto Defs = - getInvertedBundles(CI, BundleTypes, NB, - /*lookup*/ true, available); - auto cal = - NB.CreateCall(CI->getFunctionType(), - CI->getCalledOperand(), args, Defs); - cal->setAttributes(CI->getAttributes()); - cal->setCallingConv(CI->getCallingConv()); - cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); - } - } else { - assert(isDeallocationFunction(funcName, TLI)); - continue; - } - } else { - assert( - 0 && - "unhandlable loop shadow rematerialization instruction"); - } - } else if (loopShadowReallocations.count(&I)) { + if (!isConstantValue(arg)) + args.push_back( + lookupM(invertPointerM(arg, NB), NB, available)); + + if (args.size()) { + SmallVector BundleTypes(args.size(), + ValueType::Primal); + + auto Defs = getInvertedBundles(CI, BundleTypes, NB, + /*lookup*/ true, available); + auto cal = NB.CreateCall(CI->getFunctionType(), + CI->getCalledOperand(), args, Defs); + cal->setAttributes(CI->getAttributes()); + cal->setCallingConv(CI->getCallingConv()); + cal->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); + } + } else { + assert(isDeallocationFunction(funcName, TLI)); + continue; + } + } else { + assert(0 && + "unhandlable loop shadow rematerialization instruction"); + } + } else if (loopShadowReallocations.count(&I)) { - LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, - &newFunc->getEntryBlock()); - auto ipfound = invertedPointers.find(&I); - PHINode *placeholder = cast(&*ipfound->second); - - auto found = scopeMap.find(placeholder); - if (found == scopeMap.end()) { - AllocaInst *cache = createCacheForScope( - lctx, placeholder->getType(), placeholder->getName(), - /*shouldFree*/ true); - assert(cache); - found = insert_or_assign( - scopeMap, (Value *&)placeholder, - std::pair, LimitContext>(cache, - lctx)); - } - auto cache = found->second.first; - Value *anti = nullptr; + LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, + &newFunc->getEntryBlock()); + auto ipfound = invertedPointers.find(&I); + PHINode *placeholder = cast(&*ipfound->second); + + auto found = scopeMap.find(placeholder); + if (found == scopeMap.end()) { + AllocaInst *cache = createCacheForScope( + lctx, placeholder->getType(), placeholder->getName(), + /*shouldFree*/ true); + assert(cache); + found = insert_or_assign( + scopeMap, (Value *&)placeholder, + std::pair, LimitContext>(cache, lctx)); + } + auto cache = found->second.first; + Value *anti = nullptr; - if (auto orig = dyn_cast(&I)) { - StringRef funcName = getFuncNameFromCall(orig); - assert(funcName.size()); + if (auto orig = dyn_cast(&I)) { + StringRef funcName = getFuncNameFromCall(orig); + assert(funcName.size()); - auto dbgLoc = getNewFromOriginal(orig)->getDebugLoc(); + auto dbgLoc = getNewFromOriginal(orig)->getDebugLoc(); - SmallVector args; + SmallVector args; #if LLVM_VERSION_MAJOR >= 14 - for (auto &arg : orig->args()) + for (auto &arg : orig->args()) #else - for (auto &arg : orig->arg_operands()) + for (auto &arg : orig->arg_operands()) #endif - { - args.push_back(lookupM(getNewFromOriginal(arg), NB)); - } + { + args.push_back(lookupM(getNewFromOriginal(arg), NB)); + } - placeholder->setName(""); - if (shadowHandlers.find(funcName) != shadowHandlers.end()) { + placeholder->setName(""); + if (shadowHandlers.find(funcName) != shadowHandlers.end()) { - anti = shadowHandlers[funcName](NB, orig, args, this); - } else { - auto rule = [&]() { - Value *anti = NB.CreateCall( - orig->getFunctionType(), orig->getCalledOperand(), - args, orig->getName() + "'mi"); - cast(anti)->setAttributes( - orig->getAttributes()); - cast(anti)->setCallingConv( - orig->getCallingConv()); - cast(anti)->setDebugLoc( - getNewFromOriginal(I.getDebugLoc())); - - cast(anti)->addAttribute( - AttributeList::ReturnIndex, Attribute::NoAlias); - cast(anti)->addAttribute( - AttributeList::ReturnIndex, Attribute::NonNull); - return anti; - }; - - anti = applyChainRule(orig->getType(), NB, rule); - - if (auto MD = hasMetadata(orig, "enzyme_fromstack")) { - auto rule = [&](Value *anti) { - AllocaInst *replacement = NB.CreateAlloca( - Type::getInt8Ty(orig->getContext()), args[0]); - replacement->takeName(anti); - auto Alignment = - cast( - cast(MD->getOperand(0)) - ->getValue()) - ->getLimitedValue(); - replacement->setAlignment(Align(Alignment)); - replacement->setDebugLoc( - getNewFromOriginal(I.getDebugLoc())); - return replacement; - }; - - Value *replacement = applyChainRule( - Type::getInt8Ty(orig->getContext()), NB, rule, anti); - - replaceAWithB(cast(anti), replacement); - erase(cast(anti)); - anti = replacement; - } + anti = shadowHandlers[funcName](NB, orig, args, this); + } else { + auto rule = [&]() { + Value *anti = NB.CreateCall(orig->getFunctionType(), + orig->getCalledOperand(), args, + orig->getName() + "'mi"); + cast(anti)->setAttributes(orig->getAttributes()); + cast(anti)->setCallingConv(orig->getCallingConv()); + cast(anti)->setDebugLoc( + getNewFromOriginal(I.getDebugLoc())); + + cast(anti)->addAttribute(AttributeList::ReturnIndex, + Attribute::NoAlias); + cast(anti)->addAttribute(AttributeList::ReturnIndex, + Attribute::NonNull); + return anti; + }; + + anti = applyChainRule(orig->getType(), NB, rule); + + if (auto MD = hasMetadata(orig, "enzyme_fromstack")) { + auto rule = [&](Value *anti) { + AllocaInst *replacement = NB.CreateAlloca( + Type::getInt8Ty(orig->getContext()), args[0]); + replacement->takeName(anti); + auto Alignment = cast(cast( + MD->getOperand(0)) + ->getValue()) + ->getLimitedValue(); + replacement->setAlignment(Align(Alignment)); + replacement->setDebugLoc(getNewFromOriginal(I.getDebugLoc())); + return replacement; + }; - applyChainRule( - NB, - [&](Value *anti) { - zeroKnownAllocation(NB, anti, args, funcName, TLI, - orig); - }, - anti); - } - } else { - llvm_unreachable("Unknown shadow rematerialization value"); - } - assert(anti); - storeInstructionInCache(lctx, NB, anti, cache); + Value *replacement = applyChainRule( + Type::getInt8Ty(orig->getContext()), NB, rule, anti); + + replaceAWithB(cast(anti), replacement); + erase(cast(anti)); + anti = replacement; } + + applyChainRule( + NB, + [&](Value *anti) { + zeroKnownAllocation(NB, anti, args, funcName, TLI, orig); + }, + anti); } + } else { + llvm_unreachable("Unknown shadow rematerialization value"); + } + assert(anti); + storeInstructionInCache(lctx, NB, anti, cache); + } + } - llvm::SmallPtrSet origExitBlocks; - getExitBlocks(origLI, origExitBlocks); - // Remap a branch to the header to enter the incremented - // reverse of that block. - auto remap = [&](BasicBlock *rB) { - // Remap of an exit branch is to go to the reverse - // exiting block. - if (origExitBlocks.count(rB)) { - return reverseBlocks[getNewFromOriginal(B)].front(); - } - // Reverse of an incrementing branch is go to the - // reverse of the branching block. - if (rB == origLI->getHeader()) - return reverseBlocks[getNewFromOriginal(B)].front(); - auto found = origToNewForward.find(rB); - if (found == origToNewForward.end()) { - llvm::errs() << *newFunc << "\n"; - llvm::errs() << *origLI << "\n"; - llvm::errs() << *rB << "\n"; - } - assert(found != origToNewForward.end()); - return found->second; - }; + llvm::SmallPtrSet origExitBlocks; + getExitBlocks(origLI, origExitBlocks); + // Remap a branch to the header to enter the incremented + // reverse of that block. + auto remap = [&](BasicBlock *rB) { + // Remap of an exit branch is to go to the reverse + // exiting block. + if (origExitBlocks.count(rB)) { + return reverseBlocks[getNewFromOriginal(B)].front(); + } + // Reverse of an incrementing branch is go to the + // reverse of the branching block. + if (rB == origLI->getHeader()) + return reverseBlocks[getNewFromOriginal(B)].front(); + auto found = origToNewForward.find(rB); + if (found == origToNewForward.end()) { + llvm::errs() << *newFunc << "\n"; + llvm::errs() << *origLI << "\n"; + llvm::errs() << *rB << "\n"; + } + assert(found != origToNewForward.end()); + return found->second; + }; - // TODO clone terminator - auto TI = B->getTerminator(); - assert(TI); - if (notForAnalysis.count(B)) { + // TODO clone terminator + auto TI = B->getTerminator(); + assert(TI); + if (notForAnalysis.count(B)) { + NB.CreateUnreachable(); + } else if (auto BI = dyn_cast(TI)) { + if (BI->isUnconditional()) { + if (notForAnalysis.count(BI->getSuccessor(0))) + NB.CreateUnreachable(); + else + NB.CreateBr(remap(BI->getSuccessor(0))); + } else { + if (notForAnalysis.count(BI->getSuccessor(0))) { + if (notForAnalysis.count(BI->getSuccessor(1))) { NB.CreateUnreachable(); - } else if (auto BI = dyn_cast(TI)) { - if (BI->isUnconditional()) { - if (notForAnalysis.count(BI->getSuccessor(0))) - NB.CreateUnreachable(); - else - NB.CreateBr(remap(BI->getSuccessor(0))); - } else { - if (notForAnalysis.count(BI->getSuccessor(0))) { - if (notForAnalysis.count(BI->getSuccessor(1))) { - NB.CreateUnreachable(); - } else { - NB.CreateBr(remap(BI->getSuccessor(1))); - } - } else if (notForAnalysis.count(BI->getSuccessor(1))) { - NB.CreateBr(remap(BI->getSuccessor(0))); - } else { - NB.CreateCondBr( - lookupM(getNewFromOriginal(BI->getCondition()), NB, - available), - remap(BI->getSuccessor(0)), remap(BI->getSuccessor(1))); - } - } - } else if (auto SI = dyn_cast(TI)) { - auto NSI = NB.CreateSwitch( - lookupM(getNewFromOriginal(SI->getCondition()), NB, - available), - remap(SI->getDefaultDest())); - for (auto cas : SI->cases()) { - if (!notForAnalysis.count(cas.getCaseSuccessor())) - NSI->addCase(cas.getCaseValue(), - remap(cas.getCaseSuccessor())); - } } else { - assert(isa(TI)); - NB.CreateUnreachable(); - } - // Fixup phi nodes that may have their predecessors now changed by - // the phi unwrapping - if (!notForAnalysis.count(B) && - NB.GetInsertBlock() != origToNewForward[B]) { - for (auto S0 : successors(B)) { - if (!origToNewForward.count(S0)) - continue; - auto S = origToNewForward[S0]; - assert(S); - for (auto I = S->begin(), E = S->end(); I != E; ++I) { - PHINode *orig = dyn_cast(&*I); - if (orig == nullptr) - break; - for (unsigned Op = 0, NumOps = orig->getNumOperands(); - Op != NumOps; ++Op) - if (orig->getIncomingBlock(Op) == origToNewForward[B]) - orig->setIncomingBlock(Op, NB.GetInsertBlock()); - } - } + NB.CreateBr(remap(BI->getSuccessor(1))); } + } else if (notForAnalysis.count(BI->getSuccessor(1))) { + NB.CreateBr(remap(BI->getSuccessor(0))); + } else { + NB.CreateCondBr( + lookupM(getNewFromOriginal(BI->getCondition()), NB, available), + remap(BI->getSuccessor(0)), remap(BI->getSuccessor(1))); + } + } + } else if (auto SI = dyn_cast(TI)) { + auto NSI = NB.CreateSwitch( + lookupM(getNewFromOriginal(SI->getCondition()), NB, available), + remap(SI->getDefaultDest())); + for (auto cas : SI->cases()) { + if (!notForAnalysis.count(cas.getCaseSuccessor())) + NSI->addCase(cas.getCaseValue(), remap(cas.getCaseSuccessor())); + } + } else { + assert(isa(TI)); + NB.CreateUnreachable(); + } + // Fixup phi nodes that may have their predecessors now changed by + // the phi unwrapping + if (!notForAnalysis.count(B) && + NB.GetInsertBlock() != origToNewForward[B]) { + for (auto S0 : successors(B)) { + if (!origToNewForward.count(S0)) + continue; + auto S = origToNewForward[S0]; + assert(S); + for (auto I = S->begin(), E = S->end(); I != E; ++I) { + PHINode *orig = dyn_cast(&*I); + if (orig == nullptr) + break; + for (unsigned Op = 0, NumOps = orig->getNumOperands(); Op != NumOps; + ++Op) + if (orig->getIncomingBlock(Op) == origToNewForward[B]) + orig->setIncomingBlock(Op, NB.GetInsertBlock()); } - resumeblock = enterB; } } + } + return enterB; + } + return nullptr; +} - if (incEntering) { - BasicBlock *incB = BasicBlock::Create( - BB->getContext(), - "inc" + reverseBlocks[lc.header].front()->getName(), - BB->getParent()); - incB->moveAfter(reverseBlocks[lc.header].back()); +/// Given an edge from BB to branchingBlock get the corresponding block to +/// branch to in the reverse pass +BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB, + BasicBlock *branchingBlock) { + assert(BB); + // BB should be a forward pass block, assert that + if (reverseBlocks.find(BB) == reverseBlocks.end()) { + llvm::errs() << *oldFunc << "\n"; + llvm::errs() << *newFunc << "\n"; + llvm::errs() << "BB: " << *BB << "\n"; + llvm::errs() << "branchingBlock: " << *branchingBlock << "\n"; + } + assert(reverseBlocks.find(BB) != reverseBlocks.end()); + assert(reverseBlocks.find(branchingBlock) != reverseBlocks.end()); + LoopContext lc; + bool inLoop = getContext(BB, lc); - IRBuilder<> tbuild(incB); + LoopContext branchingContext; + bool inLoopContext = getContext(branchingBlock, branchingContext); - Value *av = tbuild.CreateLoad(lc.var->getType(), lc.antivaralloc); - Value *sub = - tbuild.CreateAdd(av, ConstantInt::get(av->getType(), -1), "", - /*NUW*/ false, /*NSW*/ true); - tbuild.CreateStore(sub, lc.antivaralloc); - tbuild.CreateBr(resumeblock); - return newBlocksForLoop_cache[tup] = incB; - } else { - assert(exitEntering); - BasicBlock *incB = BasicBlock::Create( - BB->getContext(), - "merge" + reverseBlocks[lc.header].front()->getName() + "_" + - branchingBlock->getName(), - BB->getParent()); - incB->moveAfter(reverseBlocks[branchingBlock].back()); - - IRBuilder<> tbuild(reverseBlocks[branchingBlock].back()); - - Value *lim = nullptr; - if (lc.dynamic && assumeDynamicLoopOfSizeOne(L)) { - lim = ConstantInt::get(lc.var->getType(), 0); - } else if (lc.dynamic) { - // Must be in a reverse pass fashion for a lookup to index bound to be - // legal - assert(/*ReverseLimit*/ reverseBlocks.size() > 0); - LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, - lc.preheader); - lim = lookupValueFromCache( - lc.var->getType(), - /*forwardPass*/ false, tbuild, lctx, - getDynamicLoopLimit(LI.getLoopFor(lc.header)), - /*isi1*/ false, /*available*/ ValueToValueMapTy()); - } else { - lim = lookupM(lc.trueLimit, tbuild); - } + if (!inLoop) + return reverseBlocks[BB].front(); + + auto tup = std::make_tuple(BB, branchingBlock); + if (newBlocksForLoop_cache.find(tup) != newBlocksForLoop_cache.end()) + return newBlocksForLoop_cache[tup]; + + // If we're reversing a latch edge. + bool incEntering = inLoopContext && branchingBlock == lc.header && + lc.header == branchingContext.header; + + auto L = LI.getLoopFor(BB); + auto latches = getLatches(L, lc.exitBlocks); + // If we're reverseing a loop exit. + bool exitEntering = + std::find(latches.begin(), latches.end(), BB) != latches.end() && + std::find(lc.exitBlocks.begin(), lc.exitBlocks.end(), branchingBlock) != + lc.exitBlocks.end(); + + // It is illegal to be both an increment into a loop, and exiting the loop. + assert(!(incEntering && exitEntering)); + + // If we're re-entering a loop, prepare a loop-level forward pass to + // rematerialize any loop-scope rematerialization. + + if (incEntering) { + BasicBlock *resumeblock = reverseBlocks[BB].front(); + auto tmp_resumeblock = prepRematerializedLoopEntry(lc); + if (tmp_resumeblock) + resumeblock = tmp_resumeblock; + BasicBlock *incB = BasicBlock::Create( + BB->getContext(), "inc" + reverseBlocks[lc.header].front()->getName(), + BB->getParent()); + incB->moveAfter(reverseBlocks[lc.header].back()); + + IRBuilder<> tbuild(incB); + + Value *av = tbuild.CreateLoad(lc.var->getType(), lc.antivaralloc); + Value *sub = tbuild.CreateAdd(av, ConstantInt::get(av->getType(), -1), "", + /*NUW*/ false, /*NSW*/ true); + tbuild.CreateStore(sub, lc.antivaralloc); + tbuild.CreateBr(resumeblock); + return newBlocksForLoop_cache[tup] = incB; + } + + if (exitEntering) { + SmallVector exitingContexts = {lc}; + + auto L2 = L; + while ((L2 = L2->getParentLoop())) { + LoopContext lc2; + bool inLoop = getContext(L2->getHeader(), lc2); + if (!inLoop) + break; + + auto latches2 = getLatches(L2, lc2.exitBlocks); + + // If we're reverseing a loop exit. + bool exitEntering2 = + std::find(latches2.begin(), latches2.end(), BB) != latches2.end() && + std::find(lc2.exitBlocks.begin(), lc2.exitBlocks.end(), + branchingBlock) != lc2.exitBlocks.end(); + if (exitEntering2) { + exitingContexts.push_back(lc2); + } else + break; + } - tbuild.SetInsertPoint(incB); - tbuild.CreateStore(lim, lc.antivaralloc); - tbuild.CreateBr(resumeblock); + BasicBlock *resumeblock = reverseBlocks[BB].front(); + BasicBlock *prevBlock = reverseBlocks[branchingBlock].back(); - return newBlocksForLoop_cache[tup] = incB; + BasicBlock *outerMerge = nullptr; + + BasicBlock *incB = BasicBlock::Create( + BB->getContext(), + "merge" + reverseBlocks[lc.header].front()->getName() + "_" + + branchingBlock->getName(), + BB->getParent()); + if (!outerMerge) + outerMerge = incB; + incB->moveAfter(prevBlock); + + IRBuilder<> tbuild(prevBlock); + + SmallVector, 1> lims; + for (auto I = exitingContexts.rbegin(), E = exitingContexts.rend(); I != E; + I++) { + auto &lc = *I; + auto L = LI.getLoopFor(lc.header); + Value *lim = nullptr; + if (lc.dynamic && assumeDynamicLoopOfSizeOne(L)) { + lim = ConstantInt::get(lc.var->getType(), 0); + } else if (lc.dynamic) { + // Must be in a reverse pass fashion for a lookup to index bound to be + // legal + assert(/*ReverseLimit*/ reverseBlocks.size() > 0); + LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, + lc.preheader); + lim = lookupValueFromCache( + lc.var->getType(), + /*forwardPass*/ false, tbuild, lctx, getDynamicLoopLimit(L), + /*isi1*/ false, /*available*/ ValueToValueMapTy()); + } else { + lim = lookupM(lc.trueLimit, tbuild); } + lims.push_back(std::make_pair(lim, (Value *)lc.antivaralloc)); + } + + tbuild.SetInsertPoint(incB); + for (auto &pair : lims) { + tbuild.CreateStore(pair.first, pair.second); } + + auto tmp_resumeblock = prepRematerializedLoopEntry(exitingContexts.back()); + if (tmp_resumeblock) + resumeblock = tmp_resumeblock; + + tbuild.CreateBr(resumeblock); + + return newBlocksForLoop_cache[tup] = incB; } return newBlocksForLoop_cache[tup] = reverseBlocks[BB].front(); diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index ea5407e6044e..a87f6da9dd69 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -417,11 +417,20 @@ class GradientUtils : public CacheUtility { newBlocksForLoop_cache; //! This cache stores a rematerialized forward pass in the loop - //! specified - std::map rematerializedLoops_cache; + //! specified. The key is the loop header. + std::map + rematerializedLoops_cache; llvm::BasicBlock *getReverseOrLatchMerge(llvm::BasicBlock *BB, llvm::BasicBlock *branchingBlock); +private: + //! Given a loop `lc`, create the rematerialization blocks for the reverse + //! pass, if required, caching if already created. This function will return + //! the new block for the rematerialized loop entry to branch to, if created. + //! Otherwise it will return nullptr. + llvm::BasicBlock *prepRematerializedLoopEntry(LoopContext &lc); + +public: void forceContexts(); void computeMinCache(); diff --git a/enzyme/test/Enzyme/ReverseMode/multiloopexit.ll b/enzyme/test/Enzyme/ReverseMode/multiloopexit.ll new file mode 100644 index 000000000000..463247f0bc7d --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/multiloopexit.ll @@ -0,0 +1,212 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme" -S | FileCheck %s + +declare i1 @end() + +define float @todiff(float %a0, i64 %a1) { +entry: + br label %loop1 + +loop1: ; preds = %L19.i, %entry + %sum.0 = phi float [ 0.000000e+00, %entry ], [ %sum.1, %floop1 ] + %sum.1 = fadd float %sum.0, %a0 + %end1 = call i1 @end() + br i1 %end1, label %loop2, label %floop1 + +loop2: ; preds = %L2.i, %pass.i + %sum.2 = phi float [ %sum.3, %pass.i ], [ %sum.1, %loop1 ] + %end2 = call i1 @end() + br i1 %end2, label %exit, label %pass.i + +pass.i: ; preds = %L9.i + %sum.3 = fadd float %sum.2, %a0 + %end3 = call i1 @end() + br i1 %end3, label %loop2, label %floop1 + +floop1: ; preds = %pass.i, %L2.i + br label %loop1 + +exit: ; preds = %L9.i + ret float %sum.2 +} + +declare float @__enzyme_autodiff(...) + +define float @c() { + %c = call float (...) @__enzyme_autodiff(float (float, i64)* @todiff, float 1.0, i64 1) + ret float %c +} + +; CHECK: define internal { float } @diffetodiff(float %a0, i64 %a1, float %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"iv'ac" = alloca i64, align 8 +; CHECK-NEXT: %loopLimit_cache = alloca i64, align 8 +; CHECK-NEXT: %"iv1'ac" = alloca i64, align 8 +; CHECK-NEXT: %loopLimit_cache2 = alloca i64*, align 8 +; CHECK-NEXT: %"sum.2'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"sum.2'de", align 4 +; CHECK-NEXT: %"a0'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"a0'de", align 4 +; CHECK-NEXT: %"sum.1'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"sum.1'de", align 4 +; CHECK-NEXT: %"sum.0'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"sum.0'de", align 4 +; CHECK-NEXT: %"sum.3'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"sum.3'de", align 4 +; CHECK-NEXT: %end1_cache = alloca i1*, align 8 +; CHECK-NEXT: store i64* null, i64** %loopLimit_cache2, align 8 +; CHECK-NEXT: store i1* null, i1** %end1_cache, align 8 +; CHECK-NEXT: br label %loop1 + +; CHECK: loop1: ; preds = %floop1, %entry +; CHECK-NEXT: %iv = phi i64 [ %iv.next, %floop1 ], [ 0, %entry ] +; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 + +; CHECK: %end1 = call i1 @end() +; CHECK-NEXT: %[[i32:.+]] = load i1*, i1** %end1_cache, align 8 +; CHECK-NEXT: %[[i33:.+]] = getelementptr inbounds i1, i1* %[[i32]], i64 %iv +; CHECK-NEXT: store i1 %end1, i1* %[[i33]], align 1 +; CHECK-NEXT: br i1 %end1, label %loop2.preheader, label %floop1 + +; CHECK: loop2.preheader: +; CHECK-NEXT: br label %loop2 + +; CHECK: loop2: +; CHECK-NEXT: %iv1 = phi i64 [ 0, %loop2.preheader ], [ %iv.next2, %pass.i ] +; CHECK-NEXT: %iv.next2 = add nuw nsw i64 %iv1, 1 +; CHECK-NEXT: %end2 = call i1 @end() +; CHECK-NEXT: br i1 %end2, label %exit, label %pass.i + +; CHECK: pass.i: ; preds = %loop2 +; CHECK-NEXT: %end3 = call i1 @end() +; CHECK-NEXT: br i1 %end3, label %loop2, label %floop1.loopexit + +; CHECK: floop1.loopexit: +; CHECK-NEXT: %[[i34:.+]] = phi i64 [ %iv1, %pass.i ] +; CHECK-NEXT: %[[i35:.+]] = load i64*, i64** %loopLimit_cache2, align 8 +; CHECK-NEXT: %[[i36:.+]] = getelementptr inbounds i64, i64* %[[i35]], i64 %iv +; CHECK-NEXT: store i64 %[[i34]], i64* %[[i36]], align 8 +; CHECK-NEXT: br label %floop1 + +; CHECK: floop1: ; preds = %floop1.loopexit, %__enzyme_exponentialallocation.exit15 +; CHECK-NEXT: br label %loop1 + +; CHECK: exit: ; preds = %loop2 +; CHECK-NEXT: %[[i37:.+]] = phi i64 [ %iv1, %loop2 ] +; CHECK-NEXT: %[[i38:.+]] = phi i64 [ %iv, %loop2 ] +; CHECK-NEXT: %[[i39:.+]] = load i64*, i64** %loopLimit_cache2, align 8 +; CHECK-NEXT: %[[i40:.+]] = getelementptr inbounds i64, i64* %[[i39]], i64 %iv +; CHECK-NEXT: store i64 %[[i37]], i64* %[[i40]], align 8 +; CHECK-NEXT: store i64 %[[i38]], i64* %loopLimit_cache, align 8 +; CHECK-NEXT: br label %invertexit + +; CHECK: invertentry: ; preds = %invertloop1 +; CHECK-NEXT: %[[i41:.+]] = load i64, i64* %"iv'ac", align 4 +; CHECK-NEXT: %forfree = load i64*, i64** %loopLimit_cache2, align 8 +; CHECK-NEXT: %[[i42:.+]] = bitcast i64* %forfree to i8* +; CHECK-NEXT: tail call void @free(i8* nonnull %[[i42]]) +; CHECK-NEXT: %[[i43:.+]] = load float, float* %"a0'de", align 4 +; CHECK-NEXT: %[[i44:.+]] = insertvalue { float } undef, float %[[i43]], 0 +; CHECK-NEXT: %[[i45:.+]] = load i64, i64* %"iv'ac", align 4 +; CHECK-NEXT: %forfree10 = load i1*, i1** %end1_cache, align 8 +; CHECK-NEXT: %[[i46:.+]] = bitcast i1* %forfree10 to i8* +; CHECK-NEXT: tail call void @free(i8* nonnull %[[i46]]) +; CHECK-NEXT: ret { float } %[[i44]] + +; CHECK: invertloop1: ; preds = %invertfloop1, %invertloop2.preheader +; CHECK-NEXT: %[[i47:.+]] = load float, float* %"sum.1'de", align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"sum.1'de", align 4 +; CHECK-NEXT: %[[i48:.+]] = load float, float* %"sum.0'de", align 4 +; CHECK-NEXT: %[[i49:.+]] = fadd fast float %[[i48]], %[[i47]] +; CHECK-NEXT: store float %[[i49]], float* %"sum.0'de", align 4 +; CHECK-NEXT: %[[i50:.+]] = load float, float* %"a0'de", align 4 +; CHECK-NEXT: %[[i51:.+]] = fadd fast float %[[i50]], %[[i47]] +; CHECK-NEXT: store float %[[i51]], float* %"a0'de", align 4 +; CHECK-NEXT: %[[i52:.+]] = load float, float* %"sum.0'de", align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"sum.0'de", align 4 +; CHECK-NEXT: %[[i53:.+]] = load i64, i64* %"iv'ac", align 4 +; CHECK-NEXT: %[[i54:.+]] = icmp eq i64 %[[i53]], 0 +; CHECK-NEXT: %[[i55:.+]] = xor i1 %[[i54]], true +; CHECK-NEXT: %[[i56:.+]] = select fast i1 %[[i55]], float %[[i52]], float 0.000000e+00 +; CHECK-NEXT: %[[i57:.+]] = load float, float* %"sum.1'de", align 4 +; CHECK-NEXT: %[[i58:.+]] = fadd fast float %[[i57]], %[[i52]] +; CHECK-NEXT: %[[i59:.+]] = select fast i1 %[[i54]], float %[[i57]], float %[[i58]] +; CHECK-NEXT: store float %[[i59]], float* %"sum.1'de", align 4 +; CHECK-NEXT: br i1 %[[i54]], label %invertentry, label %incinvertloop1 + +; CHECK: incinvertloop1: ; preds = %invertloop1 +; CHECK-NEXT: %[[i60:.+]] = load i64, i64* %"iv'ac", align 4 +; CHECK-NEXT: %[[i61:.+]] = add nsw i64 %[[i60]], -1 +; CHECK-NEXT: store i64 %[[i61]], i64* %"iv'ac", align 4 +; CHECK-NEXT: br label %invertfloop1 + +; CHECK: invertloop2.preheader: ; preds = %invertloop2 +; CHECK-NEXT: br label %invertloop1 + +; CHECK: invertloop2: ; preds = %mergeinvertloop2_exit, %invertpass.i +; CHECK-NEXT: %[[i62:.+]] = load float, float* %"sum.2'de", align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"sum.2'de", align 4 +; CHECK-NEXT: %[[i63:.+]] = load i64, i64* %"iv1'ac", align 4 +; CHECK-NEXT: %[[i64:.+]] = icmp eq i64 %[[i63]], 0 +; CHECK-NEXT: %[[i65:.+]] = xor i1 %[[i64]], true +; CHECK-NEXT: %[[i66:.+]] = select fast i1 %[[i64]], float %[[i62]], float 0.000000e+00 +; CHECK-NEXT: %[[i67:.+]] = load float, float* %"sum.1'de", align 4 +; CHECK-NEXT: %[[i68:.+]] = fadd fast float %[[i67]], %[[i62]] +; CHECK-NEXT: %[[i69:.+]] = select fast i1 %[[i64]], float %[[i68]], float %[[i67]] +; CHECK-NEXT: store float %[[i69]], float* %"sum.1'de", align 4 +; CHECK-NEXT: %[[i70:.+]] = select fast i1 %[[i65]], float %[[i62]], float 0.000000e+00 +; CHECK-NEXT: %[[i71:.+]] = load float, float* %"sum.3'de", align 4 +; CHECK-NEXT: %[[i72:.+]] = fadd fast float %[[i71]], %[[i62]] +; CHECK-NEXT: %[[i73:.+]] = select fast i1 %[[i64]], float %[[i71]], float %[[i72]] +; CHECK-NEXT: store float %[[i73]], float* %"sum.3'de", align 4 +; CHECK-NEXT: br i1 %[[i64]], label %invertloop2.preheader, label %incinvertloop2 + +; CHECK: incinvertloop2: ; preds = %invertloop2 +; CHECK-NEXT: %[[i74:.+]] = load i64, i64* %"iv1'ac", align 4 +; CHECK-NEXT: %[[i75:.+]] = add nsw i64 %[[i74]], -1 +; CHECK-NEXT: store i64 %[[i75]], i64* %"iv1'ac", align 4 +; CHECK-NEXT: br label %invertpass.i + +; CHECK: invertpass.i: ; preds = %mergeinvertloop2_floop1.loopexit, %incinvertloop2 +; CHECK-NEXT: %[[i76:.+]] = load float, float* %"sum.3'de", align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"sum.3'de", align 4 +; CHECK-NEXT: %[[i77:.+]] = load float, float* %"sum.2'de", align 4 +; CHECK-NEXT: %[[i78:.+]] = fadd fast float %[[i77]], %[[i76]] +; CHECK-NEXT: store float %[[i78]], float* %"sum.2'de", align 4 +; CHECK-NEXT: %[[i79:.+]] = load float, float* %"a0'de", align 4 +; CHECK-NEXT: %[[i80:.+]] = fadd fast float %[[i79]], %[[i76]] +; CHECK-NEXT: store float %[[i80]], float* %"a0'de", align 4 +; CHECK-NEXT: br label %invertloop2 + +; CHECK: invertfloop1.loopexit: ; preds = %invertfloop1 +; CHECK-NEXT: %[[i81:.+]] = load i64*, i64** %loopLimit_cache2, align 8 +; CHECK-NEXT: %[[i82:.+]] = load i64, i64* %"iv'ac", align 4 +; CHECK-NEXT: %[[i83:.+]] = getelementptr inbounds i64, i64* %[[i81]], i64 %[[i82]] +; CHECK-NEXT: %[[i84:.+]] = load i64, i64* %[[i83]], align 8, !invariant.group ! +; CHECK-NEXT: br label %mergeinvertloop2_floop1.loopexit + +; CHECK: mergeinvertloop2_floop1.loopexit: ; preds = %invertfloop1.loopexit +; CHECK-NEXT: store i64 %[[i84]], i64* %"iv1'ac", align 4 +; CHECK-NEXT: br label %invertpass.i + +; CHECK: invertfloop1: ; preds = %incinvertloop1 +; CHECK-NEXT: %[[i85:.+]] = load i64, i64* %"iv'ac", align 4 +; CHECK-NEXT: %[[i86:.+]] = load i1*, i1** %end1_cache, align 8 +; CHECK-NEXT: %[[i87:.+]] = getelementptr inbounds i1, i1* %[[i86]], i64 %[[i85]] +; CHECK-NEXT: %[[i88:.+]] = load i1, i1* %[[i87]], align 1, !invariant.group ! +; CHECK-NEXT: br i1 %[[i88]], label %invertfloop1.loopexit, label %invertloop1 + +; CHECK: invertexit: ; preds = %exit +; CHECK-NEXT: store float %differeturn, float* %"sum.2'de", align 4 +; CHECK-NEXT: %[[i89:.+]] = load i64, i64* %loopLimit_cache, align 8 +; CHECK-NEXT: %[[i90:.+]] = load i64*, i64** %loopLimit_cache2, align 8 +; CHECK-NEXT: %[[i91:.+]] = load i64, i64* %"iv'ac", align 4 +; CHECK-NEXT: %[[i92:.+]] = getelementptr inbounds i64, i64* %[[i90]], i64 %[[i91]] +; CHECK-NEXT: %[[i93:.+]] = load i64, i64* %[[i92]], align 8 +; CHECK-NEXT: br label %mergeinvertloop2_exit + +; CHECK: mergeinvertloop2_exit: ; preds = %invertexit +; CHECK-NEXT: store i64 %[[i89]], i64* %"iv'ac", align 4 +; CHECK-NEXT: store i64 %[[i93]], i64* %"iv1'ac", align 4 +; CHECK-NEXT: br label %invertloop2 +; CHECK-NEXT: } From 4c6c0a61e40b20e91a4ce1d33de9a675f4b022f0 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 18 Sep 2023 00:03:32 -0500 Subject: [PATCH 06/29] Do simple mem2reg for autodiff fn detection (#1431) --- enzyme/Enzyme/Utils.cpp | 60 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 7efd06a8bd37..e2d896893537 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -2271,6 +2271,66 @@ Function *GetFunctionFromValue(Value *fn) { } } } + if (auto LI = dyn_cast(fn)) { + auto obj = getBaseObject(LI->getPointerOperand()); + if (isa(obj)) { + std::set> done; + SmallVector, 1> todo; + Value *stored = nullptr; + bool legal = true; + for (auto U : obj->users()) { + if (auto I = dyn_cast(U)) + todo.push_back(std::make_pair(I, obj)); + else { + legal = false; + break; + } + } + while (legal && todo.size()) { + auto tup = todo.pop_back_val(); + if (done.count(tup)) + continue; + done.insert(tup); + auto cur = tup.first; + auto prev = tup.second; + if (auto SI = dyn_cast(cur)) + if (SI->getPointerOperand() == prev) { + if (stored == SI->getValueOperand()) + continue; + else if (stored == nullptr) { + stored = SI->getValueOperand(); + continue; + } else { + legal = false; + break; + } + } + + if (isPointerArithmeticInst(cur, /*includephi*/ true)) { + for (auto U : cur->users()) { + if (auto I = dyn_cast(U)) + todo.push_back(std::make_pair(I, cur)); + else { + legal = false; + break; + } + } + continue; + } + + if (!cur->mayWriteToMemory() && cur->getType()->isVoidTy()) + continue; + + legal = false; + break; + } + + if (legal && stored) { + fn = stored; + continue; + } + } + } break; } From a7302cedfcbc677d3f3cdce9e088a451b152d1f3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 18 Sep 2023 01:41:03 -0500 Subject: [PATCH 07/29] Special case gep of small_typeof (#1433) --- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index bbf63b9a4c90..2a6c1426aa97 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -1401,6 +1401,16 @@ void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) { return; } } + if (auto GV = dyn_cast(gep.getPointerOperand())) { + // from julia code, do not propagate int to operands + if (GV->getName() == "small_typeof") { + TypeTree T; + T.insert({-1}, BaseType::Pointer); + T.insert({-1, -1}, BaseType::Pointer); + updateAnalysis(&gep, T, &gep); + return; + } + } if (gep.indices().begin() == gep.indices().end()) { if (direction & DOWN) From c19fc9c366c1eebfc634e65c1394791c09965266 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 18 Sep 2023 13:04:00 -0500 Subject: [PATCH 08/29] Speed up and fix type analysis merges (#1432) * Speed up and fix type analysis merges * Fix analysis insertion of anything --- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 119 ++------ enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h | 2 + enzyme/Enzyme/TypeAnalysis/TypeTree.h | 263 ++++++++++-------- enzyme/test/Enzyme/ReverseMode/infanalysis.ll | 22 ++ 4 files changed, 206 insertions(+), 200 deletions(-) create mode 100644 enzyme/test/Enzyme/ReverseMode/infanalysis.ll diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 2a6c1426aa97..ec02c6f22146 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -436,34 +436,8 @@ void getConstantAnalysis(Constant *Val, TypeAnalyzer &TA, analysis[Val] = analysis[CE->getOperand(0)]; return; } - if (CE->getOpcode() == Instruction::GetElementPtr && - llvm::all_of(CE->operand_values(), - [](Value *v) { return isa(v); })) { - auto g2 = cast(CE->getAsInstruction()); - APInt ai(DL.getIndexSizeInBits(g2->getPointerAddressSpace()), 0); - g2->accumulateConstantOffset(DL, ai); - // Using destructor rather than eraseFromParent - // as g2 has no parent - delete g2; - - int off = (int)ai.getLimitedValue(); - - // TODO also allow negative offsets - if (off < 0) { - analysis[Val] = TypeTree(BaseType::Pointer).Only(-1, nullptr); - return; - } - - getConstantAnalysis(CE->getOperand(0), TA, analysis); - auto gepData0 = analysis[CE->getOperand(0)].Data0(); - - TypeTree result = - gepData0 - .ShiftIndices(DL, /*init offset*/ off, /*max size*/ -1, - /*new offset*/ 0) - .Only(-1, nullptr); - result.insert({-1}, BaseType::Pointer); - analysis[Val] = result; + if (CE->getOpcode() == Instruction::GetElementPtr) { + TA.visitGEPOperator(*cast(CE)); return; } @@ -678,8 +652,11 @@ void TypeAnalyzer::updateAnalysis(Value *Val, TypeTree Data, Value *Origin) { // Attempt to update the underlying analysis bool LegalOr = true; - if (analysis.find(Val) == analysis.end() && isa(Val)) - getConstantAnalysis(cast(Val), *this, analysis); + if (analysis.find(Val) == analysis.end() && isa(Val)) { + if (!isa(Val) || + cast(Val)->getOpcode() != Instruction::GetElementPtr) + getConstantAnalysis(cast(Val), *this, analysis); + } TypeTree prev = analysis[Val]; @@ -1200,53 +1177,8 @@ void TypeAnalyzer::visitConstantExpr(ConstantExpr &CE) { updateAnalysis(CE.getOperand(0), getAnalysis(&CE), &CE); return; } - if (CE.getOpcode() == Instruction::GetElementPtr && - llvm::all_of(CE.operand_values(), - [](Value *v) { return isa(v); })) { - - auto &DL = fntypeinfo.Function->getParent()->getDataLayout(); - auto g2 = cast(CE.getAsInstruction()); - APInt ai(DL.getIndexSizeInBits(g2->getPointerAddressSpace()), 0); - g2->accumulateConstantOffset(DL, ai); - // Using destructor rather than eraseFromParent - // as g2 has no parent - - int maxSize = -1; - if (cast(CE.getOperand(1))->getLimitedValue() == 0) { - maxSize = DL.getTypeAllocSizeInBits(g2->getResultElementType()) / 8; - } - - delete g2; - - int off = (int)ai.getLimitedValue(); - - // TODO also allow negative offsets - if (off < 0) { - if (direction & DOWN) - updateAnalysis(&CE, TypeTree(BaseType::Pointer).Only(-1, nullptr), &CE); - if (direction & UP) - updateAnalysis(CE.getOperand(0), - TypeTree(BaseType::Pointer).Only(-1, nullptr), &CE); - return; - } - - if (direction & DOWN) { - auto gepData0 = getAnalysis(CE.getOperand(0)).Data0(); - TypeTree result = - gepData0.ShiftIndices(DL, /*init offset*/ off, - /*max size*/ maxSize, /*newoffset*/ 0); - result.insert({}, BaseType::Pointer); - updateAnalysis(&CE, result.Only(-1, nullptr), &CE); - } - if (direction & UP) { - auto pointerData0 = getAnalysis(&CE).Data0(); - - TypeTree result = - pointerData0.ShiftIndices(DL, /*init offset*/ 0, /*max size*/ -1, - /*new offset*/ off); - result.insert({}, BaseType::Pointer); - updateAnalysis(CE.getOperand(0), result.Only(-1, nullptr), &CE); - } + if (CE.getOpcode() == Instruction::GetElementPtr) { + visitGEPOperator(*cast(&CE)); return; } auto I = CE.getAsInstruction(); @@ -1375,14 +1307,20 @@ std::set> getSet(ArrayRef> todo, size_t idx) { } void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) { + visitGEPOperator(*cast(&gep)); +} + +void TypeAnalyzer::visitGEPOperator(GEPOperator &gep) { + auto inst = dyn_cast(&gep); if (isa(gep.getPointerOperand())) { - updateAnalysis(&gep, TypeTree(BaseType::Anything).Only(-1, &gep), &gep); + updateAnalysis(&gep, TypeTree(BaseType::Anything).Only(-1, inst), &gep); return; } if (isa(gep.getPointerOperand())) { bool nonZero = false; bool legal = true; - for (auto &ind : gep.indices()) { + for (auto I = gep.idx_begin(), E = gep.idx_end(); I != E; I++) { + auto ind = I->get(); if (auto CI = dyn_cast(ind)) { if (!CI->isZero()) { nonZero = true; @@ -1397,7 +1335,7 @@ void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) { break; } if (legal && nonZero) { - updateAnalysis(&gep, TypeTree(BaseType::Integer).Only(-1, &gep), &gep); + updateAnalysis(&gep, TypeTree(BaseType::Integer).Only(-1, inst), &gep); return; } } @@ -1412,7 +1350,7 @@ void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) { } } - if (gep.indices().begin() == gep.indices().end()) { + if (gep.idx_begin() == gep.idx_end()) { if (direction & DOWN) updateAnalysis(&gep, getAnalysis(gep.getPointerOperand()), &gep); if (direction & UP) @@ -1438,8 +1376,9 @@ void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) { if (gep.isInBounds() || (!EnzymeStrictAliasing && pointerAnalysis.Inner0() == BaseType::Pointer && getAnalysis(&gep).Inner0() == BaseType::Pointer)) { - for (auto &ind : gep.indices()) { - updateAnalysis(ind, TypeTree(BaseType::Integer).Only(-1, &gep), &gep); + for (auto I = gep.idx_begin(), E = gep.idx_end(); I != E; I++) { + auto ind = I->get(); + updateAnalysis(ind, TypeTree(BaseType::Integer).Only(-1, inst), &gep); } } } @@ -1449,7 +1388,8 @@ void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) { bool pointerPropagate = gep.isInBounds(); if (!pointerPropagate) { bool allIntegral = true; - for (auto &ind : gep.indices()) { + for (auto I = gep.idx_begin(), E = gep.idx_end(); I != E; I++) { + auto ind = I->get(); auto CT = getAnalysis(ind).Inner0(); if (CT != BaseType::Integer && CT != BaseType::Anything) { allIntegral = false; @@ -1478,16 +1418,17 @@ void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) { } } updateAnalysis(&gep, keepMinus, &gep); - updateAnalysis(&gep, TypeTree(pointerAnalysis.Inner0()).Only(-1, &gep), + updateAnalysis(&gep, TypeTree(pointerAnalysis.Inner0()).Only(-1, inst), &gep); } if (direction & UP) updateAnalysis(gep.getPointerOperand(), - TypeTree(getAnalysis(&gep).Inner0()).Only(-1, &gep), &gep); + TypeTree(getAnalysis(&gep).Inner0()).Only(-1, inst), &gep); SmallVector, 4> idnext; - for (auto &a : gep.indices()) { + for (auto I = gep.idx_begin(), E = gep.idx_end(); I != E; I++) { + auto a = I->get(); auto iset = fntypeinfo.knownIntegralValues(a, DT, intseen, SE); std::set vset; for (auto i : iset) { @@ -1556,9 +1497,9 @@ void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) { seenIdx = true; } if (direction & DOWN) - updateAnalysis(&gep, downTree.Only(-1, &gep), &gep); + updateAnalysis(&gep, downTree.Only(-1, inst), &gep); if (direction & UP) - updateAnalysis(gep.getPointerOperand(), upTree.Only(-1, &gep), &gep); + updateAnalysis(gep.getPointerOperand(), upTree.Only(-1, inst), &gep); } void TypeAnalyzer::visitPHINode(PHINode &phi) { diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h index a299f83921ba..18765487e6d5 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h @@ -290,6 +290,8 @@ class TypeAnalyzer : public llvm::InstVisitor { void visitGetElementPtrInst(llvm::GetElementPtrInst &gep); + void visitGEPOperator(llvm::GEPOperator &gep); + void visitPHINode(llvm::PHINode &phi); void visitTruncInst(llvm::TruncInst &I); diff --git a/enzyme/Enzyme/TypeAnalysis/TypeTree.h b/enzyme/Enzyme/TypeAnalysis/TypeTree.h index f0c8fe691009..bf31de9f8b9f 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeTree.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeTree.h @@ -146,7 +146,7 @@ class TypeTree : public std::enable_shared_from_this { /// Return if changed bool insert(const std::vector Seq, ConcreteType CT, - bool intsAreLegalSubPointer = false) { + bool PointerIntSame = false) { size_t SeqSize = Seq.size(); if (SeqSize > EnzymeMaxTypeDepth) { if (EnzymeTypeWarning) { @@ -188,83 +188,98 @@ class TypeTree : public std::enable_shared_from_this { } bool changed = false; - - // if this is a ending -1, remove other elems if no more info - for (size_t suffixSize = 1; suffixSize <= SeqSize; suffixSize++) { - if (Seq[SeqSize - suffixSize] != -1) - break; + // Check if there is an existing match, e.g. [-1, -1, -1] and inserting + // [-1, 8, -1] + { std::set> toremove; for (const auto &pair : mapping) { - if (pair.first.size() != SeqSize) - continue; - bool matches = true; - for (unsigned i = 0; i < SeqSize - suffixSize; ++i) { - if (pair.first[i] != Seq[i]) { - matches = false; - break; + if (pair.first.size() == SeqSize) { + // Whether the the inserted val (e.g. [-1, 0] or [0, 0]) is at least + // as general as the existing map val (e.g. [0, 0]). + bool newMoreGeneralThanOld = true; + // Whether the the existing val (e.g. [-1, 0] or [0, 0]) is at least + // as general as the inserted map val (e.g. [0, 0]). + bool oldMoreGeneralThanNew = true; + for (unsigned i = 0; i < SeqSize; i++) { + if (pair.first[i] == Seq[i]) + continue; + if (Seq[i] == -1) { + oldMoreGeneralThanNew = false; + } else if (pair.first[i] == -1) { + newMoreGeneralThanOld = false; + } else { + oldMoreGeneralThanNew = false; + newMoreGeneralThanOld = false; + break; + } } - } - if (!matches) - continue; - if (intsAreLegalSubPointer && pair.second == BaseType::Integer && - CT == BaseType::Pointer) { - toremove.insert(pair.first); - } else { - if (CT == pair.second) { - // previous equivalent values or values overwritten by - // an anything are removed - toremove.insert(pair.first); - } else if (pair.second != BaseType::Anything) { + if (oldMoreGeneralThanNew) { + // Inserting an existing or less general version + if (CT == pair.second) + return false; + + // Inserting an existing or less general version (with pointer-int + // equivalence) + if (PointerIntSame) + if ((CT == BaseType::Pointer && + pair.second == BaseType::Integer) || + (CT == BaseType::Integer && pair.second == BaseType::Pointer)) + return false; + + // Inserting into an anything. Since from above we know this is not + // an anything, the inserted value contains no new information + if (pair.second == BaseType::Anything) + return false; + + // Inserting say a [0]:anything into a [-1]:Float + if (CT == BaseType::Anything) + continue; + + // Otherwise, inserting a non-equivalent pair into a more general + // slot. This is invalid. llvm::errs() << "inserting into : " << str() << " with " << to_string(Seq) << " of " << CT.str() << "\n"; llvm_unreachable("illegal insertion"); - } - } - } + } else if (newMoreGeneralThanOld) { + // This new is strictly more general than the old. If they were + // equivalent, the case above would have been hit. - for (const auto &val : toremove) { - mapping.erase(val); - changed = true; - } - } + if (CT == BaseType::Anything || CT == pair.second) { + // previous equivalent values or values overwritten by + // an anything are removed + toremove.insert(pair.first); + continue; + } - // if this is a starting -1, remove other -1's - for (size_t prefixSize = 1; prefixSize <= SeqSize; prefixSize++) { - if (Seq[prefixSize - 1] != -1) - break; - std::set> toremove; - for (const auto &pair : mapping) { - if (pair.first.size() != SeqSize) - continue; - bool matches = true; - for (unsigned i = prefixSize; i < SeqSize; ++i) { - if (pair.first[i] != Seq[i]) { - matches = false; - break; - } - } - if (!matches) - continue; - if (intsAreLegalSubPointer && pair.second == BaseType::Integer && - CT == BaseType::Pointer) { - toremove.insert(pair.first); - } else { - if (CT == pair.second) { - // previous equivalent values or values overwritten by - // an anything are removed - toremove.insert(pair.first); - } else if (pair.second != BaseType::Anything) { + // Inserting an existing or less general version (with pointer-int + // equivalence) + if (PointerIntSame) + if ((CT == BaseType::Pointer && + pair.second == BaseType::Integer) || + (CT == BaseType::Integer && + pair.second == BaseType::Pointer)) { + toremove.insert(pair.first); + continue; + } + + // Keep lingering anythings if not being overwritten, even if this + // (e.g. Float) applies to more locations. Therefore it is legal to + // have [-1]:Float, [8]:Anything + if (CT != BaseType::Anything && pair.second == BaseType::Anything) + continue; + + // Otherwise, inserting a more general non-equivalent pair. This is + // invalid. llvm::errs() << "inserting into : " << str() << " with " << to_string(Seq) << " of " << CT.str() << "\n"; llvm_unreachable("illegal insertion"); } } } - for (const auto &val : toremove) { - mapping.erase(val); changed = true; + mapping.erase(val); } } @@ -920,72 +935,98 @@ class TypeTree : public std::enable_shared_from_this { } } - // if this is a ending -1, remove other elems if no more info - for (size_t suffixSize = 1; suffixSize <= SeqSize; suffixSize++) { - if (Seq[SeqSize - suffixSize] != -1) - break; + // Check if there is an existing match, e.g. [-1, -1, -1] and inserting + // [-1, 8, -1] + { std::set> toremove; for (const auto &pair : mapping) { if (pair.first.size() == SeqSize) { - bool matches = true; - for (unsigned i = 0; i < SeqSize - suffixSize; ++i) { - if (pair.first[i] != Seq[i]) { - matches = false; + // Whether the the inserted val (e.g. [-1, 0] or [0, 0]) is at least + // as general as the existing map val (e.g. [0, 0]). + bool newMoreGeneralThanOld = true; + // Whether the the existing val (e.g. [-1, 0] or [0, 0]) is at least + // as general as the inserted map val (e.g. [0, 0]). + bool oldMoreGeneralThanNew = true; + for (unsigned i = 0; i < SeqSize; i++) { + if (pair.first[i] == Seq[i]) + continue; + if (Seq[i] == -1) { + oldMoreGeneralThanNew = false; + } else if (pair.first[i] == -1) { + newMoreGeneralThanOld = false; + } else { + oldMoreGeneralThanNew = false; + newMoreGeneralThanOld = false; break; } } - if (!matches) - continue; - if (CT == BaseType::Anything || CT == pair.second) { - // previous equivalent values or values overwritten by - // an anything are removed - toremove.insert(pair.first); - } else if (CT != BaseType::Anything && - pair.second == BaseType::Anything) { - // keep lingering anythings if not being overwritten - } else { + if (oldMoreGeneralThanNew) { + // Inserting an existing or less general version + if (CT == pair.second) + return false; + + // Inserting an existing or less general version (with pointer-int + // equivalence) + if (PointerIntSame) + if ((CT == BaseType::Pointer && + pair.second == BaseType::Integer) || + (CT == BaseType::Integer && + pair.second == BaseType::Pointer)) + return false; + + // Inserting into an anything. Since from above we know this is + // not an anything, the inserted value contains no new information + if (pair.second == BaseType::Anything) + return false; + + // Inserting say a [0]:anything into a [-1]:Float + if (CT == BaseType::Anything) { + // If both at same index, remove old index + if (newMoreGeneralThanOld) + toremove.insert(pair.first); + continue; + } + + // Otherwise, inserting a non-equivalent pair into a more general + // slot. This is invalid. LegalOr = false; return false; - } - } - } - for (const auto &val : toremove) { - mapping.erase(val); - } - } - - // if this is a starting -1, remove other -1's - for (size_t prefixSize = 1; prefixSize <= SeqSize; prefixSize++) { - if (Seq[prefixSize - 1] != -1) - break; - std::set> toremove; - for (const auto &pair : mapping) { - if (pair.first.size() == SeqSize) { - bool matches = true; - for (unsigned i = prefixSize; i < SeqSize; ++i) { - if (pair.first[i] != Seq[i]) { - matches = false; - break; + } else if (newMoreGeneralThanOld) { + // This new is strictly more general than the old. If they were + // equivalent, the case above would have been hit. + + if (CT == BaseType::Anything || CT == pair.second) { + // previous equivalent values or values overwritten by + // an anything are removed + toremove.insert(pair.first); + continue; } - } - if (!matches) - continue; - if (CT == BaseType::Anything || CT == pair.second) { - // previous equivalent values or values overwritten by - // an anything are removed - toremove.insert(pair.first); - } else if (CT != BaseType::Anything && - pair.second == BaseType::Anything) { - // keep lingering anythings if not being overwritten - } else { + // Inserting an existing or less general version (with pointer-int + // equivalence) + if (PointerIntSame) + if ((CT == BaseType::Pointer && + pair.second == BaseType::Integer) || + (CT == BaseType::Integer && + pair.second == BaseType::Pointer)) { + toremove.insert(pair.first); + continue; + } + + // Keep lingering anythings if not being overwritten, even if this + // (e.g. Float) applies to more locations. Therefore it is legal + // to have [-1]:Float, [8]:Anything + if (CT != BaseType::Anything && pair.second == BaseType::Anything) + continue; + + // Otherwise, inserting a more general non-equivalent pair. This + // is invalid. LegalOr = false; return false; } } } - for (const auto &val : toremove) { mapping.erase(val); } diff --git a/enzyme/test/Enzyme/ReverseMode/infanalysis.ll b/enzyme/test/Enzyme/ReverseMode/infanalysis.ll new file mode 100644 index 000000000000..3e0e939a3e7b --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/infanalysis.ll @@ -0,0 +1,22 @@ +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme" -S | FileCheck %s + +@_ZTV4Test = linkonce_odr dso_local unnamed_addr constant [1 x i8*] [i8* bitcast ({ i8**, i8* }* @_ZTI4Test to i8*)], align 8 +@_ZTVN10__cxxabiv117__class_type_infoE = external global i8* +@_ZTI4Test = linkonce_odr dso_local constant { i8**, i8* } { i8** getelementptr inbounds (i8*, i8** @_ZTVN10__cxxabiv117__class_type_infoE, i64 2), i8* null }, align 8 + +define void @_Z8simulatev() { +entry: + %sys = alloca i8**, align 8 + store i8** getelementptr inbounds ([1 x i8*], [1 x i8*]* @_ZTV4Test, i32 0, i32 1), i8*** %sys, align 8 + ret void +} + +define void @main() { +entry: + %call = call double @_Z17__enzyme_autodiffPv(void ()* @_Z8simulatev) + ret void +} + +declare double @_Z17__enzyme_autodiffPv(void ()*) + +; CHECK: define internal void @diffe_Z8simulatev From eb841acb497cf19211167f9886f2b4507569d39a Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 18 Sep 2023 13:30:39 -0500 Subject: [PATCH 09/29] Fix memset no type handling (#1435) --- enzyme/Enzyme/AdjointGenerator.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 3ce4f6165215..b0b11704b416 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -2766,10 +2766,12 @@ class AdjointGenerator if (II->getIntrinsicID() == Intrinsic::lifetime_start) { if (getBaseObject(II->getOperand(1)) == root) { if (auto CI2 = dyn_cast(II->getOperand(0))) { - if (MCI->getValue().ult(CI2->getValue())) + if (MCI->getValue().ule(CI2->getValue())) break; } } + cur = cur->getPrevNode(); + continue; } } } From 86fc287c5a39632364af2c48bc3efb5ef1f6652d Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 18 Sep 2023 13:44:06 -0500 Subject: [PATCH 10/29] Also memset undef as sret (#1436) --- enzyme/Enzyme/AdjointGenerator.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index b0b11704b416..99a73eca4415 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -2754,7 +2754,12 @@ class AdjointGenerator if (CI->isZero()) { auto root = getBaseObject(MS.getOperand(0)); bool writtenTo = false; - if (isa(root) || isAllocationCall(root, gutils->TLI)) { + bool undefMemory = + isa(root) || isAllocationCall(root, gutils->TLI); + if (auto arg = dyn_cast(root)) + if (arg->hasStructRetAttr()) + undefMemory = true; + if (undefMemory) { Instruction *cur = MS.getPrevNode(); while (cur) { if (cur == root) From e6e5a03c899d3d608dfa3d1c468e72e4d8e0952c Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 18 Sep 2023 16:20:14 -0500 Subject: [PATCH 11/29] Fix extract invertpointer ordering (#1437) --- enzyme/Enzyme/GradientUtils.cpp | 7 ++++--- enzyme/test/Enzyme/ForwardMode/augmentedreturn.ll | 2 +- enzyme/test/Enzyme/ForwardMode/freeuse.ll | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 0b7ca4be6368..ef6799c79ba4 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -5490,12 +5490,13 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, } goto end; } else if (auto arg = dyn_cast(oval)) { - IRBuilder<> bb(getNewFromOriginal(arg)); + auto newi = getNewFromOriginal(arg); + IRBuilder<> bb(newi->getNextNode()); auto ip = invertPointerM(arg->getOperand(0), bb, nullShadow); - auto rule = [&bb, &arg, this](Value *ip) -> llvm::Value * { + auto rule = [&bb, &arg, &newi, this](Value *ip) -> llvm::Value * { if (ip == getNewFromOriginal(arg->getOperand(0))) - return getNewFromOriginal(arg); + return newi; return bb.CreateExtractValue(ip, arg->getIndices(), arg->getName() + "'ipev"); }; diff --git a/enzyme/test/Enzyme/ForwardMode/augmentedreturn.ll b/enzyme/test/Enzyme/ForwardMode/augmentedreturn.ll index b8149fcbc399..1d622397ea4a 100644 --- a/enzyme/test/Enzyme/ForwardMode/augmentedreturn.ll +++ b/enzyme/test/Enzyme/ForwardMode/augmentedreturn.ll @@ -34,8 +34,8 @@ entry: ; CHECK-NEXT: %0 = call { { i8*, double }, { i8*, double } } @fwddiffeaugsquare(double %x, double %"x'") ; CHECK-NEXT: %1 = extractvalue { { i8*, double }, { i8*, double } } %0, 0 ; CHECK-NEXT: %2 = extractvalue { { i8*, double }, { i8*, double } } %0, 1 -; CHECK-NEXT: %[[i3:.+]] = extractvalue { i8*, double } %2, 1 ; CHECK-NEXT: %o = extractvalue { i8*, double } %1, 1 +; CHECK-NEXT: %[[i3:.+]] = extractvalue { i8*, double } %2, 1 ; CHECK-NEXT: %[[i4:.+]] = fmul fast double %[[i3]], %o ; CHECK-NEXT: %[[i5:.+]] = fmul fast double %[[i3]], %o ; CHECK-NEXT: %[[i6:.+]] = fadd fast double %[[i4]], %[[i5]] diff --git a/enzyme/test/Enzyme/ForwardMode/freeuse.ll b/enzyme/test/Enzyme/ForwardMode/freeuse.ll index 6192877a4384..a4472c373334 100644 --- a/enzyme/test/Enzyme/ForwardMode/freeuse.ll +++ b/enzyme/test/Enzyme/ForwardMode/freeuse.ll @@ -31,8 +31,8 @@ declare void @free(i8*) ; CHECK: define internal void @fwddiffejac_rev(double* nocapture readonly %r, { i8* } %tapeArg, { i8* } %"tapeArg'", i1 %cmp) ; CHECK-NEXT: entry: -; CHECK-NEXT: %"arg0'ipev" = extractvalue { i8* } %"tapeArg'", 0 ; CHECK-NEXT: %arg0 = extractvalue { i8* } %tapeArg, 0 +; CHECK-NEXT: %"arg0'ipev" = extractvalue { i8* } %"tapeArg'", 0 ; CHECK-NEXT: store i8 0, i8* %"arg0'ipev", align 8 ; CHECK-NEXT: store i8 0, i8* %arg0, align 8 ; CHECK-NEXT: br i1 %cmp, label %invertbaz, label %invertfoo @@ -52,7 +52,7 @@ declare void @free(i8*) ; CHECK-NEXT: ret void ; CHECK: invertfoo: ; preds = %entry -; CHECK-NEXT: %"arg1'ipev" = extractvalue { i8* } %"tapeArg'", 0 ; CHECK-NEXT: %arg1 = extractvalue { i8* } %tapeArg, 0 +; CHECK-NEXT: %"arg1'ipev" = extractvalue { i8* } %"tapeArg'", 0 ; CHECK-NEXT: br label %invertbaz ; CHECK-NEXT: } From 3870f89aea2354a1900bfda0bfa224faac6a9f15 Mon Sep 17 00:00:00 2001 From: jlk9 Date: Mon, 18 Sep 2023 16:20:31 -0500 Subject: [PATCH 12/29] Adding instructions for expm1f / expm1l (#1438) * Update InstructionDerivatives.td Adding the functions expm1f and expm1l to the instruction table gen so they are supported. * Adding a test for new instructions for expm1f --- enzyme/Enzyme/InstructionDerivatives.td | 2 +- enzyme/test/Enzyme/ReverseMode/expm1f.ll | 29 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 enzyme/test/Enzyme/ReverseMode/expm1f.ll diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 1796645073fc..9de7f8b871ae 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -370,7 +370,7 @@ def : CallPattern<(Op $x), [ReadNone, NoUnwind] >; def : CallPattern<(Op $x), - ["expm1"], + ["expm1", "expm1f", "expm1l"], [(FMul (Intrinsic<"exp"> $x), (DiffeRet))], (ForwardFromSummedReverse), [ReadNone, NoUnwind] diff --git a/enzyme/test/Enzyme/ReverseMode/expm1f.ll b/enzyme/test/Enzyme/ReverseMode/expm1f.ll new file mode 100644 index 000000000000..e284a0334c16 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/expm1f.ll @@ -0,0 +1,29 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -simplifycfg -instsimplify -adce -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,%simplifycfg,instsimplify,adce)" -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define float @tester(float %x) { +entry: + %0 = tail call fast float @expm1f(float %x) + ret float %0 +} + +define float @test_derivative(float %x) { +entry: + %0 = tail call float (float (float)*, ...) @__enzyme_autodiff(float (float)* nonnull @tester, float %x) + ret float %0 +} + +; Function Attrs: nounwind readnone speculatable +declare float @expm1f(float) + +; Function Attrs: nounwind +declare float @__enzyme_autodiff(float (float)*, ...) + +; CHECK: define internal { float } @diffetester(float %x, float %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[i0:.+]] = call fast float @llvm.exp.f32(float %x) +; CHECK-NEXT: %[[i1:.+]] = fmul fast float %[[i0]], %differeturn +; CHECK-NEXT: %[[i2:.+]] = insertvalue { float } {{(undef|poison)}}, float %[[i1]], 0 +; CHECK-NEXT: ret { float } %[[i2]] +; CHECK-NEXT: } \ No newline at end of file From e8e4188a6ac3c58ce1336f5b56659388bfd70b66 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 18 Sep 2023 23:27:05 -0500 Subject: [PATCH 13/29] Fix mem2reg on input fn arg (#1442) --- enzyme/Enzyme/Utils.cpp | 3 +++ enzyme/test/Enzyme/ReverseMode/mem2regfn.ll | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+) create mode 100644 enzyme/test/Enzyme/ReverseMode/mem2regfn.ll diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index e2d896893537..94de076a4af1 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -2318,6 +2318,9 @@ Function *GetFunctionFromValue(Value *fn) { continue; } + if (isa(cur)) + continue; + if (!cur->mayWriteToMemory() && cur->getType()->isVoidTy()) continue; diff --git a/enzyme/test/Enzyme/ReverseMode/mem2regfn.ll b/enzyme/test/Enzyme/ReverseMode/mem2regfn.ll new file mode 100644 index 000000000000..00d9fe10883c --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/mem2regfn.ll @@ -0,0 +1,18 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme" -S | FileCheck %s + +define double @_Z3food(double %0) { + ret double %0 +} + +define i32 @main() { + %a5 = alloca i8*, align 8 + store i8* bitcast (double (double)* @_Z3food to i8*), i8** %a5, align 8 + %a17 = load i8*, i8** %a5, align 8 + %q = call double (...) @__enzyme_autodiff(i8* %a17, metadata !"enzyme_dup", double 1.0, double 1.0) + ret i32 0 +} + +declare double @__enzyme_autodiff(...) + +; CHECK: define internal void @diffe_Z3food From 7e719eec342d59409dcd735ba0812a0209e1eca2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 18 Sep 2023 23:50:18 -0500 Subject: [PATCH 14/29] Reuse intermediate loop variables on multi exit (#1443) --- enzyme/Enzyme/GradientUtils.cpp | 12 +++++++----- enzyme/test/Enzyme/ReverseMode/multiloopexit.ll | 3 +-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index ef6799c79ba4..8901f782a88c 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -3686,6 +3686,7 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB, IRBuilder<> tbuild(prevBlock); SmallVector, 1> lims; + ValueToValueMapTy available; for (auto I = exitingContexts.rbegin(), E = exitingContexts.rend(); I != E; I++) { auto &lc = *I; @@ -3699,13 +3700,14 @@ BasicBlock *GradientUtils::getReverseOrLatchMerge(BasicBlock *BB, assert(/*ReverseLimit*/ reverseBlocks.size() > 0); LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0, lc.preheader); - lim = lookupValueFromCache( - lc.var->getType(), - /*forwardPass*/ false, tbuild, lctx, getDynamicLoopLimit(L), - /*isi1*/ false, /*available*/ ValueToValueMapTy()); + lim = lookupValueFromCache(lc.var->getType(), + /*forwardPass*/ false, tbuild, lctx, + getDynamicLoopLimit(L), + /*isi1*/ false, available); } else { - lim = lookupM(lc.trueLimit, tbuild); + lim = lookupM(lc.trueLimit, tbuild, available); } + available[lc.var] = lim; lims.push_back(std::make_pair(lim, (Value *)lc.antivaralloc)); } diff --git a/enzyme/test/Enzyme/ReverseMode/multiloopexit.ll b/enzyme/test/Enzyme/ReverseMode/multiloopexit.ll index 463247f0bc7d..b1e71bcb0a69 100644 --- a/enzyme/test/Enzyme/ReverseMode/multiloopexit.ll +++ b/enzyme/test/Enzyme/ReverseMode/multiloopexit.ll @@ -200,8 +200,7 @@ define float @c() { ; CHECK-NEXT: store float %differeturn, float* %"sum.2'de", align 4 ; CHECK-NEXT: %[[i89:.+]] = load i64, i64* %loopLimit_cache, align 8 ; CHECK-NEXT: %[[i90:.+]] = load i64*, i64** %loopLimit_cache2, align 8 -; CHECK-NEXT: %[[i91:.+]] = load i64, i64* %"iv'ac", align 4 -; CHECK-NEXT: %[[i92:.+]] = getelementptr inbounds i64, i64* %[[i90]], i64 %[[i91]] +; CHECK-NEXT: %[[i92:.+]] = getelementptr inbounds i64, i64* %[[i90]], i64 %[[i89]] ; CHECK-NEXT: %[[i93:.+]] = load i64, i64* %[[i92]], align 8 ; CHECK-NEXT: br label %mergeinvertloop2_exit From c2a4d0fa557c921a61c1c8bbfe2507e1164194d2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 19 Sep 2023 01:06:39 -0500 Subject: [PATCH 15/29] [BLAS] fix row transpose arg (#1440) Co-authored-by: Manuel Drehwald --- .../test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll | 6 +++--- .../ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll | 4 ++-- .../test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll | 2 +- enzyme/tools/enzyme-tblgen/blas-tblgen.cpp | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll index f2bef95aec83..a05cdb333fd1 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll @@ -187,7 +187,7 @@ entry: ; CHECK-NEXT: store i64 1, i64* %byref.constant.int.1 ; CHECK-NEXT: %intcast.constant.int.1 = bitcast i64* %byref.constant.int.1 to i8* ; CHECK-NEXT: call void @dgemv_64_(i8* %malloccall, i8* %m_p, i8* %n_p, i8* %fpcast.constant.fp.1.0, i8* %A, i8* %lda_p, i8* %20, i8* %intcast.int.one, i8* %fpcast.constant.fp.0.0, i8* %19, i8* %intcast.constant.int.1) -; CHECK-NEXT: %ld.row.trans = load i8, i8* %byref.transpose.transa +; CHECK-NEXT: %ld.row.trans = load i8, i8* %malloccall ; CHECK-DAG: %[[c1:.+]] = icmp eq i8 %ld.row.trans, 110 ; CHECK-DAG: %[[c2:.+]] = icmp eq i8 %ld.row.trans, 78 ; CHECK-NEXT: %[[c3:.+]] = or i1 %[[c2]], %[[c1]] @@ -203,7 +203,7 @@ entry: ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.09 ; CHECK-NEXT: %fpcast.constant.fp.1.010 = bitcast double* %byref.constant.fp.1.09 to i8* ; CHECK-NEXT: call void @dgemv_64_(i8* %byref.transpose.transa, i8* %m_p, i8* %n_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %"y'", i8* %incy_p, i8* %fpcast.constant.fp.1.010, i8* %"x'", i8* %incx_p) -; CHECK-NEXT: %ld.row.trans11 = load i8, i8* %byref.transpose.transa +; CHECK-NEXT: %ld.row.trans11 = load i8, i8* %malloccall ; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans11, 110 ; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans11, 78 ; CHECK-NEXT: %41 = or i1 %[[r40]], %[[r39]] @@ -213,7 +213,7 @@ entry: ; CHECK-NEXT: %45 = load double, double* %44 ; CHECK-NEXT: %46 = fadd fast double %45, %43 ; CHECK-NEXT: store double %46, double* %44 -; CHECK-NEXT: %ld.row.trans12 = load i8, i8* %byref.transpose.transa +; CHECK-NEXT: %ld.row.trans12 = load i8, i8* %malloccall ; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans12, 110 ; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans12, 78 ; CHECK-NEXT: %49 = or i1 %[[r48]], %[[r47]] diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll index a214afa93a24..81655aebbe38 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll @@ -178,7 +178,7 @@ entry: ; CHECK-NEXT: br i1 %rt.inactive.beta, label %invertentry.beta.done, label %invertentry.beta.active ; CHECK: invertentry.beta.active: ; preds = %invertentry.A.done -; CHECK-NEXT: %ld.row.trans = load i8, i8* %byref.transpose.transa +; CHECK-NEXT: %ld.row.trans = load i8, i8* %malloccall ; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans, 110 ; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans, 78 ; CHECK-NEXT: %[[r41:.+]] = or i1 %[[r40]], %[[r39]] @@ -194,7 +194,7 @@ entry: ; CHECK-NEXT: br i1 %rt.inactive.y, label %invertentry.y.done, label %invertentry.y.active ; CHECK: invertentry.y.active: ; preds = %invertentry.beta.done -; CHECK-NEXT: %ld.row.trans2 = load i8, i8* %byref.transpose.transa +; CHECK-NEXT: %ld.row.trans2 = load i8, i8* %malloccall ; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans2, 110 ; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans2, 78 ; CHECK-NEXT: %[[r49:.+]] = or i1 %[[r48]], %[[r47]] diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll index 8a1dd30e639e..65857bbba6ec 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll @@ -194,7 +194,7 @@ entry: ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 ; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* ; CHECK-NEXT: call void @dgemv_64_(i8* %byref.transpose.transa, i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"y'", i8* %incy_p, i8* %fpcast.constant.fp.1.0, i8* %"x'", i8* %incx_p) -; CHECK-NEXT: %ld.row.trans = load i8, i8* %byref.transpose.transa +; CHECK-NEXT: %ld.row.trans = load i8, i8* %malloccall ; CHECK-DAG: %[[r2:.+]] = icmp eq i8 %ld.row.trans, 110 ; CHECK-DAG: %[[r3:.+]] = icmp eq i8 %ld.row.trans, 78 ; CHECK-NEXT: %26 = or i1 %[[r3]], %[[r2]] diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index 92e8c30e3020..a55f3eaf717d 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -1042,8 +1042,8 @@ void rev_call_arg(StringRef argName, DagInit *ruleDag, Rule &rule, auto tname = Dag->getArgNameStr(0); auto rname = Dag->getArgNameStr(1); auto cname = Dag->getArgNameStr(2); - os << "get_blas_row(Builder2, arg_transposed_" << tname << ", arg_" - << rname << ", arg_" << cname << ", byRef)"; + os << "get_blas_row(Builder2, arg_" << tname << ", arg_" << rname + << ", arg_" << cname << ", byRef)"; } else if (Def->isSubClassOf("MagicInst") && Def->getName() == "ld") { assert(Dag->getNumArgs() == 5); //(ld $A, $transa, $lda, $m, $k) From d009770958d78c4abfde32559e0c62417f37d37f Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 19 Sep 2023 19:20:52 -0400 Subject: [PATCH 16/29] fix gemv rule for A (#1441) * fix gemv rule for A * looking for shadow usage now recursively * fmt * now really * update testcases --- enzyme/Enzyme/BlasDerivatives.td | 12 +++- .../Enzyme/ReverseMode/blas/gemv_c_loop2.ll | 4 +- .../ReverseMode/blas/gemv_c_loop3_matcopy.ll | 20 ++++-- .../blas/gemv_f_c_split_blascpy.ll | 70 ++++++++++++------- .../gemv_f_c_split_blascpy_runtime_act.ll | 34 +++++++-- .../ReverseMode/blas/gemv_f_c_split_memcpy.ll | 40 ++++++++--- enzyme/tools/enzyme-tblgen/blas-tblgen.cpp | 63 +++++++++++++---- 7 files changed, 180 insertions(+), 63 deletions(-) diff --git a/enzyme/Enzyme/BlasDerivatives.td b/enzyme/Enzyme/BlasDerivatives.td index e39f635528ff..6edf464c2372 100644 --- a/enzyme/Enzyme/BlasDerivatives.td +++ b/enzyme/Enzyme/BlasDerivatives.td @@ -187,7 +187,17 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $ /* alpha */ (Seq<["Ax", "is_normal", "transa", "m", "n"]> (b<"gemv"> $layout, $transa, $m, $n, Constant<"1.0">, $A, (ld $A, $transa, $lda, $m, $n), $x, $incx, Constant<"0.0">, use<"Ax">, ConstantInt<1>), (b<"dot"> (Rows $transa, $m, $n), adj<"y">, $incy, use<"Ax">, ConstantInt<1>)), - /* A */ (b<"ger"> $layout, $m, $n, $alpha, adj<"y">, $incy, $x, $incx, adj<"A">, $lda), + + //if (is_normal $transa) { + // call sger(m, n, alpha, ya, incy, x, incx, Aa, lda) + //} else { + // call sger(m, n, alpha, x, incx, ya, incy, Aa, lda) + //} + /* A */ (b<"ger"> $layout, $m, $n, $alpha, (Rows $transa, adj<"y">, $x), + (Rows $transa, $incy, $incx), + (Rows $transa, $x, adj<"y">), + (Rows $transa, $incx, $incy), + adj<"A">, $lda), /* x */ (b<"gemv"> $layout, transpose<"transa">, $m, $n, $alpha, $A, (ld $A, $transa, $lda, $m, $n), adj<"y">, $incy, Constant<"1.0">, adj<"x">, $incx), /* beta */ (b<"dot"> (Rows $transa, $m, $n), adj<"y">, $incy, input<"y">, $incy), /* y */ (b<"scal"> (Rows $transa, $m, $n), $beta, adj<"y">, $incy) diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop2.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop2.ll index 98dcca57f324..872964e04cf7 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop2.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop2.ll @@ -105,7 +105,9 @@ entry: ; CHECK-NEXT: %[[i20:.+]] = getelementptr inbounds i8*, i8** %[[i19]], i64 %[[i15]] ; CHECK-NEXT: %[[i21:.+]] = load i8*, i8** %[[i20]], align 8, !invariant.group !8 ; CHECK-NEXT: %cache.A_unwrap = bitcast i8* %[[i21]] to double* -; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %"v0'", i32 1, double* %cache.x_unwrap, i32 1, double* %"K'", i32 %N) +; CHECK-DAG: %[[r20:.+]] = select i1 false, double* %"v0'", double* %cache.x_unwrap +; CHECK-DAG: %[[r21:.+]] = select i1 false, double* %cache.x_unwrap, double* %"v0'" +; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r20]], i32 1, double* %[[r21]], i32 1, double* %"K'", i32 %N) ; CHECK-NEXT: %[[i22:.+]] = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A_unwrap, i32 %[[i22]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) ; CHECK-NEXT: %[[i23:.+]] = select i1 false, i32 %N, i32 %N diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop3_matcopy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop3_matcopy.ll index 08bf25bfc360..3819ed2b6eb9 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop3_matcopy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop3_matcopy.ll @@ -192,11 +192,15 @@ entry: ; CHECK-NEXT: %[[i45:.+]] = load double, double* %"x0'", align 8 ; CHECK-NEXT: %[[i46:.+]] = fadd fast double %[[i45:.+]], %[[i44]] ; CHECK-NEXT: store double %[[i46:.+]], double* %"x0'", align 8 -; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %"v0'", i32 1, double* %x0, i32 1, double* %"K'", i32 %N) +; CHECK-DAG: %[[r39:.+]] = select i1 false, double* %"v0'", double* %x0 +; CHECK-DAG: %[[r40:.+]] = select i1 false, double* %x0, double* %"v0'" +; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r39]], i32 1, double* %[[r40]], i32 1, double* %"K'", i32 %N) ; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %K, i32 %N, double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) ; CHECK-NEXT: %[[i47:.+]] = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: call void @cblas_dscal(i32 %[[i47]], double 1.000000e+00, double* %"v0'", i32 1) -; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %"v0'", i32 1, double* %cache.x, i32 1, double* %"K'", i32 %N) +; CHECK-DAG: %[[r42:.+]] = select i1 false, double* %"v0'", double* %cache.x +; CHECK-DAG: %[[r43:.+]] = select i1 false, double* %cache.x, double* %"v0'" +; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r42]], i32 1, double* %[[r43]], i32 1, double* %"K'", i32 %N) ; CHECK-NEXT: %[[i48:.+]] = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A, i32 %[[i48]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) ; CHECK-NEXT: %[[i49:.+]] = select i1 false, i32 %N, i32 %N @@ -205,7 +209,9 @@ entry: ; CHECK-NEXT: tail call void @free(i8* nonnull %[[i50]]) ; CHECK-NEXT: %[[i51:.+]] = bitcast double* %cache.x to i8* ; CHECK-NEXT: tail call void @free(i8* nonnull %[[i51]]) -; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %"v0'", i32 1, double* %cache.x8, i32 1, double* %"K'", i32 %N) +; CHECK-DAG: %[[r48:.+]] = select i1 false, double* %"v0'", double* %cache.x8 +; CHECK-DAG: %[[r49:.+]] = select i1 false, double* %cache.x8, double* %"v0'" +; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r48]], i32 1, double* %[[r49]], i32 1, double* %"K'", i32 %N) ; CHECK-NEXT: %[[i52:.+]] = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A5, i32 %[[i52]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) ; CHECK-NEXT: %[[i53:.+]] = select i1 false, i32 %N, i32 %N @@ -214,7 +220,9 @@ entry: ; CHECK-NEXT: tail call void @free(i8* nonnull %[[i54]]) ; CHECK-NEXT: %[[i55:.+]] = bitcast double* %cache.x8 to i8* ; CHECK-NEXT: tail call void @free(i8* nonnull %[[i55]]) -; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %"v0'", i32 1, double* %cache.x16, i32 1, double* %"K'", i32 %N) +; CHECK-DAG: %[[r54:.+]] = select i1 false, double* %"v0'", double* %cache.x16 +; CHECK-DAG: %[[r55:.+]] = select i1 false, double* %cache.x16, double* %"v0'" +; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r54]], i32 1, double* %[[r55]], i32 1, double* %"K'", i32 %N) ; CHECK-NEXT: %[[i56:.+]] = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A13, i32 %[[i56]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) ; CHECK-NEXT: %[[i57:.+]] = select i1 false, i32 %N, i32 %N @@ -223,7 +231,9 @@ entry: ; CHECK-NEXT: tail call void @free(i8* nonnull %[[i58]]) ; CHECK-NEXT: %[[i59:.+]] = bitcast double* %cache.x16 to i8* ; CHECK-NEXT: tail call void @free(i8* nonnull %[[i59]]) -; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %"v0'", i32 1, double* %cache.x24, i32 1, double* %"K'", i32 %N) +; CHECK-DAG: %[[r60:.+]] = select i1 false, double* %"v0'", double* %cache.x24 +; CHECK-DAG: %[[r61:.+]] = select i1 false, double* %cache.x24, double* %"v0'" +; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r60]], i32 1, double* %[[r61]], i32 1, double* %"K'", i32 %N) ; CHECK-NEXT: %[[i60:.+]] = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A21, i32 %[[i60]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) ; CHECK-NEXT: %[[i61:.+]] = select i1 false, i32 %N, i32 %N diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll index a05cdb333fd1..856451959941 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll @@ -118,7 +118,7 @@ entry: ; CHECK-NEXT: %byref.constant.fp.0.0 = alloca double ; CHECK-NEXT: %byref.constant.int.1 = alloca i64 ; CHECK-NEXT: %byref.constant.int.17 = alloca i64 -; CHECK-NEXT: %byref.constant.fp.1.09 = alloca double +; CHECK-NEXT: %byref.constant.fp.1.013 = alloca double ; CHECK-NEXT: %incy = alloca i64, i64 1, align 16 ; CHECK-NEXT: %1 = bitcast i64* %incy to i8* ; CHECK-NEXT: %incx = alloca i64, i64 1, align 16 @@ -199,29 +199,49 @@ entry: ; CHECK-NEXT: %37 = load double, double* %36 ; CHECK-NEXT: %38 = fadd fast double %37, %35 ; CHECK-NEXT: store double %38, double* %36 -; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha, i8* %"y'", i8* %incy_p, i8* %20, i8* %intcast.int.one, i8* %"A'", i8* %lda_p) -; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.09 -; CHECK-NEXT: %fpcast.constant.fp.1.010 = bitcast double* %byref.constant.fp.1.09 to i8* -; CHECK-NEXT: call void @dgemv_64_(i8* %byref.transpose.transa, i8* %m_p, i8* %n_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %"y'", i8* %incy_p, i8* %fpcast.constant.fp.1.010, i8* %"x'", i8* %incx_p) -; CHECK-NEXT: %ld.row.trans11 = load i8, i8* %malloccall -; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans11, 110 -; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans11, 78 -; CHECK-NEXT: %41 = or i1 %[[r40]], %[[r39]] -; CHECK-NEXT: %42 = select i1 %41, i8* %m_p, i8* %n_p -; CHECK-NEXT: %43 = call fast double @ddot_64_(i8* %42, i8* %"y'", i8* %incy_p, i8* %21, i8* %intcast.int.one) -; CHECK-NEXT: %44 = bitcast i8* %"beta'" to double* -; CHECK-NEXT: %45 = load double, double* %44 -; CHECK-NEXT: %46 = fadd fast double %45, %43 -; CHECK-NEXT: store double %46, double* %44 -; CHECK-NEXT: %ld.row.trans12 = load i8, i8* %malloccall -; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans12, 110 -; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans12, 78 -; CHECK-NEXT: %49 = or i1 %[[r48]], %[[r47]] -; CHECK-NEXT: %50 = select i1 %49, i8* %m_p, i8* %n_p -; CHECK-NEXT: call void @dscal_64_(i8* %50, i8* %beta, i8* %"y'", i8* %incy_p) -; CHECK-NEXT: %51 = bitcast double* %tape.ext.x to i8* -; CHECK-NEXT: tail call void @free(i8* nonnull %51) -; CHECK-NEXT: %52 = bitcast double* %tape.ext.y1 to i8* -; CHECK-NEXT: tail call void @free(i8* nonnull %52) +; CHECK-NEXT: %ld.row.trans9 = load i8, i8* %malloccall, align 1 +; CHECK-DAG: %[[i39:.+]] = icmp eq i8 %ld.row.trans9, 110 +; CHECK-DAG: %[[i40:.+]] = icmp eq i8 %ld.row.trans9, 78 +; CHECK-NEXT: %[[i41:.+]] = or i1 %[[i40]], %[[i39]] +; CHECK-NEXT: %[[i42:.+]] = select i1 %41, i8* %"y'", i8* %20 +; CHECK-NEXT: %ld.row.trans10 = load i8, i8* %malloccall, align 1 +; CHECK-DAG: %[[i43:.+]] = icmp eq i8 %ld.row.trans10, 110 +; CHECK-DAG: %[[i44:.+]] = icmp eq i8 %ld.row.trans10, 78 +; CHECK-NEXT: %[[i45:.+]] = or i1 %[[i44]], %[[i43]] +; CHECK-NEXT: %[[i46:.+]] = select i1 %[[i45]], i8* %incy_p, i8* %incx_p +; CHECK-NEXT: %ld.row.trans11 = load i8, i8* %malloccall, align 1 +; CHECK-DAG: %[[i47:.+]] = icmp eq i8 %ld.row.trans11, 110 +; CHECK-DAG: %[[i48:.+]] = icmp eq i8 %ld.row.trans11, 78 +; CHECK-NEXT: %[[i49:.+]] = or i1 %[[i48]], %[[i47]] +; CHECK-NEXT: %[[i50:.+]] = select i1 %[[i49]], i8* %20, i8* %"y'" +; CHECK-NEXT: %ld.row.trans12 = load i8, i8* %malloccall, align 1 +; CHECK-NEXT: %[[i51:.+]] = icmp eq i8 %ld.row.trans12, 110 +; CHECK-NEXT: %[[i52:.+]] = icmp eq i8 %ld.row.trans12, 78 +; CHECK-NEXT: %[[i53:.+]] = or i1 %52, %51 +; CHECK-NEXT: %[[i54:.+]] = select i1 %53, i8* %incx_p, i8* %incy_p +; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha, i8* %[[i42]], i8* %[[i46]], i8* %[[i50]], i8* %[[i54]], i8* %"A'", i8* %lda_p) +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.013 +; CHECK-NEXT: %fpcast.constant.fp.1.014 = bitcast double* %byref.constant.fp.1.013 to i8* +; CHECK-NEXT: call void @dgemv_64_(i8* %byref.transpose.transa, i8* %m_p, i8* %n_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %"y'", i8* %incy_p, i8* %fpcast.constant.fp.1.014, i8* %"x'", i8* %incx_p) +; CHECK-NEXT: %ld.row.trans15 = load i8, i8* %malloccall +; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans15, 110 +; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans15, 78 +; CHECK-NEXT: %[[r41:.+]] = or i1 %[[r40]], %[[r39]] +; CHECK-NEXT: %[[r42:.+]] = select i1 %[[r41]], i8* %m_p, i8* %n_p +; CHECK-NEXT: %[[r43:.+]] = call fast double @ddot_64_(i8* %[[r42]], i8* %"y'", i8* %incy_p, i8* %21, i8* %intcast.int.one) +; CHECK-NEXT: %[[r44:.+]] = bitcast i8* %"beta'" to double* +; CHECK-NEXT: %[[r45:.+]] = load double, double* %[[r44]] +; CHECK-NEXT: %[[r46:.+]] = fadd fast double %[[r45]], %[[r43]] +; CHECK-NEXT: store double %[[r46]], double* %[[r44]] +; CHECK-NEXT: %ld.row.trans16 = load i8, i8* %malloccall +; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans16, 110 +; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans16, 78 +; CHECK-NEXT: %[[r49:.+]] = or i1 %[[r48]], %[[r47]] +; CHECK-NEXT: %[[r50:.+]] = select i1 %[[r49]], i8* %m_p, i8* %n_p +; CHECK-NEXT: call void @dscal_64_(i8* %[[r50]], i8* %beta, i8* %"y'", i8* %incy_p) +; CHECK-NEXT: %[[r51:.+]] = bitcast double* %tape.ext.x to i8* +; CHECK-NEXT: tail call void @free(i8* nonnull %[[r51]]) +; CHECK-NEXT: %[[r52:.+]] = bitcast double* %tape.ext.y1 to i8* +; CHECK-NEXT: tail call void @free(i8* nonnull %[[r52]]) ; CHECK-NEXT: ret void ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll index 81655aebbe38..2e2ea7b4dac4 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll @@ -171,16 +171,36 @@ entry: ; CHECK-NEXT: br i1 %rt.inactive.A, label %invertentry.A.done, label %invertentry.A.active ; CHECK: invertentry.A.active: ; preds = %invertentry -; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha, i8* %"y'", i8* %incy_p, i8* %[[i11]], i8* %intcast.int.one, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: %ld.row.trans = load i8, i8* %malloccall, align 1 +; CHECK-DAG: %[[r22:.+]] = icmp eq i8 %ld.row.trans, 110 +; CHECK-DAG: %[[r23:.+]] = icmp eq i8 %ld.row.trans, 78 +; CHECK-NEXT: %[[r24:.+]] = or i1 %[[r23]], %[[r22]] +; CHECK-NEXT: %[[r25:.+]] = select i1 %[[r24]], i8* %"y'", i8* %11 +; CHECK-NEXT: %ld.row.trans2 = load i8, i8* %malloccall, align 1 +; CHECK-DAG: %[[r26:.+]] = icmp eq i8 %ld.row.trans2, 110 +; CHECK-DAG: %[[r27:.+]] = icmp eq i8 %ld.row.trans2, 78 +; CHECK-NEXT: %[[r28:.+]] = or i1 %[[r27]], %[[r26]] +; CHECK-NEXT: %[[r29:.+]] = select i1 %[[r28]], i8* %incy_p, i8* %incx_p +; CHECK-NEXT: %ld.row.trans3 = load i8, i8* %malloccall, align 1 +; CHECK-DAG: %[[r30:.+]] = icmp eq i8 %ld.row.trans3, 110 +; CHECK-DAG: %[[r31:.+]] = icmp eq i8 %ld.row.trans3, 78 +; CHECK-NEXT: %[[r32:.+]] = or i1 %[[r31]], %[[r30]] +; CHECK-NEXT: %[[r33:.+]] = select i1 %[[r32]], i8* %11, i8* %"y'" +; CHECK-NEXT: %ld.row.trans4 = load i8, i8* %malloccall, align 1 +; CHECK-DAG: %[[r34:.+]] = icmp eq i8 %ld.row.trans4, 110 +; CHECK-DAG: %[[r35:.+]] = icmp eq i8 %ld.row.trans4, 78 +; CHECK-NEXT: %[[r36:.+]] = or i1 %[[r35]], %[[r34]] +; CHECK-NEXT: %[[r37:.+]] = select i1 %[[r36]], i8* %incx_p, i8* %incy_p +; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha, i8* %[[r25]], i8* %[[r29]], i8* %[[r33]], i8* %[[r37]], i8* %"A'", i8* %lda_p) ; CHECK-NEXT: br label %invertentry.A.done ; CHECK: invertentry.A.done: ; preds = %invertentry.A.active, %invertentry ; CHECK-NEXT: br i1 %rt.inactive.beta, label %invertentry.beta.done, label %invertentry.beta.active ; CHECK: invertentry.beta.active: ; preds = %invertentry.A.done -; CHECK-NEXT: %ld.row.trans = load i8, i8* %malloccall -; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans, 110 -; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans, 78 +; CHECK-NEXT: %ld.row.trans5 = load i8, i8* %malloccall +; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans5, 110 +; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans5, 78 ; CHECK-NEXT: %[[r41:.+]] = or i1 %[[r40]], %[[r39]] ; CHECK-NEXT: %[[r42:.+]] = select i1 %[[r41]], i8* %m_p, i8* %n_p ; CHECK-NEXT: %[[r43:.+]] = call fast double @ddot_64_(i8* %[[r42]], i8* %"y'", i8* %incy_p, i8* %[[i12]], i8* %intcast.int.one) @@ -194,9 +214,9 @@ entry: ; CHECK-NEXT: br i1 %rt.inactive.y, label %invertentry.y.done, label %invertentry.y.active ; CHECK: invertentry.y.active: ; preds = %invertentry.beta.done -; CHECK-NEXT: %ld.row.trans2 = load i8, i8* %malloccall -; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans2, 110 -; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans2, 78 +; CHECK-NEXT: %ld.row.trans6 = load i8, i8* %malloccall +; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans6, 110 +; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans6, 78 ; CHECK-NEXT: %[[r49:.+]] = or i1 %[[r48]], %[[r47]] ; CHECK-NEXT: %[[r50:.+]] = select i1 %[[r49]], i8* %m_p, i8* %n_p ; CHECK-NEXT: call void @dscal_64_(i8* %[[r50]], i8* %beta, i8* %"y'", i8* %incy_p) diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll index 65857bbba6ec..bef22acf4410 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll @@ -189,18 +189,38 @@ entry: ; CHECK-DAG: %[[r7:.+]] = select i1 %[[r6]], i8 78, i8 %[[r5]] ; CHECK-NEXT: store i8 %23, i8* %byref.transpose.transa ; CHECK-NEXT: store i64 1, i64* %byref.int.one -; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %"y'", i8* %incy_p, i8* %15, i8* %intcast.int.one, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* +; CHECK-NEXT: %ld.row.trans = load i8, i8* %malloccall, align 1 +; CHECK-DAG: %[[r24:.+]] = icmp eq i8 %ld.row.trans, 110 +; CHECK-DAG: %[[r25:.+]] = icmp eq i8 %ld.row.trans, 78 +; CHECK-DAG: %[[r26:.+]] = or i1 %[[r25]], %[[r24]] +; CHECK-NEXT: %[[r27:.+]] = select i1 %[[r26]], i8* %"y'", i8* %15 +; CHECK-NEXT: %ld.row.trans1 = load i8, i8* %malloccall, align 1 +; CHECK-DAG: %[[r28:.+]] = icmp eq i8 %ld.row.trans1, 110 +; CHECK-DAG: %[[r29:.+]] = icmp eq i8 %ld.row.trans1, 78 +; CHECK-DAG: %[[r30:.+]] = or i1 %[[r29]], %[[r28]] +; CHECK-NEXT: %[[r31:.+]] = select i1 %[[r30]], i8* %incy_p, i8* %incx_p +; CHECK-NEXT: %ld.row.trans2 = load i8, i8* %malloccall, align 1 +; CHECK-DAG: %[[r32:.+]] = icmp eq i8 %ld.row.trans2, 110 +; CHECK-DAG: %[[r33:.+]] = icmp eq i8 %ld.row.trans2, 78 +; CHECK-DAG: %[[r34:.+]] = or i1 %[[r33]], %[[r32]] +; CHECK-NEXT: %[[r35:.+]] = select i1 %[[r34]], i8* %15, i8* %"y'" +; CHECK-NEXT: %ld.row.trans3 = load i8, i8* %malloccall, align 1 +; CHECK-DAG: %[[r36:.+]] = icmp eq i8 %ld.row.trans3, 110 +; CHECK-DAG: %[[r37:.+]] = icmp eq i8 %ld.row.trans3, 78 +; CHECK-DAG: %[[r38:.+]] = or i1 %[[r37]], %[[r36]] +; CHECK-NEXT: %[[r39:.+]] = select i1 %[[r38]], i8* %incx_p, i8* %incy_p +; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %27, i8* %31, i8* %35, i8* %39, i8* %"A'", i8* %lda_p) ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 ; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* ; CHECK-NEXT: call void @dgemv_64_(i8* %byref.transpose.transa, i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"y'", i8* %incy_p, i8* %fpcast.constant.fp.1.0, i8* %"x'", i8* %incx_p) -; CHECK-NEXT: %ld.row.trans = load i8, i8* %malloccall -; CHECK-DAG: %[[r2:.+]] = icmp eq i8 %ld.row.trans, 110 -; CHECK-DAG: %[[r3:.+]] = icmp eq i8 %ld.row.trans, 78 -; CHECK-NEXT: %26 = or i1 %[[r3]], %[[r2]] -; CHECK-NEXT: %27 = select i1 %26, i8* %m_p, i8* %n_p -; CHECK-NEXT: call void @dscal_64_(i8* %27, i8* %beta_p, i8* %"y'", i8* %incy_p) -; CHECK-NEXT: %28 = bitcast double* %0 to i8* -; CHECK-NEXT: tail call void @free(i8* nonnull %28) +; CHECK-NEXT: %ld.row.trans4 = load i8, i8* %malloccall +; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans4, 110 +; CHECK-DAG: %[[r41:.+]] = icmp eq i8 %ld.row.trans4, 78 +; CHECK-NEXT: %[[r42:.+]] = or i1 %[[r41]], %[[r40]] +; CHECK-NEXT: %[[r43:.+]] = select i1 %[[r42]], i8* %m_p, i8* %n_p +; CHECK-NEXT: call void @dscal_64_(i8* %[[r43]], i8* %beta_p, i8* %"y'", i8* %incy_p) +; CHECK-NEXT: %[[r44:.+]] = bitcast double* %0 to i8* +; CHECK-NEXT: tail call void @free(i8* nonnull %[[r44]]) ; CHECK-NEXT: ret void ; CHECK-NEXT: } diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index a55f3eaf717d..d8a3d7c7fffa 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -862,8 +862,16 @@ void emit_deriv_blas_call(DagInit *ruleDag, if (Def->isSubClassOf("MagicInst") && Def->getName() == "Rows") { if (!first) typeString += ", "; - typeString += (Twine("type_") + Dag->getArgNameStr(1)).str(); - first = false; + if (DefInit *def = dyn_cast(Dag->getArg(1))) { + const auto Def = def->getDef(); + assert(Def->isSubClassOf("adj")); + typeString += + (Twine("type_") + Def->getValueAsString("name")).str(); + } else { + assert(Dag->getArgNameStr(1) != ""); + typeString += (Twine("type_") + Dag->getArgNameStr(1)).str(); + first = false; + } continue; } else if (Def->isSubClassOf("MagicInst") && Def->getName() == "ld") { if (!first) @@ -985,7 +993,6 @@ void emit_tmp_creation(Record *Def, raw_ostream &os) { void emit_deriv_rule(const StringMap &patternMap, Rule &rule, StringSet<> &handled, raw_ostream &os) { const auto ruleDag = rule.getRuleDag(); - const auto typeMap = rule.getArgTypeMap(); const auto opName = ruleDag->getOperator()->getAsString(); const auto nameMap = rule.getArgNameMap(); const auto Def = cast(ruleDag->getOperator())->getDef(); @@ -1039,11 +1046,24 @@ void rev_call_arg(StringRef argName, DagInit *ruleDag, Rule &rule, auto Def = cast(Dag->getOperator())->getDef(); if (Def->isSubClassOf("MagicInst") && Def->getName() == "Rows") { - auto tname = Dag->getArgNameStr(0); - auto rname = Dag->getArgNameStr(1); - auto cname = Dag->getArgNameStr(2); - os << "get_blas_row(Builder2, arg_" << tname << ", arg_" << rname - << ", arg_" << cname << ", byRef)"; + std::string tname, rname, cname; + tname = (Twine("arg_") + Dag->getArgNameStr(0)).str(); + if (DefInit *Def1 = dyn_cast(Dag->getArg(1))) { + auto Def1Name = Def1->getDef()->getValueAsString("name"); + assert(Def1->getDef()->isSubClassOf("adj")); + rname = (Twine("d_") + Def1Name).str(); + } else { + rname = (Twine("arg_") + Dag->getArgNameStr(1)).str(); + } + if (DefInit *Def2 = dyn_cast(Dag->getArg(2))) { + auto Def2Name = Def2->getDef()->getValueAsString("name"); + assert(Def2->getDef()->isSubClassOf("adj")); + cname = (Twine("d_") + Def2Name).str(); + } else { + cname = (Twine("arg_") + Dag->getArgNameStr(2)).str(); + } + os << "get_blas_row(Builder2, " << tname << ", " << rname << ", " << cname + << ", byRef)"; } else if (Def->isSubClassOf("MagicInst") && Def->getName() == "ld") { assert(Dag->getNumArgs() == 5); //(ld $A, $transa, $lda, $m, $k) @@ -1191,7 +1211,6 @@ void rev_call_args(StringRef argName, Rule &rule, size_t actArg, raw_ostream &os, int subRule = -1) { const auto nameMap = rule.getArgNameMap(); - const auto typeMap = rule.getArgTypeMap(); auto ruleDag = rule.getRuleDag(); size_t numArgs = ruleDag->getNumArgs(); @@ -1291,20 +1310,36 @@ void emit_runtime_continue(DagInit *ruleDag, StringRef name, StringRef tab, << tab << "}\n"; } -void emit_if_rule_condition(DagInit *ruleDag, StringRef name, StringRef tab, - raw_ostream &os) { - os << tab << "if (active_" << name; +void if_rule_condition_inner(DagInit *ruleDag, StringRef name, StringRef tab, + raw_ostream &os, llvm::StringSet<> &seen) { for (size_t pos = 0; pos < ruleDag->getNumArgs();) { - auto arg = ruleDag->getArg(pos); + Init *arg = ruleDag->getArg(pos); if (DefInit *DefArg = dyn_cast(arg)) { auto Def = DefArg->getDef(); if (Def->isSubClassOf("adj")) { auto name = Def->getValueAsString("name"); - os << " && d_" << name; + seen.insert(name); } + } else if (auto sub_Dag = dyn_cast(arg)) { + if_rule_condition_inner(sub_Dag, name, tab, os, seen); } pos++; } +} + +// primal arguments are always available, +// shadow arguments (d_) might not, so check if they are active +void emit_if_rule_condition(DagInit *ruleDag, StringRef name, StringRef tab, + raw_ostream &os) { + llvm::StringSet<> seen = llvm::StringSet<>(); + + if_rule_condition_inner(ruleDag, name, tab, os, seen); + + // this will only run once, at the end of the outermost call + os << tab << "if (active_" << name; + for (auto name : seen.keys()) { + os << " && d_" << name.str(); + } os << ") {\n"; } From c737e0cd92454f320e796a6a71f94b68fcb36d20 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 19 Sep 2023 18:36:05 -0500 Subject: [PATCH 17/29] Add blas integration infra tests (#1439) * Add blas integration infra tests * Now at least testing * no readonly/readnone * Improve blas tests * Fix printing * Now with memory checker capabilities * With nicer testing --- enzyme/Enzyme/AdjointGenerator.h | 10 +- enzyme/test/Integration/ReverseMode/blas.cpp | 820 +++++++++++++++++++ enzyme/tools/enzyme-tblgen/blasDeclUpdater.h | 9 +- 3 files changed, 831 insertions(+), 8 deletions(-) create mode 100644 enzyme/test/Integration/ReverseMode/blas.cpp diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 99a73eca4415..c4fa2b588e5c 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -8157,15 +8157,13 @@ class AdjointGenerator return; } - if (!called || called->empty()) { - if (auto blas = extractBLAS(funcName)) { + if (auto blas = extractBLAS(funcName)) { #if LLVM_VERSION_MAJOR >= 16 - if (handleBLAS(call, called, blas.value(), overwritten_args)) + if (handleBLAS(call, called, blas.value(), overwritten_args)) #else - if (handleBLAS(call, called, blas.getValue(), overwritten_args)) + if (handleBLAS(call, called, blas.getValue(), overwritten_args)) #endif - return; - } + return; } if (funcName == "printf" || funcName == "puts" || diff --git a/enzyme/test/Integration/ReverseMode/blas.cpp b/enzyme/test/Integration/ReverseMode/blas.cpp new file mode 100644 index 000000000000..a225cbd84013 --- /dev/null +++ b/enzyme/test/Integration/ReverseMode/blas.cpp @@ -0,0 +1,820 @@ +// This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load +// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... +// RUN: %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - +// RUN: %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - +// RUN: %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - +// RUN: %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - +// RUN: %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - +// RUN: %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi + +#include "test_utils.h" + +#include +#include +#include +#include + +template +class vector { + T* data; + size_t capacity; + size_t length; +public: + vector() : data(nullptr), capacity(0), length(0) {} + vector(const vector &prev) : data((T*)malloc(sizeof(T)*prev.capacity)), capacity(prev.capacity), length(prev.length) { + memcpy(data, prev.data, prev.length*sizeof(T)); + } + void operator=(const vector &prev) { + free(data); + data = (T*)malloc(sizeof(T)*prev.capacity); + capacity = prev.capacity; + length = prev.length; + memcpy(data, prev.data, prev.length*sizeof(T)); + } + // Don't destruct to avoi dso handle in global + // ~vector() { free(data); } + + void push_back(T v) { + if (length == capacity) { + size_t next = capacity == 0 ? 1 : (2 * capacity); + data = (T*)realloc(data, sizeof(T)*next); + capacity = next; + } + data[length] = v; + length++; + } + + T& operator[](size_t index) { + assert(index < length); + return data[index]; + } + + const T& operator[] (size_t index) const { + assert(index < length); + return data[index]; + } + + bool operator==(const vector& rhs) const { + if (length != rhs.length) return false; + for (size_t i=0; i &tr, std::string prefix="") { + printf("%sPrimal:\n", prefix.c_str()); + bool reverse = false; + for (size_t i=0; i +void assert_eq(std::string scope, std::string varName, int i, T expected, T real, BlasCall texpected, BlasCall rcall) { + if (expected == real) return; + printf("Failure on call %d var %s found ", i, varName.c_str()); + printty(expected); + printf(" expected "); + printty(real); + printf("\n"); + exit(1); +} + +void check_equiv(std::string scope, int i, BlasCall expected, BlasCall real) { +#define MAKEASSERT(name) assert_eq(scope, #name, i, expected.name, real.name, expected, real); + MAKEASSERT(inDerivative) + MAKEASSERT(type) + MAKEASSERT(pout_arg1); + MAKEASSERT(pin_arg1); + MAKEASSERT(pin_arg2); + MAKEASSERT(farg1); + MAKEASSERT(farg2); + MAKEASSERT(layout); + MAKEASSERT(targ1); + MAKEASSERT(targ2); + MAKEASSERT(iarg1); + MAKEASSERT(iarg2); + MAKEASSERT(iarg3); + MAKEASSERT(iarg4); + MAKEASSERT(iarg5); + MAKEASSERT(iarg6); +} + +vector calls; +vector foundCalls; + +extern "C" { + +// Y = alpha * op(A) * X + beta * Y +__attribute__((noinline)) +void cblas_dgemv(char layout, char trans, int M, int N, double alpha, double* A, int lda, double* X, int incx, double beta, double* Y, int incy) { + BlasCall call = {inDerivative, CallType::GEMV, + Y, A, X, + alpha, beta, + layout, + trans, UNUSED_TRANS, + M, N, UNUSED_INT, lda, incx, incy}; + calls.push_back(call); +} + +// C = alpha * A^transA * B^transB + beta * C +__attribute__((noinline)) +void cblas_dgemm(char layout, char transA, char transB, int M, int N, int K, double alpha, double* A, int lda, double* B, int ldb, double beta, double* C, int ldc) { + calls.push_back((BlasCall){inDerivative, CallType::GEMM, + C, A, B, + alpha, beta, + layout, + transA, transB, + M, N, K, lda, ldb, ldc}); +} + +// X = alpha * X +__attribute__((noinline)) +void cblas_dscal(int N, double alpha, double* X, int incX) { + calls.push_back((BlasCall){inDerivative, CallType::SCAL, + X, UNUSED_POINTER, UNUSED_POINTER, + alpha, UNUSED_DOUBLE, + UNUSED_TRANS, + UNUSED_TRANS, UNUSED_TRANS, + N, UNUSED_INT, UNUSED_INT, incX, UNUSED_INT, UNUSED_INT}); +} + +// A = alpha * X * transpose(Y) + A +__attribute__((noinline)) +void cblas_dger(char layout, int M, int N, double alpha, double* X, int incX, double* Y, int incY, double* A, int lda) { + calls.push_back((BlasCall){inDerivative, CallType::GER, + A, X, Y, + alpha, UNUSED_DOUBLE, + layout, + UNUSED_TRANS, UNUSED_TRANS, + M, N, UNUSED_INT, incX, incY, lda}); +} + +__attribute__((noinline)) +void cblas_dcopy(int N, double* X, int incX, double* Y, int incY) { + + calls.push_back((BlasCall){inDerivative, CallType::COPY, + Y, X, UNUSED_POINTER, + alpha, UNUSED_DOUBLE, + UNUSED_TRANS, + UNUSED_TRANS, UNUSED_TRANS, + N, UNUSED_INT, UNUSED_INT, incX, incY, UNUSED_INT}); +} +} + +enum class ValueType { + Matrix, + Vector +}; +struct BlasInfo { + ValueType ty; + int vec_length; + int vec_increment; + char mat_layout; + int mat_rows; + int mat_cols; + int mat_ld; + BlasInfo (int length, int increment) { + ty = ValueType::Vector; + vec_length = length; + vec_increment = increment; + mat_layout = '@'; + mat_rows = -1; + mat_cols = -1; + mat_ld = -1; + } + BlasInfo (char layout, int rows, int cols, int ld) { + ty = ValueType::Matrix; + vec_length = -1; + vec_increment = -1; + mat_layout = layout; + mat_rows = rows; + mat_cols = cols; + mat_ld = ld; + } +}; + +int pointer_to_index(void* v) { + if (v == A || v == dA) return 0; + if (v == B || v == dB) return 1; + if (v == C || v == dC) return 2; + assert(0 && " illegal pointer to invert"); +} + +void checkVector(BlasInfo info, std::string vecname, int length, int increment, std::string test, BlasCall rcall, const vector & trace) { + if (info.ty != ValueType::Vector) { + printf("Error in test %s, invalid memory\n", test.c_str()); + printTrace(trace); + printcall(rcall); + printf(" Input %s is not a vector\n", vecname.c_str()); + exit(1); + } + if (info.vec_length != length) { + printf("Error in test %s, invalid memory\n", test.c_str()); + printTrace(trace); + printcall(rcall); + printf(" Input %s length must be ", vecname.c_str()); + printty(info.vec_length); + printf(" found "); + printty(length); + printf("\n"); + exit(1); + } + if (info.vec_increment != increment) { + printf("Error in test %s, invalid memory\n", test.c_str()); + printTrace(trace); + printcall(rcall); + printf(" Input %s increment must be ", vecname.c_str()); + printty(info.vec_increment); + printf(" found "); + printty(increment); + printf("\n"); + exit(1); + } +} + +void checkMatrix(BlasInfo info, std::string matname, char layout, int rows, int cols, int ld, std::string test, BlasCall rcall, const vector & trace) { + if (info.ty != ValueType::Matrix) { + printf("Error in test %s, invalid memory\n", test.c_str()); + printTrace(trace); + printcall(rcall); + printf(" Input %s is not a matrix\n", matname.c_str()); + exit(1); + } + if (info.mat_layout != layout) { + printf("Error in test %s, invalid memory\n", test.c_str()); + printTrace(trace); + printcall(rcall); + printf(" Input %s layout must be ", matname.c_str()); + printty(info.mat_layout); + printf(" found layout="); + printty(layout); + printf("\n"); + exit(1); + } + if (info.mat_rows != rows) { + printf("Error in test %s, invalid memory\n", test.c_str()); + printTrace(trace); + printcall(rcall); + printf(" Input %s rows must be ", matname.c_str()); + printty(info.mat_rows); + printf(" found "); + printty(rows); + printf("\n"); + exit(1); + } + if (info.mat_cols != cols) { + printf("Error in test %s, invalid memory\n", test.c_str()); + printTrace(trace); + printcall(rcall); + printf(" Input %s cols must be ", matname.c_str()); + printty(info.mat_cols); + printf(" found "); + printty(cols); + printf("\n"); + exit(1); + } + if (info.mat_ld != ld) { + printf("Error in test %s, invalid memory\n", test.c_str()); + printTrace(trace); + printcall(rcall); + printf(" Input %s leading dimension rows must be ", test.c_str()); + printty(info.mat_ld); + printf(" found "); + printty(ld); + printf("\n"); + exit(1); + } +} + +void checkMemory(BlasCall rcall, BlasInfo inputs[3], std::string test, const vector & trace) { + switch (rcall.type) { + case CallType::GEMV:{ + // Y = alpha * op(A) * X + beta * Y + auto Y = inputs[pointer_to_index(rcall.pout_arg1)]; + auto A = inputs[pointer_to_index(rcall.pin_arg1)]; + auto X = inputs[pointer_to_index(rcall.pin_arg2)]; + + auto layout = rcall.layout; + auto trans_char = rcall.targ1; + auto trans = (trans_char == 'N' || trans_char == 'n'); + auto M = rcall.iarg1; + auto N =rcall.iarg2; + auto alpha = rcall.farg1; + auto lda = rcall.iarg4; + auto incX = rcall.iarg5; + auto beta = rcall.farg2; + auto incY = rcall.iarg6; + + // A is an m-by-n matrix + checkMatrix(A, "A", layout, /*rows=*/M, /*cols=*/N, /*ld=*/lda, test, rcall, trace); + + // if no trans, X must be N otherwise must be M + // From https://www.netlib.org/lapack/explore-html/d7/d15/group__double__blas__level2_gadd421a107a488d524859b4a64c1901a9.html + // X is DOUBLE PRECISION array, dimension at least + // ( 1 + ( n - 1 )*abs( INCX ) ) when TRANS = 'N' or 'n' + // and at least + // ( 1 + ( m - 1 )*abs( INCX ) ) otherwise. + // Before entry, the incremented array X must contain the + // vector x. + auto Xlen = trans ? N : M; + checkVector(X, "X", /*len=*/Xlen, /*inc=*/incX, test, rcall, trace); + + // if no trans, Y must be M otherwise must be N + auto Ylen = trans ? M : N; + checkVector(Y, "Y", /*len=*/Ylen, /*inc=*/incY, test, rcall, trace); + + return; + } + case CallType::GEMM:{ + // C = alpha * A^transA * B^transB + beta * C + auto C = inputs[pointer_to_index(rcall.pout_arg1)]; + auto A = inputs[pointer_to_index(rcall.pin_arg1)]; + auto B = inputs[pointer_to_index(rcall.pin_arg2)]; + + auto layout = rcall.layout; + auto transA = rcall.targ1; + auto transB = rcall.targ2; + auto M = rcall.iarg1; + auto N = rcall.iarg2; + auto K = rcall.iarg3; + auto alpha = rcall.farg1; + auto lda = rcall.iarg4; + auto ldb = rcall.iarg5; + auto beta = rcall.farg2; + auto ldc = rcall.iarg6; + + // From https://www.netlib.org/lapack/explore-html/d1/d54/group__double__blas__level3_gaeda3cbd99c8fb834a60a6412878226e1.html + /* + M is INTEGER + On entry, M specifies the number of rows of the matrix + op( A ) and of the matrix C. M must be at least zero. + N is INTEGER + On entry, N specifies the number of columns of the matrix + op( B ) and the number of columns of the matrix C. N must be + at least zero. + K is INTEGER + On entry, K specifies the number of columns of the matrix + op( A ) and the number of rows of the matrix op( B ). K must + be at least zero. + LDA is INTEGER + On entry, LDA specifies the first dimension of A as declared + in the calling (sub) program. When TRANSA = 'N' or 'n' then + LDA must be at least max( 1, m ), otherwise LDA must be at + least max( 1, k ). + */ + checkMatrix(A, "A", layout, /*rows=*/(!transA) ? M : K, /*cols=*/(!transA) ? K : M, /*ld=*/lda, test, rcall, trace); + checkMatrix(B, "B", layout, /*rows=*/(!transB) ? K : N, /*cols=*/(!transB) ? N : K, /*ld=*/ldb, test, rcall, trace); + checkMatrix(C, "C", layout, /*rows=*/M, /*cols=*/N, /*ld=*/ldc, test, rcall, trace); + return; + } + + case CallType::SCAL: { + auto N = rcall.iarg1; + auto alpha = rcall.farg1; + auto X = inputs[pointer_to_index(rcall.pout_arg1)]; + auto incX = rcall.iarg4; + checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace); + return; + } + case CallType::GER: { + // A = alpha * X * transpose(Y) + A + auto A = inputs[pointer_to_index(rcall.pout_arg1)]; + auto X = inputs[pointer_to_index(rcall.pin_arg1)]; + auto Y = inputs[pointer_to_index(rcall.pin_arg2)]; + + auto layout = rcall.layout; + auto M = rcall.iarg1; + auto N = rcall.iarg2; + auto alpha = rcall.farg1; + auto incX = rcall.iarg4; + auto incY = rcall.iarg5; + auto incA = rcall.iarg6; + + // From https://www.netlib.org/lapack/explore-html/d7/d15/group__double__blas__level2_ga458222e01b4d348e9b52b9343d52f828.html + // x is an m element vector, y is an n element + // vector and A is an m by n matrix. + checkVector(X, "X", /*len=*/M, /*inc=*/incX, test, rcall, trace); + checkVector(Y, "Y", /*len=*/N, /*inc=*/incY, test, rcall, trace); + checkMatrix(A, "A", layout, /*rows=*/M, /*cols=*/N, /*ld=*/incA, test, rcall, trace); + return; + } + case CallType::COPY: { + auto Y = inputs[pointer_to_index(rcall.pout_arg1)]; + auto X = inputs[pointer_to_index(rcall.pin_arg1)]; + + auto N = rcall.iarg1; + auto incX = rcall.iarg4; + checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace); + checkVector(Y, "Y", /*len=*/N, /*inc=*/incX, test, rcall, trace); + return; + } + default: printf("UNKNOWN CALL (%d)", (int)rcall.type); return; + } +} + +void checkMemoryTrace(BlasInfo inputs[3], std::string test, const vector & trace) { + for (size_t i=0; i +void __enzyme_autodiff(void*, T...); + +void my_dgemv(char layout, char trans, int M, int N, double alpha, double* __restrict__ A, int lda, double* __restrict__ X, int incx, double beta, double* __restrict__ Y, int incy) { + cblas_dgemv(layout, trans, M, N, alpha, A, lda, X, incx, beta, Y, incy); + inDerivative = true; +} + +void init() { + inDerivative = false; + calls.clear(); +} + +void checkTest(std::string name) { + if (foundCalls.size() != calls.size()) { + printf("Test %s failed: Expected %zu calls, found %zu\n", name.c_str(), calls.size(), foundCalls.size()); + printf("Expected:\n"); + printTrace(calls, " "); + printf("Found:\n"); + printTrace(foundCalls, " "); + assert(0 && "non-equal call count"); + exit(1); + } + if (foundCalls != calls) { + printf("Test %s failed\n", name.c_str()); + printf("Expected:\n"); + printTrace(calls, " "); + printf("Found:\n"); + printTrace(foundCalls, " "); + } + for (size_t i=0; i= 16\n"; os << " F->setOnlyReadsMemory();\n"; os << "#else\n"; + os << " F->removeFnAttr(llvm::Attribute::ReadNone);\n"; os << " F->addFnAttr(llvm::Attribute::ReadOnly);\n"; os << "#endif\n"; } @@ -76,7 +77,9 @@ void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) { typeOfArg == ArgType::fp || typeOfArg == ArgType::trans || typeOfArg == ArgType::mldLD || typeOfArg == ArgType::uplo || typeOfArg == ArgType::diag || typeOfArg == ArgType::side) { - os << " F->addParamAttr(" << i << (lv23 ? " + offset" : "") + os << " F->removeParamAttr(" << i << (lv23 ? " + offset" : "") + << ", llvm::Attribute::ReadNone);\n" + << " F->addParamAttr(" << i << (lv23 ? " + offset" : "") << ", llvm::Attribute::ReadOnly);\n" << " F->addParamAttr(" << i << (lv23 ? " + offset" : "") << ", llvm::Attribute::NoCapture);\n"; @@ -95,7 +98,9 @@ void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) { << ", llvm::Attribute::NoCapture);\n"; if (mutableArgs.count(argPos) == 0) { // Only emit ReadOnly if the arg isn't mutable - os << " F->addParamAttr(" << i << (lv23 ? " + offset" : "") + os << " F->removeParamAttr(" << i << (lv23 ? " + offset" : "") + << ", llvm::Attribute::ReadNone);\n" + << " F->addParamAttr(" << i << (lv23 ? " + offset" : "") << ", llvm::Attribute::ReadOnly);\n"; } } From 562b7a8439b72dbdf481be9a29a79a50b49bed53 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 19 Sep 2023 20:26:55 -0400 Subject: [PATCH 18/29] add blas axpy support (#1445) --- enzyme/Enzyme/BlasDerivatives.td | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/BlasDerivatives.td b/enzyme/Enzyme/BlasDerivatives.td index 6edf464c2372..5988f13babaa 100644 --- a/enzyme/Enzyme/BlasDerivatives.td +++ b/enzyme/Enzyme/BlasDerivatives.td @@ -134,10 +134,9 @@ def lascl : CallBlasPattern<(Op $layout, $type, $kl, $ku, $cfrom, $cto, $m, $n, def axpy : CallBlasPattern<(Op $n, $alpha, $x, $incx, $y, $incy), ["y"],[len, fp, vinc<["n"]>, vinc<["n"]>], [ - // dot must proceed scal, because scal modifies adj<"x"> - (inactive), - (inactive),//(b<"scal"> $n, $alpha, adj<"x">, $incy), - (inactive) // y = (Ax) + y, so nothing to do here + (b<"dot"> $n, adj<"y">, $incy, $x, $incx), + (b<"axpy"> $n, $alpha, adj<"y">, $incy, adj<"x">, $incx), + (noop) // y = alpha*x + y, so nothing to do here ] >; From 53a15adbfe239b994a826fb99ebeb681c3703e58 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 19 Sep 2023 20:54:07 -0400 Subject: [PATCH 19/29] add blas copy support (#1447) --- enzyme/Enzyme/BlasDerivatives.td | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/enzyme/Enzyme/BlasDerivatives.td b/enzyme/Enzyme/BlasDerivatives.td index 5988f13babaa..43742a32246a 100644 --- a/enzyme/Enzyme/BlasDerivatives.td +++ b/enzyme/Enzyme/BlasDerivatives.td @@ -154,13 +154,13 @@ def dot : CallBlasPattern<(Op $n, $x, $incx, $y, $incy), // >; -// def copy : CallBlasPattern<(Op $n, $x, $incx, $y, $incy), -// ["y"],[len, vinc, vinc], -// [ -// (noop),// copy moves x into y, so x is never modified. -// (b<"axpy"> $n, Constant<"1.0">, adj<"y">, $incy, adj<"x">, $incx) -// ] -// >; +def copy : CallBlasPattern<(Op $n, $x, $incx, $y, $incy), + ["y"],[len, vinc<["n"]>, vinc<["n"]>], + [ + (noop),// copy moves x into y, so x is never modified. + (b<"axpy"> $n, Constant<"1.0">, adj<"y">, $incy, adj<"x">, $incx) + ] + >; // def swap : CallBlasPattern<(Op $n, $x, $incx, $y, $incy), // ["x","y"],[len, vinc, vinc], From edd0331deebf5e875ac312357285f251e4d35d88 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 20 Sep 2023 04:28:50 -0500 Subject: [PATCH 20/29] Add dot blas test (#1446) * Add dot blas test * Now adding caching test * Cleanup and progress * cleanup * Fix autodiff ordering with inlining * Rebase * Fix fortran calling conv --- enzyme/Enzyme/BlasDerivatives.td | 52 +- enzyme/Enzyme/Enzyme.cpp | 144 +++-- enzyme/Enzyme/Utils.cpp | 33 +- enzyme/Enzyme/Utils.h | 9 +- enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll | 10 +- .../test/Enzyme/ReverseMode/blas/gemm_f_c.ll | 10 +- .../Enzyme/ReverseMode/blas/gemm_f_c_lacpy.ll | 12 +- .../blas/gemm_f_c_lacpy_runtime_act.ll | 12 +- .../Enzyme/ReverseMode/blas/gemm_f_c_loop.ll | 8 +- .../Enzyme/ReverseMode/blas/gemm_f_c_split.ll | 10 +- .../ReverseMode/blas/gemm_f_c_split_lacpy.ll | 10 +- .../blas/gemm_f_c_split_transpose_lacpy.ll | 10 +- .../blas/gemm_f_c_transpose_lacpy.ll | 10 +- .../ReverseMode/blas/gemm_f_change_ld.ll | 8 +- .../Enzyme/ReverseMode/blas/gemm_f_lacpy.ll | 10 +- .../Enzyme/ReverseMode/blas/gemm_f_over.ll | 10 +- .../ReverseMode/blas/gemm_f_over_lacpy.ll | 10 +- .../Enzyme/ReverseMode/blas/gemv_c_loop.ll | 2 +- .../Enzyme/ReverseMode/blas/gemv_c_loop2.ll | 5 +- .../ReverseMode/blas/gemv_c_loop3_matcopy.ll | 12 +- .../blas/gemv_f_c_split_blascpy.ll | 52 +- .../gemv_f_c_split_blascpy_runtime_act.ll | 28 +- .../ReverseMode/blas/gemv_f_c_split_memcpy.ll | 36 +- enzyme/test/Integration/ReverseMode/blas.cpp | 609 ++++++++++++++++-- enzyme/tools/enzyme-tblgen/blas-tblgen.cpp | 371 ++++++----- enzyme/tools/enzyme-tblgen/caching.cpp | 6 +- enzyme/tools/enzyme-tblgen/datastructures.cpp | 65 +- enzyme/tools/enzyme-tblgen/datastructures.h | 9 +- 28 files changed, 1082 insertions(+), 481 deletions(-) diff --git a/enzyme/Enzyme/BlasDerivatives.td b/enzyme/Enzyme/BlasDerivatives.td index 43742a32246a..9ff3abf65bbd 100644 --- a/enzyme/Enzyme/BlasDerivatives.td +++ b/enzyme/Enzyme/BlasDerivatives.td @@ -109,8 +109,8 @@ def scal : CallBlasPattern<(Op $n, $alpha, $x, $incx), ["x"],[len, fp, vinc<["n"]>], [ // dot must proceed scal, because scal modifies adj<"x"> - (b<"dot"> $n, $x, $incx, adj<"x">, $incx), - (b<"scal"> $n, $alpha, adj<"x">, $incx) + (b<"dot"> $n, $x, adj<"x">), + (b<"scal"> $n, $alpha, adj<"x">) ] >; @@ -134,8 +134,8 @@ def lascl : CallBlasPattern<(Op $layout, $type, $kl, $ku, $cfrom, $cto, $m, $n, def axpy : CallBlasPattern<(Op $n, $alpha, $x, $incx, $y, $incy), ["y"],[len, fp, vinc<["n"]>, vinc<["n"]>], [ - (b<"dot"> $n, adj<"y">, $incy, $x, $incx), - (b<"axpy"> $n, $alpha, adj<"y">, $incy, adj<"x">, $incx), + (b<"dot"> $n, adj<"y">, $x), + (b<"axpy"> $n, $alpha, adj<"y">, adj<"x">), (noop) // y = alpha*x + y, so nothing to do here ] >; @@ -143,8 +143,8 @@ def axpy : CallBlasPattern<(Op $n, $alpha, $x, $incx, $y, $incy), def dot : CallBlasPattern<(Op $n, $x, $incx, $y, $incy), [],[len, vinc<["n"]>, vinc<["n"]>], [ - (b<"axpy"> $n, DiffeRet, $y, $incy, adj<"x">, $incx), - (b<"axpy"> $n, DiffeRet, $x, $incx, adj<"y">, $incy) + (b<"axpy"> $n, DiffeRet, $y, adj<"x">), + (b<"axpy"> $n, DiffeRet, $x, adj<"y">), ] >; @@ -158,7 +158,7 @@ def copy : CallBlasPattern<(Op $n, $x, $incx, $y, $incy), ["y"],[len, vinc<["n"]>, vinc<["n"]>], [ (noop),// copy moves x into y, so x is never modified. - (b<"axpy"> $n, Constant<"1.0">, adj<"y">, $incy, adj<"x">, $incx) + (b<"axpy"> $n, Constant<"1.0">, adj<"y">, adj<"x">) ] >; @@ -184,8 +184,8 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $ ["y"], [cblas_layout, trans, len, len, fp, mld<["m", "n"]>, vinc<["transa", "n", "m"]>, fp, vinc<["transa", "m", "n"]>], [ /* alpha */ (Seq<["Ax", "is_normal", "transa", "m", "n"]> - (b<"gemv"> $layout, $transa, $m, $n, Constant<"1.0">, $A, (ld $A, $transa, $lda, $m, $n), $x, $incx, Constant<"0.0">, use<"Ax">, ConstantInt<1>), - (b<"dot"> (Rows $transa, $m, $n), adj<"y">, $incy, use<"Ax">, ConstantInt<1>)), + (b<"gemv"> $layout, $transa, $m, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $m, $n), $x, Constant<"0.0">, use<"Ax">, ConstantInt<1>), + (b<"dot"> (Rows $transa, $m, $n), adj<"y">, use<"Ax">, ConstantInt<1>)), //if (is_normal $transa) { // call sger(m, n, alpha, ya, incy, x, incx, Aa, lda) @@ -193,13 +193,11 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $ // call sger(m, n, alpha, x, incx, ya, incy, Aa, lda) //} /* A */ (b<"ger"> $layout, $m, $n, $alpha, (Rows $transa, adj<"y">, $x), - (Rows $transa, $incy, $incx), (Rows $transa, $x, adj<"y">), - (Rows $transa, $incx, $incy), - adj<"A">, $lda), - /* x */ (b<"gemv"> $layout, transpose<"transa">, $m, $n, $alpha, $A, (ld $A, $transa, $lda, $m, $n), adj<"y">, $incy, Constant<"1.0">, adj<"x">, $incx), - /* beta */ (b<"dot"> (Rows $transa, $m, $n), adj<"y">, $incy, input<"y">, $incy), - /* y */ (b<"scal"> (Rows $transa, $m, $n), $beta, adj<"y">, $incy) + adj<"A">), + /* x */ (b<"gemv"> $layout, transpose<"transa">, $m, $n, $alpha, $A, (ld $A, Char<"N">, $lda, $m, $n), adj<"y">, Constant<"1.0">, adj<"x">), + /* beta */ (b<"dot"> (Rows $transa, $m, $n), adj<"y">, input<"y">), + /* y */ (b<"scal"> (Rows $transa, $m, $n), $beta, adj<"y">) ] >; // @@ -226,11 +224,11 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A /* alpha */ (Seq<["AB", "product", "m", "n"]> (b<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $m, $k), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n - (FrobInnerProd<""> $m, $n, adj<"C">, $ldc, use<"AB">)), - /* A */ (b<"gemm"> $layout, $transa, transpose<"transb">, $m, $k, $n, $alpha, adj<"C">, $ldc, $B, (ld $B, $transb, $ldb, $k, $n), $beta, adj<"A">, $lda), - /* B */ (b<"gemm"> $layout, transpose<"transa">, $transb, $k, $n, $m, $alpha, $A, (ld $A, $transa, $lda, $m, $k), adj<"C">, $ldc, $beta, adj<"B">, $ldb), - /* beta */ (FrobInnerProd<""> $m, $n, adj<"C">, $ldc, input<"C">), - /* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, $ldc, ConstantInt<0>) + (FrobInnerProd<""> $m, $n, adj<"C">, use<"AB">)), + /* A */ (b<"gemm"> $layout, $transa, transpose<"transb">, $m, $k, $n, $alpha, adj<"C">, $B, (ld $B, $transb, $ldb, $k, $n), $beta, adj<"A">), + /* B */ (b<"gemm"> $layout, transpose<"transa">, $transb, $k, $n, $m, $alpha, $A, (ld $A, $transa, $lda, $m, $k), adj<"C">, $beta, adj<"B">), + /* beta */ (FrobInnerProd<""> $m, $n, adj<"C">, input<"C">), + /* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, ConstantInt<0>) ] >; @@ -239,14 +237,14 @@ def spmv : CallBlasPattern<(Op $layout, $uplo, $n, $alpha, $ap, $x, $incx, $beta [cblas_layout, uplo, len, fp, ap<["n"]>, vinc<["n"]>, fp, vinc<["n"]>], [ /* alpha */ (Seq<["y0", "triangular", "n"]> - (b<"spmv"> $layout, $uplo, $n, Constant<"1.0">, $ap, $x, $incx, Constant<"0.0">, use<"y0">, ConstantInt<1>), - (b<"dot"> $n, adj<"y">, $incy, use<"y0">, ConstantInt<1>)), + (b<"spmv"> $layout, $uplo, $n, Constant<"1.0">, $ap, $x, Constant<"0.0">, use<"y0">, ConstantInt<1>), + (b<"dot"> $n, adj<"y">, use<"y0">, ConstantInt<1>)), /* ap */ (Seq<[]> - (b<"spr2"> $layout, $uplo, $n, $alpha, $x, $incx, adj<"y">, $incy, adj<"ap">), - (DiagUpdateSPMV<""> $uplo, $n, $alpha, $x, $incx, adj<"y">, $incy, adj<"ap">)), - /* x */ (b<"spmv"> $layout, $uplo, $n, $alpha, $ap, adj<"y">, $incy, Constant<"1.0">, adj<"x">, $incx), - /* beta */ (b<"dot"> $n, adj<"y">, $incy, input<"y">, $incy), - /* y */ (b<"scal"> $n, $beta, adj<"y">, $incy) + (b<"spr2"> $layout, $uplo, $n, $alpha, $x, adj<"y">, adj<"ap">), + (DiagUpdateSPMV<""> $uplo, $n, $alpha, $x, adj<"y">, adj<"ap">)), + /* x */ (b<"spmv"> $layout, $uplo, $n, $alpha, $ap, adj<"y">, Constant<"1.0">, adj<"x">), + /* beta */ (b<"dot"> $n, adj<"y">, input<"y">), + /* y */ (b<"scal"> $n, $beta, adj<"y">) ] >; diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 793d6dacef34..4772edbedfe5 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -1409,7 +1409,8 @@ class EnzymeBase { Type *retElemType, SmallVectorImpl &args, const std::map &byVal, const std::vector &constants, Function *fn, - DerivativeMode mode, Options &options, bool sizeOnly) { + DerivativeMode mode, Options &options, bool sizeOnly, + SmallVectorImpl &calls) { auto &differet = options.differet; auto &tape = options.tape; auto &width = options.width; @@ -1702,63 +1703,13 @@ class EnzymeBase { } ReplaceOriginalCall(Builder, ret, retElemType, diffret, CI, mode); - - if (Logic.PostOpt) { - auto Params = llvm::getInlineParams(); - - llvm::SetVector Q; - Q.insert(diffretc); - while (Q.size()) { - auto cur = *Q.begin(); - Function *outerFunc = cur->getParent()->getParent(); - llvm::OptimizationRemarkEmitter ORE(outerFunc); - Q.erase(Q.begin()); - if (auto F = cur->getCalledFunction()) { - if (!F->empty()) { - // Garbage collect AC's created - SmallVector ACAlloc; - auto getAC = [&](Function &F) -> llvm::AssumptionCache & { - auto AC = new AssumptionCache(F); - ACAlloc.push_back(AC); - return *AC; - }; - auto GetTLI = - [&](llvm::Function &F) -> const llvm::TargetLibraryInfo & { - return Logic.PPC.FAM.getResult(F); - }; - - auto GetInlineCost = [&](CallBase &CB) { - TargetTransformInfo TTI(F->getParent()->getDataLayout()); - auto cst = llvm::getInlineCost(CB, Params, TTI, getAC, GetTLI); - return cst; - }; - if (llvm::shouldInline(*cur, GetInlineCost, ORE)) { - InlineFunctionInfo IFI; - InlineResult IR = InlineFunction(*cur, IFI); - if (IR.isSuccess()) { - LowerSparsification(outerFunc, /*replaceAll*/ false); - for (auto U : outerFunc->users()) { - if (auto CI = dyn_cast(U)) { - if (CI->getCalledFunction() == outerFunc) { - Q.insert(CI); - } - } - } - } - } - for (auto AC : ACAlloc) { - delete AC; - } - } - } - } - } - return true; + calls.push_back(diffretc); + return diffret; } /// Return whether successful - bool HandleAutoDiffArguments(CallInst *CI, DerivativeMode mode, - bool sizeOnly) { + bool HandleAutoDiffArguments(CallInst *CI, DerivativeMode mode, bool sizeOnly, + SmallVectorImpl &calls) { // determine function to differentiate Function *fn = parseFunctionParameter(CI); @@ -1796,16 +1747,17 @@ class EnzymeBase { #if LLVM_VERSION_MAJOR >= 16 return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args, - byVal, constants, fn, mode, options.value(), - sizeOnly); + byVal, constants, fn, mode, options.value(), sizeOnly, + calls); #else return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args, byVal, constants, fn, mode, options.getValue(), - sizeOnly); + sizeOnly, calls); #endif } - bool HandleProbProg(CallInst *CI, ProbProgMode mode) { + bool HandleProbProg(CallInst *CI, ProbProgMode mode, + SmallVectorImpl &calls) { IRBuilder<> Builder(CI); Function *F = parseFunctionParameter(CI); if (!F) @@ -1928,13 +1880,15 @@ class EnzymeBase { } #if LLVM_VERSION_MAJOR >= 16 - bool status = HandleAutoDiff( - CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants, - newFunc, DerivativeMode::ReverseModeCombined, opt.value(), false); + bool status = + HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, + constants, newFunc, DerivativeMode::ReverseModeCombined, + opt.value(), false, calls); #else - bool status = HandleAutoDiff( - CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants, - newFunc, DerivativeMode::ReverseModeCombined, opt.getValue(), false); + bool status = + HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, + constants, newFunc, DerivativeMode::ReverseModeCombined, + opt.getValue(), false, calls); #endif delete interface; @@ -2447,17 +2401,19 @@ class EnzymeBase { Changed = true; } + SmallVector calls; + // Perform all the size replacements first to create constants for (auto pair : toSize) { bool successful = HandleAutoDiffArguments(pair.first, pair.second, - /*sizeOnly*/ true); + /*sizeOnly*/ true, calls); Changed = true; if (!successful) break; } for (auto pair : toLower) { bool successful = HandleAutoDiffArguments(pair.first, pair.second, - /*sizeOnly*/ false); + /*sizeOnly*/ false, calls); Changed = true; if (!successful) break; @@ -2495,7 +2451,59 @@ class EnzymeBase { } for (auto &&[call, mode] : toProbProg) { - HandleProbProg(call, mode); + HandleProbProg(call, mode, calls); + } + + if (Logic.PostOpt) { + auto Params = llvm::getInlineParams(); + + llvm::SetVector Q; + for (auto call : calls) + Q.insert(call); + while (Q.size()) { + auto cur = *Q.begin(); + Function *outerFunc = cur->getParent()->getParent(); + llvm::OptimizationRemarkEmitter ORE(outerFunc); + Q.erase(Q.begin()); + if (auto F = cur->getCalledFunction()) { + if (!F->empty()) { + // Garbage collect AC's created + SmallVector ACAlloc; + auto getAC = [&](Function &F) -> llvm::AssumptionCache & { + auto AC = new AssumptionCache(F); + ACAlloc.push_back(AC); + return *AC; + }; + auto GetTLI = + [&](llvm::Function &F) -> const llvm::TargetLibraryInfo & { + return Logic.PPC.FAM.getResult(F); + }; + + auto GetInlineCost = [&](CallBase &CB) { + TargetTransformInfo TTI(F->getParent()->getDataLayout()); + auto cst = llvm::getInlineCost(CB, Params, TTI, getAC, GetTLI); + return cst; + }; + if (llvm::shouldInline(*cur, GetInlineCost, ORE)) { + InlineFunctionInfo IFI; + InlineResult IR = InlineFunction(*cur, IFI); + if (IR.isSuccess()) { + LowerSparsification(outerFunc, /*replaceAll*/ false); + for (auto U : outerFunc->users()) { + if (auto CI = dyn_cast(U)) { + if (CI->getCalledFunction() == outerFunc) { + Q.insert(CI); + } + } + } + } + } + for (auto AC : ACAlloc) { + delete AC; + } + } + } + } } if (Changed && EnzymeAttributor) { diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 94de076a4af1..3445550a1b13 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -645,7 +645,8 @@ void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, void callMemcpyStridedLapack(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, llvm::ArrayRef args, llvm::ArrayRef bundles) { - std::string copy_name = (blas.floatType + "lacpy" + blas.suffix).str(); + std::string copy_name = + (blas.prefix + blas.floatType + "lacpy" + blas.suffix).str(); SmallVector tys; for (auto arg : args) @@ -2554,14 +2555,16 @@ llvm::Value *transpose(IRBuilder<> &B, llvm::Value *V) { // } else { // ld_A = arg_lda; // } -llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B, llvm::Value *trans, +llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B, + llvm::ArrayRef trans, llvm::Value *arg_ld, llvm::Value *dim1, llvm::Value *dim2, bool cacheMat, bool byRef) { if (!cacheMat) return arg_ld; - Value *width = B.CreateSelect(is_normal(B, trans, byRef), dim1, dim2); + assert(trans.size() == 1); + Value *width = CreateSelect(B, is_normal(B, trans[0], byRef), dim1, dim2); return width; } @@ -2593,19 +2596,27 @@ llvm::Value *load_if_ref(llvm::IRBuilder<> &B, llvm::IntegerType *intType, return B.CreateLoad(intType, VP); } -llvm::Value *get_blas_row(llvm::IRBuilder<> &B, llvm::Value *trans, - llvm::Value *row, llvm::Value *col, bool byRef) { - +SmallVector get_blas_row(llvm::IRBuilder<> &B, + ArrayRef transA, + ArrayRef row, + ArrayRef col, + bool byRef) { + assert(transA.size() == 1); + auto trans = transA[0]; if (byRef) { auto charType = IntegerType::get(trans->getContext(), 8); trans = B.CreateLoad(charType, trans, "ld.row.trans"); } - return B.CreateSelect( - B.CreateOr( - B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')), - B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n'))), - row, col); + auto cond = B.CreateOr( + B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')), + B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n'))); + assert(row.size() == col.size()); + SmallVector toreturn; + for (size_t i = 0; i < row.size(); i++) { + toreturn.push_back(B.CreateSelect(cond, row[i], col[i])); + } + return toreturn; } // return how many Special pointers are in T (count > 0), diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 3ace54029ca3..8d8d47bcff51 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -1637,7 +1637,8 @@ llvm::Value *to_blas_fp_callconv(llvm::IRBuilder<> &B, llvm::Value *V, llvm::IRBuilder<> &entryBuilder, llvm::Twine const & = ""); -llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B, llvm::Value *trans, +llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B, + llvm::ArrayRef trans, llvm::Value *arg_ld, llvm::Value *dim_1, llvm::Value *dim_2, bool cacheMat, bool byRef); @@ -1651,8 +1652,10 @@ llvm::Value *transpose(llvm::IRBuilder<> &B, llvm::Value *V); llvm::Value *transpose(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef, llvm::IntegerType *IT, llvm::IRBuilder<> &entryBuilder, const llvm::Twine &name); -llvm::Value *get_blas_row(llvm::IRBuilder<> &B, llvm::Value *trans, - llvm::Value *row, llvm::Value *col, bool byRef); +llvm::SmallVector +get_blas_row(llvm::IRBuilder<> &B, llvm::ArrayRef trans, + llvm::ArrayRef row, + llvm::ArrayRef col, bool byRef); // Parameter attributes from the original function/call that // we should preserve on the primal of the derivative code. diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll index 6cd71c23be62..9b514b2be53b 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* %C, i8* %A, i8* %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 8, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ret void } @@ -84,7 +84,7 @@ entry: ; CHECK-NEXT: store i64 8, i64* %ldb, align 16 ; CHECK-NEXT: store double 0.000000e+00, double* %beta ; CHECK-NEXT: store i64 4, i64* %ldc, align 16 -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: br label %invertentry ; CHECK: invertentry: ; preds = %entry @@ -110,8 +110,8 @@ entry: ; CHECK-NEXT: store i8 %[[i25]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p) -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll index 5a3d1793dbbe..991ae30a792c 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* %C, i8* %A, i8* %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 8, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) %ptr = bitcast i8* %A to double* store double 0.0000000e+00, double* %ptr, align 8 ret void @@ -176,7 +176,7 @@ entry: ; CHECK-NEXT: br i1 %39, label %[[enzyme_memcpy_double_mat_64_exit21]], label %[[init_idx]] ; CHECK: [[enzyme_memcpy_double_mat_64_exit21]]: ; preds = %__enzyme_memcpy_double_mat_64.exit, %[[init_end_i18]] -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %"ptr'ipc" = bitcast i8* %"A'" to double* ; CHECK-NEXT: %ptr = bitcast i8* %A to double* ; CHECK-NEXT: store double 0.000000e+00, double* %ptr, align 8, !alias.scope !10, !noalias !13 @@ -213,13 +213,13 @@ entry: ; CHECK-DAG: %[[r19:.+]] = icmp eq i8 %loaded.trans4, 110 ; CHECK-DAG: %[[r20:.+]] = or i1 %[[r19]], %[[r18]] ; CHECK-DAG: %[[r21:.+]] = select i1 %[[r20]], i8* %k_p, i8* %n_p -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %[[i43]], i8* %[[r21]], i8* %beta_p, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %[[i43]], i8* %[[r21]], i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) ; CHECK-NEXT: %loaded.trans5 = load i8, i8* %transa ; CHECK-DAG: %[[r22:.+]] = icmp eq i8 %loaded.trans5, 78 ; CHECK-DAG: %[[r23:.+]] = icmp eq i8 %loaded.trans5, 110 ; CHECK-DAG: %[[r24:.+]] = or i1 %[[r23]], %[[r22]] ; CHECK-DAG: %[[r25:.+]] = select i1 %[[r24]], i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %[[i42]], i8* %[[r25]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %[[i42]], i8* %[[r25]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy.ll index 9b09574ce220..c9e40aac8e36 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-lapack-copy=1 -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-lapack-copy=1 -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* noalias %C, i8* noalias %A, i8* noalias %B, i8* noalias %alpha, i8* noalias %beta) { entry: @@ -27,7 +27,7 @@ entry: store i64 4, i64* %lda, align 16 store i64 8, i64* %ldb, align 16 store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p, i64 1, i64 1) %ptr = bitcast i8* %A to double* store double 0.0000000e+00, double* %ptr, align 8 ret void @@ -117,7 +117,7 @@ entry: ; CHECK-NEXT: %malloccall6 = tail call noalias nonnull i8* @malloc(i64 %mallocsize5) ; CHECK-NEXT: %mat_AB = bitcast i8* %malloccall6 to double* ; CHECK-NEXT: %[[i21:.+]] = bitcast double* %mat_AB to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %"ptr'ipc" = bitcast i8* %"A'" to double* ; CHECK-NEXT: %ptr = bitcast i8* %A to double* ; CHECK-NEXT: store double 0.000000e+00, double* %ptr, align 8, !alias.scope !0, !noalias !3 @@ -159,7 +159,7 @@ entry: ; CHECK-NEXT: %[[i44:.+]] = select i1 %[[i43]], i8* %m_p, i8* %k_p ; CHECK-NEXT: store double 0.000000e+00, double* %byref.constant.fp.0.0 ; CHECK-NEXT: %fpcast.constant.fp.0.0 = bitcast double* %byref.constant.fp.0.0 to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %fpcast.constant.fp.1.0, i8* %[[matA]], i8* %[[i44]], i8* %B, i8* %ldb_p, i8* %fpcast.constant.fp.0.0, i8* %[[i21]], i8* %m_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %fpcast.constant.fp.1.0, i8* %[[matA]], i8* %[[i44]], i8* %B, i8* %ldb_p, i8* %fpcast.constant.fp.0.0, i8* %[[i21]], i8* %m_p, i64 1, i64 1) ; CHECK: %[[i45:.+]] = bitcast i64* %byref.constant.one.i to i8* ; CHECK: %[[i46:.+]] = bitcast i64* %byref.mat.size.i to i8* ; CHECK: store i64 1, i64* %byref.constant.one.i @@ -207,13 +207,13 @@ entry: ; CHECK-NEXT: %[[i62:.+]] = load double, double* %[[i61]] ; CHECK-NEXT: %[[i63:.+]] = fadd fast double %[[i62]], %res.i ; CHECK-NEXT: store double %[[i63]], double* %[[i61]] -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %"A'", i8* %lda_p, i64 1, i64 1) ; CHECK-NEXT: %loaded.trans8 = load i8, i8* %transa ; CHECK-DAG: %[[i64:.+]] = icmp eq i8 %loaded.trans8, 78 ; CHECK-DAG: %[[i65:.+]] = icmp eq i8 %loaded.trans8, 110 ; CHECK-DAG: %[[i66:.+]] = or i1 %[[i65]], %[[i64]] ; CHECK-NEXT: %[[i67:.+]] = select i1 %[[i66]], i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha, i8* %[[matA]], i8* %[[i67]], i8* %"C'", i8* %ldc_p, i8* %beta, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha, i8* %[[matA]], i8* %[[i67]], i8* %"C'", i8* %ldc_p, i8* %beta, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK: %[[i68:.+]] = bitcast i64* %byref.constant.one.i15 to i8* ; CHECK: %[[i69:.+]] = bitcast i64* %byref.mat.size.i18 to i8* ; CHECK: store i64 1, i64* %byref.constant.one.i15 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll index 1aae37d5b5d7..f8d13293ef97 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-lapack-copy=1 -enzyme-runtime-activity=1 -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-lapack-copy=1 -enzyme-runtime-activity=1 -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* noalias %C, i8* noalias %A, i8* noalias %B, i8* noalias %alpha, i8* noalias %beta) { entry: @@ -27,7 +27,7 @@ entry: store i64 4, i64* %lda, align 16 store i64 8, i64* %ldb, align 16 store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p, i64 1, i64 1) %ptr = bitcast i8* %A to double* store double 0.0000000e+00, double* %ptr, align 8 ret void @@ -122,7 +122,7 @@ entry: ; CHECK-NEXT: %malloccall6 = tail call noalias nonnull i8* @malloc(i64 %mallocsize5) ; CHECK-NEXT: %mat_AB = bitcast i8* %malloccall6 to double* ; CHECK-NEXT: %[[i21:.+]] = bitcast double* %mat_AB to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %"ptr'ipc" = bitcast i8* %"A'" to double* ; CHECK-NEXT: %ptr = bitcast i8* %A to double* ; CHECK-NEXT: store double 0.000000e+00, double* %ptr, align 8, !alias.scope !0, !noalias !3 @@ -167,7 +167,7 @@ entry: ; CHECK-NEXT: %[[i44:.+]] = select i1 %[[i43]], i8* %m_p, i8* %k_p ; CHECK-NEXT: store double 0.000000e+00, double* %byref.constant.fp.0.0 ; CHECK-NEXT: %fpcast.constant.fp.0.0 = bitcast double* %byref.constant.fp.0.0 to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %fpcast.constant.fp.1.0, i8* %[[matA]], i8* %[[i44]], i8* %B, i8* %ldb_p, i8* %fpcast.constant.fp.0.0, i8* %[[i21]], i8* %m_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %fpcast.constant.fp.1.0, i8* %[[matA]], i8* %[[i44]], i8* %B, i8* %ldb_p, i8* %fpcast.constant.fp.0.0, i8* %[[i21]], i8* %m_p, i64 1, i64 1) ; CHECK: %[[i45:.+]] = bitcast i64* %byref.constant.one.i to i8* ; CHECK: %[[i46:.+]] = bitcast i64* %byref.mat.size.i to i8* ; CHECK: store i64 1, i64* %byref.constant.one.i @@ -221,7 +221,7 @@ entry: ; CHECK-NEXT: br i1 %rt.inactive.A, label %invertentry.A.done, label %invertentry.A.active ; CHECK: invertentry.A.active: ; preds = %invertentry.alpha.done -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %"A'", i8* %lda_p, i64 1, i64 1) ; CHECK-NEXT: br label %invertentry.A.done ; CHECK: invertentry.A.done: ; preds = %invertentry.A.active, %invertentry.alpha.done @@ -233,7 +233,7 @@ entry: ; CHECK-DAG: %[[i65:.+]] = icmp eq i8 %loaded.trans8, 110 ; CHECK-DAG: %[[i66:.+]] = or i1 %[[i65]], %[[i64]] ; CHECK-NEXT: %[[i67:.+]] = select i1 %[[i66]], i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha, i8* %[[matA]], i8* %[[i67]], i8* %"C'", i8* %ldc_p, i8* %beta, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha, i8* %[[matA]], i8* %[[i67]], i8* %"C'", i8* %ldc_p, i8* %beta, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: br label %invertentry.B.done ; CHECK: invertentry.B.done: ; preds = %invertentry.B.active, %invertentry.A.done diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll index ca9e23085eaf..44cc10d80949 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) declare i8* @AData(i64) declare i8* @Aldap(i64) @@ -43,7 +43,7 @@ loop: store i64 4, i64* %ldc, align 16 %A = call i8* @AData(i64 %i) "enzyme_inactive" %lda_p = call i8* @Aldap(i64 %i) "enzyme_inactive" - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) call void @free(i8* %m_p) %cmp = icmp eq i64 %inc, 10 br i1 %cmp, label %exit, label %loop @@ -187,7 +187,7 @@ entry: ; CHECK-NEXT: br i1 %23, label %__enzyme_memcpy_double_mat_64.exit, label %init.idx.i ; CHECK: __enzyme_memcpy_double_mat_64.exit: ; preds = %loop, %init.end.i -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: call void @free(i8* %m_p) ; CHECK-NEXT: %cmp = icmp eq i64 %iv.next, 10 ; CHECK-NEXT: br i1 %cmp, label %exit, label %loop @@ -275,7 +275,7 @@ entry: ; CHECK-DAG: %[[r19:.+]] = icmp eq i8 %loaded.trans30, 110 ; CHECK-DAG: %[[r20:.+]] = or i1 %[[r19]], %[[r18]] ; CHECK-DAG: %[[r21:.+]] = select i1 %[[r20]], i8* %[[i46]], i8* %cast.k -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %byref.transb, i8* %cast.k, i8* %n_p_unwrap, i8* %[[i46]], i8* %cast.alpha, i8* %[[i44]], i8* %[[r21]], i8* %"C'", i8* %cast.ldc, i8* %cast.beta, i8* %"B'", i8* %cast.ldb) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %byref.transb, i8* %cast.k, i8* %n_p_unwrap, i8* %[[i46]], i8* %cast.alpha, i8* %[[i44]], i8* %[[r21]], i8* %"C'", i8* %cast.ldc, i8* %cast.beta, i8* %"B'", i8* %cast.ldb, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll index d65e7c4aba61..b3a6f1743944 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* noalias %C, i8* noalias %A, i8* noalias %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 8, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ret void } @@ -146,7 +146,7 @@ entry: ; CHECK-NEXT: br i1 %36, label %__enzyme_memcpy_double_mat_64.exit, label %init.idx.i ; CHECK: __enzyme_memcpy_double_mat_64.exit: ; preds = %entry, %init.end.i -; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %malloccall1, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %malloccall1, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %37 = load double*, double** %0 ; CHECK-NEXT: ret double* %37 ; CHECK-NEXT: } @@ -232,13 +232,13 @@ entry: ; CHECK-NEXT: store i8 %[[r15]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) ; CHECK-NEXT: %loaded.trans = load i8, i8* %malloccall ; CHECK-DAG: %[[r18:.+]] = icmp eq i8 %loaded.trans, 78 ; CHECK-DAG: %[[r19:.+]] = icmp eq i8 %loaded.trans, 110 ; CHECK-DAG: %[[r20:.+]] = or i1 %[[r19]], %[[r18]] ; CHECK-NEXT: %[[r21:.+]] = select i1 %[[r20]], i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %malloccall1, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %17, i8* %[[r21]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %malloccall1, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %17, i8* %[[r21]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll index 752f85319613..b15204f9b31d 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-lapack-copy=1 -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-lapack-copy=1 -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* noalias %C, i8* noalias %A, i8* noalias %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 8, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ret void } @@ -118,7 +118,7 @@ entry: ; CHECK-NEXT: store double* %cache.A, double** %0 ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage ; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %20, i8* %21, i8* %A, i8* %lda_p, double* %cache.A, i8* %20) -; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %malloccall1, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %malloccall1, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %[[ret:.+]] = load double*, double** %0 ; CHECK-NEXT: ret double* %[[ret]] ; CHECK-NEXT: } @@ -204,13 +204,13 @@ entry: ; CHECK-NEXT: store i8 %[[r15]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) ; CHECK-NEXT: %loaded.trans = load i8, i8* %malloccall ; CHECK-DAG: %[[r18:.+]] = icmp eq i8 %loaded.trans, 78 ; CHECK-DAG: %[[r19:.+]] = icmp eq i8 %loaded.trans, 110 ; CHECK-NEXT: %[[r20:.+]] = or i1 %[[r19]], %[[r18]] ; CHECK-NEXT: %[[r21:.+]] = select i1 %[[r20]], i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %malloccall1, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %17, i8* %[[r21]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %malloccall1, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %17, i8* %[[r21]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll index 689904cc286b..4561682b3ae9 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-lapack-copy=1 -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-lapack-copy=1 -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* noalias %C, i8* noalias %A, i8* noalias %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 8, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ret void } @@ -118,7 +118,7 @@ entry: ; CHECK-NEXT: store double* %cache.A, double** %0 ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage ; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %20, i8* %21, i8* %A, i8* %lda_p, double* %cache.A, i8* %20) -; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %malloccall1, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %malloccall1, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %[[ret:.+]] = load double*, double** %0 ; CHECK-NEXT: ret double* %[[ret]] ; CHECK-NEXT: } @@ -204,13 +204,13 @@ entry: ; CHECK-NEXT: store i8 %[[i33]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) ; CHECK-NEXT: %loaded.trans = load i8, i8* %malloccall ; CHECK-DAG: %[[i34:.+]] = icmp eq i8 %loaded.trans, 78 ; CHECK-DAG: %[[i35:.+]] = icmp eq i8 %loaded.trans, 110 ; CHECK-DAG: %36 = or i1 %[[i35]], %[[i34]] ; CHECK-NEXT: %37 = select i1 %36, i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %malloccall1, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %17, i8* %37, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %malloccall1, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %17, i8* %37, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll index 7241127eb8bb..d6c145ee54ae 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-lapack-copy=1 -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-lapack-copy=1 -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* %C, i8* %A, i8* %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 8, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) %ptr = bitcast i8* %A to double* store double 0.0000000e+00, double* %ptr, align 8 ret void @@ -120,7 +120,7 @@ entry: ; CHECK-NEXT: %cache.B = bitcast i8* %[[malloccall2]] to double* ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage4 ; CHECK-NEXT: call void @dlacpy_64_(i8* %[[byrefgarbage2]], i8* %13, i8* %14, i8* %B, i8* %ldb_p, double* %cache.B, i8* %13) -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %"ptr'ipc" = bitcast i8* %"A'" to double* ; CHECK-NEXT: %ptr = bitcast i8* %A to double* ; CHECK-NEXT: store double 0.000000e+00, double* %ptr, align 8, !alias.scope !0, !noalias !3 @@ -157,13 +157,13 @@ entry: ; CHECK-DAG: %[[i41:.+]] = icmp eq i8 %loaded.trans5, 110 ; CHECK-NEXT: %[[i42:.+]] = or i1 %[[i41]], %[[i40]] ; CHECK-NEXT: %[[i43:.+]] = select i1 %[[i42]], i8* %k_p, i8* %n_p -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %[[i25]], i8* %[[i43]], i8* %beta_p, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %[[i25]], i8* %[[i43]], i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) ; CHECK-NEXT: %[[cachedtrans2:.+]] = load i8, i8* %transa ; CHECK-DAG: %[[i54:.+]] = icmp eq i8 %[[cachedtrans2]], 78 ; CHECK-DAG: %[[i55:.+]] = icmp eq i8 %[[cachedtrans2]], 110 ; CHECK-NEXT: %[[i56:.+]] = or i1 %[[i55]], %[[i54]] ; CHECK-NEXT: %[[i57:.+]] = select i1 %[[i56]], i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %[[i24]], i8* %[[i57]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %[[i24]], i8* %[[i57]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %[[intcast0:.+]] = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll index 2afc2b06f78a..51879c3537da 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-lapack-copy=1 -enzyme -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-lapack-copy=1 -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* noalias %C, i8* noalias %A, i8* noalias %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 16, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 8, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) %ptr = bitcast i8* %B to double* store double 0.0000000e+00, double* %ptr, align 8 ret void @@ -104,7 +104,7 @@ entry: ; CHECK-NEXT: %cache.B = bitcast i8* %malloccall to double* ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage ; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %3, i8* %4, i8* %B, i8* %ldb_p, double* %cache.B, i8* %3) -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %ptr = bitcast i8* %B to double* ; CHECK-NEXT: store double 0.000000e+00, double* %ptr, align 8 ; CHECK-NEXT: br label %invertentry @@ -138,7 +138,7 @@ entry: ; CHECK-DAG: %[[r17:.+]] = icmp eq i8 %loaded.trans1, 110 ; CHECK-NEXT: %[[r18:.+]] = or i1 %[[r17]], %[[r16]] ; CHECK-NEXT: %[[r19:.+]] = select i1 %[[r18]], i8* %k_p, i8* %n_p -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %10, i8* %[[r19]], i8* %beta_p, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %10, i8* %[[r19]], i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_lacpy.ll index c7571f1ea8ef..f07275c6df1e 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_lacpy.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-lapack-copy=1 -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-lapack-copy=1 -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* %C, i8* %A, i8* %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 8, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ret void } @@ -84,7 +84,7 @@ entry: ; CHECK-NEXT: store i64 8, i64* %ldb, align 16 ; CHECK-NEXT: store double 0.000000e+00, double* %beta ; CHECK-NEXT: store i64 4, i64* %ldc, align 16 -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: br label %invertentry ; CHECK: invertentry: ; preds = %entry @@ -110,8 +110,8 @@ entry: ; CHECK-NEXT: store i8 %[[i25]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p) -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %[[int00]] ; CHECK-NEXT: %[[intcast00:.+]] = bitcast i64* %[[int00]] to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over.ll index c95f61c138a4..5f5ff35f0659 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* %C, i8* %A, i8* %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 8, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) store i64 0, i64* %m, align 16 ret void } @@ -88,7 +88,7 @@ entry: ; CHECK-NEXT: store i64 4, i64* %ldc, align 16 ; CHECK-NEXT: %pcld.m = bitcast i8* %m_p to i64* ; CHECK-NEXT: %avld.m = load i64, i64* %pcld.m -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: store i64 0, i64* %m ; CHECK-NEXT: br label %invertentry @@ -117,8 +117,8 @@ entry: ; CHECK-NEXT: store i8 %[[i25]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %cast.m, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p) -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %cast.m, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %cast.m, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %cast.m, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over_lacpy.ll index 63283c5e271c..1c1c0cea8c2c 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over_lacpy.ll @@ -1,7 +1,7 @@ ;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-lapack-copy=1 -S | FileCheck %s; fi ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-lapack-copy=1 -S | FileCheck %s -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly) +declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) define void @f(i8* %C, i8* %A, i8* %B) { entry: @@ -33,7 +33,7 @@ entry: store i64 8, i64* %ldb, align 16 store double 0.000000e+00, double* %beta store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) + call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) store i64 0, i64* %m, align 16 ret void } @@ -88,7 +88,7 @@ entry: ; CHECK-NEXT: store i64 4, i64* %ldc, align 16 ; CHECK-NEXT: %pcld.m = bitcast i8* %m_p to i64* ; CHECK-NEXT: %avld.m = load i64, i64* %pcld.m -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: store i64 0, i64* %m ; CHECK-NEXT: br label %invertentry @@ -117,8 +117,8 @@ entry: ; CHECK-DAG: store i8 %[[i25]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %cast.m, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p) -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %cast.m, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p) +; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %cast.m, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) +; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %cast.m, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop.ll index 21807cc0b4de..8997057fde01 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop.ll @@ -55,7 +55,7 @@ entry: ; CHECK-NEXT: %mallocsize = mul nuw nsw i32 %0, 8 ; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i32 %mallocsize) ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall to double* -; CHECK-NEXT: call void @dlacpy(i8 0, i32 %N, i32 %N, double* %K, i32 %N, double* %cache.A, i32 %N) +; CHECK-NEXT: call void @cblas_dlacpy(i32 101, i8 0, i32 %N, i32 %N, double* %K, i32 %N, double* %cache.A, i32 %N) ; CHECK-NEXT: %1 = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: %mallocsize1 = mul nuw nsw i32 %1, 8 ; CHECK-NEXT: %malloccall2 = tail call noalias nonnull i8* @malloc(i32 %mallocsize1) diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop2.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop2.ll index 872964e04cf7..8501555f6e06 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop2.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop2.ll @@ -55,7 +55,7 @@ entry: ; CHECK-NEXT: %mallocsize = mul nuw nsw i32 %0, 8 ; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i32 %mallocsize) ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall to double* -; CHECK-NEXT: call void @dlacpy(i8 0, i32 %N, i32 %N, double* %K, i32 %N, double* %cache.A, i32 %N) +; CHECK-NEXT: call void @cblas_dlacpy(i32 101, i8 0, i32 %N, i32 %N, double* %K, i32 %N, double* %cache.A, i32 %N) ; CHECK-NEXT: %1 = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: %mallocsize1 = mul nuw nsw i32 %1, 8 ; CHECK-NEXT: %malloccall2 = tail call noalias nonnull i8* @malloc(i32 %mallocsize1) @@ -108,8 +108,7 @@ entry: ; CHECK-DAG: %[[r20:.+]] = select i1 false, double* %"v0'", double* %cache.x_unwrap ; CHECK-DAG: %[[r21:.+]] = select i1 false, double* %cache.x_unwrap, double* %"v0'" ; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r20]], i32 1, double* %[[r21]], i32 1, double* %"K'", i32 %N) -; CHECK-NEXT: %[[i22:.+]] = select i1 false, i32 %N, i32 %N -; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A_unwrap, i32 %[[i22]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) +; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A_unwrap, i32 %N, double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) ; CHECK-NEXT: %[[i23:.+]] = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: call void @cblas_dscal(i32 %[[i23]], double 1.000000e+00, double* %"v0'", i32 1) ; CHECK-NEXT: %[[i24:.+]] = bitcast double* %cache.A_unwrap to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop3_matcopy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop3_matcopy.ll index 3819ed2b6eb9..5133760023d2 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop3_matcopy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop3_matcopy.ll @@ -201,8 +201,7 @@ entry: ; CHECK-DAG: %[[r42:.+]] = select i1 false, double* %"v0'", double* %cache.x ; CHECK-DAG: %[[r43:.+]] = select i1 false, double* %cache.x, double* %"v0'" ; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r42]], i32 1, double* %[[r43]], i32 1, double* %"K'", i32 %N) -; CHECK-NEXT: %[[i48:.+]] = select i1 false, i32 %N, i32 %N -; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A, i32 %[[i48]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) +; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A, i32 %N, double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) ; CHECK-NEXT: %[[i49:.+]] = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: call void @cblas_dscal(i32 %[[i49]], double 1.000000e+00, double* %"v0'", i32 1) ; CHECK-NEXT: %[[i50:.+]] = bitcast double* %cache.A to i8* @@ -212,8 +211,7 @@ entry: ; CHECK-DAG: %[[r48:.+]] = select i1 false, double* %"v0'", double* %cache.x8 ; CHECK-DAG: %[[r49:.+]] = select i1 false, double* %cache.x8, double* %"v0'" ; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r48]], i32 1, double* %[[r49]], i32 1, double* %"K'", i32 %N) -; CHECK-NEXT: %[[i52:.+]] = select i1 false, i32 %N, i32 %N -; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A5, i32 %[[i52]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) +; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A5, i32 %N, double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) ; CHECK-NEXT: %[[i53:.+]] = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: call void @cblas_dscal(i32 %[[i53]], double 1.000000e+00, double* %"v0'", i32 1) ; CHECK-NEXT: %[[i54:.+]] = bitcast double* %cache.A5 to i8* @@ -223,8 +221,7 @@ entry: ; CHECK-DAG: %[[r54:.+]] = select i1 false, double* %"v0'", double* %cache.x16 ; CHECK-DAG: %[[r55:.+]] = select i1 false, double* %cache.x16, double* %"v0'" ; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r54]], i32 1, double* %[[r55]], i32 1, double* %"K'", i32 %N) -; CHECK-NEXT: %[[i56:.+]] = select i1 false, i32 %N, i32 %N -; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A13, i32 %[[i56]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) +; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A13, i32 %N, double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) ; CHECK-NEXT: %[[i57:.+]] = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: call void @cblas_dscal(i32 %[[i57]], double 1.000000e+00, double* %"v0'", i32 1) ; CHECK-NEXT: %[[i58:.+]] = bitcast double* %cache.A13 to i8* @@ -234,8 +231,7 @@ entry: ; CHECK-DAG: %[[r60:.+]] = select i1 false, double* %"v0'", double* %cache.x24 ; CHECK-DAG: %[[r61:.+]] = select i1 false, double* %cache.x24, double* %"v0'" ; CHECK-NEXT: call void @cblas_dger(i32 101, i32 %N, i32 %N, double 1.000000e-03, double* %[[r60]], i32 1, double* %[[r61]], i32 1, double* %"K'", i32 %N) -; CHECK-NEXT: %[[i60:.+]] = select i1 false, i32 %N, i32 %N -; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A21, i32 %[[i60]], double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) +; CHECK-NEXT: call void @cblas_dgemv(i32 101, i32 112, i32 %N, i32 %N, double 1.000000e-03, double* %cache.A21, i32 %N, double* %"v0'", i32 1, double 1.000000e+00, double* %"x0'", i32 1) ; CHECK-NEXT: %[[i61:.+]] = select i1 false, i32 %N, i32 %N ; CHECK-NEXT: call void @cblas_dscal(i32 %[[i61]], double 1.000000e+00, double* %"v0'", i32 1) ; CHECK-NEXT: %[[i62:.+]] = bitcast double* %cache.A21 to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll index 856451959941..7f70c9865b6e 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll @@ -4,7 +4,7 @@ ; Here we don't transpose the matrix a (78 equals 'N' in ASCII) and we therefore also don't transpose x. ; Therfore the first arg to dcopy is n_p, as opposed to the gemv_transpose test. ; trans, M, N, alpha, A, lda, x, , incx, beta, y, incy -declare void @dgemv_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* , i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8* , i8* nocapture readonly) +declare void @dgemv_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* , i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8* , i8* nocapture readonly, i64) define void @f(i8* noalias %y, i8* noalias %A, i8* noalias %x, i8* noalias %alpha, i8* noalias %beta) { entry: @@ -25,7 +25,7 @@ entry: store i64 4, i64* %lda, align 16 store i64 2, i64* %incx, align 16 store i64 1, i64* %incy, align 16 - call void @dgemv_64_(i8* %transa, i8* %m_p, i8* %n_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %x, i8* %incx_p, i8* %beta, i8* %y, i8* %incy_p) + call void @dgemv_64_(i8* %transa, i8* %m_p, i8* %n_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %x, i8* %incx_p, i8* %beta, i8* %y, i8* %incy_p, i64 1) ret void } @@ -104,7 +104,7 @@ entry: ; CHECK-NEXT: %23 = insertvalue { double*, double* } undef, double* %cache.x, 0 ; CHECK-NEXT: %24 = insertvalue { double*, double* } %23, double* %cache.y, 1 ; CHECK-NEXT: store { double*, double* } %24, { double*, double* }* %0 -; CHECK-NEXT: call void @dgemv_64_(i8* %malloccall, i8* %m_p, i8* %n_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %x, i8* %incx_p, i8* %beta, i8* %y, i8* %incy_p) +; CHECK-NEXT: call void @dgemv_64_(i8* %malloccall, i8* %m_p, i8* %n_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %x, i8* %incx_p, i8* %beta, i8* %y, i8* %incy_p, i64 1) ; CHECK-NEXT: %25 = load { double*, double* }, { double*, double* }* %0 ; CHECK-NEXT: ret { double*, double* } %25 ; CHECK-NEXT: } @@ -114,11 +114,13 @@ entry: ; CHECK-NEXT: %ret = alloca double ; CHECK-NEXT: %byref.transpose.transa = alloca i8 ; CHECK-NEXT: %byref.int.one = alloca i64 -; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double -; CHECK-NEXT: %byref.constant.fp.0.0 = alloca double +; CHECK-DAG: %byref.constant.fp.1.0 = alloca double +; CHECK-DAG: %byref.constant.char.N = alloca i8, align 1 +; CHECK-DAG: %byref.constant.fp.0.0 = alloca double ; CHECK-NEXT: %byref.constant.int.1 = alloca i64 ; CHECK-NEXT: %byref.constant.int.17 = alloca i64 -; CHECK-NEXT: %byref.constant.fp.1.013 = alloca double +; CHECK-NEXT: %byref.constant.char.N11 = alloca i8, align 1 +; CHECK-NEXT: %[[byrefconstantfp1:.+]] = alloca double ; CHECK-NEXT: %incy = alloca i64, i64 1, align 16 ; CHECK-NEXT: %1 = bitcast i64* %incy to i8* ; CHECK-NEXT: %incx = alloca i64, i64 1, align 16 @@ -182,11 +184,12 @@ entry: ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 ; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1 ; CHECK-NEXT: store double 0.000000e+00, double* %byref.constant.fp.0.0 ; CHECK-NEXT: %fpcast.constant.fp.0.0 = bitcast double* %byref.constant.fp.0.0 to i8* ; CHECK-NEXT: store i64 1, i64* %byref.constant.int.1 ; CHECK-NEXT: %intcast.constant.int.1 = bitcast i64* %byref.constant.int.1 to i8* -; CHECK-NEXT: call void @dgemv_64_(i8* %malloccall, i8* %m_p, i8* %n_p, i8* %fpcast.constant.fp.1.0, i8* %A, i8* %lda_p, i8* %20, i8* %intcast.int.one, i8* %fpcast.constant.fp.0.0, i8* %19, i8* %intcast.constant.int.1) +; CHECK-NEXT: call void @dgemv_64_(i8* %malloccall, i8* %m_p, i8* %n_p, i8* %fpcast.constant.fp.1.0, i8* %A, i8* %lda_p, i8* %20, i8* %intcast.int.one, i8* %fpcast.constant.fp.0.0, i8* %19, i8* %intcast.constant.int.1, i64 1) ; CHECK-NEXT: %ld.row.trans = load i8, i8* %malloccall ; CHECK-DAG: %[[c1:.+]] = icmp eq i8 %ld.row.trans, 110 ; CHECK-DAG: %[[c2:.+]] = icmp eq i8 %ld.row.trans, 78 @@ -204,28 +207,21 @@ entry: ; CHECK-DAG: %[[i40:.+]] = icmp eq i8 %ld.row.trans9, 78 ; CHECK-NEXT: %[[i41:.+]] = or i1 %[[i40]], %[[i39]] ; CHECK-NEXT: %[[i42:.+]] = select i1 %41, i8* %"y'", i8* %20 +; CHECK-NEXT: %[[i46:.+]] = select i1 %[[i41]], i8* %incy_p, i8* %intcast.int.one ; CHECK-NEXT: %ld.row.trans10 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[i43:.+]] = icmp eq i8 %ld.row.trans10, 110 -; CHECK-DAG: %[[i44:.+]] = icmp eq i8 %ld.row.trans10, 78 -; CHECK-NEXT: %[[i45:.+]] = or i1 %[[i44]], %[[i43]] -; CHECK-NEXT: %[[i46:.+]] = select i1 %[[i45]], i8* %incy_p, i8* %incx_p -; CHECK-NEXT: %ld.row.trans11 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[i47:.+]] = icmp eq i8 %ld.row.trans11, 110 -; CHECK-DAG: %[[i48:.+]] = icmp eq i8 %ld.row.trans11, 78 +; CHECK-DAG: %[[i47:.+]] = icmp eq i8 %ld.row.trans10, 110 +; CHECK-DAG: %[[i48:.+]] = icmp eq i8 %ld.row.trans10, 78 ; CHECK-NEXT: %[[i49:.+]] = or i1 %[[i48]], %[[i47]] ; CHECK-NEXT: %[[i50:.+]] = select i1 %[[i49]], i8* %20, i8* %"y'" -; CHECK-NEXT: %ld.row.trans12 = load i8, i8* %malloccall, align 1 -; CHECK-NEXT: %[[i51:.+]] = icmp eq i8 %ld.row.trans12, 110 -; CHECK-NEXT: %[[i52:.+]] = icmp eq i8 %ld.row.trans12, 78 -; CHECK-NEXT: %[[i53:.+]] = or i1 %52, %51 -; CHECK-NEXT: %[[i54:.+]] = select i1 %53, i8* %incx_p, i8* %incy_p +; CHECK-NEXT: %[[i54:.+]] = select i1 %[[i49]], i8* %intcast.int.one, i8* %incy_p ; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha, i8* %[[i42]], i8* %[[i46]], i8* %[[i50]], i8* %[[i54]], i8* %"A'", i8* %lda_p) -; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.013 -; CHECK-NEXT: %fpcast.constant.fp.1.014 = bitcast double* %byref.constant.fp.1.013 to i8* -; CHECK-NEXT: call void @dgemv_64_(i8* %byref.transpose.transa, i8* %m_p, i8* %n_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %"y'", i8* %incy_p, i8* %fpcast.constant.fp.1.014, i8* %"x'", i8* %incx_p) -; CHECK-NEXT: %ld.row.trans15 = load i8, i8* %malloccall -; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans15, 110 -; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans15, 78 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N11, align 1 +; CHECK-NEXT: store double 1.000000e+00, double* %[[byrefconstantfp1]] +; CHECK-NEXT: %[[fpcast14:.+]] = bitcast double* %[[byrefconstantfp1]] to i8* +; CHECK-NEXT: call void @dgemv_64_(i8* %byref.transpose.transa, i8* %m_p, i8* %n_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %"y'", i8* %incy_p, i8* %[[fpcast14]], i8* %"x'", i8* %incx_p, i64 1) +; CHECK-NEXT: %ld.row.trans14 = load i8, i8* %malloccall +; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans14, 110 +; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans14, 78 ; CHECK-NEXT: %[[r41:.+]] = or i1 %[[r40]], %[[r39]] ; CHECK-NEXT: %[[r42:.+]] = select i1 %[[r41]], i8* %m_p, i8* %n_p ; CHECK-NEXT: %[[r43:.+]] = call fast double @ddot_64_(i8* %[[r42]], i8* %"y'", i8* %incy_p, i8* %21, i8* %intcast.int.one) @@ -233,9 +229,9 @@ entry: ; CHECK-NEXT: %[[r45:.+]] = load double, double* %[[r44]] ; CHECK-NEXT: %[[r46:.+]] = fadd fast double %[[r45]], %[[r43]] ; CHECK-NEXT: store double %[[r46]], double* %[[r44]] -; CHECK-NEXT: %ld.row.trans16 = load i8, i8* %malloccall -; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans16, 110 -; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans16, 78 +; CHECK-NEXT: %ld.row.trans15 = load i8, i8* %malloccall +; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans15, 110 +; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans15, 78 ; CHECK-NEXT: %[[r49:.+]] = or i1 %[[r48]], %[[r47]] ; CHECK-NEXT: %[[r50:.+]] = select i1 %[[r49]], i8* %m_p, i8* %n_p ; CHECK-NEXT: call void @dscal_64_(i8* %[[r50]], i8* %beta, i8* %"y'", i8* %incy_p) diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll index 2e2ea7b4dac4..7337b69b7794 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll @@ -176,21 +176,13 @@ entry: ; CHECK-DAG: %[[r23:.+]] = icmp eq i8 %ld.row.trans, 78 ; CHECK-NEXT: %[[r24:.+]] = or i1 %[[r23]], %[[r22]] ; CHECK-NEXT: %[[r25:.+]] = select i1 %[[r24]], i8* %"y'", i8* %11 +; CHECK-NEXT: %[[r29:.+]] = select i1 %[[r24]], i8* %incy_p, i8* %intcast.int.one ; CHECK-NEXT: %ld.row.trans2 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[r26:.+]] = icmp eq i8 %ld.row.trans2, 110 -; CHECK-DAG: %[[r27:.+]] = icmp eq i8 %ld.row.trans2, 78 -; CHECK-NEXT: %[[r28:.+]] = or i1 %[[r27]], %[[r26]] -; CHECK-NEXT: %[[r29:.+]] = select i1 %[[r28]], i8* %incy_p, i8* %incx_p -; CHECK-NEXT: %ld.row.trans3 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[r30:.+]] = icmp eq i8 %ld.row.trans3, 110 -; CHECK-DAG: %[[r31:.+]] = icmp eq i8 %ld.row.trans3, 78 +; CHECK-DAG: %[[r30:.+]] = icmp eq i8 %ld.row.trans2, 110 +; CHECK-DAG: %[[r31:.+]] = icmp eq i8 %ld.row.trans2, 78 ; CHECK-NEXT: %[[r32:.+]] = or i1 %[[r31]], %[[r30]] ; CHECK-NEXT: %[[r33:.+]] = select i1 %[[r32]], i8* %11, i8* %"y'" -; CHECK-NEXT: %ld.row.trans4 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[r34:.+]] = icmp eq i8 %ld.row.trans4, 110 -; CHECK-DAG: %[[r35:.+]] = icmp eq i8 %ld.row.trans4, 78 -; CHECK-NEXT: %[[r36:.+]] = or i1 %[[r35]], %[[r34]] -; CHECK-NEXT: %[[r37:.+]] = select i1 %[[r36]], i8* %incx_p, i8* %incy_p +; CHECK-NEXT: %[[r37:.+]] = select i1 %[[r32]], i8* %intcast.int.one, i8* %incy_p ; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha, i8* %[[r25]], i8* %[[r29]], i8* %[[r33]], i8* %[[r37]], i8* %"A'", i8* %lda_p) ; CHECK-NEXT: br label %invertentry.A.done @@ -198,9 +190,9 @@ entry: ; CHECK-NEXT: br i1 %rt.inactive.beta, label %invertentry.beta.done, label %invertentry.beta.active ; CHECK: invertentry.beta.active: ; preds = %invertentry.A.done -; CHECK-NEXT: %ld.row.trans5 = load i8, i8* %malloccall -; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans5, 110 -; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans5, 78 +; CHECK-NEXT: %ld.row.trans3 = load i8, i8* %malloccall +; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans3, 110 +; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans3, 78 ; CHECK-NEXT: %[[r41:.+]] = or i1 %[[r40]], %[[r39]] ; CHECK-NEXT: %[[r42:.+]] = select i1 %[[r41]], i8* %m_p, i8* %n_p ; CHECK-NEXT: %[[r43:.+]] = call fast double @ddot_64_(i8* %[[r42]], i8* %"y'", i8* %incy_p, i8* %[[i12]], i8* %intcast.int.one) @@ -214,9 +206,9 @@ entry: ; CHECK-NEXT: br i1 %rt.inactive.y, label %invertentry.y.done, label %invertentry.y.active ; CHECK: invertentry.y.active: ; preds = %invertentry.beta.done -; CHECK-NEXT: %ld.row.trans6 = load i8, i8* %malloccall -; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans6, 110 -; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans6, 78 +; CHECK-NEXT: %ld.row.trans4 = load i8, i8* %malloccall +; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans4, 110 +; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans4, 78 ; CHECK-NEXT: %[[r49:.+]] = or i1 %[[r48]], %[[r47]] ; CHECK-NEXT: %[[r50:.+]] = select i1 %[[r49]], i8* %m_p, i8* %n_p ; CHECK-NEXT: call void @dscal_64_(i8* %[[r50]], i8* %beta, i8* %"y'", i8* %incy_p) diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll index bef22acf4410..cd853b3f207a 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll @@ -2,7 +2,7 @@ ;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-blas-copy=0 -enzyme-lapack-copy=1 -S | FileCheck %s ; trans, M, N, alpha, A, lda, x, , incx, beta, y, incy -declare void @dgemv_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* , i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8* , i8* nocapture readonly) +declare void @dgemv_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* , i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8* , i8* nocapture readonly, i64) define void @f(i8* noalias %y, i8* noalias %A, i8* noalias %x) { entry: @@ -29,7 +29,7 @@ entry: store i64 2, i64* %incx, align 16 store double 0.000000e+00, double* %beta store i64 1, i64* %incy, align 16 - call void @dgemv_64_(i8* %transa, i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %x, i8* %incx_p, i8* %beta_p, i8* %y, i8* %incy_p) + call void @dgemv_64_(i8* %transa, i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %x, i8* %incx_p, i8* %beta_p, i8* %y, i8* %incy_p, i64 1) ret void } @@ -126,7 +126,7 @@ entry: ; CHECK-NEXT: br i1 %25, label %__enzyme_memcpy_double_64_da0sa0stride.exit, label %for.body.i ; CHECK: __enzyme_memcpy_double_64_da0sa0stride.exit: ; preds = %entry, %for.body.i -; CHECK-NEXT: call void @dgemv_64_(i8* %malloccall, i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %x, i8* %incx_p, i8* %beta_p, i8* %y, i8* %incy_p) +; CHECK-NEXT: call void @dgemv_64_(i8* %malloccall, i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %x, i8* %incx_p, i8* %beta_p, i8* %y, i8* %incy_p, i64 1) ; CHECK-NEXT: %26 = load double*, double** %0 ; CHECK-NEXT: ret double* %26 ; CHECK-NEXT: } @@ -136,6 +136,7 @@ entry: ; CHECK-NEXT: %ret = alloca double ; CHECK-NEXT: %byref.transpose.transa = alloca i8 ; CHECK-NEXT: %byref.int.one = alloca i64 +; CHECK-NEXT: %byref.constant.char.N = alloca i8, align 1 ; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double ; CHECK-NEXT: %incy = alloca i64, i64 1, align 16 ; CHECK-NEXT: %1 = bitcast i64* %incy to i8* @@ -195,28 +196,21 @@ entry: ; CHECK-DAG: %[[r25:.+]] = icmp eq i8 %ld.row.trans, 78 ; CHECK-DAG: %[[r26:.+]] = or i1 %[[r25]], %[[r24]] ; CHECK-NEXT: %[[r27:.+]] = select i1 %[[r26]], i8* %"y'", i8* %15 -; CHECK-NEXT: %ld.row.trans1 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[r28:.+]] = icmp eq i8 %ld.row.trans1, 110 -; CHECK-DAG: %[[r29:.+]] = icmp eq i8 %ld.row.trans1, 78 -; CHECK-DAG: %[[r30:.+]] = or i1 %[[r29]], %[[r28]] -; CHECK-NEXT: %[[r31:.+]] = select i1 %[[r30]], i8* %incy_p, i8* %incx_p -; CHECK-NEXT: %ld.row.trans2 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[r32:.+]] = icmp eq i8 %ld.row.trans2, 110 -; CHECK-DAG: %[[r33:.+]] = icmp eq i8 %ld.row.trans2, 78 +; CHECK-NEXT: %[[r31:.+]] = select i1 %[[r26]], i8* %incy_p, i8* %intcast.int.one +; CHECK-NEXT: %ld.row.trans1 = load i8, i8* %malloccall +; CHECK-DAG: %[[r32:.+]] = icmp eq i8 %ld.row.trans1, 110 +; CHECK-DAG: %[[r33:.+]] = icmp eq i8 %ld.row.trans1, 78 ; CHECK-DAG: %[[r34:.+]] = or i1 %[[r33]], %[[r32]] ; CHECK-NEXT: %[[r35:.+]] = select i1 %[[r34]], i8* %15, i8* %"y'" -; CHECK-NEXT: %ld.row.trans3 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[r36:.+]] = icmp eq i8 %ld.row.trans3, 110 -; CHECK-DAG: %[[r37:.+]] = icmp eq i8 %ld.row.trans3, 78 -; CHECK-DAG: %[[r38:.+]] = or i1 %[[r37]], %[[r36]] -; CHECK-NEXT: %[[r39:.+]] = select i1 %[[r38]], i8* %incx_p, i8* %incy_p -; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %27, i8* %31, i8* %35, i8* %39, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: %[[r39:.+]] = select i1 %[[r34]], i8* %intcast.int.one, i8* %incy_p +; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %[[r27]], i8* %[[r31]], i8* %[[r35]], i8* %[[r39]], i8* %"A'", i8* %lda_p) +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1 ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 ; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* -; CHECK-NEXT: call void @dgemv_64_(i8* %byref.transpose.transa, i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"y'", i8* %incy_p, i8* %fpcast.constant.fp.1.0, i8* %"x'", i8* %incx_p) -; CHECK-NEXT: %ld.row.trans4 = load i8, i8* %malloccall -; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans4, 110 -; CHECK-DAG: %[[r41:.+]] = icmp eq i8 %ld.row.trans4, 78 +; CHECK-NEXT: call void @dgemv_64_(i8* %byref.transpose.transa, i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"y'", i8* %incy_p, i8* %fpcast.constant.fp.1.0, i8* %"x'", i8* %incx_p, i64 1) +; CHECK-NEXT: %ld.row.trans2 = load i8, i8* %malloccall +; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans2, 110 +; CHECK-DAG: %[[r41:.+]] = icmp eq i8 %ld.row.trans2, 78 ; CHECK-NEXT: %[[r42:.+]] = or i1 %[[r41]], %[[r40]] ; CHECK-NEXT: %[[r43:.+]] = select i1 %[[r42]], i8* %m_p, i8* %n_p ; CHECK-NEXT: call void @dscal_64_(i8* %[[r43]], i8* %beta_p, i8* %"y'", i8* %incy_p) diff --git a/enzyme/test/Integration/ReverseMode/blas.cpp b/enzyme/test/Integration/ReverseMode/blas.cpp index a225cbd84013..1561cda793f2 100644 --- a/enzyme/test/Integration/ReverseMode/blas.cpp +++ b/enzyme/test/Integration/ReverseMode/blas.cpp @@ -1,17 +1,17 @@ // This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load // a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... -// RUN: %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - -// RUN: %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - -// RUN: %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - -// RUN: %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - -// RUN: %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - -// RUN: %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// RUN: %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - +// RUN: %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - +// RUN: %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - +// RUN: %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -S | %lli - +// RUN: %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -S | %lli - +// RUN: %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -S | %lli - +// TORUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// TORUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// TORUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// TORUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// TORUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// TORUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi #include "test_utils.h" @@ -98,8 +98,9 @@ bool inDerivative = false; double alpha = 2.71828; double beta = 47.56; - int M = 105688; - int N = 78412; + int M = 105688; + int N = 78412; + int K = 5013424; int lda = 3416; char UNUSED_TRANS = 'A'; int UNUSED_INT = -1; @@ -110,7 +111,11 @@ enum class CallType { GEMM, SCAL, GER, - COPY + DOT, + AXPY, + LASCL, + COPY, + LACPY, }; struct BlasCall { @@ -168,6 +173,10 @@ void printty(CallType v) { case CallType::SCAL: printf("SCAL"); return; case CallType::GER: printf("GER"); return; case CallType::COPY: printf("COPY"); return; + case CallType::LACPY: printf("LACPY"); return; + case CallType::DOT: printf("DOT"); return; + case CallType::AXPY: printf("AXPY"); return; + case CallType::LASCL: printf("LASCL"); return; default: printf("UNKNOWN CALL (%d)", (int)v); } } @@ -190,6 +199,7 @@ void printty(int v) { else if (v == incC) printf("incC"); else if (v == M) printf("M"); else if (v == N) printf("N"); + else if (v == K) printf("K"); else if (v == lda) printf("lda"); else if (v == UNUSED_INT) printf("UNUSED_INT"); else printf("Unknown int"); @@ -226,6 +236,77 @@ void printty(double v) { void printcall(BlasCall rcall) { switch (rcall.type) { + case CallType::LACPY: + printf("LACPY(layout="); + printty(rcall.layout); + printf(", uplo="); + printty(rcall.targ1); + printf(", M="); + printty(rcall.iarg1); + printf(", N="); + printty(rcall.iarg2); + printf(", A="); + printty(rcall.pin_arg1); + printf(", lda="); + printty(rcall.iarg4); + printf(", B="); + printty(rcall.pout_arg1); + printf(", ldb="); + printty(rcall.iarg5); + printf(")"); + return; + case CallType::LASCL: + printf("LASCL(layout="); + printty(rcall.layout); + printf(", type="); + printty(rcall.targ1); + printf(", KL="); + printty(rcall.iarg5); + printf(", KU="); + printty(rcall.iarg6); + printf(", cfrom="); + printty(rcall.farg1); + printf(", cto="); + printty(rcall.farg2); + + printf(", M="); + printty(rcall.iarg1); + printf(", N="); + printty(rcall.iarg2); + printf(", A="); + printty(rcall.pout_arg1); + printf(", lda="); + printty(rcall.iarg4); + printf(")"); + return; + case CallType::AXPY: + printf("DOT(N="); + printty(rcall.iarg1); + printf(", alpha="); + printty(rcall.farg1); + printf(", X="); + printty(rcall.pin_arg1); + printf(", incx="); + printty(rcall.iarg4); + printf(", Y="); + printty(rcall.pout_arg1); + printf(", incy="); + printty(rcall.iarg5); + printf(")"); + return; + case CallType::DOT: + printf("DOT(N="); + printty(rcall.iarg1); + printf(", X="); + printty(rcall.pin_arg1); + printf(", incx="); + printty(rcall.iarg4); + printf(", Y="); + printty(rcall.pin_arg2); + printf(", incy="); + printty(rcall.iarg5); + printf(")"); + return; case CallType::GEMV: printf("GEMV(layout="); printty(rcall.layout); @@ -388,6 +469,43 @@ vector foundCalls; extern "C" { +// https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2023-0/lascl.html +// technically LAPACKE_dlascl +__attribute__((noinline)) +void cblas_dlascl(char layout, char type, int KL, int KU, double cfrom, double cto, int M, int N, double* A, int lda) { + BlasCall call = {inDerivative, CallType::LASCL, + A, UNUSED_POINTER, UNUSED_POINTER, + cfrom, cto, + layout, + type, UNUSED_TRANS, + M, N, UNUSED_INT, lda, KL, KU}; + calls.push_back(call); +} + +__attribute__((noinline)) +double cblas_ddot(int N, double* X, int incx, double* Y, int incy) { + BlasCall call = {inDerivative, CallType::DOT, + UNUSED_POINTER, X, Y, + UNUSED_DOUBLE, UNUSED_DOUBLE, + UNUSED_TRANS, + UNUSED_TRANS, UNUSED_TRANS, + N, UNUSED_INT, UNUSED_INT, incx, incy, UNUSED_INT}; + calls.push_back(call); + return 3.15+N; +} + +// Y += alpha * X +__attribute__((noinline)) +void cblas_daxpy(int N, double alpha, double* X, int incx, double* Y, int incy) { + BlasCall call = {inDerivative, CallType::AXPY, + Y, X, UNUSED_POINTER, + alpha, UNUSED_DOUBLE, + UNUSED_TRANS, + UNUSED_TRANS, UNUSED_TRANS, + N, UNUSED_INT, UNUSED_INT, incx, incy, UNUSED_INT}; + calls.push_back(call); +} + // Y = alpha * op(A) * X + beta * Y __attribute__((noinline)) void cblas_dgemv(char layout, char trans, int M, int N, double alpha, double* A, int lda, double* X, int incx, double beta, double* Y, int incy) { @@ -443,6 +561,22 @@ void cblas_dcopy(int N, double* X, int incX, double* Y, int incY) { UNUSED_TRANS, UNUSED_TRANS, N, UNUSED_INT, UNUSED_INT, incX, incY, UNUSED_INT}); } + +__attribute__((noinline)) +void cblas_dlacpy(char layout, char uplo, int M, int N, double* A, int lda, double* B, int ldb) { + calls.push_back((BlasCall){inDerivative, CallType::LACPY, + B, A, UNUSED_POINTER, + UNUSED_DOUBLE, UNUSED_DOUBLE, + layout, + uplo, UNUSED_TRANS, + M, N, UNUSED_INT, lda, ldb, UNUSED_INT}); +} + +__attribute__((noinline)) +void dlacpy(char *uplo, int *M, int* N, double* A, int *lda, double* B, int* ldb) { + cblas_dlacpy(CblasColMajor, *uplo, *M, *N, A, *lda, B, *ldb); +} + } enum class ValueType { @@ -450,6 +584,7 @@ enum class ValueType { Vector }; struct BlasInfo { + void* ptr; ValueType ty; int vec_length; int vec_increment; @@ -457,7 +592,8 @@ struct BlasInfo { int mat_rows; int mat_cols; int mat_ld; - BlasInfo (int length, int increment) { + BlasInfo (void* v_ptr, int length, int increment) { + ptr = v_ptr; ty = ValueType::Vector; vec_length = length; vec_increment = increment; @@ -466,7 +602,8 @@ struct BlasInfo { mat_cols = -1; mat_ld = -1; } - BlasInfo (char layout, int rows, int cols, int ld) { + BlasInfo (void* v_ptr, char layout, int rows, int cols, int ld) { + ptr = v_ptr; ty = ValueType::Matrix; vec_length = -1; vec_increment = -1; @@ -475,12 +612,25 @@ struct BlasInfo { mat_cols = cols; mat_ld = ld; } + BlasInfo () { + ptr = (void*)(-1); + ty = ValueType::Matrix; + vec_length = -1; + vec_increment = -1; + mat_layout = -1; + mat_rows = -1; + mat_cols = -1; + mat_ld = -1; + } }; -int pointer_to_index(void* v) { - if (v == A || v == dA) return 0; - if (v == B || v == dB) return 1; - if (v == C || v == dC) return 2; +BlasInfo pointer_to_index(void* v, BlasInfo inputs[6]) { + if (v == A || v == dA) return inputs[0]; + if (v == B || v == dB) return inputs[1]; + if (v == C || v == dC) return inputs[2]; + for (int i=3; i<6; i++) + if (inputs[i].ptr == v) + return inputs[i]; assert(0 && " illegal pointer to invert"); } @@ -570,17 +720,66 @@ void checkMatrix(BlasInfo info, std::string matname, char layout, int rows, int } } -void checkMemory(BlasCall rcall, BlasInfo inputs[3], std::string test, const vector & trace) { +void checkMemory(BlasCall rcall, BlasInfo inputs[6], std::string test, const vector & trace) { switch (rcall.type) { + return; + case CallType::LASCL: { + auto A = pointer_to_index(rcall.pout_arg1, inputs); + + auto layout = rcall.layout; + auto type = rcall.targ1; + auto KL = rcall.iarg5; + auto KU = rcall.iarg6; + auto cfrom = rcall.farg1; + auto cto = rcall.farg2; + + auto M = rcall.iarg1; + auto N = rcall.iarg2; + auto lda = rcall.iarg4; + + // = 'G': A is a full matrix. + assert(type == 'G'); + + // A is an m-by-n matrix + checkMatrix(A, "A", layout, /*rows=*/M, /*cols=*/N, /*ld=*/lda, test, rcall, trace); + return; + } + case CallType::AXPY: { + auto Y = pointer_to_index(rcall.pout_arg1, inputs); + + auto X = pointer_to_index(rcall.pin_arg1, inputs); + + auto alpha = rcall.farg1; + + auto N = rcall.iarg1; + auto incX = rcall.iarg4; + auto incY = rcall.iarg5; + + checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace); + checkVector(Y, "Y", /*len=*/N, /*inc=*/incY, test, rcall, trace); + return; + } + case CallType::DOT: { + auto X = pointer_to_index(rcall.pin_arg1, inputs); + auto Y = pointer_to_index(rcall.pin_arg2, inputs); + + auto N = rcall.iarg1; + auto incX = rcall.iarg4; + auto incY = rcall.iarg5; + + checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace); + checkVector(Y, "Y", /*len=*/N, /*inc=*/incY, test, rcall, trace); + return; + } case CallType::GEMV:{ // Y = alpha * op(A) * X + beta * Y - auto Y = inputs[pointer_to_index(rcall.pout_arg1)]; - auto A = inputs[pointer_to_index(rcall.pin_arg1)]; - auto X = inputs[pointer_to_index(rcall.pin_arg2)]; + auto Y = pointer_to_index(rcall.pout_arg1, inputs); + auto A = pointer_to_index(rcall.pin_arg1, inputs); + auto X = pointer_to_index(rcall.pin_arg2, inputs); auto layout = rcall.layout; auto trans_char = rcall.targ1; - auto trans = (trans_char == 'N' || trans_char == 'n'); + auto trans = !(trans_char == 'N' || trans_char == 'n'); auto M = rcall.iarg1; auto N =rcall.iarg2; auto alpha = rcall.farg1; @@ -600,24 +799,26 @@ void checkMemory(BlasCall rcall, BlasInfo inputs[3], std::string test, const vec // ( 1 + ( m - 1 )*abs( INCX ) ) otherwise. // Before entry, the incremented array X must contain the // vector x. - auto Xlen = trans ? N : M; + auto Xlen = trans ? M : N; checkVector(X, "X", /*len=*/Xlen, /*inc=*/incX, test, rcall, trace); // if no trans, Y must be M otherwise must be N - auto Ylen = trans ? M : N; + auto Ylen = trans ? N : M; checkVector(Y, "Y", /*len=*/Ylen, /*inc=*/incY, test, rcall, trace); return; } case CallType::GEMM:{ // C = alpha * A^transA * B^transB + beta * C - auto C = inputs[pointer_to_index(rcall.pout_arg1)]; - auto A = inputs[pointer_to_index(rcall.pin_arg1)]; - auto B = inputs[pointer_to_index(rcall.pin_arg2)]; + auto C = pointer_to_index(rcall.pout_arg1, inputs); + auto A = pointer_to_index(rcall.pin_arg1, inputs); + auto B = pointer_to_index(rcall.pin_arg2, inputs); auto layout = rcall.layout; - auto transA = rcall.targ1; - auto transB = rcall.targ2; + auto transA_char = rcall.targ1; + auto transA = !(transA_char == 'N' || transA_char == 'n'); + auto transB_char = rcall.targ2; + auto transB = !(transB_char == 'N' || transB_char == 'n'); auto M = rcall.iarg1; auto N = rcall.iarg2; auto K = rcall.iarg3; @@ -655,16 +856,16 @@ void checkMemory(BlasCall rcall, BlasInfo inputs[3], std::string test, const vec case CallType::SCAL: { auto N = rcall.iarg1; auto alpha = rcall.farg1; - auto X = inputs[pointer_to_index(rcall.pout_arg1)]; + auto X = pointer_to_index(rcall.pout_arg1, inputs); auto incX = rcall.iarg4; checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace); return; } case CallType::GER: { // A = alpha * X * transpose(Y) + A - auto A = inputs[pointer_to_index(rcall.pout_arg1)]; - auto X = inputs[pointer_to_index(rcall.pin_arg1)]; - auto Y = inputs[pointer_to_index(rcall.pin_arg2)]; + auto A = pointer_to_index(rcall.pout_arg1, inputs); + auto X = pointer_to_index(rcall.pin_arg1, inputs); + auto Y = pointer_to_index(rcall.pin_arg2, inputs); auto layout = rcall.layout; auto M = rcall.iarg1; @@ -683,24 +884,44 @@ void checkMemory(BlasCall rcall, BlasInfo inputs[3], std::string test, const vec return; } case CallType::COPY: { - auto Y = inputs[pointer_to_index(rcall.pout_arg1)]; - auto X = inputs[pointer_to_index(rcall.pin_arg1)]; + auto Y = pointer_to_index(rcall.pout_arg1, inputs); + auto X = pointer_to_index(rcall.pin_arg1, inputs); auto N = rcall.iarg1; auto incX = rcall.iarg4; + auto incY = rcall.iarg5; checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace); - checkVector(Y, "Y", /*len=*/N, /*inc=*/incX, test, rcall, trace); + checkVector(Y, "Y", /*len=*/N, /*inc=*/incY, test, rcall, trace); + return; + } + case CallType::LACPY: { + auto B = pointer_to_index(rcall.pout_arg1, inputs); + auto A = pointer_to_index(rcall.pin_arg1, inputs); + + auto layout = rcall.layout; + auto uplo = rcall.targ1; + auto M = rcall.iarg1; + auto N = rcall.iarg2; + auto lda = rcall.iarg4; + auto ldb = rcall.iarg5; + checkMatrix(A, "A", layout, /*rows=*/M, /*cols=*/N, /*ld=*/lda, test, rcall, trace); + checkMatrix(B, "B", layout, /*rows=*/M, /*cols=*/N, /*ld=*/ldb, test, rcall, trace); return; } default: printf("UNKNOWN CALL (%d)", (int)rcall.type); return; } } -void checkMemoryTrace(BlasInfo inputs[3], std::string test, const vector & trace) { +void checkMemoryTrace(BlasInfo inputs[6], std::string test, const vector & trace) { for (size_t i=0; i 2); + auto A_cache = (double*)foundCalls[0].pout_arg1; + cblas_dlacpy(layout, '\0', M, N, A, lda, A_cache, M); + inputs[4] = BlasInfo(A_cache, layout, M, N, M); + auto B_cache = (double*)foundCalls[1].pout_arg1; + cblas_dcopy(trans ? M : N, B, incB, B_cache, 1); + inputs[5] = BlasInfo(B_cache, trans ? M : N, 1); + + ow_dgemv(layout, transA, M, N, alpha, A, lda, B, incB, beta, C, incC); + + inDerivative = true; + // dC = alpha * X * transpose(Y) + A + cblas_dger(layout, M, N, alpha, + trans ? B_cache : dC, + trans ? 1 : incC, + trans ? dC : B_cache, + trans ? incC : 1, dA, + lda); + + // dB = alpha * trans(A) * dC + dB + cblas_dgemv(layout, transpose(transA), M, N, alpha, A_cache, M, dC, incC, 1.0, dB, incB); + + // dY = beta * dY + cblas_dscal(trans ? N : M, beta, dC, incC); + + checkTest(Test); + + // Check memory of primal of expected derivative + checkMemoryTrace(inputs, "Expected " + Test, calls); + + // Check memory of primal of our derivative (if equal above, it + // should be the same). + checkMemoryTrace(inputs, "Found " + Test, foundCalls); + + inputs[4] = BlasInfo(); + inputs[5] = BlasInfo(); + } + + + } + } +} + +static void gemmTests() { + // N means normal matrix, T means transposed + for (char layout : { CblasRowMajor, CblasColMajor }) { + for (char transA : {'N', 'n', 'T', 't'}) { + for (char transB : {'N', 'n', 'T', 't'}) { + + { + + bool transA_bool = !(transA == 'N' || transA == 'n'); + bool transB_bool = !(transA == 'N' || transA == 'n'); + std::string Test = "GEMM"; + BlasInfo inputs[6] = { + /*A*/ BlasInfo(A, layout, transA_bool ? K : M, transA_bool ? M : K, lda), + /*B*/ BlasInfo(B, layout, transB_bool ? N : K , transA_bool ? K : N, incB), + /*C*/ BlasInfo(C, layout, M, N, incC), + BlasInfo(), + BlasInfo(), + BlasInfo() + }; + init(); + my_dgemm(layout, transA, transB, M, N, K, alpha, A, lda, B, incB, beta, C, incC); + + assert(calls.size() == 1); + assert(calls[0].inDerivative == false); + assert(calls[0].type == CallType::GEMM); + assert(calls[0].pout_arg1 == C); + assert(calls[0].pin_arg1 == A); + assert(calls[0].pin_arg2 == B); + assert(calls[0].farg1 == alpha); + assert(calls[0].farg2 == beta); + assert(calls[0].layout == layout); + assert(calls[0].targ1 == transA); + assert(calls[0].targ2 == transB); + assert(calls[0].iarg1 == M); + assert(calls[0].iarg2 == N); + assert(calls[0].iarg3 == K); + assert(calls[0].iarg4 == lda); + assert(calls[0].iarg5 == incB); + assert(calls[0].iarg6 == incC); + + // Check memory of primal on own. + checkMemoryTrace(inputs, "Primal " + Test, calls); + + init(); + __enzyme_autodiff((void*) my_dgemm, + enzyme_const, layout, + enzyme_const, transA, + enzyme_const, transB, + enzyme_const, M, + enzyme_const, N, + enzyme_const, K, + enzyme_const, alpha, + enzyme_dup, A, dA, + enzyme_const, lda, + enzyme_dup, B, dB, + enzyme_const, incB, + enzyme_const, beta, + enzyme_dup, C, dC, + enzyme_const, incC); + foundCalls = calls; + init(); + + + my_dgemm(layout, transA, transB, M, N, K, alpha, A, lda, B, incB, beta, C, incC); + + inDerivative = true; + + // dA = + my_dgemm(layout, + transA_bool ? transpose(transB) : transA, + transA_bool ? transA : transpose(transB), + transA_bool ? K : M, + transA_bool ? M : K, + N, + alpha, + transA_bool ? B : dC, + transA_bool ? incB : incC, + transA_bool ? C : dB, + transA_bool ? incC : incB, + 1.0, dA, lda); + + // dB = + my_dgemm(layout, + transB_bool ? transB : transpose(transA), + transB_bool ? transA : transB, + transB_bool ? N : K, + transB_bool ? K : N, + M, + alpha, + transB_bool ? dC : A, + transB_bool ? incC : lda, + transB_bool ? A : dC, + transB_bool ? lda : incC, + 1.0, dB, incB); + + cblas_dlascl(layout, 'G', 0, 0, 1.0, beta, M, N, dC, incC /*, extra 0*/ ); + + checkTest(Test); + + // Check memory of primal of expected derivative + checkMemoryTrace(inputs, "Expected " + Test, calls); + + // Check memory of primal of our derivative (if equal above, it + // should be the same). + checkMemoryTrace(inputs, "Found " + Test, foundCalls); + } } } + } +} + +int main() { + + dotTests(); + + gemvTests(); + + // gemmTests(); } diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index d8a3d7c7fffa..5123aec615f1 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -53,24 +53,6 @@ static void checkBlasCallsInDag(const RecordKeeper &RK, checkBlasCallsInDag(RK, blasPatterns, blasName, arg); } } - - auto Def = cast(toSearch->getOperator())->getDef(); - if (Def->isSubClassOf("b")) { - auto numArgs = toSearch->getNumArgs(); - auto opName = Def->getValueAsString("s"); - auto CalledBlas = RK.getDef(opName); - if (!CalledBlas) - errs() << " opName: " << opName << "\n"; - assert(CalledBlas); - auto expectedNumArgs = - CalledBlas->getValueAsDag("PatternToMatch")->getNumArgs(); - if (expectedNumArgs != numArgs) { - errs() << "failed calling " << opName << " in the derivative of " - << blasName << " incorrect number of params. Expected " - << expectedNumArgs << " but got " << numArgs << "\n"; - assert(expectedNumArgs == numArgs); - } - } } /// Here we check that all the Blas derivatives who call another @@ -803,6 +785,7 @@ std::string get_blas_ret_ty(StringRef dfnc_name) { return "Builder2.getVoidTy()"; } +/* void emit_deriv_blas_call(DagInit *ruleDag, const StringMap &patternMap, StringSet<> &handled, raw_ostream &os) { @@ -834,7 +817,11 @@ void emit_deriv_blas_call(DagInit *ruleDag, if (Def->isSubClassOf("DiffeRetIndex")) { typeToAdd = "byRef ? PointerType::getUnqual(call.getType()) : " "call.getType()\n"; - } else if (Def->isSubClassOf("input") || Def->isSubClassOf("adj")) { + } else if (Def->isSubClassOf("adj")) { + auto argStr = Def->getValueAsString("name"); + // primary and adj have the same type + typeToAdd = (Twine("type_") + argStr).str(); + } else if (Def->isSubClassOf("input")) { auto argStr = Def->getValueAsString("name"); // primary and adj have the same type typeToAdd = (Twine("type_") + argStr).str(); @@ -926,6 +913,7 @@ void emit_deriv_blas_call(DagInit *ruleDag, << " }\n\n"; return; } +*/ void emit_tmp_creation(Record *Def, raw_ostream &os) { const auto args = Def->getValueAsListOfStrings("args"); @@ -997,7 +985,7 @@ void emit_deriv_rule(const StringMap &patternMap, Rule &rule, const auto nameMap = rule.getArgNameMap(); const auto Def = cast(ruleDag->getOperator())->getDef(); if (Def->isSubClassOf("b")) { - emit_deriv_blas_call(ruleDag, patternMap, handled, os); + // emit_deriv_blas_call(ruleDag, patternMap, handled, os); } else if (Def->isSubClassOf("MagicInst") && Def->getName() == "noop") { // nothing to prepare } else if (Def->isSubClassOf("DiffeRetIndex")) { @@ -1015,103 +1003,137 @@ void emit_deriv_rule(const StringMap &patternMap, Rule &rule, const auto sub_Def = sub_def->getDef(); if (sub_Def->isSubClassOf("b")) { os << " //handling nested blas: " << std::to_string(i) << "\n"; - emit_deriv_blas_call(sub_Dag, patternMap, handled, os); + // emit_deriv_blas_call(sub_Dag, patternMap, handled, os); os << " //handled nested blas: " << std::to_string(i) << "\n"; } else if (sub_Def->isSubClassOf("FrobInnerProd")) { // nothing to prepare - assert(sub_Dag->getNumArgs() == 5); + assert(sub_Dag->getNumArgs() == 4); } else if (sub_Def->isSubClassOf("DiagUpdateSPMV")) { // nothing to prepare - assert(sub_Dag->getNumArgs() == 8); + assert(sub_Dag->getNumArgs() == 6); } } } } else if (Def->isSubClassOf("FrobInnerProd")) { // nothing to prepare - assert(ruleDag->getNumArgs() == 5); + assert(ruleDag->getNumArgs() == 4); } else if (Def->isSubClassOf("DiagUpdateSPMV")) { // nothing to prepare - assert(ruleDag->getNumArgs() == 8); + assert(ruleDag->getNumArgs() == 6); } else { PrintFatalError("Unhandled deriv Rule!"); } } -void rev_call_arg(StringRef argName, DagInit *ruleDag, Rule &rule, - size_t actArg, size_t &pos, raw_ostream &os) { +// Emit the corresponding code rom (ruleDag arg # pos), given +// that the arg being differentiated is argAct. +// The map offsetToBaseNames takes vinc, ld, and maps them to +// the arg name of the original vector/matrix +void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos, + raw_ostream &os) { const auto nameMap = rule.getArgNameMap(); const auto typeMap = rule.getArgTypeMap(); auto arg = ruleDag->getArg(pos); if (auto Dag = dyn_cast(arg)) { auto Def = cast(Dag->getOperator())->getDef(); - if (Def->isSubClassOf("MagicInst") && Def->getName() == "Rows") { - std::string tname, rname, cname; - tname = (Twine("arg_") + Dag->getArgNameStr(0)).str(); - if (DefInit *Def1 = dyn_cast(Dag->getArg(1))) { - auto Def1Name = Def1->getDef()->getValueAsString("name"); - assert(Def1->getDef()->isSubClassOf("adj")); - rname = (Twine("d_") + Def1Name).str(); - } else { - rname = (Twine("arg_") + Dag->getArgNameStr(1)).str(); + if (Def->isSubClassOf("MagicInst")) { + if (Def->getName() == "Rows") { + os << "get_blas_row(Builder2, "; + for (size_t i = 0; i < Dag->getNumArgs(); i++) { + rev_call_arg(Dag, rule, actArg, i, os); + os << ", "; + } + os << "byRef)"; + return; } - if (DefInit *Def2 = dyn_cast(Dag->getArg(2))) { - auto Def2Name = Def2->getDef()->getValueAsString("name"); - assert(Def2->getDef()->isSubClassOf("adj")); - cname = (Twine("d_") + Def2Name).str(); - } else { - cname = (Twine("arg_") + Dag->getArgNameStr(2)).str(); + if (Def->getName() == "ld") { + assert(Dag->getNumArgs() == 5); + //(ld $A, $transa, $lda, $m, $k) + const auto ldName = Dag->getArgNameStr(2); + const auto dim1Name = Dag->getArgNameStr(3); + const auto dim2Name = Dag->getArgNameStr(4); + const auto matName = Dag->getArgNameStr(0); + os << "{get_cached_mat_width(Builder2, "; + rev_call_arg(Dag, rule, actArg, 1, os); + os << ", arg_" << ldName << ", arg_" << dim1Name << ", arg_" << dim2Name + << ", cache_" << matName << ", byRef)}"; + return; } - os << "get_blas_row(Builder2, " << tname << ", " << rname << ", " << cname - << ", byRef)"; - } else if (Def->isSubClassOf("MagicInst") && Def->getName() == "ld") { - assert(Dag->getNumArgs() == 5); - //(ld $A, $transa, $lda, $m, $k) - const auto transName = Dag->getArgNameStr(1); - const auto ldName = Dag->getArgNameStr(2); - const auto dim1Name = Dag->getArgNameStr(3); - const auto dim2Name = Dag->getArgNameStr(4); - const auto matName = Dag->getArgNameStr(0); - os << "get_cached_mat_width(Builder2, " - << "arg_" << transName << ", arg_" << ldName << ", arg_" << dim1Name - << ", arg_" << dim2Name << ", cache_" << matName << ", byRef)"; - } else { - errs() << Def->getName() << "\n"; - PrintFatalError("Dag/Def that isn't a DiffeRet!!"); } + + errs() << Def->getName() << "\n"; + PrintFatalError("Dag/Def that isn't a DiffeRet!!"); } else if (DefInit *DefArg = dyn_cast(arg)) { auto Def = DefArg->getDef(); if (Def->isSubClassOf("DiffeRetIndex")) { - os << "dif"; + os << "{dif}"; } else if (Def->isSubClassOf("adj")) { auto name = Def->getValueAsString("name"); - os << "d_" << name; + os << "{d_" << name; + size_t argPosition = (size_t)(-1); + for (size_t i = 0; i < rule.nameVec.size(); i++) { + if (rule.nameVec[i] == name) { + argPosition = i; + break; + } + } + if (argPosition == (size_t)(-1)) { + errs() << "couldn't find name: " << name << " ap=" << argPosition + << "\n"; + PrintFatalError("arg not in inverted nameMap!"); + } + auto ty = rule.argTypesFull.lookup(argPosition); + auto incName = rule.nameVec[argPosition + 1]; + if (ty == ArgType::vincData || ty == ArgType::mldData) + os << ", arg_" << incName; + else + assert(ty == ArgType::fp || ty == ArgType::ap); + os << "}"; } else if (Def->isSubClassOf("input")) { auto name = Def->getValueAsString("name"); - os << "input_" << name; + os << "{input_" << name; + size_t argPosition = (size_t)(-1); + for (size_t i = 0; i < rule.nameVec.size(); i++) { + if (rule.nameVec[i] == name) { + argPosition = i; + break; + } + } + if (argPosition == (size_t)(-1)) { + errs() << "couldn't find name: " << name << " ap=" << argPosition + << "\n"; + PrintFatalError("arg not in inverted nameMap!"); + } + auto ty = rule.argTypesFull.lookup(argPosition); + auto incName = rule.nameVec[argPosition + 1]; + if (ty == ArgType::vincData) + os << ", (cache_" << name << " ? const_one : arg_" << incName << ")"; + else + assert(ty == ArgType::fp || ty == ArgType::ap || + ty == ArgType::mldData); + os << "}"; } else if (Def->isSubClassOf("use")) { auto name = Def->getValueAsString("name"); - os << "mat_" << name; - } else if (Def->isSubClassOf("MagicInst")) { - errs() << "MagicInst\n"; + os << "{mat_" << name << "}"; } else if (Def->isSubClassOf("Constant")) { auto val = Def->getValueAsString("value"); - os << "to_blas_fp_callconv(Builder2, ConstantFP::get(fpType, " << val + os << "{to_blas_fp_callconv(Builder2, ConstantFP::get(fpType, " << val << "), byRef, blasFPType, allocationBuilder, \"constant.fp." << val - << "\")"; + << "\")}"; } else if (Def->isSubClassOf("Char")) { auto val = Def->getValueAsString("value"); - os << "to_blas_callconv(Builder2, ConstantInt::get(charType, '" << val + os << "{to_blas_callconv(Builder2, ConstantInt::get(charType, '" << val << "'), byRef, nullptr, allocationBuilder, \"constant.char." << val - << "\")"; + << "\")}"; } else if (Def->isSubClassOf("ConstantInt")) { auto val = Def->getValueAsInt("value"); - os << "to_blas_callconv(Builder2, ConstantInt::get(intType, " << val + os << "{to_blas_callconv(Builder2, ConstantInt::get(intType, " << val << "), byRef, intType, allocationBuilder, \"constant.int." << val - << "\")"; + << "\")}"; } else if (Def->isSubClassOf("transpose")) { auto name = Def->getValueAsString("name"); - os << "arg_transposed_" << name; + os << "{arg_transposed_" << name << "}"; } else { errs() << Def->getName() << "\n"; PrintFatalError("Def that isn't a DiffeRet!"); @@ -1132,74 +1154,42 @@ void rev_call_arg(StringRef argName, DagInit *ruleDag, Rule &rule, // and based on that get the fp/int + scalar/vector type auto ty = typeMap.lookup(argPosition); - // Now we create the adj call args through concating type and primal name - if (ty == ArgType::len) { - os << "arg_" << name; - } else if (ty == ArgType::fp || ty == ArgType::ap || - ty == ArgType::vincData) { + switch (ty) { + case ArgType::cblas_layout: + case ArgType::len: + case ArgType::fp: + case ArgType::ap: + case ArgType::trans: + case ArgType::diag: + case ArgType::uplo: + case ArgType::side: + case ArgType::vincInc: + case ArgType::vincData: + case ArgType::mldData: { + os << "{"; if (argPosition == actArg) { os << "d_" << name; } else { os << "arg_" << name; } - } else if (ty == ArgType::vincInc) { - auto prevArg = ruleDag->getArg(pos - 1); - if (DefInit *DefArg = dyn_cast(prevArg)) { - auto Def = DefArg->getDef(); - if (Def->isSubClassOf("adj")) { - // all ok, single inc after shadow of vec - // use original inc, since shadow is never cached - os << "arg_" << name; - } else { - auto prevName = Def->getValueAsString("name"); - os << "(cache_" << prevName << " ? const_one : arg_" << name << ")"; - } - } else { - auto prevName = ruleDag->getArgNameStr(pos - 1); - os << "(cache_" << prevName << " ? const_one : arg_" << name << ")"; - } - } else if (ty == ArgType::mldData) { - // TODO: update this to use width_ instead of true_, - // similar to the vector inc case - auto nextName = ruleDag->getArgNameStr(pos + 1); - // get the position of the argument in the primary blas call - auto nextArgPosition = nameMap.lookup(nextName); - // and based on that get the fp/int + scalar/vector type - auto nextTy = typeMap.lookup(nextArgPosition); - if (pos == actArg) { - assert(nextTy == ArgType::mldLD); - os << "d_" << name << ", true_" << nextName; - pos++; // extra ++ due to also handling mldLD - } else { - // if this matrix got cached, we need more complex logic - // to determine the next arg. Thus handle it once we reach it - os << "arg_" << name; + if (ty == ArgType::vincData) { + auto incName = rule.nameVec[argPosition + 1]; + os << ", (cache_" << name << " ? const_one : arg_" << incName << ")"; } - } else if (ty == ArgType::mldLD) { - auto prevArg = ruleDag->getArg(pos - 1); - if (DefInit *DefArg = dyn_cast(prevArg)) { - auto Def = DefArg->getDef(); - if (Def->isSubClassOf("adj")) { - // all ok, single LD after shadow of mat - // use original ld, since shadow is never cached - os << "arg_" << name; + if (ty == ArgType::mldData) { + auto ldName = rule.nameVec[argPosition + 1]; + if (argPosition == actArg) { + os << ", true_" << ldName; } else { - errs() << rule.to_string() << "\n"; - PrintFatalError("sholdn't be hit?\n"); + // if this matrix got cached, we need more complex logic + // to determine the next arg. Thus handle it once we reach it } - } else { - errs() << rule.to_string() << "\n"; - llvm::errs() << "name: " << name << " typename: " << ty << "\n"; - PrintFatalError("shouldn't be hit??\n"); } - } else if (ty == ArgType::trans || ty == ArgType::diag || - ty == ArgType::uplo || ty == ArgType::side) { - os << "arg_" << name; - // Extra handled in the calling function, so - // if we are here for a layout arg something went wrong (error) - //} else if (ty == ArgType::cblas_layout) { - // os << "arg_" << name; - } else { + + os << "}"; + return; + } + default: errs() << "name: " << name << " typename: " << ty << "\n"; llvm_unreachable("unimplemented input type in reverse mode!\n"); } @@ -1208,7 +1198,7 @@ void rev_call_arg(StringRef argName, DagInit *ruleDag, Rule &rule, // fill the result string and return the number of added args void rev_call_args(StringRef argName, Rule &rule, size_t actArg, - raw_ostream &os, int subRule = -1) { + raw_ostream &os, int subRule, StringRef func) { const auto nameMap = rule.getArgNameMap(); @@ -1221,42 +1211,36 @@ void rev_call_args(StringRef argName, Rule &rule, size_t actArg, numArgs = ruleDag->getNumArgs(); } + os << " std::vector" << argName << ";\n"; + // layout exist only under the cBLas ABI and not for all fncs. bool fncHasLayout = (ruleDag->getArgNameStr(0) == "layout"); - if (!fncHasLayout) { - os << " std::vector" << argName << " = {"; - for (size_t pos = 0; pos < numArgs;) { - if (pos > 0) { - os << ", "; - } - - rev_call_arg(argName, ruleDag, rule, actArg, pos, os); - pos++; - } - os << "};\n"; - return; + if (fncHasLayout) { + // Fnc has a layout if cBLAS, that makes it more complex. + // Distinguish later trough byRef if it is cblas (thus has layout) + os << " if (!byRef) " << argName << ".push_back(arg_layout);\n"; } - // Fnc has a layout if cBLAS, that makes it more complex. - // Distinguish later trough byRef if it is cblas (thus has layout) - - os << " std::vector" << argName << ";\n"; - os << " if (!byRef) " << argName << ".push_back(arg_layout);\n"; - os << " auto tmp = {\n"; - // just replace argOps with rule - for (size_t pos = 1; pos < numArgs;) { - if (pos > 1) { - os << ", "; - } - rev_call_arg(argName, ruleDag, rule, actArg, pos, os); - pos++; + for (size_t pos = fncHasLayout ? 1 : 0; pos < numArgs; pos++) { + os << " for (auto item : "; + rev_call_arg(ruleDag, rule, actArg, pos, os); + os << ") " << argName << ".push_back(item);\n"; } - os << "};\n"; - os << " for (auto val : tmp) " << argName << ".push_back(val);\n"; + os << " if (byRef) {\n"; + int n = 0; + if (func == "gemv") + n = 1; + if (func == "gemm") + n = 2; + for (int i = 0; i < n; i++) + os << " " << argName + << ".push_back(ConstantInt::get(intType, 1));\n"; + os << " }\n"; } void emit_fret_call(StringRef dfnc_name, StringRef argName, StringRef name, StringRef bb, raw_ostream &os) { + os << "{\n"; if (dfnc_name == "inner_prod") { os << " auto derivcall_inner_prod = \n" " getorInsertInnerProd(" @@ -1268,6 +1252,21 @@ void emit_fret_call(StringRef dfnc_name, StringRef argName, StringRef name, << " CallInst *cubcall = " "cast(derivcall_inner_prod);\n"; } else { + os << " SmallVector tys; for (auto arg : " << argName + << ") tys.push_back(arg->getType());\n"; + std::string dfnc_ret_ty = get_blas_ret_ty(dfnc_name); + os << " llvm::FunctionType *FT" << dfnc_name << " = FunctionType::get(" + << dfnc_ret_ty << ", tys, false);\n"; + os << " auto derivcall_" << dfnc_name + << " = gutils->oldFunc->getParent()->getOrInsertFunction(\n" + << " (blas.prefix + blas.floatType + \"" << dfnc_name + << "\" + blas.suffix).str(), FT" << dfnc_name << ");\n"; + + os << " if (auto F = dyn_cast(derivcall_" << dfnc_name + << ".getCallee()))\n" + << " {\n" + << " attribute_" << dfnc_name << "(blas, F);\n" + << " }\n\n"; os << " CallInst *cubcall = " "cast(" << bb << ".CreateCall(derivcall_" << dfnc_name << ", " << argName @@ -1282,6 +1281,7 @@ void emit_fret_call(StringRef dfnc_name, StringRef argName, StringRef name, << " addToDiffe(orig_" << name << ", cubcall, " << bb << ", fpType);\n" << " }\n"; + os << "}\n"; } // todo: update rt_active_ to use actual dag requirements, @@ -1500,11 +1500,11 @@ void emit_rev_rewrite_rules(const StringMap &patternMap, emit_if_rule_condition(ruleDag, name, " ", os); emit_runtime_condition(ruleDag, name, " ", "Builder2", (ty == ArgType::fp), os); - rev_call_args("args1", rule, actArg, os); + const auto dfnc_name = Def->getValueAsString("s"); + rev_call_args("args1", rule, actArg, os, -1, dfnc_name); os << " const auto Defs = gutils->getInvertedBundles(&call, {" << valueTypes << "}, Builder2, /* lookup */ true);\n"; - const auto dfnc_name = Def->getValueAsString("s"); if (ty == ArgType::fp) { // extra handling, since we will update only a fp scalar as part of the // return struct it's presumably done by setting it to the value @@ -1512,8 +1512,23 @@ void emit_rev_rewrite_rules(const StringMap &patternMap, emit_fret_call(dfnc_name, "ArrayRef(args1)", name, "Builder2", os); } else { + os << " SmallVector tys; for (auto arg : args1) " + "tys.push_back(arg->getType());\n"; + std::string dfnc_ret_ty = get_blas_ret_ty(dfnc_name); + os << " llvm::FunctionType *FT" << dfnc_name + << " = FunctionType::get(" << dfnc_ret_ty << ", tys, false);\n"; + os << " auto derivcall_" << dfnc_name + << " = gutils->oldFunc->getParent()->getOrInsertFunction(\n" + << " (blas.prefix + blas.floatType + \"" << dfnc_name + << "\" + blas.suffix).str(), FT" << dfnc_name << ");\n"; + + os << " if (auto F = dyn_cast(derivcall_" << dfnc_name + << ".getCallee()))\n" + << " {\n" + << " attribute_" << dfnc_name << "(blas, F);\n" + << " }\n\n"; os << " Builder2.CreateCall(derivcall_" << dfnc_name - << ", ArrayRef(args1), Defs);\n"; + << ", args1, Defs);\n"; } emit_runtime_continue(ruleDag, name, " ", "Builder2", (ty == ArgType::fp), os); @@ -1525,7 +1540,7 @@ void emit_rev_rewrite_rules(const StringMap &patternMap, os << " // DiagUpdateSPMV\n"; emit_if_rule_condition(ruleDag, name, " ", os); emit_runtime_condition(ruleDag, name, " ", "Builder2", true, os); - rev_call_args("args1", rule, actArg, os); + rev_call_args("args1", rule, actArg, os, -1, ""); os << " const auto Defs = gutils->getInvertedBundles(&call, {" << valueTypes << "}, Builder2, /* lookup */ true);\n"; // Now that we have the defs, we can create the call @@ -1541,7 +1556,7 @@ void emit_rev_rewrite_rules(const StringMap &patternMap, os << " // FrobInnerProd\n"; emit_if_rule_condition(ruleDag, name, " ", os); emit_runtime_condition(ruleDag, name, " ", "Builder2", true, os); - rev_call_args("args1", rule, actArg, os); + rev_call_args("args1", rule, actArg, os, -1, ""); os << " const auto Defs = gutils->getInvertedBundles(&call, {" << valueTypes << "}, Builder2, /* lookup */ true);\n"; // Now that we have the defs, we can create the call @@ -1565,32 +1580,52 @@ void emit_rev_rewrite_rules(const StringMap &patternMap, // handle seq rules for (size_t i = 0; i < ruleDag->getNumArgs(); i++) { - std::string argName = "args" + std::to_string(i); - rev_call_args(argName, rule, actArg, os, i); Init *subArg = ruleDag->getArg(i); DagInit *sub_Dag = cast(subArg); if (auto sub_def = dyn_cast(sub_Dag->getOperator())) { const auto sub_Def = sub_def->getDef(); if (sub_Def->isSubClassOf("b")) { const auto dfnc_name = sub_Def->getValueAsString("s"); + std::string argName = "args" + std::to_string(i); + rev_call_args(argName, rule, actArg, os, i, dfnc_name); os << " //handling nested blas: " << std::to_string(i) << "\n"; - emit_deriv_blas_call(sub_Dag, patternMap, handled, os); + // emit_deriv_blas_call(sub_Dag, patternMap, handled, os); if (get_blas_ret_ty(dfnc_name) == "fpType") { // returns, so assume it's the last step of the sequence // and update the diffe accordingly assert(i == ruleDag->getNumArgs() - 1); emit_fret_call(dfnc_name, argName, name, "Builder2", os); } else { + os << " SmallVector tys; for (auto arg : " << argName + << ") tys.push_back(arg->getType());\n"; + std::string dfnc_ret_ty = get_blas_ret_ty(dfnc_name); + os << " llvm::FunctionType *FT" << dfnc_name + << " = FunctionType::get(" << dfnc_ret_ty + << ", tys, false);\n"; + os << " auto derivcall_" << dfnc_name + << " = gutils->oldFunc->getParent()->getOrInsertFunction(\n" + << " (blas.prefix + blas.floatType + \"" << dfnc_name + << "\" + blas.suffix).str(), FT" << dfnc_name << ");\n"; + + os << " if (auto F = dyn_cast(derivcall_" + << dfnc_name << ".getCallee()))\n" + << " {\n" + << " attribute_" << dfnc_name << "(blas, F);\n" + << " }\n\n"; os << " Builder2.CreateCall(derivcall_" << dfnc_name << ", " << argName << ", Defs);\n"; } os << " //handled nested blas: " << std::to_string(i) << "\n"; } else if (sub_Def->isSubClassOf("FrobInnerProd")) { - assert(sub_Dag->getNumArgs() == 5); + std::string argName = "args" + std::to_string(i); + rev_call_args(argName, rule, actArg, os, i, ""); + assert(sub_Dag->getNumArgs() == 4); assert(ty == ArgType::fp); emit_fret_call("inner_prod", argName, name, "Builder2", os); } else if (sub_Def->isSubClassOf("DiagUpdateSPMV")) { - assert(sub_Dag->getNumArgs() == 8); + std::string argName = "args" + std::to_string(i); + rev_call_args(argName, rule, actArg, os, i, ""); + assert(sub_Dag->getNumArgs() == 6); assert(ty == ArgType::ap); os << "callSPMVDiagUpdate(Builder2, *gutils->oldFunc->getParent(), " "blas, intType, blasCharType, blasFPType, type_vec_like, " diff --git a/enzyme/tools/enzyme-tblgen/caching.cpp b/enzyme/tools/enzyme-tblgen/caching.cpp index cefbec9b0ba7..bede971bc5af 100644 --- a/enzyme/tools/enzyme-tblgen/caching.cpp +++ b/enzyme/tools/enzyme-tblgen/caching.cpp @@ -281,7 +281,7 @@ os << " if (EnzymeBlasCopy) {\n" << " auto *len2 = load_if_ref(BuilderZ, intType, N, byRef);\n" << " auto *matSize = BuilderZ.CreateMul(len1, len2);\n" << " auto malins = CreateAllocation(BuilderZ, fpType, matSize, \"cache." << matName << "\");\n" -<< " ValueType valueTypes[] = {" << valueTypes << "};\n" +<< " SmallVector valueTypes = {" << valueTypes << "};\n" <<" valueTypes[" << argIdx << "] = ValueType::Primal;\n" << " if (byRef) valueTypes[" << argIdx+1 << "] = ValueType::Primal;\n"; for (auto len_pos : dimensions ) { @@ -290,7 +290,9 @@ os << " if (byRef) valueTypes[" << len_pos << "] = ValueType::Primal;\n"; os << " if (EnzymeLapackCopy) {\n" << " Value *uplo = llvm::ConstantInt::get(charTy, 0);\n" // garbage data, just should not match U or L << " uplo = to_blas_callconv(BuilderZ, uplo, byRef, nullptr, allocationBuilder, \"copy.garbage\");\n" -<< " Value *args[7] = {uplo, M, N, arg_" << matName << ", arg_" << ldName << ", malins, M};\n" +<< " SmallVector args = {uplo, M, N, arg_" << matName << ", arg_" << ldName << ", malins, M};\n" +<< " if (!byRef) {\n" +<< " args.insert(args.begin(), arg_layout); valueTypes.insert(valueTypes.begin(), ValueType::Primal); }\n" << " callMemcpyStridedLapack(BuilderZ, *gutils->oldFunc->getParent(), blas, args, gutils->getInvertedBundles(&call, valueTypes, BuilderZ, /*lookup*/false));\n" << " } else {\n" << " auto dmemcpy = getOrInsertMemcpyMat(*gutils->oldFunc->getParent(), fpType, cast(malins->getType()), intType, 0, 0);\n" diff --git a/enzyme/tools/enzyme-tblgen/datastructures.cpp b/enzyme/tools/enzyme-tblgen/datastructures.cpp index d38dddaa456c..2d7a14d26e9f 100644 --- a/enzyme/tools/enzyme-tblgen/datastructures.cpp +++ b/enzyme/tools/enzyme-tblgen/datastructures.cpp @@ -60,11 +60,13 @@ bool isVecLikeArg(ArgType ty) { return false; } -bool isArgUsed(StringRef toFind, const DagInit *toSearch) { +bool isArgUsed(StringRef toFind, const DagInit *toSearch, + ArrayRef nameVec, + const DenseMap &argTypesFull) { for (size_t i = 0; i < toSearch->getNumArgs(); i++) { if (DagInit *arg = dyn_cast(toSearch->getArg(i))) { // os << " Recursing. Magic!\n"; - if (isArgUsed(toFind, arg)) + if (isArgUsed(toFind, arg, nameVec, argTypesFull)) return true; } else { auto name = toSearch->getArgNameStr(i); @@ -80,30 +82,79 @@ bool isArgUsed(StringRef toFind, const DagInit *toSearch) { if (toFind == transName) { return true; } - } else if (opName == "adj" || Def->isSubClassOf("adj")) { + } else if (opName == "adj" || Def->isSubClassOf("adj") || + opName == "input" || Def->isSubClassOf("input")) { // shadow is unrelated, ignore it + // However, consider the extra added inc. + + auto name = Def->getValueAsString("name"); + + size_t argPosition = (size_t)(-1); + for (size_t i = 0; i < nameVec.size(); i++) { + if (nameVec[i] == name) { + argPosition = i; + break; + } + } + if (argPosition == (size_t)(-1)) { + errs() << "couldn't find name: " << name << " ap=" << argPosition + << "\n"; + PrintFatalError("arg not in inverted nameMap!"); + } + auto ty = argTypesFull.lookup(argPosition); + if (ty == ArgType::vincData || + ((opName == "adj" || Def->isSubClassOf("adj")) && + ty == ArgType::mldData)) { + auto incName = nameVec[argPosition + 1]; + if (incName == toFind) + return true; + } } } else { if (name == toFind) { return true; } + size_t argPosition = (size_t)(-1); + for (size_t i = 0; i < nameVec.size(); i++) { + if (nameVec[i] == name) { + argPosition = i; + break; + } + } + if (argPosition == (size_t)(-1)) { + errs() << "couldn't find name: " << name << " ap=" << argPosition + << "\n"; + PrintFatalError("arg not in inverted nameMap!"); + } + auto ty = argTypesFull.lookup(argPosition); + if (ty == ArgType::vincData || ty == ArgType::mldData) { + auto incName = nameVec[argPosition + 1]; + if (incName == toFind) + return true; + } } } } return false; } -Rule::Rule(DagInit *dag, size_t activeArgIdx, +Rule::Rule(ArrayRef nameVec, DagInit *dag, size_t activeArgIdx, const StringMap &patternArgs, const DenseMap &patternTypes, const DenseSet &patternMutables) - : rewriteRule(dag), activeArg(activeArgIdx) { + : rewriteRule(dag), activeArg(activeArgIdx), + nameVec(nameVec.begin(), nameVec.end()) { // For each arg found in the dag: // 1) copy patternArgs to ruleArgs if arg shows up in this rule for (auto argName : patternArgs.keys()) { assert(patternArgs.count(argName) == 1); size_t argPos = patternArgs.lookup(argName); - bool argUsedInRule = isArgUsed(argName, rewriteRule); + argTypesFull.insert(*patternTypes.find(argPos)); + } + for (auto argName : patternArgs.keys()) { + assert(patternArgs.count(argName) == 1); + size_t argPos = patternArgs.lookup(argName); + bool argUsedInRule = isArgUsed(argName, rewriteRule, nameVec, argTypesFull); if (argUsedInRule) { argNameToPos.insert(std::pair(argName, argPos)); // 2) look up and copy the corresponding argType @@ -331,7 +382,7 @@ TGPattern::TGPattern(Record *r) : blasName(r->getNameInitAsString()) { DagInit *derivRule = cast(derivOp.value()); size_t actIdx = posActArgs[derivOp.index()]; rules.push_back( - Rule(derivRule, actIdx, argNameToPos, argTypes, mutables)); + Rule(args, derivRule, actIdx, argNameToPos, argTypes, mutables)); } } diff --git a/enzyme/tools/enzyme-tblgen/datastructures.h b/enzyme/tools/enzyme-tblgen/datastructures.h index 3f18e541e870..3416f3105491 100644 --- a/enzyme/tools/enzyme-tblgen/datastructures.h +++ b/enzyme/tools/enzyme-tblgen/datastructures.h @@ -40,7 +40,9 @@ using namespace llvm; const char *TyToString(ArgType ty); bool isVecLikeArg(ArgType ty); -bool isArgUsed(StringRef toFind, const DagInit *toSearch); +bool isArgUsed(StringRef toFind, const DagInit *toSearch, + llvm::ArrayRef nameVec, + const llvm::DenseMap &argTypesFull); /// Subset of the general pattern info, /// but only the part that affects the specific argument being active. @@ -55,7 +57,10 @@ class Rule { bool BLASLevel2or3; public: - Rule(DagInit *dag, size_t activeArgIdx, const StringMap &patternArgs, + SmallVector nameVec; + DenseMap argTypesFull; + Rule(ArrayRef nameVec, DagInit *dag, size_t activeArgIdx, + const StringMap &patternArgs, const DenseMap &patternTypes, const DenseSet &patternMutables); bool isBLASLevel2or3() const; From b14e0abb4dc2718f87d1ad7ce49cb9c5bc3cd859 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 20 Sep 2023 14:03:44 -0400 Subject: [PATCH 21/29] Fix test infra for blas [new pm not yet supported per lapack copy arg --- enzyme/test/Integration/ReverseMode/blas.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/enzyme/test/Integration/ReverseMode/blas.cpp b/enzyme/test/Integration/ReverseMode/blas.cpp index 1561cda793f2..66b161bc635e 100644 --- a/enzyme/test/Integration/ReverseMode/blas.cpp +++ b/enzyme/test/Integration/ReverseMode/blas.cpp @@ -6,12 +6,12 @@ // RUN: %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -S | %lli - // RUN: %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -S | %lli - // RUN: %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -S | %lli - -// TORUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi -// TORUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi -// TORUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi -// TORUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -// TORUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -// TORUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi #include "test_utils.h" From 21dcb51d0eb4c297356d5661e414ab5d15d30ae2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 20 Sep 2023 13:24:17 -0500 Subject: [PATCH 22/29] Implement concat (#1450) --- enzyme/Enzyme/Utils.h | 13 ++ enzyme/tools/enzyme-tblgen/blas-tblgen.cpp | 140 ++------------------- 2 files changed, 23 insertions(+), 130 deletions(-) diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 8d8d47bcff51..793803635556 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -1642,6 +1642,19 @@ llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B, llvm::Value *arg_ld, llvm::Value *dim_1, llvm::Value *dim_2, bool cacheMat, bool byRef); + +template static inline void nothing(T...){}; +template +static inline llvm::SmallVector concat_values(T... t) { + llvm::SmallVector res; + auto append = [&](llvm::ArrayRef V) { + res.append(V.begin(), V.end()); + return 0; + }; + nothing(append(t)...); + return res; +} + llvm::Value *is_normal(llvm::IRBuilder<> &B, llvm::Value *trans, bool byRef); llvm::Value *is_uper(llvm::IRBuilder<> &B, llvm::Value *trans, bool byRef); llvm::Value *select_vec_dims(llvm::IRBuilder<> &B, llvm::Value *trans, diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index 5123aec615f1..3e903ed98061 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -785,136 +785,6 @@ std::string get_blas_ret_ty(StringRef dfnc_name) { return "Builder2.getVoidTy()"; } -/* -void emit_deriv_blas_call(DagInit *ruleDag, - const StringMap &patternMap, - StringSet<> &handled, raw_ostream &os) { - - const auto Def = cast(ruleDag->getOperator())->getDef(); - const auto dfnc_name = Def->getValueAsString("s"); - if (patternMap.find(dfnc_name) == patternMap.end()) { - PrintFatalError("calling unknown Blas function"); - } - TGPattern calledPattern = patternMap.find(dfnc_name)->getValue(); - bool derivlv23 = calledPattern.isBLASLevel2or3(); - DenseSet mutableArgs = calledPattern.getMutableArgs(); - - if (handled.find(dfnc_name) != handled.end()) - return; - else - handled.insert(dfnc_name); - - auto retTy = get_blas_ret_ty(dfnc_name); - - // insert arg types based on .td file - std::string typeString = ""; - bool first = true; - for (size_t i = 0; i < ruleDag->getNumArgs(); i++) { - Init *subArg = ruleDag->getArg(i); - if (DefInit *def = dyn_cast(subArg)) { - const auto Def = def->getDef(); - std::string typeToAdd = ""; - if (Def->isSubClassOf("DiffeRetIndex")) { - typeToAdd = "byRef ? PointerType::getUnqual(call.getType()) : " - "call.getType()\n"; - } else if (Def->isSubClassOf("adj")) { - auto argStr = Def->getValueAsString("name"); - // primary and adj have the same type - typeToAdd = (Twine("type_") + argStr).str(); - } else if (Def->isSubClassOf("input")) { - auto argStr = Def->getValueAsString("name"); - // primary and adj have the same type - typeToAdd = (Twine("type_") + argStr).str(); - } else if (Def->isSubClassOf("Constant")) { - typeToAdd = "blasFPType"; - } else if (Def->isSubClassOf("Char")) { - typeToAdd = "byRef ? (Type*)PointerType::getUnqual(charType) : " - "(Type*)charType"; - } else if (Def->isSubClassOf("ConstantInt")) { - typeToAdd = "byRef ? (Type*)blasIntType : (Type*)intType"; - } else if (Def->isSubClassOf("transpose")) { - auto argStr = Def->getValueAsString("name"); - // transpose the given trans arg, but type stays - typeToAdd = (Twine("type_") + argStr).str(); - } else if (Def->isSubClassOf("use")) { - // we only use tmp matrices, so mat type - typeToAdd = "type_vec_like"; - } else { - PrintFatalError(Def->getLoc(), "PANIC! Unsupported Definit"); - } - typeString += ((first) ? "" : ", ") + typeToAdd; - } else { - if (auto Dag = dyn_cast(subArg)) { - auto Def = cast(Dag->getOperator())->getDef(); - if (Def->isSubClassOf("MagicInst") && Def->getName() == "Rows") { - if (!first) - typeString += ", "; - if (DefInit *def = dyn_cast(Dag->getArg(1))) { - const auto Def = def->getDef(); - assert(Def->isSubClassOf("adj")); - typeString += - (Twine("type_") + Def->getValueAsString("name")).str(); - } else { - assert(Dag->getArgNameStr(1) != ""); - typeString += (Twine("type_") + Dag->getArgNameStr(1)).str(); - first = false; - } - continue; - } else if (Def->isSubClassOf("MagicInst") && Def->getName() == "ld") { - if (!first) - typeString += ", "; - //(ld $A, $transa, $lda, $m, $k) - // Either of 2,3,4 would work - typeString += (Twine("type_") + Dag->getArgNameStr(2)).str(); - first = false; - continue; - } - } - const auto argStr = ruleDag->getArgNameStr(i); - // skip layout because it is cblas only, - // so not relevant for the byRef Fortran abi. - // Optionally add it later as first arg for byRef. - if (argStr == "layout") - continue; - typeString += (first ? "" : ", "); - typeString += (Twine("type_") + argStr).str(); - } - first = false; - } - - std::string dfnc_ret_ty = get_blas_ret_ty(dfnc_name); - os << " llvm::FunctionType *FT" << dfnc_name << " = nullptr;\n"; - if (derivlv23) { - os << " if(byRef) {\n" - << " Type* tys" << dfnc_name << "[] = {" << typeString << "};\n" - << " FT" << dfnc_name << " = FunctionType::get(" << dfnc_ret_ty - << ", tys" << dfnc_name << ", false);\n" - << " } else {\n" - << " Type* tys" << dfnc_name << "[] = {type_layout, " << typeString - << "};\n" - << " FT" << dfnc_name << " = FunctionType::get(" << dfnc_ret_ty - << ", tys" << dfnc_name << ", false);\n" - << " }\n"; - } else { - os << " Type* tys" << dfnc_name << "[] = {" << typeString << "};\n" - << " FT" << dfnc_name << " = FunctionType::get(" << dfnc_ret_ty - << ", tys" << dfnc_name << ", false);\n"; - } - - os << " auto derivcall_" << dfnc_name - << " = gutils->oldFunc->getParent()->getOrInsertFunction(\n" - << " (blas.prefix + blas.floatType + \"" << dfnc_name - << "\" + blas.suffix).str(), FT" << dfnc_name << ");\n"; - - os << " if (auto F = dyn_cast(derivcall_" << dfnc_name - << ".getCallee()))\n" - << " {\n" - << " attribute_" << dfnc_name << "(blas, F);\n" - << " }\n\n"; - return; -} -*/ - void emit_tmp_creation(Record *Def, raw_ostream &os) { const auto args = Def->getValueAsListOfStrings("args"); // allocating tmp variables is optional, return if not required @@ -1047,6 +917,16 @@ void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos, os << "byRef)"; return; } + if (Def->getName() == "Concat") { + os << "concat_values("; + for (size_t i = 0; i < Dag->getNumArgs(); i++) { + if (i != 0) + os << ", "; + rev_call_arg(Dag, rule, actArg, i, os); + } + os << ")"; + return; + } if (Def->getName() == "ld") { assert(Dag->getNumArgs() == 5); //(ld $A, $transa, $lda, $m, $k) From eb0943967a47f5d356e58c9a88fffbf707b3736a Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 21 Sep 2023 01:01:46 -0500 Subject: [PATCH 23/29] Nice error message for undifferentiable functions (#1451) * Nice error message for undifferentiable functions * Don't clone if empty --- enzyme/Enzyme/AdjointGenerator.h | 23 +- enzyme/Enzyme/CApi.cpp | 51 ++- enzyme/Enzyme/CApi.h | 36 +- enzyme/Enzyme/DiffeGradientUtils.cpp | 6 +- enzyme/Enzyme/Enzyme.cpp | 26 +- enzyme/Enzyme/EnzymeLogic.cpp | 307 ++++++++++++++---- enzyme/Enzyme/EnzymeLogic.h | 94 ++++-- enzyme/Enzyme/FunctionUtils.cpp | 35 +- enzyme/Enzyme/FunctionUtils.h | 2 + enzyme/Enzyme/GradientUtils.cpp | 96 +++--- enzyme/Enzyme/GradientUtils.h | 18 +- enzyme/Enzyme/InstructionBatcher.cpp | 3 +- enzyme/Enzyme/TraceGenerator.cpp | 5 +- enzyme/Enzyme/TraceUtils.cpp | 17 +- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 3 +- .../test/Integration/ForwardMode/err_empty.c | 20 ++ .../test/Integration/ReverseMode/err_empty.c | 20 ++ 17 files changed, 534 insertions(+), 228 deletions(-) create mode 100644 enzyme/test/Integration/ForwardMode/err_empty.c create mode 100644 enzyme/test/Integration/ReverseMode/err_empty.c diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index c4fa2b588e5c..5eef7104043c 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -3896,8 +3896,9 @@ class AdjointGenerator Mode == DerivativeMode::ReverseModeCombined) { if (called) { subdata = &gutils->Logic.CreateAugmentedPrimal( - cast(called), subretType, argsInverted, - TR.analyzer.interprocedural, /*return is used*/ false, + RequestContext(&call, &BuilderZ), cast(called), + subretType, argsInverted, TR.analyzer.interprocedural, + /*return is used*/ false, /*shadowReturnUsed*/ false, nextTypeInfo, overwritten_args, false, gutils->getWidth(), /*AtomicAdd*/ true, @@ -4096,6 +4097,7 @@ class AdjointGenerator } newcalled = gutils->Logic.CreatePrimalAndGradient( + RequestContext(&call, &Builder2), (ReverseCacheKey){.todiff = cast(called), .retType = subretType, .constant_args = argsInverted, @@ -6851,8 +6853,9 @@ class AdjointGenerator if (called) { newcalled = gutils->Logic.CreateForwardDiff( - cast(called), subretType, argsInverted, - TR.analyzer.interprocedural, /*returnValue*/ subretused, Mode, + RequestContext(&call, &BuilderZ), cast(called), + subretType, argsInverted, TR.analyzer.interprocedural, + /*returnValue*/ subretused, Mode, ((DiffeGradientUtils *)gutils)->FreeMemory, gutils->getWidth(), tape ? tape->getType() : nullptr, nextTypeInfo, overwritten_args, /*augmented*/ subdata); @@ -7254,10 +7257,10 @@ class AdjointGenerator if (Mode == DerivativeMode::ReverseModePrimal || Mode == DerivativeMode::ReverseModeCombined) { subdata = &gutils->Logic.CreateAugmentedPrimal( - cast(called), subretType, argsInverted, - TR.analyzer.interprocedural, /*return is used*/ subretused, - shadowReturnUsed, nextTypeInfo, overwritten_args, false, - gutils->getWidth(), gutils->AtomicAdd); + RequestContext(&call, &BuilderZ), cast(called), + subretType, argsInverted, TR.analyzer.interprocedural, + /*return is used*/ subretused, shadowReturnUsed, nextTypeInfo, + overwritten_args, false, gutils->getWidth(), gutils->AtomicAdd); if (Mode == DerivativeMode::ReverseModePrimal) { assert(augmentedReturn); auto subaugmentations = @@ -7639,6 +7642,7 @@ class AdjointGenerator } newcalled = gutils->Logic.CreatePrimalAndGradient( + RequestContext(&call, &Builder2), (ReverseCacheKey){.todiff = cast(called), .retType = subretType, .constant_args = argsInverted, @@ -10066,7 +10070,8 @@ class AdjointGenerator auto callval = call.getCalledOperand(); if (!isa(callval)) callval = gutils->getNewFromOriginal(callval); - newCall->setCalledOperand(gutils->Logic.CreateNoFree(callval)); + newCall->setCalledOperand(gutils->Logic.CreateNoFree( + RequestContext(&call, &BuilderZ), callval)); } if (gutils->knownRecomputeHeuristic.find(&call) != gutils->knownRecomputeHeuristic.end()) { diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index 026d4abf223d..89c6a266cfe6 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -541,11 +541,11 @@ void EnzymeGradientUtilsSubTransferHelper( } LLVMValueRef EnzymeCreateForwardDiff( - EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType, - CDIFFE_TYPE *constant_args, size_t constant_args_size, - EnzymeTypeAnalysisRef TA, uint8_t returnValue, CDerivativeMode mode, - uint8_t freeMemory, unsigned width, LLVMTypeRef additionalArg, - CFnTypeInfo typeInfo, uint8_t *_overwritten_args, + EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, + LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, + size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue, + CDerivativeMode mode, uint8_t freeMemory, unsigned width, + LLVMTypeRef additionalArg, CFnTypeInfo typeInfo, uint8_t *_overwritten_args, size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented) { SmallVector nconstant_args((DIFFE_TYPE *)constant_args, (DIFFE_TYPE *)constant_args + @@ -556,16 +556,18 @@ LLVMValueRef EnzymeCreateForwardDiff( overwritten_args.push_back(_overwritten_args[i]); } return wrap(eunwrap(Logic).CreateForwardDiff( + RequestContext(cast_or_null(unwrap(request_req)), + unwrap(request_ip)), cast(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args, eunwrap(TA), returnValue, (DerivativeMode)mode, freeMemory, width, unwrap(additionalArg), eunwrap(typeInfo, cast(unwrap(todiff))), overwritten_args, eunwrap(augmented))); } LLVMValueRef EnzymeCreatePrimalAndGradient( - EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType, - CDIFFE_TYPE *constant_args, size_t constant_args_size, - EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed, - CDerivativeMode mode, unsigned width, uint8_t freeMemory, + EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, + LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, + size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue, + uint8_t dretUsed, CDerivativeMode mode, unsigned width, uint8_t freeMemory, LLVMTypeRef additionalArg, uint8_t forceAnonymousTape, CFnTypeInfo typeInfo, uint8_t *_overwritten_args, size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented, uint8_t AtomicAdd) { @@ -578,6 +580,8 @@ LLVMValueRef EnzymeCreatePrimalAndGradient( overwritten_args.push_back(_overwritten_args[i]); } return wrap(eunwrap(Logic).CreatePrimalAndGradient( + RequestContext(cast(unwrap(request_req)), + unwrap(request_ip)), (ReverseCacheKey){ .todiff = cast(unwrap(todiff)), .retType = (DIFFE_TYPE)retType, @@ -596,10 +600,10 @@ LLVMValueRef EnzymeCreatePrimalAndGradient( eunwrap(TA), eunwrap(augmented))); } EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal( - EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType, - CDIFFE_TYPE *constant_args, size_t constant_args_size, - EnzymeTypeAnalysisRef TA, uint8_t returnUsed, uint8_t shadowReturnUsed, - CFnTypeInfo typeInfo, uint8_t *_overwritten_args, + EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, + LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, + size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnUsed, + uint8_t shadowReturnUsed, CFnTypeInfo typeInfo, uint8_t *_overwritten_args, size_t overwritten_args_size, uint8_t forceAnonymousTape, unsigned width, uint8_t AtomicAdd) { @@ -612,14 +616,31 @@ EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal( overwritten_args.push_back(_overwritten_args[i]); } return ewrap(eunwrap(Logic).CreateAugmentedPrimal( + RequestContext(cast_or_null(unwrap(request_req)), + unwrap(request_ip)), cast(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args, eunwrap(TA), returnUsed, shadowReturnUsed, eunwrap(typeInfo, cast(unwrap(todiff))), overwritten_args, forceAnonymousTape, width, AtomicAdd)); } +LLVMValueRef EnzymeCreateBatch(EnzymeLogicRef Logic, LLVMValueRef request_req, + LLVMBuilderRef request_ip, LLVMValueRef tobatch, + unsigned width, CBATCH_TYPE *arg_types, + size_t arg_types_size, CBATCH_TYPE retType) { + + return wrap(eunwrap(Logic).CreateBatch( + RequestContext(cast_or_null(unwrap(request_req)), + unwrap(request_ip)), + cast(unwrap(tobatch)), width, + ArrayRef((BATCH_TYPE *)arg_types, + (BATCH_TYPE *)arg_types + arg_types_size), + (BATCH_TYPE)retType)); +} + LLVMValueRef EnzymeCreateTrace( - EnzymeLogicRef Logic, LLVMValueRef totrace, LLVMValueRef *sample_functions, + EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip, + LLVMValueRef totrace, LLVMValueRef *sample_functions, size_t sample_functions_size, LLVMValueRef *observe_functions, size_t observe_functions_size, const char *active_random_variables[], size_t active_random_variables_size, CProbProgMode mode, uint8_t autodiff, @@ -641,6 +662,8 @@ LLVMValueRef EnzymeCreateTrace( } return wrap(eunwrap(Logic).CreateTrace( + RequestContext(cast_or_null(unwrap(request_req)), + unwrap(request_ip)), cast(unwrap(totrace)), SampleFunctions, ObserveFunctions, ActiveRandomVariables, (ProbProgMode)mode, (bool)autodiff, eunwrap(interface))); diff --git a/enzyme/Enzyme/CApi.h b/enzyme/Enzyme/CApi.h index 0f7903437f35..3ffced41dee2 100644 --- a/enzyme/Enzyme/CApi.h +++ b/enzyme/Enzyme/CApi.h @@ -119,6 +119,8 @@ typedef enum { // but don't need the forward } CDIFFE_TYPE; +typedef enum { BT_SCALAR = 0, BT_VECTOR = 1 } CBATCH_TYPE; + typedef enum { DEM_ForwardMode = 0, DEM_ReverseModePrimal = 1, @@ -132,40 +134,6 @@ typedef enum { DEM_Condition = 1, } CProbProgMode; -LLVMValueRef EnzymeCreateForwardDiff( - EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType, - CDIFFE_TYPE *constant_args, size_t constant_args_size, - EnzymeTypeAnalysisRef TA, uint8_t returnValue, CDerivativeMode mode, - uint8_t freeMemory, unsigned width, LLVMTypeRef additionalArg, - struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args, - size_t uncacheable_args_size, EnzymeAugmentedReturnPtr augmented); - -LLVMValueRef EnzymeCreatePrimalAndGradient( - EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType, - CDIFFE_TYPE *constant_args, size_t constant_args_size, - EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed, - CDerivativeMode mode, unsigned width, uint8_t freeMemory, - LLVMTypeRef additionalArg, uint8_t forceAnonymousTape, - struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args, - size_t uncacheable_args_size, EnzymeAugmentedReturnPtr augmented, - uint8_t AtomicAdd); - -EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal( - EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType, - CDIFFE_TYPE *constant_args, size_t constant_args_size, - EnzymeTypeAnalysisRef TA, uint8_t returnUsed, uint8_t shadowReturnUsed, - struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args, - size_t uncacheable_args_size, uint8_t forceAnonymousTape, unsigned width, - uint8_t AtomicAdd); - -LLVMValueRef CreateTrace( - EnzymeLogicRef Logic, LLVMValueRef totrace, LLVMValueRef *sample_functions, - size_t sample_functions_size, LLVMValueRef *observe_functions, - size_t observe_functions_size, LLVMValueRef *generative_functions, - size_t generative_functions_size, const char *active_random_variables[], - size_t active_random_variables_size, CProbProgMode mode, uint8_t autodiff, - EnzymeTraceInterfaceRef interface); - typedef uint8_t (*CustomRuleType)(int /*direction*/, CTypeTreeRef /*return*/, CTypeTreeRef * /*args*/, struct IntList * /*knownValues*/, diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index 9025f5e076c9..c88ce960b5df 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -62,6 +62,8 @@ DiffeGradientUtils::DiffeGradientUtils( : GradientUtils(Logic, newFunc_, oldFunc_, TLI, TA, TR, invertedPointers_, constantvalues_, returnvals_, ActiveReturn, constant_values, origToNew_, mode, width, omp) { + if (oldFunc_->empty()) + return; assert(reverseBlocks.size() == 0); if (mode == DerivativeMode::ForwardMode || mode == DerivativeMode::ForwardModeSplit) { @@ -83,7 +85,6 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone( TargetLibraryInfo &TLI, TypeAnalysis &TA, FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, bool diffeReturnArg, ArrayRef constant_args, ReturnType returnValue, Type *additionalArg, bool omp) { - assert(!todiff->empty()); Function *oldFunc = todiff; assert(mode == DerivativeMode::ReverseModeGradient || mode == DerivativeMode::ReverseModeCombined || @@ -149,7 +150,8 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone( } TypeResults TR = TA.analyzeFunction(typeInfo); - assert(TR.getFunction() == oldFunc); + if (!oldFunc->empty()) + assert(TR.getFunction() == oldFunc); auto res = new DiffeGradientUtils(Logic, newFunc, oldFunc, TLI, TA, TR, invertedPointers, constant_values, diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 4772edbedfe5..275951c1d638 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -1377,7 +1377,8 @@ class EnzymeBase { ? BATCH_TYPE::SCALAR : BATCH_TYPE::VECTOR; - auto newFunc = Logic.CreateBatch(F, width, arg_types, ret_type); + auto newFunc = Logic.CreateBatch(RequestContext(CI, &Builder), F, width, + arg_types, ret_type); if (!newFunc) return false; @@ -1432,6 +1433,7 @@ class EnzymeBase { populate_overwritten_args(TA, fn, mode, overwritten_args); IRBuilder Builder(CI); + RequestContext context(CI, &Builder); // differentiate fn Function *newFunc = nullptr; @@ -1440,7 +1442,7 @@ class EnzymeBase { switch (mode) { case DerivativeMode::ForwardMode: newFunc = Logic.CreateForwardDiff( - fn, retType, constants, TA, + context, fn, retType, constants, TA, /*should return*/ primalReturn, mode, freeMemory, width, /*addedType*/ nullptr, type_args, overwritten_args, /*augmented*/ nullptr); @@ -1448,7 +1450,7 @@ class EnzymeBase { case DerivativeMode::ForwardModeSplit: { bool forceAnonymousTape = !sizeOnly && allocatedTapeSize == -1; aug = &Logic.CreateAugmentedPrimal( - fn, retType, constants, TA, + context, fn, retType, constants, TA, /*returnUsed*/ false, /*shadowReturnUsed*/ false, type_args, overwritten_args, forceAnonymousTape, width, /*atomicAdd*/ AtomicAdd); auto &DL = fn->getParent()->getDataLayout(); @@ -1484,7 +1486,7 @@ class EnzymeBase { tapeType = PointerType::getInt8PtrTy(fn->getContext()); } newFunc = Logic.CreateForwardDiff( - fn, retType, constants, TA, + context, fn, retType, constants, TA, /*should return*/ primalReturn, mode, freeMemory, width, /*addedType*/ tapeType, type_args, overwritten_args, aug); break; @@ -1492,6 +1494,7 @@ class EnzymeBase { case DerivativeMode::ReverseModeCombined: assert(freeMemory); newFunc = Logic.CreatePrimalAndGradient( + context, (ReverseCacheKey){.todiff = fn, .retType = retType, .constant_args = constants, @@ -1518,8 +1521,8 @@ class EnzymeBase { bool shadowReturnUsed = returnUsed && (retType == DIFFE_TYPE::DUP_ARG || retType == DIFFE_TYPE::DUP_NONEED); aug = &Logic.CreateAugmentedPrimal( - fn, retType, constants, TA, returnUsed, shadowReturnUsed, type_args, - overwritten_args, forceAnonymousTape, width, + context, fn, retType, constants, TA, returnUsed, shadowReturnUsed, + type_args, overwritten_args, forceAnonymousTape, width, /*atomicAdd*/ AtomicAdd); auto &DL = fn->getParent()->getDataLayout(); if (!forceAnonymousTape) { @@ -1557,6 +1560,7 @@ class EnzymeBase { newFunc = aug->fn; else newFunc = Logic.CreatePrimalAndGradient( + context, (ReverseCacheKey){.todiff = fn, .retType = retType, .constant_args = constants, @@ -1856,9 +1860,9 @@ class EnzymeBase { constants.push_back(DIFFE_TYPE::CONSTANT); } - auto newFunc = Logic.CreateTrace(F, sampleFunctions, observeFunctions, - opt->ActiveRandomVariables, mode, autodiff, - interface); + auto newFunc = Logic.CreateTrace( + RequestContext(CI, &Builder), F, sampleFunctions, observeFunctions, + opt->ActiveRandomVariables, mode, autodiff, interface); if (!autodiff) { auto call = CallInst::Create(newFunc->getFunctionType(), newFunc, args); @@ -2438,8 +2442,10 @@ class EnzymeBase { bool AtomicAdd = Arch == Triple::nvptx || Arch == Triple::nvptx64 || Arch == Triple::amdgcn; + IRBuilder<> Builder(CI); auto val = GradientUtils::GetOrCreateShadowConstant( - Logic, Logic.PPC.FAM.getResult(F), TA, fn, + RequestContext(CI, &Builder), Logic, + Logic.PPC.FAM.getResult(F), TA, fn, pair.second, /*width*/ 1, AtomicAdd); CI->replaceAllUsesWith(ConstantExpr::getPointerCast(val, CI->getType())); CI->eraseFromParent(); diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 56dbff198e58..420939ec2bc8 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -1918,10 +1918,11 @@ void restoreCache( //! return structtype if recursive function const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( - Function *todiff, DIFFE_TYPE retType, ArrayRef constant_args, - TypeAnalysis &TA, bool returnUsed, bool shadowReturnUsed, - const FnTypeInfo &oldTypeInfo_, const std::vector _overwritten_args, - bool forceAnonymousTape, unsigned width, bool AtomicAdd, bool omp) { + RequestContext context, Function *todiff, DIFFE_TYPE retType, + ArrayRef constant_args, TypeAnalysis &TA, bool returnUsed, + bool shadowReturnUsed, const FnTypeInfo &oldTypeInfo_, + const std::vector _overwritten_args, bool forceAnonymousTape, + unsigned width, bool AtomicAdd, bool omp) { if (returnUsed) assert(!todiff->getReturnType()->isEmptyTy() && !todiff->getReturnType()->isVoidTy()); @@ -1999,9 +2000,9 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( } auto &aug = CreateAugmentedPrimal( - todiff, retType, next_constant_args, TA, returnUsed, shadowReturnUsed, - oldTypeInfo_, _overwritten_args, forceAnonymousTape, width, AtomicAdd, - omp); + context, todiff, retType, next_constant_args, TA, returnUsed, + shadowReturnUsed, oldTypeInfo_, _overwritten_args, forceAnonymousTape, + width, AtomicAdd, omp); FunctionType *FTy = FunctionType::get(aug.fn->getReturnType(), dupargs, @@ -2253,26 +2254,54 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( ->second; // dyn_cast(st->getElementType(0))); } + std::map returnMapping; + + GradientUtils *gutils = GradientUtils::CreateFromClone( + *this, width, todiff, TLI, TA, oldTypeInfo, retType, constant_args, + /*returnUsed*/ returnUsed, /*shadowReturnUsed*/ shadowReturnUsed, + returnMapping, omp); + if (todiff->empty()) { - if (todiff->empty() && CustomErrorHandler) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "No augmented forward pass found for " + todiff->getName() << "\n"; + std::string s; + llvm::raw_string_ostream ss(s); + ss << "No augmented forward pass found for " + todiff->getName() << "\n"; + llvm::Value *toshow = todiff; + if (context.req) { + toshow = context.req; + ss << " at context: " << *context.req; + } else { ss << *todiff << "\n"; - CustomErrorHandler(ss.str().c_str(), wrap(todiff), - ErrorType::NoDerivative, nullptr, nullptr, nullptr); + } + (IRBuilder<>(gutils->inversionAllocs)).CreateUnreachable(); + DeleteDeadBlock(gutils->inversionAllocs); + clearFunctionAttributes(gutils->newFunc); + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(toshow), + ErrorType::NoDerivative, nullptr, wrap(todiff), + wrap(context.ip)); + auto newFunc = gutils->newFunc; + delete gutils; + return insert_or_assign( + AugmentedCachedFunctions, tup, + AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {}, + constant_args)) + ->second; + } + if (context.req) { + EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, + ss.str()); + auto newFunc = gutils->newFunc; + delete gutils; + return insert_or_assign( + AugmentedCachedFunctions, tup, + AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {}, + constant_args)) + ->second; } llvm::errs() << "mod: " << *todiff->getParent() << "\n"; llvm::errs() << *todiff << "\n"; - assert(0 && "attempting to differentiate function without definition"); llvm_unreachable("attempting to differentiate function without definition"); } - std::map returnMapping; - - GradientUtils *gutils = GradientUtils::CreateFromClone( - *this, width, todiff, TLI, TA, oldTypeInfo, retType, constant_args, - /*returnUsed*/ returnUsed, /*shadowReturnUsed*/ shadowReturnUsed, - returnMapping, omp); gutils->AtomicAdd = AtomicAdd; const SmallPtrSet guaranteedUnreachable = getGuaranteedUnreachable(gutils->oldFunc); @@ -2915,7 +2944,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( auto GV = pair.first; GV->setName("_tmp"); auto R = gutils->GetOrCreateShadowFunction( - *this, TLI, TA, todiff, pair.second, width, gutils->AtomicAdd); + context, *this, TLI, TA, todiff, pair.second, width, gutils->AtomicAdd); SmallVector, 1> users; GV->replaceAllUsesWith(ConstantExpr::getPointerCast(R, GV->getType())); GV->eraseFromParent(); @@ -3462,7 +3491,7 @@ void createInvertedTerminator(DiffeGradientUtils *gutils, } Function *EnzymeLogic::CreatePrimalAndGradient( - const ReverseCacheKey &&key, TypeAnalysis &TA, + RequestContext context, const ReverseCacheKey &&key, TypeAnalysis &TA, const AugmentedReturn *augmenteddata, bool omp) { assert(key.mode == DerivativeMode::ReverseModeCombined || @@ -3536,8 +3565,9 @@ Function *EnzymeLogic::CreatePrimalAndGradient( IRBuilder<> bb(BB); auto &aug = CreateAugmentedPrimal( - key.todiff, key.retType, key.constant_args, TA, key.returnUsed, - key.shadowReturnUsed, key.typeInfo, key.overwritten_args, + context, key.todiff, key.retType, key.constant_args, TA, + key.returnUsed, key.shadowReturnUsed, key.typeInfo, + key.overwritten_args, /*forceAnonymousTape*/ false, key.width, key.AtomicAdd, omp); SmallVector fwdargs; @@ -3578,6 +3608,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient( } auto revfn = CreatePrimalAndGradient( + context, (ReverseCacheKey){.todiff = key.todiff, .retType = key.retType, .constant_args = key.constant_args, @@ -3658,6 +3689,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient( } auto revfn = CreatePrimalAndGradient( + context, (ReverseCacheKey){.todiff = key.todiff, .retType = key.retType, .constant_args = next_constant_args, @@ -3872,21 +3904,6 @@ Function *EnzymeLogic::CreatePrimalAndGradient( assert(augmenteddata->constant_args == key.constant_args); } - if (key.todiff->empty()) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "No reverse pass found for " + key.todiff->getName() << "\n"; - ss << *key.todiff << "\n"; - if (CustomErrorHandler) { - CustomErrorHandler(ss.str().c_str(), wrap(key.todiff), - ErrorType::NoDerivative, nullptr, nullptr, nullptr); - return nullptr; - } else { - llvm_unreachable(ss.str().c_str()); - } - } - assert(!key.todiff->empty()); - ReturnType retVal = key.returnUsed ? (key.shadowReturnUsed ? ReturnType::ArgsWithTwoReturns : ReturnType::ArgsWithReturn) @@ -3904,6 +3921,40 @@ Function *EnzymeLogic::CreatePrimalAndGradient( insert_or_assign2(ReverseCachedFunctions, key, gutils->newFunc); + if (key.todiff->empty()) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << "No reverse pass found for " + key.todiff->getName() << "\n"; + llvm::Value *toshow = key.todiff; + if (context.req) { + toshow = context.req; + ss << " at context: " << *context.req; + } else { + ss << *key.todiff << "\n"; + } + BasicBlock *entry = &gutils->newFunc->getEntryBlock(); + cleanupInversionAllocs(gutils, entry); + clearFunctionAttributes(gutils->newFunc); + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(toshow), + ErrorType::NoDerivative, nullptr, wrap(key.todiff), + wrap(context.ip)); + auto newFunc = gutils->newFunc; + delete gutils; + return newFunc; + } + if (context.req) { + EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, + ss.str()); + auto newFunc = gutils->newFunc; + delete gutils; + return newFunc; + } + llvm::errs() << "mod: " << *key.todiff->getParent() << "\n"; + llvm::errs() << *key.todiff << "\n"; + llvm_unreachable("attempting to differentiate function without definition"); + } + if (augmenteddata && !augmenteddata->isComplete) { auto nf = gutils->newFunc; delete gutils; @@ -4293,9 +4344,10 @@ Function *EnzymeLogic::CreatePrimalAndGradient( } Function *EnzymeLogic::CreateForwardDiff( - Function *todiff, DIFFE_TYPE retType, ArrayRef constant_args, - TypeAnalysis &TA, bool returnUsed, DerivativeMode mode, bool freeMemory, - unsigned width, llvm::Type *additionalArg, const FnTypeInfo &oldTypeInfo_, + RequestContext context, Function *todiff, DIFFE_TYPE retType, + ArrayRef constant_args, TypeAnalysis &TA, bool returnUsed, + DerivativeMode mode, bool freeMemory, unsigned width, + llvm::Type *additionalArg, const FnTypeInfo &oldTypeInfo_, const std::vector _overwritten_args, const AugmentedReturn *augmenteddata, bool omp) { assert(retType != DIFFE_TYPE::OUT_DIFF); @@ -4478,17 +4530,6 @@ Function *EnzymeLogic::CreateForwardDiff( EmitWarning("NoCustom", *todiff, "Cannot use provided custom derivative pass"); } - if (todiff->empty() && CustomErrorHandler) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "No forward derivative found for " + todiff->getName() << "\n"; - ss << *todiff << "\n"; - CustomErrorHandler(s.c_str(), wrap(todiff), ErrorType::NoDerivative, - nullptr, nullptr, nullptr); - } - if (todiff->empty()) - llvm::errs() << *todiff << "\n"; - assert(!todiff->empty()); bool retActive = retType != DIFFE_TYPE::CONSTANT; @@ -4505,6 +4546,45 @@ Function *EnzymeLogic::CreateForwardDiff( insert_or_assign2(ForwardCachedFunctions, tup, gutils->newFunc); + if (todiff->empty()) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << "No forward mode derivative found for " + todiff->getName() << "\n"; + llvm::Value *toshow = todiff; + if (context.req) { + toshow = context.req; + ss << " at context: " << *context.req; + } else { + ss << *todiff << "\n"; + } + BasicBlock *entry = &gutils->newFunc->getEntryBlock(); + cleanupInversionAllocs(gutils, entry); + clearFunctionAttributes(gutils->newFunc); + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(toshow), + ErrorType::NoDerivative, nullptr, wrap(todiff), + wrap(context.ip)); + auto newFunc = gutils->newFunc; + delete gutils; + return newFunc; + } + if (context.req) { + EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, + ss.str()); + + if (llvm::verifyFunction(*gutils->newFunc, &llvm::errs())) { + llvm::errs() << *gutils->oldFunc << "\n"; + llvm::errs() << *gutils->newFunc << "\n"; + report_fatal_error("function failed verification (r6)"); + } + auto newFunc = gutils->newFunc; + delete gutils; + return newFunc; + } + llvm::errs() << "mod: " << *todiff->getParent() << "\n"; + llvm::errs() << *todiff << "\n"; + llvm_unreachable("attempting to differentiate function without definition"); + } gutils->FreeMemory = freeMemory; const SmallPtrSet guaranteedUnreachable = @@ -4692,7 +4772,8 @@ Function *EnzymeLogic::CreateForwardDiff( return nf; } -llvm::Function *EnzymeLogic::CreateBatch(Function *tobatch, unsigned width, +llvm::Function *EnzymeLogic::CreateBatch(RequestContext context, + Function *tobatch, unsigned width, ArrayRef arg_types, BATCH_TYPE ret_type) { @@ -4722,6 +4803,33 @@ llvm::Function *EnzymeLogic::CreateBatch(Function *tobatch, unsigned width, Function::Create(FTy, tobatch->getLinkage(), "batch_" + tobatch->getName(), tobatch->getParent()); + if (tobatch->empty()) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << "No batch mode found for " + tobatch->getName() << "\n"; + llvm::Value *toshow = tobatch; + if (context.req) { + toshow = context.req; + ss << " at context: " << *context.req; + } else { + ss << *tobatch << "\n"; + } + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(toshow), + ErrorType::NoDerivative, nullptr, wrap(tobatch), + wrap(context.ip)); + return NewF; + } + if (context.req) { + EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, + ss.str()); + return NewF; + } + llvm::errs() << "mod: " << *tobatch->getParent() << "\n"; + llvm::errs() << *tobatch << "\n"; + llvm_unreachable("attempting to batch function without definition"); + } + NewF->setLinkage(Function::LinkageTypes::InternalLinkage); ValueToValueMapTy originalToNewFn; @@ -4949,11 +5057,13 @@ llvm::Function *EnzymeLogic::CreateBatch(Function *tobatch, unsigned width, return BatchCachedFunctions[tup] = NewF; }; -llvm::Function *EnzymeLogic::CreateTrace( - llvm::Function *totrace, const SmallPtrSetImpl &sampleFunctions, - const SmallPtrSetImpl &observeFunctions, - const StringSet<> &ActiveRandomVariables, ProbProgMode mode, bool autodiff, - TraceInterface *interface) { +llvm::Function * +EnzymeLogic::CreateTrace(RequestContext context, llvm::Function *totrace, + const SmallPtrSetImpl &sampleFunctions, + const SmallPtrSetImpl &observeFunctions, + const StringSet<> &ActiveRandomVariables, + ProbProgMode mode, bool autodiff, + TraceInterface *interface) { TraceCacheKey tup(totrace, mode, autodiff, interface); if (TraceCachedFunctions.find(tup) != TraceCachedFunctions.end()) { return TraceCachedFunctions.find(tup)->second; @@ -4989,6 +5099,39 @@ llvm::Function *EnzymeLogic::CreateTrace( new TraceGenerator(*this, tutils, autodiff, originalToNewFn, GenerativeFunctions, ActiveRandomVariables); + if (totrace->empty()) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << "No tracer found for " + totrace->getName() << "\n"; + llvm::Value *toshow = totrace; + if (context.req) { + toshow = context.req; + ss << " at context: " << *context.req; + } else { + ss << *totrace << "\n"; + } + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(toshow), + ErrorType::NoDerivative, nullptr, wrap(totrace), + wrap(context.ip)); + auto newFunc = tutils->newFunc; + delete tracer; + delete tutils; + return newFunc; + } + if (context.req) { + EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, + ss.str()); + auto newFunc = tutils->newFunc; + delete tracer; + delete tutils; + return newFunc; + } + llvm::errs() << "mod: " << *totrace->getParent() << "\n"; + llvm::errs() << *totrace << "\n"; + llvm_unreachable("attempting to trace function without definition"); + } + tracer->visit(totrace); if (verifyFunction(*tutils->newFunc, &errs())) { @@ -5015,13 +5158,14 @@ llvm::Function *EnzymeLogic::CreateTrace( return TraceCachedFunctions[tup] = NewF; } -llvm::Value *EnzymeLogic::CreateNoFree(llvm::Value *todiff) { +llvm::Value *EnzymeLogic::CreateNoFree(RequestContext context, + llvm::Value *todiff) { if (auto F = dyn_cast(todiff)) - return CreateNoFree(F); + return CreateNoFree(context, F); if (auto castinst = dyn_cast(todiff)) if (castinst->isCast()) { llvm::Constant *reps[] = { - cast(CreateNoFree(castinst->getOperand(0)))}; + cast(CreateNoFree(context, castinst->getOperand(0)))}; return castinst->getWithOperands(reps); } if (CustomErrorHandler) { @@ -5029,14 +5173,24 @@ llvm::Value *EnzymeLogic::CreateNoFree(llvm::Value *todiff) { llvm::raw_string_ostream ss(s); ss << "No create nofree of unknown value\n"; ss << *todiff << "\n"; - CustomErrorHandler(ss.str().c_str(), wrap(todiff), ErrorType::NoDerivative, - nullptr, nullptr, nullptr); + if (context.req) { + ss << " at context: " << *context.req; + } + CustomErrorHandler(ss.str().c_str(), wrap(context.req), + ErrorType::NoDerivative, nullptr, wrap(todiff), + wrap(context.ip)); + return todiff; } + if (context.req) { + EmitFailure("IllegalNoFree", context.req->getDebugLoc(), context.req, + "Cannot create nofree of instruction-created value: ", *todiff); + return todiff; + } if (auto arg = dyn_cast(todiff)) { auto loc = arg->getDebugLoc(); EmitFailure("IllegalNoFree", loc, arg, - "Cannot create nofree of instruction-created value: ", *arg); + "Cannot create nofree of instruction-created value: ", *todiff); return todiff; } @@ -5044,7 +5198,7 @@ llvm::Value *EnzymeLogic::CreateNoFree(llvm::Value *todiff) { llvm_unreachable("unhandled, create no free"); } -llvm::Function *EnzymeLogic::CreateNoFree(Function *F) { +llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { if (NoFreeCachedFunctions.find(F) != NoFreeCachedFunctions.end()) { return NoFreeCachedFunctions.find(F)->second; } @@ -5141,9 +5295,20 @@ llvm::Function *EnzymeLogic::CreateNoFree(Function *F) { std::string s; llvm::raw_string_ostream ss(s); ss << "No create nofree of empty function " << F->getName() << "\n"; - ss << *F << "\n"; - CustomErrorHandler(ss.str().c_str(), wrap(F), ErrorType::NoDerivative, - nullptr, nullptr, nullptr); + if (context.req) { + ss << " at context: " << *context.req; + } else { + ss << *F << "\n"; + } + CustomErrorHandler(ss.str().c_str(), wrap(context.req), + ErrorType::NoDerivative, nullptr, wrap(F), + wrap(context.ip)); + return F; + } + if (context.req) { + EmitFailure("IllegalNoFree", context.req->getDebugLoc(), context.req, + "Cannot create nofree of empty function: ", *F); + return F; } llvm::errs() << " unhandled, create no free of empty function: " << *F << "\n"; @@ -5202,11 +5367,11 @@ llvm::Function *EnzymeLogic::CreateNoFree(Function *F) { else { if (auto CI = dyn_cast(&I)) { auto callval = CI->getCalledOperand(); - CI->setCalledOperand(CreateNoFree(callval)); + CI->setCalledOperand(CreateNoFree(context, callval)); } if (auto CI = dyn_cast(&I)) { auto callval = CI->getCalledOperand(); - CI->setCalledOperand(CreateNoFree(callval)); + CI->setCalledOperand(CreateNoFree(context, callval)); } } } diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index a8bcc84dd13a..c7f7c4bae86e 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -240,6 +240,18 @@ struct ReverseCacheKey { } }; +// Holder class to represent a context in which a derivative +// or batch is being requested. This contains the instruction +// (or null) that led to the request, and a builder (or null) +// of the insertion point for code. +struct RequestContext { + llvm::Instruction *req; + llvm::IRBuilder<> *ip; + RequestContext(llvm::Instruction *req = nullptr, + llvm::IRBuilder<> *ip = nullptr) + : req(req), ip(ip) {} +}; + class EnzymeLogic { public: PreProcessCache PPC; @@ -333,12 +345,13 @@ class EnzymeLogic { }; std::map NoFreeCachedFunctions; - llvm::Function *CreateNoFree(llvm::Function *todiff); - llvm::Value *CreateNoFree(llvm::Value *todiff); + llvm::Function *CreateNoFree(RequestContext context, llvm::Function *todiff); + llvm::Value *CreateNoFree(RequestContext context, llvm::Value *todiff); std::map AugmentedCachedFunctions; /// Create an augmented forward pass. + /// \p context the instruction which requested this derivative (or null). /// \p todiff is the function to differentiate /// \p retType is the activity info of the return /// \p constant_args is the activity info of the arguments @@ -350,7 +363,7 @@ class EnzymeLogic { /// structure \p AtomicAdd is whether to perform all adjoint updates to /// memory in an atomic way const AugmentedReturn &CreateAugmentedPrimal( - llvm::Function *todiff, DIFFE_TYPE retType, + RequestContext context, llvm::Function *todiff, DIFFE_TYPE retType, llvm::ArrayRef constant_args, TypeAnalysis &TA, bool returnUsed, bool shadowReturnUsed, const FnTypeInfo &typeInfo, const std::vector _overwritten_args, bool forceAnonymousTape, @@ -437,40 +450,77 @@ class EnzymeLogic { std::tuple; std::map TraceCachedFunctions; - /// Create the derivative function itself. + /// Create the reverse pass, or combined forward+reverse derivative function. + /// \p context the instruction which requested this derivative (or null). /// \p todiff is the function to differentiate /// \p retType is the activity info of the return /// \p constant_args is the activity info of the arguments /// \p returnValue is whether the primal's return should also be returned /// \p dretUsed is whether the shadow return value should also be returned /// \p additionalArg is the type (or null) of an additional type in the - /// signature to hold the tape. \p typeInfo is the type info information - /// about the calling context \p _overwritten_args marks whether an argument - /// may be rewritten before loads in the generated function (and thus cannot - /// be cached). \p augmented is the data structure created by prior call to - /// an augmented forward pass \p AtomicAdd is whether to perform all adjoint + /// signature to hold the tape. + /// \p typeInfo is the type info information about the calling context + /// \p _overwritten_args marks whether an argument may be rewritten + /// before loads in the generated function (and thus cannot be cached). + /// \p augmented is the data structure created by prior call to an + /// augmented forward pass + /// \p AtomicAdd is whether to perform all adjoint /// updates to memory in an atomic way - llvm::Function *CreatePrimalAndGradient(const ReverseCacheKey &&key, + llvm::Function *CreatePrimalAndGradient(RequestContext context, + const ReverseCacheKey &&key, TypeAnalysis &TA, const AugmentedReturn *augmented, bool omp = false); - llvm::Function *CreateForwardDiff(llvm::Function *todiff, DIFFE_TYPE retType, - llvm::ArrayRef constant_args, - TypeAnalysis &TA, bool returnValue, - DerivativeMode mode, bool freeMemory, - unsigned width, llvm::Type *additionalArg, - const FnTypeInfo &typeInfo, - const std::vector _overwritten_args, - const AugmentedReturn *augmented, - bool omp = false); - - llvm::Function *CreateBatch(llvm::Function *tobatch, unsigned width, + /// Create the forward (or forward split) mode derivative function. + /// \p context the instruction which requested this derivative (or null). + /// \p todiff is the function to differentiate + /// \p retType is the activity info of the return + /// \p constant_args is the activity info of the arguments + /// \p TA is the type analysis results + /// \p returnValue is whether the primal's return should also be returned + /// \p mode is the requested derivative mode + /// \p is whether we should free memory allocated here (and could be + /// accessed externally). + /// \p width is the vector width requested. + /// \p additionalArg is the type (or null) of an additional type in the + /// signature to hold the tape. + /// \p FnTypeInfo is the known types of the argument and returns + /// \p _overwritten_args marks whether an argument may be rewritten + /// before loads in the generated function (and thus cannot be cached). + /// \p augmented is the data structure created by prior call to an + /// augmented forward pass + /// \p omp is whether this function is an OpenMP closure body. + llvm::Function *CreateForwardDiff( + RequestContext context, llvm::Function *todiff, DIFFE_TYPE retType, + llvm::ArrayRef constant_args, TypeAnalysis &TA, + bool returnValue, DerivativeMode mode, bool freeMemory, unsigned width, + llvm::Type *additionalArg, const FnTypeInfo &typeInfo, + const std::vector _overwritten_args, + const AugmentedReturn *augmented, bool omp = false); + + /// Create a function batched in its inputs. + /// \p context the instruction which requested this batch (or null). + /// \p tobatch is the function to batch + /// \p width is the vector width requested. + /// \p arg_types denotes which arguments are batched. + /// \p ret_type denotes whether to batch the return. + llvm::Function *CreateBatch(RequestContext context, llvm::Function *tobatch, + unsigned width, llvm::ArrayRef arg_types, BATCH_TYPE ret_type); + /// Create a traced version of a function + /// \p context the instruction which requested this trace (or null). + /// \p totrace is the function to trace + /// \p sampleFunctions is a set of the functions to sample + /// \p observeFunctions is a set of the functions to observe + /// \p ActiveRandomVariables is a set of which variables are active + /// \p mode is the mode to use + /// \p autodiff is whether to also differentiate + /// \p interface specifies the ABI to use. llvm::Function * - CreateTrace(llvm::Function *totrace, + CreateTrace(RequestContext context, llvm::Function *totrace, const llvm::SmallPtrSetImpl &sampleFunctions, const llvm::SmallPtrSetImpl &observeFunctions, const llvm::StringSet<> &ActiveRandomVariables, ProbProgMode mode, diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 53a615049ed6..6189ab1870bc 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -1325,16 +1325,18 @@ Function *PreProcessCache::preprocessForClone(Function *F, SmallVector Returns; + if (!F->empty()) { #if LLVM_VERSION_MAJOR >= 13 - CloneFunctionInto( - NewF, F, VMap, - /*ModuleLevelChanges*/ CloneFunctionChangeType::LocalChangesOnly, Returns, - "", nullptr); + CloneFunctionInto( + NewF, F, VMap, + /*ModuleLevelChanges*/ CloneFunctionChangeType::LocalChangesOnly, + Returns, "", nullptr); #else - CloneFunctionInto(NewF, F, VMap, - /*ModuleLevelChanges*/ F->getSubprogram() != nullptr, - Returns, "", nullptr); + CloneFunctionInto(NewF, F, VMap, + /*ModuleLevelChanges*/ F->getSubprogram() != nullptr, + Returns, "", nullptr); #endif + } CloneOrigin[NewF] = F; NewF->setAttributes(F->getAttributes()); if (EnzymeNoAlias) @@ -2060,8 +2062,8 @@ Function *PreProcessCache::CloneFunctionWithReturns( DIFFE_TYPE returnType, const Twine &name, llvm::ValueMap *VMapO, bool diffeReturnArg, llvm::Type *additionalArg) { - assert(!F->empty()); - F = preprocessForClone(F, mode); + if (!F->empty()) + F = preprocessForClone(F, mode); llvm::ValueToValueMapTy VMap; llvm::FunctionType *FTy = getFunctionTypeForClone( F->getFunctionType(), mode, width, additionalArg, constant_args, @@ -2113,13 +2115,20 @@ Function *PreProcessCache::CloneFunctionWithReturns( VMap[&I] = &*DestI++; // Add mapping to VMap } SmallVector Returns; + if (!F->empty()) { #if LLVM_VERSION_MAJOR >= 13 - CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, - Returns, "", nullptr); + CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly, + Returns, "", nullptr); #else - CloneFunctionInto(NewF, F, VMap, F->getSubprogram() != nullptr, Returns, "", - nullptr); + CloneFunctionInto(NewF, F, VMap, F->getSubprogram() != nullptr, Returns, "", + nullptr); #endif + } + if (NewF->empty()) { + auto entry = BasicBlock::Create(NewF->getContext(), "entry", NewF); + IRBuilder<> B(entry); + B.CreateUnreachable(); + } CloneOrigin[NewF] = F; if (VMapO) { for (const auto &data : VMap) diff --git a/enzyme/Enzyme/FunctionUtils.h b/enzyme/Enzyme/FunctionUtils.h index c3c63dc2dafe..7e6b4856379d 100644 --- a/enzyme/Enzyme/FunctionUtils.h +++ b/enzyme/Enzyme/FunctionUtils.h @@ -173,6 +173,8 @@ getLatches(const llvm::Loop *L, static inline llvm::SmallPtrSet getGuaranteedUnreachable(llvm::Function *F) { llvm::SmallPtrSet knownUnreachables; + if (F->empty()) + return knownUnreachables; std::deque todo; for (auto &BB : *F) { todo.push_back(&BB); diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 8901f782a88c..7c00ac376bad 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -168,19 +168,35 @@ GradientUtils::GradientUtils( DerivativeMode mode, unsigned width, bool omp) : CacheUtility(TLI_, newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_), invertedPointers(), - OrigDT(Logic.PPC.FAM.getResult(*oldFunc_)), - OrigPDT( - Logic.PPC.FAM.getResult(*oldFunc_)), - OrigLI(Logic.PPC.FAM.getResult(*oldFunc_)), - OrigSE(Logic.PPC.FAM.getResult(*oldFunc_)), + OrigDT(oldFunc_->empty() + ? *((DominatorTree *)nullptr) + : Logic.PPC.FAM.getResult( + *oldFunc_)), + OrigPDT(oldFunc_->empty() + ? *((PostDominatorTree *)nullptr) + : Logic.PPC.FAM.getResult( + *oldFunc_)), + OrigLI(oldFunc_->empty() + ? *((LoopInfo *)nullptr) + : Logic.PPC.FAM.getResult(*oldFunc_)), + OrigSE(oldFunc_->empty() + ? *((ScalarEvolution *)nullptr) + : Logic.PPC.FAM.getResult( + *oldFunc_)), notForAnalysis(getGuaranteedUnreachable(oldFunc_)), - ATA(new ActivityAnalyzer( - Logic.PPC, Logic.PPC.getAAResultsFromFunction(oldFunc_), - notForAnalysis, TLI_, constantvalues_, activevals_, ReturnActivity)), + ATA(oldFunc_->empty() + ? nullptr + : new ActivityAnalyzer( + Logic.PPC, Logic.PPC.getAAResultsFromFunction(oldFunc_), + notForAnalysis, TLI_, constantvalues_, activevals_, + ReturnActivity)), tid(nullptr), numThreads(nullptr), - OrigAA(Logic.PPC.getAAResultsFromFunction(oldFunc_)), TA(TA_), TR(TR_), - omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_), + OrigAA(oldFunc_->empty() ? *((AAResults *)nullptr) + : Logic.PPC.getAAResultsFromFunction(oldFunc_)), + TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_), overwritten_args_map_ptr(nullptr) { + if (oldFunc_->empty()) + return; if (oldFunc_->getSubprogram()) { assert(originalToNewFn_.hasMD()); } @@ -4192,7 +4208,6 @@ GradientUtils *GradientUtils::CreateFromClone( DIFFE_TYPE retType, ArrayRef constant_args, bool returnUsed, bool shadowReturnUsed, std::map &returnMapping, bool omp) { - assert(!todiff->empty()); Function *oldFunc = todiff; // Since this is forward pass this should always return the tape (at index 0) @@ -4276,7 +4291,8 @@ GradientUtils *GradientUtils::CreateFromClone( } TypeResults TR = TA.analyzeFunction(typeInfo); - assert(TR.getFunction() == oldFunc); + if (!oldFunc->empty()) + assert(TR.getFunction() == oldFunc); auto res = new GradientUtils( Logic, newFunc, oldFunc, TLI, TA, TR, invertedPointers, constant_values, @@ -4366,8 +4382,9 @@ DIFFE_TYPE GradientUtils::getDiffeType(Value *v, bool foreignFunction) const { } Constant *GradientUtils::GetOrCreateShadowConstant( - EnzymeLogic &Logic, TargetLibraryInfo &TLI, TypeAnalysis &TA, - Constant *oval, DerivativeMode mode, unsigned width, bool AtomicAdd) { + RequestContext context, EnzymeLogic &Logic, TargetLibraryInfo &TLI, + TypeAnalysis &TA, Constant *oval, DerivativeMode mode, unsigned width, + bool AtomicAdd) { if (isa(oval)) { return oval; } else if (isa(oval)) { @@ -4377,36 +4394,38 @@ Constant *GradientUtils::GetOrCreateShadowConstant( } else if (auto CD = dyn_cast(oval)) { SmallVector Vals; for (size_t i = 0, len = CD->getNumElements(); i < len; i++) { - Vals.push_back(GetOrCreateShadowConstant( - Logic, TLI, TA, CD->getElementAsConstant(i), mode, width, AtomicAdd)); + Vals.push_back(GetOrCreateShadowConstant(context, Logic, TLI, TA, + CD->getElementAsConstant(i), + mode, width, AtomicAdd)); } return ConstantArray::get(CD->getType(), Vals); } else if (auto CD = dyn_cast(oval)) { SmallVector Vals; for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { Vals.push_back(GetOrCreateShadowConstant( - Logic, TLI, TA, CD->getOperand(i), mode, width, AtomicAdd)); + context, Logic, TLI, TA, CD->getOperand(i), mode, width, AtomicAdd)); } return ConstantArray::get(CD->getType(), Vals); } else if (auto CD = dyn_cast(oval)) { SmallVector Vals; for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { Vals.push_back(GetOrCreateShadowConstant( - Logic, TLI, TA, CD->getOperand(i), mode, width, AtomicAdd)); + context, Logic, TLI, TA, CD->getOperand(i), mode, width, AtomicAdd)); } return ConstantStruct::get(CD->getType(), Vals); } else if (auto CD = dyn_cast(oval)) { SmallVector Vals; for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { Vals.push_back(GetOrCreateShadowConstant( - Logic, TLI, TA, CD->getOperand(i), mode, width, AtomicAdd)); + context, Logic, TLI, TA, CD->getOperand(i), mode, width, AtomicAdd)); } return ConstantVector::get(Vals); } else if (auto F = dyn_cast(oval)) { - return GetOrCreateShadowFunction(Logic, TLI, TA, F, mode, width, AtomicAdd); + return GetOrCreateShadowFunction(context, Logic, TLI, TA, F, mode, width, + AtomicAdd); } else if (auto arg = dyn_cast(oval)) { - auto C = GetOrCreateShadowConstant(Logic, TLI, TA, arg->getOperand(0), mode, - width, AtomicAdd); + auto C = GetOrCreateShadowConstant( + context, Logic, TLI, TA, arg->getOperand(0), mode, width, AtomicAdd); if (arg->isCast() || arg->getOpcode() == Instruction::GetElementPtr || arg->getOpcode() == Instruction::Add) { SmallVector NewOps; @@ -4415,8 +4434,8 @@ Constant *GradientUtils::GetOrCreateShadowConstant( return arg->getWithOperands(NewOps); } } else if (auto arg = dyn_cast(oval)) { - return GetOrCreateShadowConstant(Logic, TLI, TA, arg->getAliasee(), mode, - width, AtomicAdd); + return GetOrCreateShadowConstant(context, Logic, TLI, TA, arg->getAliasee(), + mode, width, AtomicAdd); } else if (auto arg = dyn_cast(oval)) { if (arg->getName() == "_ZTVN10__cxxabiv120__si_class_type_infoE" || arg->getName() == "_ZTVN10__cxxabiv117__class_type_infoE" || @@ -4469,8 +4488,8 @@ Constant *GradientUtils::GetOrCreateShadowConstant( shadow->setUnnamedAddr(arg->getUnnamedAddr()); if (arg->hasInitializer()) shadow->setInitializer(GetOrCreateShadowConstant( - Logic, TLI, TA, cast(arg->getOperand(0)), mode, width, - AtomicAdd)); + context, Logic, TLI, TA, cast(arg->getOperand(0)), mode, + width, AtomicAdd)); return shadow; } } @@ -4479,8 +4498,9 @@ Constant *GradientUtils::GetOrCreateShadowConstant( } Constant *GradientUtils::GetOrCreateShadowFunction( - EnzymeLogic &Logic, TargetLibraryInfo &TLI, TypeAnalysis &TA, Function *fn, - DerivativeMode mode, unsigned width, bool AtomicAdd) { + RequestContext context, EnzymeLogic &Logic, TargetLibraryInfo &TLI, + TypeAnalysis &TA, Function *fn, DerivativeMode mode, unsigned width, + bool AtomicAdd) { //! Todo allow tape propagation // Note that specifically this should _not_ be called with topLevel=true // (since it may not be valid to always assume we can recompute the @@ -4584,8 +4604,8 @@ Constant *GradientUtils::GetOrCreateShadowFunction( switch (mode) { case DerivativeMode::ForwardMode: { Constant *newf = Logic.CreateForwardDiff( - fn, retType, types, TA, false, mode, /*freeMemory*/ true, width, - nullptr, type_args, overwritten_args, /*augmented*/ nullptr); + context, fn, retType, types, TA, false, mode, /*freeMemory*/ true, + width, nullptr, type_args, overwritten_args, /*augmented*/ nullptr); assert(newf); @@ -4608,14 +4628,14 @@ Constant *GradientUtils::GetOrCreateShadowFunction( } case DerivativeMode::ForwardModeSplit: { auto &augdata = Logic.CreateAugmentedPrimal( - fn, retType, /*constant_args*/ types, TA, + context, fn, retType, /*constant_args*/ types, TA, /*returnUsed*/ !fn->getReturnType()->isEmptyTy() && !fn->getReturnType()->isVoidTy(), /*shadowReturnUsed*/ false, type_args, overwritten_args, /*forceAnonymousTape*/ true, width, AtomicAdd); Constant *newf = Logic.CreateForwardDiff( - fn, retType, types, TA, false, mode, /*freeMemory*/ true, width, - nullptr, type_args, overwritten_args, /*augmented*/ &augdata); + context, fn, retType, types, TA, false, mode, /*freeMemory*/ true, + width, nullptr, type_args, overwritten_args, /*augmented*/ &augdata); assert(newf); @@ -4651,10 +4671,11 @@ Constant *GradientUtils::GetOrCreateShadowFunction( bool shadowReturnUsed = returnUsed && (retType == DIFFE_TYPE::DUP_ARG || retType == DIFFE_TYPE::DUP_NONEED); auto &augdata = Logic.CreateAugmentedPrimal( - fn, retType, /*constant_args*/ types, TA, returnUsed, shadowReturnUsed, - type_args, overwritten_args, /*forceAnonymousTape*/ true, width, - AtomicAdd); + context, fn, retType, /*constant_args*/ types, TA, returnUsed, + shadowReturnUsed, type_args, overwritten_args, + /*forceAnonymousTape*/ true, width, AtomicAdd); Constant *newf = Logic.CreatePrimalAndGradient( + context, (ReverseCacheKey){.todiff = fn, .retType = retType, .constant_args = types, @@ -5398,7 +5419,8 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, } } else if (auto fn = dyn_cast(oval)) { Constant *shadow = - GetOrCreateShadowFunction(Logic, TLI, TA, fn, mode, width, AtomicAdd); + GetOrCreateShadowFunction(RequestContext(nullptr, &BuilderM), Logic, + TLI, TA, fn, mode, width, AtomicAdd); if (width > 1) { SmallVector arr; for (unsigned i = 0; i < width; ++i) { diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index a87f6da9dd69..1635540fa91a 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -485,13 +485,17 @@ class GradientUtils : public CacheUtility { llvm::Value *invertPointerM(llvm::Value *val, llvm::IRBuilder<> &BuilderM, bool nullShadow = false); - static llvm::Constant *GetOrCreateShadowConstant( - EnzymeLogic &Logic, llvm::TargetLibraryInfo &TLI, TypeAnalysis &TA, - llvm::Constant *F, DerivativeMode mode, unsigned width, bool AtomicAdd); - - static llvm::Constant *GetOrCreateShadowFunction( - EnzymeLogic &Logic, llvm::TargetLibraryInfo &TLI, TypeAnalysis &TA, - llvm::Function *F, DerivativeMode mode, unsigned width, bool AtomicAdd); + static llvm::Constant * + GetOrCreateShadowConstant(RequestContext context, EnzymeLogic &Logic, + llvm::TargetLibraryInfo &TLI, TypeAnalysis &TA, + llvm::Constant *F, DerivativeMode mode, + unsigned width, bool AtomicAdd); + + static llvm::Constant * + GetOrCreateShadowFunction(RequestContext context, EnzymeLogic &Logic, + llvm::TargetLibraryInfo &TLI, TypeAnalysis &TA, + llvm::Function *F, DerivativeMode mode, + unsigned width, bool AtomicAdd); void branchToCorrespondingTarget( llvm::BasicBlock *ctx, llvm::IRBuilder<> &BuilderM, diff --git a/enzyme/Enzyme/InstructionBatcher.cpp b/enzyme/Enzyme/InstructionBatcher.cpp index 9aeaca0beab4..36972f5e79dd 100644 --- a/enzyme/Enzyme/InstructionBatcher.cpp +++ b/enzyme/Enzyme/InstructionBatcher.cpp @@ -260,7 +260,8 @@ void InstructionBatcher::visitCallInst(llvm::CallInst &call) { ? BATCH_TYPE::SCALAR : BATCH_TYPE::VECTOR; - Function *new_func = Logic.CreateBatch(orig_func, width, arg_types, ret_type); + Function *new_func = Logic.CreateBatch(RequestContext(&call, &Builder2), + orig_func, width, arg_types, ret_type); CallInst *new_call = Builder2.CreateCall(new_func->getFunctionType(), new_func, args, call.getName()); diff --git a/enzyme/Enzyme/TraceGenerator.cpp b/enzyme/Enzyme/TraceGenerator.cpp index 2e3b12c716f4..b19e91fb6d5c 100644 --- a/enzyme/Enzyme/TraceGenerator.cpp +++ b/enzyme/Enzyme/TraceGenerator.cpp @@ -332,8 +332,9 @@ void TraceGenerator::handleArbitraryCall(CallInst &call, CallInst *new_call) { assert(called); Function *samplefn = Logic.CreateTrace( - called, tutils->sampleFunctions, tutils->observeFunctions, - activeRandomVariables, mode, autodiff, tutils->interface); + RequestContext(&call, &Builder), called, tutils->sampleFunctions, + tutils->observeFunctions, activeRandomVariables, mode, autodiff, + tutils->interface); Instruction *replacement; switch (mode) { diff --git a/enzyme/Enzyme/TraceUtils.cpp b/enzyme/Enzyme/TraceUtils.cpp index ec99f497052e..5e7626c4a247 100644 --- a/enzyme/Enzyme/TraceUtils.cpp +++ b/enzyme/Enzyme/TraceUtils.cpp @@ -114,14 +114,21 @@ TraceUtils::FromClone(ProbProgMode mode, } SmallVector Returns; + if (!oldFunc->empty()) { #if LLVM_VERSION_MAJOR >= 13 - CloneFunctionInto(newFunc, oldFunc, originalToNewFn, - CloneFunctionChangeType::LocalChangesOnly, Returns, "", - nullptr); + CloneFunctionInto(newFunc, oldFunc, originalToNewFn, + CloneFunctionChangeType::LocalChangesOnly, Returns, "", + nullptr); #else - CloneFunctionInto(newFunc, oldFunc, originalToNewFn, true, Returns, "", - nullptr); + CloneFunctionInto(newFunc, oldFunc, originalToNewFn, true, Returns, "", + nullptr); #endif + } + if (newFunc->empty()) { + auto entry = BasicBlock::Create(newFunc->getContext(), "entry", newFunc); + IRBuilder<> B(entry); + B.CreateUnreachable(); + } newFunc->setLinkage(Function::LinkageTypes::InternalLinkage); diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index ec02c6f22146..a389ad94c292 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -5296,7 +5296,6 @@ TypeResults TypeAnalysis::analyzeFunction(const FnTypeInfo &fn) { assert(fn.KnownValues.size() == fn.Function->getFunctionType()->getNumParams()); assert(fn.Function); - assert(!fn.Function->empty()); auto found = analyzedFunctions.find(fn); if (found != analyzedFunctions.end()) { auto &analysis = *found->second; @@ -5309,6 +5308,8 @@ TypeResults TypeAnalysis::analyzeFunction(const FnTypeInfo &fn) { return TypeResults(analysis); } + if (fn.Function->empty()) + return TypeResults(*(TypeAnalyzer *)nullptr); auto res = analyzedFunctions.emplace(fn, new TypeAnalyzer(fn, *this)); auto &analysis = *res.first->second; diff --git a/enzyme/test/Integration/ForwardMode/err_empty.c b/enzyme/test/Integration/ForwardMode/err_empty.c new file mode 100644 index 000000000000..14820e047061 --- /dev/null +++ b/enzyme/test/Integration/ForwardMode/err_empty.c @@ -0,0 +1,20 @@ +// RUN: %clang -std=c11 -g -O0 %s -S -emit-llvm -o - %loadClangEnzyme -Xclang -verify +// RUN: %clang -std=c11 -g -O1 %s -S -emit-llvm -o - %loadClangEnzyme -Xclang -verify +// RUN: %clang -std=c11 -g -O2 %s -S -emit-llvm -o - %loadClangEnzyme -Xclang -verify +// RUN: %clang -std=c11 -g -O3 %s -S -emit-llvm -o - %loadClangEnzyme -Xclang -verify +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -g -O0 %s -S -emit-llvm -o - %newLoadClangEnzyme -Xclang -verify; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -g -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -Xclang -verify; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -g -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -Xclang -verify; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -g -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -Xclang -verify; fi + +extern double __enzyme_fwddiff(void*, double, double); + +double unknown(double in); + +double g(double in) { + return unknown(in); // expected-error {{Enzyme: No forward mode derivative found for unknown}} +} + +double square(double x, double dx) { + return __enzyme_fwddiff((void*)g, x, dx); +} diff --git a/enzyme/test/Integration/ReverseMode/err_empty.c b/enzyme/test/Integration/ReverseMode/err_empty.c new file mode 100644 index 000000000000..899b1cb364af --- /dev/null +++ b/enzyme/test/Integration/ReverseMode/err_empty.c @@ -0,0 +1,20 @@ +// RUN: %clang -std=c11 -g -O0 %s -S -emit-llvm -o - %loadClangEnzyme -Xclang -verify +// RUN: %clang -std=c11 -g -O1 %s -S -emit-llvm -o - %loadClangEnzyme -Xclang -verify +// RUN: %clang -std=c11 -g -O2 %s -S -emit-llvm -o - %loadClangEnzyme -Xclang -verify +// RUN: %clang -std=c11 -g -O3 %s -S -emit-llvm -o - %loadClangEnzyme -Xclang -verify +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -g -O0 %s -S -emit-llvm -o - %newLoadClangEnzyme -Xclang -verify; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -g -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -Xclang -verify; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -g -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -Xclang -verify; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -g -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -Xclang -verify; fi + +extern double __enzyme_autodiff(void*, double); + +double unknown(double in); + +double g(double in) { + return unknown(unknown(in)); // expected-error {{Enzyme: No reverse pass found for unknown}} expected-error {{Enzyme: No augmented forward pass found for unknown}} expected-error {{Enzyme: No reverse pass found for unknown}} +} + +double square(double x) { + return __enzyme_autodiff((void*)g, x); +} From 4369b893fe4686ef6b5708e590f02f87ffa7707e Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Thu, 21 Sep 2023 03:23:07 -0400 Subject: [PATCH 24/29] start fixing gemm (#1448) * start fixing gemm * update rule for gemm * start fixing gemm * update rule for gemm * Fix concat * fix * Gemm passes * fix constant 1 * Adding runtime activity tests * fix attributore * return early if output arg is runtime inactive * make runtime activity more aggressive * simplify generated IR * update some tests * update some tests * fix minor bugs * fix sdot * gemf * update some tests * f c tp lacpy * update some tests * remove test * byref over * update some tests * fix test --------- Co-authored-by: William S. Moses --- enzyme/Enzyme/BlasDerivatives.td | 26 +- enzyme/Enzyme/Utils.h | 17 +- .../blas/cblas_sdot_runtime_act.ll | 32 +- enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll | 112 +- .../test/Enzyme/ReverseMode/blas/gemm_f_c.ll | 97 +- .../Enzyme/ReverseMode/blas/gemm_f_c_lacpy.ll | 329 ------ .../blas/gemm_f_c_lacpy_runtime_act.ll | 181 ++-- .../Enzyme/ReverseMode/blas/gemm_f_c_loop.ll | 87 +- .../Enzyme/ReverseMode/blas/gemm_f_c_split.ll | 92 +- .../ReverseMode/blas/gemm_f_c_split_lacpy.ll | 108 +- .../blas/gemm_f_c_split_transpose_lacpy.ll | 148 ++- .../blas/gemm_f_c_transpose_lacpy.ll | 78 +- .../ReverseMode/blas/gemm_f_change_ld.ll | 46 +- .../Enzyme/ReverseMode/blas/gemm_f_lacpy.ll | 126 --- .../Enzyme/ReverseMode/blas/gemm_f_over.ll | 122 ++- .../ReverseMode/blas/gemm_f_over_lacpy.ll | 75 +- .../blas/gemv_f_c_split_blascpy.ll | 78 +- .../gemv_f_c_split_blascpy_runtime_act.ll | 38 +- .../ReverseMode/blas/gemv_f_c_split_memcpy.ll | 14 +- enzyme/test/Integration/ReverseMode/blas.cpp | 958 +----------------- .../Integration/ReverseMode/blas_runtime.cpp | 376 +++++++ enzyme/test/Integration/blasinfra.h | 941 +++++++++++++++++ enzyme/tools/enzyme-tblgen/blas-tblgen.cpp | 52 +- enzyme/tools/enzyme-tblgen/blasDeclUpdater.h | 34 +- enzyme/tools/enzyme-tblgen/datastructures.cpp | 7 + enzyme/tools/enzyme-tblgen/datastructures.h | 2 + 26 files changed, 2330 insertions(+), 1846 deletions(-) delete mode 100644 enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy.ll delete mode 100644 enzyme/test/Enzyme/ReverseMode/blas/gemm_f_lacpy.ll create mode 100644 enzyme/test/Integration/ReverseMode/blas_runtime.cpp create mode 100644 enzyme/test/Integration/blasinfra.h diff --git a/enzyme/Enzyme/BlasDerivatives.td b/enzyme/Enzyme/BlasDerivatives.td index 9ff3abf65bbd..33e7d0e963c4 100644 --- a/enzyme/Enzyme/BlasDerivatives.td +++ b/enzyme/Enzyme/BlasDerivatives.td @@ -57,6 +57,7 @@ def tp : MagicInst; // transpose the trans param. def noop : MagicInst; // gradient is zero def inactive : MagicInst; // like noop, but assert it's inactive def Rows : MagicInst; // given a transpose, normal rows, normal cols get the true rows, aka normal rows if N else normal cols +def Concat : MagicInst; // if !cache_A, then just use $lda. // if cache_A, then check $transa. @@ -192,9 +193,7 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $ //} else { // call sger(m, n, alpha, x, incx, ya, incy, Aa, lda) //} - /* A */ (b<"ger"> $layout, $m, $n, $alpha, (Rows $transa, adj<"y">, $x), - (Rows $transa, $x, adj<"y">), - adj<"A">), + /* A */ (b<"ger"> $layout, $m, $n, $alpha, (Rows $transa, (Concat adj<"y">, $x), (Concat $x, adj<"y">)), adj<"A">), /* x */ (b<"gemv"> $layout, transpose<"transa">, $m, $n, $alpha, $A, (ld $A, Char<"N">, $lda, $m, $n), adj<"y">, Constant<"1.0">, adj<"x">), /* beta */ (b<"dot"> (Rows $transa, $m, $n), adj<"y">, input<"y">), /* y */ (b<"scal"> (Rows $transa, $m, $n), $beta, adj<"y">) @@ -225,10 +224,25 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A /* alpha */ (Seq<["AB", "product", "m", "n"]> (b<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $m, $k), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n (FrobInnerProd<""> $m, $n, adj<"C">, use<"AB">)), - /* A */ (b<"gemm"> $layout, $transa, transpose<"transb">, $m, $k, $n, $alpha, adj<"C">, $B, (ld $B, $transb, $ldb, $k, $n), $beta, adj<"A">), - /* B */ (b<"gemm"> $layout, transpose<"transa">, $transb, $k, $n, $m, $alpha, $A, (ld $A, $transa, $lda, $m, $k), adj<"C">, $beta, adj<"B">), + /* A */ (b<"gemm"> $layout, (Rows $transa, + (Concat Char<"N">, transpose<"transb">, $m, $k), + (Concat $transb, Char<"T">, $k, $m)), + $n, $alpha, + (Rows $transa, + (Concat adj<"C">, $B, (ld $B, $transb, $ldb, $k, $n)), + (Concat $B, (ld $B, $transb, $ldb, $k, $n), adj<"C">)), + Constant<"1.0">, adj<"A">), + + /* B */ (b<"gemm"> $layout, (Rows $transb, + (Concat transpose<"transa">, Char<"N">, $k, $n), + (Concat Char<"T">, $transa, $n, $k)), + $m, $alpha, + (Rows $transb, + (Concat $A, (ld $A, $transa, $lda, $m, $k), adj<"C">), + (Concat adj<"C">, $A, (ld $A, $transa, $lda, $m, $k))), + Constant<"1.0">, adj<"B">), /* beta */ (FrobInnerProd<""> $m, $n, adj<"C">, input<"C">), - /* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, ConstantInt<0>) + /* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">) ] >; diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 793803635556..fd0daa7da04d 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -1643,15 +1643,18 @@ llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B, llvm::Value *dim_2, bool cacheMat, bool byRef); -template static inline void nothing(T...){}; +template +static inline void append(llvm::SmallVectorImpl &vec) {} +template +static inline void append(llvm::SmallVectorImpl &vec, llvm::ArrayRef vals, + T2 &&...ts) { + vec.append(vals.begin(), vals.end()); + append(vec, std::forward(ts)...); +} template -static inline llvm::SmallVector concat_values(T... t) { +static inline llvm::SmallVector concat_values(T &&...t) { llvm::SmallVector res; - auto append = [&](llvm::ArrayRef V) { - res.append(V.begin(), V.end()); - return 0; - }; - nothing(append(t)...); + append(res, std::forward(t)...); return res; } diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_sdot_runtime_act.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sdot_runtime_act.ll index 902a6a784019..3f60b2d5cc84 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/cblas_sdot_runtime_act.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sdot_runtime_act.ll @@ -87,16 +87,16 @@ entry: ; CHECK: define internal void @[[active]](i32 %len, float* noalias %m, float* %"m'", i32 %incm, float* noalias %n, float* %"n'", i32 %incn, float %differeturn) ; CHECK-NEXT: entry: -; CHECK-NEXT: %rt.inactive.x = icmp eq float* %"m'", %m -; CHECK-NEXT: %rt.inactive.y = icmp eq float* %"n'", %n -; CHECK-NEXT: br i1 %rt.inactive.x, label %invertentry.x.done, label %invertentry.x.active +; CHECK-NEXT: %[[rtinactivex:.+]] = icmp eq float* %"m'", %m +; CHECK-NEXT: %[[rtinactivey:.+]] = icmp eq float* %"n'", %n +; CHECK-NEXT: br i1 %[[rtinactivex]], label %invertentry.x.done, label %invertentry.x.active ; CHECK: invertentry.x.active: ; preds = %entry ; CHECK-NEXT: call void @cblas_saxpy(i32 %len, float %differeturn, float* %n, i32 %incn, float* %"m'", i32 %incm) ; CHECK-NEXT: br label %invertentry.x.done ; CHECK: invertentry.x.done: ; preds = %invertentry.x.active, %entry -; CHECK-NEXT: br i1 %rt.inactive.y, label %invertentry.y.done, label %invertentry.y.active +; CHECK-NEXT: br i1 %[[rtinactivey]], label %invertentry.y.done, label %invertentry.y.active ; CHECK: invertentry.y.active: ; preds = %invertentry.x.done ; CHECK-NEXT: call void @cblas_saxpy(i32 %len, float %differeturn, float* %m, i32 %incm, float* %"n'", i32 %incn) @@ -108,8 +108,8 @@ entry: ; CHECK: define internal void @[[inactiveFirst]](i32 %len, float* noalias %m, i32 %incm, float* noalias %n, float* %"n'", i32 %incn, float %differeturn) ; CHECK-NEXT: entry: -; CHECK-NEXT: %rt.inactive.y = icmp eq float* %"n'", %n -; CHECK-NEXT: br i1 %rt.inactive.y, label %invertentry.y.done, label %invertentry.y.active +; CHECK-NEXT: %[[rtinactivey:.+]] = icmp eq float* %"n'", %n +; CHECK-NEXT: br i1 %[[rtinactivey]], label %invertentry.y.done, label %invertentry.y.active ; CHECK: invertentry.y.active: ; preds = %entry ; CHECK-NEXT: call void @cblas_saxpy(i32 %len, float %differeturn, float* %m, i32 %incm, float* %"n'", i32 %incn) @@ -121,8 +121,8 @@ entry: ; CHECK: define internal void @[[inactiveSecond]](i32 %len, float* noalias %m, float* %"m'", i32 %incm, float* noalias %n, i32 %incn, float %differeturn) ; CHECK-NEXT: entry: -; CHECK-NEXT: %rt.inactive.x = icmp eq float* %"m'", %m -; CHECK-NEXT: br i1 %rt.inactive.x, label %invertentry.x.done, label %invertentry.x.active +; CHECK-NEXT: %[[rtinactivex]] = icmp eq float* %"m'", %m +; CHECK-NEXT: br i1 %[[rtinactivex]], label %invertentry.x.done, label %invertentry.x.active ; CHECK: invertentry.x.active: ; preds = %entry ; CHECK-NEXT: call void @cblas_saxpy(i32 %len, float %differeturn, float* %n, i32 %incn, float* %"m'", i32 %incm) @@ -156,18 +156,18 @@ entry: ; CHECK: define internal void @[[revMod]](i32 %len, float* noalias %m, float* %"m'", i32 %incm, float* noalias %n, float* %"n'", i32 %incn, float %differeturn, { float*, float* } ; CHECK-NEXT: entry: -; CHECK-NEXT: %rt.inactive.x = icmp eq float* %"m'", %m -; CHECK-NEXT: %rt.inactive.y = icmp eq float* %"n'", %n +; CHECK-NEXT: %[[rtinactivex:.+]] = icmp eq float* %"m'", %m +; CHECK-NEXT: %[[rtinactivey:.+]] = icmp eq float* %"n'", %n ; CHECK-NEXT: %tape.ext.x = extractvalue { float*, float* } %0, 0 ; CHECK-NEXT: %tape.ext.y = extractvalue { float*, float* } %0, 1 -; CHECK-NEXT: br i1 %rt.inactive.x, label %invertentry.x.done, label %invertentry.x.active +; CHECK-NEXT: br i1 %[[rtinactivex]], label %invertentry.x.done, label %invertentry.x.active ; CHECK: invertentry.x.active: ; preds = %entry ; CHECK-NEXT: call void @cblas_saxpy(i32 %len, float %differeturn, float* %tape.ext.y, i32 1, float* %"m'", i32 %incm) ; CHECK-NEXT: br label %invertentry.x.done ; CHECK: invertentry.x.done: ; preds = %invertentry.x.active, %entry -; CHECK-NEXT: br i1 %rt.inactive.y, label %invertentry.y.done, label %invertentry.y.active +; CHECK-NEXT: br i1 %[[rtinactivey]], label %invertentry.y.done, label %invertentry.y.active ; CHECK: invertentry.y.active: ; preds = %invertentry.x.done ; CHECK-NEXT: call void @cblas_saxpy(i32 %len, float %differeturn, float* %tape.ext.x, i32 1, float* %"n'", i32 %incn) @@ -199,8 +199,8 @@ entry: ; CHECK: define internal void @[[revModFirst]](i32 %len, float* noalias %m, i32 %incm, float* noalias %n, float* %"n'", i32 %incn, float %differeturn, float* ; CHECK-NEXT: entry: -; CHECK-NEXT: %rt.inactive.y = icmp eq float* %"n'", %n -; CHECK-NEXT: br i1 %rt.inactive.y, label %invertentry.y.done, label %invertentry.y.active +; CHECK-NEXT: %[[rtinactivey:.+]] = icmp eq float* %"n'", %n +; CHECK-NEXT: br i1 %[[rtinactivey]], label %invertentry.y.done, label %invertentry.y.active ; CHECK: invertentry.y.active: ; preds = %entry ; CHECK-NEXT: call void @cblas_saxpy(i32 %len, float %differeturn, float* %0, i32 1, float* %"n'", i32 %incn) @@ -231,8 +231,8 @@ entry: ; CHECK: define internal void @[[revModSecond]](i32 %len, float* noalias %m, float* %"m'", i32 %incm, float* noalias %n, i32 %incn, float %differeturn, float* ; CHECK-NEXT: entry: -; CHECK-NEXT: %rt.inactive.x = icmp eq float* %"m'", %m -; CHECK-NEXT: br i1 %rt.inactive.x, label %invertentry.x.done, label %invertentry.x.active +; CHECK-NEXT: %[[rtinactivex:.+]] = icmp eq float* %"m'", %m +; CHECK-NEXT: br i1 %[[rtinactivex]], label %invertentry.x.done, label %invertentry.x.active ; CHECK: invertentry.x.active: ; preds = %entry ; CHECK-NEXT: call void @cblas_saxpy(i32 %len, float %differeturn, float* %0, i32 1, float* %"m'", i32 %incm) diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll index 9b514b2be53b..c6736fdf0982 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f.ll @@ -47,33 +47,38 @@ entry: ; CHECK: define internal void @diffef(i8* %C, i8* %"C'", i8* %A, i8* %"A'", i8* %B, i8* %"B'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %ret = alloca double -; CHECK-NEXT: %byref.transpose.transa = alloca i8 -; CHECK-NEXT: %byref.transpose.transb = alloca i8 -; CHECK-NEXT: %byref.int.one = alloca i64 -; CHECK-NEXT: %byref.constant.char.G = alloca i8 -; CHECK-NEXT: %byref.constant.int.0 = alloca i64 -; CHECK-NEXT: %[[byrefconstantint1:.+]] = alloca i64 -; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double -; CHECK-NEXT: %[[byrefconstantint2:.+]] = alloca i64 -; CHECK-NEXT: %transa = alloca i8, align 1 -; CHECK-NEXT: %transb = alloca i8, align 1 -; CHECK-NEXT: %m = alloca i64, align 16 -; CHECK-NEXT: %m_p = bitcast i64* %m to i8* -; CHECK-NEXT: %n = alloca i64, align 16 -; CHECK-NEXT: %n_p = bitcast i64* %n to i8* -; CHECK-NEXT: %k = alloca i64, align 16 -; CHECK-NEXT: %k_p = bitcast i64* %k to i8* -; CHECK-NEXT: %alpha = alloca double, align 16 -; CHECK-NEXT: %alpha_p = bitcast double* %alpha to i8* -; CHECK-NEXT: %lda = alloca i64, align 16 -; CHECK-NEXT: %lda_p = bitcast i64* %lda to i8* -; CHECK-NEXT: %ldb = alloca i64, align 16 -; CHECK-NEXT: %ldb_p = bitcast i64* %ldb to i8* -; CHECK-NEXT: %beta = alloca double, align 16 -; CHECK-NEXT: %beta_p = bitcast double* %beta to i8* -; CHECK-NEXT: %ldc = alloca i64, align 16 -; CHECK-NEXT: %ldc_p = bitcast i64* %ldc to i8* +; CHECK-DAG: %ret = alloca double +; CHECK-DAG: %byref.transpose.transa = alloca i8 +; CHECK-DAG: %byref.transpose.transb = alloca i8 +; CHECK-DAG: %byref.int.one = alloca i64 +; CHECK-DAG: %byref.constant.char.T = alloca i8, align 1 +; CHECK-DAG: %byref.constant.char.N = alloca i8, align 1 +; CHECK-DAG: %byref.constant.fp.1.0 = alloca double +; CHECK-DAG: %byref.constant.char.T2 = alloca i8, align 1 +; CHECK-DAG: %byref.constant.char.N3 = alloca i8, align 1 +; CHECK-DAG: %byref.constant.fp.1.06 = alloca double +; CHECK-DAG: %byref.constant.char.G = alloca i8 +; CHECK-DAG: %byref.constant.int.0 = alloca i64 +; CHECK-DAG: %[[byrefconstantint1:.+]] = alloca i64 +; CHECK-DAG: %byref.constant.fp.1.010 = alloca double +; CHECK-DAG: %transa = alloca i8, align 1 +; CHECK-DAG: %transb = alloca i8, align 1 +; CHECK-DAG: %m = alloca i64, align 16 +; CHECK-DAG: %m_p = bitcast i64* %m to i8* +; CHECK-DAG: %n = alloca i64, align 16 +; CHECK-DAG: %n_p = bitcast i64* %n to i8* +; CHECK-DAG: %k = alloca i64, align 16 +; CHECK-DAG: %k_p = bitcast i64* %k to i8* +; CHECK-DAG: %alpha = alloca double, align 16 +; CHECK-DAG: %alpha_p = bitcast double* %alpha to i8* +; CHECK-DAG: %lda = alloca i64, align 16 +; CHECK-DAG: %lda_p = bitcast i64* %lda to i8* +; CHECK-DAG: %ldb = alloca i64, align 16 +; CHECK-DAG: %ldb_p = bitcast i64* %ldb to i8* +; CHECK-DAG: %beta = alloca double, align 16 +; CHECK-DAG: %beta_p = bitcast double* %beta to i8* +; CHECK-DAG: %ldc = alloca i64, align 16 +; CHECK-DAG: %ldc_p = bitcast i64* %ldc to i8* ; CHECK-NEXT: store i8 78, i8* %transa, align 1 ; CHECK-NEXT: store i8 78, i8* %transb, align 1 ; CHECK-NEXT: store i64 4, i64* %m, align 16 @@ -110,17 +115,56 @@ entry: ; CHECK-NEXT: store i8 %[[i25]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) + +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1 +; CHECK-NEXT: %ld.row.trans = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[a16:.+]] = icmp eq i8 %ld.row.trans, 110 +; CHECK-NEXT: %[[a17:.+]] = icmp eq i8 %ld.row.trans, 78 +; CHECK-NEXT: %[[a18:.+]] = or i1 %[[a17]], %[[a16]] +; CHECK-NEXT: %[[a19:.+]] = select i1 %[[a18]], i8* %byref.constant.char.N, i8* %transb +; CHECK-NEXT: %[[a20:.+]] = select i1 %[[a18]], i8* %byref.transpose.transb, i8* %byref.constant.char.T +; CHECK-NEXT: %[[a21:.+]] = select i1 %[[a18]], i8* %m_p, i8* %k_p +; CHECK-NEXT: %[[a22:.+]] = select i1 %[[a18]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %ld.row.trans1 = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[a23:.+]] = icmp eq i8 %ld.row.trans1, 110 +; CHECK-NEXT: %[[a24:.+]] = icmp eq i8 %ld.row.trans1, 78 +; CHECK-NEXT: %[[a25:.+]] = or i1 %[[a24]], %[[a23]] +; CHECK-NEXT: %[[a26:.+]] = select i1 %[[a25]], i8* %"C'", i8* %B +; CHECK-NEXT: %[[a27:.+]] = select i1 %[[a25]], i8* %ldc_p, i8* %ldb_p +; CHECK-NEXT: %[[a28:.+]] = select i1 %[[a25]], i8* %B, i8* %"C'" +; CHECK-NEXT: %[[a29:.+]] = select i1 %[[a25]], i8* %ldb_p, i8* %ldc_p +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* +; CHECK-NEXT: call void @dgemm_64_(i8* %[[a19]], i8* %[[a20]], i8* %[[a21]], i8* %[[a22]], i8* %n_p, i8* %alpha_p, i8* %[[a26]], i8* %[[a27]], i8* %[[a28]], i8* %[[a29]], i8* %fpcast.constant.fp.1.0, i8* %"A'", i8* %lda_p, i64 1, i64 1) +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T2, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N3, align 1 +; CHECK-NEXT: %ld.row.trans4 = load i8, i8* %transb, align 1 +; CHECK-NEXT: %[[a30:.+]] = icmp eq i8 %ld.row.trans4, 110 +; CHECK-NEXT: %[[a31:.+]] = icmp eq i8 %ld.row.trans4, 78 +; CHECK-NEXT: %[[a32:.+]] = or i1 %[[a31]], %[[a30]] +; CHECK-NEXT: %[[a33:.+]] = select i1 %[[a32]], i8* %byref.transpose.transa, i8* %byref.constant.char.T2 +; CHECK-NEXT: %[[a34:.+]] = select i1 %[[a32]], i8* %byref.constant.char.N3, i8* %transa +; CHECK-NEXT: %[[a35:.+]] = select i1 %[[a32]], i8* %k_p, i8* %n_p +; CHECK-NEXT: %[[a36:.+]] = select i1 %[[a32]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %ld.row.trans5 = load i8, i8* %transb, align 1 +; CHECK-NEXT: %[[a37:.+]] = icmp eq i8 %ld.row.trans5, 110 +; CHECK-NEXT: %[[a38:.+]] = icmp eq i8 %ld.row.trans5, 78 +; CHECK-NEXT: %[[a39:.+]] = or i1 %[[a38]], %[[a37]] +; CHECK-NEXT: %[[a40:.+]] = select i1 %[[a39]], i8* %A, i8* %"C'" +; CHECK-NEXT: %[[a41:.+]] = select i1 %[[a39]], i8* %lda_p, i8* %ldc_p +; CHECK-NEXT: %[[a42:.+]] = select i1 %[[a39]], i8* %"C'", i8* %A +; CHECK-NEXT: %[[a43:.+]] = select i1 %[[a39]], i8* %ldc_p, i8* %lda_p +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.06, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.07 = bitcast double* %byref.constant.fp.1.06 to i8* +; CHECK-NEXT: call void @dgemm_64_(i8* %[[a33]], i8* %[[a34]], i8* %[[a35]], i8* %[[a36]], i8* %m_p, i8* %alpha_p, i8* %[[a40]], i8* %[[a41]], i8* %[[a42]], i8* %[[a43]], i8* %fpcast.constant.fp.1.07, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* ; CHECK-NEXT: store i64 0, i64* %[[byrefconstantint1]] -; CHECK-NEXT: %intcast.constant.int.02 = bitcast i64* %byref.constant.int.01 to i8* +; CHECK-NEXT: %[[int02:.+]] = bitcast i64* %[[byrefconstantint1]] to i8* ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 -; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* -; CHECK-NEXT: store i64 0, i64* %[[byrefconstantint2]] -; CHECK-NEXT: %intcast.constant.int.04 = bitcast i64* %byref.constant.int.03 to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.02, i8* %fpcast.constant.fp.1.0, i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %intcast.constant.int.04) +; CHECK-NEXT: %[[fp11:.+]] = bitcast double* %byref.constant.fp.1.010 to i8* +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[int02]], i8* %[[fp11]], i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i64 1) ; CHECK-NEXT: ret void ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll index 991ae30a792c..f32418423c1c 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll @@ -53,11 +53,16 @@ entry: ; CHECK-NEXT: %byref.transpose.transa = alloca i8 ; CHECK-NEXT: %byref.transpose.transb = alloca i8 ; CHECK-NEXT: %byref.int.one = alloca i64 -; CHECK-NEXT: %byref.constant.char.G = alloca i8 -; CHECK-NEXT: %byref.constant.int.0 = alloca i64 -; CHECK-NEXT: %[[byrefconstantint4:.+]] = alloca i64 -; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double -; CHECK-NEXT: %[[byrefconstantint5:.+]] = alloca i64 +; CHECK-NEXT: %byref.constant.char.T = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.char.N = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double, align 8 +; CHECK-NEXT: %byref.constant.char.T7 = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.char.N8 = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.fp.1.013 = alloca double, align 8 +; CHECK-NEXT: %byref.constant.char.G = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.int.0 = alloca i64, align 8 +; CHECK-NEXT: %[[byrefconstantint4:.+]] = alloca i64, align 8 +; CHECK-NEXT: %byref.constant.fp.1.017 = alloca double, align 8 ; CHECK-NEXT: %transa = alloca i8, align 1 ; CHECK-NEXT: %transb = alloca i8, align 1 ; CHECK-NEXT: %m = alloca i64, align 16 @@ -208,28 +213,76 @@ entry: ; CHECK-NEXT: store i8 %[[r15]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: %loaded.trans4 = load i8, i8* %transb -; CHECK-DAG: %[[r18:.+]] = icmp eq i8 %loaded.trans4, 78 -; CHECK-DAG: %[[r19:.+]] = icmp eq i8 %loaded.trans4, 110 -; CHECK-DAG: %[[r20:.+]] = or i1 %[[r19]], %[[r18]] -; CHECK-DAG: %[[r21:.+]] = select i1 %[[r20]], i8* %k_p, i8* %n_p -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %[[i43]], i8* %[[r21]], i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) -; CHECK-NEXT: %loaded.trans5 = load i8, i8* %transa -; CHECK-DAG: %[[r22:.+]] = icmp eq i8 %loaded.trans5, 78 -; CHECK-DAG: %[[r23:.+]] = icmp eq i8 %loaded.trans5, 110 -; CHECK-DAG: %[[r24:.+]] = or i1 %[[r23]], %[[r22]] -; CHECK-DAG: %[[r25:.+]] = select i1 %[[r24]], i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %[[i42]], i8* %[[r25]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1 +; CHECK-NEXT: %ld.row.trans = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[r58:.+]] = icmp eq i8 %ld.row.trans, 110 +; CHECK-NEXT: %[[r59:.+]] = icmp eq i8 %ld.row.trans, 78 +; CHECK-NEXT: %[[r60:.+]] = or i1 %[[r59]], %[[r58]] +; CHECK-NEXT: %[[r61:.+]] = select i1 %[[r60]], i8* %byref.constant.char.N, i8* %transb +; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r60]], i8* %byref.transpose.transb, i8* %byref.constant.char.T +; CHECK-NEXT: %[[r63:.+]] = select i1 %[[r60]], i8* %m_p, i8* %k_p +; CHECK-NEXT: %[[r64:.+]] = select i1 %[[r60]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %loaded.trans4 = load i8, i8* %transb, align 1 +; CHECK-NEXT: %[[r65:.+]] = icmp eq i8 %loaded.trans4, 78 +; CHECK-NEXT: %[[r66:.+]] = icmp eq i8 %loaded.trans4, 110 +; CHECK-NEXT: %[[r67:.+]] = or i1 %[[r66]], %[[r65]] +; CHECK-NEXT: %[[r68:.+]] = select i1 %[[r67]], i8* %k_p, i8* %n_p +; CHECK-NEXT: %loaded.trans5 = load i8, i8* %transb, align 1 +; CHECK-NEXT: %[[r69:.+]] = icmp eq i8 %loaded.trans5, 78 +; CHECK-NEXT: %[[r70:.+]] = icmp eq i8 %loaded.trans5, 110 +; CHECK-NEXT: %[[r71:.+]] = or i1 %[[r70]], %[[r69]] +; CHECK-NEXT: %[[r72:.+]] = select i1 %[[r71]], i8* %k_p, i8* %n_p +; CHECK-NEXT: %ld.row.trans6 = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[r73:.+]] = icmp eq i8 %ld.row.trans6, 110 +; CHECK-NEXT: %[[r74:.+]] = icmp eq i8 %ld.row.trans6, 78 +; CHECK-NEXT: %[[r75:.+]] = or i1 %[[r74]], %[[r73]] +; CHECK-NEXT: %[[r76:.+]] = select i1 %[[r75]], i8* %"C'", i8* %41 +; CHECK-NEXT: %[[r77:.+]] = select i1 %[[r75]], i8* %ldc_p, i8* %68 +; CHECK-NEXT: %[[r78:.+]] = select i1 %[[r75]], i8* %[[i43]], i8* %"C'" +; CHECK-NEXT: %[[r79:.+]] = select i1 %[[r75]], i8* %[[r72]], i8* %ldc_p +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* +; CHECK-NEXT: call void @dgemm_64_(i8* %[[r61]], i8* %[[r62]], i8* %[[r63]], i8* %[[r64]], i8* %n_p, i8* %alpha_p, i8* %[[r76]], i8* %[[r77]], i8* %[[r78]], i8* %[[r79]], i8* %fpcast.constant.fp.1.0, i8* %"A'", i8* %lda_p, i64 1, i64 1) +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T7, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N8, align 1 +; CHECK-NEXT: %ld.row.trans9 = load i8, i8* %transb, align 1 +; CHECK-NEXT: %[[r80:.+]] = icmp eq i8 %ld.row.trans9, 110 +; CHECK-NEXT: %[[r81:.+]] = icmp eq i8 %ld.row.trans9, 78 +; CHECK-NEXT: %[[r82:.+]] = or i1 %[[r81]], %[[r80]] +; CHECK-NEXT: %[[r83:.+]] = select i1 %[[r82]], i8* %byref.transpose.transa, i8* %byref.constant.char.T7 +; CHECK-NEXT: %[[r84:.+]] = select i1 %[[r82]], i8* %byref.constant.char.N8, i8* %transa +; CHECK-NEXT: %[[r85:.+]] = select i1 %[[r82]], i8* %k_p, i8* %n_p +; CHECK-NEXT: %[[r86:.+]] = select i1 %[[r82]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %loaded.trans10 = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[r87:.+]] = icmp eq i8 %loaded.trans10, 78 +; CHECK-NEXT: %[[r88:.+]] = icmp eq i8 %loaded.trans10, 110 +; CHECK-NEXT: %[[r89:.+]] = or i1 %[[r88]], %[[r87]] +; CHECK-NEXT: %[[r90:.+]] = select i1 %[[r89]], i8* %m_p, i8* %k_p +; CHECK-NEXT: %loaded.trans11 = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[r91:.+]] = icmp eq i8 %loaded.trans11, 78 +; CHECK-NEXT: %[[r92:.+]] = icmp eq i8 %loaded.trans11, 110 +; CHECK-NEXT: %[[r93:.+]] = or i1 %[[r92]], %[[r91]] +; CHECK-NEXT: %[[r94:.+]] = select i1 %[[r93]], i8* %m_p, i8* %k_p +; CHECK-NEXT: %ld.row.trans12 = load i8, i8* %transb, align 1 +; CHECK-NEXT: %[[r95:.+]] = icmp eq i8 %ld.row.trans12, 110 +; CHECK-NEXT: %[[r96:.+]] = icmp eq i8 %ld.row.trans12, 78 +; CHECK-NEXT: %[[r97:.+]] = or i1 %[[r96]], %[[r95]] +; CHECK-NEXT: %[[r98:.+]] = select i1 %[[r97]], i8* %[[i42]], i8* %"C'" +; CHECK-NEXT: %[[r99:.+]] = select i1 %[[r97]], i8* %[[r94]], i8* %ldc_p +; CHECK-NEXT: %[[r100:.+]] = select i1 %[[r97]], i8* %"C'", i8* %[[i42]] +; CHECK-NEXT: %[[r101:.+]] = select i1 %[[r97]], i8* %ldc_p, i8* %[[r90]] +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.013, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.014 = bitcast double* %byref.constant.fp.1.013 to i8* +; CHECK-NEXT: call void @dgemm_64_(i8* %[[r83]], i8* %[[r84]], i8* %[[r85]], i8* %[[r86]], i8* %m_p, i8* %alpha_p, i8* %[[r98]], i8* %[[r99]], i8* %[[r100]], i8* %[[r101]], i8* %fpcast.constant.fp.1.014, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* ; CHECK-NEXT: store i64 0, i64* %[[byrefconstantint4]] ; CHECK-NEXT: %[[intcast07:.+]] = bitcast i64* %[[byrefconstantint4]] to i8* -; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 -; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* -; CHECK-NEXT: store i64 0, i64* %[[byrefconstantint5]] -; CHECK-NEXT: %[[intcast09:.+]] = bitcast i64* %[[byrefconstantint5]] to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[intcast07]], i8* %fpcast.constant.fp.1.0, i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %[[intcast09]]) +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.017 +; CHECK-NEXT: %fpcast.constant.fp.1.018 = bitcast double* %byref.constant.fp.1.017 to i8* +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[intcast07]], i8* %fpcast.constant.fp.1.018, i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i64 1) ; CHECK-NEXT: %[[ret1:.+]] = bitcast double* %cache.A to i8* ; CHECK-NEXT: tail call void @free(i8* nonnull %[[ret1]]) ; CHECK-NEXT: %[[ret2:.+]] = bitcast double* %cache.B to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy.ll deleted file mode 100644 index c9e40aac8e36..000000000000 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy.ll +++ /dev/null @@ -1,329 +0,0 @@ -;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-lapack-copy=1 -S | FileCheck %s; fi -;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-lapack-copy=1 -S | FileCheck %s - -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) - -define void @f(i8* noalias %C, i8* noalias %A, i8* noalias %B, i8* noalias %alpha, i8* noalias %beta) { -entry: - %transa = alloca i8, align 1 - %transb = alloca i8, align 1 - %m = alloca i64, align 16 - %m_p = bitcast i64* %m to i8* - %n = alloca i64, align 16 - %n_p = bitcast i64* %n to i8* - %k = alloca i64, align 16 - %k_p = bitcast i64* %k to i8* - %lda = alloca i64, align 16 - %lda_p = bitcast i64* %lda to i8* - %ldb = alloca i64, align 16 - %ldb_p = bitcast i64* %ldb to i8* - %ldc = alloca i64, align 16 - %ldc_p = bitcast i64* %ldc to i8* - store i8 78, i8* %transa, align 1 - store i8 78, i8* %transb, align 1 - store i64 4, i64* %m, align 16 - store i64 4, i64* %n, align 16 - store i64 8, i64* %k, align 16 - store i64 4, i64* %lda, align 16 - store i64 8, i64* %ldb, align 16 - store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p, i64 1, i64 1) - %ptr = bitcast i8* %A to double* - store double 0.0000000e+00, double* %ptr, align 8 - ret void -} - -declare dso_local void @__enzyme_autodiff(...) - -define void @active(i8* %C, i8* %dC, i8* %A, i8* %dA, i8* %B, i8* %dB, i8* %alpha, i8* %dalpha, i8* %beta, i8* %dbeta) { -entry: - call void (...) @__enzyme_autodiff(void (i8*,i8*,i8*,i8*,i8*)* @f, metadata !"enzyme_dup", i8* %C, i8* %dC, metadata !"enzyme_dup", i8* %A, i8* %dA, metadata !"enzyme_dup", i8* %B, i8* %dB, metadata !"enzyme_dup", i8* %alpha, i8* %dalpha, metadata !"enzyme_dup", i8* %beta, i8* %dbeta) - ret void -} - -; CHECK: define internal void @diffef(i8* noalias %C, i8* %"C'", i8* noalias %A, i8* %"A'", i8* noalias %B, i8* %"B'", i8* noalias %alpha, i8* %"alpha'", i8* noalias %beta, i8* -; CHECK-NEXT: entry: -; CHECK-NEXT: %byref.constant.one.i15 = alloca i64 -; CHECK-NEXT: %byref.mat.size.i18 = alloca i64 -; CHECK-NEXT: %byref.constant.one.i = alloca i64 -; CHECK-NEXT: %byref.mat.size.i = alloca i64 -; CHECK-NEXT: %[[byrefgarbage:.+]] = alloca i8 -; CHECK-NEXT: %[[byrefgarbage2:.+]] = alloca i8 -; CHECK-NEXT: %ret = alloca double -; CHECK-NEXT: %byref.transpose.transa = alloca i8 -; CHECK-NEXT: %byref.transpose.transb = alloca i8 -; CHECK-NEXT: %byref.int.one = alloca i64 -; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double -; CHECK-NEXT: %byref.constant.fp.0.0 = alloca double -; CHECK-NEXT: %byref.constant.char.G = alloca i8 -; CHECK-NEXT: %byref.constant.int.0 = alloca i64 -; CHECK-NEXT: %byref.constant.int.09 = alloca i64 -; CHECK-NEXT: %byref.constant.fp.1.011 = alloca double -; CHECK-NEXT: %byref.constant.int.013 = alloca i64 -; CHECK-NEXT: %transa = alloca i8, align 1 -; CHECK-NEXT: %transb = alloca i8, align 1 -; CHECK-NEXT: %m = alloca i64, align 16 -; CHECK-NEXT: %m_p = bitcast i64* %m to i8* -; CHECK-NEXT: %n = alloca i64, align 16 -; CHECK-NEXT: %n_p = bitcast i64* %n to i8* -; CHECK-NEXT: %k = alloca i64, align 16 -; CHECK-NEXT: %k_p = bitcast i64* %k to i8* -; CHECK-NEXT: %lda = alloca i64, align 16 -; CHECK-NEXT: %lda_p = bitcast i64* %lda to i8* -; CHECK-NEXT: %ldb = alloca i64, align 16 -; CHECK-NEXT: %ldb_p = bitcast i64* %ldb to i8* -; CHECK-NEXT: %ldc = alloca i64, align 16 -; CHECK-NEXT: %ldc_p = bitcast i64* %ldc to i8* -; CHECK-NEXT: store i8 78, i8* %transa, align 1 -; CHECK-NEXT: store i8 78, i8* %transb, align 1 -; CHECK-NEXT: store i64 4, i64* %m, align 16 -; CHECK-NEXT: store i64 4, i64* %n, align 16 -; CHECK-NEXT: store i64 8, i64* %k, align 16 -; CHECK-NEXT: store i64 4, i64* %lda, align 16 -; CHECK-NEXT: store i64 8, i64* %ldb, align 16 -; CHECK-NEXT: store i64 4, i64* %ldc, align 16 -; CHECK-NEXT: %loaded.trans = load i8, i8* %transa -; CHECK-DAG: %[[i0:.+]] = icmp eq i8 %loaded.trans, 78 -; CHECK-DAG: %[[i1:.+]] = icmp eq i8 %loaded.trans, 110 -; CHECK-NEXT: %2 = or i1 %[[i1]], %[[i0]] -; CHECK-NEXT: %3 = select i1 %2, i8* %m_p, i8* %k_p -; CHECK-NEXT: %4 = select i1 %2, i8* %k_p, i8* %m_p -; CHECK-NEXT: %[[i5:.+]] = bitcast i8* %3 to i64* -; CHECK-NEXT: %[[i6:.+]] = load i64, i64* %[[i5]] -; CHECK-NEXT: %[[i7:.+]] = bitcast i8* %4 to i64* -; CHECK-NEXT: %[[i8:.+]] = load i64, i64* %[[i7]] -; CHECK-NEXT: %9 = mul i64 %[[i6]], %[[i8]] -; CHECK-NEXT: %mallocsize = mul nuw nsw i64 %9, 8 -; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 %mallocsize) -; CHECK-NEXT: %cache.A = bitcast i8* %malloccall to double* -; CHECK-NEXT: store i8 0, i8* %[[byrefgarbage]] -; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %3, i8* %4, i8* %A, i8* %lda_p, double* %cache.A, i8* %3) -; CHECK-NEXT: %10 = bitcast i8* %m_p to i64* -; CHECK-NEXT: %11 = load i64, i64* %10 -; CHECK-NEXT: %12 = bitcast i8* %n_p to i64* -; CHECK-NEXT: %13 = load i64, i64* %12 -; CHECK-NEXT: %14 = mul i64 %11, %13 -; CHECK-NEXT: %mallocsize1 = mul nuw nsw i64 %14, 8 -; CHECK-NEXT: %malloccall2 = tail call noalias nonnull i8* @malloc(i64 %mallocsize1) -; CHECK-NEXT: %cache.C = bitcast i8* %malloccall2 to double* -; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage3 -; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage3, i8* %m_p, i8* %n_p, i8* %C, i8* %ldc_p, double* %cache.C, i8* %m_p) -; CHECK-NEXT: %[[i17:.+]] = bitcast i8* %m_p to i64* -; CHECK-NEXT: %[[i18:.+]] = load i64, i64* %[[i17]] -; CHECK-NEXT: %[[i19:.+]] = bitcast i8* %n_p to i64* -; CHECK-NEXT: %[[i20:.+]] = load i64, i64* %[[i19]] -; CHECK-NEXT: %size_AB = mul nuw i64 %[[i18]], %[[i20]] -; CHECK-NEXT: %mallocsize5 = mul nuw nsw i64 %size_AB, 8 -; CHECK-NEXT: %malloccall6 = tail call noalias nonnull i8* @malloc(i64 %mallocsize5) -; CHECK-NEXT: %mat_AB = bitcast i8* %malloccall6 to double* -; CHECK-NEXT: %[[i21:.+]] = bitcast double* %mat_AB to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p, i64 1, i64 1) -; CHECK-NEXT: %"ptr'ipc" = bitcast i8* %"A'" to double* -; CHECK-NEXT: %ptr = bitcast i8* %A to double* -; CHECK-NEXT: store double 0.000000e+00, double* %ptr, align 8, !alias.scope !0, !noalias !3 -; CHECK-NEXT: br label %invertentry - -; CHECK: invertentry: ; preds = %entry -; CHECK-NEXT: store double 0.000000e+00, double* %"ptr'ipc", align 8, !alias.scope !3, !noalias !0 -; CHECK-NEXT: %[[matA:.+]] = bitcast double* %cache.A to i8* -; CHECK-NEXT: %[[matC0:.+]] = bitcast double* %cache.C to i8* -; CHECK-NEXT: %[[matC:.+]] = bitcast double* %cache.C to i8* -; CHECK-NEXT: %ld.transa = load i8, i8* %transa -; CHECK-DAG: %[[i25:.+]] = icmp eq i8 %ld.transa, 110 -; CHECK-DAG: %[[i26:.+]] = select i1 %[[i25]], i8 116, i8 0 -; CHECK-DAG: %[[i27:.+]] = icmp eq i8 %ld.transa, 78 -; CHECK-DAG: %[[i28:.+]] = select i1 %[[i27]], i8 84, i8 %[[i26]] -; CHECK-DAG: %[[i29:.+]] = icmp eq i8 %ld.transa, 116 -; CHECK-DAG: %[[i30:.+]] = select i1 %[[i29]], i8 110, i8 %[[i28]] -; CHECK-DAG: %[[i31:.+]] = icmp eq i8 %ld.transa, 84 -; CHECK-DAG: %[[i32:.+]] = select i1 %[[i31]], i8 78, i8 %[[i30]] -; CHECK-NEXT: store i8 %[[i32]], i8* %byref.transpose.transa -; CHECK-NEXT: %ld.transb = load i8, i8* %transb -; CHECK-DAG: %[[i33:.+]] = icmp eq i8 %ld.transb, 110 -; CHECK-DAG: %[[i34:.+]] = select i1 %[[i33]], i8 116, i8 0 -; CHECK-DAG: %[[i35:.+]] = icmp eq i8 %ld.transb, 78 -; CHECK-DAG: %[[i36:.+]] = select i1 %[[i35]], i8 84, i8 %[[i34]] -; CHECK-DAG: %[[i37:.+]] = icmp eq i8 %ld.transb, 116 -; CHECK-DAG: %[[i38:.+]] = select i1 %[[i37]], i8 110, i8 %[[i36]] -; CHECK-DAG: %[[i39:.+]] = icmp eq i8 %ld.transb, 84 -; CHECK-DAG: %[[i40:.+]] = select i1 %[[i39]], i8 78, i8 %[[i38]] -; CHECK-NEXT: store i8 %[[i40]], i8* %byref.transpose.transb -; CHECK-NEXT: store i64 1, i64* %byref.int.one -; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 -; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* -; CHECK-NEXT: %loaded.trans7 = load i8, i8* %transa -; CHECK-DAG: %[[i41:.+]] = icmp eq i8 %loaded.trans7, 78 -; CHECK-DAG: %[[i42:.+]] = icmp eq i8 %loaded.trans7, 110 -; CHECK-NEXT: %[[i43:.+]] = or i1 %[[i42]], %[[i41]] -; CHECK-NEXT: %[[i44:.+]] = select i1 %[[i43]], i8* %m_p, i8* %k_p -; CHECK-NEXT: store double 0.000000e+00, double* %byref.constant.fp.0.0 -; CHECK-NEXT: %fpcast.constant.fp.0.0 = bitcast double* %byref.constant.fp.0.0 to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %fpcast.constant.fp.1.0, i8* %[[matA]], i8* %[[i44]], i8* %B, i8* %ldb_p, i8* %fpcast.constant.fp.0.0, i8* %[[i21]], i8* %m_p, i64 1, i64 1) -; CHECK: %[[i45:.+]] = bitcast i64* %byref.constant.one.i to i8* -; CHECK: %[[i46:.+]] = bitcast i64* %byref.mat.size.i to i8* -; CHECK: store i64 1, i64* %byref.constant.one.i -; CHECK-NEXT: %intcast.constant.one.i = bitcast i64* %byref.constant.one.i to i8* -; CHECK-DAG: %[[i47:.+]] = load i64, i64* %m -; CHECK-DAG: %[[i48:.+]] = load i64, i64* %n -; CHECK-DAG: %mat.size.i = mul nuw i64 %[[i47]], %[[i48]] -; CHECK-NEXT: store i64 %mat.size.i, i64* %byref.mat.size.i -; CHECK-NEXT: %intcast.mat.size.i = bitcast i64* %byref.mat.size.i to i8* -; CHECK-NEXT: %[[i49:.+]] = icmp eq i64 %mat.size.i, 0 -; CHECK-NEXT: br i1 %[[i49]], label %__enzyme_inner_prodd_64_.exit, label %init.idx.i - -; CHECK: init.idx.i: ; preds = %invertentry -; CHECK-NEXT: %[[i50:.+]] = load i64, i64* %ldc -; CHECK-NEXT: %[[i51:.+]] = bitcast i8* %"C'" to double* -; CHECK-NEXT: %[[i52:.+]] = icmp eq i64 %[[i47]], %[[i50]] -; CHECK-NEXT: br i1 %[[i52]], label %fast.path.i, label %for.body.i - -; CHECK: fast.path.i: ; preds = %init.idx.i -; CHECK-NEXT: %[[i53:.+]] = call fast double @ddot_64_(i8* %intcast.mat.size.i, i8* %"C'", i8* %intcast.constant.one.i, i8* %[[i21]], i8* %intcast.constant.one.i) -; CHECK-NEXT: br label %__enzyme_inner_prodd_64_.exit - -; CHECK: for.body.i: ; preds = %for.body.i, %init.idx.i -; CHECK-NEXT: %Aidx.i = phi i64 [ 0, %init.idx.i ], [ %Aidx.next.i, %for.body.i ] -; CHECK-NEXT: %Bidx.i = phi i64 [ 0, %init.idx.i ], [ %Bidx.next.i, %for.body.i ] -; CHECK-NEXT: %iteration.i = phi i64 [ 0, %init.idx.i ], [ %iter.next.i, %for.body.i ] -; CHECK-NEXT: %sum.i = phi{{( fast)?}} double [ 0.000000e+00, %init.idx.i ], [ %[[i57:.+]], %for.body.i ] -; CHECK-NEXT: %A.i.i = getelementptr inbounds double, double* %[[i51]], i64 %Aidx.i -; CHECK-NEXT: %B.i.i = getelementptr inbounds double, double* %mat_AB, i64 %Bidx.i -; CHECK-NEXT: %[[i54:.+]] = bitcast double* %A.i.i to i8* -; CHECK-NEXT: %[[i55:.+]] = bitcast double* %B.i.i to i8* -; CHECK-NEXT: %[[i56:.+]] = call fast double @ddot_64_(i8* %m_p, i8* %[[i54]], i8* %intcast.constant.one.i, i8* %[[i55]], i8* %intcast.constant.one.i) -; CHECK-NEXT: %Aidx.next.i = add nuw i64 %Aidx.i, %[[i50]] -; CHECK-NEXT: %Bidx.next.i = add nuw i64 %Aidx.i, %[[i47]] -; CHECK-NEXT: %iter.next.i = add i64 %iteration.i, 1 -; CHECK-NEXT: %[[i57]] = fadd fast double %sum.i, %[[i56]] -; CHECK-NEXT: %[[i58:.+]] = icmp eq i64 %iteration.i, %[[i48]] -; CHECK-NEXT: br i1 %[[i58]], label %__enzyme_inner_prodd_64_.exit, label %for.body.i - -; CHECK: __enzyme_inner_prodd_64_.exit: ; preds = %invertentry, %fast.path.i, %for.body.i -; CHECK-NEXT: %res.i = phi double [ 0.000000e+00, %invertentry ], [ %sum.i, %for.body.i ], [ %[[i53]], %fast.path.i ] -; CHECK-NEXT: %[[i59:.+]] = bitcast i64* %byref.constant.one.i to i8* -; CHECK: %[[i60:.+]] = bitcast i64* %byref.mat.size.i to i8* -; CHECK: %[[i61:.+]] = bitcast i8* %"alpha'" to double* -; CHECK-NEXT: %[[i62:.+]] = load double, double* %[[i61]] -; CHECK-NEXT: %[[i63:.+]] = fadd fast double %[[i62]], %res.i -; CHECK-NEXT: store double %[[i63]], double* %[[i61]] -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %"A'", i8* %lda_p, i64 1, i64 1) -; CHECK-NEXT: %loaded.trans8 = load i8, i8* %transa -; CHECK-DAG: %[[i64:.+]] = icmp eq i8 %loaded.trans8, 78 -; CHECK-DAG: %[[i65:.+]] = icmp eq i8 %loaded.trans8, 110 -; CHECK-DAG: %[[i66:.+]] = or i1 %[[i65]], %[[i64]] -; CHECK-NEXT: %[[i67:.+]] = select i1 %[[i66]], i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha, i8* %[[matA]], i8* %[[i67]], i8* %"C'", i8* %ldc_p, i8* %beta, i8* %"B'", i8* %ldb_p, i64 1, i64 1) -; CHECK: %[[i68:.+]] = bitcast i64* %byref.constant.one.i15 to i8* -; CHECK: %[[i69:.+]] = bitcast i64* %byref.mat.size.i18 to i8* -; CHECK: store i64 1, i64* %byref.constant.one.i15 -; CHECK-NEXT: %intcast.constant.one.i16 = bitcast i64* %byref.constant.one.i15 to i8* -; CHECK-NEXT: %[[i70:.+]] = load i64, i64* %m -; CHECK-NEXT: %[[i71:.+]] = load i64, i64* %n -; CHECK-NEXT: %mat.size.i17 = mul nuw i64 %[[i70]], %[[i71]] -; CHECK-NEXT: store i64 %mat.size.i17, i64* %byref.mat.size.i18 -; CHECK-NEXT: %intcast.mat.size.i19 = bitcast i64* %byref.mat.size.i18 to i8* -; CHECK-NEXT: %[[i72:.+]] = icmp eq i64 %mat.size.i17, 0 -; CHECK-NEXT: br i1 %[[i72]], label %__enzyme_inner_prodd_64_.exit33, label %init.idx.i20 - -; CHECK: init.idx.i20: ; preds = %__enzyme_inner_prodd_64_.exit -; CHECK-NEXT: %[[i73:.+]] = load i64, i64* %ldc -; CHECK-NEXT: %[[i74:.+]] = bitcast i8* %"C'" to double* -; CHECK-NEXT: %[[i75:.+]] = icmp eq i64 %[[i70]], %[[i73]] -; CHECK-NEXT: br i1 %[[i75]], label %fast.path.i21, label %for.body.i31 - -; CHECK: fast.path.i21: ; preds = %init.idx.i20 -; CHECK-NEXT: %[[i76:.+]] = call fast double @ddot_64_(i8* %intcast.mat.size.i19, i8* %"C'", i8* %intcast.constant.one.i16, i8* %[[matC0]], i8* %intcast.constant.one.i16) -; CHECK-NEXT: br label %__enzyme_inner_prodd_64_.exit33 - -; CHECK: for.body.i31: ; preds = %for.body.i31, %init.idx.i20 -; CHECK-NEXT: %Aidx.i22 = phi i64 [ 0, %init.idx.i20 ], [ %Aidx.next.i28, %for.body.i31 ] -; CHECK-NEXT: %Bidx.i23 = phi i64 [ 0, %init.idx.i20 ], [ %Bidx.next.i29, %for.body.i31 ] -; CHECK-NEXT: %iteration.i24 = phi i64 [ 0, %init.idx.i20 ], [ %iter.next.i30, %for.body.i31 ] -; CHECK-NEXT: %sum.i25 = phi{{( fast)?}} double [ 0.000000e+00, %init.idx.i20 ], [ %[[i80:.+]], %for.body.i31 ] -; CHECK-NEXT: %A.i.i26 = getelementptr inbounds double, double* %[[i74]], i64 %Aidx.i22 -; CHECK-NEXT: %B.i.i27 = getelementptr inbounds double, double* %cache.C, i64 %Bidx.i23 -; CHECK-NEXT: %[[i77:.+]] = bitcast double* %A.i.i26 to i8* -; CHECK-NEXT: %[[i78:.+]] = bitcast double* %B.i.i27 to i8* -; CHECK-NEXT: %[[i79:.+]] = call fast double @ddot_64_(i8* %m_p, i8* %[[i77]], i8* %intcast.constant.one.i16, i8* %[[i78]], i8* %intcast.constant.one.i16) -; CHECK-NEXT: %Aidx.next.i28 = add nuw i64 %Aidx.i22, %[[i73]] -; CHECK-NEXT: %Bidx.next.i29 = add nuw i64 %Aidx.i22, %[[i70]] -; CHECK-NEXT: %iter.next.i30 = add i64 %iteration.i24, 1 -; CHECK-NEXT: %[[i80]] = fadd fast double %sum.i25, %[[i79]] -; CHECK-NEXT: %[[i81:.+]] = icmp eq i64 %iteration.i24, %[[i71]] -; CHECK-NEXT: br i1 %[[i81:.+]], label %__enzyme_inner_prodd_64_.exit33, label %for.body.i31 - -; CHECK: __enzyme_inner_prodd_64_.exit33: ; preds = %__enzyme_inner_prodd_64_.exit, %fast.path.i21, %for.body.i31 -; CHECK-NEXT: %res.i32 = phi double [ 0.000000e+00, %__enzyme_inner_prodd_64_.exit ], [ %sum.i25, %for.body.i31 ], [ %[[i76]], %fast.path.i21 ] -; CHECK-NEXT: %[[i82:.+]] = bitcast i64* %byref.constant.one.i15 to i8* -; CHECK: %[[i83:.+]] = bitcast i64* %byref.mat.size.i18 to i8* -; CHECK: %[[i84:.+]] = bitcast i8* %"beta'" to double* -; CHECK-NEXT: %[[i85:.+]] = load double, double* %[[i84]] -; CHECK-NEXT: %[[i86:.+]] = fadd fast double %[[i85]], %res.i32 -; CHECK-NEXT: store double %[[i86]], double* %[[i84]] -; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G -; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 -; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* -; CHECK-NEXT: store i64 0, i64* %byref.constant.int.09 -; CHECK-NEXT: %intcast.constant.int.010 = bitcast i64* %byref.constant.int.09 to i8* -; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.011 -; CHECK-NEXT: %fpcast.constant.fp.1.012 = bitcast double* %byref.constant.fp.1.011 to i8* -; CHECK-NEXT: store i64 0, i64* %byref.constant.int.013 -; CHECK-NEXT: %intcast.constant.int.014 = bitcast i64* %byref.constant.int.013 to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.010, i8* %fpcast.constant.fp.1.012, i8* %beta, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %intcast.constant.int.014) -; CHECK-NEXT: %[[i87:.+]] = bitcast double* %cache.A to i8* -; CHECK-NEXT: tail call void @free(i8* nonnull %[[i87]]) -; CHECK-NEXT: %[[i88:.+]] = bitcast double* %cache.C to i8* -; CHECK-NEXT: tail call void @free(i8* nonnull %[[i88]]) -; CHECK-NEXT: ret void -; CHECK-NEXT: } - -; CHECK: define internal double @__enzyme_inner_prodd_64_(i8* %blasm, i8* %blasn, i8* noalias nocapture readonly %A, i8* %lda, i8* noalias nocapture readonly %B -; CHECK-NEXT: entry: -; CHECK-NEXT: %byref.constant.one = alloca i64 -; CHECK-NEXT: store i64 1, i64* %byref.constant.one -; CHECK-NEXT: %intcast.constant.one = bitcast i64* %byref.constant.one to i8* -; CHECK-NEXT: %0 = bitcast i8* %blasm to i64* -; CHECK-NEXT: %1 = load i64, i64* %0 -; CHECK-NEXT: %2 = bitcast i8* %blasn to i64* -; CHECK-NEXT: %3 = load i64, i64* %2 -; CHECK-NEXT: %mat.size = mul nuw i64 %1, %3 -; CHECK-NEXT: %byref.mat.size = alloca i64 -; CHECK-NEXT: store i64 %mat.size, i64* %byref.mat.size -; CHECK-NEXT: %intcast.mat.size = bitcast i64* %byref.mat.size to i8* -; CHECK-NEXT: %4 = icmp eq i64 %mat.size, 0 -; CHECK-NEXT: br i1 %4, label %for.end, label %init.idx - -; CHECK: init.idx: ; preds = %entry -; CHECK-NEXT: %5 = bitcast i8* %lda to i64* -; CHECK-NEXT: %6 = load i64, i64* %5 -; CHECK-NEXT: %7 = bitcast i8* %A to double* -; CHECK-NEXT: %8 = bitcast i8* %B to double* -; CHECK-NEXT: %9 = icmp eq i64 %1, %6 -; CHECK-NEXT: br i1 %9, label %fast.path, label %for.body - -; CHECK: fast.path: ; preds = %init.idx -; CHECK-NEXT: %10 = call fast double @ddot_64_(i8* %intcast.mat.size, i8* %A, i8* %intcast.constant.one, i8* %B, i8* %intcast.constant.one) -; CHECK-NEXT: br label %for.end - -; CHECK: for.body: ; preds = %for.body, %init.idx -; CHECK-NEXT: %Aidx = phi i64 [ 0, %init.idx ], [ %Aidx.next, %for.body ] -; CHECK-NEXT: %Bidx = phi i64 [ 0, %init.idx ], [ %Bidx.next, %for.body ] -; CHECK-NEXT: %iteration = phi i64 [ 0, %init.idx ], [ %iter.next, %for.body ] -; CHECK-NEXT: %sum = phi{{( fast)?}} double [ 0.000000e+00, %init.idx ], [ %14, %for.body ] -; CHECK-NEXT: %A.i = getelementptr inbounds double, double* %7, i64 %Aidx -; CHECK-NEXT: %B.i = getelementptr inbounds double, double* %8, i64 %Bidx -; CHECK-NEXT: %11 = bitcast double* %A.i to i8* -; CHECK-NEXT: %12 = bitcast double* %B.i to i8* -; CHECK-NEXT: %13 = call fast double @ddot_64_(i8* %blasm, i8* %11, i8* %intcast.constant.one, i8* %12, i8* %intcast.constant.one) -; CHECK-NEXT: %Aidx.next = add nuw i64 %Aidx, %6 -; CHECK-NEXT: %Bidx.next = add nuw i64 %Aidx, %1 -; CHECK-NEXT: %iter.next = add i64 %iteration, 1 -; CHECK-NEXT: %14 = fadd fast double %sum, %13 -; CHECK-NEXT: %15 = icmp eq i64 %iteration, %3 -; CHECK-NEXT: br i1 %15, label %for.end, label %for.body - -; CHECK: for.end: ; preds = %for.body, %fast.path, %entry -; CHECK-NEXT: %res = phi double [ 0.000000e+00, %entry ], [ %sum, %for.body ], [ %10, %fast.path ] -; CHECK-NEXT: ret double %res -; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll index f8d13293ef97..2e983e3ceb6d 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll @@ -43,8 +43,8 @@ entry: ; CHECK: define internal void @diffef(i8* noalias %C, i8* %"C'", i8* noalias %A, i8* %"A'", i8* noalias %B, i8* %"B'", i8* noalias %alpha, i8* %"alpha'", i8* noalias %beta, i8* ; CHECK-NEXT: entry: -; CHECK-NEXT: %byref.constant.one.i15 = alloca i64 -; CHECK-NEXT: %byref.mat.size.i18 = alloca i64 +; CHECK-NEXT: %byref.constant.one.i23 = alloca i64 +; CHECK-NEXT: %byref.mat.size.i26 = alloca i64 ; CHECK-NEXT: %byref.constant.one.i = alloca i64 ; CHECK-NEXT: %byref.mat.size.i = alloca i64 ; CHECK-NEXT: %[[byrefgarbage:.+]] = alloca i8 @@ -55,11 +55,16 @@ entry: ; CHECK-NEXT: %byref.int.one = alloca i64 ; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double ; CHECK-NEXT: %byref.constant.fp.0.0 = alloca double -; CHECK-NEXT: %byref.constant.char.G = alloca i8 -; CHECK-NEXT: %byref.constant.int.0 = alloca i64 -; CHECK-NEXT: %byref.constant.int.09 = alloca i64 -; CHECK-NEXT: %byref.constant.fp.1.011 = alloca double -; CHECK-NEXT: %byref.constant.int.013 = alloca i64 +; CHECK-NEXT: %byref.constant.char.T = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.char.N = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.fp.1.09 = alloca double, align 8 +; CHECK-NEXT: %byref.constant.char.T11 = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.char.N12 = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.fp.1.017 = alloca double, align 8 +; CHECK-NEXT: %byref.constant.char.G = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.int.0 = alloca i64, align 8 +; CHECK-NEXT: %byref.constant.int.019 = alloca i64, align 8 +; CHECK-NEXT: %byref.constant.fp.1.021 = alloca double, align 8 ; CHECK-NEXT: %transa = alloca i8, align 1 ; CHECK-NEXT: %transb = alloca i8, align 1 ; CHECK-NEXT: %m = alloca i64, align 16 @@ -82,11 +87,16 @@ entry: ; CHECK-NEXT: store i64 4, i64* %lda, align 16 ; CHECK-NEXT: store i64 8, i64* %ldb, align 16 ; CHECK-NEXT: store i64 4, i64* %ldc, align 16 -; CHECK-NEXT: %rt.inactive.alpha = icmp eq i8* %"alpha'", %alpha -; CHECK-NEXT: %rt.inactive.A = icmp eq i8* %"A'", %A -; CHECK-NEXT: %rt.inactive.B = icmp eq i8* %"B'", %B -; CHECK-NEXT: %rt.inactive.beta = icmp eq i8* %"beta'", %beta -; CHECK-NEXT: %rt.inactive.C = icmp eq i8* %"C'", %C +; CHECK-NEXT: %rt.tmp.inactive.alpha = icmp eq i8* %"alpha'", %alpha +; CHECK-NEXT: %rt.tmp.inactive.A = icmp eq i8* %"A'", %A +; CHECK-NEXT: %rt.tmp.inactive.B = icmp eq i8* %"B'", %B +; CHECK-NEXT: %rt.tmp.inactive.beta = icmp eq i8* %"beta'", %beta +; CHECK-NEXT: %rt.tmp.inactive.C = icmp eq i8* %"C'", %C +; CHECK-NEXT: %rt.inactive.alpha = or i1 %rt.tmp.inactive.alpha, %rt.tmp.inactive.C +; CHECK-NEXT: %rt.inactive.A = or i1 %rt.tmp.inactive.A, %rt.tmp.inactive.C +; CHECK-NEXT: %rt.inactive.B = or i1 %rt.tmp.inactive.B, %rt.tmp.inactive.C +; CHECK-NEXT: %rt.inactive.beta = or i1 %rt.tmp.inactive.beta, %rt.tmp.inactive.C +; CHECK-NEXT: %rt.inactive.C = or i1 %rt.tmp.inactive.C, %rt.tmp.inactive.C ; CHECK-NEXT: %loaded.trans = load i8, i8* %transa ; CHECK-DAG: %[[i0:.+]] = icmp eq i8 %loaded.trans, 78 ; CHECK-DAG: %[[i1:.+]] = icmp eq i8 %loaded.trans, 110 @@ -97,8 +107,8 @@ entry: ; CHECK-NEXT: %[[i6:.+]] = load i64, i64* %[[i5]] ; CHECK-NEXT: %[[i7:.+]] = bitcast i8* %4 to i64* ; CHECK-NEXT: %[[i8:.+]] = load i64, i64* %[[i7]] -; CHECK-NEXT: %9 = mul i64 %[[i6]], %[[i8]] -; CHECK-NEXT: %mallocsize = mul nuw nsw i64 %9, 8 +; CHECK-NEXT: %[[i9:.+]] = mul i64 %[[i6]], %[[i8]] +; CHECK-NEXT: %mallocsize = mul nuw nsw i64 %[[i9]], 8 ; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 %mallocsize) ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall to double* ; CHECK-NEXT: store i8 0, i8* %[[byrefgarbage]] @@ -221,88 +231,131 @@ entry: ; CHECK-NEXT: br i1 %rt.inactive.A, label %invertentry.A.done, label %invertentry.A.active ; CHECK: invertentry.A.active: ; preds = %invertentry.alpha.done -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %"A'", i8* %lda_p, i64 1, i64 1) +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1 +; CHECK-NEXT: %ld.row.trans = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[r62:.+]] = icmp eq i8 %ld.row.trans, 110 +; CHECK-NEXT: %[[r63:.+]] = icmp eq i8 %ld.row.trans, 78 +; CHECK-NEXT: %[[r64:.+]] = or i1 %[[r63]], %[[r62]] +; CHECK-NEXT: %[[r65:.+]] = select i1 %[[r64]], i8* %byref.constant.char.N, i8* %transb +; CHECK-NEXT: %[[r66:.+]] = select i1 %[[r64]], i8* %byref.transpose.transb, i8* %byref.constant.char.T +; CHECK-NEXT: %[[r67:.+]] = select i1 %[[r64]], i8* %m_p, i8* %k_p +; CHECK-NEXT: %[[r68:.+]] = select i1 %[[r64]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %ld.row.trans8 = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[r69:.+]] = icmp eq i8 %ld.row.trans8, 110 +; CHECK-NEXT: %[[r70:.+]] = icmp eq i8 %ld.row.trans8, 78 +; CHECK-NEXT: %[[r71:.+]] = or i1 %[[r70]], %[[r69]] +; CHECK-NEXT: %[[r72:.+]] = select i1 %[[r71]], i8* %"C'", i8* %B +; CHECK-NEXT: %[[r73:.+]] = select i1 %[[r71]], i8* %ldc_p, i8* %ldb_p +; CHECK-NEXT: %[[r74:.+]] = select i1 %[[r71]], i8* %B, i8* %"C'" +; CHECK-NEXT: %[[r75:.+]] = select i1 %[[r71]], i8* %ldb_p, i8* %ldc_p +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.09, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.010 = bitcast double* %byref.constant.fp.1.09 to i8* +; CHECK-NEXT: call void @dgemm_64_(i8* %[[r65]], i8* %[[r66]], i8* %[[r67]], i8* %[[r68]], i8* %n_p, i8* %alpha, i8* %[[r72]], i8* %[[r73]], i8* %[[r74]], i8* %[[r75]], i8* %fpcast.constant.fp.1.010, i8* %"A'", i8* %lda_p, i64 1, i64 1) ; CHECK-NEXT: br label %invertentry.A.done ; CHECK: invertentry.A.done: ; preds = %invertentry.A.active, %invertentry.alpha.done ; CHECK-NEXT: br i1 %rt.inactive.B, label %invertentry.B.done, label %invertentry.B.active ; CHECK: invertentry.B.active: ; preds = %invertentry.A.done -; CHECK-NEXT: %loaded.trans8 = load i8, i8* %transa -; CHECK-DAG: %[[i64:.+]] = icmp eq i8 %loaded.trans8, 78 -; CHECK-DAG: %[[i65:.+]] = icmp eq i8 %loaded.trans8, 110 -; CHECK-DAG: %[[i66:.+]] = or i1 %[[i65]], %[[i64]] -; CHECK-NEXT: %[[i67:.+]] = select i1 %[[i66]], i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha, i8* %[[matA]], i8* %[[i67]], i8* %"C'", i8* %ldc_p, i8* %beta, i8* %"B'", i8* %ldb_p, i64 1, i64 1) +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T11, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N12, align 1 +; CHECK-NEXT: %ld.row.trans13 = load i8, i8* %transb, align 1 +; CHECK-NEXT: %[[r76:.+]] = icmp eq i8 %ld.row.trans13, 110 +; CHECK-NEXT: %[[r77:.+]] = icmp eq i8 %ld.row.trans13, 78 +; CHECK-NEXT: %[[r78:.+]] = or i1 %[[r77]], %[[r76]] +; CHECK-NEXT: %[[r79:.+]] = select i1 %[[r78]], i8* %byref.transpose.transa, i8* %byref.constant.char.T11 +; CHECK-NEXT: %[[r80:.+]] = select i1 %[[r78]], i8* %byref.constant.char.N12, i8* %transa +; CHECK-NEXT: %[[r81:.+]] = select i1 %[[r78]], i8* %k_p, i8* %n_p +; CHECK-NEXT: %[[r82:.+]] = select i1 %[[r78]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %loaded.trans14 = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[r83:.+]] = icmp eq i8 %loaded.trans14, 78 +; CHECK-NEXT: %[[r84:.+]] = icmp eq i8 %loaded.trans14, 110 +; CHECK-NEXT: %[[r85:.+]] = or i1 %84, %83 +; CHECK-NEXT: %[[r86:.+]] = select i1 %85, i8* %m_p, i8* %k_p +; CHECK-NEXT: %loaded.trans15 = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[r87:.+]] = icmp eq i8 %loaded.trans15, 78 +; CHECK-NEXT: %[[r88:.+]] = icmp eq i8 %loaded.trans15, 110 +; CHECK-NEXT: %[[r89:.+]] = or i1 %[[r88]], %[[r87]] +; CHECK-NEXT: %[[r90:.+]] = select i1 %[[r89]], i8* %m_p, i8* %k_p +; CHECK-NEXT: %ld.row.trans16 = load i8, i8* %transb, align 1 +; CHECK-NEXT: %[[r91:.+]] = icmp eq i8 %ld.row.trans16, 110 +; CHECK-NEXT: %[[r92:.+]] = icmp eq i8 %ld.row.trans16, 78 +; CHECK-NEXT: %[[r93:.+]] = or i1 %[[r92]], %[[r91]] +; CHECK-NEXT: %[[r94:.+]] = select i1 %[[r93]], i8* %[[matA]], i8* %"C'" +; CHECK-NEXT: %[[r95:.+]] = select i1 %[[r93]], i8* %[[r90]], i8* %ldc_p +; CHECK-NEXT: %[[r96:.+]] = select i1 %[[r93]], i8* %"C'", i8* %[[matA]] +; CHECK-NEXT: %[[r97:.+]] = select i1 %[[r93]], i8* %ldc_p, i8* %[[r86]] +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.017, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.018 = bitcast double* %byref.constant.fp.1.017 to i8* +; CHECK-NEXT: call void @dgemm_64_(i8* %79, i8* %80, i8* %81, i8* %82, i8* %m_p, i8* %alpha, i8* %94, i8* %95, i8* %96, i8* %97, i8* %fpcast.constant.fp.1.018, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: br label %invertentry.B.done ; CHECK: invertentry.B.done: ; preds = %invertentry.B.active, %invertentry.A.done ; CHECK-NEXT: br i1 %rt.inactive.beta, label %invertentry.beta.done, label %invertentry.beta.active ; CHECK: invertentry.beta.active: ; preds = %invertentry.B.done -; CHECK: %[[i68:.+]] = bitcast i64* %byref.constant.one.i15 to i8* -; CHECK: %[[i69:.+]] = bitcast i64* %byref.mat.size.i18 to i8* -; CHECK: store i64 1, i64* %byref.constant.one.i15 -; CHECK-NEXT: %intcast.constant.one.i16 = bitcast i64* %byref.constant.one.i15 to i8* +; CHECK: %[[i68:.+]] = bitcast i64* %[[byrefconstantonei15:.+]] to i8* +; CHECK: %[[i69:.+]] = bitcast i64* %[[byrefmatsizei18:.+]] to i8* +; CHECK: store i64 1, i64* %byref.constant.one.i23 +; CHECK-NEXT: %intcast.constant.one.i24 = bitcast i64* %byref.constant.one.i23 to i8* ; CHECK-NEXT: %[[i70:.+]] = load i64, i64* %m ; CHECK-NEXT: %[[i71:.+]] = load i64, i64* %n -; CHECK-NEXT: %mat.size.i17 = mul nuw i64 %[[i70]], %[[i71]] -; CHECK-NEXT: store i64 %mat.size.i17, i64* %byref.mat.size.i18 -; CHECK-NEXT: %intcast.mat.size.i19 = bitcast i64* %byref.mat.size.i18 to i8* -; CHECK-NEXT: %[[i72:.+]] = icmp eq i64 %mat.size.i17, 0 -; CHECK-NEXT: br i1 %[[i72]], label %__enzyme_inner_prodd_64_.exit33, label %init.idx.i20 +; CHECK-NEXT: %mat.size.i25 = mul nuw i64 %[[i70]], %[[i71]] +; CHECK-NEXT: store i64 %mat.size.i25, i64* %[[byrefmatsizei18]] +; CHECK-NEXT: %intcast.mat.size.i27 = bitcast i64* %byref.mat.size.i26 to i8* +; CHECK-NEXT: %[[i72:.+]] = icmp eq i64 %mat.size.i25, 0 +; CHECK-NEXT: br i1 %[[i72]], label %__enzyme_inner_prodd_64_.exit41, label %init.idx.i28 -; CHECK: init.idx.i20: ; preds = %invertentry.beta.active +; CHECK: init.idx.i28: ; preds = %invertentry.beta.active ; CHECK-NEXT: %[[i73:.+]] = load i64, i64* %ldc ; CHECK-NEXT: %[[i74:.+]] = bitcast i8* %"C'" to double* ; CHECK-NEXT: %[[i75:.+]] = icmp eq i64 %[[i70]], %[[i73]] -; CHECK-NEXT: br i1 %[[i75]], label %fast.path.i21, label %for.body.i31 +; CHECK-NEXT: br i1 %[[i75]], label %fast.path.i29, label %for.body.i39 -; CHECK: fast.path.i21: ; preds = %init.idx.i20 -; CHECK-NEXT: %[[i76:.+]] = call fast double @ddot_64_(i8* %intcast.mat.size.i19, i8* %"C'", i8* %intcast.constant.one.i16, i8* %[[matC0]], i8* %intcast.constant.one.i16) -; CHECK-NEXT: br label %__enzyme_inner_prodd_64_.exit33 +; CHECK: fast.path.i29: ; preds = %init.idx.i28 +; CHECK-NEXT: %[[i76:.+]] = call fast double @ddot_64_(i8* %intcast.mat.size.i27, i8* %"C'", i8* %intcast.constant.one.i24, i8* %[[matC0]], i8* %intcast.constant.one.i24) +; CHECK-NEXT: br label %__enzyme_inner_prodd_64_.exit41 -; CHECK: for.body.i31: ; preds = %for.body.i31, %init.idx.i20 -; CHECK-NEXT: %Aidx.i22 = phi i64 [ 0, %init.idx.i20 ], [ %Aidx.next.i28, %for.body.i31 ] -; CHECK-NEXT: %Bidx.i23 = phi i64 [ 0, %init.idx.i20 ], [ %Bidx.next.i29, %for.body.i31 ] -; CHECK-NEXT: %iteration.i24 = phi i64 [ 0, %init.idx.i20 ], [ %iter.next.i30, %for.body.i31 ] -; CHECK-NEXT: %sum.i25 = phi{{( fast)?}} double [ 0.000000e+00, %init.idx.i20 ], [ %[[i80:.+]], %for.body.i31 ] -; CHECK-NEXT: %A.i.i26 = getelementptr inbounds double, double* %[[i74]], i64 %Aidx.i22 -; CHECK-NEXT: %B.i.i27 = getelementptr inbounds double, double* %cache.C, i64 %Bidx.i23 -; CHECK-NEXT: %[[i77:.+]] = bitcast double* %A.i.i26 to i8* -; CHECK-NEXT: %[[i78:.+]] = bitcast double* %B.i.i27 to i8* -; CHECK-NEXT: %[[i79:.+]] = call fast double @ddot_64_(i8* %m_p, i8* %[[i77]], i8* %intcast.constant.one.i16, i8* %[[i78]], i8* %intcast.constant.one.i16) -; CHECK-NEXT: %Aidx.next.i28 = add nuw i64 %Aidx.i22, %[[i73]] -; CHECK-NEXT: %Bidx.next.i29 = add nuw i64 %Aidx.i22, %[[i70]] -; CHECK-NEXT: %iter.next.i30 = add i64 %iteration.i24, 1 -; CHECK-NEXT: %[[i80]] = fadd fast double %sum.i25, %[[i79]] -; CHECK-NEXT: %[[i81:.+]] = icmp eq i64 %iteration.i24, %[[i71]] -; CHECK-NEXT: br i1 %[[i81]], label %__enzyme_inner_prodd_64_.exit33, label %for.body.i31 +; CHECK: [[forbodyi31:.+]]: ; preds = %[[forbodyi31]], %[[initidxi20:.+]] +; CHECK-NEXT: %[[Aidxi22:.+]] = phi i64 [ 0, %[[initidxi20]] ], [ %[[Aidxnexti28:.+]], %[[forbodyi31]] ] +; CHECK-NEXT: %[[Bidxi23:.+]] = phi i64 [ 0, %[[initidxi20]] ], [ %[[Bidxnexti29:.+]], %[[forbodyi31]] ] +; CHECK-NEXT: %[[iterationi24:.+]] = phi i64 [ 0, %[[initidxi20]] ], [ %[[iternexti30:.+]], %[[forbodyi31]] ] +; CHECK-NEXT: %[[sumi25:.+]] = phi fast double [ 0.000000e+00, %[[initidxi20]] ], [ %[[i80:.+]], %[[forbodyi31]] ] +; CHECK-NEXT: %[[Aii26:.+]] = getelementptr inbounds double, double* %[[i74]], i64 %[[Aidxi22]] +; CHECK-NEXT: %[[Bii27:.+]] = getelementptr inbounds double, double* %cache.C, i64 %[[Bidxi23]] +; CHECK-NEXT: %[[i77:.+]] = bitcast double* %[[Aii26]] to i8* +; CHECK-NEXT: %[[i78:.+]] = bitcast double* %[[Bii27]] to i8* +; CHECK-NEXT: %[[i79:.+]] = call fast double @ddot_64_(i8* %m_p, i8* %[[i77]], i8* %intcast.constant.one.i24, i8* %[[i78]], i8* %intcast.constant.one.i24) +; CHECK-NEXT: %[[Aidxnexti28]] = add nuw i64 %[[Aidxi22]], %[[i73]] +; CHECK-NEXT: %[[Bidxnexti29]] = add nuw i64 %[[Aidxi22]], %[[i70]] +; CHECK-NEXT: %[[iternexti30]] = add i64 %[[iterationi24]], 1 +; CHECK-NEXT: %[[i80]] = fadd fast double %[[sumi25]], %[[i79]] +; CHECK-NEXT: %[[i81:.+]] = icmp eq i64 %[[iterationi24]], %[[i71]] +; CHECK-NEXT: br i1 %[[i81]], label %__enzyme_inner_prodd_64_.exit41, label %[[forbodyi31]] -; CHECK: __enzyme_inner_prodd_64_.exit33: ; preds = %invertentry.beta.active, %fast.path.i21, %for.body.i31 -; CHECK-NEXT: %res.i32 = phi double [ 0.000000e+00, %invertentry.beta.active ], [ %sum.i25, %for.body.i31 ], [ %[[i76]], %fast.path.i21 ] -; CHECK-NEXT: %[[i82:.+]] = bitcast i64* %byref.constant.one.i15 to i8* -; CHECK: %[[i83:.+]] = bitcast i64* %byref.mat.size.i18 to i8* +; CHECK: __enzyme_inner_prodd_64_.exit41: ; preds = %invertentry.beta.active, %fast.path.i29, %for.body.i39 +; CHECK-NEXT: %[[resi32:.+]] = phi double [ 0.000000e+00, %invertentry.beta.active ], [ %[[sumi25]], %[[forbodyi31]] ], [ %[[i76]], %fast.path.i29 ] +; CHECK-NEXT: %[[i82:.+]] = bitcast i64* %[[byrefconstantonei15]] to i8* +; CHECK: %[[i83:.+]] = bitcast i64* %[[byrefmatsizei18]] to i8* ; CHECK: %[[i84:.+]] = bitcast i8* %"beta'" to double* ; CHECK-NEXT: %[[i85:.+]] = load double, double* %[[i84]] -; CHECK-NEXT: %[[i86:.+]] = fadd fast double %[[i85]], %res.i32 +; CHECK-NEXT: %[[i86:.+]] = fadd fast double %[[i85]], %[[resi32]] ; CHECK-NEXT: store double %[[i86]], double* %[[i84]] ; CHECK-NEXT: br label %invertentry.beta.done -; CHECK: invertentry.beta.done: ; preds = %__enzyme_inner_prodd_64_.exit33, %invertentry.B.done +; CHECK: invertentry.beta.done: ; preds = %__enzyme_inner_prodd_64_.exit41, %invertentry.B.done ; CHECK-NEXT: br i1 %rt.inactive.C, label %invertentry.C.done, label %invertentry.C.active ; CHECK: invertentry.C.active: ; preds = %invertentry.beta.done ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* -; CHECK-NEXT: store i64 0, i64* %byref.constant.int.09 -; CHECK-NEXT: %intcast.constant.int.010 = bitcast i64* %byref.constant.int.09 to i8* -; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.011 -; CHECK-NEXT: %fpcast.constant.fp.1.012 = bitcast double* %byref.constant.fp.1.011 to i8* -; CHECK-NEXT: store i64 0, i64* %byref.constant.int.013 -; CHECK-NEXT: %intcast.constant.int.014 = bitcast i64* %byref.constant.int.013 to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.010, i8* %fpcast.constant.fp.1.012, i8* %beta, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %intcast.constant.int.014) +; CHECK-NEXT: store i64 0, i64* %byref.constant.int.019 +; CHECK-NEXT: %intcast.constant.int.020 = bitcast i64* %byref.constant.int.019 to i8* +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.021 +; CHECK-NEXT: %fpcast.constant.fp.1.022 = bitcast double* %byref.constant.fp.1.021 to i8* +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.020, i8* %fpcast.constant.fp.1.022, i8* %beta, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i64 1) ; CHECK-NEXT: br label %invertentry.C.done ; CHECK: invertentry.C.done: ; preds = %invertentry.C.active, %invertentry.beta.done diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll index 44cc10d80949..4ebbe1e6a322 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll @@ -78,11 +78,13 @@ entry: ; CHECK-NEXT: %byref.transpose.transa = alloca i8 ; CHECK-NEXT: %byref.transpose.transb = alloca i8 ; CHECK-NEXT: %byref.int.one = alloca i64 -; CHECK-NEXT: %byref.constant.char.G = alloca i8 -; CHECK-NEXT: %byref.constant.int.0 = alloca i64 -; CHECK-NEXT: %[[byrefconstantint31:.+]] = alloca i64 -; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double -; CHECK-NEXT: %[[byrefconstantint33:.+]] = alloca i64 +; CHECK-NEXT: %byref.constant.char.T = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.char.N = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double, align 8 +; CHECK-NEXT: %byref.constant.char.G = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.int.0 = alloca i64, align 8 +; CHECK-NEXT: %byref.constant.int.033 = alloca i64, align 8 +; CHECK-NEXT: %byref.constant.fp.1.035 = alloca double, align 8 ; CHECK-NEXT: %transa = alloca i8, align 1 ; CHECK-NEXT: %transb = alloca i8, align 1 ; CHECK-NEXT: %n = alloca i64, align 16 @@ -243,10 +245,10 @@ entry: ; CHECK-NEXT: %cast.beta = bitcast double* %byref.beta to i8* ; CHECK-NEXT: store i64 %avld.ldc_unwrap, i64* %byref.ldc ; CHECK-NEXT: %cast.ldc = bitcast i64* %byref.ldc to i8* -; CHECK-NEXT: %[[i44:.+]] = bitcast double* %cache.A_unwrap to i8* -; CHECK-NEXT: %[[i45:.+]] = load i64, i64* %"iv'ac" -; CHECK-NEXT: %[[i46:.+]] = load i8*, i8** %m_p_cache, align 8, !invariant.group !6 -; CHECK-NEXT: %[[i47:.+]] = load i64, i64* %"iv'ac" +; CHECK-NEXT: %[[r35:.+]] = bitcast double* %cache.A_unwrap to i8* +; CHECK-NEXT: %[[r36:.+]] = load i64, i64* %"iv'ac" +; CHECK-NEXT: %[[r37:.+]] = load i8*, i8** %m_p_cache, align 8, !invariant.group !6 +; CHECK-NEXT: %[[r38:.+]] = load i64, i64* %"iv'ac" ; CHECK-NEXT: %n_p_unwrap = bitcast i64* %n to i8* ; CHECK-NEXT: %ld.transa = load i8, i8* %byref.transa ; CHECK-DAG: %[[r0:.+]] = icmp eq i8 %ld.transa, 110 @@ -270,29 +272,52 @@ entry: ; CHECK-NEXT: store i8 %[[r15]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: %loaded.trans30 = load i8, i8* %byref.transa -; CHECK-DAG: %[[r18:.+]] = icmp eq i8 %loaded.trans30, 78 -; CHECK-DAG: %[[r19:.+]] = icmp eq i8 %loaded.trans30, 110 -; CHECK-DAG: %[[r20:.+]] = or i1 %[[r19]], %[[r18]] -; CHECK-DAG: %[[r21:.+]] = select i1 %[[r20]], i8* %[[i46]], i8* %cast.k -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %byref.transb, i8* %cast.k, i8* %n_p_unwrap, i8* %[[i46]], i8* %cast.alpha, i8* %[[i44]], i8* %[[r21]], i8* %"C'", i8* %cast.ldc, i8* %cast.beta, i8* %"B'", i8* %cast.ldb, i64 1, i64 1) -; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G -; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 -; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* -; CHECK-NEXT: store i64 0, i64* %[[byrefconstantint31]] -; CHECK-NEXT: %intcast.constant.int.032 = bitcast i64* %[[byrefconstantint31]] to i8* -; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1 +; CHECK-NEXT: %ld.row.trans = load i8, i8* %byref.transb, align 1 +; CHECK-NEXT: %[[r55:.+]] = icmp eq i8 %ld.row.trans, 110 +; CHECK-NEXT: %[[r56:.+]] = icmp eq i8 %ld.row.trans, 78 +; CHECK-NEXT: %[[r57:.+]] = or i1 %[[r56]], %[[r55]] +; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %byref.transpose.transa, i8* %byref.constant.char.T +; CHECK-NEXT: %[[r59:.+]] = select i1 %[[r57]], i8* %byref.constant.char.N, i8* %byref.transa +; CHECK-NEXT: %[[r60:.+]] = select i1 %[[r57]], i8* %cast.k, i8* %n_p_unwrap +; CHECK-NEXT: %[[r61:.+]] = select i1 %[[r57]], i8* %n_p_unwrap, i8* %cast.k +; CHECK-NEXT: %loaded.trans30 = load i8, i8* %byref.transa, align 1 +; CHECK-NEXT: %[[r62:.+]] = icmp eq i8 %loaded.trans30, 78 +; CHECK-NEXT: %[[r63:.+]] = icmp eq i8 %loaded.trans30, 110 +; CHECK-NEXT: %[[r64:.+]] = or i1 %[[r63]], %[[r62]] +; CHECK-NEXT: %[[r65:.+]] = select i1 %[[r64]], i8* %[[r37]], i8* %cast.k +; CHECK-NEXT: %loaded.trans31 = load i8, i8* %byref.transa, align 1 +; CHECK-NEXT: %[[r66:.+]] = icmp eq i8 %loaded.trans31, 78 +; CHECK-NEXT: %[[r67:.+]] = icmp eq i8 %loaded.trans31, 110 +; CHECK-NEXT: %[[r68:.+]] = or i1 %[[r67]], %[[r66]] +; CHECK-NEXT: %[[r69:.+]] = select i1 %[[r68]], i8* %[[r37]], i8* %cast.k +; CHECK-NEXT: %ld.row.trans32 = load i8, i8* %byref.transb, align 1 +; CHECK-NEXT: %[[r70:.+]] = icmp eq i8 %ld.row.trans32, 110 +; CHECK-NEXT: %[[r71:.+]] = icmp eq i8 %ld.row.trans32, 78 +; CHECK-NEXT: %[[r72:.+]] = or i1 %[[r71]], %[[r70]] +; CHECK-NEXT: %[[r73:.+]] = select i1 %[[r72]], i8* %[[r35]], i8* %"C'" +; CHECK-NEXT: %[[r74:.+]] = select i1 %[[r72]], i8* %[[r69]], i8* %cast.ldc +; CHECK-NEXT: %[[r75:.+]] = select i1 %[[r72]], i8* %"C'", i8* %[[r35]] +; CHECK-NEXT: %[[r76:.+]] = select i1 %[[r72]], i8* %cast.ldc, i8* %[[r65]] +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0, align 8 ; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* -; CHECK-NEXT: store i64 0, i64* %[[byrefconstantint33]] -; CHECK-NEXT: %intcast.constant.int.034 = bitcast i64* %[[byrefconstantint33]] to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.032, i8* %fpcast.constant.fp.1.0, i8* %cast.beta, i8* %[[i46]], i8* %n_p_unwrap, i8* %"C'", i8* %cast.ldc, i8* %intcast.constant.int.034) -; CHECK-NEXT: %[[i68:.+]] = bitcast double* %cache.A_unwrap to i8* -; CHECK-NEXT: tail call void @free(i8* nonnull %[[i68]]) -; CHECK-NEXT: call void @free(i8* %[[i46]]) -; CHECK-NEXT: %[[i69:.+]] = load i64, i64* %"iv'ac" -; CHECK-NEXT: %[[i70:.+]] = icmp eq i64 %[[i69]], 0 -; CHECK-NEXT: %[[i71:.+]] = xor i1 %[[i70]], true -; CHECK-NEXT: br i1 %[[i70]], label %invertentry, label %incinvertloop +; CHECK-NEXT: call void @dgemm_64_(i8* %58, i8* %59, i8* %60, i8* %61, i8* %37, i8* %cast.alpha, i8* %73, i8* %74, i8* %75, i8* %76, i8* %fpcast.constant.fp.1.0, i8* %"B'", i8* %cast.ldb, i64 1, i64 1) +; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G, align 1 +; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0, align 4 +; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* +; CHECK-NEXT: store i64 0, i64* %byref.constant.int.033, align 4 +; CHECK-NEXT: %intcast.constant.int.034 = bitcast i64* %byref.constant.int.033 to i8* +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.035, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.036 = bitcast double* %byref.constant.fp.1.035 to i8* +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.034, i8* %fpcast.constant.fp.1.036, i8* %cast.beta, i8* %37, i8* %n_p_unwrap, i8* %"C'", i8* %cast.ldc, i64 1) +; CHECK-NEXT: %[[r77:.+]] = bitcast double* %cache.A_unwrap to i8* +; CHECK-NEXT: tail call void @free(i8* nonnull %77) +; CHECK-NEXT: call void @free(i8* %37) +; CHECK-NEXT: %[[r78:.+]] = load i64, i64* %"iv'ac", align 4 +; CHECK-NEXT: %[[r79:.+]] = icmp eq i64 %[[r78]], 0 +; CHECK-NEXT: %[[r80:.+]] = xor i1 %[[r79]], true +; CHECK-NEXT: br i1 %[[r79]], label %invertentry, label %incinvertloop ; CHECK: incinvertloop: ; preds = %invertloop ; CHECK-NEXT: %[[i72:.+]] = load i64, i64* %"iv'ac" diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll index b3a6f1743944..80e2b91f724e 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll @@ -157,11 +157,16 @@ entry: ; CHECK-NEXT: %byref.transpose.transa = alloca i8 ; CHECK-NEXT: %byref.transpose.transb = alloca i8 ; CHECK-NEXT: %byref.int.one = alloca i64 +; CHECK-NEXT: %byref.constant.char.T = alloca i8 +; CHECK-NEXT: %byref.constant.char.N = alloca i8 +; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double +; CHECK-NEXT: %byref.constant.char.T2 = alloca i8 +; CHECK-NEXT: %byref.constant.char.N3 = alloca i8 +; CHECK-NEXT: %byref.constant.fp.1.07 = alloca double ; CHECK-NEXT: %byref.constant.char.G = alloca i8 ; CHECK-NEXT: %byref.constant.int.0 = alloca i64 -; CHECK-NEXT: %[[byrefconstantint1:.+]] = alloca i64 -; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double -; CHECK-NEXT: %[[byrefconstantint2:.+]] = alloca i64 +; CHECK-NEXT: %byref.constant.int.09 = alloca i64 +; CHECK-NEXT: %byref.constant.fp.1.011 = alloca double ; CHECK-NEXT: %ldc = alloca i64, i64 1, align 16 ; CHECK-NEXT: %1 = bitcast i64* %ldc to i8* ; CHECK-NEXT: %beta = alloca double, i64 1, align 16 @@ -209,7 +214,7 @@ entry: ; CHECK-NEXT: br label %invertentry ; CHECK: invertentry: ; preds = %entry -; CHECK-NEXT: %17 = bitcast double* %0 to i8* +; CHECK-NEXT: %[[r17:.+]] = bitcast double* %0 to i8* ; CHECK-NEXT: %ld.transa = load i8, i8* %malloccall ; CHECK-DAG: %[[r0:.+]] = icmp eq i8 %ld.transa, 110 ; CHECK-DAG: %[[r1:.+]] = select i1 %[[r0]], i8 116, i8 0 @@ -232,24 +237,67 @@ entry: ; CHECK-NEXT: store i8 %[[r15]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) -; CHECK-NEXT: %loaded.trans = load i8, i8* %malloccall -; CHECK-DAG: %[[r18:.+]] = icmp eq i8 %loaded.trans, 78 -; CHECK-DAG: %[[r19:.+]] = icmp eq i8 %loaded.trans, 110 -; CHECK-DAG: %[[r20:.+]] = or i1 %[[r19]], %[[r18]] -; CHECK-NEXT: %[[r21:.+]] = select i1 %[[r20]], i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %malloccall1, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %17, i8* %[[r21]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) -; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G -; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 -; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* -; CHECK-NEXT: store i64 0, i64* %[[byrefconstantint1]] -; CHECK-NEXT: %intcast.constant.int.02 = bitcast i64* %[[byrefconstantint1]] to i8* -; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1 +; CHECK-NEXT: %ld.row.trans = load i8, i8* %malloccall, align 1 +; CHECK-NEXT: %[[r34:.+]] = icmp eq i8 %ld.row.trans, 110 +; CHECK-NEXT: %[[r35:.+]] = icmp eq i8 %ld.row.trans, 78 +; CHECK-NEXT: %[[r36:.+]] = or i1 %[[r35]], %[[r34]] +; CHECK-NEXT: %[[r37:.+]] = select i1 %[[r36:.+]], i8* %byref.constant.char.N, i8* %malloccall1 +; CHECK-NEXT: %[[r38:.+]] = select i1 %[[r36:.+]], i8* %byref.transpose.transb, i8* %byref.constant.char.T +; CHECK-NEXT: %[[r39:.+]] = select i1 %[[r36:.+]], i8* %m_p, i8* %k_p +; CHECK-NEXT: %[[r40:.+]] = select i1 %[[r36:.+]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %ld.row.trans1 = load i8, i8* %malloccall, align 1 +; CHECK-NEXT: %[[r41:.+]] = icmp eq i8 %ld.row.trans1, 110 +; CHECK-NEXT: %[[r42:.+]] = icmp eq i8 %ld.row.trans1, 78 +; CHECK-NEXT: %[[r43:.+]] = or i1 %[[r42]], %[[r41]] +; CHECK-NEXT: %[[r44:.+]] = select i1 %[[r43:.+]], i8* %"C'", i8* %B +; CHECK-NEXT: %[[r45:.+]] = select i1 %[[r43:.+]], i8* %ldc_p, i8* %ldb_p +; CHECK-NEXT: %[[r46:.+]] = select i1 %[[r43:.+]], i8* %B, i8* %"C'" +; CHECK-NEXT: %[[r47:.+]] = select i1 %[[r43:.+]], i8* %ldb_p, i8* %ldc_p +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0, align 8 ; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* -; CHECK-NEXT: store i64 0, i64* %[[byrefconstantint2]] -; CHECK-NEXT: %intcast.constant.int.04 = bitcast i64* %[[byrefconstantint2]] to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.02, i8* %fpcast.constant.fp.1.0, i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %intcast.constant.int.04) -; CHECK-NEXT: %[[i34:.+]] = bitcast double* %0 to i8* -; CHECK-NEXT: tail call void @free(i8* nonnull %[[i34]]) +; CHECK-NEXT: call void @dgemm_64_(i8* %37, i8* %38, i8* %39, i8* %40, i8* %n_p, i8* %alpha_p, i8* %44, i8* %45, i8* %46, i8* %47, i8* %fpcast.constant.fp.1.0, i8* %"A'", i8* %lda_p, i64 1, i64 1) +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T2, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N3, align 1 +; CHECK-NEXT: %ld.row.trans4 = load i8, i8* %malloccall1, align 1 +; CHECK-NEXT: %[[r48:.+]] = icmp eq i8 %ld.row.trans4, 110 +; CHECK-NEXT: %[[r49:.+]] = icmp eq i8 %ld.row.trans4, 78 +; CHECK-NEXT: %[[r50:.+]] = or i1 %[[r49]], %[[r48]] +; CHECK-NEXT: %[[r51:.+]] = select i1 %[[r50]], i8* %byref.transpose.transa, i8* %byref.constant.char.T2 +; CHECK-NEXT: %[[r52:.+]] = select i1 %[[r50]], i8* %byref.constant.char.N3, i8* %malloccall +; CHECK-NEXT: %[[r53:.+]] = select i1 %[[r50]], i8* %k_p, i8* %n_p +; CHECK-NEXT: %[[r54:.+]] = select i1 %[[r50]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %loaded.trans = load i8, i8* %malloccall, align 1 +; CHECK-NEXT: %[[r55:.+]] = icmp eq i8 %loaded.trans, 78 +; CHECK-NEXT: %[[r56:.+]] = icmp eq i8 %loaded.trans, 110 +; CHECK-NEXT: %[[r57:.+]] = or i1 %[[r56]], %[[r55]] +; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %m_p, i8* %k_p +; CHECK-NEXT: %loaded.trans5 = load i8, i8* %malloccall, align 1 +; CHECK-NEXT: %[[r59:.+]] = icmp eq i8 %loaded.trans5, 78 +; CHECK-NEXT: %[[r60:.+]] = icmp eq i8 %loaded.trans5, 110 +; CHECK-NEXT: %[[r61:.+]] = or i1 %[[r60]], %[[r59]] +; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %m_p, i8* %k_p +; CHECK-NEXT: %ld.row.trans6 = load i8, i8* %malloccall1, align 1 +; CHECK-NEXT: %[[r63:.+]] = icmp eq i8 %ld.row.trans6, 110 +; CHECK-NEXT: %[[r64:.+]] = icmp eq i8 %ld.row.trans6, 78 +; CHECK-NEXT: %[[r65:.+]] = or i1 %[[r64]], %[[r63]] +; CHECK-NEXT: %[[r66:.+]] = select i1 %[[r65]], i8* %[[r17]], i8* %"C'" +; CHECK-NEXT: %[[r67:.+]] = select i1 %[[r65]], i8* %[[r62]], i8* %ldc_p +; CHECK-NEXT: %[[r68:.+]] = select i1 %[[r65]], i8* %"C'", i8* %[[r17]] +; CHECK-NEXT: %[[r69:.+]] = select i1 %[[r65]], i8* %ldc_p, i8* %[[r58]] +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.07, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.08 = bitcast double* %byref.constant.fp.1.07 to i8* +; CHECK-NEXT: call void @dgemm_64_(i8* %[[r51]], i8* %[[r52]], i8* %[[r53]], i8* %[[r54]], i8* %m_p, i8* %alpha_p, i8* %[[r66]], i8* %[[r67]], i8* %[[r68]], i8* %[[r69]], i8* %fpcast.constant.fp.1.08, i8* %"B'", i8* %ldb_p, i64 1, i64 1) +; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G, align 1 +; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0, align 4 +; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* +; CHECK-NEXT: store i64 0, i64* %byref.constant.int.09, align 4 +; CHECK-NEXT: %intcast.constant.int.010 = bitcast i64* %byref.constant.int.09 to i8* +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.011, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.012 = bitcast double* %byref.constant.fp.1.011 to i8* +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.010, i8* %fpcast.constant.fp.1.012, i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i64 1) +; CHECK-NEXT: %[[r70:.+]] = bitcast double* %0 to i8* +; CHECK-NEXT: tail call void @free(i8* nonnull %[[r70]]) ; CHECK-NEXT: ret void ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll index b15204f9b31d..755b366a68df 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll @@ -102,16 +102,16 @@ entry: ; CHECK-NEXT: store double 0.000000e+00, double* %15 ; CHECK-NEXT: store i64 4, i64* %16, align 16 ; CHECK-NEXT: %loaded.trans = load i8, i8* %malloccall -; CHECK-DAG: %[[i17:.+]] = icmp eq i8 %loaded.trans, 78 -; CHECK-DAG: %[[i18:.+]] = icmp eq i8 %loaded.trans, 110 -; CHECK-NEXT: %19 = or i1 %[[i18]], %[[i17]] +; CHECK-DAG: %[[ri17:.+]] = icmp eq i8 %loaded.trans, 78 +; CHECK-DAG: %[[ri18:.+]] = icmp eq i8 %loaded.trans, 110 +; CHECK-NEXT: %19 = or i1 %[[ri18]], %[[ri17]] ; CHECK-NEXT: %20 = select i1 %19, i8* %m_p, i8* %k_p ; CHECK-NEXT: %21 = select i1 %19, i8* %k_p, i8* %m_p -; CHECK-NEXT: %[[i22:.+]] = bitcast i8* %20 to i64* -; CHECK-NEXT: %[[i24:.+]] = load i64, i64* %[[i22]] -; CHECK-NEXT: %[[i23:.+]] = bitcast i8* %21 to i64* -; CHECK-NEXT: %[[i25:.+]] = load i64, i64* %[[i23]] -; CHECK-NEXT: %26 = mul i64 %[[i24]], %[[i25]] +; CHECK-NEXT: %[[ri22:.+]] = bitcast i8* %20 to i64* +; CHECK-NEXT: %[[ri24:.+]] = load i64, i64* %[[ri22]] +; CHECK-NEXT: %[[ri23:.+]] = bitcast i8* %21 to i64* +; CHECK-NEXT: %[[ri25:.+]] = load i64, i64* %[[ri23]] +; CHECK-NEXT: %26 = mul i64 %[[ri24]], %[[ri25]] ; CHECK-NEXT: %mallocsize = mul nuw nsw i64 %26, 8 ; CHECK-NEXT: %malloccall10 = tail call noalias nonnull i8* @malloc(i64 %mallocsize) ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall10 to double* @@ -129,11 +129,16 @@ entry: ; CHECK-NEXT: %byref.transpose.transa = alloca i8 ; CHECK-NEXT: %byref.transpose.transb = alloca i8 ; CHECK-NEXT: %byref.int.one = alloca i64 +; CHECK-NEXT: %byref.constant.char.T = alloca i8 +; CHECK-NEXT: %byref.constant.char.N = alloca i8 +; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double +; CHECK-NEXT: %byref.constant.char.T2 = alloca i8 +; CHECK-NEXT: %byref.constant.char.N3 = alloca i8 +; CHECK-NEXT: %byref.constant.fp.1.07 = alloca double ; CHECK-NEXT: %byref.constant.char.G = alloca i8 ; CHECK-NEXT: %byref.constant.int.0 = alloca i64 -; CHECK-NEXT: %[[byrefconstantint1:.+]] = alloca i64 -; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double -; CHECK-NEXT: %[[byrefconstantint2:.+]] = alloca i64 +; CHECK-NEXT: %byref.constant.int.09 = alloca i64 +; CHECK-NEXT: %byref.constant.fp.1.011 = alloca double ; CHECK-NEXT: %ldc = alloca i64, i64 1, align 16 ; CHECK-NEXT: %1 = bitcast i64* %ldc to i8* ; CHECK-NEXT: %beta = alloca double, i64 1, align 16 @@ -181,7 +186,7 @@ entry: ; CHECK-NEXT: br label %invertentry ; CHECK: invertentry: ; preds = %entry -; CHECK-NEXT: %17 = bitcast double* %0 to i8* +; CHECK-NEXT: %[[r17:.+]] = bitcast double* %0 to i8* ; CHECK-NEXT: %ld.transa = load i8, i8* %malloccall ; CHECK-DAG: %[[r0:.+]] = icmp eq i8 %ld.transa, 110 ; CHECK-DAG: %[[r1:.+]] = select i1 %[[r0]], i8 116, i8 0 @@ -204,24 +209,67 @@ entry: ; CHECK-NEXT: store i8 %[[r15]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) -; CHECK-NEXT: %loaded.trans = load i8, i8* %malloccall -; CHECK-DAG: %[[r18:.+]] = icmp eq i8 %loaded.trans, 78 -; CHECK-DAG: %[[r19:.+]] = icmp eq i8 %loaded.trans, 110 -; CHECK-NEXT: %[[r20:.+]] = or i1 %[[r19]], %[[r18]] -; CHECK-NEXT: %[[r21:.+]] = select i1 %[[r20]], i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %malloccall1, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %17, i8* %[[r21]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) -; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G -; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 -; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* -; CHECK-NEXT: store i64 0, i64* %[[byrefconstantint1]] -; CHECK-NEXT: %intcast.constant.int.02 = bitcast i64* %[[byrefconstantint1]] to i8* -; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1 +; CHECK-NEXT: %ld.row.trans = load i8, i8* %malloccall, align 1 +; CHECK-NEXT: %[[r34:.+]] = icmp eq i8 %ld.row.trans, 110 +; CHECK-NEXT: %[[r35:.+]] = icmp eq i8 %ld.row.trans, 78 +; CHECK-NEXT: %[[r36:.+]] = or i1 %[[r35]], %[[r34]] +; CHECK-NEXT: %[[r37:.+]] = select i1 %[[r36]], i8* %byref.constant.char.N, i8* %malloccall1 +; CHECK-NEXT: %[[r38:.+]] = select i1 %[[r36]], i8* %byref.transpose.transb, i8* %byref.constant.char.T +; CHECK-NEXT: %[[r39:.+]] = select i1 %[[r36]], i8* %m_p, i8* %k_p +; CHECK-NEXT: %[[r40:.+]] = select i1 %[[r36]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %ld.row.trans1 = load i8, i8* %malloccall, align 1 +; CHECK-NEXT: %[[r41:.+]] = icmp eq i8 %ld.row.trans1, 110 +; CHECK-NEXT: %[[r42:.+]] = icmp eq i8 %ld.row.trans1, 78 +; CHECK-NEXT: %[[r43:.+]] = or i1 %[[r42]], %[[r41]] +; CHECK-NEXT: %[[r44:.+]] = select i1 %[[r43]], i8* %"C'", i8* %B +; CHECK-NEXT: %[[r45:.+]] = select i1 %[[r43]], i8* %ldc_p, i8* %ldb_p +; CHECK-NEXT: %[[r46:.+]] = select i1 %[[r43]], i8* %B, i8* %"C'" +; CHECK-NEXT: %[[r47:.+]] = select i1 %[[r43]], i8* %ldb_p, i8* %ldc_p +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0, align 8 ; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* -; CHECK-NEXT: store i64 0, i64* %[[byrefconstantint2]] -; CHECK-NEXT: %intcast.constant.int.04 = bitcast i64* %[[byrefconstantint2]] to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.02, i8* %fpcast.constant.fp.1.0, i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %intcast.constant.int.04) -; CHECK-NEXT: %[[i34:.+]] = bitcast double* %0 to i8* -; CHECK-NEXT: tail call void @free(i8* nonnull %[[i34]]) +; CHECK-NEXT: call void @dgemm_64_(i8* %37, i8* %38, i8* %39, i8* %40, i8* %n_p, i8* %alpha_p, i8* %44, i8* %45, i8* %46, i8* %47, i8* %fpcast.constant.fp.1.0, i8* %"A'", i8* %lda_p, i64 1, i64 1) +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T2, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N3, align 1 +; CHECK-NEXT: %ld.row.trans4 = load i8, i8* %malloccall1, align 1 +; CHECK-NEXT: %[[r48:.+]] = icmp eq i8 %ld.row.trans4, 110 +; CHECK-NEXT: %[[r49:.+]] = icmp eq i8 %ld.row.trans4, 78 +; CHECK-NEXT: %[[r50:.+]] = or i1 %[[r49]], %[[r48]] +; CHECK-NEXT: %[[r51:.+]] = select i1 %[[r50]], i8* %byref.transpose.transa, i8* %byref.constant.char.T2 +; CHECK-NEXT: %[[r52:.+]] = select i1 %[[r50]], i8* %byref.constant.char.N3, i8* %malloccall +; CHECK-NEXT: %[[r53:.+]] = select i1 %[[r50]], i8* %k_p, i8* %n_p +; CHECK-NEXT: %[[r54:.+]] = select i1 %[[r50]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %loaded.trans = load i8, i8* %malloccall, align 1 +; CHECK-NEXT: %[[r55:.+]] = icmp eq i8 %loaded.trans, 78 +; CHECK-NEXT: %[[r56:.+]] = icmp eq i8 %loaded.trans, 110 +; CHECK-NEXT: %[[r57:.+]] = or i1 %[[r56]], %[[r55]] +; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %m_p, i8* %k_p +; CHECK-NEXT: %loaded.trans5 = load i8, i8* %malloccall, align 1 +; CHECK-NEXT: %[[r59:.+]] = icmp eq i8 %loaded.trans5, 78 +; CHECK-NEXT: %[[r60:.+]] = icmp eq i8 %loaded.trans5, 110 +; CHECK-NEXT: %[[r61:.+]] = or i1 %[[r60]], %[[r59]] +; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %m_p, i8* %k_p +; CHECK-NEXT: %ld.row.trans6 = load i8, i8* %malloccall1, align 1 +; CHECK-NEXT: %[[r63:.+]] = icmp eq i8 %ld.row.trans6, 110 +; CHECK-NEXT: %[[r64:.+]] = icmp eq i8 %ld.row.trans6, 78 +; CHECK-NEXT: %[[r65:.+]] = or i1 %[[r64]], %[[r63]] +; CHECK-NEXT: %[[r66:.+]] = select i1 %[[r65]], i8* %[[r17]], i8* %"C'" +; CHECK-NEXT: %[[r67:.+]] = select i1 %[[r65]], i8* %[[r62]], i8* %ldc_p +; CHECK-NEXT: %[[r68:.+]] = select i1 %[[r65]], i8* %"C'", i8* %[[r17]] +; CHECK-NEXT: %[[r69:.+]] = select i1 %[[r65]], i8* %ldc_p, i8* %[[r58]] +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.07, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.08 = bitcast double* %byref.constant.fp.1.07 to i8* +; CHECK-NEXT: call void @dgemm_64_(i8* %51, i8* %52, i8* %53, i8* %54, i8* %m_p, i8* %alpha_p, i8* %66, i8* %67, i8* %68, i8* %69, i8* %fpcast.constant.fp.1.08, i8* %"B'", i8* %ldb_p, i64 1, i64 1) +; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G, align 1 +; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0, align 4 +; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* +; CHECK-NEXT: store i64 0, i64* %byref.constant.int.09, align 4 +; CHECK-NEXT: %intcast.constant.int.010 = bitcast i64* %byref.constant.int.09 to i8* +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.011, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.012 = bitcast double* %byref.constant.fp.1.011 to i8* +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.010, i8* %fpcast.constant.fp.1.012, i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i64 1) +; CHECK-NEXT: %70 = bitcast double* %0 to i8* +; CHECK-NEXT: tail call void @free(i8* nonnull %70) ; CHECK-NEXT: ret void ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll index 4561682b3ae9..709419e244bc 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll @@ -102,16 +102,16 @@ entry: ; CHECK-NEXT: store double 0.000000e+00, double* %15 ; CHECK-NEXT: store i64 4, i64* %16, align 16 ; CHECK-NEXT: %loaded.trans = load i8, i8* %malloccall -; CHECK-DAG: %[[i17:.+]] = icmp eq i8 %loaded.trans, 78 -; CHECK-DAG: %[[i18:.+]] = icmp eq i8 %loaded.trans, 110 -; CHECK-NEXT: %19 = or i1 %[[i18]], %[[i17]] +; CHECK-DAG: %[[r17:.+]] = icmp eq i8 %loaded.trans, 78 +; CHECK-DAG: %[[r18:.+]] = icmp eq i8 %loaded.trans, 110 +; CHECK-NEXT: %19 = or i1 %[[r18]], %[[r17]] ; CHECK-NEXT: %20 = select i1 %19, i8* %m_p, i8* %k_p ; CHECK-NEXT: %21 = select i1 %19, i8* %k_p, i8* %m_p -; CHECK-NEXT: %[[i22:.+]] = bitcast i8* %20 to i64* -; CHECK-NEXT: %[[i24:.+]] = load i64, i64* %[[i22]] -; CHECK-NEXT: %[[i23:.+]] = bitcast i8* %21 to i64* -; CHECK-NEXT: %[[i25:.+]] = load i64, i64* %[[i23]] -; CHECK-NEXT: %26 = mul i64 %[[i24]], %[[i25]] +; CHECK-NEXT: %[[r22:.+]] = bitcast i8* %20 to i64* +; CHECK-NEXT: %[[r24:.+]] = load i64, i64* %[[r22]] +; CHECK-NEXT: %[[r23:.+]] = bitcast i8* %21 to i64* +; CHECK-NEXT: %[[r25:.+]] = load i64, i64* %[[r23]] +; CHECK-NEXT: %26 = mul i64 %[[r24]], %[[r25]] ; CHECK-NEXT: %mallocsize = mul nuw nsw i64 %26, 8 ; CHECK-NEXT: %malloccall10 = tail call noalias nonnull i8* @malloc(i64 %mallocsize) ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall10 to double* @@ -129,11 +129,16 @@ entry: ; CHECK-NEXT: %byref.transpose.transa = alloca i8 ; CHECK-NEXT: %byref.transpose.transb = alloca i8 ; CHECK-NEXT: %byref.int.one = alloca i64 -; CHECK-NEXT: %byref.constant.char.G = alloca i8 -; CHECK-NEXT: %byref.constant.int.0 = alloca i64 -; CHECK-NEXT: %[[byrefconstantint1:.+]] = alloca i64 -; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double -; CHECK-NEXT: %[[byrefconstantint2:.+]] = alloca i64 +; CHECK-NEXT: %byref.constant.char.T = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.char.N = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double, align 8 +; CHECK-NEXT: %byref.constant.char.T2 = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.char.N3 = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.fp.1.07 = alloca double, align 8 +; CHECK-NEXT: %byref.constant.char.G = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.int.0 = alloca i64, align 8 +; CHECK-NEXT: %byref.constant.int.09 = alloca i64, align 8 +; CHECK-NEXT: %byref.constant.fp.1.011 = alloca double, align 8 ; CHECK-NEXT: %ldc = alloca i64, i64 1, align 16 ; CHECK-NEXT: %1 = bitcast i64* %ldc to i8* ; CHECK-NEXT: %beta = alloca double, i64 1, align 16 @@ -181,47 +186,90 @@ entry: ; CHECK-NEXT: br label %invertentry ; CHECK: invertentry: ; preds = %entry -; CHECK-NEXT: %17 = bitcast double* %0 to i8* +; CHECK-NEXT: %[[r17]] = bitcast double* %0 to i8* ; CHECK-NEXT: %ld.transa = load i8, i8* %malloccall -; CHECK-DAG: %[[i18:.+]] = icmp eq i8 %ld.transa, 110 -; CHECK-DAG: %[[i19:.+]] = select i1 %[[i18]], i8 116, i8 0 -; CHECK-DAG: %[[i20:.+]] = icmp eq i8 %ld.transa, 78 -; CHECK-DAG: %[[i21:.+]] = select i1 %[[i20]], i8 84, i8 %[[i19]] -; CHECK-DAG: %[[i22:.+]] = icmp eq i8 %ld.transa, 116 -; CHECK-DAG: %[[i23:.+]] = select i1 %[[i22]], i8 110, i8 %[[i21]] -; CHECK-DAG: %[[i24:.+]] = icmp eq i8 %ld.transa, 84 -; CHECK-DAG: %[[i25:.+]] = select i1 %[[i24]], i8 78, i8 %[[i23]] -; CHECK-NEXT: store i8 %[[i25]], i8* %byref.transpose.transa +; CHECK-DAG: %[[r18:.+]] = icmp eq i8 %ld.transa, 110 +; CHECK-DAG: %[[r19:.+]] = select i1 %[[r18]], i8 116, i8 0 +; CHECK-DAG: %[[r20:.+]] = icmp eq i8 %ld.transa, 78 +; CHECK-DAG: %[[r21:.+]] = select i1 %[[r20]], i8 84, i8 %[[r19]] +; CHECK-DAG: %[[r22:.+]] = icmp eq i8 %ld.transa, 116 +; CHECK-DAG: %[[r23:.+]] = select i1 %[[r22]], i8 110, i8 %[[r21]] +; CHECK-DAG: %[[r24:.+]] = icmp eq i8 %ld.transa, 84 +; CHECK-DAG: %[[r25:.+]] = select i1 %[[r24]], i8 78, i8 %[[r23]] +; CHECK-NEXT: store i8 %[[r25]], i8* %byref.transpose.transa ; CHECK-NEXT: %ld.transb = load i8, i8* %malloccall1 -; CHECK-DAG: %[[i26:.+]] = icmp eq i8 %ld.transb, 110 -; CHECK-DAG: %[[i27:.+]] = select i1 %[[i26]], i8 116, i8 0 -; CHECK-DAG: %[[i28:.+]] = icmp eq i8 %ld.transb, 78 -; CHECK-DAG: %[[i29:.+]] = select i1 %[[i28]], i8 84, i8 %[[i27]] -; CHECK-DAG: %[[i30:.+]] = icmp eq i8 %ld.transb, 116 -; CHECK-DAG: %[[i31:.+]] = select i1 %[[i30]], i8 110, i8 %[[i29]] -; CHECK-DAG: %[[i32:.+]] = icmp eq i8 %ld.transb, 84 -; CHECK-DAG: %[[i33:.+]] = select i1 %[[i32]], i8 78, i8 %[[i31]] -; CHECK-NEXT: store i8 %[[i33]], i8* %byref.transpose.transb +; CHECK-DAG: %[[r26:.+]] = icmp eq i8 %ld.transb, 110 +; CHECK-DAG: %[[r27:.+]] = select i1 %[[r26]], i8 116, i8 0 +; CHECK-DAG: %[[r28:.+]] = icmp eq i8 %ld.transb, 78 +; CHECK-DAG: %[[r29:.+]] = select i1 %[[r28]], i8 84, i8 %[[r27]] +; CHECK-DAG: %[[r30:.+]] = icmp eq i8 %ld.transb, 116 +; CHECK-DAG: %[[r31:.+]] = select i1 %[[r30]], i8 110, i8 %[[r29]] +; CHECK-DAG: %[[r32:.+]] = icmp eq i8 %ld.transb, 84 +; CHECK-DAG: %[[r33:.+]] = select i1 %[[r32]], i8 78, i8 %[[r31]] +; CHECK-NEXT: store i8 %[[r33]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) -; CHECK-NEXT: %loaded.trans = load i8, i8* %malloccall -; CHECK-DAG: %[[i34:.+]] = icmp eq i8 %loaded.trans, 78 -; CHECK-DAG: %[[i35:.+]] = icmp eq i8 %loaded.trans, 110 -; CHECK-DAG: %36 = or i1 %[[i35]], %[[i34]] -; CHECK-NEXT: %37 = select i1 %36, i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %malloccall1, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %17, i8* %37, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) -; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G -; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 -; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* -; CHECK-NEXT: store i64 0, i64* %byref.constant.int.01 -; CHECK-NEXT: %intcast.constant.int.02 = bitcast i64* %[[byrefconstantint1]] to i8* -; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1 +; CHECK-NEXT: %ld.row.trans = load i8, i8* %malloccall, align 1 +; CHECK-NEXT: %[[r34:.+]] = icmp eq i8 %ld.row.trans, 110 +; CHECK-NEXT: %[[r35:.+]] = icmp eq i8 %ld.row.trans, 78 +; CHECK-NEXT: %[[r36:.+]] = or i1 %[[r35]], %[[r34]] +; CHECK-NEXT: %[[r37:.+]] = select i1 %[[r36]], i8* %byref.constant.char.N, i8* %malloccall1 +; CHECK-NEXT: %[[r38:.+]] = select i1 %[[r36]], i8* %byref.transpose.transb, i8* %byref.constant.char.T +; CHECK-NEXT: %[[r39:.+]] = select i1 %[[r36]], i8* %m_p, i8* %k_p +; CHECK-NEXT: %[[r40:.+]] = select i1 %[[r36]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %ld.row.trans1 = load i8, i8* %malloccall, align 1 +; CHECK-NEXT: %[[r41:.+]] = icmp eq i8 %ld.row.trans1, 110 +; CHECK-NEXT: %[[r42:.+]] = icmp eq i8 %ld.row.trans1, 78 +; CHECK-NEXT: %[[r43:.+]] = or i1 %[[r42]], %[[r41]] +; CHECK-NEXT: %[[r44:.+]] = select i1 %[[r43]], i8* %"C'", i8* %B +; CHECK-NEXT: %[[r45:.+]] = select i1 %[[r43]], i8* %ldc_p, i8* %ldb_p +; CHECK-NEXT: %[[r46:.+]] = select i1 %[[r43]], i8* %B, i8* %"C'" +; CHECK-NEXT: %[[r47:.+]] = select i1 %[[r43]], i8* %ldb_p, i8* %ldc_p +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0, align 8 ; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* -; CHECK-NEXT: store i64 0, i64* %[[byrefconstantint2]] -; CHECK-NEXT: %intcast.constant.int.04 = bitcast i64* %[[byrefconstantint2]] to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.02, i8* %fpcast.constant.fp.1.0, i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %intcast.constant.int.04) -; CHECK-NEXT: %[[i34:.+]] = bitcast double* %0 to i8* -; CHECK-NEXT: tail call void @free(i8* nonnull %[[i34]]) +; CHECK-NEXT: call void @dgemm_64_(i8* %[[r37]], i8* %[[r38]], i8* %[[r39]], i8* %[[r40]], i8* %n_p, i8* %alpha_p, i8* %[[r44]], i8* %[[r45]], i8* %[[r46]], i8* %[[r47]], i8* %fpcast.constant.fp.1.0, i8* %"A'", i8* %lda_p, i64 1, i64 1) +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T2, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N3, align 1 +; CHECK-NEXT: %ld.row.trans4 = load i8, i8* %malloccall1, align 1 +; CHECK-NEXT: %[[r48:.+]] = icmp eq i8 %ld.row.trans4, 110 +; CHECK-NEXT: %[[r49:.+]] = icmp eq i8 %ld.row.trans4, 78 +; CHECK-NEXT: %[[r50:.+]] = or i1 %[[r49]], %[[r48]] +; CHECK-NEXT: %[[r51:.+]] = select i1 %[[r50]], i8* %byref.transpose.transa, i8* %byref.constant.char.T2 +; CHECK-NEXT: %[[r52:.+]] = select i1 %[[r50]], i8* %byref.constant.char.N3, i8* %malloccall +; CHECK-NEXT: %[[r53:.+]] = select i1 %[[r50]], i8* %k_p, i8* %n_p +; CHECK-NEXT: %[[r54:.+]] = select i1 %[[r50]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %loaded.trans = load i8, i8* %malloccall, align 1 +; CHECK-NEXT: %[[r55:.+]] = icmp eq i8 %loaded.trans, 78 +; CHECK-NEXT: %[[r56:.+]] = icmp eq i8 %loaded.trans, 110 +; CHECK-NEXT: %[[r57:.+]] = or i1 %[[r56]], %[[r55]] +; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %m_p, i8* %k_p +; CHECK-NEXT: %loaded.trans5 = load i8, i8* %malloccall, align 1 +; CHECK-NEXT: %[[r59:.+]] = icmp eq i8 %loaded.trans5, 78 +; CHECK-NEXT: %[[r60:.+]] = icmp eq i8 %loaded.trans5, 110 +; CHECK-NEXT: %[[r61:.+]] = or i1 %[[r60]], %[[r59]] +; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %m_p, i8* %k_p +; CHECK-NEXT: %ld.row.trans6 = load i8, i8* %malloccall1, align 1 +; CHECK-NEXT: %[[r63:.+]] = icmp eq i8 %ld.row.trans6, 110 +; CHECK-NEXT: %[[r64:.+]] = icmp eq i8 %ld.row.trans6, 78 +; CHECK-NEXT: %[[r65:.+]] = or i1 %[[r64]], %[[r63]] +; CHECK-NEXT: %[[r66:.+]] = select i1 %[[r65]], i8* %[[r17]], i8* %"C'" +; CHECK-NEXT: %[[r67:.+]] = select i1 %[[r65]], i8* %[[r62]], i8* %ldc_p +; CHECK-NEXT: %[[r68:.+]] = select i1 %[[r65]], i8* %"C'", i8* %[[r17]] +; CHECK-NEXT: %[[r69:.+]] = select i1 %[[r65]], i8* %ldc_p, i8* %[[r58]] +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.07, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.08 = bitcast double* %byref.constant.fp.1.07 to i8* +; CHECK-NEXT: call void @dgemm_64_(i8* %[[r51]], i8* %[[r52]], i8* %[[r53]], i8* %[[r54]], i8* %m_p, i8* %alpha_p, i8* %[[r66]], i8* %[[r67]], i8* %[[r68]], i8* %[[r69]], i8* %fpcast.constant.fp.1.08, i8* %"B'", i8* %ldb_p, i64 1, i64 1) +; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G, align 1 +; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0, align 4 +; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* +; CHECK-NEXT: store i64 0, i64* %byref.constant.int.09, align 4 +; CHECK-NEXT: %intcast.constant.int.010 = bitcast i64* %byref.constant.int.09 to i8* +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.011, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.012 = bitcast double* %byref.constant.fp.1.011 to i8* +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.010, i8* %fpcast.constant.fp.1.012, i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i64 1) +; CHECK-NEXT: %[[r70:.+]] = bitcast double* %0 to i8* +; CHECK-NEXT: tail call void @free(i8* nonnull %[[r70]]) ; CHECK-NEXT: ret void ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll index d6c145ee54ae..6b8599dedeb5 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll @@ -55,11 +55,16 @@ entry: ; CHECK-NEXT: %byref.transpose.transa = alloca i8 ; CHECK-NEXT: %byref.transpose.transb = alloca i8 ; CHECK-NEXT: %byref.int.one = alloca i64 +; CHECK-NEXT: %byref.constant.char.T = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.char.N = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double, align 8 +; CHECK-NEXT: %byref.constant.char.T8 = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.char.N9 = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.fp.1.014 = alloca double, align 8 ; CHECK-NEXT: %byref.constant.char.G = alloca i8 ; CHECK-NEXT: %byref.constant.int.0 = alloca i64 ; CHECK-NEXT: %[[int04:.+]] = alloca i64 -; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double -; CHECK-NEXT: %[[int05:.+]] = alloca i64 +; CHECK-NEXT: %byref.constant.fp.1.018 = alloca double ; CHECK-NEXT: %transa = alloca i8, align 1 ; CHECK-NEXT: %transb = alloca i8, align 1 ; CHECK-NEXT: %m = alloca i64, align 16 @@ -152,28 +157,85 @@ entry: ; CHECK-NEXT: store i8 %[[i41]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* + +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1 +; CHECK-NEXT: %ld.row.trans = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[a38:.+]] = icmp eq i8 %ld.row.trans, 110 +; CHECK-NEXT: %[[a39:.+]] = icmp eq i8 %ld.row.trans, 78 +; CHECK-NEXT: %[[a40:.+]] = or i1 %[[a39]], %[[a38]] +; CHECK-NEXT: %[[a41:.+]] = select i1 %[[a40]], i8* %byref.constant.char.N, i8* %transb +; CHECK-NEXT: %[[a42:.+]] = select i1 %[[a40]], i8* %byref.transpose.transb, i8* %byref.constant.char.T +; CHECK-NEXT: %[[a43:.+]] = select i1 %[[a40]], i8* %m_p, i8* %k_p +; CHECK-NEXT: %[[a44:.+]] = select i1 %[[a40]], i8* %k_p, i8* %m_p + ; CHECK-NEXT: %loaded.trans5 = load i8, i8* %transb ; CHECK-DAG: %[[i40:.+]] = icmp eq i8 %loaded.trans5, 78 ; CHECK-DAG: %[[i41:.+]] = icmp eq i8 %loaded.trans5, 110 ; CHECK-NEXT: %[[i42:.+]] = or i1 %[[i41]], %[[i40]] ; CHECK-NEXT: %[[i43:.+]] = select i1 %[[i42]], i8* %k_p, i8* %n_p -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %[[i25]], i8* %[[i43]], i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) + +; CHECK-NEXT: %loaded.trans6 = load i8, i8* %transb, align 1 +; CHECK-NEXT: %[[a49:.+]] = icmp eq i8 %loaded.trans6, 78 +; CHECK-NEXT: %[[a50:.+]] = icmp eq i8 %loaded.trans6, 110 +; CHECK-NEXT: %[[a51:.+]] = or i1 %[[a50]], %[[a49]] +; CHECK-NEXT: %[[a52:.+]] = select i1 %[[a51]], i8* %k_p, i8* %n_p +; CHECK-NEXT: ld.row.trans7 = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[a53:.+]] = icmp eq i8 %ld.row.trans7, 110 +; CHECK-NEXT: %[[a54:.+]] = icmp eq i8 %ld.row.trans7, 78 +; CHECK-NEXT: %[[a55:.+]] = or i1 %[[a54]], %[[a53]] +; CHECK-NEXT: %[[a56:.+]] = select i1 %[[a55]], i8* %"C'", i8* %[[i25]] +; CHECK-NEXT: %[[a57:.+]] = select i1 %[[a55]], i8* %ldc_p, i8* %[[i43]] +; CHECK-NEXT: %[[a58:.+]] = select i1 %[[a55]], i8* %[[i25]], i8* %"C'" +; CHECK-NEXT: %[[a59:.+]] = select i1 %[[a55]], i8* %[[a52]], i8* %ldc_p +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* + +; CHECK-NEXT: call void @dgemm_64_(i8* %[[a41]], i8* %[[a42]], i8* %[[a43]], i8* %[[a44]], i8* %n_p, i8* %alpha_p, i8* %[[a56]], i8* %[[a57]], i8* %[[a58]], i8* %[[a59]], i8* %fpcast.constant.fp.1.0, i8* %"A'", i8* %lda_p, i64 1, i64 1) + +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T8, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N9, align 1 +; CHECK-NEXT: %ld.row.trans10 = load i8, i8* %transb, align 1 +; CHECK-NEXT: %[[a60:.+]] = icmp eq i8 %ld.row.trans10, 110 +; CHECK-NEXT: %[[a61:.+]] = icmp eq i8 %ld.row.trans10, 78 +; CHECK-NEXT: %[[a62:.+]] = or i1 %[[a61]], %[[a60]] +; CHECK-NEXT: %[[a63:.+]] = select i1 %[[a62]], i8* %byref.transpose.transa, i8* %byref.constant.char.T8 +; CHECK-NEXT: %[[a64:.+]] = select i1 %[[a62]], i8* %byref.constant.char.N9, i8* %transa +; CHECK-NEXT: %[[a65:.+]] = select i1 %[[a62]], i8* %k_p, i8* %n_p +; CHECK-NEXT: %[[a66:.+]] = select i1 %[[a62]], i8* %n_p, i8* %k_p + + ; CHECK-NEXT: %[[cachedtrans2:.+]] = load i8, i8* %transa ; CHECK-DAG: %[[i54:.+]] = icmp eq i8 %[[cachedtrans2]], 78 ; CHECK-DAG: %[[i55:.+]] = icmp eq i8 %[[cachedtrans2]], 110 ; CHECK-NEXT: %[[i56:.+]] = or i1 %[[i55]], %[[i54]] ; CHECK-NEXT: %[[i57:.+]] = select i1 %[[i56]], i8* %m_p, i8* %k_p -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %[[i24]], i8* %[[i57]], i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) + +; CHECK-NEXT: %loaded.trans12 = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[a71:.+]] = icmp eq i8 %loaded.trans12, 78 +; CHECK-NEXT: %[[a72:.+]] = icmp eq i8 %loaded.trans12, 110 +; CHECK-NEXT: %[[a73:.+]] = or i1 %[[a72]], %[[a71]] +; CHECK-NEXT: %[[a74:.+]] = select i1 %[[a73]], i8* %m_p, i8* %k_p +; CHECK-NEXT: %ld.row.trans13 = load i8, i8* %transb, align 1 +; CHECK-NEXT: %[[a75:.+]] = icmp eq i8 %ld.row.trans13, 110 +; CHECK-NEXT: %[[a76:.+]] = icmp eq i8 %ld.row.trans13, 78 +; CHECK-NEXT: %[[a77:.+]] = or i1 %[[a76]], %[[a75]] +; CHECK-NEXT: %[[a78:.+]] = select i1 %[[a77]], i8* %[[i24]], i8* %"C'" +; CHECK-NEXT: %[[a79:.+]] = select i1 %[[a77]], i8* %[[a74]], i8* %ldc_p +; CHECK-NEXT: %[[a80:.+]] = select i1 %[[a77]], i8* %"C'", i8* %[[i24]] +; CHECK-NEXT: %[[a81:.+]] = select i1 %[[a77]], i8* %ldc_p, i8* %[[i57]] +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.014, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.015 = bitcast double* %byref.constant.fp.1.014 to i8* + +; CHECK-NEXT: call void @dgemm_64_(i8* %[[a63]], i8* %[[a64]], i8* %[[a65]], i8* %[[a66]], i8* %m_p, i8* %alpha_p, i8* %[[a78]], i8* %[[a79]], i8* %[[a80]], i8* %[[a81]], i8* %fpcast.constant.fp.1.015, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %[[intcast0:.+]] = bitcast i64* %byref.constant.int.0 to i8* ; CHECK-NEXT: store i64 0, i64* %[[int04]] ; CHECK-NEXT: %[[intcast08:.+]] = bitcast i64* %[[int04]] to i8* ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 -; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* -; CHECK-NEXT: store i64 0, i64* %[[int05]] -; CHECK-NEXT: %[[intcast010:.+]] = bitcast i64* %[[int05]] to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %[[intcast0]], i8* %[[intcast08]], i8* %fpcast.constant.fp.1.0, i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %[[intcast010]]) +; CHECK-NEXT: %[[fp19:.+]] = bitcast double* %byref.constant.fp.1.018 to i8* +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %[[intcast0]], i8* %[[intcast08]], i8* %[[fp19]], i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i64 1) ; CHECK-NEXT: %[[free1:.+]] = bitcast double* %cache.A to i8* ; CHECK-NEXT: tail call void @free(i8* nonnull %[[free1]]) ; CHECK-NEXT: %[[free2:.+]] = bitcast double* %cache.B to i8* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll index 51879c3537da..bfd9b9be2aba 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll @@ -55,11 +55,13 @@ entry: ; CHECK-NEXT: %byref.transpose.transa = alloca i8 ; CHECK-NEXT: %byref.transpose.transb = alloca i8 ; CHECK-NEXT: %byref.int.one = alloca i64 +; CHECK-NEXT: %byref.constant.char.T = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.char.N = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double, align 8 ; CHECK-NEXT: %byref.constant.char.G = alloca i8 ; CHECK-NEXT: %byref.constant.int.0 = alloca i64 ; CHECK-NEXT: %[[byrefint03:.+]] = alloca i64 -; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double -; CHECK-NEXT: %[[byrefint04:.+]] = alloca i64 +; CHECK-NEXT: %byref.constant.fp.1.06 = alloca double ; CHECK-NEXT: %transa = alloca i8, align 1 ; CHECK-NEXT: %transb = alloca i8, align 1 ; CHECK-NEXT: %m = alloca i64, align 16 @@ -110,7 +112,7 @@ entry: ; CHECK-NEXT: br label %invertentry ; CHECK: invertentry: ; preds = %entry -; CHECK-NEXT: %10 = bitcast double* %cache.B to i8* +; CHECK-NEXT: %[[a10:.+]] = bitcast double* %cache.B to i8* ; CHECK-NEXT: %ld.transa = load i8, i8* %transa ; CHECK-DAG: %[[r0:.+]] = icmp eq i8 %ld.transa, 110 ; CHECK-DAG: %[[r1:.+]] = select i1 %[[r0]], i8 116, i8 0 @@ -133,22 +135,48 @@ entry: ; CHECK-DAG: store i8 %[[r15]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1 +; CHECK-NEXT: %ld.row.trans = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[a27:.+]] = icmp eq i8 %ld.row.trans, 110 +; CHECK-NEXT: %[[a28:.+]] = icmp eq i8 %ld.row.trans, 78 +; CHECK-NEXT: %[[a29:.+]] = or i1 %[[a28]], %[[a27]] +; CHECK-NEXT: %[[a30:.+]] = select i1 %[[a29]], i8* %byref.constant.char.N, i8* %transb +; CHECK-NEXT: %[[a31:.+]] = select i1 %[[a29]], i8* %byref.transpose.transb, i8* %byref.constant.char.T +; CHECK-NEXT: %[[a32:.+]] = select i1 %[[a29]], i8* %m_p, i8* %k_p +; CHECK-NEXT: %[[a33:.+]] = select i1 %[[a29]], i8* %k_p, i8* %m_p ; CHECK-NEXT: %loaded.trans1 = load i8, i8* %transb ; CHECK-DAG: %[[r16:.+]] = icmp eq i8 %loaded.trans1, 78 ; CHECK-DAG: %[[r17:.+]] = icmp eq i8 %loaded.trans1, 110 ; CHECK-NEXT: %[[r18:.+]] = or i1 %[[r17]], %[[r16]] ; CHECK-NEXT: %[[r19:.+]] = select i1 %[[r18]], i8* %k_p, i8* %n_p -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %10, i8* %[[r19]], i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) + +; CHECK-NEXT: %loaded.trans2 = load i8, i8* %transb, align 1 +; CHECK-NEXT: %[[a38:.+]] = icmp eq i8 %loaded.trans2, 78 +; CHECK-NEXT: %[[a39:.+]] = icmp eq i8 %loaded.trans2, 110 +; CHECK-NEXT: %[[a40:.+]] = or i1 %[[a39]], %[[a38]] +; CHECK-NEXT: %[[a41:.+]] = select i1 %[[a40]], i8* %k_p, i8* %n_p +; CHECK-NEXT: %ld.row.trans3 = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[a42:.+]] = icmp eq i8 %ld.row.trans3, 110 +; CHECK-NEXT: %[[a43:.+]] = icmp eq i8 %ld.row.trans3, 78 +; CHECK-NEXT: %[[a44:.+]] = or i1 %[[a43]], %[[a42]] +; CHECK-NEXT: %[[a45:.+]] = select i1 %[[a44]], i8* %"C'", i8* %[[a10]] +; CHECK-NEXT: %[[a46:.+]] = select i1 %[[a44]], i8* %ldc_p, i8* %[[r19]] +; CHECK-NEXT: %[[a47:.+]] = select i1 %[[a44]], i8* %[[a10]], i8* %"C'" +; CHECK-NEXT: %[[a48:.+]] = select i1 %[[a44]], i8* %[[a41]], i8* %ldc_p +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* + +; CHECK-NEXT: call void @dgemm_64_(i8* %[[a30]], i8* %[[a31]], i8* %[[a32]], i8* %[[a33]], i8* %n_p, i8* %alpha_p, i8* %[[a45]], i8* %[[a46]], i8* %[[a47]], i8* %[[a48]], i8* %fpcast.constant.fp.1.0, i8* %"A'", i8* %lda_p, i64 1, i64 1) + ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* -; CHECK-NEXT: store i64 0, i64* %byref.constant.int.02 -; CHECK-NEXT: %intcast.constant.int.03 = bitcast i64* %byref.constant.int.02 to i8* -; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 -; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.04 ; CHECK-NEXT: %intcast.constant.int.05 = bitcast i64* %byref.constant.int.04 to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.03, i8* %fpcast.constant.fp.1.0, i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %intcast.constant.int.05) +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 +; CHECK-NEXT: %fpcast.constant.fp.1.07 = bitcast double* %byref.constant.fp.1.06 to i8* +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.05, i8* %fpcast.constant.fp.1.07, i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i64 1) ; CHECK-NEXT: %[[ret:.+]] = bitcast double* %cache.B to i8* ; CHECK-NEXT: tail call void @free(i8* nonnull %[[ret]]) ; CHECK-NEXT: ret void diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_lacpy.ll deleted file mode 100644 index f07275c6df1e..000000000000 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_lacpy.ll +++ /dev/null @@ -1,126 +0,0 @@ -;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-lapack-copy=1 -S | FileCheck %s; fi -;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-lapack-copy=1 -S | FileCheck %s - -declare void @dgemm_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i8*, i8* nocapture readonly, i8* nocapture readonly, i8*, i8* nocapture readonly, i64, i64) - -define void @f(i8* %C, i8* %A, i8* %B) { -entry: - %transa = alloca i8, align 1 - %transb = alloca i8, align 1 - %m = alloca i64, align 16 - %m_p = bitcast i64* %m to i8* - %n = alloca i64, align 16 - %n_p = bitcast i64* %n to i8* - %k = alloca i64, align 16 - %k_p = bitcast i64* %k to i8* - %alpha = alloca double, align 16 - %alpha_p = bitcast double* %alpha to i8* - %lda = alloca i64, align 16 - %lda_p = bitcast i64* %lda to i8* - %ldb = alloca i64, align 16 - %ldb_p = bitcast i64* %ldb to i8* - %beta = alloca double, align 16 - %beta_p = bitcast double* %beta to i8* - %ldc = alloca i64, align 16 - %ldc_p = bitcast i64* %ldc to i8* - store i8 78, i8* %transa, align 1 - store i8 78, i8* %transb, align 1 - store i64 4, i64* %m, align 16 - store i64 4, i64* %n, align 16 - store i64 8, i64* %k, align 16 - store double 1.000000e+00, double* %alpha, align 16 - store i64 4, i64* %lda, align 16 - store i64 8, i64* %ldb, align 16 - store double 0.000000e+00, double* %beta - store i64 4, i64* %ldc, align 16 - call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) - ret void -} - -declare dso_local void @__enzyme_autodiff(...) - -define void @active(i8* %C, i8* %dC, i8* %A, i8* %dA, i8* %B, i8* %dB) { -entry: - call void (...) @__enzyme_autodiff(void (i8*,i8*,i8*)* @f, metadata !"enzyme_dup", i8* %C, i8* %dC, metadata !"enzyme_dup", i8* %A, i8* %dA, metadata !"enzyme_dup", i8* %B, i8* %dB) - ret void -} - -; CHECK: define internal void @diffef(i8* %C, i8* %"C'", i8* %A, i8* %"A'", i8* %B, i8* %"B'") -; CHECK-NEXT: entry: -; CHECK-NEXT: %ret = alloca double -; CHECK-NEXT: %byref.transpose.transa = alloca i8 -; CHECK-NEXT: %byref.transpose.transb = alloca i8 -; CHECK-NEXT: %byref.int.one = alloca i64 -; CHECK-NEXT: %byref.constant.char.G = alloca i8 -; CHECK-NEXT: %[[int00:.+]] = alloca i64 -; CHECK-NEXT: %[[int01:.+]] = alloca i64 -; CHECK-NEXT: %[[fp10:.+]] = alloca double -; CHECK-NEXT: %[[int02:.+]] = alloca i64 -; CHECK-NEXT: %transa = alloca i8, align 1 -; CHECK-NEXT: %transb = alloca i8, align 1 -; CHECK-NEXT: %m = alloca i64, align 16 -; CHECK-NEXT: %m_p = bitcast i64* %m to i8* -; CHECK-NEXT: %n = alloca i64, align 16 -; CHECK-NEXT: %n_p = bitcast i64* %n to i8* -; CHECK-NEXT: %k = alloca i64, align 16 -; CHECK-NEXT: %k_p = bitcast i64* %k to i8* -; CHECK-NEXT: %alpha = alloca double, align 16 -; CHECK-NEXT: %alpha_p = bitcast double* %alpha to i8* -; CHECK-NEXT: %lda = alloca i64, align 16 -; CHECK-NEXT: %lda_p = bitcast i64* %lda to i8* -; CHECK-NEXT: %ldb = alloca i64, align 16 -; CHECK-NEXT: %ldb_p = bitcast i64* %ldb to i8* -; CHECK-NEXT: %beta = alloca double, align 16 -; CHECK-NEXT: %beta_p = bitcast double* %beta to i8* -; CHECK-NEXT: %ldc = alloca i64, align 16 -; CHECK-NEXT: %ldc_p = bitcast i64* %ldc to i8* -; CHECK-NEXT: store i8 78, i8* %transa, align 1 -; CHECK-NEXT: store i8 78, i8* %transb, align 1 -; CHECK-NEXT: store i64 4, i64* %m, align 16 -; CHECK-NEXT: store i64 4, i64* %n, align 16 -; CHECK-NEXT: store i64 8, i64* %k, align 16 -; CHECK-NEXT: store double 1.000000e+00, double* %alpha, align 16 -; CHECK-NEXT: store i64 4, i64* %lda, align 16 -; CHECK-NEXT: store i64 8, i64* %ldb, align 16 -; CHECK-NEXT: store double 0.000000e+00, double* %beta -; CHECK-NEXT: store i64 4, i64* %ldc, align 16 -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) -; CHECK-NEXT: br label %invertentry - -; CHECK: invertentry: ; preds = %entry -; CHECK-NEXT: %ld.transa = load i8, i8* %transa -; CHECK-DAG: %[[i10:.+]] = icmp eq i8 %ld.transa, 110 -; CHECK-DAG: %[[i11:.+]] = select i1 %[[i10]], i8 116, i8 0 -; CHECK-DAG: %[[i12:.+]] = icmp eq i8 %ld.transa, 78 -; CHECK-DAG: %[[i13:.+]] = select i1 %[[i12]], i8 84, i8 %[[i11]] -; CHECK-DAG: %[[i14:.+]] = icmp eq i8 %ld.transa, 116 -; CHECK-DAG: %[[i15:.+]] = select i1 %[[i14]], i8 110, i8 %[[i13]] -; CHECK-DAG: %[[i16:.+]] = icmp eq i8 %ld.transa, 84 -; CHECK-DAG: %[[i17:.+]] = select i1 %[[i16]], i8 78, i8 %[[i15]] -; CHECK-NEXT: store i8 %[[i17]], i8* %byref.transpose.transa -; CHECK-NEXT: %ld.transb = load i8, i8* %transb -; CHECK-DAG: %[[i18:.+]] = icmp eq i8 %ld.transb, 110 -; CHECK-DAG: %[[i19:.+]] = select i1 %[[i18:.+]], i8 116, i8 0 -; CHECK-DAG: %[[i20:.+]] = icmp eq i8 %ld.transb, 78 -; CHECK-DAG: %[[i21:.+]] = select i1 %[[i20:.+]], i8 84, i8 %[[i19]] -; CHECK-DAG: %[[i22:.+]] = icmp eq i8 %ld.transb, 116 -; CHECK-DAG: %[[i23:.+]] = select i1 %[[i22:.+]], i8 110, i8 %[[i21]] -; CHECK-DAG: %[[i24:.+]] = icmp eq i8 %ld.transb, 84 -; CHECK-DAG: %[[i25:.+]] = select i1 %[[i24:.+]], i8 78, i8 %[[i23]] -; CHECK-NEXT: store i8 %[[i25]], i8* %byref.transpose.transb -; CHECK-NEXT: store i64 1, i64* %byref.int.one -; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %m_p, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %m_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) -; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G -; CHECK-NEXT: store i64 0, i64* %[[int00]] -; CHECK-NEXT: %[[intcast00:.+]] = bitcast i64* %[[int00]] to i8* -; CHECK-NEXT: store i64 0, i64* %[[int01]] -; CHECK-NEXT: %[[intcast02:.+]] = bitcast i64* %[[int01]] to i8* -; CHECK-NEXT: store double 1.000000e+00, double* %[[fp10]] -; CHECK-NEXT: %[[fpcast10:.+]] = bitcast double* %[[fp10]] to i8* -; CHECK-NEXT: store i64 0, i64* %[[int02]] -; CHECK-NEXT: %[[intcast04:.+]] = bitcast i64* %[[int02]] to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %[[intcast00]], i8* %[[intcast02]], i8* %[[fpcast10]], i8* %beta_p, i8* %m_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %[[intcast04]]) -; CHECK-NEXT: ret void -; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over.ll index 5f5ff35f0659..a13387ca8002 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over.ll @@ -48,34 +48,39 @@ entry: ; CHECK: define internal void @diffef(i8* %C, i8* %"C'", i8* %A, i8* %"A'", i8* %B, i8* %"B'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %byref.m = alloca i64 -; CHECK-NEXT: %ret = alloca double -; CHECK-NEXT: %byref.transpose.transa = alloca i8 -; CHECK-NEXT: %byref.transpose.transb = alloca i8 -; CHECK-NEXT: %byref.int.one = alloca i64 -; CHECK-NEXT: %byref.constant.char.G = alloca i8 -; CHECK-NEXT: %byref.constant.int.0 = alloca i64 -; CHECK-NEXT: %[[byrefconstantint1:.+]] = alloca i64 -; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double -; CHECK-NEXT: %[[byrefconstantint2:.+]] = alloca i64 -; CHECK-NEXT: %transa = alloca i8, align 1 -; CHECK-NEXT: %transb = alloca i8, align 1 -; CHECK-NEXT: %m = alloca i64, align 16 -; CHECK-NEXT: %m_p = bitcast i64* %m to i8* -; CHECK-NEXT: %n = alloca i64, align 16 -; CHECK-NEXT: %n_p = bitcast i64* %n to i8* -; CHECK-NEXT: %k = alloca i64, align 16 -; CHECK-NEXT: %k_p = bitcast i64* %k to i8* -; CHECK-NEXT: %alpha = alloca double, align 16 -; CHECK-NEXT: %alpha_p = bitcast double* %alpha to i8* -; CHECK-NEXT: %lda = alloca i64, align 16 -; CHECK-NEXT: %lda_p = bitcast i64* %lda to i8* -; CHECK-NEXT: %ldb = alloca i64, align 16 -; CHECK-NEXT: %ldb_p = bitcast i64* %ldb to i8* -; CHECK-NEXT: %beta = alloca double, align 16 -; CHECK-NEXT: %beta_p = bitcast double* %beta to i8* -; CHECK-NEXT: %ldc = alloca i64, align 16 -; CHECK-NEXT: %ldc_p = bitcast i64* %ldc to i8* +; CHECK-DAG: %ret = alloca double +; CHECK-DAG: %byref.transpose.transa = alloca i8 +; CHECK-DAG: %byref.transpose.transb = alloca i8 +; CHECK-DAG: %byref.int.one = alloca i64 +; CHECK-DAG: %byref.constant.char.T = alloca i8, align 1 +; CHECK-DAG: %byref.constant.char.N = alloca i8, align 1 +; CHECK-DAG: %byref.constant.fp.1.0 = alloca double +; CHECK-DAG: %byref.constant.char.T2 = alloca i8, align 1 +; CHECK-DAG: %byref.constant.char.N3 = alloca i8, align 1 +; CHECK-DAG: %byref.constant.fp.1.06 = alloca double +; CHECK-DAG: %byref.constant.char.G = alloca i8 +; CHECK-DAG: %byref.constant.int.0 = alloca i64 +; CHECK-DAT: %byref.constant.int.08 = alloca i64, align 8 +; CHECK-DAG: %byref.m = alloca i64 +; CHECK-DAG: %byref.constant.fp.1.010 = alloca double +; CHECK-DAG: %transa = alloca i8, align 1 +; CHECK-DAG: %transb = alloca i8, align 1 +; CHECK-DAG: %m = alloca i64, align 16 +; CHECK-DAG: %m_p = bitcast i64* %m to i8* +; CHECK-DAG: %n = alloca i64, align 16 +; CHECK-DAG: %n_p = bitcast i64* %n to i8* +; CHECK-DAG: %k = alloca i64, align 16 +; CHECK-DAG: %k_p = bitcast i64* %k to i8* +; CHECK-DAG: %alpha = alloca double, align 16 +; CHECK-DAG: %alpha_p = bitcast double* %alpha to i8* +; CHECK-DAG: %lda = alloca i64, align 16 +; CHECK-DAG: %lda_p = bitcast i64* %lda to i8* +; CHECK-DAG: %ldb = alloca i64, align 16 +; CHECK-DAG: %ldb_p = bitcast i64* %ldb to i8* +; CHECK-DAG: %beta = alloca double, align 16 +; CHECK-DAG: %beta_p = bitcast double* %beta to i8* +; CHECK-DAG: %ldc = alloca i64, align 16 +; CHECK-DAG: %ldc_p = bitcast i64* %ldc to i8* ; CHECK-NEXT: store i8 78, i8* %transa, align 1 ; CHECK-NEXT: store i8 78, i8* %transb, align 1 ; CHECK-NEXT: store i64 4, i64* %m, align 16 @@ -87,13 +92,13 @@ entry: ; CHECK-NEXT: store double 0.000000e+00, double* %beta ; CHECK-NEXT: store i64 4, i64* %ldc, align 16 ; CHECK-NEXT: %pcld.m = bitcast i8* %m_p to i64* -; CHECK-NEXT: %avld.m = load i64, i64* %pcld.m +; CHECK-NEXT: %avld.m = load i64, i64* %pcld.m, align 4 ; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) -; CHECK-NEXT: store i64 0, i64* %m +; CHECK-NEXT: store i64 0, i64* %m, align 16 ; CHECK-NEXT: br label %invertentry ; CHECK: invertentry: ; preds = %entry -; CHECK-NEXT: store i64 %avld.m, i64* %byref.m +; CHECK-NEXT: store i64 %avld.m, i64* %byref.m, align 4 ; CHECK-NEXT: %cast.m = bitcast i64* %byref.m to i8* ; CHECK-NEXT: %ld.transa = load i8, i8* %transa ; CHECK-DAG: %[[i10:.+]] = icmp eq i8 %ld.transa, 110 @@ -117,17 +122,56 @@ entry: ; CHECK-NEXT: store i8 %[[i25]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %cast.m, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %cast.m, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) + +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1 +; CHECK-NEXT: %ld.row.trans = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[a16:.+]] = icmp eq i8 %ld.row.trans, 110 +; CHECK-NEXT: %[[a17:.+]] = icmp eq i8 %ld.row.trans, 78 +; CHECK-NEXT: %[[a18:.+]] = or i1 %[[a17]], %[[a16]] +; CHECK-NEXT: %[[a19:.+]] = select i1 %[[a18]], i8* %byref.constant.char.N, i8* %transb +; CHECK-NEXT: %[[a20:.+]] = select i1 %[[a18]], i8* %byref.transpose.transb, i8* %byref.constant.char.T +; CHECK-NEXT: %[[a21:.+]] = select i1 %[[a18]], i8* %cast.m, i8* %k_p +; CHECK-NEXT: %[[a22:.+]] = select i1 %[[a18]], i8* %k_p, i8* %cast.m +; CHECK-NEXT: %ld.row.trans1 = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[a23:.+]] = icmp eq i8 %ld.row.trans1, 110 +; CHECK-NEXT: %[[a24:.+]] = icmp eq i8 %ld.row.trans1, 78 +; CHECK-NEXT: %[[a25:.+]] = or i1 %[[a24]], %[[a23]] +; CHECK-NEXT: %[[a26:.+]] = select i1 %[[a25]], i8* %"C'", i8* %B +; CHECK-NEXT: %[[a27:.+]] = select i1 %[[a25]], i8* %ldc_p, i8* %ldb_p +; CHECK-NEXT: %[[a28:.+]] = select i1 %[[a25]], i8* %B, i8* %"C'" +; CHECK-NEXT: %[[a29:.+]] = select i1 %[[a25]], i8* %ldb_p, i8* %ldc_p +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* +; CHECK-NEXT: call void @dgemm_64_(i8* %[[a19]], i8* %[[a20]], i8* %[[a21]], i8* %[[a22]], i8* %n_p, i8* %alpha_p, i8* %[[a26]], i8* %[[a27]], i8* %[[a28]], i8* %[[a29]], i8* %fpcast.constant.fp.1.0, i8* %"A'", i8* %lda_p, i64 1, i64 1) +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T2, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N3, align 1 +; CHECK-NEXT: %ld.row.trans4 = load i8, i8* %transb, align 1 +; CHECK-NEXT: %[[a30:.+]] = icmp eq i8 %ld.row.trans4, 110 +; CHECK-NEXT: %[[a31:.+]] = icmp eq i8 %ld.row.trans4, 78 +; CHECK-NEXT: %[[a32:.+]] = or i1 %[[a31]], %[[a30]] +; CHECK-NEXT: %[[a33:.+]] = select i1 %[[a32]], i8* %byref.transpose.transa, i8* %byref.constant.char.T2 +; CHECK-NEXT: %[[a34:.+]] = select i1 %[[a32]], i8* %byref.constant.char.N3, i8* %transa +; CHECK-NEXT: %[[a35:.+]] = select i1 %[[a32]], i8* %k_p, i8* %n_p +; CHECK-NEXT: %[[a36:.+]] = select i1 %[[a32]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %ld.row.trans5 = load i8, i8* %transb, align 1 +; CHECK-NEXT: %[[a37:.+]] = icmp eq i8 %ld.row.trans5, 110 +; CHECK-NEXT: %[[a38:.+]] = icmp eq i8 %ld.row.trans5, 78 +; CHECK-NEXT: %[[a39:.+]] = or i1 %[[a38]], %[[a37]] +; CHECK-NEXT: %[[a40:.+]] = select i1 %[[a39]], i8* %A, i8* %"C'" +; CHECK-NEXT: %[[a41:.+]] = select i1 %[[a39]], i8* %lda_p, i8* %ldc_p +; CHECK-NEXT: %[[a42:.+]] = select i1 %[[a39]], i8* %"C'", i8* %A +; CHECK-NEXT: %[[a43:.+]] = select i1 %[[a39]], i8* %ldc_p, i8* %lda_p +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.06, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.07 = bitcast double* %byref.constant.fp.1.06 to i8* +; CHECK-NEXT: call void @dgemm_64_(i8* %[[a33]], i8* %[[a34]], i8* %[[a35]], i8* %[[a36]], i8* %cast.m, i8* %alpha_p, i8* %[[a40]], i8* %[[a41]], i8* %[[a42]], i8* %[[a43]], i8* %fpcast.constant.fp.1.07, i8* %"B'", i8* %ldb_p, i64 1, i64 1) ; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* -; CHECK-NEXT: store i64 0, i64* %[[byrefconstantint1]] -; CHECK-NEXT: %intcast.constant.int.02 = bitcast i64* %byref.constant.int.01 to i8* +; CHECK-NEXT: store i64 0, i64* %byref.constant.int.08 +; CHECK-NEXT: %[[int02:.+]] = bitcast i64* %byref.constant.int.08 to i8* ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 -; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* -; CHECK-NEXT: store i64 0, i64* %[[byrefconstantint2]] -; CHECK-NEXT: %intcast.constant.int.04 = bitcast i64* %byref.constant.int.03 to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.02, i8* %fpcast.constant.fp.1.0, i8* %beta_p, i8* %cast.m, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %intcast.constant.int.04) +; CHECK-NEXT: %[[fp11:.+]] = bitcast double* %byref.constant.fp.1.010 to i8* +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %[[int02]], i8* %[[fp11]], i8* %beta_p, i8* %cast.m, i8* %n_p, i8* %"C'", i8* %ldc_p, i64 1) ; CHECK-NEXT: ret void ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over_lacpy.ll index 1c1c0cea8c2c..5ff039dfa7f8 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_over_lacpy.ll @@ -53,11 +53,16 @@ entry: ; CHECK-NEXT: %byref.transpose.transa = alloca i8 ; CHECK-NEXT: %byref.transpose.transb = alloca i8 ; CHECK-NEXT: %byref.int.one = alloca i64 -; CHECK-NEXT: %byref.constant.char.G = alloca i8 -; CHECK-NEXT: %byref.constant.int.0 = alloca i64 -; CHECK-NEXT: %[[byrefconstantint1:.+]] = alloca i64 -; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double -; CHECK-NEXT: %[[byrefconstantint2:.+]] = alloca i64 +; CHECK-NEXT: %byref.constant.char.T = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.char.N = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double, align 8 +; CHECK-NEXT: %byref.constant.char.T2 = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.char.N3 = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.fp.1.06 = alloca double, align 8 +; CHECK-NEXT: %byref.constant.char.G = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.int.0 = alloca i64, align 8 +; CHECK-NEXT: %byref.constant.int.08 = alloca i64, align 8 +; CHECK-NEXT: %byref.constant.fp.1.010 = alloca double, align 8 ; CHECK-NEXT: %transa = alloca i8, align 1 ; CHECK-NEXT: %transb = alloca i8, align 1 ; CHECK-NEXT: %m = alloca i64, align 16 @@ -117,18 +122,56 @@ entry: ; CHECK-DAG: store i8 %[[i25]], i8* %byref.transpose.transb ; CHECK-NEXT: store i64 1, i64* %byref.int.one ; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8* -; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %byref.transpose.transb, i8* %cast.m, i8* %k_p, i8* %n_p, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %"A'", i8* %lda_p, i64 1, i64 1) -; CHECK-NEXT: call void @dgemm_64_(i8* %byref.transpose.transa, i8* %transb, i8* %k_p, i8* %n_p, i8* %cast.m, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"C'", i8* %ldc_p, i8* %beta_p, i8* %"B'", i8* %ldb_p, i64 1, i64 1) -; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G -; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0 -; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* -; CHECK-NEXT: store i64 0, i64* %[[byrefconstantint1]] -; CHECK-NEXT: %intcast.constant.int.02 = bitcast i64* %[[byrefconstantint1]] to i8* -; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1 +; CHECK-NEXT: %ld.row.trans = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[r16:.+]] = icmp eq i8 %ld.row.trans, 110 +; CHECK-NEXT: %[[r17:.+]] = icmp eq i8 %ld.row.trans, 78 +; CHECK-NEXT: %[[r18:.+]] = or i1 %[[r17]], %[[r16]] +; CHECK-NEXT: %[[r19:.+]] = select i1 %[[r18]], i8* %byref.constant.char.N, i8* %transb +; CHECK-NEXT: %[[r20:.+]] = select i1 %[[r18]], i8* %byref.transpose.transb, i8* %byref.constant.char.T +; CHECK-NEXT: %[[r21:.+]] = select i1 %[[r18]], i8* %cast.m, i8* %k_p +; CHECK-NEXT: %[[r22:.+]] = select i1 %[[r18]], i8* %k_p, i8* %cast.m +; CHECK-NEXT: %ld.row.trans1 = load i8, i8* %transa, align 1 +; CHECK-NEXT: %[[r23:.+]] = icmp eq i8 %ld.row.trans1, 110 +; CHECK-NEXT: %[[r24:.+]] = icmp eq i8 %ld.row.trans1, 78 +; CHECK-NEXT: %[[r25:.+]] = or i1 %[[r24]], %[[r23]] +; CHECK-NEXT: %[[r26:.+]] = select i1 %[[r25:.+]], i8* %"C'", i8* %B +; CHECK-NEXT: %[[r27:.+]] = select i1 %[[r25:.+]], i8* %ldc_p, i8* %ldb_p +; CHECK-NEXT: %[[r28:.+]] = select i1 %[[r25:.+]], i8* %B, i8* %"C'" +; CHECK-NEXT: %[[r29:.+]] = select i1 %[[r25:.+]], i8* %ldb_p, i8* %ldc_p +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0, align 8 ; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* -; CHECK-NEXT: store i64 0, i64* %[[byrefconstantint2]] -; CHECK-NEXT: %intcast.constant.int.04 = bitcast i64* %[[byrefconstantint2]] to i8* -; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.02, i8* %fpcast.constant.fp.1.0, i8* %beta_p, i8* %cast.m, i8* %n_p, i8* %"C'", i8* %ldc_p, i8* %intcast.constant.int.04) +; CHECK-NEXT: call void @dgemm_64_(i8* %19, i8* %20, i8* %21, i8* %22, i8* %n_p, i8* %alpha_p, i8* %26, i8* %27, i8* %28, i8* %29, i8* %fpcast.constant.fp.1.0, i8* %"A'", i8* %lda_p, i64 1, i64 1) +; CHECK-NEXT: store i8 84, i8* %byref.constant.char.T2, align 1 +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N3, align 1 +; CHECK-NEXT: %ld.row.trans4 = load i8, i8* %transb, align 1 +; CHECK-NEXT: %[[r30:.+]] = icmp eq i8 %ld.row.trans4, 110 +; CHECK-NEXT: %[[r31:.+]] = icmp eq i8 %ld.row.trans4, 78 +; CHECK-NEXT: %[[r32:.+]] = or i1 %31, %30 +; CHECK-NEXT: %[[r33:.+]] = select i1 %[[r32]], i8* %byref.transpose.transa, i8* %byref.constant.char.T2 +; CHECK-NEXT: %[[r34:.+]] = select i1 %[[r32]], i8* %byref.constant.char.N3, i8* %transa +; CHECK-NEXT: %[[r35:.+]] = select i1 %[[r32]], i8* %k_p, i8* %n_p +; CHECK-NEXT: %[[r36:.+]] = select i1 %[[r32]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %ld.row.trans5 = load i8, i8* %transb, align 1 +; CHECK-NEXT: %[[r37:.+]] = icmp eq i8 %ld.row.trans5, 110 +; CHECK-NEXT: %[[r38:.+]] = icmp eq i8 %ld.row.trans5, 78 +; CHECK-NEXT: %[[r39:.+]] = or i1 %[[r38]], %[[r37]] +; CHECK-NEXT: %[[r40:.+]] = select i1 %[[r39]], i8* %A, i8* %"C'" +; CHECK-NEXT: %[[r41:.+]] = select i1 %[[r39]], i8* %lda_p, i8* %ldc_p +; CHECK-NEXT: %[[r42:.+]] = select i1 %[[r39]], i8* %"C'", i8* %A +; CHECK-NEXT: %[[r43:.+]] = select i1 %[[r39]], i8* %ldc_p, i8* %lda_p +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.06, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.07 = bitcast double* %byref.constant.fp.1.06 to i8* +; CHECK-NEXT: call void @dgemm_64_(i8* %[[r33]], i8* %[[r34]], i8* %[[r35]], i8* %[[r36]], i8* %cast.m, i8* %alpha_p, i8* %[[r40]], i8* %[[r41]], i8* %[[r42]], i8* %[[r43]], i8* %fpcast.constant.fp.1.07, i8* %"B'", i8* %ldb_p, i64 1, i64 1) +; CHECK-NEXT: store i8 71, i8* %byref.constant.char.G, align 1 +; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0, align 4 +; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* +; CHECK-NEXT: store i64 0, i64* %byref.constant.int.08, align 4 +; CHECK-NEXT: %intcast.constant.int.09 = bitcast i64* %byref.constant.int.08 to i8* +; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.010, align 8 +; CHECK-NEXT: %fpcast.constant.fp.1.011 = bitcast double* %byref.constant.fp.1.010 to i8* +; CHECK-NEXT: call void @dlascl_64_(i8* %byref.constant.char.G, i8* %intcast.constant.int.0, i8* %intcast.constant.int.09, i8* %fpcast.constant.fp.1.011, i8* %beta_p, i8* %cast.m, i8* %n_p, i8* %"C'", i8* %ldc_p, i64 1) ; CHECK-NEXT: ret void ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll index 7f70c9865b6e..4d3d6018dc64 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy.ll @@ -78,25 +78,25 @@ entry: ; CHECK-NEXT: store i64 2, i64* %9, align 16 ; CHECK-NEXT: store i64 1, i64* %10, align 16 ; CHECK-NEXT: %loaded.trans = load i8, i8* %malloccall -; CHECK-DAG: %[[i11:.+]] = icmp eq i8 %loaded.trans, 78 -; CHECK-DAG: %[[i12:.+]] = icmp eq i8 %loaded.trans, 110 -; CHECK-NEXT: %[[i13:.+]] = or i1 %[[i12]], %[[i11]] -; CHECK-NEXT: %[[i14:.+]] = select i1 %[[i13]], i8* %n_p, i8* %m_p -; CHECK-NEXT: %[[i15:.+]] = bitcast i8* %[[i14]] to i64* -; CHECK-NEXT: %[[i16:.+]] = load i64, i64* %[[i15]] -; CHECK-NEXT: %mallocsize = mul nuw nsw i64 %[[i16]], 8 +; CHECK-DAG: %[[r11:.+]] = icmp eq i8 %loaded.trans, 78 +; CHECK-DAG: %[[r12:.+]] = icmp eq i8 %loaded.trans, 110 +; CHECK-NEXT: %[[r13:.+]] = or i1 %[[r12]], %[[r11]] +; CHECK-NEXT: %[[r14:.+]] = select i1 %[[r13]], i8* %n_p, i8* %m_p +; CHECK-NEXT: %[[r15:.+]] = bitcast i8* %[[r14]] to i64* +; CHECK-NEXT: %[[r16:.+]] = load i64, i64* %[[r15]] +; CHECK-NEXT: %mallocsize = mul nuw nsw i64 %[[r16]], 8 ; CHECK-NEXT: %malloccall6 = tail call noalias nonnull i8* @malloc(i64 %mallocsize) ; CHECK-NEXT: %cache.x = bitcast i8* %malloccall6 to double* ; CHECK-NEXT: store i64 1, i64* %byref. -; CHECK-NEXT: call void @dcopy_64_(i8* %[[i14]], i8* %x, i8* %incx_p, double* %cache.x, i64* %byref.) +; CHECK-NEXT: call void @dcopy_64_(i8* %[[r14]], i8* %x, i8* %incx_p, double* %cache.x, i64* %byref.) ; CHECK-NEXT: %loaded.trans7 = load i8, i8* %malloccall -; CHECK-DAG: %[[i17:.+]] = icmp eq i8 %loaded.trans7, 78 -; CHECK-DAG: %[[i18:.+]] = icmp eq i8 %loaded.trans7, 110 -; CHECK-NEXT: %[[i19:.+]] = or i1 %[[i18]], %[[i17]] -; CHECK-NEXT: %[[i20:.+]] = select i1 %[[i19]], i8* %m_p, i8* %n_p -; CHECK-NEXT: %[[i21:.+]] = bitcast i8* %[[i20]] to i64* -; CHECK-NEXT: %[[i22:.+]] = load i64, i64* %[[i21]] -; CHECK-NEXT: %mallocsize8 = mul nuw nsw i64 %[[i22]], 8 +; CHECK-DAG: %[[r17:.+]] = icmp eq i8 %loaded.trans7, 78 +; CHECK-DAG: %[[r18:.+]] = icmp eq i8 %loaded.trans7, 110 +; CHECK-NEXT: %[[r19:.+]] = or i1 %[[r18]], %[[r17]] +; CHECK-NEXT: %[[r20:.+]] = select i1 %[[r19]], i8* %m_p, i8* %n_p +; CHECK-NEXT: %[[r21:.+]] = bitcast i8* %[[r20]] to i64* +; CHECK-NEXT: %[[r22:.+]] = load i64, i64* %[[r21]] +; CHECK-NEXT: %mallocsize8 = mul nuw nsw i64 %[[r22]], 8 ; CHECK-NEXT: %malloccall9 = tail call noalias nonnull i8* @malloc(i64 %mallocsize8) ; CHECK-NEXT: %cache.y = bitcast i8* %malloccall9 to double* ; CHECK-NEXT: store i64 1, i64* %byref.10 @@ -119,7 +119,7 @@ entry: ; CHECK-DAG: %byref.constant.fp.0.0 = alloca double ; CHECK-NEXT: %byref.constant.int.1 = alloca i64 ; CHECK-NEXT: %byref.constant.int.17 = alloca i64 -; CHECK-NEXT: %byref.constant.char.N11 = alloca i8, align 1 +; CHECK-NEXT: %byref.constant.char.N10 = alloca i8, align 1 ; CHECK-NEXT: %[[byrefconstantfp1:.+]] = alloca double ; CHECK-NEXT: %incy = alloca i64, i64 1, align 16 ; CHECK-NEXT: %1 = bitcast i64* %incy to i8* @@ -194,44 +194,40 @@ entry: ; CHECK-DAG: %[[c1:.+]] = icmp eq i8 %ld.row.trans, 110 ; CHECK-DAG: %[[c2:.+]] = icmp eq i8 %ld.row.trans, 78 ; CHECK-NEXT: %[[c3:.+]] = or i1 %[[c2]], %[[c1]] -; CHECK-NEXT: %34 = select i1 %[[c3]], i8* %m_p, i8* %n_p +; CHECK-NEXT: %[[r34:.+]] = select i1 %[[c3]], i8* %m_p, i8* %n_p ; CHECK-NEXT: store i64 1, i64* %byref.constant.int.17 ; CHECK-NEXT: %intcast.constant.int.18 = bitcast i64* %byref.constant.int.17 to i8* -; CHECK-NEXT: %35 = call fast double @ddot_64_(i8* %34, i8* %"y'", i8* %incy_p, i8* %19, i8* %intcast.constant.int.18) -; CHECK-NEXT: %36 = bitcast i8* %"alpha'" to double* -; CHECK-NEXT: %37 = load double, double* %36 -; CHECK-NEXT: %38 = fadd fast double %37, %35 +; CHECK-NEXT: %[[r35:.+]] = call fast double @ddot_64_(i8* %[[r34]], i8* %"y'", i8* %incy_p, i8* %19, i8* %intcast.constant.int.18) +; CHECK-NEXT: %[[r36:.+]] = bitcast i8* %"alpha'" to double* +; CHECK-NEXT: %[[r37:.+]] = load double, double* %[[r36]] +; CHECK-NEXT: %[[r38:.+]] = fadd fast double %[[r37]], %[[r35]] ; CHECK-NEXT: store double %38, double* %36 ; CHECK-NEXT: %ld.row.trans9 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[i39:.+]] = icmp eq i8 %ld.row.trans9, 110 -; CHECK-DAG: %[[i40:.+]] = icmp eq i8 %ld.row.trans9, 78 -; CHECK-NEXT: %[[i41:.+]] = or i1 %[[i40]], %[[i39]] -; CHECK-NEXT: %[[i42:.+]] = select i1 %41, i8* %"y'", i8* %20 -; CHECK-NEXT: %[[i46:.+]] = select i1 %[[i41]], i8* %incy_p, i8* %intcast.int.one -; CHECK-NEXT: %ld.row.trans10 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[i47:.+]] = icmp eq i8 %ld.row.trans10, 110 -; CHECK-DAG: %[[i48:.+]] = icmp eq i8 %ld.row.trans10, 78 -; CHECK-NEXT: %[[i49:.+]] = or i1 %[[i48]], %[[i47]] -; CHECK-NEXT: %[[i50:.+]] = select i1 %[[i49]], i8* %20, i8* %"y'" -; CHECK-NEXT: %[[i54:.+]] = select i1 %[[i49]], i8* %intcast.int.one, i8* %incy_p -; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha, i8* %[[i42]], i8* %[[i46]], i8* %[[i50]], i8* %[[i54]], i8* %"A'", i8* %lda_p) -; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N11, align 1 +; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans9, 110 +; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans9, 78 +; CHECK-NEXT: %[[r41:.+]] = or i1 %[[r40]], %[[r39]] +; CHECK-NEXT: %[[r42:.+]] = select i1 %[[r41]], i8* %"y'", i8* %[[r20]] +; CHECK-NEXT: %[[r43:.+]] = select i1 %[[r41]], i8* %incy_p, i8* %intcast.int.one +; CHECK-NEXT: %[[r44:.+]] = select i1 %[[r41]], i8* %[[r20]], i8* %"y'" +; CHECK-NEXT: %[[r45:.+]] = select i1 %[[r41]], i8* %intcast.int.one, i8* %incy_p +; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha, i8* %[[r42]], i8* %[[r43]], i8* %[[r44]], i8* %[[r45]], i8* %"A'", i8* %lda_p) +; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N10, align 1 ; CHECK-NEXT: store double 1.000000e+00, double* %[[byrefconstantfp1]] ; CHECK-NEXT: %[[fpcast14:.+]] = bitcast double* %[[byrefconstantfp1]] to i8* ; CHECK-NEXT: call void @dgemv_64_(i8* %byref.transpose.transa, i8* %m_p, i8* %n_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %"y'", i8* %incy_p, i8* %[[fpcast14]], i8* %"x'", i8* %incx_p, i64 1) -; CHECK-NEXT: %ld.row.trans14 = load i8, i8* %malloccall -; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans14, 110 -; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans14, 78 +; CHECK-NEXT: %ld.row.trans13 = load i8, i8* %malloccall +; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans13, 110 +; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans13, 78 ; CHECK-NEXT: %[[r41:.+]] = or i1 %[[r40]], %[[r39]] ; CHECK-NEXT: %[[r42:.+]] = select i1 %[[r41]], i8* %m_p, i8* %n_p -; CHECK-NEXT: %[[r43:.+]] = call fast double @ddot_64_(i8* %[[r42]], i8* %"y'", i8* %incy_p, i8* %21, i8* %intcast.int.one) +; CHECK-NEXT: %[[r43:.+]] = call fast double @ddot_64_(i8* %[[r42]], i8* %"y'", i8* %incy_p, i8* %[[r21]], i8* %intcast.int.one) ; CHECK-NEXT: %[[r44:.+]] = bitcast i8* %"beta'" to double* ; CHECK-NEXT: %[[r45:.+]] = load double, double* %[[r44]] ; CHECK-NEXT: %[[r46:.+]] = fadd fast double %[[r45]], %[[r43]] ; CHECK-NEXT: store double %[[r46]], double* %[[r44]] -; CHECK-NEXT: %ld.row.trans15 = load i8, i8* %malloccall -; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans15, 110 -; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans15, 78 +; CHECK-NEXT: %ld.row.trans14 = load i8, i8* %malloccall +; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans14, 110 +; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans14, 78 ; CHECK-NEXT: %[[r49:.+]] = or i1 %[[r48]], %[[r47]] ; CHECK-NEXT: %[[r50:.+]] = select i1 %[[r49]], i8* %m_p, i8* %n_p ; CHECK-NEXT: call void @dscal_64_(i8* %[[r50]], i8* %beta, i8* %"y'", i8* %incy_p) diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll index 7337b69b7794..bc9db26018b3 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_blascpy_runtime_act.ll @@ -77,9 +77,12 @@ entry: ; CHECK-NEXT: store i64 4, i64* %8, align 16 ; CHECK-NEXT: store i64 2, i64* %9, align 16 ; CHECK-NEXT: store i64 1, i64* %10, align 16 -; CHECK-NEXT: %rt.inactive.A = icmp eq i8* %"A'", %A -; CHECK-NEXT: %rt.inactive.beta = icmp eq i8* %"beta'", %beta -; CHECK-NEXT: %rt.inactive.y = icmp eq i8* %"y'", %y +; CHECK-NEXT: %rt.tmp.inactive.A = icmp eq i8* %"A'", %A +; CHECK-NEXT: %rt.tmp.inactive.beta = icmp eq i8* %"beta'", %beta +; CHECK-NEXT: %rt.tmp.inactive.y = icmp eq i8* %"y'", %y +; CHECK-NEXT: %rt.inactive.A = or i1 %rt.tmp.inactive.A, %rt.tmp.inactive.y +; CHECK-NEXT: %rt.inactive.beta = or i1 %rt.tmp.inactive.beta, %rt.tmp.inactive.y +; CHECK-NEXT: %rt.inactive.y = or i1 %rt.tmp.inactive.y, %rt.tmp.inactive.y ; CHECK-NEXT: %loaded.trans = load i8, i8* %malloccall ; CHECK-DAG: %[[i11:.+]] = icmp eq i8 %loaded.trans, 78 ; CHECK-DAG: %[[i12:.+]] = icmp eq i8 %loaded.trans, 110 @@ -144,9 +147,12 @@ entry: ; CHECK-NEXT: store i64 4, i64* %8, align 16 ; CHECK-NEXT: store i64 2, i64* %9, align 16 ; CHECK-NEXT: store i64 1, i64* %10, align 16 -; CHECK-NEXT: %rt.inactive.A = icmp eq i8* %"A'", %A -; CHECK-NEXT: %rt.inactive.beta = icmp eq i8* %"beta'", %beta -; CHECK-NEXT: %rt.inactive.y = icmp eq i8* %"y'", %y +; CHECK-NEXT: %rt.tmp.inactive.A = icmp eq i8* %"A'", %A +; CHECK-NEXT: %rt.tmp.inactive.beta = icmp eq i8* %"beta'", %beta +; CHECK-NEXT: %rt.tmp.inactive.y = icmp eq i8* %"y'", %y +; CHECK-NEXT: %rt.inactive.A = or i1 %rt.tmp.inactive.A, %rt.tmp.inactive.y +; CHECK-NEXT: %rt.inactive.beta = or i1 %rt.tmp.inactive.beta, %rt.tmp.inactive.y +; CHECK-NEXT: %rt.inactive.y = or i1 %rt.tmp.inactive.y, %rt.tmp.inactive.y ; CHECK-NEXT: br label %invertentry ; CHECK: invertentry: ; preds = %entry @@ -177,12 +183,8 @@ entry: ; CHECK-NEXT: %[[r24:.+]] = or i1 %[[r23]], %[[r22]] ; CHECK-NEXT: %[[r25:.+]] = select i1 %[[r24]], i8* %"y'", i8* %11 ; CHECK-NEXT: %[[r29:.+]] = select i1 %[[r24]], i8* %incy_p, i8* %intcast.int.one -; CHECK-NEXT: %ld.row.trans2 = load i8, i8* %malloccall, align 1 -; CHECK-DAG: %[[r30:.+]] = icmp eq i8 %ld.row.trans2, 110 -; CHECK-DAG: %[[r31:.+]] = icmp eq i8 %ld.row.trans2, 78 -; CHECK-NEXT: %[[r32:.+]] = or i1 %[[r31]], %[[r30]] -; CHECK-NEXT: %[[r33:.+]] = select i1 %[[r32]], i8* %11, i8* %"y'" -; CHECK-NEXT: %[[r37:.+]] = select i1 %[[r32]], i8* %intcast.int.one, i8* %incy_p +; CHECK-NEXT: %[[r33:.+]] = select i1 %[[r24]], i8* %11, i8* %"y'" +; CHECK-NEXT: %[[r37:.+]] = select i1 %[[r24]], i8* %intcast.int.one, i8* %incy_p ; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha, i8* %[[r25]], i8* %[[r29]], i8* %[[r33]], i8* %[[r37]], i8* %"A'", i8* %lda_p) ; CHECK-NEXT: br label %invertentry.A.done @@ -190,9 +192,9 @@ entry: ; CHECK-NEXT: br i1 %rt.inactive.beta, label %invertentry.beta.done, label %invertentry.beta.active ; CHECK: invertentry.beta.active: ; preds = %invertentry.A.done -; CHECK-NEXT: %ld.row.trans3 = load i8, i8* %malloccall -; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans3, 110 -; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans3, 78 +; CHECK-NEXT: %ld.row.trans2 = load i8, i8* %malloccall +; CHECK-DAG: %[[r39:.+]] = icmp eq i8 %ld.row.trans2, 110 +; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans2, 78 ; CHECK-NEXT: %[[r41:.+]] = or i1 %[[r40]], %[[r39]] ; CHECK-NEXT: %[[r42:.+]] = select i1 %[[r41]], i8* %m_p, i8* %n_p ; CHECK-NEXT: %[[r43:.+]] = call fast double @ddot_64_(i8* %[[r42]], i8* %"y'", i8* %incy_p, i8* %[[i12]], i8* %intcast.int.one) @@ -206,9 +208,9 @@ entry: ; CHECK-NEXT: br i1 %rt.inactive.y, label %invertentry.y.done, label %invertentry.y.active ; CHECK: invertentry.y.active: ; preds = %invertentry.beta.done -; CHECK-NEXT: %ld.row.trans4 = load i8, i8* %malloccall -; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans4, 110 -; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans4, 78 +; CHECK-NEXT: %ld.row.trans3 = load i8, i8* %malloccall +; CHECK-DAG: %[[r47:.+]] = icmp eq i8 %ld.row.trans3, 110 +; CHECK-DAG: %[[r48:.+]] = icmp eq i8 %ld.row.trans3, 78 ; CHECK-NEXT: %[[r49:.+]] = or i1 %[[r48]], %[[r47]] ; CHECK-NEXT: %[[r50:.+]] = select i1 %[[r49]], i8* %m_p, i8* %n_p ; CHECK-NEXT: call void @dscal_64_(i8* %[[r50]], i8* %beta, i8* %"y'", i8* %incy_p) diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll index cd853b3f207a..27f6f79c8f38 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_f_c_split_memcpy.ll @@ -197,20 +197,16 @@ entry: ; CHECK-DAG: %[[r26:.+]] = or i1 %[[r25]], %[[r24]] ; CHECK-NEXT: %[[r27:.+]] = select i1 %[[r26]], i8* %"y'", i8* %15 ; CHECK-NEXT: %[[r31:.+]] = select i1 %[[r26]], i8* %incy_p, i8* %intcast.int.one -; CHECK-NEXT: %ld.row.trans1 = load i8, i8* %malloccall -; CHECK-DAG: %[[r32:.+]] = icmp eq i8 %ld.row.trans1, 110 -; CHECK-DAG: %[[r33:.+]] = icmp eq i8 %ld.row.trans1, 78 -; CHECK-DAG: %[[r34:.+]] = or i1 %[[r33]], %[[r32]] -; CHECK-NEXT: %[[r35:.+]] = select i1 %[[r34]], i8* %15, i8* %"y'" -; CHECK-NEXT: %[[r39:.+]] = select i1 %[[r34]], i8* %intcast.int.one, i8* %incy_p +; CHECK-NEXT: %[[r35:.+]] = select i1 %[[r26]], i8* %15, i8* %"y'" +; CHECK-NEXT: %[[r39:.+]] = select i1 %[[r26]], i8* %intcast.int.one, i8* %incy_p ; CHECK-NEXT: call void @dger_64_(i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %[[r27]], i8* %[[r31]], i8* %[[r35]], i8* %[[r39]], i8* %"A'", i8* %lda_p) ; CHECK-NEXT: store i8 78, i8* %byref.constant.char.N, align 1 ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0 ; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8* ; CHECK-NEXT: call void @dgemv_64_(i8* %byref.transpose.transa, i8* %m_p, i8* %n_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %"y'", i8* %incy_p, i8* %fpcast.constant.fp.1.0, i8* %"x'", i8* %incx_p, i64 1) -; CHECK-NEXT: %ld.row.trans2 = load i8, i8* %malloccall -; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans2, 110 -; CHECK-DAG: %[[r41:.+]] = icmp eq i8 %ld.row.trans2, 78 +; CHECK-NEXT: %ld.row.trans1 = load i8, i8* %malloccall +; CHECK-DAG: %[[r40:.+]] = icmp eq i8 %ld.row.trans1, 110 +; CHECK-DAG: %[[r41:.+]] = icmp eq i8 %ld.row.trans1, 78 ; CHECK-NEXT: %[[r42:.+]] = or i1 %[[r41]], %[[r40]] ; CHECK-NEXT: %[[r43:.+]] = select i1 %[[r42]], i8* %m_p, i8* %n_p ; CHECK-NEXT: call void @dscal_64_(i8* %[[r43]], i8* %beta_p, i8* %"y'", i8* %incy_p) diff --git a/enzyme/test/Integration/ReverseMode/blas.cpp b/enzyme/test/Integration/ReverseMode/blas.cpp index 66b161bc635e..8203f577c4c2 100644 --- a/enzyme/test/Integration/ReverseMode/blas.cpp +++ b/enzyme/test/Integration/ReverseMode/blas.cpp @@ -13,914 +13,7 @@ // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi // TODO: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -#include "test_utils.h" - -#include -#include -#include -#include - -template -class vector { - T* data; - size_t capacity; - size_t length; -public: - vector() : data(nullptr), capacity(0), length(0) {} - vector(const vector &prev) : data((T*)malloc(sizeof(T)*prev.capacity)), capacity(prev.capacity), length(prev.length) { - memcpy(data, prev.data, prev.length*sizeof(T)); - } - void operator=(const vector &prev) { - free(data); - data = (T*)malloc(sizeof(T)*prev.capacity); - capacity = prev.capacity; - length = prev.length; - memcpy(data, prev.data, prev.length*sizeof(T)); - } - // Don't destruct to avoi dso handle in global - // ~vector() { free(data); } - - void push_back(T v) { - if (length == capacity) { - size_t next = capacity == 0 ? 1 : (2 * capacity); - data = (T*)realloc(data, sizeof(T)*next); - capacity = next; - } - data[length] = v; - length++; - } - - T& operator[](size_t index) { - assert(index < length); - return data[index]; - } - - const T& operator[] (size_t index) const { - assert(index < length); - return data[index]; - } - - bool operator==(const vector& rhs) const { - if (length != rhs.length) return false; - for (size_t i=0; i &tr, std::string prefix="") { - printf("%sPrimal:\n", prefix.c_str()); - bool reverse = false; - for (size_t i=0; i -void assert_eq(std::string scope, std::string varName, int i, T expected, T real, BlasCall texpected, BlasCall rcall) { - if (expected == real) return; - printf("Failure on call %d var %s found ", i, varName.c_str()); - printty(expected); - printf(" expected "); - printty(real); - printf("\n"); - exit(1); -} - -void check_equiv(std::string scope, int i, BlasCall expected, BlasCall real) { -#define MAKEASSERT(name) assert_eq(scope, #name, i, expected.name, real.name, expected, real); - MAKEASSERT(inDerivative) - MAKEASSERT(type) - MAKEASSERT(pout_arg1); - MAKEASSERT(pin_arg1); - MAKEASSERT(pin_arg2); - MAKEASSERT(farg1); - MAKEASSERT(farg2); - MAKEASSERT(layout); - MAKEASSERT(targ1); - MAKEASSERT(targ2); - MAKEASSERT(iarg1); - MAKEASSERT(iarg2); - MAKEASSERT(iarg3); - MAKEASSERT(iarg4); - MAKEASSERT(iarg5); - MAKEASSERT(iarg6); -} - -vector calls; -vector foundCalls; - -extern "C" { - -// https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2023-0/lascl.html -// technically LAPACKE_dlascl -__attribute__((noinline)) -void cblas_dlascl(char layout, char type, int KL, int KU, double cfrom, double cto, int M, int N, double* A, int lda) { - BlasCall call = {inDerivative, CallType::LASCL, - A, UNUSED_POINTER, UNUSED_POINTER, - cfrom, cto, - layout, - type, UNUSED_TRANS, - M, N, UNUSED_INT, lda, KL, KU}; - calls.push_back(call); -} - -__attribute__((noinline)) -double cblas_ddot(int N, double* X, int incx, double* Y, int incy) { - BlasCall call = {inDerivative, CallType::DOT, - UNUSED_POINTER, X, Y, - UNUSED_DOUBLE, UNUSED_DOUBLE, - UNUSED_TRANS, - UNUSED_TRANS, UNUSED_TRANS, - N, UNUSED_INT, UNUSED_INT, incx, incy, UNUSED_INT}; - calls.push_back(call); - return 3.15+N; -} - -// Y += alpha * X -__attribute__((noinline)) -void cblas_daxpy(int N, double alpha, double* X, int incx, double* Y, int incy) { - BlasCall call = {inDerivative, CallType::AXPY, - Y, X, UNUSED_POINTER, - alpha, UNUSED_DOUBLE, - UNUSED_TRANS, - UNUSED_TRANS, UNUSED_TRANS, - N, UNUSED_INT, UNUSED_INT, incx, incy, UNUSED_INT}; - calls.push_back(call); -} - -// Y = alpha * op(A) * X + beta * Y -__attribute__((noinline)) -void cblas_dgemv(char layout, char trans, int M, int N, double alpha, double* A, int lda, double* X, int incx, double beta, double* Y, int incy) { - BlasCall call = {inDerivative, CallType::GEMV, - Y, A, X, - alpha, beta, - layout, - trans, UNUSED_TRANS, - M, N, UNUSED_INT, lda, incx, incy}; - calls.push_back(call); -} - -// C = alpha * A^transA * B^transB + beta * C -__attribute__((noinline)) -void cblas_dgemm(char layout, char transA, char transB, int M, int N, int K, double alpha, double* A, int lda, double* B, int ldb, double beta, double* C, int ldc) { - calls.push_back((BlasCall){inDerivative, CallType::GEMM, - C, A, B, - alpha, beta, - layout, - transA, transB, - M, N, K, lda, ldb, ldc}); -} - -// X = alpha * X -__attribute__((noinline)) -void cblas_dscal(int N, double alpha, double* X, int incX) { - calls.push_back((BlasCall){inDerivative, CallType::SCAL, - X, UNUSED_POINTER, UNUSED_POINTER, - alpha, UNUSED_DOUBLE, - UNUSED_TRANS, - UNUSED_TRANS, UNUSED_TRANS, - N, UNUSED_INT, UNUSED_INT, incX, UNUSED_INT, UNUSED_INT}); -} - -// A = alpha * X * transpose(Y) + A -__attribute__((noinline)) -void cblas_dger(char layout, int M, int N, double alpha, double* X, int incX, double* Y, int incY, double* A, int lda) { - calls.push_back((BlasCall){inDerivative, CallType::GER, - A, X, Y, - alpha, UNUSED_DOUBLE, - layout, - UNUSED_TRANS, UNUSED_TRANS, - M, N, UNUSED_INT, incX, incY, lda}); -} - -__attribute__((noinline)) -void cblas_dcopy(int N, double* X, int incX, double* Y, int incY) { - - calls.push_back((BlasCall){inDerivative, CallType::COPY, - Y, X, UNUSED_POINTER, - alpha, UNUSED_DOUBLE, - UNUSED_TRANS, - UNUSED_TRANS, UNUSED_TRANS, - N, UNUSED_INT, UNUSED_INT, incX, incY, UNUSED_INT}); -} - -__attribute__((noinline)) -void cblas_dlacpy(char layout, char uplo, int M, int N, double* A, int lda, double* B, int ldb) { - calls.push_back((BlasCall){inDerivative, CallType::LACPY, - B, A, UNUSED_POINTER, - UNUSED_DOUBLE, UNUSED_DOUBLE, - layout, - uplo, UNUSED_TRANS, - M, N, UNUSED_INT, lda, ldb, UNUSED_INT}); -} - -__attribute__((noinline)) -void dlacpy(char *uplo, int *M, int* N, double* A, int *lda, double* B, int* ldb) { - cblas_dlacpy(CblasColMajor, *uplo, *M, *N, A, *lda, B, *ldb); -} - -} - -enum class ValueType { - Matrix, - Vector -}; -struct BlasInfo { - void* ptr; - ValueType ty; - int vec_length; - int vec_increment; - char mat_layout; - int mat_rows; - int mat_cols; - int mat_ld; - BlasInfo (void* v_ptr, int length, int increment) { - ptr = v_ptr; - ty = ValueType::Vector; - vec_length = length; - vec_increment = increment; - mat_layout = '@'; - mat_rows = -1; - mat_cols = -1; - mat_ld = -1; - } - BlasInfo (void* v_ptr, char layout, int rows, int cols, int ld) { - ptr = v_ptr; - ty = ValueType::Matrix; - vec_length = -1; - vec_increment = -1; - mat_layout = layout; - mat_rows = rows; - mat_cols = cols; - mat_ld = ld; - } - BlasInfo () { - ptr = (void*)(-1); - ty = ValueType::Matrix; - vec_length = -1; - vec_increment = -1; - mat_layout = -1; - mat_rows = -1; - mat_cols = -1; - mat_ld = -1; - } -}; - -BlasInfo pointer_to_index(void* v, BlasInfo inputs[6]) { - if (v == A || v == dA) return inputs[0]; - if (v == B || v == dB) return inputs[1]; - if (v == C || v == dC) return inputs[2]; - for (int i=3; i<6; i++) - if (inputs[i].ptr == v) - return inputs[i]; - assert(0 && " illegal pointer to invert"); -} - -void checkVector(BlasInfo info, std::string vecname, int length, int increment, std::string test, BlasCall rcall, const vector & trace) { - if (info.ty != ValueType::Vector) { - printf("Error in test %s, invalid memory\n", test.c_str()); - printTrace(trace); - printcall(rcall); - printf(" Input %s is not a vector\n", vecname.c_str()); - exit(1); - } - if (info.vec_length != length) { - printf("Error in test %s, invalid memory\n", test.c_str()); - printTrace(trace); - printcall(rcall); - printf(" Input %s length must be ", vecname.c_str()); - printty(info.vec_length); - printf(" found "); - printty(length); - printf("\n"); - exit(1); - } - if (info.vec_increment != increment) { - printf("Error in test %s, invalid memory\n", test.c_str()); - printTrace(trace); - printcall(rcall); - printf(" Input %s increment must be ", vecname.c_str()); - printty(info.vec_increment); - printf(" found "); - printty(increment); - printf("\n"); - exit(1); - } -} - -void checkMatrix(BlasInfo info, std::string matname, char layout, int rows, int cols, int ld, std::string test, BlasCall rcall, const vector & trace) { - if (info.ty != ValueType::Matrix) { - printf("Error in test %s, invalid memory\n", test.c_str()); - printTrace(trace); - printcall(rcall); - printf(" Input %s is not a matrix\n", matname.c_str()); - exit(1); - } - if (info.mat_layout != layout) { - printf("Error in test %s, invalid memory\n", test.c_str()); - printTrace(trace); - printcall(rcall); - printf(" Input %s layout must be ", matname.c_str()); - printty(info.mat_layout); - printf(" found layout="); - printty(layout); - printf("\n"); - exit(1); - } - if (info.mat_rows != rows) { - printf("Error in test %s, invalid memory\n", test.c_str()); - printTrace(trace); - printcall(rcall); - printf(" Input %s rows must be ", matname.c_str()); - printty(info.mat_rows); - printf(" found "); - printty(rows); - printf("\n"); - exit(1); - } - if (info.mat_cols != cols) { - printf("Error in test %s, invalid memory\n", test.c_str()); - printTrace(trace); - printcall(rcall); - printf(" Input %s cols must be ", matname.c_str()); - printty(info.mat_cols); - printf(" found "); - printty(cols); - printf("\n"); - exit(1); - } - if (info.mat_ld != ld) { - printf("Error in test %s, invalid memory\n", test.c_str()); - printTrace(trace); - printcall(rcall); - printf(" Input %s leading dimension rows must be ", test.c_str()); - printty(info.mat_ld); - printf(" found "); - printty(ld); - printf("\n"); - exit(1); - } -} - -void checkMemory(BlasCall rcall, BlasInfo inputs[6], std::string test, const vector & trace) { - switch (rcall.type) { - return; - case CallType::LASCL: { - auto A = pointer_to_index(rcall.pout_arg1, inputs); - - auto layout = rcall.layout; - auto type = rcall.targ1; - auto KL = rcall.iarg5; - auto KU = rcall.iarg6; - auto cfrom = rcall.farg1; - auto cto = rcall.farg2; - - auto M = rcall.iarg1; - auto N = rcall.iarg2; - auto lda = rcall.iarg4; - - // = 'G': A is a full matrix. - assert(type == 'G'); - - // A is an m-by-n matrix - checkMatrix(A, "A", layout, /*rows=*/M, /*cols=*/N, /*ld=*/lda, test, rcall, trace); - return; - } - case CallType::AXPY: { - auto Y = pointer_to_index(rcall.pout_arg1, inputs); - - auto X = pointer_to_index(rcall.pin_arg1, inputs); - - auto alpha = rcall.farg1; - - auto N = rcall.iarg1; - auto incX = rcall.iarg4; - auto incY = rcall.iarg5; - - checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace); - checkVector(Y, "Y", /*len=*/N, /*inc=*/incY, test, rcall, trace); - return; - } - case CallType::DOT: { - auto X = pointer_to_index(rcall.pin_arg1, inputs); - auto Y = pointer_to_index(rcall.pin_arg2, inputs); - - auto N = rcall.iarg1; - auto incX = rcall.iarg4; - auto incY = rcall.iarg5; - - checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace); - checkVector(Y, "Y", /*len=*/N, /*inc=*/incY, test, rcall, trace); - return; - } - case CallType::GEMV:{ - // Y = alpha * op(A) * X + beta * Y - auto Y = pointer_to_index(rcall.pout_arg1, inputs); - auto A = pointer_to_index(rcall.pin_arg1, inputs); - auto X = pointer_to_index(rcall.pin_arg2, inputs); - - auto layout = rcall.layout; - auto trans_char = rcall.targ1; - auto trans = !(trans_char == 'N' || trans_char == 'n'); - auto M = rcall.iarg1; - auto N =rcall.iarg2; - auto alpha = rcall.farg1; - auto lda = rcall.iarg4; - auto incX = rcall.iarg5; - auto beta = rcall.farg2; - auto incY = rcall.iarg6; - - // A is an m-by-n matrix - checkMatrix(A, "A", layout, /*rows=*/M, /*cols=*/N, /*ld=*/lda, test, rcall, trace); - - // if no trans, X must be N otherwise must be M - // From https://www.netlib.org/lapack/explore-html/d7/d15/group__double__blas__level2_gadd421a107a488d524859b4a64c1901a9.html - // X is DOUBLE PRECISION array, dimension at least - // ( 1 + ( n - 1 )*abs( INCX ) ) when TRANS = 'N' or 'n' - // and at least - // ( 1 + ( m - 1 )*abs( INCX ) ) otherwise. - // Before entry, the incremented array X must contain the - // vector x. - auto Xlen = trans ? M : N; - checkVector(X, "X", /*len=*/Xlen, /*inc=*/incX, test, rcall, trace); - - // if no trans, Y must be M otherwise must be N - auto Ylen = trans ? N : M; - checkVector(Y, "Y", /*len=*/Ylen, /*inc=*/incY, test, rcall, trace); - - return; - } - case CallType::GEMM:{ - // C = alpha * A^transA * B^transB + beta * C - auto C = pointer_to_index(rcall.pout_arg1, inputs); - auto A = pointer_to_index(rcall.pin_arg1, inputs); - auto B = pointer_to_index(rcall.pin_arg2, inputs); - - auto layout = rcall.layout; - auto transA_char = rcall.targ1; - auto transA = !(transA_char == 'N' || transA_char == 'n'); - auto transB_char = rcall.targ2; - auto transB = !(transB_char == 'N' || transB_char == 'n'); - auto M = rcall.iarg1; - auto N = rcall.iarg2; - auto K = rcall.iarg3; - auto alpha = rcall.farg1; - auto lda = rcall.iarg4; - auto ldb = rcall.iarg5; - auto beta = rcall.farg2; - auto ldc = rcall.iarg6; - - // From https://www.netlib.org/lapack/explore-html/d1/d54/group__double__blas__level3_gaeda3cbd99c8fb834a60a6412878226e1.html - /* - M is INTEGER - On entry, M specifies the number of rows of the matrix - op( A ) and of the matrix C. M must be at least zero. - N is INTEGER - On entry, N specifies the number of columns of the matrix - op( B ) and the number of columns of the matrix C. N must be - at least zero. - K is INTEGER - On entry, K specifies the number of columns of the matrix - op( A ) and the number of rows of the matrix op( B ). K must - be at least zero. - LDA is INTEGER - On entry, LDA specifies the first dimension of A as declared - in the calling (sub) program. When TRANSA = 'N' or 'n' then - LDA must be at least max( 1, m ), otherwise LDA must be at - least max( 1, k ). - */ - checkMatrix(A, "A", layout, /*rows=*/(!transA) ? M : K, /*cols=*/(!transA) ? K : M, /*ld=*/lda, test, rcall, trace); - checkMatrix(B, "B", layout, /*rows=*/(!transB) ? K : N, /*cols=*/(!transB) ? N : K, /*ld=*/ldb, test, rcall, trace); - checkMatrix(C, "C", layout, /*rows=*/M, /*cols=*/N, /*ld=*/ldc, test, rcall, trace); - return; - } - - case CallType::SCAL: { - auto N = rcall.iarg1; - auto alpha = rcall.farg1; - auto X = pointer_to_index(rcall.pout_arg1, inputs); - auto incX = rcall.iarg4; - checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace); - return; - } - case CallType::GER: { - // A = alpha * X * transpose(Y) + A - auto A = pointer_to_index(rcall.pout_arg1, inputs); - auto X = pointer_to_index(rcall.pin_arg1, inputs); - auto Y = pointer_to_index(rcall.pin_arg2, inputs); - - auto layout = rcall.layout; - auto M = rcall.iarg1; - auto N = rcall.iarg2; - auto alpha = rcall.farg1; - auto incX = rcall.iarg4; - auto incY = rcall.iarg5; - auto incA = rcall.iarg6; - - // From https://www.netlib.org/lapack/explore-html/d7/d15/group__double__blas__level2_ga458222e01b4d348e9b52b9343d52f828.html - // x is an m element vector, y is an n element - // vector and A is an m by n matrix. - checkVector(X, "X", /*len=*/M, /*inc=*/incX, test, rcall, trace); - checkVector(Y, "Y", /*len=*/N, /*inc=*/incY, test, rcall, trace); - checkMatrix(A, "A", layout, /*rows=*/M, /*cols=*/N, /*ld=*/incA, test, rcall, trace); - return; - } - case CallType::COPY: { - auto Y = pointer_to_index(rcall.pout_arg1, inputs); - auto X = pointer_to_index(rcall.pin_arg1, inputs); - - auto N = rcall.iarg1; - auto incX = rcall.iarg4; - auto incY = rcall.iarg5; - checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace); - checkVector(Y, "Y", /*len=*/N, /*inc=*/incY, test, rcall, trace); - return; - } - case CallType::LACPY: { - auto B = pointer_to_index(rcall.pout_arg1, inputs); - auto A = pointer_to_index(rcall.pin_arg1, inputs); - - auto layout = rcall.layout; - auto uplo = rcall.targ1; - auto M = rcall.iarg1; - auto N = rcall.iarg2; - auto lda = rcall.iarg4; - auto ldb = rcall.iarg5; - checkMatrix(A, "A", layout, /*rows=*/M, /*cols=*/N, /*ld=*/lda, test, rcall, trace); - checkMatrix(B, "B", layout, /*rows=*/M, /*cols=*/N, /*ld=*/ldb, test, rcall, trace); - return; - } - default: printf("UNKNOWN CALL (%d)", (int)rcall.type); return; - } -} - -void checkMemoryTrace(BlasInfo inputs[6], std::string test, const vector & trace) { - for (size_t i=0; i +void __enzyme_autodiff(void*, T...); + +void my_dgemv(char layout, char trans, int M, int N, double alpha, double* __restrict__ A, int lda, double* __restrict__ X, int incx, double beta, double* __restrict__ Y, int incy) { + cblas_dgemv(layout, trans, M, N, alpha, A, lda, X, incx, beta, Y, incy); + inDerivative = true; +} + +void ow_dgemv(char layout, char trans, int M, int N, double alpha, double* A, int lda, double* X, int incx, double beta, double* Y, int incy) { + cblas_dgemv(layout, trans, M, N, alpha, A, lda, X, incx, beta, Y, incy); + inDerivative = true; +} + +double my_ddot(int N, double* __restrict__ X, int incx, double* __restrict__ Y, int incy) { + double res = cblas_ddot(N, X, incx, Y, incy); + inDerivative = true; + return res; +} + +void my_dgemm(char layout, char transA, char transB, int M, int N, int K, double alpha, double* __restrict__ A, int lda, double* __restrict__ B, int ldb, double beta, double* __restrict__ C, int ldc) { + cblas_dgemm(layout, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + inDerivative = true; +} + + +static void dotTests() { + + std::string Test = "DOT active both "; + BlasInfo inputs[6] = { + /*A*/ BlasInfo(A, N, incA), + /*B*/ BlasInfo(B, N, incB), + /*C*/ BlasInfo(C, M, incC), + BlasInfo(), + BlasInfo(), + BlasInfo(), + }; + init(); + my_ddot(N, A, incA, B, incB); + + // Check memory of primal on own. + checkMemoryTrace(inputs, "Primal " + Test, calls); + + init(); + __enzyme_autodiff((void*) my_ddot, + enzyme_const, N, + enzyme_dup, A, dA, + enzyme_const, incA, + enzyme_dup, B, B, + enzyme_const, incB); + foundCalls = calls; + init(); + + my_ddot(N, A, incA, B, incB); + + inDerivative = true; + + cblas_daxpy(N, 1.0, B, incB, dA, incA); + + checkTest(Test); + + // Check memory of primal of expected derivative + checkMemoryTrace(inputs, "Expected " + Test, calls); + + // Check memory of primal of our derivative (if equal above, it + // should be the same). + checkMemoryTrace(inputs, "Found " + Test, foundCalls); +} + +static void gemvTests() { + // N means normal matrix, T means transposed + for (char layout : { CblasRowMajor, CblasColMajor }) { + for (char transA : {'N', 'n', 'T', 't'}) { + + { + + bool trans = !(transA == 'N' || transA == 'n'); + std::string Test = "GEMV active A, C [runtime const B] "; + BlasInfo inputs[6] = { + /*A*/ BlasInfo(A, layout, M, N, lda), + /*B*/ BlasInfo(B, trans ? M : N, incB), + /*C*/ BlasInfo(C, trans ? N : M, incC), + BlasInfo(), + BlasInfo(), + BlasInfo() + }; + init(); + my_dgemv(layout, transA, M, N, alpha, A, lda, B, incB, beta, C, incC); + + assert(calls.size() == 1); + assert(calls[0].inDerivative == false); + assert(calls[0].type == CallType::GEMV); + assert(calls[0].pout_arg1 == C); + assert(calls[0].pin_arg1 == A); + assert(calls[0].pin_arg2 == B); + assert(calls[0].farg1 == alpha); + assert(calls[0].farg2 == beta); + assert(calls[0].layout == layout); + assert(calls[0].targ1 == transA); + assert(calls[0].targ2 == UNUSED_TRANS); + assert(calls[0].iarg1 == M); + assert(calls[0].iarg2 == N); + assert(calls[0].iarg3 == UNUSED_INT); + assert(calls[0].iarg4 == lda); + assert(calls[0].iarg5 == incB); + assert(calls[0].iarg6 == incC); + + // Check memory of primal on own. + checkMemoryTrace(inputs, "Primal " + Test, calls); + + init(); + __enzyme_autodiff((void*) my_dgemv, + enzyme_const, layout, + enzyme_const, transA, + enzyme_const, M, + enzyme_const, N, + enzyme_const, alpha, + enzyme_dup, A, dA, + enzyme_const, lda, + enzyme_dup, B, B, + enzyme_const, incB, + enzyme_const, beta, + enzyme_dup, C, dC, + enzyme_const, incC); + foundCalls = calls; + init(); + + my_dgemv(layout, transA, M, N, alpha, A, lda, B, incB, beta, C, incC); + + inDerivative = true; + // dC = alpha * X * transpose(Y) + A + cblas_dger(layout, M, N, alpha, trans ? B : dC, trans ? incB : incC, trans ? dC : B, trans ? incC : incB, dA, lda); + // dY = beta * dY + cblas_dscal(trans ? N : M, beta, dC, incC); + + checkTest(Test); + + // Check memory of primal of expected derivative + checkMemoryTrace(inputs, "Expected " + Test, calls); + + // Check memory of primal of our derivative (if equal above, it + // should be the same). + checkMemoryTrace(inputs, "Found " + Test, foundCalls); + + Test = "GEMV active B, C [Runtime Const A]"; + + init(); + __enzyme_autodiff((void*) my_dgemv, + enzyme_const, layout, + enzyme_const, transA, + enzyme_const, M, + enzyme_const, N, + enzyme_const, alpha, + enzyme_dup, A, A, + enzyme_const, lda, + enzyme_dup, B, dB, + enzyme_const, incB, + enzyme_const, beta, + enzyme_dup, C, dC, + enzyme_const, incC); + foundCalls = calls; + init(); + + my_dgemv(layout, transA, M, N, alpha, A, lda, B, incB, beta, C, incC); + + inDerivative = true; + + // dB = alpha * trans(A) * dC + dB + cblas_dgemv(layout, transpose(transA), M, N, alpha, A, lda, dC, incC, 1.0, dB, incB); + + // dY = beta * dY + cblas_dscal(trans ? N : M, beta, dC, incC); + + checkTest(Test); + + // Check memory of primal of expected derivative + checkMemoryTrace(inputs, "Expected " + Test, calls); + + // Check memory of primal of our derivative (if equal above, it + // should be the same). + checkMemoryTrace(inputs, "Found " + Test, foundCalls); + + + Test = "GEMV active A B [Runtime Const C]"; + + init(); + __enzyme_autodiff((void*) my_dgemv, + enzyme_const, layout, + enzyme_const, transA, + enzyme_const, M, + enzyme_const, N, + enzyme_const, alpha, + enzyme_dup, A, dA, + enzyme_const, lda, + enzyme_dup, B, dB, + enzyme_const, incB, + enzyme_const, beta, + enzyme_dup, C, C, + enzyme_const, incC); + foundCalls = calls; + init(); + + my_dgemv(layout, transA, M, N, alpha, A, lda, B, incB, beta, C, incC); + + inDerivative = true; + // dC = alpha * X * transpose(Y) + A + // cblas_dger(layout, M, N, alpha, trans ? B : dC, trans ? incB : incC, trans ? dC : B, trans ? incC : incB, dA, lda); + + // dB = alpha * trans(A) * dC + dB + // cblas_dgemv(layout, transpose(transA), M, N, alpha, A, lda, dC, incC, 1.0, dB, incB); + + // dY = beta * dY + // cblas_dscal(trans ? N : M, beta, dC, incC); + + checkTest(Test); + + // Check memory of primal of expected derivative + checkMemoryTrace(inputs, "Expected " + Test, calls); + + // Check memory of primal of our derivative (if equal above, it + // should be the same). + checkMemoryTrace(inputs, "Found " + Test, foundCalls); + + inputs[4] = BlasInfo(); + inputs[5] = BlasInfo(); + } + + + } + } +} + +static void gemmTests() { + // N means normal matrix, T means transposed + for (char layout : { CblasRowMajor, CblasColMajor }) { + for (char transA : {'N', 'n', 'T', 't'}) { + for (char transB : {'N', 'n', 'T', 't'}) { + + { + + bool transA_bool = !(transA == 'N' || transA == 'n'); + bool transB_bool = !(transB == 'N' || transB == 'n'); + std::string Test = "GEMM"; + BlasInfo inputs[6] = { + /*A*/ BlasInfo(A, layout, transA_bool ? K : M, transA_bool ? M : K, lda), + /*B*/ BlasInfo(B, layout, transB_bool ? N : K , transB_bool ? K : N, incB), + /*C*/ BlasInfo(C, layout, M, N, incC), + BlasInfo(), + BlasInfo(), + BlasInfo() + }; + + printf("TODO GEMM runtime activity\n"); + init(); + my_dgemm(layout, transA, transB, M, N, K, alpha, A, lda, B, incB, beta, C, incC); + + assert(calls.size() == 1); + assert(calls[0].inDerivative == false); + assert(calls[0].type == CallType::GEMM); + assert(calls[0].pout_arg1 == C); + assert(calls[0].pin_arg1 == A); + assert(calls[0].pin_arg2 == B); + assert(calls[0].farg1 == alpha); + assert(calls[0].farg2 == beta); + assert(calls[0].layout == layout); + assert(calls[0].targ1 == transA); + assert(calls[0].targ2 == transB); + assert(calls[0].iarg1 == M); + assert(calls[0].iarg2 == N); + assert(calls[0].iarg3 == K); + assert(calls[0].iarg4 == lda); + assert(calls[0].iarg5 == incB); + assert(calls[0].iarg6 == incC); + + // Check memory of primal on own. + checkMemoryTrace(inputs, "Primal " + Test, calls); + + init(); + __enzyme_autodiff((void*) my_dgemm, + enzyme_const, layout, + enzyme_const, transA, + enzyme_const, transB, + enzyme_const, M, + enzyme_const, N, + enzyme_const, K, + enzyme_const, alpha, + enzyme_dup, A, dA, + enzyme_const, lda, + enzyme_dup, B, dB, + enzyme_const, incB, + enzyme_const, beta, + enzyme_dup, C, dC, + enzyme_const, incC); + foundCalls = calls; + init(); + + + my_dgemm(layout, transA, transB, M, N, K, alpha, A, lda, B, incB, beta, C, incC); + + inDerivative = true; + + // dA = + my_dgemm(layout, + transA_bool ? transB : 'N', + transA_bool ? 'T' : transpose(transB), + transA_bool ? K : M, + transA_bool ? M : K, + N, + alpha, + transA_bool ? B : dC, + transA_bool ? incB : incC, + transA_bool ? dC : B, + transA_bool ? incC : incB, + 1.0, dA, lda); + + // dB = + my_dgemm(layout, + transB_bool ? 'T' : transpose(transA), + transB_bool ? transA : 'N', //transB, + transB_bool ? N : K, + transB_bool ? K : N, + M, + alpha, + transB_bool ? dC : A, + transB_bool ? incC : lda, + transB_bool ? A : dC, + transB_bool ? lda : incC, + 1.0, dB, incB); + + cblas_dlascl(layout, 'G', 0, 0, 1.0, beta, M, N, dC, incC /*, extra 0*/ ); + + checkTest(Test); + + // Check memory of primal of expected derivative + checkMemoryTrace(inputs, "Expected " + Test, calls); + + // Check memory of primal of our derivative (if equal above, it + // should be the same). + checkMemoryTrace(inputs, "Found " + Test, foundCalls); + + } + + + } + } + } +} + +int main() { + + dotTests(); + + gemvTests(); + + gemmTests(); + +} diff --git a/enzyme/test/Integration/blasinfra.h b/enzyme/test/Integration/blasinfra.h new file mode 100644 index 000000000000..de2deb3f60da --- /dev/null +++ b/enzyme/test/Integration/blasinfra.h @@ -0,0 +1,941 @@ + +#include +#include +#include +#include + +template +class vector { + T* data; + size_t capacity; + size_t length; +public: + vector() : data(nullptr), capacity(0), length(0) {} + vector(const vector &prev) : data((T*)malloc(sizeof(T)*prev.capacity)), capacity(prev.capacity), length(prev.length) { + memcpy(data, prev.data, prev.length*sizeof(T)); + } + void operator=(const vector &prev) { + free(data); + data = (T*)malloc(sizeof(T)*prev.capacity); + capacity = prev.capacity; + length = prev.length; + memcpy(data, prev.data, prev.length*sizeof(T)); + } + // Don't destruct to avoi dso handle in global + // ~vector() { free(data); } + + void push_back(T v) { + if (length == capacity) { + size_t next = capacity == 0 ? 1 : (2 * capacity); + data = (T*)realloc(data, sizeof(T)*next); + capacity = next; + } + data[length] = v; + length++; + } + + T& operator[](size_t index) { + assert(index < length); + return data[index]; + } + + const T& operator[] (size_t index) const { + assert(index < length); + return data[index]; + } + + bool operator==(const vector& rhs) const { + if (length != rhs.length) return false; + for (size_t i=0; i &tr, std::string prefix="") { + printf("%sPrimal:\n", prefix.c_str()); + bool reverse = false; + for (size_t i=0; i +void assert_eq(std::string scope, std::string varName, int i, T expected, T real, BlasCall texpected, BlasCall rcall) { + if (expected == real) return; + printf("Failure on call %d var %s found ", i, varName.c_str()); + printty(expected); + printf(" expected "); + printty(real); + printf("\n"); + exit(1); +} + +void check_equiv(std::string scope, int i, BlasCall expected, BlasCall real) { +#define MAKEASSERT(name) assert_eq(scope, #name, i, expected.name, real.name, expected, real); + MAKEASSERT(inDerivative) + MAKEASSERT(type) + MAKEASSERT(pout_arg1); + MAKEASSERT(pin_arg1); + MAKEASSERT(pin_arg2); + MAKEASSERT(farg1); + MAKEASSERT(farg2); + MAKEASSERT(layout); + MAKEASSERT(targ1); + MAKEASSERT(targ2); + MAKEASSERT(iarg1); + MAKEASSERT(iarg2); + MAKEASSERT(iarg3); + MAKEASSERT(iarg4); + MAKEASSERT(iarg5); + MAKEASSERT(iarg6); +} + +vector calls; +vector foundCalls; + +extern "C" { + +// https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2023-0/lascl.html +// technically LAPACKE_dlascl +__attribute__((noinline)) +void cblas_dlascl(char layout, char type, int KL, int KU, double cfrom, double cto, int M, int N, double* A, int lda) { + BlasCall call = {inDerivative, CallType::LASCL, + A, UNUSED_POINTER, UNUSED_POINTER, + cfrom, cto, + layout, + type, UNUSED_TRANS, + M, N, UNUSED_INT, lda, KL, KU}; + calls.push_back(call); +} + +__attribute__((noinline)) +double cblas_ddot(int N, double* X, int incx, double* Y, int incy) { + BlasCall call = {inDerivative, CallType::DOT, + UNUSED_POINTER, X, Y, + UNUSED_DOUBLE, UNUSED_DOUBLE, + UNUSED_TRANS, + UNUSED_TRANS, UNUSED_TRANS, + N, UNUSED_INT, UNUSED_INT, incx, incy, UNUSED_INT}; + calls.push_back(call); + return 3.15+N; +} + +// Y += alpha * X +__attribute__((noinline)) +void cblas_daxpy(int N, double alpha, double* X, int incx, double* Y, int incy) { + BlasCall call = {inDerivative, CallType::AXPY, + Y, X, UNUSED_POINTER, + alpha, UNUSED_DOUBLE, + UNUSED_TRANS, + UNUSED_TRANS, UNUSED_TRANS, + N, UNUSED_INT, UNUSED_INT, incx, incy, UNUSED_INT}; + calls.push_back(call); +} + +// Y = alpha * op(A) * X + beta * Y +__attribute__((noinline)) +void cblas_dgemv(char layout, char trans, int M, int N, double alpha, double* A, int lda, double* X, int incx, double beta, double* Y, int incy) { + BlasCall call = {inDerivative, CallType::GEMV, + Y, A, X, + alpha, beta, + layout, + trans, UNUSED_TRANS, + M, N, UNUSED_INT, lda, incx, incy}; + calls.push_back(call); +} + +// C = alpha * A^transA * B^transB + beta * C +__attribute__((noinline)) +void cblas_dgemm(char layout, char transA, char transB, int M, int N, int K, double alpha, double* A, int lda, double* B, int ldb, double beta, double* C, int ldc) { + calls.push_back((BlasCall){inDerivative, CallType::GEMM, + C, A, B, + alpha, beta, + layout, + transA, transB, + M, N, K, lda, ldb, ldc}); +} + +// X = alpha * X +__attribute__((noinline)) +void cblas_dscal(int N, double alpha, double* X, int incX) { + calls.push_back((BlasCall){inDerivative, CallType::SCAL, + X, UNUSED_POINTER, UNUSED_POINTER, + alpha, UNUSED_DOUBLE, + UNUSED_TRANS, + UNUSED_TRANS, UNUSED_TRANS, + N, UNUSED_INT, UNUSED_INT, incX, UNUSED_INT, UNUSED_INT}); +} + +// A = alpha * X * transpose(Y) + A +__attribute__((noinline)) +void cblas_dger(char layout, int M, int N, double alpha, double* X, int incX, double* Y, int incY, double* A, int lda) { + calls.push_back((BlasCall){inDerivative, CallType::GER, + A, X, Y, + alpha, UNUSED_DOUBLE, + layout, + UNUSED_TRANS, UNUSED_TRANS, + M, N, UNUSED_INT, incX, incY, lda}); +} + +__attribute__((noinline)) +void cblas_dcopy(int N, double* X, int incX, double* Y, int incY) { + + calls.push_back((BlasCall){inDerivative, CallType::COPY, + Y, X, UNUSED_POINTER, + alpha, UNUSED_DOUBLE, + UNUSED_TRANS, + UNUSED_TRANS, UNUSED_TRANS, + N, UNUSED_INT, UNUSED_INT, incX, incY, UNUSED_INT}); +} + +__attribute__((noinline)) +void cblas_dlacpy(char layout, char uplo, int M, int N, double* A, int lda, double* B, int ldb) { + calls.push_back((BlasCall){inDerivative, CallType::LACPY, + B, A, UNUSED_POINTER, + UNUSED_DOUBLE, UNUSED_DOUBLE, + layout, + uplo, UNUSED_TRANS, + M, N, UNUSED_INT, lda, ldb, UNUSED_INT}); +} + +__attribute__((noinline)) +void dlacpy(char *uplo, int *M, int* N, double* A, int *lda, double* B, int* ldb) { + cblas_dlacpy(CblasColMajor, *uplo, *M, *N, A, *lda, B, *ldb); +} + +} + +enum class ValueType { + Matrix, + Vector +}; +struct BlasInfo { + void* ptr; + ValueType ty; + int vec_length; + int vec_increment; + char mat_layout; + int mat_rows; + int mat_cols; + int mat_ld; + BlasInfo (void* v_ptr, int length, int increment) { + ptr = v_ptr; + ty = ValueType::Vector; + vec_length = length; + vec_increment = increment; + mat_layout = '@'; + mat_rows = -1; + mat_cols = -1; + mat_ld = -1; + } + BlasInfo (void* v_ptr, char layout, int rows, int cols, int ld) { + ptr = v_ptr; + ty = ValueType::Matrix; + vec_length = -1; + vec_increment = -1; + mat_layout = layout; + mat_rows = rows; + mat_cols = cols; + mat_ld = ld; + } + BlasInfo () { + ptr = (void*)(-1); + ty = ValueType::Matrix; + vec_length = -1; + vec_increment = -1; + mat_layout = -1; + mat_rows = -1; + mat_cols = -1; + mat_ld = -1; + } +}; + +BlasInfo pointer_to_index(void* v, BlasInfo inputs[6]) { + if (v == A || v == dA) return inputs[0]; + if (v == B || v == dB) return inputs[1]; + if (v == C || v == dC) return inputs[2]; + for (int i=3; i<6; i++) + if (inputs[i].ptr == v) + return inputs[i]; + assert(0 && " illegal pointer to invert"); +} + +void checkVector(BlasInfo info, std::string vecname, int length, int increment, std::string test, BlasCall rcall, const vector & trace) { + if (info.ty != ValueType::Vector) { + printf("Error in test %s, invalid memory\n", test.c_str()); + printTrace(trace); + printcall(rcall); + printf(" Input %s is not a vector\n", vecname.c_str()); + exit(1); + } + if (info.vec_length != length) { + printf("Error in test %s, invalid memory\n", test.c_str()); + printTrace(trace); + printcall(rcall); + printf(" Input %s length must be ", vecname.c_str()); + printty(info.vec_length); + printf(" found "); + printty(length); + printf("\n"); + exit(1); + } + if (info.vec_increment != increment) { + printf("Error in test %s, invalid memory\n", test.c_str()); + printTrace(trace); + printcall(rcall); + printf(" Input %s increment must be ", vecname.c_str()); + printty(info.vec_increment); + printf(" found "); + printty(increment); + printf("\n"); + exit(1); + } +} + +void checkMatrix(BlasInfo info, std::string matname, char layout, int rows, int cols, int ld, std::string test, BlasCall rcall, const vector & trace) { + if (info.ty != ValueType::Matrix) { + printf("Error in test %s, invalid memory\n", test.c_str()); + printTrace(trace); + printcall(rcall); + printf(" Input %s is not a matrix\n", matname.c_str()); + exit(1); + } + if (info.mat_layout != layout) { + printf("Error in test %s, invalid memory\n", test.c_str()); + printTrace(trace); + printcall(rcall); + printf(" Input %s layout must be ", matname.c_str()); + printty(info.mat_layout); + printf(" found layout="); + printty(layout); + printf("\n"); + exit(1); + } + if (info.mat_rows != rows) { + printf("Error in test %s, invalid memory\n", test.c_str()); + printTrace(trace); + printcall(rcall); + printf(" Input %s rows must be ", matname.c_str()); + printty(info.mat_rows); + printf(" found "); + printty(rows); + printf("\n"); + exit(1); + } + if (info.mat_cols != cols) { + printf("Error in test %s, invalid memory\n", test.c_str()); + printTrace(trace); + printcall(rcall); + printf(" Input %s cols must be ", matname.c_str()); + printty(info.mat_cols); + printf(" found "); + printty(cols); + printf("\n"); + exit(1); + } + if (info.mat_ld != ld) { + printf("Error in test %s, invalid memory\n", test.c_str()); + printTrace(trace); + printcall(rcall); + printf(" Input %s leading dimension rows must be ", test.c_str()); + printty(info.mat_ld); + printf(" found "); + printty(ld); + printf("\n"); + exit(1); + } +} + +void checkMemory(BlasCall rcall, BlasInfo inputs[6], std::string test, const vector & trace) { + switch (rcall.type) { + return; + case CallType::LASCL: { + auto A = pointer_to_index(rcall.pout_arg1, inputs); + + auto layout = rcall.layout; + auto type = rcall.targ1; + auto KL = rcall.iarg5; + auto KU = rcall.iarg6; + auto cfrom = rcall.farg1; + auto cto = rcall.farg2; + + auto M = rcall.iarg1; + auto N = rcall.iarg2; + auto lda = rcall.iarg4; + + // = 'G': A is a full matrix. + assert(type == 'G'); + + // A is an m-by-n matrix + checkMatrix(A, "A", layout, /*rows=*/M, /*cols=*/N, /*ld=*/lda, test, rcall, trace); + return; + } + case CallType::AXPY: { + auto Y = pointer_to_index(rcall.pout_arg1, inputs); + + auto X = pointer_to_index(rcall.pin_arg1, inputs); + + auto alpha = rcall.farg1; + + auto N = rcall.iarg1; + auto incX = rcall.iarg4; + auto incY = rcall.iarg5; + + checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace); + checkVector(Y, "Y", /*len=*/N, /*inc=*/incY, test, rcall, trace); + return; + } + case CallType::DOT: { + auto X = pointer_to_index(rcall.pin_arg1, inputs); + auto Y = pointer_to_index(rcall.pin_arg2, inputs); + + auto N = rcall.iarg1; + auto incX = rcall.iarg4; + auto incY = rcall.iarg5; + + checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace); + checkVector(Y, "Y", /*len=*/N, /*inc=*/incY, test, rcall, trace); + return; + } + case CallType::GEMV:{ + // Y = alpha * op(A) * X + beta * Y + auto Y = pointer_to_index(rcall.pout_arg1, inputs); + auto A = pointer_to_index(rcall.pin_arg1, inputs); + auto X = pointer_to_index(rcall.pin_arg2, inputs); + + auto layout = rcall.layout; + auto trans_char = rcall.targ1; + auto trans = !(trans_char == 'N' || trans_char == 'n'); + auto M = rcall.iarg1; + auto N =rcall.iarg2; + auto alpha = rcall.farg1; + auto lda = rcall.iarg4; + auto incX = rcall.iarg5; + auto beta = rcall.farg2; + auto incY = rcall.iarg6; + + // A is an m-by-n matrix + checkMatrix(A, "A", layout, /*rows=*/M, /*cols=*/N, /*ld=*/lda, test, rcall, trace); + + // if no trans, X must be N otherwise must be M + // From https://www.netlib.org/lapack/explore-html/d7/d15/group__double__blas__level2_gadd421a107a488d524859b4a64c1901a9.html + // X is DOUBLE PRECISION array, dimension at least + // ( 1 + ( n - 1 )*abs( INCX ) ) when TRANS = 'N' or 'n' + // and at least + // ( 1 + ( m - 1 )*abs( INCX ) ) otherwise. + // Before entry, the incremented array X must contain the + // vector x. + auto Xlen = trans ? M : N; + checkVector(X, "X", /*len=*/Xlen, /*inc=*/incX, test, rcall, trace); + + // if no trans, Y must be M otherwise must be N + auto Ylen = trans ? N : M; + checkVector(Y, "Y", /*len=*/Ylen, /*inc=*/incY, test, rcall, trace); + + return; + } + case CallType::GEMM:{ + // C = alpha * A^transA * B^transB + beta * C + auto C = pointer_to_index(rcall.pout_arg1, inputs); + auto A = pointer_to_index(rcall.pin_arg1, inputs); + auto B = pointer_to_index(rcall.pin_arg2, inputs); + + auto layout = rcall.layout; + auto transA_char = rcall.targ1; + auto transA = !(transA_char == 'N' || transA_char == 'n'); + auto transB_char = rcall.targ2; + auto transB = !(transB_char == 'N' || transB_char == 'n'); + auto M = rcall.iarg1; + auto N = rcall.iarg2; + auto K = rcall.iarg3; + auto alpha = rcall.farg1; + auto lda = rcall.iarg4; + auto ldb = rcall.iarg5; + auto beta = rcall.farg2; + auto ldc = rcall.iarg6; + + // From https://www.netlib.org/lapack/explore-html/d1/d54/group__double__blas__level3_gaeda3cbd99c8fb834a60a6412878226e1.html + /* + M is INTEGER + On entry, M specifies the number of rows of the matrix + op( A ) and of the matrix C. M must be at least zero. + N is INTEGER + On entry, N specifies the number of columns of the matrix + op( B ) and the number of columns of the matrix C. N must be + at least zero. + K is INTEGER + On entry, K specifies the number of columns of the matrix + op( A ) and the number of rows of the matrix op( B ). K must + be at least zero. + LDA is INTEGER + On entry, LDA specifies the first dimension of A as declared + in the calling (sub) program. When TRANSA = 'N' or 'n' then + LDA must be at least max( 1, m ), otherwise LDA must be at + least max( 1, k ). + */ + checkMatrix(A, "A", layout, /*rows=*/(!transA) ? M : K, /*cols=*/(!transA) ? K : M, /*ld=*/lda, test, rcall, trace); + checkMatrix(B, "B", layout, /*rows=*/(!transB) ? K : N, /*cols=*/(!transB) ? N : K, /*ld=*/ldb, test, rcall, trace); + checkMatrix(C, "C", layout, /*rows=*/M, /*cols=*/N, /*ld=*/ldc, test, rcall, trace); + return; + } + + case CallType::SCAL: { + auto N = rcall.iarg1; + auto alpha = rcall.farg1; + auto X = pointer_to_index(rcall.pout_arg1, inputs); + auto incX = rcall.iarg4; + checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace); + return; + } + case CallType::GER: { + // A = alpha * X * transpose(Y) + A + auto A = pointer_to_index(rcall.pout_arg1, inputs); + auto X = pointer_to_index(rcall.pin_arg1, inputs); + auto Y = pointer_to_index(rcall.pin_arg2, inputs); + + auto layout = rcall.layout; + auto M = rcall.iarg1; + auto N = rcall.iarg2; + auto alpha = rcall.farg1; + auto incX = rcall.iarg4; + auto incY = rcall.iarg5; + auto incA = rcall.iarg6; + + // From https://www.netlib.org/lapack/explore-html/d7/d15/group__double__blas__level2_ga458222e01b4d348e9b52b9343d52f828.html + // x is an m element vector, y is an n element + // vector and A is an m by n matrix. + checkVector(X, "X", /*len=*/M, /*inc=*/incX, test, rcall, trace); + checkVector(Y, "Y", /*len=*/N, /*inc=*/incY, test, rcall, trace); + checkMatrix(A, "A", layout, /*rows=*/M, /*cols=*/N, /*ld=*/incA, test, rcall, trace); + return; + } + case CallType::COPY: { + auto Y = pointer_to_index(rcall.pout_arg1, inputs); + auto X = pointer_to_index(rcall.pin_arg1, inputs); + + auto N = rcall.iarg1; + auto incX = rcall.iarg4; + auto incY = rcall.iarg5; + checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace); + checkVector(Y, "Y", /*len=*/N, /*inc=*/incY, test, rcall, trace); + return; + } + case CallType::LACPY: { + auto B = pointer_to_index(rcall.pout_arg1, inputs); + auto A = pointer_to_index(rcall.pin_arg1, inputs); + + auto layout = rcall.layout; + auto uplo = rcall.targ1; + auto M = rcall.iarg1; + auto N = rcall.iarg2; + auto lda = rcall.iarg4; + auto ldb = rcall.iarg5; + checkMatrix(A, "A", layout, /*rows=*/M, /*cols=*/N, /*ld=*/lda, test, rcall, trace); + checkMatrix(B, "B", layout, /*rows=*/M, /*cols=*/N, /*ld=*/ldb, test, rcall, trace); + return; + } + default: printf("UNKNOWN CALL (%d)", (int)rcall.type); return; + } +} + +void checkMemoryTrace(BlasInfo inputs[6], std::string test, const vector & trace) { + for (size_t i=0; i 0); auto argTypeMap = pattern.getArgTypeMap(); bool lv23 = pattern.isBLASLevel2or3(); + const auto mutArgSet = pattern.getMutableArgs(); os << " const bool byRef = blas.prefix == \"\";\n"; os << " Value *cacheval = nullptr;\n\n"; @@ -256,10 +257,39 @@ void emit_helper(const TGPattern &pattern, raw_ostream &os) { << " auto shadow_" << name << " = gutils->invertPointerM(orig_" << name << ", BuilderZ);\n" << " rt_inactive_" << name << " = BuilderZ.CreateICmpEQ(shadow_" - << name << ", arg_" << name << ", (Twine(\"rt.inactive.\") + \"" << name - << "\").str());\n" + << name << ", arg_" << name << ", \"rt.tmp.inactive.\" \"" << name + << "\");\n" << " }\n"; } + // Blas functions return one float XOR modify one output arg. + // If we have runtimeActivity and the output arg is inactive, + // we don't need to do anything here and can return early. + if (mutArgSet.size() == 1) { + for (auto pos : mutArgSet) { + auto name = nameVec[pos]; + os << " Value *rt_inactive_out = nullptr;\n"; + os << " if (active_" << name << ") {\n" + << " rt_inactive_out = rt_inactive_" << name << ";\n" + << " } else {\n" + << " rt_inactive_out = " + "ConstantInt::getTrue(BuilderZ.getContext());\n" + << " }\n"; + break; + } + for (size_t i = 0; i < actArgs.size(); i++) { + auto name = nameVec[actArgs[i]]; + // floats are passed by calue, except of the Fortran Abi (byRef) + auto ty = argTypeMap.lookup(actArgs[i]); + os << " if ("; + if (ty == ArgType::fp) + os << "byRef && "; + os << "active_" << name << ") {\n" + << " rt_inactive_" << name << " = BuilderZ.CreateOr(rt_inactive_" + << name << ", rt_inactive_out, \"rt.inactive.\" \"" << name << "\");\n" + << " }\n"; + } + } + os << " }\n"; bool hasFP = false; @@ -918,7 +948,13 @@ void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos, return; } if (Def->getName() == "Concat") { - os << "concat_values("; + os << "concat_values<"; + for (size_t i = 0; i < Dag->getNumArgs(); i++) { + if (i != 0) + os << ", "; + os << "ArrayRef"; + } + os << ">("; for (size_t i = 0; i < Dag->getNumArgs(); i++) { if (i != 0) os << ", "; @@ -1108,7 +1144,7 @@ void rev_call_args(StringRef argName, Rule &rule, size_t actArg, } os << " if (byRef) {\n"; int n = 0; - if (func == "gemv") + if (func == "gemv" || func == "lascl") n = 1; if (func == "gemm") n = 2; @@ -1334,8 +1370,8 @@ void emit_rev_rewrite_rules(const StringMap &patternMap, for (auto arg : activeArgs) { const auto name = nameVec[arg]; const auto ty = typeMap.lookup(arg); - // We don't pass in shaddows of fp values, - // we just create and struct-return the shaddows + // We don't pass in shadows of fp values, + // we just create and struct-return the shadows if (ty == ArgType::fp) continue; os << ((first) ? "" : ", ") << "Value *" @@ -1529,8 +1565,8 @@ void emit_rev_rewrite_rules(const StringMap &patternMap, for (auto arg : activeArgs) { const auto name = nameVec[arg]; const auto ty = typeMap.lookup(arg); - // We don't pass in shaddows of fp values, - // we just create and struct-return the shaddows + // We don't pass in shadows of fp values, + // we just create and struct-return the shadows if (ty == ArgType::fp) continue; os << ((first) ? "" : ", ") << "d_" + name; diff --git a/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h b/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h index c3ffc794e31f..3556f35cbba0 100644 --- a/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h +++ b/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h @@ -70,19 +70,29 @@ void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) { } os << " if (byRef) {\n"; - for (size_t argPos = 0; argPos < argTypeMap.size(); argPos++) { + int numCharArgs = 0; + size_t numArgs = argTypeMap.size(); + for (size_t argPos = 0; argPos < numArgs; argPos++) { + const auto typeOfArg = argTypeMap.lookup(argPos); + if (is_char_arg(typeOfArg)) + numCharArgs++; + } + + for (size_t argPos = 0; argPos < numArgs; argPos++) { const auto typeOfArg = argTypeMap.lookup(argPos); size_t i = (lv23 ? argPos - 1 : argPos); - if (typeOfArg == ArgType::len || typeOfArg == ArgType::vincInc || - typeOfArg == ArgType::fp || typeOfArg == ArgType::trans || - typeOfArg == ArgType::mldLD || typeOfArg == ArgType::uplo || - typeOfArg == ArgType::diag || typeOfArg == ArgType::side) { - os << " F->removeParamAttr(" << i << (lv23 ? " + offset" : "") - << ", llvm::Attribute::ReadNone);\n" - << " F->addParamAttr(" << i << (lv23 ? " + offset" : "") - << ", llvm::Attribute::ReadOnly);\n" - << " F->addParamAttr(" << i << (lv23 ? " + offset" : "") - << ", llvm::Attribute::NoCapture);\n"; + + if (is_char_arg(typeOfArg) || typeOfArg == ArgType::len || + typeOfArg == ArgType::vincInc || typeOfArg == ArgType::fp || + typeOfArg == ArgType::mldLD) { + if (is_char_arg(typeOfArg) && numArgs - argPos <= numCharArgs) { + os << " F->removeParamAttr(" << i << (lv23 ? " + offset" : "") + << ", llvm::Attribute::ReadNone);\n" + << " F->addParamAttr(" << i << (lv23 ? " + offset" : "") + << ", llvm::Attribute::ReadOnly);\n" + << " F->addParamAttr(" << i << (lv23 ? " + offset" : "") + << ", llvm::Attribute::NoCapture);\n"; + } } } @@ -90,7 +100,7 @@ void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) { << " // Julia declares double* pointers as Int64,\n" << " // so LLVM won't let us add these Attributes.\n" << " if (!julia_decl) {\n"; - for (size_t argPos = 0; argPos < argTypeMap.size(); argPos++) { + for (size_t argPos = 0; argPos < numArgs; argPos++) { auto typeOfArg = argTypeMap.lookup(argPos); size_t i = (lv23 ? argPos - 1 : argPos); if (typeOfArg == ArgType::vincData || typeOfArg == ArgType::mldData) { diff --git a/enzyme/tools/enzyme-tblgen/datastructures.cpp b/enzyme/tools/enzyme-tblgen/datastructures.cpp index 2d7a14d26e9f..310a853bb0d8 100644 --- a/enzyme/tools/enzyme-tblgen/datastructures.cpp +++ b/enzyme/tools/enzyme-tblgen/datastructures.cpp @@ -23,6 +23,13 @@ raw_ostream &operator<<(raw_fd_ostream &os, ArgType arg) { using namespace llvm; +bool is_char_arg(ArgType ty) { + if (ty == ArgType::side || ty == ArgType::diag || ty == ArgType::trans || + ty == ArgType::uplo) + return true; + return false; +} + const char *TyToString(ArgType ty) { switch (ty) { case ArgType::fp: diff --git a/enzyme/tools/enzyme-tblgen/datastructures.h b/enzyme/tools/enzyme-tblgen/datastructures.h index 3416f3105491..c0df6d57c96a 100644 --- a/enzyme/tools/enzyme-tblgen/datastructures.h +++ b/enzyme/tools/enzyme-tblgen/datastructures.h @@ -30,6 +30,8 @@ enum class ArgType { side }; +bool is_char_arg(ArgType ty); + namespace llvm { raw_ostream &operator<<(raw_ostream &os, ArgType arg); raw_ostream &operator<<(raw_fd_ostream &os, ArgType arg); From 3d953157f56bc8ac594ca4f7f60adc91ed2edeeb Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 21 Sep 2023 12:43:26 -0400 Subject: [PATCH 25/29] fix c++ test --- enzyme/test/Integration/ReverseMode/blas_runtime.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/test/Integration/ReverseMode/blas_runtime.cpp b/enzyme/test/Integration/ReverseMode/blas_runtime.cpp index ccae8b6ef818..267228476308 100644 --- a/enzyme/test/Integration/ReverseMode/blas_runtime.cpp +++ b/enzyme/test/Integration/ReverseMode/blas_runtime.cpp @@ -2,7 +2,7 @@ // a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... // RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 -mllvm -enzyme-runtime-activity | %lli -; fi // RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 -mllvm -enzyme-runtime-activity | %lli -; fi -// RUN: %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 -mllvm -enzyme-runtime-activity | %lli -; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 -mllvm -enzyme-runtime-activity | %lli -; fi // RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -mllvm -enzyme-runtime-activity -S | %lli -; fi // RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -mllvm -enzyme-runtime-activity -S | %lli -; fi // RUN: if [ %llvmver -ge 12 ]; then %clang++ -fno-exceptions -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -mllvm -enzyme-lapack-copy=1 -mllvm -enzyme-runtime-activity -S | %lli -; fi From 0b4092c8c357832a6b6489d0c022c6773b1e1dc2 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 21 Sep 2023 12:59:02 -0400 Subject: [PATCH 26/29] Fix C ABI --- enzyme/Enzyme/CApi.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index 89c6a266cfe6..212843852ef7 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -580,7 +580,7 @@ LLVMValueRef EnzymeCreatePrimalAndGradient( overwritten_args.push_back(_overwritten_args[i]); } return wrap(eunwrap(Logic).CreatePrimalAndGradient( - RequestContext(cast(unwrap(request_req)), + RequestContext(cast_or_null(unwrap(request_req)), unwrap(request_ip)), (ReverseCacheKey){ .todiff = cast(unwrap(todiff)), From 57dff0ecf5c754e04795effe5175364acc3a6f8b Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 26 Sep 2023 16:39:46 -0500 Subject: [PATCH 27/29] Don't store inverted global into map [fixes module linking state] (#1456) --- enzyme/Enzyme/GradientUtils.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 7c00ac376bad..974e3fad1887 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -5319,8 +5319,6 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, MDTuple::get(shadow->getContext(), {})); shadow->setAlignment(arg->getAlign()); shadow->setUnnamedAddr(arg->getUnnamedAddr()); - invertedPointers.insert(std::make_pair( - (const Value *)oval, InvertedPointerVH(this, shadow))); return shadow; } From 2021e178a7e81e8ee714c31764b56efaf8ff972a Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 26 Sep 2023 22:32:09 -0500 Subject: [PATCH 28/29] Add write barrier binding support (#1457) --- enzyme/Enzyme/ActivityAnalysis.cpp | 3 ++- enzyme/Enzyme/AdjointGenerator.h | 3 ++- enzyme/Enzyme/DifferentialUseAnalysis.h | 3 ++- enzyme/Enzyme/EnzymeLogic.cpp | 9 ++++++++- enzyme/Enzyme/FunctionUtils.cpp | 4 ++++ enzyme/Enzyme/GradientUtils.cpp | 11 +++++++---- 6 files changed, 25 insertions(+), 8 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 8ce7dba7b44c..0d9a2bf7be84 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -2930,7 +2930,8 @@ bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults const &TR, if (F) { if (UA == UseActivity::AllStores && - F->getName() == "julia.write_barrier") + (F->getName() == "julia.write_barrier" || + F->getName() == "julia.write_barrier_binding")) continue; if (F->getIntrinsicID() == Intrinsic::memcpy || F->getIntrinsicID() == Intrinsic::memmove) { diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 5eef7104043c..9d6bc93bc094 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -8546,7 +8546,8 @@ class AdjointGenerator } if (called) { - if (funcName == "julia.write_barrier") { + if (funcName == "julia.write_barrier" || + funcName == "julia.write_barrier_binding") { bool backwardsShadow = false; bool forwardsShadow = true; for (auto pair : gutils->backwardsOnlyShadows) { diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.h b/enzyme/Enzyme/DifferentialUseAnalysis.h index 9539b9dd8e5f..5ebfa36328e5 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.h +++ b/enzyme/Enzyme/DifferentialUseAnalysis.h @@ -296,7 +296,8 @@ inline bool is_value_needed_in_reverse( // Use in a write barrier requires the shadow in the forward, even // though the instruction is active. if (mode != DerivativeMode::ReverseModeGradient && - funcName == "julia.write_barrier") { + (funcName == "julia.write_barrier" || + funcName == "julia.write_barrier_binding")) { if (EnzymePrintDiffUse) llvm::errs() << " Need: " << to_string(VT) << " of " << *inst << " in reverse as shadow write_barrier " << *CI diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 420939ec2bc8..51c4aa300f36 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -430,6 +430,9 @@ struct CacheAnalysis { if (funcName == "julia.write_barrier") return {}; + if (funcName == "julia.write_barrier_binding") + return {}; + if (funcName == "enzyme_zerotype") return {}; @@ -762,6 +765,9 @@ void calculateUnusedValuesInFunction( if (getFuncNameFromCall(CI) == "julia.write_barrier") { continue; } + if (getFuncNameFromCall(CI) == "julia.write_barrier_binding") { + continue; + } bool writeOnlyNoCapture = true; if (shouldDisableNoWrite(CI)) { writeOnlyNoCapture = false; @@ -1004,7 +1010,8 @@ void calculateUnusedValuesInFunction( const Function *CF = CI ? getFunctionFromCall(CI) : nullptr; StringRef funcName = CF ? CF->getName() : ""; if (isa(inst) || isa(inst) || - isa(inst) || funcName == "julia.write_barrier") { + isa(inst) || funcName == "julia.write_barrier" || + funcName == "julia.write_barrier_binding") { for (auto pair : gutils->rematerializableAllocations) { if (pair.second.stores.count(inst)) { if (DifferentialUseAnalysis::is_value_needed_in_reverse< diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 6189ab1870bc..1a05c4d709ae 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -409,6 +409,10 @@ void RecursivelyReplaceAddressSpace(Value *AI, Value *rep, bool legal) { toErase.push_back(CI); continue; } + if (F->getName() == "julia.write_barrier_binding" && legal) { + toErase.push_back(CI); + continue; + } } IRBuilder<> B(CI); auto Addr = B.CreateAddrSpaceCast(rep, prev->getType()); diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 974e3fad1887..dd5bce668daf 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -3188,8 +3188,9 @@ BasicBlock *GradientUtils::prepRematerializedLoopEntry(LoopContext &lc) { StringRef funcName = getFuncNameFromCall(CI); if (funcName == "enzyme_zerotype") continue; - if (funcName == "julia.write_barrier" || isa(&I) || - isa(&I)) { + if (funcName == "julia.write_barrier" || + funcName == "julia.write_barrier_binding" || + isa(&I) || isa(&I)) { // TODO SmallVector args; @@ -3380,7 +3381,8 @@ BasicBlock *GradientUtils::prepRematerializedLoopEntry(LoopContext &lc) { } } else if (auto CI = dyn_cast(&I)) { StringRef funcName = getFuncNameFromCall(CI); - if (funcName == "julia.write_barrier") { + if (funcName == "julia.write_barrier" || + funcName == "julia.write_barrier_binding") { // TODO SmallVector args; @@ -8514,7 +8516,8 @@ void GradientUtils::computeForwardingProperties(Instruction *V) { frees.insert(CI); continue; } - if (funcName == "julia.write_barrier") { + if (funcName == "julia.write_barrier" || + funcName == "julia.write_barrier_binding") { stores.insert(CI); continue; } From d0ba44d2c895fe98800566720234e10de526eddb Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 26 Sep 2023 22:32:27 -0500 Subject: [PATCH 29/29] Fix forward mode nice c++ error (#1458) * Fix forward mode nice c++ error * actually fix sse2 * Add test --- enzyme/Enzyme/AdjointGenerator.h | 11 ++++------ enzyme/Enzyme/InstructionDerivatives.td | 4 ++-- enzyme/test/Enzyme/ForwardMode/maxpd.ll | 27 +++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 9 deletions(-) create mode 100644 enzyme/test/Enzyme/ForwardMode/maxpd.ll diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 9d6bc93bc094..ddc91787f3c6 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -30,6 +30,7 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IntrinsicsX86.h" #include "llvm/IR/Value.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -3696,8 +3697,6 @@ class AdjointGenerator return false; std::string s; llvm::raw_string_ostream ss(s); - ss << *gutils->oldFunc << "\n"; - ss << *gutils->newFunc << "\n"; if (Intrinsic::isOverloaded(ID)) #if LLVM_VERSION_MAJOR >= 13 ss << "cannot handle (forward) unknown intrinsic\n" @@ -3718,14 +3717,12 @@ class AdjointGenerator CustomErrorHandler(ss.str().c_str(), wrap(&I), ErrorType::NoDerivative, gutils, nullptr, wrap(&Builder2)); - setDiffe(&I, - Constant::getNullValue(gutils->getShadowType(I.getType())), - Builder2); - return false; } else { EmitFailure("NoDerivative", I.getDebugLoc(), &I, ss.str()); - return false; } + setDiffe(&I, Constant::getNullValue(gutils->getShadowType(I.getType())), + Builder2); + return false; } return false; } diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 9de7f8b871ae..c19452f48e5f 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -777,7 +777,7 @@ def : IntrPattern<(Op $x, $y), >; def : IntrPattern<(Op $x, $y), - [["minnum"], ["nvvm_fmin_f"], ["nvvm_fmin_d"], ["nvvm_fmin_ftz_f"], ["x86_sse_min_ss", "", "9"], ["x86_sse_min_ps", "", "9"], ["minimum", "15", ""]], + [["minnum"], ["nvvm_fmin_f"], ["nvvm_fmin_d"], ["nvvm_fmin_ftz_f"], ["x86_sse_min_ss", "", "9"], ["x86_sse_min_ps", "", "9"], ["x86_sse2_min_pd", "", ""], ["minimum", "15", ""]], [ (Select (FCmpOLT $x, $y), (DiffeRet), (ConstantFP<"0"> $x)), (Select (FCmpOLT $x, $y), (ConstantFP<"0"> $x), (DiffeRet)) @@ -786,7 +786,7 @@ def : IntrPattern<(Op $x, $y), >; def : IntrPattern<(Op $x, $y), - [["maxnum"], ["nvvm_fmax_f"], ["nvvm_fmax_d"], ["nvvm_fmax_ftz_f"], ["x86_sse_max_ss", "", "9"], ["x86_sse_max_ps", "", "9"], ["maximum", "15", ""]], + [["maxnum"], ["nvvm_fmax_f"], ["nvvm_fmax_d"], ["nvvm_fmax_ftz_f"], ["x86_sse_max_ss", "", "9"], ["x86_sse_max_ps", "", "9"], ["x86_sse2_max_pd", "", ""], ["maximum", "15", ""]], [ (Select (FCmpOLT $x, $y), (ConstantFP<"0"> $x), (DiffeRet)), (Select (FCmpOLT $x, $y), (DiffeRet), (ConstantFP<"0"> $x)) diff --git a/enzyme/test/Enzyme/ForwardMode/maxpd.ll b/enzyme/test/Enzyme/ForwardMode/maxpd.ll new file mode 100644 index 000000000000..fe1e75c36f33 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/maxpd.ll @@ -0,0 +1,27 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +declare <2 x double> @llvm.x86.sse2.max.pd(<2 x double>, <2 x double>) + +define dso_local <2 x double> @max(<2 x double> %x, <2 x double> %y) { +entry: + %res = call <2 x double> @llvm.x86.sse2.max.pd(<2 x double> %x, <2 x double> %y) + ret <2 x double> %res +} + +; Function Attrs: nounwind uwtable +define dso_local <2 x double> @test_derivative(<2 x double> %x, <2 x double> %y) local_unnamed_addr #1 { +entry: + %0 = tail call <2 x double> (...) @__enzyme_fwddiff(<2 x double> (<2 x double>, <2 x double>)* nonnull @max, <2 x double> %x, <2 x double> %x, <2 x double> %y, <2 x double> %y) + ret <2 x double> %0 +} + +; Function Attrs: nounwind +declare <2 x double> @__enzyme_fwddiff(...) + +; CHECK: define internal <2 x double> @fwddiffemax(<2 x double> %x, <2 x double> %"x'", <2 x double> %y, <2 x double> %"y'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fcmp{{( fast)?}} olt <2 x double> %x, %y +; CHECK-NEXT: %1 = select{{( fast)?}} <2 x i1> %0, <2 x double> %"y'", <2 x double> %"x'" +; CHECK-NEXT: ret <2 x double> %1 +; CHECK-NEXT: }