Skip to content

Commit

Permalink
Simplify activity analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 8, 2023
1 parent 774367c commit 7d256c0
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 41 deletions.
38 changes: 0 additions & 38 deletions enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,27 +484,6 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) {
CI->getArgOperand(0) != val && CI->getArgOperand(1) != val)
return true;

// only the float arg input is potentially active
if (Name == "frexp" || Name == "frexpf" || Name == "frexpl") {
return val != CI->getOperand(0);
}

// The relerr argument is inactive
if (Name == "Faddeeva_erf" || Name == "Faddeeva_erfc" ||
Name == "Faddeeva_erfcx" || Name == "Faddeeva_erfi" ||
Name == "Faddeeva_dawson") {
#if LLVM_VERSION_MAJOR >= 14
for (size_t i = 0; i < CI->arg_size() - 1; i++)
#else
for (size_t i = 0; i < CI->getNumArgOperands() - 1; i++)
#endif
{
if (val == CI->getOperand(i))
return false;
}
return true;
}

// only the buffer is active for mpi send/recv
if (Name == "MPI_Recv" || Name == "PMPI_Recv" || Name == "MPI_Send" ||
Name == "PMPI_Send") {
Expand Down Expand Up @@ -550,23 +529,6 @@ static inline void propagateArgumentInformation(
propagateFromOperand(CI.getArgOperand(0));
return;
}
if (Name == "frexp" || Name == "frexpf" || Name == "frexpl") {
propagateFromOperand(CI.getOperand(0));
return;
}
if (Name == "Faddeeva_erf" || Name == "Faddeeva_erfc" ||
Name == "Faddeeva_erfcx" || Name == "Faddeeva_erfi" ||
Name == "Faddeeva_dawson") {
#if LLVM_VERSION_MAJOR >= 14
for (size_t i = 0; i < CI.arg_size() - 1; i++)
#else
for (size_t i = 0; i < CI.getNumArgOperands() - 1; i++)
#endif
{
propagateFromOperand(CI.getOperand(i));
}
return;
}

if (Name == "julia.call" || Name == "julia.call2") {
#if LLVM_VERSION_MAJOR >= 14
Expand Down
6 changes: 3 additions & 3 deletions enzyme/Enzyme/InstructionDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def : CallPattern<(Op $x, $tbd),
["Faddeeva_erf"],
[
(ToStruct2 (CFMul (DiffeRet), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x)))))),
(AssertingInactiveArg)
(InactiveArg) // relerr
],
(ForwardFromSummedReverse),
[ReadNone, NoUnwind]
Expand All @@ -560,7 +560,7 @@ def : CallPattern<(Op $x, $tbd),
["Faddeeva_erfi"],
[
(ToStruct2 (CFMul (DiffeRet), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFMul $x, $x))))),
(AssertingInactiveArg)
(InactiveArg) // relerr
],
(ForwardFromSummedReverse),
[ReadNone, NoUnwind]
Expand All @@ -570,7 +570,7 @@ def : CallPattern<(Op $x, $tbd),
["Faddeeva_erfc"],
[
(ToStruct2 (CFMul (DiffeRet), (CFMul (ConstantCFP<"-1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x)))))),
(AssertingInactiveArg)
(InactiveArg) // relerr
],
(ForwardFromSummedReverse),
[ReadNone, NoUnwind]
Expand Down
16 changes: 16 additions & 0 deletions enzyme/tools/enzyme-tblgen/blasDeclUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,22 @@ void emitBlasDeclUpdater(const RecordKeeper &RK, raw_ostream &os) {
<< attrName << "));\n";
os << " #endif \n";
}
ListInit *argOps = pattern->getValueAsListInit("ArgDerivatives");
for (auto argOpEn : enumerate(*argOps)) {
size_t argIdx = argOpEn.index();
if (DagInit *resultRoot = dyn_cast<DagInit>(argOpEn.value())) {
auto opName = resultRoot->getOperator()->getAsString();
auto Def = cast<DefInit>(resultRoot->getOperator())->getDef();
if (opName == "InactiveArgSpec" ||
Def->isSubClassOf("InactiveArgSpec")) {
if (!Def->getValueAsBit("asserting"))
os << " F.addParamAttr(" << argOpEn.index()
<< ", llvm::Attribute::get(F.getContext(), "
"\"enzyme_inactive\"));\n";
continue;
}
}
}
os << " }\n";
}
}
Expand Down

0 comments on commit 7d256c0

Please sign in to comment.