Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into controlflowfixpart1…
Browse files Browse the repository at this point in the history
…try2
  • Loading branch information
tthsqe12 committed Oct 1, 2023
2 parents 543b8b6 + d0ba44d commit 5bcc9e9
Show file tree
Hide file tree
Showing 61 changed files with 5,162 additions and 2,480 deletions.
3 changes: 2 additions & 1 deletion enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2930,7 +2930,8 @@ bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults const &TR,

if (F) {
if (UA == UseActivity::AllStores &&
F->getName() == "julia.write_barrier")
(F->getName() == "julia.write_barrier" ||
F->getName() == "julia.write_barrier_binding"))
continue;
if (F->getIntrinsicID() == Intrinsic::memcpy ||
F->getIntrinsicID() == Intrinsic::memmove) {
Expand Down
95 changes: 72 additions & 23 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IntrinsicsX86.h"
#include "llvm/IR/Value.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
Expand Down Expand Up @@ -2747,6 +2748,54 @@ class AdjointGenerator
auto &DL = gutils->newFunc->getParent()->getDataLayout();
auto vd = TR.query(MS.getOperand(0)).Data0().ShiftIndices(DL, 0, size, 0);

if (!vd.isKnownPastPointer()) {
// If unknown type results, and zeroing known undef allocation, consider
// integers
if (auto CI = dyn_cast<ConstantInt>(MS.getOperand(1)))
if (CI->isZero()) {
auto root = getBaseObject(MS.getOperand(0));
bool writtenTo = false;
bool undefMemory =
isa<AllocaInst>(root) || isAllocationCall(root, gutils->TLI);
if (auto arg = dyn_cast<Argument>(root))
if (arg->hasStructRetAttr())
undefMemory = true;
if (undefMemory) {
Instruction *cur = MS.getPrevNode();
while (cur) {
if (cur == root)
break;
if (auto MCI = dyn_cast<ConstantInt>(MS.getOperand(2))) {
if (auto II = dyn_cast<IntrinsicInst>(cur)) {
// If the start of the lifetime for more memory than being
// memset, its valid.
if (II->getIntrinsicID() == Intrinsic::lifetime_start) {
if (getBaseObject(II->getOperand(1)) == root) {
if (auto CI2 = dyn_cast<ConstantInt>(II->getOperand(0))) {
if (MCI->getValue().ule(CI2->getValue()))
break;
}
}
cur = cur->getPrevNode();
continue;
}
}
}
if (cur->mayWriteToMemory()) {
writtenTo = true;
break;
}
cur = cur->getPrevNode();
}

if (!writtenTo) {
vd = TypeTree(BaseType::Pointer);
vd.insert({-1}, BaseType::Integer);
}
}
}
}

if (!vd.isKnownPastPointer()) {
// If unknown type results, consider the intersection of all incoming.
if (isa<PHINode>(MS.getOperand(0)) || isa<SelectInst>(MS.getOperand(0))) {
Expand Down Expand Up @@ -3648,8 +3697,6 @@ class AdjointGenerator
return false;
std::string s;
llvm::raw_string_ostream ss(s);
ss << *gutils->oldFunc << "\n";
ss << *gutils->newFunc << "\n";
if (Intrinsic::isOverloaded(ID))
#if LLVM_VERSION_MAJOR >= 13
ss << "cannot handle (forward) unknown intrinsic\n"
Expand All @@ -3670,14 +3717,12 @@ class AdjointGenerator
CustomErrorHandler(ss.str().c_str(), wrap(&I),
ErrorType::NoDerivative, gutils, nullptr,
wrap(&Builder2));
setDiffe(&I,
Constant::getNullValue(gutils->getShadowType(I.getType())),
Builder2);
return false;
} else {
EmitFailure("NoDerivative", I.getDebugLoc(), &I, ss.str());
return false;
}
setDiffe(&I, Constant::getNullValue(gutils->getShadowType(I.getType())),
Builder2);
return false;
}
return false;
}
Expand Down Expand Up @@ -3848,8 +3893,9 @@ class AdjointGenerator
Mode == DerivativeMode::ReverseModeCombined) {
if (called) {
subdata = &gutils->Logic.CreateAugmentedPrimal(
cast<Function>(called), subretType, argsInverted,
TR.analyzer.interprocedural, /*return is used*/ false,
RequestContext(&call, &BuilderZ), cast<Function>(called),
subretType, argsInverted, TR.analyzer.interprocedural,
/*return is used*/ false,
/*shadowReturnUsed*/ false, nextTypeInfo, overwritten_args, false,
gutils->getWidth(),
/*AtomicAdd*/ true,
Expand Down Expand Up @@ -4048,6 +4094,7 @@ class AdjointGenerator
}

newcalled = gutils->Logic.CreatePrimalAndGradient(
RequestContext(&call, &Builder2),
(ReverseCacheKey){.todiff = cast<Function>(called),
.retType = subretType,
.constant_args = argsInverted,
Expand Down Expand Up @@ -6803,8 +6850,9 @@ class AdjointGenerator

if (called) {
newcalled = gutils->Logic.CreateForwardDiff(
cast<Function>(called), subretType, argsInverted,
TR.analyzer.interprocedural, /*returnValue*/ subretused, Mode,
RequestContext(&call, &BuilderZ), cast<Function>(called),
subretType, argsInverted, TR.analyzer.interprocedural,
/*returnValue*/ subretused, Mode,
((DiffeGradientUtils *)gutils)->FreeMemory, gutils->getWidth(),
tape ? tape->getType() : nullptr, nextTypeInfo, overwritten_args,
/*augmented*/ subdata);
Expand Down Expand Up @@ -7206,10 +7254,10 @@ class AdjointGenerator
if (Mode == DerivativeMode::ReverseModePrimal ||
Mode == DerivativeMode::ReverseModeCombined) {
subdata = &gutils->Logic.CreateAugmentedPrimal(
cast<Function>(called), subretType, argsInverted,
TR.analyzer.interprocedural, /*return is used*/ subretused,
shadowReturnUsed, nextTypeInfo, overwritten_args, false,
gutils->getWidth(), gutils->AtomicAdd);
RequestContext(&call, &BuilderZ), cast<Function>(called),
subretType, argsInverted, TR.analyzer.interprocedural,
/*return is used*/ subretused, shadowReturnUsed, nextTypeInfo,
overwritten_args, false, gutils->getWidth(), gutils->AtomicAdd);
if (Mode == DerivativeMode::ReverseModePrimal) {
assert(augmentedReturn);
auto subaugmentations =
Expand Down Expand Up @@ -7591,6 +7639,7 @@ class AdjointGenerator
}

newcalled = gutils->Logic.CreatePrimalAndGradient(
RequestContext(&call, &Builder2),
(ReverseCacheKey){.todiff = cast<Function>(called),
.retType = subretType,
.constant_args = argsInverted,
Expand Down Expand Up @@ -8109,15 +8158,13 @@ class AdjointGenerator
return;
}

if (!called || called->empty()) {
if (auto blas = extractBLAS(funcName)) {
if (auto blas = extractBLAS(funcName)) {
#if LLVM_VERSION_MAJOR >= 16
if (handleBLAS(call, called, blas.value(), overwritten_args))
if (handleBLAS(call, called, blas.value(), overwritten_args))
#else
if (handleBLAS(call, called, blas.getValue(), overwritten_args))
if (handleBLAS(call, called, blas.getValue(), overwritten_args))
#endif
return;
}
return;
}

if (funcName == "printf" || funcName == "puts" ||
Expand Down Expand Up @@ -8496,7 +8543,8 @@ class AdjointGenerator
}

if (called) {
if (funcName == "julia.write_barrier") {
if (funcName == "julia.write_barrier" ||
funcName == "julia.write_barrier_binding") {
bool backwardsShadow = false;
bool forwardsShadow = true;
for (auto pair : gutils->backwardsOnlyShadows) {
Expand Down Expand Up @@ -10020,7 +10068,8 @@ class AdjointGenerator
auto callval = call.getCalledOperand();
if (!isa<Constant>(callval))
callval = gutils->getNewFromOriginal(callval);
newCall->setCalledOperand(gutils->Logic.CreateNoFree(callval));
newCall->setCalledOperand(gutils->Logic.CreateNoFree(
RequestContext(&call, &BuilderZ), callval));
}
if (gutils->knownRecomputeHeuristic.find(&call) !=
gutils->knownRecomputeHeuristic.end()) {
Expand Down
87 changes: 54 additions & 33 deletions enzyme/Enzyme/BlasDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def tp : MagicInst; // transpose the trans param.
def noop : MagicInst; // gradient is zero
def inactive : MagicInst; // like noop, but assert it's inactive
def Rows : MagicInst; // given a transpose, normal rows, normal cols get the true rows, aka normal rows if N else normal cols
def Concat : MagicInst;

// if !cache_A, then just use $lda.
// if cache_A, then check $transa.
Expand Down Expand Up @@ -109,8 +110,8 @@ def scal : CallBlasPattern<(Op $n, $alpha, $x, $incx),
["x"],[len, fp, vinc<["n"]>],
[
// dot must proceed scal, because scal modifies adj<"x">
(b<"dot"> $n, $x, $incx, adj<"x">, $incx),
(b<"scal"> $n, $alpha, adj<"x">, $incx)
(b<"dot"> $n, $x, adj<"x">),
(b<"scal"> $n, $alpha, adj<"x">)
]
>;

Expand All @@ -134,18 +135,17 @@ def lascl : CallBlasPattern<(Op $layout, $type, $kl, $ku, $cfrom, $cto, $m, $n,
def axpy : CallBlasPattern<(Op $n, $alpha, $x, $incx, $y, $incy),
["y"],[len, fp, vinc<["n"]>, vinc<["n"]>],
[
// dot must proceed scal, because scal modifies adj<"x">
(inactive),
(inactive),//(b<"scal"> $n, $alpha, adj<"x">, $incy),
(inactive) // y = (Ax) + y, so nothing to do here
(b<"dot"> $n, adj<"y">, $x),
(b<"axpy"> $n, $alpha, adj<"y">, adj<"x">),
(noop) // y = alpha*x + y, so nothing to do here
]
>;

def dot : CallBlasPattern<(Op $n, $x, $incx, $y, $incy),
[],[len, vinc<["n"]>, vinc<["n"]>],
[
(b<"axpy"> $n, DiffeRet, $y, $incy, adj<"x">, $incx),
(b<"axpy"> $n, DiffeRet, $x, $incx, adj<"y">, $incy)
(b<"axpy"> $n, DiffeRet, $y, adj<"x">),
(b<"axpy"> $n, DiffeRet, $x, adj<"y">),
]
>;

Expand All @@ -155,13 +155,13 @@ def dot : CallBlasPattern<(Op $n, $x, $incx, $y, $incy),
// >;


// def copy : CallBlasPattern<(Op $n, $x, $incx, $y, $incy),
// ["y"],[len, vinc, vinc],
// [
// (noop),// copy moves x into y, so x is never modified.
// (b<"axpy"> $n, Constant<"1.0">, adj<"y">, $incy, adj<"x">, $incx)
// ]
// >;
def copy : CallBlasPattern<(Op $n, $x, $incx, $y, $incy),
["y"],[len, vinc<["n"]>, vinc<["n"]>],
[
(noop),// copy moves x into y, so x is never modified.
(b<"axpy"> $n, Constant<"1.0">, adj<"y">, adj<"x">)
]
>;

// def swap : CallBlasPattern<(Op $n, $x, $incx, $y, $incy),
// ["x","y"],[len, vinc, vinc],
Expand All @@ -185,12 +185,18 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $
["y"], [cblas_layout, trans, len, len, fp, mld<["m", "n"]>, vinc<["transa", "n", "m"]>, fp, vinc<["transa", "m", "n"]>],
[
/* alpha */ (Seq<["Ax", "is_normal", "transa", "m", "n"]>
(b<"gemv"> $layout, $transa, $m, $n, Constant<"1.0">, $A, (ld $A, $transa, $lda, $m, $n), $x, $incx, Constant<"0.0">, use<"Ax">, ConstantInt<1>),
(b<"dot"> (Rows $transa, $m, $n), adj<"y">, $incy, use<"Ax">, ConstantInt<1>)),
/* A */ (b<"ger"> $layout, $m, $n, $alpha, adj<"y">, $incy, $x, $incx, adj<"A">, $lda),
/* x */ (b<"gemv"> $layout, transpose<"transa">, $m, $n, $alpha, $A, (ld $A, $transa, $lda, $m, $n), adj<"y">, $incy, Constant<"1.0">, adj<"x">, $incx),
/* beta */ (b<"dot"> (Rows $transa, $m, $n), adj<"y">, $incy, input<"y">, $incy),
/* y */ (b<"scal"> (Rows $transa, $m, $n), $beta, adj<"y">, $incy)
(b<"gemv"> $layout, $transa, $m, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $m, $n), $x, Constant<"0.0">, use<"Ax">, ConstantInt<1>),
(b<"dot"> (Rows $transa, $m, $n), adj<"y">, use<"Ax">, ConstantInt<1>)),

//if (is_normal $transa) {
// call sger(m, n, alpha, ya, incy, x, incx, Aa, lda)
//} else {
// call sger(m, n, alpha, x, incx, ya, incy, Aa, lda)
//}
/* A */ (b<"ger"> $layout, $m, $n, $alpha, (Rows $transa, (Concat adj<"y">, $x), (Concat $x, adj<"y">)), adj<"A">),
/* x */ (b<"gemv"> $layout, transpose<"transa">, $m, $n, $alpha, $A, (ld $A, Char<"N">, $lda, $m, $n), adj<"y">, Constant<"1.0">, adj<"x">),
/* beta */ (b<"dot"> (Rows $transa, $m, $n), adj<"y">, input<"y">),
/* y */ (b<"scal"> (Rows $transa, $m, $n), $beta, adj<"y">)
]
>;
//
Expand All @@ -217,11 +223,26 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A

/* alpha */ (Seq<["AB", "product", "m", "n"]>
(b<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $m, $k), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n
(FrobInnerProd<""> $m, $n, adj<"C">, $ldc, use<"AB">)),
/* A */ (b<"gemm"> $layout, $transa, transpose<"transb">, $m, $k, $n, $alpha, adj<"C">, $ldc, $B, (ld $B, $transb, $ldb, $k, $n), $beta, adj<"A">, $lda),
/* B */ (b<"gemm"> $layout, transpose<"transa">, $transb, $k, $n, $m, $alpha, $A, (ld $A, $transa, $lda, $m, $k), adj<"C">, $ldc, $beta, adj<"B">, $ldb),
/* beta */ (FrobInnerProd<""> $m, $n, adj<"C">, $ldc, input<"C">),
/* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, $ldc, ConstantInt<0>)
(FrobInnerProd<""> $m, $n, adj<"C">, use<"AB">)),
/* A */ (b<"gemm"> $layout, (Rows $transa,
(Concat Char<"N">, transpose<"transb">, $m, $k),
(Concat $transb, Char<"T">, $k, $m)),
$n, $alpha,
(Rows $transa,
(Concat adj<"C">, $B, (ld $B, $transb, $ldb, $k, $n)),
(Concat $B, (ld $B, $transb, $ldb, $k, $n), adj<"C">)),
Constant<"1.0">, adj<"A">),

/* B */ (b<"gemm"> $layout, (Rows $transb,
(Concat transpose<"transa">, Char<"N">, $k, $n),
(Concat Char<"T">, $transa, $n, $k)),
$m, $alpha,
(Rows $transb,
(Concat $A, (ld $A, $transa, $lda, $m, $k), adj<"C">),
(Concat adj<"C">, $A, (ld $A, $transa, $lda, $m, $k))),
Constant<"1.0">, adj<"B">),
/* beta */ (FrobInnerProd<""> $m, $n, adj<"C">, input<"C">),
/* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">)
]
>;

Expand All @@ -230,14 +251,14 @@ def spmv : CallBlasPattern<(Op $layout, $uplo, $n, $alpha, $ap, $x, $incx, $beta
[cblas_layout, uplo, len, fp, ap<["n"]>, vinc<["n"]>, fp, vinc<["n"]>],
[
/* alpha */ (Seq<["y0", "triangular", "n"]>
(b<"spmv"> $layout, $uplo, $n, Constant<"1.0">, $ap, $x, $incx, Constant<"0.0">, use<"y0">, ConstantInt<1>),
(b<"dot"> $n, adj<"y">, $incy, use<"y0">, ConstantInt<1>)),
(b<"spmv"> $layout, $uplo, $n, Constant<"1.0">, $ap, $x, Constant<"0.0">, use<"y0">, ConstantInt<1>),
(b<"dot"> $n, adj<"y">, use<"y0">, ConstantInt<1>)),
/* ap */ (Seq<[]>
(b<"spr2"> $layout, $uplo, $n, $alpha, $x, $incx, adj<"y">, $incy, adj<"ap">),
(DiagUpdateSPMV<""> $uplo, $n, $alpha, $x, $incx, adj<"y">, $incy, adj<"ap">)),
/* x */ (b<"spmv"> $layout, $uplo, $n, $alpha, $ap, adj<"y">, $incy, Constant<"1.0">, adj<"x">, $incx),
/* beta */ (b<"dot"> $n, adj<"y">, $incy, input<"y">, $incy),
/* y */ (b<"scal"> $n, $beta, adj<"y">, $incy)
(b<"spr2"> $layout, $uplo, $n, $alpha, $x, adj<"y">, adj<"ap">),
(DiagUpdateSPMV<""> $uplo, $n, $alpha, $x, adj<"y">, adj<"ap">)),
/* x */ (b<"spmv"> $layout, $uplo, $n, $alpha, $ap, adj<"y">, Constant<"1.0">, adj<"x">),
/* beta */ (b<"dot"> $n, adj<"y">, input<"y">),
/* y */ (b<"scal"> $n, $beta, adj<"y">)
]
>;

Expand Down
Loading

0 comments on commit 5bcc9e9

Please sign in to comment.