Skip to content

Commit

Permalink
[WIP] Simplify MLIR (#1646)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jan 31, 2024
1 parent ff15cd8 commit 53a0bd8
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 87 deletions.
4 changes: 2 additions & 2 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ class MDiffeGradientUtils : public MGradientUtils {
FunctionOpInterface oldFunc_, MTypeAnalysis &TA,
MTypeResults TR, IRMapping &invertedPointers_,
const SmallPtrSetImpl<mlir::Value> &constantvalues_,
const SmallPtrSetImpl<mlir::Value> &returnvals_,
const SmallPtrSetImpl<mlir::Value> &activevals_,
DIFFE_TYPE ActiveReturn,
ArrayRef<DIFFE_TYPE> constant_values,
IRMapping &origToNew_,
std::map<Operation *, Operation *> &origToNewOps_,
DerivativeMode mode, unsigned width, bool omp)
: MGradientUtils(Logic, newFunc_, oldFunc_, TA, TR, invertedPointers_,
constantvalues_, returnvals_, ActiveReturn,
constantvalues_, activevals_, ActiveReturn,
constant_values, origToNew_, origToNewOps_, mode, width,
omp) {}

Expand Down
64 changes: 5 additions & 59 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse(
ArrayRef<DIFFE_TYPE> ArgDiffeTypes_, IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
DerivativeMode mode_, unsigned width, SymbolTableCollection &symbolTable_)
: newFunc(newFunc_), oldFunc(oldFunc_), Logic(Logic), mode(mode_),
originalToNewFn(originalToNewFn_),
originalToNewFnOps(originalToNewFnOps_), TA(TA_), width(width),
ArgDiffeTypes(ArgDiffeTypes_), symbolTable(symbolTable_) {
: MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {},
invertedPointers_, constantvalues_, activevals_,
ReturnActivity, ArgDiffeTypes_, originalToNewFn_,
originalToNewFnOps_, mode_, width, /*omp*/ false),
symbolTable(symbolTable_) {

initInitializationBlock(invertedPointers_, ArgDiffeTypes_);
}
Expand Down Expand Up @@ -136,43 +137,6 @@ Value mlir::enzyme::MGradientUtilsReverse::insertInitShadowedGradient(
return gradient;
}

Value mlir::enzyme::MGradientUtilsReverse::getNewFromOriginal(
const mlir::Value originst) const {
if (!originalToNewFn.contains(originst)) {
llvm::errs() << oldFunc << "\n";
llvm::errs() << newFunc << "\n";
llvm::errs() << originst << "\n";
llvm_unreachable("Could not get new val from original");
}
return originalToNewFn.lookupOrNull(originst);
}

Block *mlir::enzyme::MGradientUtilsReverse::getNewFromOriginal(
mlir::Block *originst) const {
if (!originalToNewFn.contains(originst)) {
llvm::errs() << oldFunc << "\n";
llvm::errs() << newFunc << "\n";
llvm::errs() << originst << "\n";
llvm_unreachable("Could not get new blk from original");
}
return originalToNewFn.lookupOrNull(originst);
}

Operation *mlir::enzyme::MGradientUtilsReverse::getNewFromOriginal(
Operation *originst) const {
auto found = originalToNewFnOps.find(originst);
if (found == originalToNewFnOps.end()) {
llvm::errs() << oldFunc << "\n";
llvm::errs() << newFunc << "\n";
for (auto &pair : originalToNewFnOps) {
llvm::errs() << " map[" << pair.first << "] = " << pair.second << "\n";
}
llvm::errs() << originst << " - " << *originst << "\n";
llvm_unreachable("Could not get new op from original");
}
return found->second;
}

Operation *
mlir::enzyme::MGradientUtilsReverse::cloneWithNewOperands(OpBuilder &B,
Operation *op) {
Expand All @@ -182,24 +146,6 @@ mlir::enzyme::MGradientUtilsReverse::cloneWithNewOperands(OpBuilder &B,
return B.clone(*op, map);
}

bool mlir::enzyme::MGradientUtilsReverse::isConstantInstruction(
Operation *op) const {
return false;
}

bool mlir::enzyme::MGradientUtilsReverse::isConstantValue(Value v) const {
if (isa<mlir::IntegerType>(v.getType()))
return true;
if (isa<mlir::IndexType>(v.getType()))
return true;

if (matchPattern(v, m_Constant()))
return true;

// TODO
return false;
}

bool mlir::enzyme::MGradientUtilsReverse::requiresShadow(Type t) {
if (auto iface = dyn_cast<AutoDiffTypeInterface>(t)) {
return iface.requiresShadow();
Expand Down
29 changes: 3 additions & 26 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

#include <functional>

#include "GradientUtils.h"

namespace mlir {
namespace enzyme {

class MGradientUtilsReverse {
class MGradientUtilsReverse : public MDiffeGradientUtils {
public:
MGradientUtilsReverse(MEnzymeLogic &Logic, FunctionOpInterface newFunc_,
FunctionOpInterface oldFunc_, MTypeAnalysis &TA_,
Expand All @@ -32,13 +34,6 @@ class MGradientUtilsReverse {
DerivativeMode mode_, unsigned width,
SymbolTableCollection &symbolTable_);

// From CacheUtility
FunctionOpInterface newFunc;
FunctionOpInterface oldFunc;

MEnzymeLogic &Logic;
bool AtomicAdd;
DerivativeMode mode;
IRMapping invertedPointersGlobal;
IRMapping invertedPointersShadow;
IRMapping shadowValues;
Expand All @@ -47,26 +42,8 @@ class MGradientUtilsReverse {
IRMapping mapReverseModeBlocks;
DenseMap<Block *, SmallVector<std::pair<Value, Value>>> mapBlockArguments;

IRMapping originalToNewFn;
std::map<Operation *, Operation *> originalToNewFnOps;

MTypeAnalysis &TA;

unsigned width;
ArrayRef<DIFFE_TYPE> ArgDiffeTypes;

SymbolTableCollection &symbolTable;

mlir::Value getNewFromOriginal(const mlir::Value originst) const;
mlir::Block *getNewFromOriginal(mlir::Block *originst) const;
Operation *getNewFromOriginal(Operation *originst) const;

void erase(Operation *op) { op->erase(); }
void eraseIfUnused(Operation *op, bool erase = true, bool check = true) {
// TODO
}
bool isConstantValue(mlir::Value v) const;
bool isConstantInstruction(mlir::Operation *v) const;
bool hasInvertPointer(mlir::Value v);
mlir::Value invertPointerM(mlir::Value v, OpBuilder &builder);
mlir::Value diffe(mlir::Value v, OpBuilder &builder);
Expand Down

0 comments on commit 53a0bd8

Please sign in to comment.