From e37105bab64701f7749d3fc26efc4de18a5bf91a Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 26 Aug 2022 23:17:29 -0400 Subject: [PATCH] Add reverse rules --- src/api.jl | 7 + src/compiler.jl | 402 +++++++++++++++++++++++++++++++++-------------- src/rules.jl | 2 +- test/rrules.jl | 64 ++++++++ test/runtests.jl | 1 + 5 files changed, 361 insertions(+), 115 deletions(-) create mode 100644 test/rrules.jl diff --git a/src/api.jl b/src/api.jl index 5a70d1e9f1..79ce5a8b0e 100644 --- a/src/api.jl +++ b/src/api.jl @@ -205,6 +205,13 @@ EnzymeGradientUtilsAllocationBlock(gutils) = ccall((:EnzymeGradientUtilsAllocati EnzymeGradientUtilsTypeAnalyzer(gutils) = ccall((:EnzymeGradientUtilsTypeAnalyzer, libEnzyme), EnzymeTypeAnalyzerRef, (EnzymeGradientUtilsRef,), gutils) EnzymeGradientUtilsAllocAndGetTypeTree(gutils, val) = ccall((:EnzymeGradientUtilsAllocAndGetTypeTree, libEnzyme), CTypeTreeRef, (EnzymeGradientUtilsRef,LLVMValueRef), gutils, val) + +EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, size) = ccall((:EnzymeGradientUtilsGetUncacheableArgs, libEnzyme), Cvoid, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, UInt64), gutils, orig, uncacheable, size) + +EnzymeGradientUtilsGetDiffeType(gutils, op, isforeign) = ccall((:EnzymeGradientUtilsGetDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, UInt8), gutils, op, isforeign) + +EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP) = ccall((:EnzymeGradientUtilsGetReturnDiffeType, libEnzyme), CDIFFE_TYPE, (EnzymeGradientUtilsRef,LLVMValueRef, Ptr{UInt8}, Ptr{UInt8}), gutils, orig, needsPrimalP, needsShadowP) + EnzymeGradientUtilsSubTransferHelper(gutils, mode, secretty, intrinsic, dstAlign, srcAlign, offset, dstConstant, origdst, srcConstant, origsrc, length, isVolatile, MTI, allowForward, shadowsLookedUp) = ccall((:EnzymeGradientUtilsSubTransferHelper, libEnzyme), Cvoid, diff --git a/src/compiler.jl b/src/compiler.jl index 8ea5221e09..b3223c4ed1 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1801,51 +1801,22 @@ function wait_rev(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gut return nothing end -function enzyme_custom_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef})::Cvoid - orig = LLVM.Instruction(OrigCI) +function enzyme_custom_setup_args(B, orig, gutils, mi) ctx = LLVM.context(orig) - - width = API.EnzymeGradientUtilsGetWidth(gutils) - - if shadowR != C_NULL - unsafe_store!(shadowR,UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llvmtype(orig)))).ref) - end - - # 1) extract out the MI from attributes - mi = nothing - job = nothing - for fattr in collect(function_attributes(LLVM.called_value(orig))) - if isa(fattr, LLVM.StringAttribute) - if kind(fattr) == "enzymejl_mi" - ptr = reinterpret(Ptr{Cvoid}, parse(Int, LLVM.value(fattr))) - mi = Base.unsafe_pointer_to_objref(ptr) - end - if kind(fattr) == "enzymejl_job" - ptr = reinterpret(Ptr{Cvoid}, parse(Int, LLVM.value(fattr))) - job = Base.unsafe_pointer_to_objref(ptr)[] - end - end - end - B = LLVM.Builder(B) - - if mi === nothing - emit_error(builder, "Enzyme: Custom forward handler, could not find MI") - return - end - - # TODO: don't inject the code multiple times for multiple calls - - # 2) Create activity, and annotate function spec - active = API.EnzymeGradientUtilsIsConstantValue(gutils, orig) == 0 - ops = collect(operands(orig)) called = ops[end] ops = ops[1:end-1] + width = API.EnzymeGradientUtilsGetWidth(gutils) args = LLVM.Value[] - activity = Type[] - + overwritten = Bool[] + + actives = LLVM.Value[] + + uncacheable = Vector{UInt8}(undef, length(ops)) + API.EnzymeGradientUtilsGetUncacheableArgs(gutils, orig, uncacheable, length(uncacheable)) + sret = false returnRoots = false @@ -1853,30 +1824,68 @@ function enzyme_custom_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValu op_idx = 1 - alloctx = LLVM.Builder(ctx) position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) for arg in jlargs if arg.cc == GPUCompiler.GHOST push!(activity, Const{arg.typ}) + push!(overwritten, false) continue end op = ops[op_idx] + push!(overwritten, uncacheable[op_idx] != 0) op_idx+=1 val = LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, op)) - activep = API.EnzymeGradientUtilsIsConstantValue(gutils, op) == 0 + activep = API.EnzymeGradientUtilsGetDiffeType(gutils, op, #=isforeign=#false) + # TODO type analysis deduce if duplicated vs active - if activep + if activep == API.DFT_CONSTANT + Ty = Const{arg.typ} + + llty = convert(LLVMType, Ty; ctx) + + arval = LLVM.UndefValue(llty) + arval = insert_value!(B, arval, val, 0) + al = alloca!(alloctx, llvmtype(arval)) + store!(B, arval, al) + push!(args, al) + + push!(activity, Ty) + + elseif activep == API.DFT_OUT_DIFF + Ty = Active{arg.typ} + llty = convert(LLVMType, Ty; ctx) + + arval = LLVM.UndefValue(llty) + arval = insert_value!(B, arval, val, 0) + + al = alloca!(alloctx, llvmtype(arval)) + store!(B, arval, al) + push!(args, al) + + push!(activity, Ty) + push!(actives, op) + else ival = LLVM.Value(API.EnzymeGradientUtilsInvertPointer(gutils, op, B)) if width == 1 - Ty = Duplicated{arg.typ} + if activep == API.DFT_DUP_ARG + Ty = Duplicated{arg.typ} + else + @assert activep == API.DFT_DUP_NONEED + Ty = DuplicatedNoNeed{arg.typ} + end else - Ty = BatchDuplicated{arg.typ, Int64(width)} + if activep == API.DFT_DUP_ARG + Ty = BatchDuplicated{arg.typ, Int64(width)} + else + @assert activep == API.DFT_DUP_NONEED + Ty = BatchDuplicatedNoNeed{arg.typ, Int64(width)} + end end llty = convert(LLVMType, Ty; ctx) @@ -1889,35 +1898,97 @@ function enzyme_custom_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValu store!(B, arval, al) push!(args, al) push!(activity, Ty) - else - push!(args, val) - push!(activity, Const{arg.typ}) end end @assert op_idx-1 == length(ops) + return args, activity, overwritten, actives +end + +function enzyme_custom_extract_mi(orig) + mi = nothing + job = nothing + for fattr in collect(function_attributes(LLVM.called_value(orig))) + if isa(fattr, LLVM.StringAttribute) + if kind(fattr) == "enzymejl_mi" + ptr = reinterpret(Ptr{Cvoid}, parse(Int, LLVM.value(fattr))) + mi = Base.unsafe_pointer_to_objref(ptr) + end + if kind(fattr) == "enzymejl_job" + ptr = reinterpret(Ptr{Cvoid}, parse(Int, LLVM.value(fattr))) + job = Base.unsafe_pointer_to_objref(ptr)[] + end + end + end + if mi === nothing + GPUCompiler.@safe_error "Enzyme: Custom handler, could not find mi", orig, LLVM.called_value(orig) + end + return mi, job +end + +function enzyme_custom_setup_ret(gutils, orig, mi, job) + width = API.EnzymeGradientUtilsGetWidth(gutils) interp = GPUCompiler.get_interpreter(job) RealRt = Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype - needsPrimal = (unsafe_load(normalR) != C_NULL) - RT = Const{RealRt} - if active - if needsPrimal - if width == 1 - RT = Duplicated{RealRt} - else - RT = BatchDuplicated{RealRt, Int64(width)} - end + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP) + needsPrimal = needsPrimalP[] != 0 + + if !needsPrimal && activep == API.DFT_DUP_ARG + activep = API.DFT_DUP_NONEED + end + + if activep == API.DFT_CONSTANT + RT = Const{RealRt} + + elseif activep == API.DFT_OUT_DIFF + RT = Active{RealRt} + + elseif activep == API.DFT_DUP_ARG + if width == 1 + RT = Duplicated{RealRt} else - if width == 1 - RT = DuplicatedNoNeed{RealRt} - else - RT = BatchDuplicatedNoNeed{RealRt, Int64(width)} - end + RT = BatchDuplicated{RealRt, Int64(width)} + end + else + @assert activep == API.DFT_DUP_NONEED + if width == 1 + RT = DuplicatedNoNeed{RealRt} + else + RT = BatchDuplicatedNoNeed{RealRt, Int64(width)} end end + return RealRt, RT, needsPrimal, needsShadowP[] != 0 +end + +function enzyme_custom_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef})::Cvoid + orig = LLVM.Instruction(OrigCI) + ctx = LLVM.context(orig) + B = LLVM.Builder(B) + + width = API.EnzymeGradientUtilsGetWidth(gutils) + + if shadowR != C_NULL + unsafe_store!(shadowR,UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, llvmtype(orig)))).ref) + end + + # TODO: don't inject the code multiple times for multiple calls + + # 1) extract out the MI from attributes + mi, job = enzyme_custom_extract_mi(orig) + + # 2) Create activity, and annotate function spec + args, activity, overwritten, actives = enzyme_custom_setup_args(B, orig, gutils, mi) + RealRt, RT, needsPrimal, needsShadow = enzyme_custom_setup_ret(gutils, orig, mi, job) + + alloctx = LLVM.Builder(ctx) + position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) + func = EnzymeRules.forward(mi.specTypes, RT, activity) if func == nothing @@ -1949,6 +2020,7 @@ function enzyme_custom_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValu end res = LLVM.call!(B, llvmf, args) + API.EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, res, orig) hasNoRet = any(map(k->kind(k)==kind(EnumAttribute("noreturn"; ctx)), collect(function_attributes(llvmf)))) @@ -1963,7 +2035,7 @@ function enzyme_custom_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValu shadowV = C_NULL normalV = C_NULL - if !active + if RT <: Const if needsPrimal normalV = res.ref end @@ -1994,86 +2066,188 @@ function enzyme_custom_fwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValu return nothing end -function enzyme_custom_augfwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef}, tapeR::Ptr{LLVM.API.LLVMValueRef})::Cvoid +function enzyme_custom_common_rev(forward::Bool, B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR, shadowR, tape)::LLVM.API.LLVMValueRef + orig = LLVM.Instruction(OrigCI) ctx = LLVM.context(orig) - normal = LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, orig)) + B = LLVM.Builder(B) + + width = API.EnzymeGradientUtilsGetWidth(gutils) + + shadowType = LLVM.LLVMType(API.EnzymeGetShadowType(width, llvmtype(orig))) if shadowR != C_NULL - unsafe_store!(shadowR, normal.ref) + unsafe_store!(shadowR,UndefValue(shadowType).ref) end + + # TODO: don't inject the code multiple times for multiple calls + + # 1) extract out the MI from attributes + mi, job = enzyme_custom_extract_mi(orig) - orig = LLVM.Instruction(OrigCI) + # 2) Create activity, and annotate function spec + args, activity, overwritten, actives = enzyme_custom_setup_args(B, orig, gutils, mi) + RealRt, RT, needsPrimal, needsShadow = enzyme_custom_setup_ret(gutils, orig, mi, job) + + alloctx = LLVM.Builder(ctx) + position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) - # 1) extract out the MI from attributes - llvmfn = LLVM.called_value(orig) - mi = nothing - for fattr in collect(function_attributes(llvmfn)) - if isa(fattr, LLVM.StringAttribute) - if kind(fattr) == "enzymejl_mi" - ptr = reinterpret(Ptr{Cvoid}, parse(Int, LLVM.value(fattr))) - mi = Base.unsafe_pointer_to_objref(ptr) - break - end - end + tup = EnzymeRules.reverse(mi.specTypes, RT, activity, needsPrimal, needsShadow, width, overwritten) + + if tup == nothing + emit_error(B, "Enzyme: activity setting not provided for "*(string(RT))*" "*string(activity)) + return C_NULL end - builder = LLVM.Builder(B) - - if mi === nothing - emit_error(builder, "Enzyme: Custom augmented forward handler, could not find MI") + aug_func, rev_func, tapeType = tup + needsTape = !GPUCompiler.isghosttype(tapeType) && !Core.Compiler.isconstType(tapeType) + + tapeV = C_NULL + if forward && needsTape + tapeV = LLVM.UndefValue(convert(LLVMType, tapeType; ctx)).ref end - emit_error(builder, "Enzyme: Custom augmented forward handler, not yet implemented") - return nothing - # TODO: don't inject the code multiple times for multiple calls - - # 2) Create activity, and annotate function spec - active = API.EnzymeGradientUtilsIsConstantValue(gutils, orig) == 0 - ops = collect(operands(orig))[1:end-1] + mode = API.EnzymeGradientUtilsGetMode(gutils) + mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) - args = LLVM.Value[] + if forward + llvmf = nested_codegen!(mode, mod, aug_func, Tuple{activity...}) + else + argTys = copy(activity) + if RT <: Active + if width == 1 + push!(argTys, RealRt) + else + push!(argTys, NTuple{RealRt, (Int64)width}) + end + end + push!(argTys, tapeType) + llvmf = nested_codegen!(mode, mod, rev_func, Tuple{argTys...}) + end + + sret = nothing + if !isempty(parameters(llvmf)) && any(map(k->kind(k)==kind(EnumAttribute("sret"; ctx)), collect(parameter_attributes(llvmf, 1)))) + sret = alloca!(alloctx, eltype(llvmtype(parameters(llvmf)[1]))) + pushfirst!(args, sret) + end - for op in ops - val = LLVM.Value(API.EnzymeGradientUtilsNewFromOriginal(gutils, op)) - push!(args, val) + if !forward + if RT <: Active + push!(args, LLVM.Value(API.EnzymeGradientUtilsDiffe(gutils, orig, B))) + end + if needsTape + @assert tape != C_NULL + push!(args, LLVM.Value(tape)) + end + end - active = API.EnzymeGradientUtilsIsConstantValue(gutils, op) == 0 - # TODO type analysis deduce if duplicated vs active - if active - push!(args, LLVM.Value(API.EnzymeGradientUtilsInvertPointer(gutils, op, B))) + for i in 1:length(args) + party = llvmtype(parameters(llvmf)[i]) + if llvmtype(args[i]) == party + continue end + if LLVM.addrspace(party) != 0 + args[i] = addrspacecast!(B, args[i], party) + else + GPUCompiler.@safe_error "Calling convention mismatch", party, args[i], i, llvmf + return tapeV + end + end + + res = LLVM.call!(B, llvmf, args) + API.EnzymeGradientUtilsSetDebugLocFromOriginal(gutils, res, orig) + + hasNoRet = any(map(k->kind(k)==kind(EnumAttribute("noreturn"; ctx)), collect(function_attributes(llvmf)))) + + if hasNoRet + return tapeV end - tt = annotate_tuple_type(mi.specTypes, activity) - funcspec = FunctionSpec(EnzymeRules.augmented_forward, tt, #=kernel=# false, #=name=# nothing) + if sret !== nothing + res = load!(B, sret) + end - # 3) Use the MI to create the correct augmented fwd/reverse - # TODO: - # - GPU support - # - When OrcV2 only use a MaterializationUnit to avoid mutation of the module here + shadowV = C_NULL + normalV = C_NULL - target = GPUCompiler.NativeCompilerTarget() - params = Compiler.PrimalCompilerParams(mode) - job = CompilerJob(target, funcspec, params) - otherMod, meta = GPUCompiler.codegen(:llvm, job, optimize=false, validate=false) - entry = name(meta.entry) + if forward + if !needsPrimal && !needsShadow && !needsTape + else + if !isa(llvmtype(res), LLVM.StructType) && !isa(llvmtype(res), LLVM.ArrayType) + emit_error(B, "Enzyme: incorrect return type of augmented forward custom rule - "*(string(RT))*" "*string(activity)) + return tapeV + end + idx = 0 + if needsPrimal + normalV = extract_value!(B, res, idx) + if llvmtype(normalV) != llvmtype(orig) + GPUCompiler.@safe_error "Primal calling convention mismatch found ", normalV, " wanted ", llvmtype(orig) + return tapeV + end + normalV = normalV.ref + idx+=1 + end + if needsShadow + shadowV = extract_value!(B, res, idx).ref + if llvmtype(shadowV) != shadowType + GPUCompiler.@safe_error "Shadow calling convention mismatch found ", shadowV, " wanted ",shadowType + return tapeV + end + shadowV = shadowV.ref + idx+=1 + end + if needsTape + tapeV = extract_value!(B, res, idx).ref + idx+=1 + end + end + else + if length(actives) >= 1 && !isa(llvmtype(res), LLVM.StructType) && !isa(llvmtype(res), LLVM.ArrayType) + GPUCompiler.@safe_error "Shadow arg calling convention mismatch found return ", res + return tapeV + end + + idx = 0 + for v in actives + ext = extract_value!(B, res, idx) + shadowVType = LLVM.LLVMType(API.EnzymeGetShadowType(width, llvmtype(v))) + if llvmtype(ext) != shadowType + GPUCompiler.@safe_error "Shadow arg calling convention mismatch found ", ext, " wanted ",shadowVType + return tapeV + end + Typ = C_NULL + API.EnzymeGradientUtilsAddToDiffe(gutils, v, ext, B, Typ) + idx+=1 + end + end - # 4) Link the corresponding module - bb = LLVM.position(builder) - mod = LLVM.parent(LLVM.parent(bb)) - LLVM.link!(mod, otherMod) + if forward + if shadowR != C_NULL + unsafe_store!(shadowR, shadowV) + end + + # Delete the primal code + if needsPrimal + unsafe_store!(normalR, normalV) + else + LLVM.API.LLVMInstructionEraseFromParent(LLVM.Instruction(API.EnzymeGradientUtilsNewFromOriginal(gutils, orig))) + end + end - # 5) Call the function - entry = functions(mod)[entry] + return tapeV +end - emit_error(builder, "Enzyme: Not yet implemented custom augmented forward handler") +function enzyme_custom_augfwd(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, normalR::Ptr{LLVM.API.LLVMValueRef}, shadowR::Ptr{LLVM.API.LLVMValueRef}, tapeR::Ptr{LLVM.API.LLVMValueRef})::Cvoid + tape = enzyme_custom_common_rev(#=forward=#true, B, OrigCI, gutils, normalR, shadowR, #=tape=#nothing) + if tape != C_NULL + unsafe_store!(tapeR, tape) + end return nothing end function enzyme_custom_rev(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, tape::LLVM.API.LLVMValueRef)::Cvoid - emit_error(LLVM.Builder(B), "Enzyme: Not yet implemented custom reverse handler") + enzyme_custom_common_rev(#=forward=#false, B, OrigCI, gutils, #=normalR=#C_NULL, #=shadowR=#C_NULL, #=tape=#tape) return nothing end diff --git a/src/rules.jl b/src/rules.jl index 9639f35771..c4cbede1b1 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -31,7 +31,7 @@ function has_frule(@nospecialize(TT), world=Base.get_world_counter()) end function has_rrule(@nospecialize(TT), world=Base.get_world_counter()) - atype = Tuple{typeof(EnzymeRules.reverse), Type{TT}, Type, Vector{Type}} + atype = Tuple{typeof(EnzymeRules.reverse), Type{TT}, Type, Vector{Type}, Bool, Bool, UInt64, Vector{Bool}} if VERSION < v"1.8.0-" res = ccall(:jl_gf_invoke_lookup, Any, (Any, UInt), atype, world) diff --git a/test/rrules.jl b/test/rrules.jl new file mode 100644 index 0000000000..402fa609aa --- /dev/null +++ b/test/rrules.jl @@ -0,0 +1,64 @@ +using Enzyme +using Enzyme: EnzymeRules +using Test + +@testset "Custom Reverse Rules" begin + rule_f(x) = x^2 + + function rule_fip(x) + x[1] *= x[1] + return nothing + end + + function Enzyme.EnzymeRules.reverse(::Type{Tuple{typeof(rule_f), Float64}}, RT, Args, needsPrimal, needsShadow, width, overwritten) + if width != 1 + return nothing + end + @assert Args[1] <: Const + @assert !needsShadow + if RT <: Active && Args[2] <: Active && needsPrimal + tmp1_aug(func, x) = (func.val(x.val),nothing) + tmp1_rev(func, x, dret, tape) = (10+2*x.val*dret,) + return tmp1_aug, tmp1_rev, typeof(nothing) + end + if RT <: Active && Args[2] <: Active && !needsPrimal + tmp2_aug(func, x) = (nothing,) + tmp2_rev(func, x, dret, tape) = (100+2*x.val*dret,) + return tmp2_aug, tmp2_rev, typeof(nothing) + end + return nothing + end + + function Enzyme.EnzymeRules.reverse(::Type{Tuple{typeof(rule_fip), T}}, RT, Args, needsPrimal, needsShadow, width, overwritten) where {T} + if width != 1 + return nothing + end + @assert Args[1] <: Const + @assert !needsPrimal + @assert !needsShadow + if RT <: Const && Args[2] <: Duplicated + function tmp1_aug(func, x) + v = x.val[1] + x.val[1] *= v + return (v,) + end + function tmp1_rev(func, x, tape) + x.dval[1] = 100 + x.dval[1] * tape + return () + end + return tmp1_aug, tmp1_rev, eltype(T) + end + return nothing + end + + @test Enzyme.autodiff(Enzyme.Reverse, rule_f, Active(2.0))[1] ≈ 104.0 + @test Enzyme.autodiff(Enzyme.Reverse, x->rule_f(x)^2, Active(2.0))[1] ≈ 42.0 + + x = [2.0] + dx = [1.0] + + Enzyme.autodiff(Enzyme.Reverse, rule_fip, Duplicated(x, dx)) + + @test x ≈ [4.0] + @test dx ≈ [102.0] +end diff --git a/test/runtests.jl b/test/runtests.jl index 439b773cd0..0c9562dda8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,6 +41,7 @@ end include("abi.jl") include("typetree.jl") include("rules.jl") +include("rrules.jl") f0(x) = 1.0 + x @testset "Internal tests" begin