diff --git a/enzyme/BUILD b/enzyme/BUILD index c9ae1c3cdb71..8954d5433e0a 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -392,7 +392,35 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/MLIR/Implementations/ArithDerivatives.td", - td_srcs = ["Enzyme/MLIR/Implementations/ArithDerivatives.td"], + td_srcs = ["Enzyme/MLIR/Implementations/ArithDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + deps = [ + ":enzyme-tblgen", + ], +) + +gentbl( + name = "llvm-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/LLVMDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/LLVMDerivatives.td", + td_srcs = ["Enzyme/MLIR/Implementations/LLVMDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + deps = [ + ":enzyme-tblgen", + ], +) + +gentbl( + name = "nvvm-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/NVVMDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/NVVMDerivatives.td", + td_srcs = ["Enzyme/MLIR/Implementations/NVVMDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], deps = [ ":enzyme-tblgen", ], @@ -420,6 +448,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":arith-derivatives", + ":llvm-derivatives", + ":nvvm-derivatives", ":EnzymeOpsIncGen", ":EnzymePassesIncGen", ":EnzymeTypesIncGen", diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp index 90b8ab3a11a3..bb8c8f2ccc9c 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp @@ -3,7 +3,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" @@ -468,16 +467,15 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, // during adjoint generation) if (isa(I)) return true; - - if (auto ifaceOp = dyn_cast(op)) { - if (ifaceOp.isInactive()) { - return true; - } - } + + if (auto ifaceOp = dyn_cast(I)) { + if (ifaceOp.isInactive()) { + return true; + } + } // Branch, unreachable, and previously computed constants are inactive - if (/*|| isa(I)*/ || - ConstantOperations.contains(I)) { + if (/*|| isa(I)*/ ConstantOperations.contains(I)) { return true; } @@ -600,10 +598,6 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, // *I // << "\n"; - if (isa(I)) { - InsertConstantOperation(TR, I); - } - // if (auto II = dyn_cast(I)) { // switch (II->getIntrinsicID()) { // case Intrinsic::nvvm_barrier0: @@ -1127,13 +1121,6 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, if (Val.getType().isa()) return true; - // All function pointers are considered active in case an augmented primal - // or reverse is needed - if (Val.getDefiningOp() && - isa(Val.getDefiningOp())) { - return false; - } - /// If we've already shown this value to be inactive if (ConstantValues.find(Val) != ConstantValues.end()) { return true; @@ -1150,10 +1137,10 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, if (Operation *definingOp = Val.getDefiningOp()) { if (auto ifaceOp = dyn_cast(definingOp)) { - if (ifaceOp.isInactive()) { - return true; - } - } + if (ifaceOp.isInactive()) { + return true; + } + } } if (auto arg = Val.dyn_cast()) { @@ -1819,17 +1806,6 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, } } - if (auto iasm = dyn_cast(op)) { - if (iasm.getAsmString().contains("exit") || - iasm.getAsmString().contains("cpuid")) - return false; - } - if (isa(op)) { - return true; - } - // If this is a malloc or free, this doesn't impact the activity if (auto CI = dyn_cast(op)) { if (CI->hasAttr("enzyme_inactive")) @@ -2538,17 +2514,6 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( // llvm::errs() << " < UPSEARCH" << (int)directions << ">" << *inst << // "\n"; - // cpuid is explicitly an inactive instruction - if (auto iasm = dyn_cast(op)) { - if (iasm.getAsmString().contains("cpuid")) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction from known cpuid instruction - // " - // << *inst << "\n"; - return true; - } - } - if (auto store = dyn_cast(op)) { if (isConstantValue(TR, store.getValue()) || isConstantValue(TR, store.getAddr())) { @@ -2640,13 +2605,6 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( return true; } } - // Intrinsics known always to be inactive - if (isa(op)) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-intrinsic " - // << *inst << "\n"; - return true; - } if (auto gep = dyn_cast(op)) { // A gep's only args that could make it active is the pointer operand @@ -3355,10 +3313,6 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( LLVM::SExtOp, LLVM::ZExtOp, LLVM::TruncOp, - LLVM::SIToFPOp, - LLVM::UIToFPOp, - LLVM::FPToSIOp, - LLVM::FPToUIOp, LLVM::FPExtOp, LLVM::FPTruncOp // clang-format on @@ -3502,19 +3456,6 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( // TODO: in MLIR, users are always operations // if (Operation *inst = a) { - auto mayWriteToMemory = [](Operation *op) { - auto iface = dyn_cast(op); - if (!iface) - return true; - - SmallVector effects; - iface.getEffects(effects); - return llvm::any_of( - effects, [](const MemoryEffects::EffectInstance &effect) { - return isa(effect.getEffect()); - }); - }; - if (!mayWriteToMemory(inst) /*|| (isa(inst) && AA.onlyReadsMemory(cast(inst)))*/) { // // if not written to memory and returning a known constant, this diff --git a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp index 19b19ab04c6d..4f922d6eb5d5 100644 --- a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp @@ -511,13 +511,13 @@ class DenseForwardActivityAnalysis join(after, before); ChangeResult result = ChangeResult::NoChange; - // If we know this is inactive by definition - if (auto ifaceOp = dyn_cast(op)) { - if (ifaceOp.isInactive()) { - propagateIfChanged(after, result); - return; - } - } + // TODO If we know this is inactive by definition + // if (auto ifaceOp = dyn_cast(op)) { + // if (ifaceOp.isInactive()) { + // propagateIfChanged(after, result); + // return; + // } + // } auto memory = dyn_cast(op); // If we can't reason about the memory effects, then conservatively assume @@ -669,12 +669,12 @@ class DenseBackwardActivityAnalysis void visitOperation(Operation *op, const BackwardMemoryActivity &after, BackwardMemoryActivity *before) override { - // If we know this is inactive by definition - if (auto ifaceOp = dyn_cast(op)) { - if (ifaceOp.isInactive()) { - return; - } - } + // TODO: If we know this is inactive by definition + // if (auto ifaceOp = dyn_cast(op)) { + // if (ifaceOp.isInactive()) { + // return; + // } + // } // Initialize the return activity of arguments. if (op->hasTrait() && op->getParentOp() == parentOp) { diff --git a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td index bb713ef61799..fb7f113f16fe 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td @@ -1,25 +1,5 @@ -class MLIRDerivative resultOps> { - string dialect = dialect_; - string opName = opName_; - dag PatternToMatch = patternToMatch; - list ArgDerivatives = resultOps; -} +include "Common.td" -class Operation { - bit usesPrimal = usesPrimal_; - bit usesShadow = usesShadow_; - bit usesCustom = usesCustom_; -} - -class DiffeRetIndex indices_> { - list indices = indices_; -} -def DiffeRet : DiffeRetIndex<[-1]>; - -class Inst : Operation { - string name = mnemonic; - string dialect = dialect_; -} class ArithInst : Inst; def AddF : ArithInst<"arith::AddFOp">; @@ -32,9 +12,6 @@ def RemF : ArithInst<"arith::RemFOp">; def CheckedMulF : ArithInst<"arith::MulFOp">; def CheckedDivF : ArithInst<"arith::DivFOp">; -def Op { -} - def : MLIRDerivative<"arith", "AddFOp", (Op $x, $y), [ (DiffeRet), diff --git a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt index a66b303ca706..521ba76c22bc 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt @@ -7,9 +7,14 @@ set(LLVM_TARGET_DEFINITIONS LLVMDerivatives.td) enzyme_tablegen(LLVMDerivatives.inc -gen-mlir-derivatives) add_public_tablegen_target(LLVMDerivativesIncGen) +set(LLVM_TARGET_DEFINITIONS NVVMDerivatives.td) +enzyme_tablegen(NVVMDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(NVVMDerivativesIncGen) + add_mlir_library(MLIREnzymeImplementations ArithAutoDiffOpInterfaceImpl.cpp LLVMAutoDiffOpInterfaceImpl.cpp + NVVMAutoDiffOpInterfaceImpl.cpp MemRefAutoDiffOpInterfaceImpl.cpp LinalgAutoDiffOpInterfaceImpl.cpp BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -18,6 +23,8 @@ add_mlir_library(MLIREnzymeImplementations DEPENDS MLIRAutoDiffOpInterfaceIncGen ArithDerivativesIncGen + LLVMDerivativesIncGen + NVVMDerivativesIncGen LINK_LIBS PUBLIC MLIRArithDialect diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td new file mode 100644 index 000000000000..3909405320d1 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -0,0 +1,30 @@ +class InactiveOp { + string dialect = dialect_; + string opName = opName_; +} + +class MLIRDerivative resultOps> { + string dialect = dialect_; + string opName = opName_; + dag PatternToMatch = patternToMatch; + list ArgDerivatives = resultOps; +} + +class Operation { + bit usesPrimal = usesPrimal_; + bit usesShadow = usesShadow_; + bit usesCustom = usesCustom_; +} + +class DiffeRetIndex indices_> { + list indices = indices_; +} +def DiffeRet : DiffeRetIndex<[-1]>; + +class Inst : Operation { + string name = mnemonic; + string dialect = dialect_; +} + +def Op { +} diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index 669b028998c6..56af04b30133 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -19,6 +19,7 @@ namespace enzyme { void registerArithDialectAutoDiffInterface(DialectRegistry ®istry); void registerBuiltinDialectAutoDiffInterface(DialectRegistry ®istry); void registerLLVMDialectAutoDiffInterface(DialectRegistry ®istry); +void registerNVVMDialectAutoDiffInterface(DialectRegistry ®istry); void registerMemRefDialectAutoDiffInterface(DialectRegistry ®istry); void registerSCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index b39abbbcf50e..079dd1cb64e9 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -27,6 +27,19 @@ namespace { } // namespace namespace { +struct InlineAsmActivityInterface + : public ActivityOpInterface::ExternalModel { + bool isInactive(Operation *op) const { + auto asmOp = cast(op); + auto str = asmOp.getAsmString(); + return str.contains("cpuid") || str.contains("exit"); + } + bool isArgInactive(Operation *op, mlir::Value) const { + return isInactive(op); + } +}; + struct LoadOpInterface : public AutoDiffOpInterface::ExternalModel { LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td new file mode 100644 index 000000000000..9e5f28e41665 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td @@ -0,0 +1,17 @@ +include "Common.td" + +def : InactiveOp<"LLVM", "SIToFPOp">; +def : InactiveOp<"LLVM", "UIToFPOp">; +def : InactiveOp<"LLVM", "FPToSIOp">; +def : InactiveOp<"LLVM", "FPToUIOp">; +def : InactiveOp<"LLVM", "AssumeOp">; +def : InactiveOp<"LLVM", "StackSaveOp">; +def : InactiveOp<"LLVM", "StackRestoreOp">; +def : InactiveOp<"LLVM", "LifetimeStartOp">; +def : InactiveOp<"LLVM", "LifetimeEndOp">; +def : InactiveOp<"LLVM", "Prefetch">; +def : InactiveOp<"LLVM", "MemsetOp">; + +def : InactiveOp<"LLVM", "UndefOp">; +def : InactiveOp<"LLVM", "ConstantOp">; +def : InactiveOp<"LLVM", "UnreachableOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000000..4d8116ce011b --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,34 @@ +//===- LLVMAutoDiffOpInterfaceImpl.cpp - Interface external model --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream LLVM dialect. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "Interfaces/AutoDiffTypeInterface.h" +#include "Interfaces/GradientUtils.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Support/LogicalResult.h" + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +#include "Implementations/NVVMDerivatives.inc" +} // namespace + +void mlir::enzyme::registerNVVMDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, NVVM::NVVMDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/NVVMDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/NVVMDerivatives.td new file mode 100644 index 000000000000..f34dfb564cbc --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/NVVMDerivatives.td @@ -0,0 +1,4 @@ +include "Common.td" + +// TODO in reverse replicate in reverse pass +def : InactiveOp<"NVVM", "Barrier0Op">; diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index 589ffa610a28..b6bb33c51df9 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -100,6 +100,7 @@ int main(int argc, char **argv) { enzyme::registerArithDialectAutoDiffInterface(registry); enzyme::registerBuiltinDialectAutoDiffInterface(registry); enzyme::registerLLVMDialectAutoDiffInterface(registry); + enzyme::registerNVVMDialectAutoDiffInterface(registry); enzyme::registerMemRefDialectAutoDiffInterface(registry); enzyme::registerSCFDialectAutoDiffInterface(registry); enzyme::registerLinalgDialectAutoDiffInterface(registry); diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 11384a2284e8..1f4a3dc1b892 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1875,6 +1875,19 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } if (intrinsic == MLIRDerivatives) { + const auto &actpatterns = + recordKeeper.getAllDerivedDefinitions("InactiveOp"); + for (auto &pattern : actpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << "struct " << opName << "Activity : \n"; + os << " public ActivityOpInterface::ExternalModel<" + << opName << "Activity, " << dialect << "::" << opName << "> {\n"; + os << " bool isInactive(mlir::Operation*) const { return true; }\n"; + os << " bool isArgInactive(mlir::Operation*, mlir::Value) const { " + "return true; }\n"; + os << "};\n"; + } os << "void registerInterfaces(MLIRContext* context) {\n"; for (Record *pattern : patterns) { auto opName = pattern->getValueAsString("opName"); @@ -1884,6 +1897,12 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " " << dialect << "::" << opName << "::attachInterface<" << opName << "RevDerivative>(*context);\n"; } + for (Record *pattern : actpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " " << dialect << "::" << opName << "::attachInterface<" << opName + << "Activity>(*context);\n"; + } os << "}\n"; } }