From e199976f07bbbf62377d21bae9802e59eaf75855 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 3 Feb 2023 13:53:01 -0500 Subject: [PATCH] Implement EnzymeRules --- Project.toml | 2 +- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/EnzymeCore.jl | 2 + lib/EnzymeCore/src/rules.jl | 117 +++++++ src/Enzyme.jl | 3 +- src/api.jl | 5 + src/compiler.jl | 561 ++++++++++++++++++++++++++++++- src/compiler/interpreter.jl | 81 +++-- src/compiler/pmap.jl | 3 +- src/compiler/reflection.jl | 1 - test/rrules.jl | 63 ++++ test/rules.jl | 92 +++++ test/runtests.jl | 18 +- 13 files changed, 901 insertions(+), 49 deletions(-) create mode 100644 lib/EnzymeCore/src/rules.jl create mode 100644 test/rrules.jl create mode 100644 test/rules.jl diff --git a/Project.toml b/Project.toml index c3ec4c9c76..888c8d0f29 100644 --- a/Project.toml +++ b/Project.toml @@ -17,7 +17,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] CEnum = "0.4" -EnzymeCore = "0.1" +EnzymeCore = "0.2" Enzyme_jll = "0.0.48" GPUCompiler = "0.16.7, 0.17" LLVM = "4.14" diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 6f1a4ce039..7c0470458b 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.1.0" +version = "0.2.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 1b63664b4f..d8b81e0a85 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -126,4 +126,6 @@ struct ForwardMode <: Mode end const Forward = ForwardMode() +include("rules.jl") + end # module EnzymeCore diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl new file mode 100644 index 0000000000..18599c7221 --- /dev/null +++ b/lib/EnzymeCore/src/rules.jl @@ -0,0 +1,117 @@ +module EnzymeRules + +import EnzymeCore: Annotation +export Config, ConfigWidth +export needs_primal, needs_shadow, width, overwritten + +import Base: unwrapva, isvarargtype, unwrap_unionall, rewrap_unionall + +""" + forward(func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...) + +Calculate the forward derivative. The first argument `func` is the callable +for which the rule applies to. Either wrapped in a [`Const`](@ref)), or +a [`Duplicated`](@ref) if it is a closure. +The second argument is the return type annotation, and all other arguments are +the annotated function arguments. +""" +function forward end + +struct Config{NeedsPrimal, NeedsShadow, Width, Overwritten} end +const ConfigWidth{Width} = Config{<:Any,<:Any, Width} + +needs_primal(::Config{NeedsPrimal}) where NeedsPrimal = NeedsPrimal +needs_shadow(::Config{<:Any, NeedsShadow}) where NeedsShadow = NeedsShadow +width(::Config{<:Any, <:Any, Width}) where Width = Width +overwritten(::Config{<:Any, <:Any, <:Any, Overwritten}) where Overwritten = Overwritten + +""" + augmented_primal(::Config, func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...) + +Must return a tuple of length 2. +The first-value is primal value and the second is the tape. If no tape is +required return `(val, nothing)`. +""" +function augmented_primal end + +""" + reverse(::Config, func::Annotation{typeof(f)}, dret::Annotation, tape, args::Annotation...) + +Takes gradient of derivative, activity annotation, and tape +""" +function reverse end + +function _annotate(@nospecialize(T)) + if isvarargtype(T) + VA = T + T = _annotate(VA.T) + if isdefined(VA, :N) + return Vararg{T, VA.N} + else + return Vararg{T} + end + else + return TypeVar(gensym(), Annotation{T}) + end +end + +function _annotate_tt(@nospecialize(TT0)) + TT = Base.unwrap_unionall(TT0) + ft = TT.parameters[1] + tt = map(T->_annotate(Base.rewrap_unionall(T, TT0)), TT.parameters[2:end]) + return ft, tt +end + +function has_frule_from_sig(@nospecialize(TT); world=Base.get_world_counter()) + ft, tt = _annotate_tt(TT) + TT = Tuple{<:Annotation{ft}, Type{<:Annotation}, tt...} + isapplicable(forward, TT; world) +end + +function has_rrule_from_sig(@nospecialize(TT); world=Base.get_world_counter()) + ft, tt = _annotate_tt(TT) + TT = Tuple{<:Config, <:Annotation{ft}, <:Annotation, <:Any, tt...} + isapplicable(reverse, TT; world) +end + +function has_frule(@nospecialize(f); world=Base.get_world_counter()) + TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{<:Annotation}, Vararg{<:Annotation}} + isapplicable(forward, TT; world) +end + +# Do we need this one? +function has_frule(@nospecialize(f), @nospecialize(TT::Type{<:Tuple}); world=Base.get_world_counter()) + TT = Base.unwrap_unionall(TT) + TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{<:Annotation}, TT.parameters...} + isapplicable(forward, TT; world) +end + +# Do we need this one? +function has_frule(@nospecialize(f), @nospecialize(RT::Type); world=Base.get_world_counter()) + TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{RT}, Vararg{<:Annotation}} + isapplicable(forward, TT; world) +end + +# Do we need this one? +function has_frule(@nospecialize(f), @nospecialize(RT::Type), @nospecialize(TT::Type{<:Tuple}); world=Base.get_world_counter()) + TT = Base.unwrap_unionall(TT) + TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{RT}, TT.parameters...} + isapplicable(forward, TT; world) +end + +# Base.hasmethod is a precise match we want the broader query. +function isapplicable(@nospecialize(f), @nospecialize(TT); world=Base.get_world_counter()) + tt = Base.to_tuple_type(TT) + sig = Base.signature_type(f, tt) + return !isempty(Base._methods_by_ftype(sig, -1, world)) # TODO cheaper way of querying? +end + +function issupported() + @static if VERSION < v"1.7.0" + return false + else + return true + end +end + +end # EnzymeRules diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 3f16a0ae79..7d9b022c72 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -15,6 +15,8 @@ export markType, batch_size, onehot, chunkedonehot using LinearAlgebra import EnzymeCore: ReverseMode, ForwardMode, Annotation, Mode +import EnzymeCore: EnzymeRules + # Independent code, must be loaded before "compiler.jl" include("pmap.jl") @@ -79,7 +81,6 @@ end end end - include("logic.jl") include("typeanalysis.jl") include("typetree.jl") diff --git a/src/api.jl b/src/api.jl index 97ff73fbcf..118ddd3c62 100644 --- a/src/api.jl +++ b/src/api.jl @@ -201,6 +201,7 @@ EnzymeRegisterFwdCallHandler(name, fwdhandle) = ccall((:EnzymeRegisterFwdCallHan EnzymeGetShadowType(width, T) = ccall((:EnzymeGetShadowType, libEnzyme), LLVMTypeRef, (UInt64,LLVMTypeRef), width, T) +EnzymeGradientUtilsGetMode(gutils) = ccall((:EnzymeGradientUtilsGetMode, libEnzyme), CDerivativeMode, (EnzymeGradientUtilsRef,), gutils) EnzymeGradientUtilsGetWidth(gutils) = ccall((:EnzymeGradientUtilsGetWidth, libEnzyme), UInt64, (EnzymeGradientUtilsRef,), gutils) EnzymeGradientUtilsNewFromOriginal(gutils, val) = ccall((:EnzymeGradientUtilsNewFromOriginal, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef), gutils, val) EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, val, orig) = ccall((:EnzymeGradientUtilsSetDebugLocFromOriginal, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef), gutils, val, orig) @@ -216,7 +217,11 @@ EnzymeGradientUtilsAllocationBlock(gutils) = ccall((:EnzymeGradientUtilsAllocati EnzymeGradientUtilsTypeAnalyzer(gutils) = ccall((:EnzymeGradientUtilsTypeAnalyzer, libEnzyme), EnzymeTypeAnalyzerRef, (EnzymeGradientUtilsRef,), gutils) EnzymeGradientUtilsAllocAndGetTypeTree(gutils, val) = ccall((:EnzymeGradientUtilsAllocAndGetTypeTree, libEnzyme), CTypeTreeRef, (EnzymeGradientUtilsRef,LLVMValueRef), gutils, val) + +EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall((:EnzymeGradientUtilsGetUncacheableArgs, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, UInt64), gutils, orig, uncacheable, size) +EnzymeGradientUtilsGetDiffeType(gutils, op, isforeign) = ccall((:EnzymeGradientUtilsGetDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, UInt8), gutils, op, isforeign) + EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP) = ccall((:EnzymeGradientUtilsGetReturnDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, Ptr{UInt8}), gutils, orig, needsPrimalP, needsShadowP) EnzymeGradientUtilsSubTransferHelper(gutils, mode, secretty, intrinsic, dstAlign, srcAlign, offset, dstConstant, origdst, srcConstant, origsrc, length, isVolatile, MTI, allowForward, shadowsLookedUp) = ccall((:EnzymeGradientUtilsSubTransferHelper, libEnzyme), diff --git a/src/compiler.jl b/src/compiler.jl index e46c1c23a4..e341811b2b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5,9 +5,10 @@ import Enzyme: Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, Bat Annotation, guess_activity, eltype, API, TypeTree, typetree, only!, shift!, data0!, merge!, TypeAnalysis, FnTypeInfo, Logic, allocatedinline, ismutabletype - using Enzyme +import EnzymeCore: EnzymeRules + using LLVM, GPUCompiler, Libdl import Enzyme_jll @@ -2186,7 +2187,7 @@ else const ctxToThreadSafe = Dict{LLVM.Context, LLVM.ThreadSafeContext}() end -function nested_codegen!(mod::LLVM.Module, f, tt) +function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, f, tt) # TODO: Put a cache here index on `mod` and f->tt ctx = LLVM.context(mod) @@ -2203,15 +2204,17 @@ end # - When OrcV2 only use a MaterializationUnit to avoid mutation of the module here target = DefaultCompilerTarget() - params = PrimalCompilerParams() + params = PrimalCompilerParams(mode) job = CompilerJob(target, funcspec, params) - otherMod, meta = GPUCompiler.codegen(:llvm, job; optimize=false, validate=false, ctx) + # TODO + parent_job = nothing + otherMod, meta = GPUCompiler.codegen(:llvm, job; optimize=false, cleanup=false, validate=false, parent_job=parent_job, ctx) entry = name(meta.entry) # Apply first stage of optimization's so that this module is at the same stage as `mod` optimize!(otherMod, JIT.get_tm()) - + # 4) Link the corresponding module LLVM.link!(mod, otherMod) @@ -2472,7 +2475,8 @@ function threadsfor_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRe else tt = Tuple{funcT, Core.Ptr{Cvoid}, dfuncT, Type{thunkTy}, Bool} end - entry = nested_codegen!(mod, runtime_pfor_fwd, tt) + mode = API.EnzymeGradientUtilsGetMode(gutils) + entry = nested_codegen!(mode, mod, runtime_pfor_fwd, tt) permit_inlining!(entry) push!(function_attributes(entry), EnumAttribute("alwaysinline"; ctx)) @@ -2513,7 +2517,8 @@ function threadsfor_augfwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValu else tt = Tuple{funcT, Core.Ptr{Cvoid}, dfuncT, Type{thunkTy}, Val{any_jltypes(GetTapeType(thunkTy))}, Bool} end - entry = nested_codegen!(mod, runtime_pfor_augfwd, tt) + mode = API.EnzymeGradientUtilsGetMode(gutils) + entry = nested_codegen!(mode, mod, runtime_pfor_augfwd, tt) permit_inlining!(entry) push!(function_attributes(entry), EnumAttribute("alwaysinline"; ctx)) @@ -2564,7 +2569,8 @@ function threadsfor_rev(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRe else tt = Tuple{funcT, Core.Ptr{Cvoid}, dfuncT, Type{thunkTy}, Val{any_jltypes(GetTapeType(thunkTy))}, STT, Bool} end - entry = nested_codegen!(mod, runtime_pfor_rev, tt) + mode = API.EnzymeGradientUtilsGetMode(gutils) + entry = nested_codegen!(mode, mod, runtime_pfor_rev, tt) permit_inlining!(entry) push!(function_attributes(entry), EnumAttribute("alwaysinline"; ctx)) @@ -2804,6 +2810,500 @@ function wait_rev(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gut return nothing end +function enzyme_custom_setup_args(B, orig, gutils, mi) + ctx = LLVM.context(orig) + ops = collect(operands(orig)) + called = ops[end] + ops = ops[1:end-1] + width = API.EnzymeGradientUtilsGetWidth(gutils) + + args = LLVM.Value[] + activity = Type[] + overwritten = Bool[] + + actives = LLVM.Value[] + + uncacheable = Vector{UInt8}(undef, length(ops)) + API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) + + sret = false + returnRoots = false + + jlargs = classify_arguments(mi.specTypes, eltype(llvmtype(called)), sret, returnRoots) + + op_idx = 1 + + alloctx = LLVM.Builder(ctx) + position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) + + for arg in jlargs + if arg.cc == GPUCompiler.GHOST + push!(activity, Const{arg.typ}) + push!(overwritten, false) + continue + end + + op = ops[op_idx] + push!(overwritten, uncacheable[op_idx] != 0) + op_idx+=1 + + val = LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, op)) + + activep = API.EnzymeGradientUtilsGetDiffeType(gutils, op, #=isforeign=#false) + + # TODO type analysis deduce if duplicated vs active + if activep == API.DFT_CONSTANT + Ty = Const{arg.typ} + + llty = convert(LLVMType, Ty; ctx) + + arval = LLVM.UndefValue(llty) + arval = insert_value!(B, arval, val, 0) + + al = alloca!(alloctx, llvmtype(arval)) + store!(B, arval, al) + push!(args, al) + + push!(activity, Ty) + + elseif activep == API.DFT_OUT_DIFF + Ty = Active{arg.typ} + llty = convert(LLVMType, Ty; ctx) + + arval = LLVM.UndefValue(llty) + arval = insert_value!(B, arval, val, 0) + + al = alloca!(alloctx, llvmtype(arval)) + store!(B, arval, al) + push!(args, al) + + push!(activity, Ty) + push!(actives, op) + else + ival = LLVM.Value(API.EnzymeGradientUtilsInvertPointer(gutils, op, B)) + if width == 1 + if activep == API.DFT_DUP_ARG + Ty = Duplicated{arg.typ} + else + @assert activep == API.DFT_DUP_NONEED + Ty = DuplicatedNoNeed{arg.typ} + end + else + if activep == API.DFT_DUP_ARG + Ty = BatchDuplicated{arg.typ, Int64(width)} + else + @assert activep == API.DFT_DUP_NONEED + Ty = BatchDuplicatedNoNeed{arg.typ, Int64(width)} + end + end + + llty = convert(LLVMType, Ty; ctx) + + arval = LLVM.UndefValue(llty) + arval = insert_value!(B, arval, val, 0) + arval = insert_value!(B, arval, ival, 1) + + al = alloca!(alloctx, llvmtype(arval)) + store!(B, arval, al) + push!(args, al) + push!(activity, Ty) + end + + end + + @assert op_idx-1 == length(ops) + + return args, activity, (overwritten...,), actives +end + +function enzyme_custom_setup_ret(gutils, orig, mi, job) + width = API.EnzymeGradientUtilsGetWidth(gutils) + interp = GPUCompiler.get_interpreter(job) + RealRt = Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype + + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP) + needsPrimal = needsPrimalP[] != 0 + + if !needsPrimal && activep == API.DFT_DUP_ARG + activep = API.DFT_DUP_NONEED + end + + if activep == API.DFT_CONSTANT + RT = Const{RealRt} + + elseif activep == API.DFT_OUT_DIFF + RT = Active{RealRt} + + elseif activep == API.DFT_DUP_ARG + if width == 1 + RT = Duplicated{RealRt} + else + RT = BatchDuplicated{RealRt, Int64(width)} + end + else + @assert activep == API.DFT_DUP_NONEED + if width == 1 + RT = DuplicatedNoNeed{RealRt} + else + RT = BatchDuplicatedNoNeed{RealRt, Int64(width)} + end + end + return RealRt, RT, needsPrimal, needsShadowP[] != 0 +end + +function enzyme_custom_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef})::Cvoid + orig = LLVM.Instruction(OrigCI) + ctx = LLVM.context(orig) + B = LLVM.Builder(B) + + width = API.EnzymeGradientUtilsGetWidth(gutils) + + if shadowR != C_NULL + unsafe_store!(shadowR,UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llvmtype(orig)))).ref) + end + + # TODO: don't inject the code multiple times for multiple calls + + # 1) extract out the MI from attributes + mi, job = enzyme_custom_extract_mi(orig) + + # 2) Create activity, and annotate function spec + args, activity, overwritten, actives = enzyme_custom_setup_args(B, orig, gutils, mi) + RealRt, RT, needsPrimal, needsShadow = enzyme_custom_setup_ret(gutils, orig, mi, job) + + alloctx = LLVM.Builder(ctx) + position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) + mode = API.EnzymeGradientUtilsGetMode(gutils) + mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) + + tt = copy(activity) + insert!(tt, 2, Type{RT}) + TT = Tuple{tt...} + + # TODO get world + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + world = enzyme_extract_world(fn) + if EnzymeRules.isapplicable(EnzymeRules.forward, TT; world) + @safe_debug "Applying custom forward rule" TT + llvmf = nested_codegen!(mode, mod, EnzymeRules.forward, TT) + else + @safe_debug "No custom forward rule is applicable for" TT + emit_error(B, orig, "Enzyme: No custom rule was appliable for " * string(TT)) + return nothing + end + + + sret = nothing + if !isempty(parameters(llvmf)) && any(map(k->kind(k)==kind(EnumAttribute("sret"; ctx)), collect(parameter_attributes(llvmf, 1)))) + sret = alloca!(alloctx, eltype(llvmtype(parameters(llvmf)[1]))) + pushfirst!(args, sret) + end + + for i in eachindex(args) + party = llvmtype(parameters(llvmf)[i]) + if llvmtype(args[i]) == party + continue + end + if LLVM.addrspace(party) != 0 + args[i] = addrspacecast!(B, args[i], party) + else + GPUCompiler.@safe_error "Calling convention mismatch", party, args[i], i, llvmf + return + end + end + + res = LLVM.call!(B, llvmf, args) + API.EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, res, orig) + + hasNoRet = any(map(k->kind(k)==kind(EnumAttribute("noreturn"; ctx)), collect(function_attributes(llvmf)))) + + if hasNoRet + return nothing + end + + if sret !== nothing + res = load!(B, sret) + end + + shadowV = C_NULL + normalV = C_NULL + + if RT <: Const + if needsPrimal + normalV = res.ref + end + else + if !needsPrimal + shadowV = res.ref + else + if !isa(llvmtype(res), LLVM.StructType) && !isa(llvmtype(res), LLVM.ArrayType) + emit_error(B, "Enzyme: incorrect return type of forward custom rule - "*(string(RT))*" "*string(activity)) + return + end + normalV = extract_value!(B, res, 0).ref + shadowV = extract_value!(B, res, 1).ref + end + end + + if shadowR != C_NULL + unsafe_store!(shadowR, shadowV) + end + + # Delete the primal code + if needsPrimal + unsafe_store!(normalR, normalV) + else + LLVM.API.LLVMInstructionEraseFromParent(LLVM.Instruction(API.EnzymeGradientUtilsNewFromOriginal(gutils, orig))) + end + + return nothing +end + +function enzyme_custom_common_rev(forward::Bool, B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR, shadowR, tape)::LLVM.API.LLVMValueRef + + orig = LLVM.Instruction(OrigCI) + ctx = LLVM.context(orig) + B = LLVM.Builder(B) + + width = API.EnzymeGradientUtilsGetWidth(gutils) + + shadowType = LLVM.LLVMType(API.EnzymeGetShadowType(width, llvmtype(orig))) + if shadowR != C_NULL + unsafe_store!(shadowR,UndefValue(shadowType).ref) + end + + # TODO: don't inject the code multiple times for multiple calls + + # 1) extract out the MI from attributes + mi, job = enzyme_custom_extract_mi(orig) + + # 2) Create activity, and annotate function spec + args, activity, overwritten, actives = enzyme_custom_setup_args(B, orig, gutils, mi) + RealRt, RT, needsPrimal, needsShadow = enzyme_custom_setup_ret(gutils, orig, mi, job) + + alloctx = LLVM.Builder(ctx) + position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) + + curent_bb = position(B) + fn = LLVM.parent(curent_bb) + world = enzyme_extract_world(fn) + + C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadow), Int(width), overwritten} + augprimal_tt = copy(activity) + insert!(augprimal_tt, 2, Type{RT}) + pushfirst!(augprimal_tt, C) + augprimal_TT = Tuple{augprimal_tt...} + + rev_TT = nothing + aug_RT = Core.Compiler.return_type(EnzymeRules.augmented_primal, augprimal_TT, world) + + if aug_RT === Union{} || + aug_RT isa Union || + !(aug_RT <: Tuple) || + length(aug_RT.parameters) != 2 + + @safe_debug "Custom augmented_primal rule has invalid return type" TT=augprimal_TT RT=aug_RT + emit_error(B, orig, "Enzyme: Custom rule " * string(augprimal_TT) * " resulted in " * string(aug_RT)) + return C_NULL + end + + TapeT = aug_RT.parameters[2] + + mode = API.EnzymeGradientUtilsGetMode(gutils) + mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) + + if forward + if EnzymeRules.isapplicable(EnzymeRules.augmented_primal, augprimal_TT; world) + @safe_debug "Applying custom augmented_primal rule" TT=augprimal_TT + llvmf = nested_codegen!(mode, mod, EnzymeRules.augmented_primal, augprimal_TT) + else + @safe_debug "No custom augmented_primal rule is applicable for" augprimal_TT + emit_error(B, orig, "Enzyme: No custom rule was appliable for " * string(augprimal_TT)) + return C_NULL + end + else + tt = copy(activity) + insert!(tt, 2, RT) + insert!(tt, 3, TapeT) + pushfirst!(tt, C) + TT = Tuple{tt...} + rev_TT = TT + + if EnzymeRules.isapplicable(EnzymeRules.reverse, TT; world) + @safe_debug "Applying custom reverse rule" TT + llvmf = nested_codegen!(mode, mod, EnzymeRules.reverse, TT) + else + @safe_debug "No custom reverse rule is applicable for" TT + emit_error(B, orig, "Enzyme: No custom rule was appliable for " * string(TT)) + return C_NULL + end + end + + needsTape = !GPUCompiler.isghosttype(TapeT) && !Core.Compiler.isconstType(TapeT) + + tapeV = C_NULL + if forward && needsTape + tapeV = LLVM.UndefValue(convert(LLVMType, TapeT; ctx)).ref + end + + # if !forward + # argTys = copy(activity) + # if RT <: Active + # if width == 1 + # push!(argTys, RealRt) + # else + # push!(argTys, NTuple{RealRt, (Int64)width}) + # end + # end + # push!(argTys, tapeType) + # llvmf = nested_codegen!(mode, mod, rev_func, Tuple{argTys...}) + # end + + sret = nothing + if !isempty(parameters(llvmf)) && any(map(k->kind(k)==kind(EnumAttribute("sret"; ctx)), collect(parameter_attributes(llvmf, 1)))) + sret = alloca!(alloctx, eltype(llvmtype(parameters(llvmf)[1]))) + pushfirst!(args, sret) + end + + if !forward + if RT <: Active + + alloctx = LLVM.Builder(ctx) + position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) + + llty = convert(LLVMType, RT; ctx) + + arval = LLVM.UndefValue(llty) + val = LLVM.Value(API.EnzymeGradientUtilsDiffe(gutils, orig, B)) + arval = insert_value!(B, arval, val, 0) + + al = alloca!(alloctx, llvmtype(arval)) + store!(B, arval, al) + + pushfirst!(args, al) + end + if needsTape + @assert tape != C_NULL + pushfirst!(args, LLVM.Value(tape)) + end + end + + for i in 1:length(args) + party = llvmtype(parameters(llvmf)[i]) + if llvmtype(args[i]) == party + continue + end + if isa(party, LLVM.PointerType) && LLVM.addrspace(party) != 0 + args[i] = addrspacecast!(B, args[i], party) + else + GPUCompiler.@safe_error "Calling convention mismatch", party, args[i], i, llvmf, augprimal_TT, rev_TT + return tapeV + end + end + + res = LLVM.call!(B, llvmf, args) + API.EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, res, orig) + + hasNoRet = any(map(k->kind(k)==kind(EnumAttribute("noreturn"; ctx)), collect(function_attributes(llvmf)))) + + if hasNoRet + return tapeV + end + + if sret !== nothing + res = load!(B, sret) + end + + shadowV = C_NULL + normalV = C_NULL + + + if forward + if !needsPrimal && !needsShadow && !needsTape + else + if !isa(llvmtype(res), LLVM.StructType) && !isa(llvmtype(res), LLVM.ArrayType) + emit_error(B, "Enzyme: incorrect return type of augmented forward custom rule - "*(string(RT))*" "*string(activity)) + return tapeV + end + idx = 0 + if needsPrimal + normalV = extract_value!(B, res, idx) + if llvmtype(normalV) != llvmtype(orig) + GPUCompiler.@safe_error "Primal calling convention mismatch found ", normalV, " wanted ", llvmtype(orig) + return tapeV + end + normalV = normalV.ref + idx+=1 + end + if needsShadow + shadowV = extract_value!(B, res, idx).ref + if llvmtype(shadowV) != shadowType + GPUCompiler.@safe_error "Shadow calling convention mismatch found ", shadowV, " wanted ",shadowType + return tapeV + end + shadowV = shadowV.ref + idx+=1 + end + if needsTape + tapeV = extract_value!(B, res, idx).ref + idx+=1 + end + end + else + if length(actives) >= 1 && !isa(llvmtype(res), LLVM.StructType) && !isa(llvmtype(res), LLVM.ArrayType) + GPUCompiler.@safe_error "Shadow arg calling convention mismatch found return ", res + return tapeV + end + + idx = 0 + for v in actives + ext = extract_value!(B, res, idx) + shadowVType = LLVM.LLVMType(API.EnzymeGetShadowType(width, llvmtype(v))) + if llvmtype(ext) != shadowType + GPUCompiler.@safe_error "Shadow arg calling convention mismatch found ", ext, " wanted ",shadowVType + return tapeV + end + Typ = C_NULL + API.EnzymeGradientUtilsAddToDiffe(gutils, v, ext, B, Typ) + idx+=1 + end + end + + if forward + if shadowR != C_NULL + unsafe_store!(shadowR, shadowV) + end + + # Delete the primal code + if needsPrimal + unsafe_store!(normalR, normalV) + else + LLVM.API.LLVMInstructionEraseFromParent(LLVM.Instruction(API.EnzymeGradientUtilsNewFromOriginal(gutils, orig))) + end + end + + return tapeV +end + + +function enzyme_custom_augfwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef}, tapeR::Ptr{LLVM.API.LLVMValueRef})::Cvoid + tape = enzyme_custom_common_rev(#=forward=#true, B, OrigCI, gutils, normalR, shadowR, #=tape=#nothing) + if tape != C_NULL + unsafe_store!(tapeR, tape) + end + return nothing +end + + +function enzyme_custom_rev(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, tape::LLVM.API.LLVMValueRef)::Cvoid + enzyme_custom_common_rev(#=forward=#false, B, OrigCI, gutils, #=normalR=#C_NULL, #=shadowR=#C_NULL, #=tape=#tape) + return nothing +end + function arraycopy_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef})::Cvoid orig = LLVM.Instruction(OrigCI) @@ -4544,6 +5044,12 @@ function __init__() @cfunction(enq_work_rev, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef)), @cfunction(enq_work_fwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef})) ) + register_handler!( + ("enzyme_custom",), + @cfunction(enzyme_custom_augfwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef})), + @cfunction(enzyme_custom_rev, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef)), + @cfunction(enzyme_custom_fwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef})) + ) register_handler!( ("jl_wait",), @cfunction(wait_augfwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef})), @@ -4716,6 +5222,7 @@ struct EnzymeCompilerParams <: AbstractEnzymeCompilerParams end struct PrimalCompilerParams <: AbstractEnzymeCompilerParams + mode::API.CDerivativeMode end DefaultCompilerTarget(;kwargs...) = GPUCompiler.NativeCompilerTarget(;jlruntime=true, kwargs...) @@ -4735,7 +5242,7 @@ GPUCompiler.runtime_slug(job::CompilerJob{EnzymeTarget}) = "enzyme" # provide a specific interpreter to use. GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = - Interpreter.EnzymeInterpeter(GPUCompiler.ci_cache(job), GPUCompiler.method_table(job), job.source.world) + Interpreter.EnzymeInterpreter(GPUCompiler.ci_cache(job), GPUCompiler.method_table(job), job.source.world, job.params.mode) include("compiler/utils.jl") include("compiler/passes.jl") @@ -5039,11 +5546,11 @@ function alloc_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeT return UInt8(false) end -function enzyme_extract_world(fn::LLVM.Function) +function enzyme_extract_world(fn::LLVM.Function)::UInt64 for fattr in collect(function_attributes(fn)) if isa(fattr, LLVM.StringAttribute) if kind(fattr) == "enzymejl_world" - return parse(Int, LLVM.value(fattr)) + return parse(UInt64, LLVM.value(fattr)) end end end @@ -5733,7 +6240,7 @@ function create_abi_wrapper(enzymefn::LLVM.Function, F, argtypes, rettype, actua #cf = add_one_in_place_gen(eltype(rettype)) #cf = inttoptr!(builder, cf, LLVM.PointerType(LLVM.FunctionType(T_void, [convert(LLVMType, eltype(rettype); ctx)]))) - cf = nested_codegen!(mod, add_one_in_place, Tuple{Any}) + cf = nested_codegen!(Mode, mod, add_one_in_place, Tuple{Any}) push!(function_attributes(cf), EnumAttribute("alwaysinline", 0; ctx)) permit_inlining!(cf) for shadowv in shadows @@ -6271,7 +6778,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if parent_job === nothing primal_target = DefaultCompilerTarget() - primal_params = PrimalCompilerParams() + primal_params = PrimalCompilerParams(mode) primal_job = CompilerJob(primal_target, primal, primal_params) else primal_job = similar(parent_job, job.source) @@ -6368,11 +6875,28 @@ end foundTys = Dict{String, Tuple{LLVM.FunctionType, Core.MethodInstance}}() jobref = Ref(job) + world = job.source.world actualRetType = nothing + customDerivativeNames = String[] for (mi, k) in meta.compiled k_name = GPUCompiler.safe_name(k.specfunc) - haskey(functions(mod), k_name) || continue + has_custom_rule = false + if mode == API.DEM_ForwardMode + has_custom_rule = EnzymeRules.has_frule_from_sig(mi.specTypes; world) + if has_custom_rule + @safe_debug "Found frule for" mi.specTypes + end + else + has_custom_rule = EnzymeRules.has_rrule_from_sig(mi.specTypes; world) + if has_custom_rule + @safe_debug "Found rrule for" mi.specTypes + end + end + + if !(haskey(functions(mod), k_name) || has_custom_rule) + continue + end llvmfn = functions(mod)[k_name] if llvmfn == primalf @@ -6393,6 +6917,7 @@ end push!(attributes, a) end push!(attributes, StringAttribute("enzymejl_mi", string(convert(Int, pointer_from_objref(mi))); ctx)) + push!(attributes, StringAttribute("enzymejl_job", string(convert(Int, pointer_from_objref(jobref))); ctx)) push!(attributes, StringAttribute("enzyme_math", name; ctx)) push!(attributes, EnumAttribute("noinline", 0; ctx)) must_wrap |= llvmfn == primalf @@ -6400,10 +6925,14 @@ end end foundTys[k_name] = (eltype(llvmtype(llvmfn)), mi) + if has_custom_rule + handleCustom("enzyme_custom") + continue + end Base.isbindingresolved(jlmod, name) && isdefined(jlmod, name) || continue func = getfield(jlmod, name) - + sparam_vals = mi.specTypes.parameters[2:end] # mi.sparam_vals if func == Base.eps || func == Base.nextfloat || func == Base.prevfloat handleCustom("jl_inactive_inout", [StringAttribute("enzyme_inactive"; ctx), @@ -6535,7 +7064,7 @@ end end primalf = wrapper_f end - + source_sig = GPUCompiler.typed_signature(job)::Type primalf, returnRoots = lower_convention(source_sig, mod, primalf, actualRetType) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 5436d5cc70..a6347489f3 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -1,10 +1,12 @@ module Interpreter using Random +import Enzyme: API using Core.Compiler: AbstractInterpreter, InferenceResult, InferenceParams, InferenceState, OptimizationParams, MethodInstance using GPUCompiler: CodeCache, WorldView import ..Enzyme +import ..EnzymeRules -struct EnzymeInterpeter <: AbstractInterpreter +struct EnzymeInterpreter <: AbstractInterpreter global_cache::CodeCache method_table::Union{Nothing,Core.MethodTable} @@ -17,7 +19,9 @@ struct EnzymeInterpeter <: AbstractInterpreter inf_params::InferenceParams opt_params::OptimizationParams - function EnzymeInterpeter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt) + mode::API.CDerivativeMode + + function EnzymeInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt, mode::API.CDerivativeMode) @assert world <= Base.get_world_counter() return new( @@ -34,42 +38,43 @@ struct EnzymeInterpeter <: AbstractInterpreter InferenceParams(unoptimize_throw_blocks=false), VERSION >= v"1.8.0-DEV.486" ? OptimizationParams() : OptimizationParams(unoptimize_throw_blocks=false), + mode ) end end -Core.Compiler.InferenceParams(interp::EnzymeInterpeter) = interp.inf_params -Core.Compiler.OptimizationParams(interp::EnzymeInterpeter) = interp.opt_params -Core.Compiler.get_world_counter(interp::EnzymeInterpeter) = interp.world -Core.Compiler.get_inference_cache(interp::EnzymeInterpeter) = interp.local_cache -Core.Compiler.code_cache(interp::EnzymeInterpeter) = WorldView(interp.global_cache, interp.world) +Core.Compiler.InferenceParams(interp::EnzymeInterpreter) = interp.inf_params +Core.Compiler.OptimizationParams(interp::EnzymeInterpreter) = interp.opt_params +Core.Compiler.get_world_counter(interp::EnzymeInterpreter) = interp.world +Core.Compiler.get_inference_cache(interp::EnzymeInterpreter) = interp.local_cache +Core.Compiler.code_cache(interp::EnzymeInterpreter) = WorldView(interp.global_cache, interp.world) # No need to do any locking since we're not putting our results into the runtime cache -Core.Compiler.lock_mi_inference(interp::EnzymeInterpeter, mi::MethodInstance) = nothing -Core.Compiler.unlock_mi_inference(interp::EnzymeInterpeter, mi::MethodInstance) = nothing +Core.Compiler.lock_mi_inference(interp::EnzymeInterpreter, mi::MethodInstance) = nothing +Core.Compiler.unlock_mi_inference(interp::EnzymeInterpreter, mi::MethodInstance) = nothing -function Core.Compiler.add_remark!(interp::EnzymeInterpeter, sv::InferenceState, msg) +function Core.Compiler.add_remark!(interp::EnzymeInterpreter, sv::InferenceState, msg) end -Core.Compiler.may_optimize(interp::EnzymeInterpeter) = true -Core.Compiler.may_compress(interp::EnzymeInterpeter) = true +Core.Compiler.may_optimize(interp::EnzymeInterpreter) = true +Core.Compiler.may_compress(interp::EnzymeInterpreter) = true # From @aviatesk: # `may_discard_trees = true`` means a complicated (in terms of inlineability) source will be discarded, # but as far as I understand Enzyme wants "always inlining, except special cased functions", # so I guess we really don't want to discard sources? -Core.Compiler.may_discard_trees(interp::EnzymeInterpeter) = false +Core.Compiler.may_discard_trees(interp::EnzymeInterpreter) = false if VERSION >= v"1.7.0-DEV.577" -Core.Compiler.verbose_stmt_info(interp::EnzymeInterpeter) = false +Core.Compiler.verbose_stmt_info(interp::EnzymeInterpreter) = false end if isdefined(Base.Experimental, Symbol("@overlay")) -Core.Compiler.method_table(interp::EnzymeInterpeter, sv::InferenceState) = +Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = Core.Compiler.OverlayMethodTable(interp.world, interp.method_table) else # On 1.6- CUDA.jl will poison the method table at the end of the world # using GPUCompiler: WorldOverlayMethodTable -# Core.Compiler.method_table(interp::EnzymeInterpeter, sv::InferenceState) = +# Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) = # WorldOverlayMethodTable(interp.world) end @@ -143,12 +148,21 @@ end @static if VERSION ≥ v"1.9.0-DEV.1535" import Core.Compiler: CallInfo -function Core.Compiler.inlining_policy(interp::EnzymeInterpeter, +function Core.Compiler.inlining_policy(interp::EnzymeInterpreter, @nospecialize(src), @nospecialize(info::CallInfo), stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) if is_primitive_func(mi.specTypes) return nothing end + if interp.mode == API.DEM_ForwardMode + if EnzymeRules.has_frule_from_sig(mi.specTypes; world = interp.world) + return nothing + end + else + if EnzymeRules.has_rrule_from_sig(mi.specTypes; world = interp.world) + return nothing + end + end return Base.@invoke Core.Compiler.inlining_policy(interp::AbstractInterpreter, src::Any, info::CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) @@ -157,12 +171,21 @@ end # https://github.com/JuliaLang/julia/pull/41328 elseif isdefined(Core.Compiler, :is_stmt_inline) -function Core.Compiler.inlining_policy(interp::EnzymeInterpeter, +function Core.Compiler.inlining_policy(interp::EnzymeInterpreter, @nospecialize(src), stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) if is_primitive_func(mi.specTypes) return nothing end + if interp.mode == API.DEM_ForwardMode + if EnzymeRules.has_frule_from_sig(mi.specTypes; world = interp.world) + return nothing + end + else + if EnzymeRules.has_rrule_from_sig(mi.specTypes; world = interp.world) + return nothing + end + end return Base.@invoke Core.Compiler.inlining_policy(interp::AbstractInterpreter, src::Any, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) @@ -171,13 +194,27 @@ end elseif isdefined(Core.Compiler, :inlining_policy) import Core.Compiler: InliningTodo, InliningState -enzyme_inlining_policy(@nospecialize(src)) = Core.Compiler.default_inlining_policy(src) -Core.Compiler.inlining_policy(::EnzymeInterpeter) = enzyme_inlining_policy -function Core.Compiler.resolve_todo(todo::InliningTodo, state::InliningState{S, T, <:typeof(enzyme_inlining_policy)}) where {S<:Union{Nothing, Core.Compiler.EdgeTracker}, T} +struct EnzymeInliningPolicy + interp::EnzymeInterpreter +end +(::EnzymeInliningPolicy)(@nospecialize(src)) = Core.Compiler.default_inlining_policy(src) +Core.Compiler.inlining_policy(interp::EnzymeInterpreter) = EnzymeInliningPolicy(interp) + +function Core.Compiler.resolve_todo(todo::InliningTodo, state::InliningState{S, T, <:EnzymeInliningPolicy}) where {S<:Union{Nothing, Core.Compiler.EdgeTracker}, T} mi = todo.mi - if is_primitive_func(mi.specTypes) + if is_primitive_func(mi.specTypes) return Core.Compiler.compileable_specialization(state.et, todo.spec.match) end + interp = state.policy.interp + if interp.mode == API.DEM_ForwardMode + if EnzymeRules.has_frule_from_sig(mi.specTypes; world = interp.world) + return Core.Compiler.compileable_specialization(state.et, todo.spec.match) + end + else + if EnzymeRules.has_rrule_from_sig(mi.specTypes; world = interp.world) + return Core.Compiler.compileable_specialization(state.et, todo.spec.match) + end + end return Base.@invoke Core.Compiler.resolve_todo( todo::InliningTodo, state::InliningState) diff --git a/src/compiler/pmap.jl b/src/compiler/pmap.jl index 501c52e0ed..e33fc54d07 100644 --- a/src/compiler/pmap.jl +++ b/src/compiler/pmap.jl @@ -184,7 +184,8 @@ end splat, _ = julia_activity(mi.specTypes.parameters, (mode != API.DEM_ReverseModeGradient) ? [Type{thunkTy}, Val{any_jltypes(TapeType)}, Int, funcT, funcT] : [Type{thunkTy}, Val{any_jltypes(TapeType)}, Int, STT, funcT, funcT], ops, gutils) tt = Tuple{splat...} - entry = nested_codegen!(mod, runtime_fn, tt) + mode = API.EnzymeGradientUtilsGetMode(gutils) + entry = nested_codegen!(mode, mod, runtime_fn, tt) # 5) Call the function B = LLVM.Builder(B) diff --git a/src/compiler/reflection.jl b/src/compiler/reflection.jl index 8fa8e04425..1c71ccf699 100644 --- a/src/compiler/reflection.jl +++ b/src/compiler/reflection.jl @@ -11,7 +11,6 @@ function get_job(@nospecialize(func), @nospecialize(A), @nospecialize(types); return Compiler.CompilerJob(target, primal, params) end - function reflect(@nospecialize(func), @nospecialize(A), @nospecialize(types); optimize::Bool=true, second_stage::Bool=true, ctx=nothing, kwargs...) diff --git a/test/rrules.jl b/test/rrules.jl new file mode 100644 index 0000000000..2476445ff5 --- /dev/null +++ b/test/rrules.jl @@ -0,0 +1,63 @@ +module ReverseRules + +using Enzyme +using Enzyme: EnzymeRules +using Test + +f(x) = x^2 + +function f_ip(x) + x[1] *= x[1] + return nothing +end + +import .EnzymeRules: augmented_primal, reverse, Annotation, has_rrule, has_rrule_from_sig +using .EnzymeRules + +function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f)}, ::Type{<:Active}, x::Active) + if needs_primal(config) + return (func.val(x.val), nothing) + else + return (nothing, nothing) + end +end + +function reverse(config::ConfigWidth{1}, ::Const{typeof(f)}, dret::Active, tape, x::Active) + if needs_primal(config) + return (10+2*x.val*dret.val,) + else + return (100+2*x.val*dret.val,) + end +end + +function augmented_primal(::Config{false, false, 1}, func::Const{typeof(f_ip)}, ::Type{<:Const}, x::Duplicated) + v = x.val[1] + x.val[1] *= v + return (nothing, v) +end + +function reverse(::Config{false, false, 1}, ::Const{typeof(f_ip)}, ::Const, tape, x::Duplicated) + x.dval[1] = 100 + x.dval[1] * tape + return () +end + +@testset "has_rrule" begin + @test has_rrule_from_sig(Base.signature_type(f, Tuple{Float64})) + @test has_rrule_from_sig(Base.signature_type(f_ip, Tuple{Vector{Float64}})) +end + + +@testset "Custom Reverse Rules" begin + @test Enzyme.autodiff(Enzyme.Reverse, f, Active(2.0))[1][1] ≈ 104.0 + @test Enzyme.autodiff(Enzyme.Reverse, x->f(x)^2, Active(2.0))[1][1] ≈ 42.0 + + x = [2.0] + dx = [1.0] + + Enzyme.autodiff(Enzyme.Reverse, f_ip, Duplicated(x, dx)) + + @test x ≈ [4.0] + @test dx ≈ [102.0] +end + +end # ReverseRules diff --git a/test/rules.jl b/test/rules.jl new file mode 100644 index 0000000000..6f22e3ff5c --- /dev/null +++ b/test/rules.jl @@ -0,0 +1,92 @@ +module ForwardRules + +using Enzyme +using Enzyme: EnzymeRules +using Test + +import .EnzymeRules: forward, Annotation, has_frule, has_frule_from_sig + +f(x) = x^2 + +function f_ip(x) + x[1] *= x[1] + return nothing +end + +function forward(::Const{typeof(f)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) + return 10+2*x.val*x.dval +end + +function forward(::Const{typeof(f)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T, N}) where {T, N} + return NTuple{N, T}(1000+2*x.val*dv for dv in x.dval) +end + +function forward(func::Const{typeof(f)}, ::Type{<:Duplicated}, x::Duplicated) + return Duplicated(func.val(x.val), 100+2*x.val*x.dval) +end + +function forward(func::Const{typeof(f)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T, N}) where {T,N} + return BatchDuplicated(func.val(x.val), NTuple{N, T}(10000+2*x.val*dv for dv in x.dval)) +end + +function forward(::Const{Core.typeof(f_ip)}, ::Type{<:Const}, x::Duplicated) + ld = x.val[1] + x.val[1] *= ld + x.dval[1] *= 2 * ld + 10 + return nothing +end + +@testset "has_frule" begin + @test has_frule_from_sig(Base.signature_type(f, Tuple{Float64})) + @test has_frule_from_sig(Base.signature_type(f_ip, Tuple{Vector{Float64}})) + + @test has_frule(f) + @test has_frule(f, Duplicated) + @test has_frule(f, DuplicatedNoNeed) + @test has_frule(f, BatchDuplicated) + @test has_frule(f, BatchDuplicatedNoNeed) + @test has_frule(f, Duplicated, Tuple{<:Duplicated}) + @test has_frule(f, DuplicatedNoNeed, Tuple{<:Duplicated}) + @test has_frule(f, BatchDuplicated, Tuple{<:BatchDuplicated}) + @test has_frule(f, BatchDuplicatedNoNeed, Tuple{<:BatchDuplicated}) + + @test !has_frule(f, Duplicated, Tuple{<:BatchDuplicated}) + @test !has_frule(f, DuplicatedNoNeed, Tuple{<:BatchDuplicated}) + @test !has_frule(f, BatchDuplicated, Tuple{<:Duplicated}) + @test !has_frule(f, BatchDuplicatedNoNeed, Tuple{<:Duplicated}) + + @test has_frule(f, Tuple{<:Duplicated}) + @test has_frule(f, Tuple{<:BatchDuplicated}) + @test has_frule(f, Tuple{<:Annotation}) + @test has_frule(f, Tuple{<:Annotation{Float64}}) + @test !has_frule(f, Tuple{<:Const}) +end + +@testset "autodiff(Forward, ...) custom rules" begin + @test autodiff(Forward, f, Duplicated(2.0, 1.0))[1] ≈ 14.0 + @test autodiff(Forward, x->f(x)^2, Duplicated(2.0, 1.0))[1] ≈ 832.0 + + res = autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(2.0, (1.0, 3.0)))[1] + @test res[1] ≈ 1004.0 + @test res[2] ≈ 1012.0 + + res = Enzyme.autodiff(Forward, x->f(x)^2, BatchDuplicatedNoNeed, BatchDuplicated(2.0, (1.0, 3.0)))[1] + + @test res[1] ≈ 80032.0 + @test res[2] ≈ 80096.0 +end + +@testset "In place" begin + vec = [2.0] + dvec = [1.0] + + Enzyme.autodiff(Forward, f_ip, Duplicated(vec, dvec)) + + @test vec ≈ [4.0] + @test dvec ≈ [14.0] +end + +# TODO: Test error for no frule applicable despite frule on Function. + + +end # module ForwardRules diff --git a/test/runtests.jl b/test/runtests.jl index ef1ec5af3e..4651d2dd03 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,14 +41,20 @@ end include("abi.jl") include("typetree.jl") +if Enzyme.EnzymeRules.issupported() + include("rules.jl") + include("rrules.jl") +end + f0(x) = 1.0 + x - function vrec(start, x) - if start > length(x) - return 1.0 - else - return x[start] * vrec(start+1, x) - end +function vrec(start, x) + if start > length(x) + return 1.0 + else + return x[start] * vrec(start+1, x) end +end + @testset "Internal tests" begin thunk_a = Enzyme.Compiler.thunk(f0, nothing, Active, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1)) thunk_b = Enzyme.Compiler.thunk(f0, nothing, Const, Tuple{Const{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1))