Skip to content

Commit

Permalink
[MLIR][ActivityAnalysis] create activity interface (#1648)
Browse files Browse the repository at this point in the history
* [MLIR][ActivityAnalysis] create activity interface

* wip

* fixup
  • Loading branch information
wsmoses authored Jan 30, 2024
1 parent ee0c678 commit ff15cd8
Show file tree
Hide file tree
Showing 15 changed files with 272 additions and 140 deletions.
32 changes: 31 additions & 1 deletion enzyme/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,35 @@ gentbl(
)],
tblgen = ":enzyme-tblgen",
td_file = "Enzyme/MLIR/Implementations/ArithDerivatives.td",
td_srcs = ["Enzyme/MLIR/Implementations/ArithDerivatives.td"],
td_srcs = ["Enzyme/MLIR/Implementations/ArithDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"],
deps = [
":enzyme-tblgen",
],
)

gentbl(
name = "llvm-derivatives",
tbl_outs = [(
"-gen-mlir-derivatives",
"Enzyme/MLIR/Implementations/LLVMDerivatives.inc",
)],
tblgen = ":enzyme-tblgen",
td_file = "Enzyme/MLIR/Implementations/LLVMDerivatives.td",
td_srcs = ["Enzyme/MLIR/Implementations/LLVMDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"],
deps = [
":enzyme-tblgen",
],
)

gentbl(
name = "nvvm-derivatives",
tbl_outs = [(
"-gen-mlir-derivatives",
"Enzyme/MLIR/Implementations/NVVMDerivatives.inc",
)],
tblgen = ":enzyme-tblgen",
td_file = "Enzyme/MLIR/Implementations/NVVMDerivatives.td",
td_srcs = ["Enzyme/MLIR/Implementations/NVVMDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"],
deps = [
":enzyme-tblgen",
],
Expand Down Expand Up @@ -420,6 +448,8 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":arith-derivatives",
":llvm-derivatives",
":nvvm-derivatives",
":EnzymeOpsIncGen",
":EnzymePassesIncGen",
":EnzymeTypesIncGen",
Expand Down
164 changes: 49 additions & 115 deletions enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
Expand All @@ -19,6 +18,8 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/Support/ModRef.h"

#include "Interfaces/AutoDiffOpInterface.h"

const char *KnownInactiveFunctionsStartingWith[] = {
"f90io",
"$ss5print",
Expand Down Expand Up @@ -467,9 +468,14 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR,
if (isa<func::ReturnOp>(I))
return true;

if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(I)) {
if (ifaceOp.isInactive()) {
return true;
}
}

// Branch, unreachable, and previously computed constants are inactive
if (isa<LLVM::UnreachableOp>(I) /*|| isa<cf::BranchOp>(I)*/ ||
ConstantOperations.contains(I)) {
if (/*|| isa<cf::BranchOp>(I)*/ ConstantOperations.contains(I)) {
return true;
}

Expand Down Expand Up @@ -592,12 +598,6 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR,
// *I
// << "\n";

if (isa<NVVM::Barrier0Op, LLVM::AssumeOp, LLVM::StackSaveOp,
LLVM::StackRestoreOp, LLVM::LifetimeStartOp, LLVM::LifetimeEndOp,
LLVM::Prefetch, LLVM::MemsetOp>(I)) {
InsertConstantOperation(TR, I);
}

// if (auto II = dyn_cast<IntrinsicInst>(I)) {
// switch (II->getIntrinsicID()) {
// case Intrinsic::nvvm_barrier0:
Expand Down Expand Up @@ -1121,13 +1121,6 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR,
if (Val.getType().isa<LLVM::LLVMTokenType>())
return true;

// All function pointers are considered active in case an augmented primal
// or reverse is needed
if (Val.getDefiningOp() &&
isa<func::ConstantOp, LLVM::InlineAsmOp>(Val.getDefiningOp())) {
return false;
}

/// If we've already shown this value to be inactive
if (ConstantValues.find(Val) != ConstantValues.end()) {
return true;
Expand All @@ -1142,44 +1135,11 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR,
if (matchPattern(Val, m_Constant()))
return true;

// if (auto CD = dyn_cast<ConstantDataSequential>(Val)) {
// // inductively assume inactive
// ConstantValues.insert(CD);
// for (size_t i = 0, len = CD->getNumElements(); i < len; i++) {
// if (!isConstantValue(TR, CD->getElementAsConstant(i))) {
// ConstantValues.erase(CD);
// ActiveValues.insert(CD);
// return false;
// }
// }
// return true;
// }
// if (auto CD = dyn_cast<ConstantAggregate>(Val)) {
// // inductively assume inactive
// ConstantValues.insert(CD);
// for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) {
// if (!isConstantValue(TR, CD->getOperand(i))) {
// ConstantValues.erase(CD);
// ActiveValues.insert(CD);
// return false;
// }
// }
// return true;
// }

if (Operation *definingOp = Val.getDefiningOp()) {
// Undef and non-global constants are inactive.
if (isa<LLVM::UndefOp, LLVM::ConstantOp>(definingOp)) {
return true;
}

// Ops derived from intrinsics.
// NOTE: this was written with the assumption that Value is-a Operation,
// which is not the case in MLIR.
if (isa<NVVM::Barrier0Op, LLVM::AssumeOp, LLVM::StackSaveOp,
LLVM::StackRestoreOp, LLVM::LifetimeStartOp, LLVM::LifetimeEndOp,
LLVM::Prefetch>(definingOp)) {
return true;
if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(definingOp)) {
if (ifaceOp.isInactive()) {
return true;
}
}
}

Expand Down Expand Up @@ -1494,6 +1454,17 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR,
}
}

if (auto op = TmpOrig.getDefiningOp())
if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
if (ifaceOp.isInactive()) {
InsertConstantValue(TR, Val);
if (TmpOrig != Val) {
InsertConstantValue(TR, TmpOrig);
}
return true;
}
}

UpHypothesis = std::shared_ptr<mlir::enzyme::ActivityAnalyzer>(
new mlir::enzyme::ActivityAnalyzer(*this, UP));
UpHypothesis->ConstantValues.insert(Val);
Expand Down Expand Up @@ -1828,16 +1799,12 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR,
if (notForAnalysis.count(op->getBlock()))
return false;

if (auto iasm = dyn_cast<LLVM::InlineAsmOp>(op)) {
if (iasm.getAsmString().contains("exit") ||
iasm.getAsmString().contains("cpuid"))
return false;
}
if (isa<NVVM::Barrier0Op, LLVM::AssumeOp, LLVM::StackSaveOp,
LLVM::StackRestoreOp, LLVM::LifetimeStartOp, LLVM::LifetimeEndOp,
LLVM::Prefetch>(op)) {
return true;
}
if (auto op = TmpOrig.getDefiningOp())
if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
if (ifaceOp.isInactive()) {
return false;
}
}

// If this is a malloc or free, this doesn't impact the activity
if (auto CI = dyn_cast<CallOpInterface>(op)) {
Expand Down Expand Up @@ -2537,21 +2504,16 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin(
return false;
}

// if (EnzymePrintActivity)
// llvm::errs() << " < UPSEARCH" << (int)directions << ">" << *inst <<
// "\n";

// cpuid is explicitly an inactive instruction
if (auto iasm = dyn_cast<LLVM::InlineAsmOp>(op)) {
if (iasm.getAsmString().contains("cpuid")) {
// if (EnzymePrintActivity)
// llvm::errs() << " constant instruction from known cpuid instruction
// "
// << *inst << "\n";
if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
if (ifaceOp.isInactive()) {
return true;
}
}

// if (EnzymePrintActivity)
// llvm::errs() << " < UPSEARCH" << (int)directions << ">" << *inst <<
// "\n";

if (auto store = dyn_cast<LLVM::StoreOp>(op)) {
if (isConstantValue(TR, store.getValue()) ||
isConstantValue(TR, store.getAddr())) {
Expand Down Expand Up @@ -2643,15 +2605,6 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin(
return true;
}
}
// Intrinsics known always to be inactive
if (isa<NVVM::Barrier0Op, LLVM::AssumeOp, LLVM::StackSaveOp,
LLVM::StackRestoreOp, LLVM::LifetimeStartOp, LLVM::LifetimeEndOp,
LLVM::Prefetch, LLVM::MemsetOp>(op)) {
// if (EnzymePrintActivity)
// llvm::errs() << "constant(" << (int)directions << ") up-intrinsic "
// << *inst << "\n";
return true;
}

if (auto gep = dyn_cast<LLVM::GEPOp>(op)) {
// A gep's only args that could make it active is the pointer operand
Expand Down Expand Up @@ -2731,13 +2684,7 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin(
return false;
}

if (isa<LLVM::SIToFPOp, LLVM::UIToFPOp, LLVM::FPToSIOp, LLVM::FPToUIOp>(op)) {
// if (EnzymePrintActivity)
// llvm::errs() << "constant(" << (int)directions << ") up-fpcst:" <<
// *inst
// << "\n";
return true;
} else {
{
bool seenuse = false;
//! TODO does not consider reading from global memory that is active and not
//! an argument
Expand Down Expand Up @@ -2871,6 +2818,13 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers(
// }
}

if (UA != UseActivity::AllStores) {
if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(a)) {
if (ifaceOp.isArgInactive(parent))
return true;
}
}

// if (EnzymePrintActivity)
// llvm::errs() << " considering use of " << *val << " - " << *a
// << "\n";
Expand Down Expand Up @@ -3078,14 +3032,6 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers(
continue;
}

if (isa<LLVM::SIToFPOp, LLVM::UIToFPOp, LLVM::FPToSIOp, LLVM::FPToUIOp>(
a)) {
// if (EnzymePrintActivity)
// llvm::errs() << "found constant(" << (int)directions
// << ") si-fp use:" << *val << " user " << *a << "\n";
continue;
}

//
// TODO: this should not happen in valid MLIR...
//
Expand Down Expand Up @@ -3367,10 +3313,6 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers(
LLVM::SExtOp,
LLVM::ZExtOp,
LLVM::TruncOp,
LLVM::SIToFPOp,
LLVM::UIToFPOp,
LLVM::FPToSIOp,
LLVM::FPToUIOp,
LLVM::FPExtOp,
LLVM::FPTruncOp
// clang-format on
Expand Down Expand Up @@ -3451,6 +3393,11 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned(
continue;
}

if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(a)) {
if (ifaceOp.isArgInactive(val))
return true;
}

if (isa<LLVM::ReturnOp>(a)) {
if (ActiveReturns == DIFFE_TYPE::CONSTANT)
continue;
Expand Down Expand Up @@ -3509,19 +3456,6 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned(
// TODO: in MLIR, users are always operations
//
if (Operation *inst = a) {
auto mayWriteToMemory = [](Operation *op) {
auto iface = dyn_cast<MemoryEffectOpInterface>(op);
if (!iface)
return true;

SmallVector<MemoryEffects::EffectInstance> effects;
iface.getEffects(effects);
return llvm::any_of(
effects, [](const MemoryEffects::EffectInstance &effect) {
return isa<MemoryEffects::Write>(effect.getEffect());
});
};

if (!mayWriteToMemory(inst) /*||
(isa<CallInst>(inst) && AA.onlyReadsMemory(cast<CallInst>(inst)))*/) {
// // if not written to memory and returning a known constant, this
Expand Down
19 changes: 19 additions & 0 deletions enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

#include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h"

#include "Interfaces/AutoDiffOpInterface.h"

using namespace mlir;
using namespace mlir::dataflow;
using enzyme::AliasClassLattice;
Expand Down Expand Up @@ -508,6 +510,15 @@ class DenseForwardActivityAnalysis
ForwardMemoryActivity *after) override {
join(after, before);
ChangeResult result = ChangeResult::NoChange;

// TODO If we know this is inactive by definition
// if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
// if (ifaceOp.isInactive()) {
// propagateIfChanged(after, result);
// return;
// }
// }

auto memory = dyn_cast<MemoryEffectOpInterface>(op);
// If we can't reason about the memory effects, then conservatively assume
// we can't deduce anything about activity via side-effects.
Expand Down Expand Up @@ -657,6 +668,14 @@ class DenseBackwardActivityAnalysis

void visitOperation(Operation *op, const BackwardMemoryActivity &after,
BackwardMemoryActivity *before) override {

// TODO: If we know this is inactive by definition
// if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
// if (ifaceOp.isInactive()) {
// return;
// }
// }

// Initialize the return activity of arguments.
if (op->hasTrait<OpTrait::ReturnLike>() && op->getParentOp() == parentOp) {
for (const auto &[arg, argActivity] :
Expand Down
Loading

0 comments on commit ff15cd8

Please sign in to comment.