Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 30, 2024
1 parent 61116dc commit fd52067
Show file tree
Hide file tree
Showing 13 changed files with 194 additions and 108 deletions.
32 changes: 31 additions & 1 deletion enzyme/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,35 @@ gentbl(
)],
tblgen = ":enzyme-tblgen",
td_file = "Enzyme/MLIR/Implementations/ArithDerivatives.td",
td_srcs = ["Enzyme/MLIR/Implementations/ArithDerivatives.td"],
td_srcs = ["Enzyme/MLIR/Implementations/ArithDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"],
deps = [
":enzyme-tblgen",
],
)

gentbl(
name = "llvm-derivatives",
tbl_outs = [(
"-gen-mlir-derivatives",
"Enzyme/MLIR/Implementations/LLVMDerivatives.inc",
)],
tblgen = ":enzyme-tblgen",
td_file = "Enzyme/MLIR/Implementations/LLVMDerivatives.td",
td_srcs = ["Enzyme/MLIR/Implementations/LLVMDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"],
deps = [
":enzyme-tblgen",
],
)

gentbl(
name = "nvvm-derivatives",
tbl_outs = [(
"-gen-mlir-derivatives",
"Enzyme/MLIR/Implementations/NVVMDerivatives.inc",
)],
tblgen = ":enzyme-tblgen",
td_file = "Enzyme/MLIR/Implementations/NVVMDerivatives.td",
td_srcs = ["Enzyme/MLIR/Implementations/NVVMDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"],
deps = [
":enzyme-tblgen",
],
Expand Down Expand Up @@ -420,6 +448,8 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":arith-derivatives",
":llvm-derivatives",
":nvvm-derivatives",
":EnzymeOpsIncGen",
":EnzymePassesIncGen",
":EnzymeTypesIncGen",
Expand Down
93 changes: 23 additions & 70 deletions enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -374,6 +373,18 @@ bool mlir::enzyme::ActivityAnalyzer::isFunctionArgumentConstant(
return false;
}

bool mayWriteToMemory(Operation *op) {
auto iface = dyn_cast<MemoryEffectOpInterface>(op);
if (!iface)
return true;

SmallVector<MemoryEffects::EffectInstance> effects;
iface.getEffects(effects);
return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
return isa<MemoryEffects::Write>(effect.getEffect());
});
}

// TODO: better support for known function calls. Ideally, they should become
// operations, but we also need parity with LLVM-enzyme.
/// Call the function propagateFromOperand on all operands of CI
Expand Down Expand Up @@ -468,16 +479,15 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR,
// during adjoint generation)
if (isa<func::ReturnOp>(I))
return true;
if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
if (ifaceOp.isInactive()) {
return true;
}
}

if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(I)) {
if (ifaceOp.isInactive()) {
return true;
}
}

// Branch, unreachable, and previously computed constants are inactive
if (/*|| isa<cf::BranchOp>(I)*/ ||
ConstantOperations.contains(I)) {
if (/*|| isa<cf::BranchOp>(I)*/ ConstantOperations.contains(I)) {
return true;
}

Expand Down Expand Up @@ -600,10 +610,6 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR,
// *I
// << "\n";

if (isa<NVVM::Barrier0Op>(I)) {
InsertConstantOperation(TR, I);
}

// if (auto II = dyn_cast<IntrinsicInst>(I)) {
// switch (II->getIntrinsicID()) {
// case Intrinsic::nvvm_barrier0:
Expand Down Expand Up @@ -1127,13 +1133,6 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR,
if (Val.getType().isa<LLVM::LLVMTokenType>())
return true;

// All function pointers are considered active in case an augmented primal
// or reverse is needed
if (Val.getDefiningOp() &&
isa<func::ConstantOp, LLVM::InlineAsmOp>(Val.getDefiningOp())) {
return false;
}

/// If we've already shown this value to be inactive
if (ConstantValues.find(Val) != ConstantValues.end()) {
return true;
Expand All @@ -1150,10 +1149,10 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR,

if (Operation *definingOp = Val.getDefiningOp()) {
if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(definingOp)) {
if (ifaceOp.isInactive()) {
return true;
}
}
if (ifaceOp.isInactive()) {
return true;
}
}
}

if (auto arg = Val.dyn_cast<BlockArgument>()) {
Expand Down Expand Up @@ -1819,17 +1818,6 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR,
}
}

if (auto iasm = dyn_cast<LLVM::InlineAsmOp>(op)) {
if (iasm.getAsmString().contains("exit") ||
iasm.getAsmString().contains("cpuid"))
return false;
}
if (isa<NVVM::Barrier0Op, LLVM::AssumeOp, LLVM::StackSaveOp,
LLVM::StackRestoreOp, LLVM::LifetimeStartOp, LLVM::LifetimeEndOp,
LLVM::Prefetch>(op)) {
return true;
}

// If this is a malloc or free, this doesn't impact the activity
if (auto CI = dyn_cast<CallOpInterface>(op)) {
if (CI->hasAttr("enzyme_inactive"))
Expand Down Expand Up @@ -2538,17 +2526,6 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin(
// llvm::errs() << " < UPSEARCH" << (int)directions << ">" << *inst <<
// "\n";

// cpuid is explicitly an inactive instruction
if (auto iasm = dyn_cast<LLVM::InlineAsmOp>(op)) {
if (iasm.getAsmString().contains("cpuid")) {
// if (EnzymePrintActivity)
// llvm::errs() << " constant instruction from known cpuid instruction
// "
// << *inst << "\n";
return true;
}
}

if (auto store = dyn_cast<LLVM::StoreOp>(op)) {
if (isConstantValue(TR, store.getValue()) ||
isConstantValue(TR, store.getAddr())) {
Expand Down Expand Up @@ -2640,13 +2617,6 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin(
return true;
}
}
// Intrinsics known always to be inactive
if (isa<NVVM::Barrier0Op>(op)) {
// if (EnzymePrintActivity)
// llvm::errs() << "constant(" << (int)directions << ") up-intrinsic "
// << *inst << "\n";
return true;
}

if (auto gep = dyn_cast<LLVM::GEPOp>(op)) {
// A gep's only args that could make it active is the pointer operand
Expand Down Expand Up @@ -3355,10 +3325,6 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers(
LLVM::SExtOp,
LLVM::ZExtOp,
LLVM::TruncOp,
LLVM::SIToFPOp,
LLVM::UIToFPOp,
LLVM::FPToSIOp,
LLVM::FPToUIOp,
LLVM::FPExtOp,
LLVM::FPTruncOp
// clang-format on
Expand Down Expand Up @@ -3502,19 +3468,6 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned(
// TODO: in MLIR, users are always operations
//
if (Operation *inst = a) {
auto mayWriteToMemory = [](Operation *op) {
auto iface = dyn_cast<MemoryEffectOpInterface>(op);
if (!iface)
return true;

SmallVector<MemoryEffects::EffectInstance> effects;
iface.getEffects(effects);
return llvm::any_of(
effects, [](const MemoryEffects::EffectInstance &effect) {
return isa<MemoryEffects::Write>(effect.getEffect());
});
};

if (!mayWriteToMemory(inst) /*||
(isa<CallInst>(inst) && AA.onlyReadsMemory(cast<CallInst>(inst)))*/) {
// // if not written to memory and returning a known constant, this
Expand Down
26 changes: 13 additions & 13 deletions enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,13 +511,13 @@ class DenseForwardActivityAnalysis
join(after, before);
ChangeResult result = ChangeResult::NoChange;

// If we know this is inactive by definition
if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
if (ifaceOp.isInactive()) {
propagateIfChanged(after, result);
return;
}
}
// TODO If we know this is inactive by definition
// if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
// if (ifaceOp.isInactive()) {
// propagateIfChanged(after, result);
// return;
// }
// }

auto memory = dyn_cast<MemoryEffectOpInterface>(op);
// If we can't reason about the memory effects, then conservatively assume
Expand Down Expand Up @@ -669,12 +669,12 @@ class DenseBackwardActivityAnalysis
void visitOperation(Operation *op, const BackwardMemoryActivity &after,
BackwardMemoryActivity *before) override {

// If we know this is inactive by definition
if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
if (ifaceOp.isInactive()) {
return;
}
}
// TODO: If we know this is inactive by definition
// if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
// if (ifaceOp.isInactive()) {
// return;
// }
// }

// Initialize the return activity of arguments.
if (op->hasTrait<OpTrait::ReturnLike>() && op->getParentOp() == parentOp) {
Expand Down
25 changes: 1 addition & 24 deletions enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td
Original file line number Diff line number Diff line change
@@ -1,25 +1,5 @@
class MLIRDerivative<string dialect_, string opName_, dag patternToMatch, list<dag> resultOps> {
string dialect = dialect_;
string opName = opName_;
dag PatternToMatch = patternToMatch;
list<dag> ArgDerivatives = resultOps;
}
include "Common.td"

class Operation<bit usesPrimal_, bit usesShadow_, bit usesCustom_=0> {
bit usesPrimal = usesPrimal_;
bit usesShadow = usesShadow_;
bit usesCustom = usesCustom_;
}

class DiffeRetIndex<list<int> indices_> {
list<int> indices = indices_;
}
def DiffeRet : DiffeRetIndex<[-1]>;

class Inst<string mnemonic, string dialect_> : Operation</*primal*/1, /*shadow*/0> {
string name = mnemonic;
string dialect = dialect_;
}
class ArithInst<string m> : Inst<m, "arith">;

def AddF : ArithInst<"arith::AddFOp">;
Expand All @@ -32,9 +12,6 @@ def RemF : ArithInst<"arith::RemFOp">;
def CheckedMulF : ArithInst<"arith::MulFOp">;
def CheckedDivF : ArithInst<"arith::DivFOp">;

def Op {
}

def : MLIRDerivative<"arith", "AddFOp", (Op $x, $y),
[
(DiffeRet),
Expand Down
7 changes: 7 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@ set(LLVM_TARGET_DEFINITIONS LLVMDerivatives.td)
enzyme_tablegen(LLVMDerivatives.inc -gen-mlir-derivatives)
add_public_tablegen_target(LLVMDerivativesIncGen)

set(LLVM_TARGET_DEFINITIONS NVVMDerivatives.td)
enzyme_tablegen(NVVMDerivatives.inc -gen-mlir-derivatives)
add_public_tablegen_target(NVVMDerivativesIncGen)

add_mlir_library(MLIREnzymeImplementations
ArithAutoDiffOpInterfaceImpl.cpp
LLVMAutoDiffOpInterfaceImpl.cpp
NVVMAutoDiffOpInterfaceImpl.cpp
MemRefAutoDiffOpInterfaceImpl.cpp
LinalgAutoDiffOpInterfaceImpl.cpp
BuiltinAutoDiffTypeInterfaceImpl.cpp
Expand All @@ -18,6 +23,8 @@ add_mlir_library(MLIREnzymeImplementations
DEPENDS
MLIRAutoDiffOpInterfaceIncGen
ArithDerivativesIncGen
LLVMDerivativesIncGen
NVVMDerivativesIncGen

LINK_LIBS PUBLIC
MLIRArithDialect
Expand Down
30 changes: 30 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/Common.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
class InactiveOp<string dialect_, string opName_> {
string dialect = dialect_;
string opName = opName_;
}

class MLIRDerivative<string dialect_, string opName_, dag patternToMatch, list<dag> resultOps> {
string dialect = dialect_;
string opName = opName_;
dag PatternToMatch = patternToMatch;
list<dag> ArgDerivatives = resultOps;
}

class Operation<bit usesPrimal_, bit usesShadow_, bit usesCustom_=0> {
bit usesPrimal = usesPrimal_;
bit usesShadow = usesShadow_;
bit usesCustom = usesCustom_;
}

class DiffeRetIndex<list<int> indices_> {
list<int> indices = indices_;
}
def DiffeRet : DiffeRetIndex<[-1]>;

class Inst<string mnemonic, string dialect_> : Operation</*primal*/1, /*shadow*/0> {
string name = mnemonic;
string dialect = dialect_;
}

def Op {
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace enzyme {
void registerArithDialectAutoDiffInterface(DialectRegistry &registry);
void registerBuiltinDialectAutoDiffInterface(DialectRegistry &registry);
void registerLLVMDialectAutoDiffInterface(DialectRegistry &registry);
void registerNVVMDialectAutoDiffInterface(DialectRegistry &registry);
void registerMemRefDialectAutoDiffInterface(DialectRegistry &registry);
void registerSCFDialectAutoDiffInterface(DialectRegistry &registry);
void registerLinalgDialectAutoDiffInterface(DialectRegistry &registry);
Expand Down
13 changes: 13 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ namespace {
} // namespace

namespace {
struct InlineAsmActivityInterface
: public ActivityOpInterface::ExternalModel<InlineAsmActivityInterface,
LLVM::InlineAsmOp> {
bool isInactive(Operation *op) const {
auto asmOp = cast<LLVM::InlineAsmOp>(op);
auto str = asmOp.getAsmString();
return str.contains("cpuid") || str.contains("exit");
}
bool isArgInactive(Operation *op, mlir::Value) const {
return isInactive(op);
}
};

struct LoadOpInterface
: public AutoDiffOpInterface::ExternalModel<LoadOpInterface, LLVM::LoadOp> {
LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder,
Expand Down
17 changes: 17 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
include "Common.td"

def : InactiveOp<"LLVM", "SIToFPOp">;
def : InactiveOp<"LLVM", "UIToFPOp">;
def : InactiveOp<"LLVM", "FPToSIOp">;
def : InactiveOp<"LLVM", "FPToUIOp">;
def : InactiveOp<"LLVM", "AssumeOp">;
def : InactiveOp<"LLVM", "StackSaveOp">;
def : InactiveOp<"LLVM", "StackRestoreOp">;
def : InactiveOp<"LLVM", "LifetimeStartOp">;
def : InactiveOp<"LLVM", "LifetimeEndOp">;
def : InactiveOp<"LLVM", "Prefetch">;
def : InactiveOp<"LLVM", "MemsetOp">;

def : InactiveOp<"LLVM", "UndefOp">;
def : InactiveOp<"LLVM", "ConstantOp">;
def : InactiveOp<"LLVM", "UnreachableOp">;
Loading

0 comments on commit fd52067

Please sign in to comment.