Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][ActivityAnalysis] create activity interface #1648

Merged
merged 3 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion enzyme/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,35 @@ gentbl(
)],
tblgen = ":enzyme-tblgen",
td_file = "Enzyme/MLIR/Implementations/ArithDerivatives.td",
td_srcs = ["Enzyme/MLIR/Implementations/ArithDerivatives.td"],
td_srcs = ["Enzyme/MLIR/Implementations/ArithDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"],
deps = [
":enzyme-tblgen",
],
)

gentbl(
name = "llvm-derivatives",
tbl_outs = [(
"-gen-mlir-derivatives",
"Enzyme/MLIR/Implementations/LLVMDerivatives.inc",
)],
tblgen = ":enzyme-tblgen",
td_file = "Enzyme/MLIR/Implementations/LLVMDerivatives.td",
td_srcs = ["Enzyme/MLIR/Implementations/LLVMDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"],
deps = [
":enzyme-tblgen",
],
)

gentbl(
name = "nvvm-derivatives",
tbl_outs = [(
"-gen-mlir-derivatives",
"Enzyme/MLIR/Implementations/NVVMDerivatives.inc",
)],
tblgen = ":enzyme-tblgen",
td_file = "Enzyme/MLIR/Implementations/NVVMDerivatives.td",
td_srcs = ["Enzyme/MLIR/Implementations/NVVMDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"],
deps = [
":enzyme-tblgen",
],
Expand Down Expand Up @@ -420,6 +448,8 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":arith-derivatives",
":llvm-derivatives",
":nvvm-derivatives",
":EnzymeOpsIncGen",
":EnzymePassesIncGen",
":EnzymeTypesIncGen",
Expand Down
164 changes: 49 additions & 115 deletions enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
Expand All @@ -19,6 +18,8 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/Support/ModRef.h"

#include "Interfaces/AutoDiffOpInterface.h"

const char *KnownInactiveFunctionsStartingWith[] = {
"f90io",
"$ss5print",
Expand Down Expand Up @@ -467,9 +468,14 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR,
if (isa<func::ReturnOp>(I))
return true;

if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(I)) {
if (ifaceOp.isInactive()) {
return true;
}
}

// Branch, unreachable, and previously computed constants are inactive
if (isa<LLVM::UnreachableOp>(I) /*|| isa<cf::BranchOp>(I)*/ ||
ConstantOperations.contains(I)) {
if (/*|| isa<cf::BranchOp>(I)*/ ConstantOperations.contains(I)) {
return true;
}

Expand Down Expand Up @@ -592,12 +598,6 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR,
// *I
// << "\n";

if (isa<NVVM::Barrier0Op, LLVM::AssumeOp, LLVM::StackSaveOp,
LLVM::StackRestoreOp, LLVM::LifetimeStartOp, LLVM::LifetimeEndOp,
LLVM::Prefetch, LLVM::MemsetOp>(I)) {
InsertConstantOperation(TR, I);
}

// if (auto II = dyn_cast<IntrinsicInst>(I)) {
// switch (II->getIntrinsicID()) {
// case Intrinsic::nvvm_barrier0:
Expand Down Expand Up @@ -1121,13 +1121,6 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR,
if (Val.getType().isa<LLVM::LLVMTokenType>())
return true;

// All function pointers are considered active in case an augmented primal
// or reverse is needed
if (Val.getDefiningOp() &&
isa<func::ConstantOp, LLVM::InlineAsmOp>(Val.getDefiningOp())) {
return false;
}

/// If we've already shown this value to be inactive
if (ConstantValues.find(Val) != ConstantValues.end()) {
return true;
Expand All @@ -1142,44 +1135,11 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR,
if (matchPattern(Val, m_Constant()))
return true;

// if (auto CD = dyn_cast<ConstantDataSequential>(Val)) {
// // inductively assume inactive
// ConstantValues.insert(CD);
// for (size_t i = 0, len = CD->getNumElements(); i < len; i++) {
// if (!isConstantValue(TR, CD->getElementAsConstant(i))) {
// ConstantValues.erase(CD);
// ActiveValues.insert(CD);
// return false;
// }
// }
// return true;
// }
// if (auto CD = dyn_cast<ConstantAggregate>(Val)) {
// // inductively assume inactive
// ConstantValues.insert(CD);
// for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) {
// if (!isConstantValue(TR, CD->getOperand(i))) {
// ConstantValues.erase(CD);
// ActiveValues.insert(CD);
// return false;
// }
// }
// return true;
// }

if (Operation *definingOp = Val.getDefiningOp()) {
// Undef and non-global constants are inactive.
if (isa<LLVM::UndefOp, LLVM::ConstantOp>(definingOp)) {
return true;
}

// Ops derived from intrinsics.
// NOTE: this was written with the assumption that Value is-a Operation,
// which is not the case in MLIR.
if (isa<NVVM::Barrier0Op, LLVM::AssumeOp, LLVM::StackSaveOp,
LLVM::StackRestoreOp, LLVM::LifetimeStartOp, LLVM::LifetimeEndOp,
LLVM::Prefetch>(definingOp)) {
return true;
if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(definingOp)) {
if (ifaceOp.isInactive()) {
return true;
}
}
}

Expand Down Expand Up @@ -1494,6 +1454,17 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR,
}
}

if (auto op = TmpOrig.getDefiningOp())
if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
if (ifaceOp.isInactive()) {
InsertConstantValue(TR, Val);
if (TmpOrig != Val) {
InsertConstantValue(TR, TmpOrig);
}
return true;
}
}

UpHypothesis = std::shared_ptr<mlir::enzyme::ActivityAnalyzer>(
new mlir::enzyme::ActivityAnalyzer(*this, UP));
UpHypothesis->ConstantValues.insert(Val);
Expand Down Expand Up @@ -1828,16 +1799,12 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR,
if (notForAnalysis.count(op->getBlock()))
return false;

if (auto iasm = dyn_cast<LLVM::InlineAsmOp>(op)) {
if (iasm.getAsmString().contains("exit") ||
iasm.getAsmString().contains("cpuid"))
return false;
}
if (isa<NVVM::Barrier0Op, LLVM::AssumeOp, LLVM::StackSaveOp,
LLVM::StackRestoreOp, LLVM::LifetimeStartOp, LLVM::LifetimeEndOp,
LLVM::Prefetch>(op)) {
return true;
}
if (auto op = TmpOrig.getDefiningOp())
if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
if (ifaceOp.isInactive()) {
return false;
}
}

// If this is a malloc or free, this doesn't impact the activity
if (auto CI = dyn_cast<CallOpInterface>(op)) {
Expand Down Expand Up @@ -2537,21 +2504,16 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin(
return false;
}

// if (EnzymePrintActivity)
// llvm::errs() << " < UPSEARCH" << (int)directions << ">" << *inst <<
// "\n";

// cpuid is explicitly an inactive instruction
if (auto iasm = dyn_cast<LLVM::InlineAsmOp>(op)) {
if (iasm.getAsmString().contains("cpuid")) {
// if (EnzymePrintActivity)
// llvm::errs() << " constant instruction from known cpuid instruction
// "
// << *inst << "\n";
if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
if (ifaceOp.isInactive()) {
return true;
}
}

// if (EnzymePrintActivity)
// llvm::errs() << " < UPSEARCH" << (int)directions << ">" << *inst <<
// "\n";

if (auto store = dyn_cast<LLVM::StoreOp>(op)) {
if (isConstantValue(TR, store.getValue()) ||
isConstantValue(TR, store.getAddr())) {
Expand Down Expand Up @@ -2643,15 +2605,6 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin(
return true;
}
}
// Intrinsics known always to be inactive
if (isa<NVVM::Barrier0Op, LLVM::AssumeOp, LLVM::StackSaveOp,
LLVM::StackRestoreOp, LLVM::LifetimeStartOp, LLVM::LifetimeEndOp,
LLVM::Prefetch, LLVM::MemsetOp>(op)) {
// if (EnzymePrintActivity)
// llvm::errs() << "constant(" << (int)directions << ") up-intrinsic "
// << *inst << "\n";
return true;
}

if (auto gep = dyn_cast<LLVM::GEPOp>(op)) {
// A gep's only args that could make it active is the pointer operand
Expand Down Expand Up @@ -2731,13 +2684,7 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin(
return false;
}

if (isa<LLVM::SIToFPOp, LLVM::UIToFPOp, LLVM::FPToSIOp, LLVM::FPToUIOp>(op)) {
// if (EnzymePrintActivity)
// llvm::errs() << "constant(" << (int)directions << ") up-fpcst:" <<
// *inst
// << "\n";
return true;
} else {
{
bool seenuse = false;
//! TODO does not consider reading from global memory that is active and not
//! an argument
Expand Down Expand Up @@ -2871,6 +2818,13 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers(
// }
}

if (UA != UseActivity::AllStores) {
if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(a)) {
if (ifaceOp.isArgInactive(parent))
return true;
}
}

// if (EnzymePrintActivity)
// llvm::errs() << " considering use of " << *val << " - " << *a
// << "\n";
Expand Down Expand Up @@ -3078,14 +3032,6 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers(
continue;
}

if (isa<LLVM::SIToFPOp, LLVM::UIToFPOp, LLVM::FPToSIOp, LLVM::FPToUIOp>(
a)) {
// if (EnzymePrintActivity)
// llvm::errs() << "found constant(" << (int)directions
// << ") si-fp use:" << *val << " user " << *a << "\n";
continue;
}

//
// TODO: this should not happen in valid MLIR...
//
Expand Down Expand Up @@ -3367,10 +3313,6 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers(
LLVM::SExtOp,
LLVM::ZExtOp,
LLVM::TruncOp,
LLVM::SIToFPOp,
LLVM::UIToFPOp,
LLVM::FPToSIOp,
LLVM::FPToUIOp,
LLVM::FPExtOp,
LLVM::FPTruncOp
// clang-format on
Expand Down Expand Up @@ -3451,6 +3393,11 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned(
continue;
}

if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(a)) {
if (ifaceOp.isArgInactive(val))
return true;
}

if (isa<LLVM::ReturnOp>(a)) {
if (ActiveReturns == DIFFE_TYPE::CONSTANT)
continue;
Expand Down Expand Up @@ -3509,19 +3456,6 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned(
// TODO: in MLIR, users are always operations
//
if (Operation *inst = a) {
auto mayWriteToMemory = [](Operation *op) {
auto iface = dyn_cast<MemoryEffectOpInterface>(op);
if (!iface)
return true;

SmallVector<MemoryEffects::EffectInstance> effects;
iface.getEffects(effects);
return llvm::any_of(
effects, [](const MemoryEffects::EffectInstance &effect) {
return isa<MemoryEffects::Write>(effect.getEffect());
});
};

if (!mayWriteToMemory(inst) /*||
(isa<CallInst>(inst) && AA.onlyReadsMemory(cast<CallInst>(inst)))*/) {
// // if not written to memory and returning a known constant, this
Expand Down
19 changes: 19 additions & 0 deletions enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

#include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h"

#include "Interfaces/AutoDiffOpInterface.h"

using namespace mlir;
using namespace mlir::dataflow;
using enzyme::AliasClassLattice;
Expand Down Expand Up @@ -508,6 +510,15 @@ class DenseForwardActivityAnalysis
ForwardMemoryActivity *after) override {
join(after, before);
ChangeResult result = ChangeResult::NoChange;

// TODO If we know this is inactive by definition
// if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
// if (ifaceOp.isInactive()) {
// propagateIfChanged(after, result);
// return;
// }
// }

auto memory = dyn_cast<MemoryEffectOpInterface>(op);
// If we can't reason about the memory effects, then conservatively assume
// we can't deduce anything about activity via side-effects.
Expand Down Expand Up @@ -657,6 +668,14 @@ class DenseBackwardActivityAnalysis

void visitOperation(Operation *op, const BackwardMemoryActivity &after,
BackwardMemoryActivity *before) override {

// TODO: If we know this is inactive by definition
// if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
// if (ifaceOp.isInactive()) {
// return;
// }
// }

// Initialize the return activity of arguments.
if (op->hasTrait<OpTrait::ReturnLike>() && op->getParentOp() == parentOp) {
for (const auto &[arg, argActivity] :
Expand Down
Loading
Loading