diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 0b34b51743..5aa795504f 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -126,4 +126,6 @@ struct ForwardMode <: Mode end const Forward = ForwardMode() +include("rules.jl") + end # module EnzymeCore diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl new file mode 100644 index 0000000000..97b53e6caa --- /dev/null +++ b/lib/EnzymeCore/src/rules.jl @@ -0,0 +1,113 @@ +module EnzymeRules + +import EnzymeCore: Annotation +export Config, ConfigWidth +export needs_primal, needs_shadow, width, overwritten + +""" + forward(func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...) + +Calculate the forward derivative. The first argument `func` is the callable +for which the rule applies to. Either wrapped in a [`Const`](@ref)), or +a [`Duplicated`](@ref) if it is a closure. +The second argument is the return type annotation, and all other arguments are +the annotated function arguments. +""" +function forward end + +struct Config{NeedsPrimal, NeedsShadow, Width, Overwritten} end +const ConfigWidth{Width} = Config{<:Any,<:Any, Width} + +needs_primal(::Config{NeedsPrimal}) where NeedsPrimal = NeedsPrimal +needs_shadow(::Config{<:Any, NeedsShadow}) where NeedsShadow = NeedsShadow +width(::Config{<:Any, <:Any, Width}) where Width = Width +overwritten(::Config{<:Any, <:Any, <:Any, Overwritten}) where Overwritten = Overwritten + +""" + augmented_primal(::Config, func::Annotation{typeof(f)}, RT::Type{<:Annotation}, args::Annotation...) + +Must return a tuple of length 2. +The first-value is primal value and the second is the tape. If no tape is +required return `(val, nothing)`. +""" +function augmented_primal end + +""" + reverse(::Config, func::Annotation{typeof(f)}, dret::Annotation, tape, args::Annotation...) + +Takes gradient of derivative, activity annotation, and tape +""" +function reverse end + +_annotate(T::DataType) = TypeVar(gensym(), Annotation{T}) +_annotate(::Type{T}) where T = TypeVar(gensym(), Annotation{T}) +function _annotate(VA::Core.TypeofVararg) + T = _annotate(VA.T) + if isdefined(VA, :N) + return Vararg{T, VA.N} + else + return Vararg{T} + end +end + +function has_frule_from_sig(@nospecialize(TT); world=Base.get_world_counter()) + TT = Base.unwrap_unionall(TT) + ft = TT.parameters[1] + tt = map(_annotate, TT.parameters[2:end]) + TT = Tuple{<:Annotation{ft}, Type{<:Annotation}, tt...} + isapplicable(forward, TT; world) +end + +function has_rrule_from_sig(@nospecialize(TT); world=Base.get_world_counter()) + TT = Base.unwrap_unionall(TT) + ft = TT.parameters[1] + tt = map(_annotate, TT.parameters[2:end]) + TT = Tuple{<:Config, <:Annotation{ft}, <:Annotation, <:Any, tt...} + isapplicable(reverse, TT; world) +end + +function has_frule(@nospecialize(f); world=Base.get_world_counter()) + TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{<:Annotation}, Vararg{<:Annotation}} + isapplicable(forward, TT; world) +end + +# Do we need this one? +function has_frule(@nospecialize(f), @nospecialize(TT::Type{<:Tuple}); world=Base.get_world_counter()) + TT = Base.unwrap_unionall(TT) + TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{<:Annotation}, TT.parameters...} + isapplicable(forward, TT; world) +end + +# Do we need this one? +function has_frule(@nospecialize(f), @nospecialize(RT::Type); world=Base.get_world_counter()) + TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{RT}, Vararg{<:Annotation}} + isapplicable(forward, TT; world) +end + +# Do we need this one? +function has_frule(@nospecialize(f), @nospecialize(RT::Type), @nospecialize(TT::Type{<:Tuple}); world=Base.get_world_counter()) + TT = Base.unwrap_unionall(TT) + TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{RT}, TT.parameters...} + isapplicable(forward, TT; world) +end + +# Base.hasmethod is a precise match we want the broader query. +function isapplicable(@nospecialize(f), @nospecialize(TT); world=Base.get_world_counter()) + tt = Base.to_tuple_type(TT) + sig = Base.signature_type(f, tt) + return !isempty(Base._methods_by_ftype(sig, -1, world)) # TODO cheaper way of querying? +end + +function has_rrule(@nospecialize(TT), world=Base.get_world_counter()) + return false +end + +function issupported() + @static if VERSION < v"1.7.0" + return false + else + return true + end +end + +end # EnzymeRules diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 1662a243bc..f4f7b36232 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -15,6 +15,8 @@ export markType, batch_size, onehot, chunkedonehot using LinearAlgebra import EnzymeCore: ReverseMode, ForwardMode, Annotation, Mode +import EnzymeCore: EnzymeRules + # Independent code, must be loaded before "compiler.jl" include("pmap.jl") @@ -61,7 +63,6 @@ end end end - include("logic.jl") include("typeanalysis.jl") include("typetree.jl") diff --git a/src/api.jl b/src/api.jl index f926d93373..c88e9a918c 100644 --- a/src/api.jl +++ b/src/api.jl @@ -196,6 +196,7 @@ EnzymeRegisterFwdCallHandler(name, fwdhandle) = ccall((:EnzymeRegisterFwdCallHan EnzymeGetShadowType(width, T) = ccall((:EnzymeGetShadowType, libEnzyme), LLVMTypeRef, (UInt64,LLVMTypeRef), width, T) +EnzymeGradientUtilsGetMode(gutils) = ccall((:EnzymeGradientUtilsGetMode, libEnzyme), CDerivativeMode, (EnzymeGradientUtilsRef,), gutils) EnzymeGradientUtilsGetWidth(gutils) = ccall((:EnzymeGradientUtilsGetWidth, libEnzyme), UInt64, (EnzymeGradientUtilsRef,), gutils) EnzymeGradientUtilsNewFromOriginal(gutils, val) = ccall((:EnzymeGradientUtilsNewFromOriginal, libEnzyme), LLVMValueRef, (EnzymeGradientUtilsRef, LLVMValueRef), gutils, val) EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, val, orig) = ccall((:EnzymeGradientUtilsSetDebugLocFromOriginal, libEnzyme), Cvoid, (EnzymeGradientUtilsRef, LLVMValueRef, LLVMValueRef), gutils, val, orig) @@ -211,7 +212,11 @@ EnzymeGradientUtilsAllocationBlock(gutils) = ccall((:EnzymeGradientUtilsAllocati EnzymeGradientUtilsTypeAnalyzer(gutils) = ccall((:EnzymeGradientUtilsTypeAnalyzer, libEnzyme), EnzymeTypeAnalyzerRef, (EnzymeGradientUtilsRef,), gutils) EnzymeGradientUtilsAllocAndGetTypeTree(gutils, val) = ccall((:EnzymeGradientUtilsAllocAndGetTypeTree, libEnzyme), CTypeTreeRef, (EnzymeGradientUtilsRef,LLVMValueRef), gutils, val) + +EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall((:EnzymeGradientUtilsGetUncacheableArgs, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, UInt64), gutils, orig, uncacheable, size) +EnzymeGradientUtilsGetDiffeType(gutils, op, isforeign) = ccall((:EnzymeGradientUtilsGetDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, UInt8), gutils, op, isforeign) + EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP) = ccall((:EnzymeGradientUtilsGetReturnDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, Ptr{UInt8}), gutils, orig, needsPrimalP, needsShadowP) EnzymeGradientUtilsSubTransferHelper(gutils, mode, secretty, intrinsic, dstAlign, srcAlign, offset, dstConstant, origdst, srcConstant, origsrc, length, isVolatile, MTI, allowForward, shadowsLookedUp) = ccall((:EnzymeGradientUtilsSubTransferHelper, libEnzyme), diff --git a/src/compiler.jl b/src/compiler.jl index 3db79d34b1..ccabb2a941 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5,9 +5,10 @@ import Enzyme: Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, Bat Annotation, guess_activity, eltype, API, TypeTree, typetree, only!, shift!, data0!, merge!, TypeAnalysis, FnTypeInfo, Logic, allocatedinline, ismutabletype - using Enzyme +import EnzymeCore: EnzymeRules + using LLVM, GPUCompiler, Libdl import Enzyme_jll @@ -206,6 +207,7 @@ return_type(::AbstractThunk{F, RT, TT, Width, DF}) where {F, RT, TT, Width, DF} using .JIT +import GPUCompiler: @safe_debug, @safe_info, @safe_warn, @safe_error safe_println(head, tail) = ccall(:jl_safe_printf, Cvoid, (Cstring, Cstring...), "%s%s\n",head, tail) macro safe_show(exs...) @@ -219,16 +221,14 @@ macro safe_show(exs...) end declare_allocobj!(mod) = get_function!(mod, "julia.gc_alloc_obj") do ctx - T_jlvalue = LLVM.StructType(LLVMType[]; ctx) - T_prjlvalue = LLVM.PointerType(T_jlvalue, #= AddressSpace::Tracked =# 10) - - T_pjlvalue = LLVM.PointerType(T_jlvalue) - T_ppjlvalue = LLVM.PointerType(T_pjlvalue) - - #TODO make size_t > 32 => 64, else 32 - T_size = LLVM.IntType((sizeof(Csize_t)*8) > 32 ? 64 : 32; ctx) - LLVM.FunctionType(T_prjlvalue, [LLVM.PointerType(T_ppjlvalue), T_size, T_prjlvalue]) + Tracked = 10 + T_jlvalue = LLVM.StructType(LLVM.LLVMType[]; ctx) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + T_ppjlvalue = LLVM.PointerType(LLVM.PointerType(T_jlvalue)) + T_size_t = convert(LLVM.LLVMType, Int; ctx) + LLVM.FunctionType(T_prjlvalue, [T_ppjlvalue, T_size_t, T_prjlvalue]) end + function emit_allocobj!(B, T, size) curent_bb = position(B) fn = LLVM.parent(curent_bb) @@ -240,14 +240,18 @@ function emit_allocobj!(B, T, size) T_jlvalue = LLVM.StructType(LLVMType[]; ctx) T_prjlvalue = LLVM.PointerType(T_jlvalue, #= AddressSpace::Tracked =# 10) + T_ppjlvalue = LLVM.PointerType(LLVM.PointerType(T_jlvalue)) + + ty = LLVM.const_inttoptr(LLVM.ConstantInt(convert(Int, pointer_from_objref(T)); ctx), LLVM.PointerType(T_jlvalue)) + ty = LLVM.const_addrspacecast(ty, T_prjlvalue) + + pgcstack = reinsert_gcmarker!(fn, B) + ct = inbounds_gep!(B, bitcast!(B, pgcstack, T_ppjlvalue), [LLVM.ConstantInt(current_task_offset(); ctx)]) - ty = inttoptr!(B, LLVM.ConstantInt(convert(Int, pointer_from_objref(T)); ctx), LLVM.PointerType(T_jlvalue)) - ty = addrspacecast!(B, ty, T_prjlvalue) size = LLVM.ConstantInt(T_size, size) - args = [reinsert_gcmarker!(fn), size, ty] - args[1] = bitcast!(B, args[1], parameters(eltype(llvmtype(func)))[1]) - return call!(B, func, args) + return call!(B, func, [ct, size, ty]) end + function emit_allocobj!(B, T) emit_allocobj!(B, T, sizeof(T)) end @@ -864,15 +868,16 @@ function emit_gc_preserve_end(B::LLVM.Builder, token) end function generic_setup(orig, func, ReturnType, gutils, start, ctx::LLVM.Context, B::LLVM.Builder, lookup; sret=nothing, tape=nothing) + mode = API.EnzymeGradientUtilsGetMode(gutils) width = API.EnzymeGradientUtilsGetWidth(gutils) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) ops = collect(operands(orig))[(start+1):end-1] if tape === nothing - llvmf = nested_codegen!(mod, func, Tuple{Any, AnyArray(length(ops)), AnyArray(Int64(width)*length(ops)), Ptr{UInt8}, Val{Int64(width)}, Val{ReturnType}}) + llvmf = nested_codegen!(mode, mod, func, Tuple{Any, AnyArray(length(ops)), AnyArray(Int64(width)*length(ops)), Ptr{UInt8}, Val{Int64(width)}, Val{ReturnType}}) else - llvmf = nested_codegen!(mod, func, Tuple{Any, AnyArray(length(ops)), Ptr{Any}, Ptr{UInt8}, Any, Val{Int64(width)}, Val{ReturnType}}) + llvmf = nested_codegen!(mode, mod, func, Tuple{Any, AnyArray(length(ops)), Ptr{Any}, Ptr{UInt8}, Any, Val{Int64(width)}, Val{ReturnType}}) end T_int8 = LLVM.Int8Type(ctx) @@ -1605,7 +1610,7 @@ else const ctxToThreadSafe = Dict{LLVM.Context, LLVM.ThreadSafeContext}() end -function nested_codegen!(mod::LLVM.Module, f, tt) +function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, f, tt) # TODO: Put a cache here index on `mod` and f->tt ctx = LLVM.context(mod) @@ -1622,15 +1627,17 @@ end # - When OrcV2 only use a MaterializationUnit to avoid mutation of the module here target = GPUCompiler.NativeCompilerTarget() - params = Compiler.PrimalCompilerParams() + params = Compiler.PrimalCompilerParams(mode) job = CompilerJob(target, funcspec, params) - otherMod, meta = GPUCompiler.codegen(:llvm, job; optimize=false, validate=false, ctx) + # TODO + parent_job = nothing + otherMod, meta = GPUCompiler.codegen(:llvm, job; optimize=false, cleanup=false, validate=false, parent_job=parent_job, ctx) entry = name(meta.entry) # Apply first stage of optimization's so that this module is at the same stage as `mod` optimize!(otherMod, JIT.get_tm()) - + # 4) Link the corresponding module LLVM.link!(mod, otherMod) @@ -1891,7 +1898,8 @@ function threadsfor_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRe else tt = Tuple{funcT, Core.Ptr{Cvoid}, dfuncT, Type{thunkTy}, Bool} end - entry = nested_codegen!(mod, runtime_pfor_fwd, tt) + mode = API.EnzymeGradientUtilsGetMode(gutils) + entry = nested_codegen!(mode, mod, runtime_pfor_fwd, tt) permit_inlining!(entry) push!(function_attributes(entry), EnumAttribute("alwaysinline"; ctx)) @@ -1932,7 +1940,8 @@ function threadsfor_augfwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValu else tt = Tuple{funcT, Core.Ptr{Cvoid}, dfuncT, Type{thunkTy}, Val{any_jltypes(GetTapeType(thunkTy))}, Bool} end - entry = nested_codegen!(mod, runtime_pfor_augfwd, tt) + mode = API.EnzymeGradientUtilsGetMode(gutils) + entry = nested_codegen!(mode, mod, runtime_pfor_augfwd, tt) permit_inlining!(entry) push!(function_attributes(entry), EnumAttribute("alwaysinline"; ctx)) @@ -1983,7 +1992,8 @@ function threadsfor_rev(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRe else tt = Tuple{funcT, Core.Ptr{Cvoid}, dfuncT, Type{thunkTy}, Val{any_jltypes(GetTapeType(thunkTy))}, STT, Bool} end - entry = nested_codegen!(mod, runtime_pfor_rev, tt) + mode = API.EnzymeGradientUtilsGetMode(gutils) + entry = nested_codegen!(mode, mod, runtime_pfor_rev, tt) permit_inlining!(entry) push!(function_attributes(entry), EnumAttribute("alwaysinline"; ctx)) @@ -2010,7 +2020,8 @@ function newtask_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) width = API.EnzymeGradientUtilsGetWidth(gutils) - fun = nested_codegen!(mod, runtime_newtask_fwd, Tuple{Any, Any, Any, Int, Val{width}}) + mode = API.EnzymeGradientUtilsGetMode(gutils) + fun = nested_codegen!(mode, mod, runtime_newtask_fwd, Tuple{Any, Any, Any, Int, Val{width}}) B = LLVM.Builder(B) @@ -2053,7 +2064,8 @@ function newtask_augfwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRe GPUCompiler.@safe_warn "active variables passed by value to jl_new_task are not yet supported" width = API.EnzymeGradientUtilsGetWidth(gutils) - fun = nested_codegen!(mod, runtime_newtask_augfwd, Tuple{Any, Any, Any, Int, Val{width}}) + mode = API.EnzymeGradientUtilsGetMode(gutils) + fun = nested_codegen!(mode, mod, runtime_newtask_augfwd, Tuple{Any, Any, Any, Int, Val{width}}) B = LLVM.Builder(B) sret = allocate_sret!(gutils, 2, ctx) @@ -2221,6 +2233,485 @@ function wait_rev(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gut return nothing end +function enzyme_custom_setup_args(B, orig, gutils, mi) + ctx = LLVM.context(orig) + ops = collect(operands(orig)) + called = ops[end] + ops = ops[1:end-1] + width = API.EnzymeGradientUtilsGetWidth(gutils) + + args = LLVM.Value[] + activity = Type[] + overwritten = Bool[] + + actives = LLVM.Value[] + + uncacheable = Vector{UInt8}(undef, length(ops)) + API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) + + sret = false + returnRoots = false + + jlargs = classify_arguments(mi.specTypes, eltype(llvmtype(called)), sret, returnRoots) + + op_idx = 1 + + alloctx = LLVM.Builder(ctx) + position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) + + for arg in jlargs + if arg.cc == GPUCompiler.GHOST + push!(activity, Const{arg.typ}) + push!(overwritten, false) + continue + end + + op = ops[op_idx] + push!(overwritten, uncacheable[op_idx] != 0) + op_idx+=1 + + val = LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, op)) + + activep = API.EnzymeGradientUtilsGetDiffeType(gutils, op, #=isforeign=#false) + + # TODO type analysis deduce if duplicated vs active + if activep == API.DFT_CONSTANT + Ty = Const{arg.typ} + + llty = convert(LLVMType, Ty; ctx) + + arval = LLVM.UndefValue(llty) + arval = insert_value!(B, arval, val, 0) + + al = alloca!(alloctx, llvmtype(arval)) + store!(B, arval, al) + push!(args, al) + + push!(activity, Ty) + + elseif activep == API.DFT_OUT_DIFF + Ty = Active{arg.typ} + llty = convert(LLVMType, Ty; ctx) + + arval = LLVM.UndefValue(llty) + arval = insert_value!(B, arval, val, 0) + + al = alloca!(alloctx, llvmtype(arval)) + store!(B, arval, al) + push!(args, al) + + push!(activity, Ty) + push!(actives, op) + else + ival = LLVM.Value(API.EnzymeGradientUtilsInvertPointer(gutils, op, B)) + if width == 1 + if activep == API.DFT_DUP_ARG + Ty = Duplicated{arg.typ} + else + @assert activep == API.DFT_DUP_NONEED + Ty = DuplicatedNoNeed{arg.typ} + end + else + if activep == API.DFT_DUP_ARG + Ty = BatchDuplicated{arg.typ, Int64(width)} + else + @assert activep == API.DFT_DUP_NONEED + Ty = BatchDuplicatedNoNeed{arg.typ, Int64(width)} + end + end + + llty = convert(LLVMType, Ty; ctx) + + arval = LLVM.UndefValue(llty) + arval = insert_value!(B, arval, val, 0) + arval = insert_value!(B, arval, ival, 1) + + al = alloca!(alloctx, llvmtype(arval)) + store!(B, arval, al) + push!(args, al) + push!(activity, Ty) + end + + end + + @assert op_idx-1 == length(ops) + + return args, activity, (overwritten...,), actives +end + +function enzyme_custom_setup_ret(gutils, orig, mi, job) + width = API.EnzymeGradientUtilsGetWidth(gutils) + interp = GPUCompiler.get_interpreter(job) + RealRt = Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype + + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP) + needsPrimal = needsPrimalP[] != 0 + + if !needsPrimal && activep == API.DFT_DUP_ARG + activep = API.DFT_DUP_NONEED + end + + if activep == API.DFT_CONSTANT + RT = Const{RealRt} + + elseif activep == API.DFT_OUT_DIFF + RT = Active{RealRt} + + elseif activep == API.DFT_DUP_ARG + if width == 1 + RT = Duplicated{RealRt} + else + RT = BatchDuplicated{RealRt, Int64(width)} + end + else + @assert activep == API.DFT_DUP_NONEED + if width == 1 + RT = DuplicatedNoNeed{RealRt} + else + RT = BatchDuplicatedNoNeed{RealRt, Int64(width)} + end + end + return RealRt, RT, needsPrimal, needsShadowP[] != 0 +end + +function enzyme_custom_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef})::Cvoid + orig = LLVM.Instruction(OrigCI) + ctx = LLVM.context(orig) + B = LLVM.Builder(B) + + width = API.EnzymeGradientUtilsGetWidth(gutils) + + if shadowR != C_NULL + unsafe_store!(shadowR,UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llvmtype(orig)))).ref) + end + + # TODO: don't inject the code multiple times for multiple calls + + # 1) extract out the MI from attributes + mi, job = enzyme_custom_extract_mi(orig) + + # 2) Create activity, and annotate function spec + args, activity, overwritten, actives = enzyme_custom_setup_args(B, orig, gutils, mi) + RealRt, RT, needsPrimal, needsShadow = enzyme_custom_setup_ret(gutils, orig, mi, job) + + alloctx = LLVM.Builder(ctx) + position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) + mode = API.EnzymeGradientUtilsGetMode(gutils) + mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) + + tt = copy(activity) + insert!(tt, 2, Type{RT}) + TT = Tuple{tt...} + + # TODO get world + world= Base.get_current_world() + if EnzymeRules.isapplicable(EnzymeRules.forward, TT; world) + @safe_debug "Applying custom forward rule" TT + llvmf = nested_codegen!(mode, mod, EnzymeRules.forward, TT) + else + @safe_debug "No custom forward rule is applicable for" TT + emit_error(B, orig, "Enzyme: No custom rule was appliable for " * string(TT)) + return nothing + end + + + sret = nothing + if !isempty(parameters(llvmf)) && any(map(k->kind(k)==kind(EnumAttribute("sret"; ctx)), collect(parameter_attributes(llvmf, 1)))) + sret = alloca!(alloctx, eltype(llvmtype(parameters(llvmf)[1]))) + pushfirst!(args, sret) + end + + for i in eachindex(args) + party = llvmtype(parameters(llvmf)[i]) + if llvmtype(args[i]) == party + continue + end + if LLVM.addrspace(party) != 0 + args[i] = addrspacecast!(B, args[i], party) + else + GPUCompiler.@safe_error "Calling convention mismatch", party, args[i], i, llvmf + return + end + end + + res = LLVM.call!(B, llvmf, args) + API.EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, res, orig) + + hasNoRet = any(map(k->kind(k)==kind(EnumAttribute("noreturn"; ctx)), collect(function_attributes(llvmf)))) + + if hasNoRet + return nothing + end + + if sret !== nothing + res = load!(B, sret) + end + + shadowV = C_NULL + normalV = C_NULL + + if RT <: Const + if needsPrimal + normalV = res.ref + end + else + if !needsPrimal + shadowV = res.ref + else + if !isa(llvmtype(res), LLVM.StructType) && !isa(llvmtype(res), LLVM.ArrayType) + emit_error(B, "Enzyme: incorrect return type of forward custom rule - "*(string(RT))*" "*string(activity)) + return + end + normalV = extract_value!(B, res, 0).ref + shadowV = extract_value!(B, res, 1).ref + end + end + + if shadowR != C_NULL + unsafe_store!(shadowR, shadowV) + end + + # Delete the primal code + if needsPrimal + unsafe_store!(normalR, normalV) + else + LLVM.API.LLVMInstructionEraseFromParent(LLVM.Instruction(API.EnzymeGradientUtilsNewFromOriginal(gutils, orig))) + end + + return nothing +end + +function enzyme_custom_common_rev(forward::Bool, B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR, shadowR, tape)::LLVM.API.LLVMValueRef + + orig = LLVM.Instruction(OrigCI) + ctx = LLVM.context(orig) + B = LLVM.Builder(B) + + width = API.EnzymeGradientUtilsGetWidth(gutils) + + shadowType = LLVM.LLVMType(API.EnzymeGetShadowType(width, llvmtype(orig))) + if shadowR != C_NULL + unsafe_store!(shadowR,UndefValue(shadowType).ref) + end + + # TODO: don't inject the code multiple times for multiple calls + + # 1) extract out the MI from attributes + mi, job = enzyme_custom_extract_mi(orig) + + # 2) Create activity, and annotate function spec + args, activity, overwritten, actives = enzyme_custom_setup_args(B, orig, gutils, mi) + RealRt, RT, needsPrimal, needsShadow = enzyme_custom_setup_ret(gutils, orig, mi, job) + + alloctx = LLVM.Builder(ctx) + position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) + + # TODO get world + world = Base.get_world_counter() + + C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadow), Int(width), overwritten} + augprimal_tt = copy(activity) + insert!(augprimal_tt, 2, Type{RT}) + pushfirst!(augprimal_tt, C) + augprimal_TT = Tuple{augprimal_tt...} + + @safe_show augprimal_TT + aug_RT = Core.Compiler.return_type(EnzymeRules.augmented_primal, augprimal_TT, world) + @safe_show aug_RT + + if aug_RT === Union{} || + aug_RT isa Union || + !(aug_RT <: Tuple) || + length(aug_RT.parameters) != 2 + + @safe_debug "Custom augmented_primal rule has invalid return type" TT=augprimal_TT RT=aug_RT + emit_error(B, orig, "Enzyme: Custom rule " * string(augprimal_TT) * " resulted in " * string(aug_RT)) + return C_NULL + end + + TapeT = aug_RT.parameters[2] + @safe_show TapeT + + mode = API.EnzymeGradientUtilsGetMode(gutils) + mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) + + if forward + if EnzymeRules.isapplicable(EnzymeRules.augmented_primal, augprimal_TT; world) + @safe_debug "Applying custom augmented_primal rule" TT=augprimal_TT + llvmf = nested_codegen!(mode, mod, EnzymeRules.augmented_primal, augprimal_TT) + else + @safe_debug "No custom augmented_primal rule is applicable for" augprimal_TT + emit_error(B, orig, "Enzyme: No custom rule was appliable for " * string(augprimal_TT)) + return C_NULL + end + else + tt = copy(activity) + insert!(tt, 2, RT) + insert!(tt, 3, TapeT) + pushfirst!(tt, C) + TT = Tuple{tt...} + + if EnzymeRules.isapplicable(EnzymeRules.reverse, TT; world) + @safe_debug "Applying custom reverse rule" TT + llvmf = nested_codegen!(mode, mod, EnzymeRules.reverse, TT) + else + @safe_debug "No custom reverse rule is applicable for" TT + emit_error(B, orig, "Enzyme: No custom rule was appliable for " * string(TT)) + return C_NULL + end + end + + needsTape = !GPUCompiler.isghosttype(TapeT) && !Core.Compiler.isconstType(TapeT) + + tapeV = C_NULL + if forward && needsTape + tapeV = LLVM.UndefValue(convert(LLVMType, TapeT; ctx)).ref + end + + # if !forward + # argTys = copy(activity) + # if RT <: Active + # if width == 1 + # push!(argTys, RealRt) + # else + # push!(argTys, NTuple{RealRt, (Int64)width}) + # end + # end + # push!(argTys, tapeType) + # llvmf = nested_codegen!(mode, mod, rev_func, Tuple{argTys...}) + # end + + sret = nothing + if !isempty(parameters(llvmf)) && any(map(k->kind(k)==kind(EnumAttribute("sret"; ctx)), collect(parameter_attributes(llvmf, 1)))) + sret = alloca!(alloctx, eltype(llvmtype(parameters(llvmf)[1]))) + pushfirst!(args, sret) + end + + if !forward + if RT <: Active + push!(args, LLVM.Value(API.EnzymeGradientUtilsDiffe(gutils, orig, B))) + end + if needsTape + @assert tape != C_NULL + push!(args, LLVM.Value(tape)) + end + end + + for i in 1:length(args) + party = llvmtype(parameters(llvmf)[i]) + if llvmtype(args[i]) == party + continue + end + if LLVM.addrspace(party) != 0 + args[i] = addrspacecast!(B, args[i], party) + else + GPUCompiler.@safe_error "Calling convention mismatch", party, args[i], i, llvmf + return tapeV + end + end + + res = LLVM.call!(B, llvmf, args) + API.EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, res, orig) + + hasNoRet = any(map(k->kind(k)==kind(EnumAttribute("noreturn"; ctx)), collect(function_attributes(llvmf)))) + + if hasNoRet + return tapeV + end + + if sret !== nothing + res = load!(B, sret) + end + + shadowV = C_NULL + normalV = C_NULL + + + if forward + if !needsPrimal && !needsShadow && !needsTape + else + if !isa(llvmtype(res), LLVM.StructType) && !isa(llvmtype(res), LLVM.ArrayType) + emit_error(B, "Enzyme: incorrect return type of augmented forward custom rule - "*(string(RT))*" "*string(activity)) + return tapeV + end + idx = 0 + if needsPrimal + normalV = extract_value!(B, res, idx) + if llvmtype(normalV) != llvmtype(orig) + GPUCompiler.@safe_error "Primal calling convention mismatch found ", normalV, " wanted ", llvmtype(orig) + return tapeV + end + normalV = normalV.ref + idx+=1 + end + if needsShadow + shadowV = extract_value!(B, res, idx).ref + if llvmtype(shadowV) != shadowType + GPUCompiler.@safe_error "Shadow calling convention mismatch found ", shadowV, " wanted ",shadowType + return tapeV + end + shadowV = shadowV.ref + idx+=1 + end + if needsTape + tapeV = extract_value!(B, res, idx).ref + idx+=1 + end + end + else + if length(actives) >= 1 && !isa(llvmtype(res), LLVM.StructType) && !isa(llvmtype(res), LLVM.ArrayType) + GPUCompiler.@safe_error "Shadow arg calling convention mismatch found return ", res + return tapeV + end + + idx = 0 + for v in actives + ext = extract_value!(B, res, idx) + shadowVType = LLVM.LLVMType(API.EnzymeGetShadowType(width, llvmtype(v))) + if llvmtype(ext) != shadowType + GPUCompiler.@safe_error "Shadow arg calling convention mismatch found ", ext, " wanted ",shadowVType + return tapeV + end + Typ = C_NULL + API.EnzymeGradientUtilsAddToDiffe(gutils, v, ext, B, Typ) + idx+=1 + end + end + + if forward + if shadowR != C_NULL + unsafe_store!(shadowR, shadowV) + end + + # Delete the primal code + if needsPrimal + unsafe_store!(normalR, normalV) + else + LLVM.API.LLVMInstructionEraseFromParent(LLVM.Instruction(API.EnzymeGradientUtilsNewFromOriginal(gutils, orig))) + end + end + + return tapeV +end + + +function enzyme_custom_augfwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef}, tapeR::Ptr{LLVM.API.LLVMValueRef})::Cvoid + tape = enzyme_custom_common_rev(#=forward=#true, B, OrigCI, gutils, normalR, shadowR, #=tape=#nothing) + if tape != C_NULL + unsafe_store!(tapeR, tape) + end + return nothing +end + + +function enzyme_custom_rev(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, tape::LLVM.API.LLVMValueRef)::Cvoid + enzyme_custom_common_rev(#=forward=#false, B, OrigCI, gutils, #=normalR=#C_NULL, #=shadowR=#C_NULL, #=tape=#tape) + return nothing +end + function arraycopy_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef})::Cvoid orig = LLVM.Instruction(OrigCI) @@ -3874,6 +4365,12 @@ function __init__() @cfunction(enq_work_rev, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef)), @cfunction(enq_work_fwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef})) ) + register_handler!( + ("enzyme_custom",), + @cfunction(enzyme_custom_augfwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef})), + @cfunction(enzyme_custom_rev, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef)), + @cfunction(enzyme_custom_fwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef})) + ) register_handler!( ("jl_wait",), @cfunction(wait_augfwd, Cvoid, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef}, Ptr{LLVM.API.LLVMValueRef})), @@ -4019,6 +4516,7 @@ struct EnzymeCompilerParams <: AbstractEnzymeCompilerParams end struct PrimalCompilerParams <: AbstractEnzymeCompilerParams + mode::API.CDerivativeMode end ## job @@ -4036,7 +4534,7 @@ GPUCompiler.runtime_slug(job::CompilerJob{EnzymeTarget}) = "enzyme" # provide a specific interpreter to use. GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = - Interpreter.EnzymeInterpeter(GPUCompiler.ci_cache(job), GPUCompiler.method_table(job), job.source.world) + Interpreter.EnzymeInterpeter(GPUCompiler.ci_cache(job), GPUCompiler.method_table(job), job.source.world, job.params.mode) include("compiler/utils.jl") include("compiler/passes.jl") @@ -5385,7 +5883,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if parent_job === nothing primal_target = GPUCompiler.NativeCompilerTarget() - primal_params = Compiler.PrimalCompilerParams() + primal_params = Compiler.PrimalCompilerParams(mode) primal_job = CompilerJob(primal_target, primal, primal_params) else primal_job = similar(parent_job, job.source) @@ -5482,11 +5980,28 @@ end foundTys = Dict{String, Tuple{LLVM.FunctionType, Core.MethodInstance}}() jobref = Ref(job) + world = job.source.world actualRetType = nothing + customDerivativeNames = String[] for (mi, k) in meta.compiled k_name = GPUCompiler.safe_name(k.specfunc) - haskey(functions(mod), k_name) || continue + has_custom_rule = false + if mode == API.DEM_ForwardMode + has_custom_rule = EnzymeRules.has_frule_from_sig(mi.specTypes; world) + if has_custom_rule + @safe_debug "Found frule for" mi.specTypes + end + else + has_custom_rule = EnzymeRules.has_rrule_from_sig(mi.specTypes; world) + if has_custom_rule + @safe_debug "Found rrule for" mi.specTypes + end + end + + if !(haskey(functions(mod), k_name) || has_custom_rule) + continue + end llvmfn = functions(mod)[k_name] if llvmfn == primalf @@ -5507,6 +6022,7 @@ end push!(attributes, a) end push!(attributes, StringAttribute("enzymejl_mi", string(convert(Int, pointer_from_objref(mi))); ctx)) + push!(attributes, StringAttribute("enzymejl_job", string(convert(Int, pointer_from_objref(jobref))); ctx)) push!(attributes, StringAttribute("enzyme_math", name; ctx)) push!(attributes, EnumAttribute("noinline", 0; ctx)) must_wrap |= llvmfn == primalf @@ -5514,10 +6030,14 @@ end end foundTys[k_name] = (eltype(llvmtype(llvmfn)), mi) + if has_custom_rule + handleCustom("enzyme_custom") + continue + end Base.isbindingresolved(jlmod, name) && isdefined(jlmod, name) || continue func = getfield(jlmod, name) - + sparam_vals = mi.specTypes.parameters[2:end] # mi.sparam_vals if func == Base.eps || func == Base.nextfloat || func == Base.prevfloat handleCustom("jl_inactive_inout", [StringAttribute("enzyme_inactive"; ctx), @@ -5648,7 +6168,7 @@ end end primalf = wrapper_f end - + source_sig = GPUCompiler.typed_signature(job)::Type primalf, returnRoots = lower_convention(source_sig, mod, primalf, actualRetType) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index aba5f7456c..3d04eaefd2 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -1,8 +1,10 @@ module Interpreter using Random +import Enzyme: API using Core.Compiler: AbstractInterpreter, InferenceResult, InferenceParams, InferenceState, OptimizationParams, MethodInstance using GPUCompiler: CodeCache, WorldView import ..Enzyme +import ..EnzymeRules struct EnzymeInterpeter <: AbstractInterpreter global_cache::CodeCache @@ -17,7 +19,9 @@ struct EnzymeInterpeter <: AbstractInterpreter inf_params::InferenceParams opt_params::OptimizationParams - function EnzymeInterpeter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt) + mode::API.CDerivativeMode + + function EnzymeInterpeter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt, mode::API.CDerivativeMode) @assert world <= Base.get_world_counter() return new( @@ -34,6 +38,7 @@ struct EnzymeInterpeter <: AbstractInterpreter InferenceParams(unoptimize_throw_blocks=false), VERSION >= v"1.8.0-DEV.486" ? OptimizationParams() : OptimizationParams(unoptimize_throw_blocks=false), + mode ) end end @@ -136,6 +141,15 @@ function Core.Compiler.inlining_policy(interp::EnzymeInterpeter, if is_primitive_func(mi.specTypes) return nothing end + if interp.mode == API.DEM_ForwardMode + if EnzymeRules.has_frule_from_sig(mi.specTypes; world = interp.world) + return nothing + end + else + if EnzymeRules.has_rrule_from_sig(mi.specTypes; world = interp.world) + return nothing + end + end return Base.@invoke Core.Compiler.inlining_policy(interp::AbstractInterpreter, src::Any, info::CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) @@ -150,6 +164,15 @@ function Core.Compiler.inlining_policy(interp::EnzymeInterpeter, if is_primitive_func(mi.specTypes) return nothing end + if interp.mode == API.DEM_ForwardMode + if EnzymeRules.has_frule_from_sig(mi.specTypes; world = interp.world) + return nothing + end + else + if EnzymeRules.has_rrule_from_sig(mi.specTypes; world = interp.world) + return nothing + end + end return Base.@invoke Core.Compiler.inlining_policy(interp::AbstractInterpreter, src::Any, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any}) @@ -162,9 +185,19 @@ enzyme_inlining_policy(@nospecialize(src)) = Core.Compiler.default_inlining_poli Core.Compiler.inlining_policy(::EnzymeInterpeter) = enzyme_inlining_policy function Core.Compiler.resolve_todo(todo::InliningTodo, state::InliningState{S, T, <:typeof(enzyme_inlining_policy)}) where {S<:Union{Nothing, Core.Compiler.EdgeTracker}, T} mi = todo.mi - if is_primitive_func(mi.specTypes) + if is_primitive_func(mi.specTypes) return Core.Compiler.compileable_specialization(state.et, todo.spec.match) end + interp = state.interp + if interp.mode == API.DEM_ForwardMode + if EnzymeRules.has_frule_from_sig(mi.specTypes; world = interp.world) + return Core.Compiler.compileable_specialization(state.et, todo.spec.match) + end + else + if EnzymeRules.has_rrule_from_sig(mi.specTypes; world = interp.world) + return Core.Compiler.compileable_specialization(state.et, todo.spec.match) + end + end return Base.@invoke Core.Compiler.resolve_todo( todo::InliningTodo, state::InliningState) diff --git a/src/compiler/pmap.jl b/src/compiler/pmap.jl index 2026bdb9d6..caf3db6b63 100644 --- a/src/compiler/pmap.jl +++ b/src/compiler/pmap.jl @@ -184,7 +184,8 @@ end splat, _ = julia_activity(mi.specTypes.parameters, (mode != API.DEM_ReverseModeGradient) ? [Type{thunkTy}, Val{any_jltypes(TapeType)}, Int, funcT, funcT] : [Type{thunkTy}, Val{any_jltypes(TapeType)}, Int, STT, funcT, funcT], ops, gutils) tt = Tuple{splat...} - entry = nested_codegen!(mod, runtime_fn, tt) + mode = API.EnzymeGradientUtilsGetMode(gutils) + entry = nested_codegen!(mode, mod, runtime_fn, tt) # 5) Call the function B = LLVM.Builder(B) diff --git a/src/compiler/reflection.jl b/src/compiler/reflection.jl index 547ba07ac5..6f8742c3f5 100644 --- a/src/compiler/reflection.jl +++ b/src/compiler/reflection.jl @@ -11,7 +11,6 @@ function get_job(@nospecialize(func), @nospecialize(A), @nospecialize(types); return Compiler.CompilerJob(target, primal, params) end - function reflect(@nospecialize(func), @nospecialize(A), @nospecialize(types); optimize::Bool=true, second_stage::Bool=true, ctx=nothing, kwargs...) diff --git a/test/rrules.jl b/test/rrules.jl new file mode 100644 index 0000000000..4125f6d313 --- /dev/null +++ b/test/rrules.jl @@ -0,0 +1,63 @@ +module ReverseRules + +using Enzyme +using Enzyme: EnzymeRules +using Test + +f(x) = x^2 + +function f_ip(x) + x[1] *= x[1] + return nothing +end + +import .EnzymeRules: augmented_primal, reverse, Annotation, has_rrule, has_rrule_from_sig +using .EnzymeRules + +function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(f)}, ::Type{<:Active}, x::Active) + if needs_primal(config) + return (func.val(x.val), nothing) + else + return (nothing, nothing) + end +end + +function reverse(config::ConfigWidth{1}, ::Const{typeof(f)}, dret::Active, tape, x::Active) + if needs_primal(config) + return (10+2*x.val*dret.val,) + else + return (100+2*x.val*dret.val,) + end +end + +function augmented_primal(::Config{false, false, 1}, func::Const{typeof(f_ip)}, ::Type{<:Const}, x::Duplicated) + v = x.val[1] + x.val[1] *= v + return (nothing, v) +end + +function reverse(::Config{false, false, 1}, ::Const{typeof(f_ip)}, ::Const, tape, x::Duplicated) + x.dval[1] = 100 + x.dval[1] * tape + return () +end + +@testset "has_rrule" begin + @test has_rrule_from_sig(Base.signature_type(f, Tuple{Float64})) + @test has_rrule_from_sig(Base.signature_type(f_ip, Tuple{Vector{Float64}})) +end + + +@testset "Custom Reverse Rules" begin + @test Enzyme.autodiff(Enzyme.Reverse, f, Active(2.0))[1] ≈ 104.0 + @test Enzyme.autodiff(Enzyme.Reverse, x->f(x)^2, Active(2.0))[1] ≈ 42.0 + + x = [2.0] + dx = [1.0] + + Enzyme.autodiff(Enzyme.Reverse, f_ip, Duplicated(x, dx)) + + @test x ≈ [4.0] + @test dx ≈ [102.0] +end + +end # ReverseRules diff --git a/test/rules.jl b/test/rules.jl new file mode 100644 index 0000000000..7f041cf1ff --- /dev/null +++ b/test/rules.jl @@ -0,0 +1,92 @@ +module ForwardRules + +using Enzyme +using Enzyme: EnzymeRules +using Test + +import .EnzymeRules: forward, Annotation, has_frule, has_frule_from_sig + +f(x) = x^2 + +function forward(::Const{typeof(f)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) + return 10+2*x.val*x.dval +end + +function forward(::Const{typeof(f)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T, N}) where {T, N} + return NTuple{N, T}(1000+2*x.val*dv for dv in x.dval) +end + +function forward(func::Const{typeof(f)}, ::Type{<:Duplicated}, x::Duplicated) + return Duplicated(func.val(x.val), 100+2*x.val*x.dval) +end + +function forward(func::Const{typeof(f)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T, N}) where {T,N} + return BatchDuplicated(func.val(x.val), NTuple{N, T}(10000+2*x.val*dv for dv in x.dval)) +end + +@testset "has_frule" begin + @test has_frule_from_sig(Base.signature_type(f, Tuple{Float64})) + @test has_frule_from_sig(Base.signature_type(f_ip, Tuple{Vector{Float64}})) + + @test has_frule(f) + @test has_frule(f, Duplicated) + @test has_frule(f, DuplicatedNoNeed) + @test has_frule(f, BatchDuplicated) + @test has_frule(f, BatchDuplicatedNoNeed) + @test has_frule(f, Duplicated, Tuple{<:Duplicated}) + @test has_frule(f, DuplicatedNoNeed, Tuple{<:Duplicated}) + @test has_frule(f, BatchDuplicated, Tuple{<:BatchDuplicated}) + @test has_frule(f, BatchDuplicatedNoNeed, Tuple{<:BatchDuplicated}) + + @test !has_frule(f, Duplicated, Tuple{<:BatchDuplicated}) + @test !has_frule(f, DuplicatedNoNeed, Tuple{<:BatchDuplicated}) + @test !has_frule(f, BatchDuplicated, Tuple{<:Duplicated}) + @test !has_frule(f, BatchDuplicatedNoNeed, Tuple{<:Duplicated}) + + @test has_frule(f, Tuple{<:Duplicated}) + @test has_frule(f, Tuple{<:BatchDuplicated}) + @test has_frule(f, Tuple{<:Annotation}) + @test has_frule(f, Tuple{<:Annotation{Float64}}) + @test !has_frule(f, Tuple{<:Const}) +end + +@testset "autodiff(Forward, ...) custom rules" begin + @test autodiff(Forward, f, Duplicated(2.0, 1.0))[1] ≈ 14.0 + @test autodiff(Forward, x->f(x)^2, Duplicated(2.0, 1.0))[1] ≈ 832.0 + + res = autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(2.0, (1.0, 3.0)))[1] + @test res[1] ≈ 1004.0 + @test res[2] ≈ 1012.0 + + res = Enzyme.autodiff(Forward, x->f(x)^2, BatchDuplicatedNoNeed, BatchDuplicated(2.0, (1.0, 3.0)))[1] + + @test res[1] ≈ 80032.0 + @test res[2] ≈ 80096.0 +end + +function f_ip(x) + x[1] *= x[1] + return nothing +end + +function forward(::Const{Core.typeof(f_ip)}, ::Type{<:Const}, x::Duplicated) + ld = x.val[1] + x.val[1] *= ld + x.dval[1] *= 2 * ld + 10 + return nothing +end + +@testset "In place" begin + vec = [2.0] + dvec = [1.0] + + Enzyme.autodiff(Forward, f_ip, Duplicated(vec, dvec)) + + @test vec ≈ [4.0] + @test dvec ≈ [14.0] +end + +# TODO: Test error for no frule applicable despite frule on Function. + + +end # module ForwardRules diff --git a/test/runtests.jl b/test/runtests.jl index fe856b63fd..4503a30e5c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,14 +41,20 @@ end include("abi.jl") include("typetree.jl") +if Enzyme.EnyzmeRules.issupported() + include("rules.jl") + include("rrules.jl") +end + f0(x) = 1.0 + x - function vrec(start, x) - if start > length(x) - return 1.0 - else - return x[start] * vrec(start+1, x) - end +function vrec(start, x) + if start > length(x) + return 1.0 + else + return x[start] * vrec(start+1, x) end +end + @testset "Internal tests" begin thunk_a = Enzyme.Compiler.thunk(f0, nothing, Active, Tuple{Active{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1)) thunk_b = Enzyme.Compiler.thunk(f0, nothing, Const, Tuple{Const{Float64}}, Val(API.DEM_ReverseModeCombined), Val(1))