Skip to content

Commit

Permalink
[MLIR] Add read interface fwd
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 1, 2024
1 parent 724756e commit 1e58352
Show file tree
Hide file tree
Showing 14 changed files with 179 additions and 51 deletions.
30 changes: 30 additions & 0 deletions enzyme/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,34 @@ gentbl(
],
)

gentbl(
name = "cf-derivatives",
tbl_outs = [(
"-gen-mlir-derivatives",
"Enzyme/MLIR/Implementations/CFDerivatives.inc",
)],
tblgen = ":enzyme-tblgen",
td_file = "Enzyme/MLIR/Implementations/CFDerivatives.td",
td_srcs = ["Enzyme/MLIR/Implementations/CFDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"],
deps = [
":enzyme-tblgen",
],
)

gentbl(
name = "memref-derivatives",
tbl_outs = [(
"-gen-mlir-derivatives",
"Enzyme/MLIR/Implementations/MemRefDerivatives.inc",
)],
tblgen = ":enzyme-tblgen",
td_file = "Enzyme/MLIR/Implementations/MemRefDerivatives.td",
td_srcs = ["Enzyme/MLIR/Implementations/MemRefDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"],
deps = [
":enzyme-tblgen",
],
)

cc_library(
name = "EnzymeMLIR",
srcs = glob([
Expand All @@ -482,6 +510,8 @@ cc_library(
":llvm-derivatives",
":nvvm-derivatives",
":scf-derivatives",
":cf-derivatives",
":memref-derivatives",
":EnzymeOpsIncGen",
":EnzymePassesIncGen",
":EnzymeTypesIncGen",
Expand Down
24 changes: 20 additions & 4 deletions enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2820,8 +2820,16 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers(

if (UA != UseActivity::AllStores) {
if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(a)) {
if (ifaceOp.isArgInactive(parent))
return true;
bool allInactive = true;
for (OpOperand &operand : a->getOpOperands()) {
if (parent == operand.get() &&
!ifaceOp.isArgInactive(operand.getOperandNumber())) {
allInactive = false;
break;
}
}
if (allInactive)
continue;
}
}

Expand Down Expand Up @@ -3394,8 +3402,16 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned(
}

if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(a)) {
if (ifaceOp.isArgInactive(val))
return true;
bool allInactive = true;
for (OpOperand &operand : a->getOpOperands()) {
if (operand.get() == val &&
!ifaceOp.isArgInactive(operand.getOperandNumber())) {
allInactive = false;
break;
}
}
if (allInactive)
continue;
}

if (isa<LLVM::ReturnOp>(a)) {
Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ def : ControlFlowOp<"affine", "AffineIfOp", [{
}]>;

def : RegionTerminatorOp<"affine", "AffineYieldOp">;
def : ReadOnlyIdentityOp<"affine", "AffineLoadOp", [0]>;
def : ReadOnlyIdentityOp<"affine", "AffineVectorLoadOp", [0]>;
5 changes: 5 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ set(LLVM_TARGET_DEFINITIONS CFDerivatives.td)
enzyme_tablegen(CFDerivatives.inc -gen-mlir-derivatives)
add_public_tablegen_target(CFDerivativesIncGen)

set(LLVM_TARGET_DEFINITIONS MemRefDerivatives.td)
enzyme_tablegen(MemRefDerivatives.inc -gen-mlir-derivatives)
add_public_tablegen_target(MemRefDerivativesIncGen)

add_mlir_library(MLIREnzymeImplementations
AffineAutoDiffOpInterfaceImpl.cpp
ArithAutoDiffOpInterfaceImpl.cpp
Expand All @@ -42,6 +46,7 @@ add_mlir_library(MLIREnzymeImplementations
NVVMDerivativesIncGen
SCFDerivativesIncGen
CFDerivativesIncGen
MemRefDerivativesIncGen

LINK_LIBS PUBLIC
MLIRArithDialect
Expand Down
6 changes: 6 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ class ControlFlowOp<string dialect_, string opName_, string impl_> {
string impl = impl_;
}

class ReadOnlyIdentityOp<string dialect_, string opName_, list<int> diffargs_> {
string dialect = dialect_;
string opName = opName_;
list<int> diffargs = diffargs_;
}

class BranchOp<string dialect_, string opName_> {
string dialect = dialect_;
string opName = opName_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,37 @@ void mlir::enzyme::detail::branchingForwardHandler(Operation *inst,
return;
}

LogicalResult mlir::enzyme::detail::readOnlyIdentityForwardHandler(
Operation *orig, OpBuilder &builder, MGradientUtils *gutils) {

auto iface = cast<ActivityOpInterface>(orig);

SmallVector<Value> newOperands;
newOperands.reserve(orig->getNumOperands());
for (OpOperand &operand : orig->getOpOperands()) {
if (iface.isArgInactive(operand.getOperandNumber())) {
newOperands.push_back(gutils->getNewFromOriginal(operand.get()));
} else {
if (gutils->isConstantValue(operand.get()))
return failure();
newOperands.push_back(gutils->invertPointerM(operand.get(), builder));
}
}

// Assuming shadows following the originals are fine.
// TODO: consider extending to have a ShadowableTerminatorOpInterface
Operation *primal = gutils->getNewFromOriginal(orig);
Operation *shadow = builder.clone(*primal);
shadow->setOperands(newOperands);
for (auto &&[oval, sval] :
llvm::zip(orig->getResults(), shadow->getResults())) {
gutils->setDiffe(oval, sval, builder);
}
llvm::errs() << " shadow load: " << *shadow << "\n";

return success();
}

void mlir::enzyme::detail::regionTerminatorForwardHandler(
Operation *origTerminator, OpBuilder &builder, MGradientUtils *gutils) {
auto termIface = cast<RegionBranchTerminatorOpInterface>(origTerminator);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ void branchingForwardHandler(Operation *op, OpBuilder &builder,
void regionTerminatorForwardHandler(Operation *op, OpBuilder &builder,
MGradientUtils *gutils);

// Implements forward-mode differentiation of read-only (including read-none)
// operations which do not perform computatoin
LogicalResult readOnlyIdentityForwardHandler(Operation *op, OpBuilder &builder,
MGradientUtils *gutils);

// Implements the forward autodiff interface for operations whose derivatives
// are can be inferred by analyzing their control flow and differentiating the
// nested operations.
Expand Down Expand Up @@ -80,6 +85,19 @@ class AutoDiffUsingRegionTerminator
return success();
}
};

// Implements the forward autodiff interface for operations which are
// read only and identity like (aka not computing sin of mem read).
template <typename OpTy>
class AutoDiffUsingReadOnlyIdentity
: public AutoDiffOpInterface::ExternalModel<
AutoDiffUsingReadOnlyIdentity<OpTy>, OpTy> {
public:
LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder,
MGradientUtils *gutils) const {
return readOnlyIdentityForwardHandler(op, builder, gutils);
}
};
} // namespace detail

// Registers AutoDiffUsingControlFlow for the given op.
Expand All @@ -99,6 +117,12 @@ void registerAutoDiffUsingRegionTerminatorInterface(MLIRContext &context) {
OpTy::template attachInterface<detail::AutoDiffUsingRegionTerminator<OpTy>>(
context);
}
// Registers AutoDiffUsingRegionTerminator for the given op.
template <typename OpTy>
void registerAutoDiffUsingReadOnlyIdentityInterface(MLIRContext &context) {
OpTy::template attachInterface<detail::AutoDiffUsingReadOnlyIdentity<OpTy>>(
context);
}

// Interface registration hooks for individual upstream dialects.
void registerAffineDialectAutoDiffInterface(DialectRegistry &registry);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,7 @@ struct InlineAsmActivityInterface
auto str = asmOp.getAsmString();
return str.contains("cpuid") || str.contains("exit");
}
bool isArgInactive(Operation *op, mlir::Value) const {
return isInactive(op);
}
};

struct LoadOpInterface
: public AutoDiffOpInterface::ExternalModel<LoadOpInterface, LLVM::LoadOp> {
LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder,
MGradientUtils *gutils) const {
auto loadOp = cast<LLVM::LoadOp>(op);
if (!gutils->isConstantValue(loadOp)) {
Type shadowType =
cast<AutoDiffTypeInterface>(loadOp.getType()).getShadowType();
mlir::Value res = builder.create<LLVM::LoadOp>(
loadOp.getLoc(), shadowType,
gutils->invertPointerM(loadOp.getAddr(), builder));
gutils->setDiffe(loadOp, res, builder);
}
gutils->eraseIfUnused(op);
return success();
}
bool isArgInactive(Operation *op, size_t) const { return isInactive(op); }
};

struct StoreOpInterface
Expand Down Expand Up @@ -115,7 +95,6 @@ class PointerTypeInterface
void mlir::enzyme::registerLLVMDialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context, LLVM::LLVMDialect *) {
LLVM::LoadOp::attachInterface<LoadOpInterface>(*context);
LLVM::StoreOp::attachInterface<StoreOpInterface>(*context);
LLVM::AllocaOp::attachInterface<AllocaOpInterface>(*context);
LLVM::LLVMPointerType::attachInterface<PointerTypeInterface>(*context);
Expand Down
8 changes: 8 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,11 @@ def : InactiveOp<"LLVM", "MemsetOp">;
def : InactiveOp<"LLVM", "UndefOp">;
def : InactiveOp<"LLVM", "ConstantOp">;
def : InactiveOp<"LLVM", "UnreachableOp">;


def : ReadOnlyIdentityOp<"LLVM", "LoadOp", [0]>;
def : ReadOnlyIdentityOp<"LLVM", "AddrSpaceCastOp", [0]>;
def : ReadOnlyIdentityOp<"LLVM", "BitCastOp", [0]>;
def : ReadOnlyIdentityOp<"LLVM", "GEPOp", [0]>;
def : ReadOnlyIdentityOp<"LLVM", "PtrToIntOp", [0]>;
def : ReadOnlyIdentityOp<"LLVM", "IntToPtrOp", [0]>;
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,7 @@ using namespace mlir;
using namespace mlir::enzyme;

namespace {
struct LoadOpInterface
: public AutoDiffOpInterface::ExternalModel<LoadOpInterface,
memref::LoadOp> {
LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder,
MGradientUtils *gutils) const {
auto loadOp = cast<memref::LoadOp>(op);
if (!gutils->isConstantValue(loadOp)) {
SmallVector<mlir::Value> inds;
for (auto ind : loadOp.getIndices())
inds.push_back(gutils->getNewFromOriginal(ind));
mlir::Value res = builder.create<memref::LoadOp>(
loadOp.getLoc(), gutils->invertPointerM(loadOp.getMemref(), builder),
inds);
gutils->setDiffe(loadOp, res, builder);
}
gutils->eraseIfUnused(op);
return success();
}
};
#include "Implementations/MemRefDerivatives.inc"

struct StoreOpInterface
: public AutoDiffOpInterface::ExternalModel<StoreOpInterface,
Expand Down Expand Up @@ -295,7 +277,7 @@ class MemRefTypeInterface
void mlir::enzyme::registerMemRefDialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context, memref::MemRefDialect *) {
memref::LoadOp::attachInterface<LoadOpInterface>(*context);
registerInterfaces(context);
memref::StoreOp::attachInterface<StoreOpInterface>(*context);
memref::AllocOp::attachInterface<AllocOpInterface>(*context);
MemRefType::attachInterface<MemRefTypeInterface>(*context);
Expand Down
13 changes: 13 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
include "Common.td"

def : ReadOnlyIdentityOp<"memref", "LoadOp", [0]>;
def : ReadOnlyIdentityOp<"memref", "CastOp", [0]>;
def : ReadOnlyIdentityOp<"memref", "CollapseShapeOp", [0]>;
def : ReadOnlyIdentityOp<"memref", "ExpandShapeOp", [0]>;
def : ReadOnlyIdentityOp<"memref", "ReinterpretCastOp", [0]>;
def : ReadOnlyIdentityOp<"memref", "ReshapeOp", [0]>;
def : ReadOnlyIdentityOp<"memref", "TransposeOp", [0]>;
def : ReadOnlyIdentityOp<"memref", "ViewOp", [0]>;
def : ReadOnlyIdentityOp<"memref", "SubViewOp", [0]>;

def : InactiveOp<"memref", "DimOp">;
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def ActivityOpInterface
}],
/*retTy=*/"bool",
/*methodName=*/"isArgInactive",
/*args=*/(ins "::mlir::Value":$val)
/*args=*/(ins "size_t":$opidx)
>
];
}
Expand Down
6 changes: 3 additions & 3 deletions enzyme/test/MLIR/ForwardMode/affine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ module {
%mul = arith.mulf %x, %x : f64
memref.store %mul, %mem[%c0] : memref<1xf64>
}
%r = memref.load %mem[%c0] : memref<1xf64>
%r = affine.load %mem[0] : memref<1xf64>
%res = arith.mulf %c2, %r : f64
return %res : f64
}
Expand All @@ -96,8 +96,8 @@ module {
// CHECK: memref.store %[[v6]], %[[alloc]][%[[c0]]] : memref<1xf64>
// CHECK: memref.store %[[v7]], %[[alloc_2]][%[[c0]]] : memref<1xf64>
// CHECK: }
// CHECK: %[[v0:.+]] = memref.load %[[alloc]][%[[c0]]] : memref<1xf64>
// CHECK: %[[v1:.+]] = memref.load %[[alloc_2]][%[[c0]]] : memref<1xf64>
// CHECK: %[[v0:.+]] = affine.load %[[alloc]][0] : memref<1xf64>
// CHECK: %[[v1:.+]] = affine.load %[[alloc_2]][0] : memref<1xf64>
// CHECK: %[[v2:.+]] = arith.mulf %[[v0]], %[[cst_0]] : f64
// CHECK: %[[v3:.+]] = arith.mulf %[[cst_0]], %[[v1]] : f64
// CHECK: return %[[v2]] : f64
Expand Down
34 changes: 33 additions & 1 deletion enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1884,12 +1884,16 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os,
os << " public ActivityOpInterface::ExternalModel<"
<< opName << "Activity, " << dialect << "::" << opName << "> {\n";
os << " bool isInactive(mlir::Operation*) const { return true; }\n";
os << " bool isArgInactive(mlir::Operation*, mlir::Value) const { "
os << " bool isArgInactive(mlir::Operation*, size_t) const { "
"return true; }\n";
os << "};\n";
}
const auto &cfpatterns =
recordKeeper.getAllDerivedDefinitions("ControlFlowOp");

const auto &ropatterns =
recordKeeper.getAllDerivedDefinitions("ReadOnlyIdentityOp");

for (auto &pattern : cfpatterns) {
auto opName = pattern->getValueAsString("opName");
auto dialect = pattern->getValueAsString("dialect");
Expand All @@ -1902,6 +1906,26 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os,
os << "};\n";
}

for (auto &pattern : ropatterns) {
auto opName = pattern->getValueAsString("opName");
auto dialect = pattern->getValueAsString("dialect");
auto diffargs = pattern->getValueAsListOfInts("diffargs");
os << "struct " << opName << "ROActivity : \n";
os << " public ActivityOpInterface::ExternalModel<" << opName
<< "ROActivity, " << dialect << "::" << opName << "> {\n";
os << " bool isInactive(mlir::Operation* op) const {\n";
os << " for (size_t i=0, len=op->getNumOperands(); i<len; i++)\n";
os << " if (!isArgInactive(op, i)) return false;\n";
os << " return true;\n";
os << " };\n";
os << " bool isArgInactive(mlir::Operation*, size_t idx) const {\n";
for (auto diffarg : diffargs) {
os << " if (idx == " << diffarg << ") return false;\n";
}
os << " return true;\n }\n";
os << "};\n";
}

const auto &brpatterns = recordKeeper.getAllDerivedDefinitions("BranchOp");

const auto &regtpatterns =
Expand Down Expand Up @@ -1930,6 +1954,14 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os,
os << " registerAutoDiffUsingControlFlowInterface<" << dialect
<< "::" << opName << ">(*context);\n";
}
for (Record *pattern : ropatterns) {
auto opName = pattern->getValueAsString("opName");
auto dialect = pattern->getValueAsString("dialect");
os << " " << dialect << "::" << opName << "::attachInterface<" << opName
<< "ROActivity>(*context);\n";
os << " registerAutoDiffUsingReadOnlyIdentityInterface<" << dialect
<< "::" << opName << ">(*context);\n";
}
for (Record *pattern : brpatterns) {
auto opName = pattern->getValueAsString("opName");
auto dialect = pattern->getValueAsString("dialect");
Expand Down

0 comments on commit 1e58352

Please sign in to comment.