From 86da3cdae09a5c6d0f877c4d9e01a4491d414501 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 8 Jun 2024 17:39:06 -0400 Subject: [PATCH] Reverse mode apply iterate (#1485) * Reverse mode apply iterate * fixed * fixup * cleanup * debugging fixes * fixup * cleanup * fix tests * fix batch getfield rev * fix tests * more test fix * fix tuple fast path * fix * Update Project.toml * fix sym index rev * fix test * fixup * Fix unionall * fix * fix sym offset * ix constantarray * Update Project.toml --- Project.toml | 4 +- src/Enzyme.jl | 7 + src/compiler.jl | 11 +- src/compiler/validation.jl | 16 +- src/rules/jitrules.jl | 662 +++++++++++++++++++++++---------- src/rules/typeunstablerules.jl | 152 +++++++- src/utils.jl | 2 +- test/applyiter.jl | 491 ++++++++++++++++++++++++ test/runtests.jl | 288 ++------------ 9 files changed, 1145 insertions(+), 488 deletions(-) create mode 100644 test/applyiter.jl diff --git a/Project.toml b/Project.toml index 87c6d55dcc..848c47e7ee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.11" +version = "0.12.12" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -20,7 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.7.4" -Enzyme_jll = "0.0.119" +Enzyme_jll = "0.0.121" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7" ObjectFile = "0.4" diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 7626304944..a6bc604e6a 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -74,6 +74,13 @@ end end)...} end +@inline function vaEltypes(args::Type{Ty}) where {Ty <: Tuple} + return Tuple{(ntuple(Val(length(Ty.parameters))) do i + Base.@_inline_meta + eltype(Ty.parameters[i]) + end)...} +end + @inline function same_or_one_helper(current, next) if current == -1 return next diff --git a/src/compiler.jl b/src/compiler.jl index fac6907b59..cca67bc874 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -380,7 +380,6 @@ end end @inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false))::ActivityState where {ST,T, justActive, UnionSret} - if T === Any return DupState end @@ -422,7 +421,9 @@ end else inmi = GPUCompiler.methodinstance(typeof(EnzymeCore.EnzymeRules.inactive_type), Tuple{Type{T}}, world) args = Any[EnzymeCore.EnzymeRules.inactive_type, T]; - ccall(:jl_invoke, Any, (Any, Ptr{Any}, Cuint, Any), EnzymeCore.EnzymeRules.inactive_type, args, length(args), inmi) + GC.@preserve T begin + ccall(:jl_invoke, Any, (Any, Ptr{Any}, Cuint, Any), EnzymeCore.EnzymeRules.inactive_type, args, length(args), inmi) + end end if inactivety @@ -480,11 +481,13 @@ end @static if VERSION < v"1.7.0" nT = T else - nT = if is_concrete_tuple(T) + nT = if T <: Tuple && T != Tuple && !(T isa UnionAll) Tuple{(ntuple(length(T.parameters)) do i Base.@_inline_meta sT = T.parameters[i] - if sT isa Core.TypeofVararg + if sT isa TypeVar + Any + elseif sT isa Core.TypeofVararg Any else sT diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 80db3cfb39..caf86cbc03 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -743,7 +743,7 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width end end - seen = Dict{LLVM.Value,Tuple}() + seen = Set{Tuple{LLVM.Value,Tuple}}() while length(todo) != 0 cur, off = pop!(todo) @@ -751,11 +751,10 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width cur = operands(cur)[1] end - if cur in keys(seen) - @assert seen[cur] == off + if cur in seen continue end - seen[cur] = off + push!(seen, (cur, off)) if isa(cur, LLVM.PHIInst) for (v, _) in LLVM.incoming(cur) @@ -781,7 +780,7 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width # if inserting at the current desired offset, we have found the value we need if ind == off[1] - push!(todo, (operands(cur)[2], -1)) + push!(todo, (operands(cur)[2], off[2:end])) # otherwise it must be inserted at a different point else push!(todo, (operands(cur)[1], off)) @@ -880,10 +879,15 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width end end + if isa(cur, LLVM.ConstantArray) + push!(todo, (cur[off[1]], off[2:end])) + continue + end + msg = sprint() do io::IO println(io, "Enzyme Internal Error (rewrite_union_returns_as_ref[2])") println(io, string(enzymefn)) - println(io, "cur=", cur) + println(io, "cur=", string(cur)) println(io, "off=", off) end throw(AssertionError(msg)) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 76b72466c1..af12d2bfbc 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1,5 +1,5 @@ -function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing) +function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing, iterate=false) primargs = Union{Symbol,Expr}[] shadowargs = Union{Symbol,Expr}[] batchshadowargs = Vector{Union{Symbol,Expr}}[] @@ -59,8 +59,36 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing) @assert length(primargs) == N @assert length(primtypes) == N wrapped = Expr[] + modbetween = Expr[:(MB[1])] for i in 1:N - expr = :( + if iterate + push!(modbetween, quote + ntuple(Val(length($(primargs[i])))) do _ + Base.@_inline_meta + MB[$i] + end + end) + end + expr = if iterate + :( + if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) + @assert $(primtypes[i]) !== DataType + if !$forwardMode && active_reg($(primtypes[i])) + iterate_unwrap_augfwd_act($(primargs[i])...) + else + $((Width == 1) ? quote + iterate_unwrap_augfwd_dup(Val($forwardMode), $(primargs[i]), $(shadowargs[i])) + end : quote + iterate_unwrap_augfwd_batchdup(Val($forwardMode), Val($Width), $(primargs[i]), $(shadowargs[i])) + end + ) + end + else + map(Const, $(primargs[i])) + end + ) + else + :( if ActivityTup[$i+1] && !guaranteed_const($(primtypes[i])) @assert $(primtypes[i]) !== DataType if !$forwardMode && active_reg($(primtypes[i])) @@ -73,9 +101,10 @@ function setup_macro_wraps(forwardMode::Bool, N::Int, Width::Int, base=nothing) end ) + end push!(wrapped, expr) end - return primargs, shadowargs, primtypes, allargs, typeargs, wrapped, batchshadowargs + return primargs, shadowargs, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween end function body_runtime_generic_fwd(N, Width, wrapped, primtypes) @@ -110,7 +139,6 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) end world = codegen_world_age(FT, tt) - forward = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ForwardMode), width, #=ModifiedBetween=#Val($ModifiedBetween), #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) res = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) @@ -131,7 +159,7 @@ function body_runtime_generic_fwd(N, Width, wrapped, primtypes) end function func_runtime_generic_fwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _ = setup_macro_wraps(true, N, Width) + _, _, primtypes, allargs, typeargs, wrapped, _, _ = setup_macro_wraps(true, N, Width) body = body_runtime_generic_fwd(N, Width, wrapped, primtypes) quote @@ -143,14 +171,14 @@ end @generated function runtime_generic_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(true, N, Width, :allargs) + _, _, primtypes, _, _, wrapped, _, _ = setup_macro_wraps(true, N, Width, :allargs) return body_runtime_generic_fwd(N, Width, wrapped, primtypes) end function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) nnothing = ntuple(i->nothing, Val(Width+1)) nres = ntuple(i->:(origRet), Val(Width+1)) - nzeros = ntuple(i->:(Ref(zero(resT))), Val(Width)) + nzeros = ntuple(i->:(Ref(make_zero(origRet))), Val(Width)) nres3 = ntuple(i->:(res[3]), Val(Width)) ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) @@ -162,7 +190,13 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) # tt0 = Tuple{$(primtypes...)} tt′ = Tuple{$(Types...)} rt = Core.Compiler.return_type(f, Tuple{$(ElTypes...)}) - annotation = guess_activity(rt, API.DEM_ReverseModePrimal) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + + annotation = if $Width != 1 && annotation0 <: Duplicated + BatchDuplicated{rt, $Width} + else + annotation0 + end dupClosure = ActivityTup[1] FT = Core.Typeof(f) @@ -209,19 +243,19 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) end function func_runtime_generic_augfwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _ = setup_macro_wraps(false, N, Width) + _, _, primtypes, allargs, typeargs, wrapped, _, _ = setup_macro_wraps(false, N, Width) body = body_runtime_generic_augfwd(N, Width, wrapped, primtypes) quote - function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} + function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} $body end end end -@generated function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, ReturnType, F, DF} +@generated function runtime_generic_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(false, N, Width, :allargs) + _, _, primtypes, _, _, wrapped, _, _= setup_macro_wraps(false, N, Width, :allargs) return body_runtime_generic_augfwd(N, Width, wrapped, primtypes) end @@ -267,7 +301,13 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) tt = Tuple{$(ElTypes...)} tt′ = Tuple{$(Types...)} rt = Core.Compiler.return_type(f, tt) - annotation = guess_activity(rt, API.DEM_ReverseModePrimal) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + + annotation = if $Width != 1 && annotation0 <: Duplicated + BatchDuplicated{rt, $Width} + else + annotation0 + end dupClosure = ActivityTup[1] FT = Core.Typeof(f) @@ -278,6 +318,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ReverseModePrimal), width, ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + if tape.shadow_return !== nothing args = (args..., $shadowret) end @@ -290,7 +331,7 @@ function body_runtime_generic_rev(N, Width, wrapped, primttypes, shadowargs) end function func_runtime_generic_rev(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs = setup_macro_wraps(false, N, Width) + _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _ = setup_macro_wraps(false, N, Width) body = body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) quote @@ -302,7 +343,7 @@ end @generated function runtime_generic_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, batchshadowargs = setup_macro_wraps(false, N, Width, :allargs) + _, _, primtypes, _, _, wrapped, batchshadowargs, _ = setup_macro_wraps(false, N, Width, :allargs) return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) end @@ -323,69 +364,127 @@ end end end +@inline function iterate_unwrap_augfwd_act(args...) + ntuple(Val(length(args))) do i + Base.@_inline_meta + arg = args[i] + if guaranteed_const(Core.Typeof(arg)) + Const(arg) + else + Active(arg) + end + end +end + +@inline function iterate_unwrap_augfwd_dup(::Val{forwardMode}, args, dargs) where forwardMode + ntuple(Val(length(args))) do i + Base.@_inline_meta + arg = args[i] + ty = Core.Typeof(arg) + if guaranteed_const(ty) + Const(arg) + elseif !forwardMode && active_reg(ty) + Active(arg) + else + Duplicated(arg, dargs[i]) + end + end +end + +@inline function iterate_unwrap_augfwd_batchdup(::Val{forwardMode}, ::Val{Width}, args, dargs) where {forwardMode, Width} + ntuple(Val(length(args))) do i + Base.@_inline_meta + arg = args[i] + ty = Core.Typeof(arg) + if guaranteed_const(ty) + Const(arg) + elseif !forwardMode && active_reg(ty) + Active(arg) + else + BatchDuplicated(arg, ntuple(Val(Width)) do j + Base.@_inline_meta + dargs[j][i] + end) + end + end +end + +@inline function allFirst(::Val{Width}, res) where Width + ntuple(Val(Width)) do i + Base.@_inline_meta + res[1] + end +end + +@inline function allZero(::Val{Width}, res) where Width + ntuple(Val(Width)) do i + Base.@_inline_meta + Ref(make_zero(res)) + end +end + # This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] -function fwddiff_with_return(::Val{width}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {width, Nargs} - tt′ = Enzyme.vaTypeof(args...) +function fwddiff_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {width, dupClosure0, ReturnType, FT, tt′, DF, Nargs} ReturnPrimal = Val(true) - RT = A ModifiedBetween = Val(Enzyme.falses_from_args(Nargs+1)) - - tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - world = codegen_world_age(Core.Typeof(f.val), tt) - thunk(Val(world), FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), - ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI)(f, args...) -end + dupClosure = dupClosure0 && !guaranteed_const(FT) + FA = dupClosure ? Duplicated{FT} : Const{FT} -function body_runtime_iterate_fwd(N, Width, wrapped, primtypes) - nnothing = ntuple(i->nothing, Val(Width+1)) - nres = ntuple(i->:(res[1]), Val(Width+1)) - ModifiedBetween = ntuple(i->false, Val(N+1)) - ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) - Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) - return quote - args0 = ($(wrapped...),) - args = concat(iterate_unwrap_fwd(args0...)...) - - dupClosure = ActivityTup[1] - FT = Core.Typeof(f) - if dupClosure && guaranteed_const(FT) - dupClosure = false - end + tt = Enzyme.vaEltypes(tt′) - tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} - rt = Core.Compiler.return_type(f, tt) - annotation0 = guess_activity(rt, API.DEM_ForwardMode) + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt, API.DEM_ForwardMode) - annotation = @static if $Width != 1 - if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated - BatchDuplicated{rt, $Width} - else - Const{rt} - end + annotation = if width != 1 + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + BatchDuplicated{rt, width} else - if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated - Duplicated{rt} - else - Const{rt} - end + Const{rt} end + else + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + Duplicated{rt} + else + Const{rt} + end + end - res = fwddiff_with_return(Val($Width), dupClosure ? Duplicated(f, df) : Const(f), annotation, args...) - return if annotation <: Const - ReturnType(($(nres...),)) + world = codegen_world_age(FT, tt) + fa = if dupClosure + if width == 1 + Duplicated(f, df) else - if $Width == 1 - ReturnType((res[1], res[2])) - else - ReturnType((res[1], res[2]...)) - end + BatchDuplicated(f, df) + end + else + Const(f) + end + res = thunk(Val(world), FA, annotation, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), + ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), FFIABI)(fa, args...) + return if annotation <: Const + ReturnType(allFirst(Val(width+1), res)) + else + if width == 1 + ReturnType((res[1], res[2])) + else + ReturnType((res[1], res[2]...)) end end end +function body_runtime_iterate_fwd(N, Width, wrapped, primtypes) + wrappedexexpand = ntuple(i->:($(wrapped[i])...), Val(N)) + return quote + args = ($(wrappedexexpand...),) + tt′ = Enzyme.vaTypeof(args...) + FT = Core.Typeof(f) + fwddiff_with_return(Val($Width), Val(ActivityTup[1]), ReturnType, FT, tt′, f, df, args...)::ReturnType + end +end + function func_runtime_iterate_fwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _ = setup_macro_wraps(true, N, Width) + _, _, primtypes, allargs, typeargs, wrapped, _, _ = setup_macro_wraps(true, N, Width, #=base=#nothing, #=iterate=#true) body = body_runtime_iterate_fwd(N, Width, wrapped, primtypes) quote @@ -397,75 +496,135 @@ end @generated function runtime_iterate_fwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(true, N, Width, :allargs) + _, _, primtypes, _, _, wrapped, _, _ = setup_macro_wraps(true, N, Width, :allargs, #=iterate=#true) return body_runtime_iterate_fwd(N, Width, wrapped, primtypes) end -function body_runtime_iterate_augfwd(N, Width, wrapped, primttypes) - nnothing = ntuple(i->nothing, Val(Width+1)) - nres = ntuple(i->:(origRet), Val(Width+1)) - nzeros = ntuple(i->:(Ref(zero(resT))), Val(Width)) - nres3 = ntuple(i->:(res[3]), Val(Width)) - ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) - Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) +function primal_tuple(args::Vararg{Annotation, Nargs}) where Nargs + ntuple(Val(Nargs)) do i + Base.@_inline_meta + args[i].val + end +end - return quote - args = ($(wrapped...),) - throw(AssertionError("Runtime iterate augmented forward pass unhandled, f=$f df=$df args=$args")) - - # TODO: Annotation of return value - # tt0 = Tuple{$(primtypes...)} - tt′ = Tuple{$(Types...)} - rt = Core.Compiler.return_type(f, Tuple{$(ElTypes...)}) - annotation = guess_activity(rt, API.DEM_ReverseModePrimal) +function shadow_tuple(::Val{1}, args::Vararg{Annotation, Nargs}) where Nargs + ntuple(Val(Nargs)) do i + Base.@_inline_meta + @assert !(args[i] isa Active) + if args[i] isa Const + args[i].val + else + args[i].dval + end + end +end - dupClosure = ActivityTup[1] - FT = Core.Typeof(f) - if dupClosure && guaranteed_const(FT) - dupClosure = false +function shadow_tuple(::Val{width}, args::Vararg{Annotation, Nargs}) where {width, Nargs} + ntuple(Val(width)) do w + ntuple(Val(Nargs)) do i + Base.@_inline_meta + @assert !(args[i] isa Active) + if args[i] isa Const + args[i].val + else + args[i].dval[w] end + end + end +end - world = codegen_world_age(FT, Tuple{$(ElTypes...)}) +# This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] +function augfwd_with_return(::Val{width}, ::Val{dupClosure0}, ::Type{ReturnType}, ::Val{ModifiedBetween0}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, args::Vararg{Annotation, Nargs})::ReturnType where {width, dupClosure0, ReturnType, ModifiedBetween0, FT, tt′, DF, Nargs} + ReturnPrimal = Val(true) + ModifiedBetween = Val(ModifiedBetween0) - forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, - annotation, tt′, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + tt = Enzyme.vaEltypes(tt′) + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) - internal_tape, origRet, initShadow = forward(dupClosure ? Duplicated(f, df) : Const(f), args...) - resT = typeof(origRet) - if annotation <: Const - shadow_return = nothing - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) - return ReturnType(($(nres...), tape)) - elseif annotation <: Active - if $Width == 1 - shadow_return = Ref(make_zero(origRet)) - else - shadow_return = ($(nzeros...),) - end - tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) - if $Width == 1 - return ReturnType((origRet, shadow_return, tape)) + annotation = if width != 1 + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + BatchDuplicated{rt, width} + elseif annotation0 <: Active + Active{rt} + else + Const{rt} + end + else + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + Duplicated{rt} + elseif annotation0 <: Active + Active{rt} + else + Const{rt} + end + end + + internal_tape, origRet, initShadow = if f != Base.tuple + dupClosure = dupClosure0 && !guaranteed_const(FT) + FA = dupClosure ? Duplicated{FT} : Const{FT} + + fa = if dupClosure + if width == 1 + Duplicated(f, df) else - return ReturnType((origRet, shadow_return..., tape)) + BatchDuplicated(f, df) end + else + Const(f) end + world = codegen_world_age(FT, tt) + forward, adjoint = thunk(Val(world), FA, + annotation, tt′, Val(API.DEM_ReverseModePrimal), Val(width), + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + forward(fa, args...) + else + nothing, primal_tuple(args...), annotation <: Active ? nothing : shadow_tuple(Val(width), args...) + end - @assert annotation <: Duplicated || annotation <: DuplicatedNoNeed || annotation <: BatchDuplicated || annotation <: BatchDuplicatedNoNeed - + resT = typeof(origRet) + if annotation <: Const shadow_return = nothing tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) - if $Width == 1 - return ReturnType((origRet, initShadow, tape)) + return ReturnType((allFirst(Val(width+1), origRet)..., tape)) + elseif annotation <: Active + if width == 1 + shadow_return = Ref(make_zero(origRet)) else - return ReturnType((origRet, initShadow..., tape)) + shadow_return = allZero(Val(width), origRet) end + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + if width == 1 + return ReturnType((origRet, shadow_return, tape)) + else + return ReturnType((origRet, shadow_return..., tape)) + end + end + + @assert annotation <: Duplicated || annotation <: DuplicatedNoNeed || annotation <: BatchDuplicated || annotation <: BatchDuplicatedNoNeed + + shadow_return = nothing + tape = Tape{typeof(internal_tape), typeof(shadow_return), resT}(internal_tape, shadow_return) + if width == 1 + return ReturnType((origRet, initShadow, tape)) + else + return ReturnType((origRet, initShadow..., tape)) + end +end + +function body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) + wrappedexexpand = ntuple(i->:($(wrapped[i])...), Val(N)) + return quote + args = ($(wrappedexexpand...),) + tt′ = Enzyme.vaTypeof(args...) + FT = Core.Typeof(f) + augfwd_with_return(Val($Width), Val(ActivityTup[1]), ReturnType, Val(concat($(modbetween...))), FT, tt′, f, df, args...)::ReturnType end end function func_runtime_iterate_augfwd(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, _ = setup_macro_wraps(false, N, Width) - body = body_runtime_iterate_augfwd(N, Width, wrapped, primtypes) + _, _, primtypes, allargs, typeargs, wrapped, _, modbetween = setup_macro_wraps(false, N, Width, #=base=#nothing, #=iterate=#true) + body = body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) quote function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, ReturnType, F, DF, $(typeargs...)} @@ -476,11 +635,139 @@ end @generated function runtime_iterate_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, ReturnType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, _ = setup_macro_wraps(false, N, Width, :allargs) - return body_runtime_iterate_augfwd(N, Width, wrapped, primtypes) + _, _, primtypes, _, _, wrapped, _ , modbetween, = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) + return body_runtime_iterate_augfwd(N, Width, modbetween, wrapped, primtypes) +end + + + +# This is explicitly escaped here to be what is apply generic in total [and thus all the insides are stable] +function rev_with_return(::Val{width}, ::Val{dupClosure0}, ::Val{ModifiedBetween0}, ::Val{lengths}, ::Type{FT}, ::Type{tt′}, f::FT, df::DF, tape, shadowargs, args::Vararg{Annotation, Nargs})::Nothing where {width, dupClosure0, ModifiedBetween0, lengths, FT, tt′, DF, Nargs} + ReturnPrimal = Val(true) + ModifiedBetween = Val(ModifiedBetween0) + + dupClosure = dupClosure0 && !guaranteed_const(FT) + FA = dupClosure ? Duplicated{FT} : Const{FT} + + tt = Enzyme.vaEltypes(tt′) + + rt = Core.Compiler.return_type(f, tt) + annotation0 = guess_activity(rt, API.DEM_ReverseModePrimal) + + annotation = if width != 1 + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + BatchDuplicated{rt, width} + elseif annotation0 <: Active + Active{rt} + else + Const{rt} + end + else + if annotation0 <: DuplicatedNoNeed || annotation0 <: Duplicated + Duplicated{rt} + elseif annotation0 <: Active + Active{rt} + else + Const{rt} + end + end + + tup = if f != Base.tuple + world = codegen_world_age(FT, tt) + + fa = if dupClosure + if width == 1 + Duplicated(f, df) + else + BatchDuplicated(f, df) + end + else + Const(f) + end + forward, adjoint = thunk(Val(world), FA, + annotation, tt′, Val(API.DEM_ReverseModePrimal), Val(width), + ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) + + args2 = if tape.shadow_return !== nothing + if width == 1 + (args..., tape.shadow_return[]) + else + (args..., ntuple(Val(width)) do w + Base.@_inline_meta + tape.shadow_return[w][] + end) + end + else + args + end + + adjoint(fa, args2..., tape.internal_tape)[1] + else + ntuple(Val(Nargs)) do i + Base.@_inline_meta + if args[i] isa Active + if width == 1 + tape.shadow_return[][i] + else + ntuple(Val(width)) do w + Base.@_inline_meta + tape.shadow_return[w][][i] + end + end + else + nothing + end + end + end + + ntuple(Val(Nargs)) do i + Base.@_inline_meta + + ntuple(Val(width)) do w + Base.@_inline_meta + + if tup[i] == nothing + else + expr = if width == 1 + tup[i] + else + tup[i][w] + end + idx_of_vec, idx_in_vec = lengths[i] + vec = @inbounds shadowargs[idx_of_vec][w] + if vec isa Base.RefValue + vecld = vec[] + T = Core.Typeof(vecld) + vec[] = splatnew(T, ntuple(Val(fieldcount(T))) do i + Base.@_inline_meta + prev = getfield(vecld, i) + if i == idx_in_vec + recursive_add(prev, expr) + else + prev + end + end) + else + val = @inbounds vec[idx_in_vec] + if val isa Base.RefValue + val[] = recursive_add(val[], expr) + elseif ismutable(vec) + @inbounds vec[idx_in_vec] = recursive_add(val, expr) + else + error("Enzyme Mutability Error: Cannot in place to immutable value vec[$idx_in_vec] = $val, vec=$vec") + end + end + end + + nothing + end + + nothing + end + nothing end -function body_runtime_iterate_rev(N, Width, wrapped, primttypes, shadowargs) +function body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, shadowargs) outs = [] for i in 1:N for w in 1:Width @@ -494,7 +781,7 @@ function body_runtime_iterate_rev(N, Width, wrapped, primttypes, shadowargs) elseif $shad isa Base.RefValue $shad[] = recursive_add($shad[], $expr) else - error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string($shad)) + error("Enzyme Mutability Error: Cannot add in place to immutable value "*string($shad)) end ) push!(outs, out) @@ -514,40 +801,30 @@ function body_runtime_iterate_rev(N, Width, wrapped, primttypes, shadowargs) ElTypes = ntuple(i->:(eltype(Core.Typeof(args[$i]))), Val(N)) Types = ntuple(i->:(Core.Typeof(args[$i])), Val(N)) + wrappedexexpand = ntuple(i->:($(wrapped[i])...), Val(N)) + lengths = ntuple(i->quote + (ntuple(Val(length($(primargs[i])))) do j + Base.@_inline_meta + ($i, j) + end) + end, Val(N)) + + shadowsplat = Expr[] + for s in shadowargs + push!(shadowsplat, :(($(s...),))) + end quote - args = ($(wrapped...),) - throw(AssertionError("Runtime iterate reverse pass unhandled, f=$f df=$df args=$args")) - - # TODO: Annotation of return value - # tt0 = Tuple{$(primtypes...)} - tt = Tuple{$(ElTypes...)} - tt′ = Tuple{$(Types...)} - rt = Core.Compiler.return_type(f, tt) - annotation = guess_activity(rt, API.DEM_ReverseModePrimal) - - dupClosure = ActivityTup[1] + args = ($(wrappedexexpand...),) + tt′ = Enzyme.vaTypeof(args...) FT = Core.Typeof(f) - if dupClosure && guaranteed_const(FT) - dupClosure = false - end - world = codegen_world_age(FT, tt) - - forward, adjoint = thunk(Val(world), (dupClosure ? Duplicated : Const){FT}, annotation, tt′, Val(API.DEM_ReverseModePrimal), width, - ModifiedBetween, #=returnPrimal=#Val(true), #=shadowInit=#Val(false), FFIABI) - if tape.shadow_return !== nothing - args = (args..., $shadowret) - end - - tup = adjoint(dupClosure ? Duplicated(f, df) : Const(f), args..., tape.internal_tape)[1] - - $(outs...) + rev_with_return(Val($Width), Val(ActivityTup[1]), Val(concat($(modbetween...))), Val(concat($(lengths...))), FT, tt′, f, df, tape, ($(shadowsplat...),), args...) return nothing end end function func_runtime_iterate_rev(N, Width) - _, _, primtypes, allargs, typeargs, wrapped, batchshadowargs = setup_macro_wraps(false, N, Width) - body = body_runtime_iterate_rev(N, Width, wrapped, primtypes, batchshadowargs) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, modbetween = setup_macro_wraps(false, N, Width, #=body=#nothing, #=iterate=#true) + body = body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs) quote function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, $(allargs...)) where {ActivityTup, MB, TapeType, F, DF, $(typeargs...)} @@ -558,8 +835,8 @@ end @generated function runtime_iterate_rev(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, tape::TapeType, f::F, df::DF, allargs...) where {ActivityTup, MB, Width, TapeType, F, DF} N = div(length(allargs)+2, Width+1)-1 - _, _, primtypes, _, _, wrapped, batchshadowargs = setup_macro_wraps(false, N, Width, :allargs) - return body_runtime_generic_rev(N, Width, wrapped, primtypes, batchshadowargs) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, modbetween = setup_macro_wraps(false, N, Width, :allargs, #=iterate=#true) + return body_runtime_iterate_rev(N, Width, modbetween, wrapped, primargs, batchshadowargs) end # Create specializations @@ -697,7 +974,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, end debug_from_orig!(gutils, cal, orig) - + if tape === nothing llty = convert(LLVMType, ReturnType) cal = LLVM.addrspacecast!(B, cal, LLVM.PointerType(T_jlvalue, Derived)) @@ -778,7 +1055,7 @@ function common_generic_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) width = get_width(gutils) sret = generic_setup(orig, runtime_generic_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset, B, false) AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) - + if unsafe_load(shadowR) != C_NULL if width == 1 gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) @@ -1074,18 +1351,6 @@ function common_apply_iterate_fwd(offset, B, orig, gutils, normalR, shadowR) return false end -function error_if_active_iter(arg) - # check if it could contain an active - for v in arg - seen = () - T = Core.Typeof(v) - areg = active_reg_inner(T, seen, nothing, #=justActive=#Val(true)) - if areg == ActiveState - throw(AssertionError("Found unhandled active variable in tuple splat, jl_apply_iterate $T")) - end - end -end - function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, tapeR) needsShadowP = Ref{UInt8}(0) needsPrimalP = Ref{UInt8}(0) @@ -1100,51 +1365,41 @@ function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, width = get_width(gutils) - if v && v2 && isiter == Base.iterate && istup == Base.tuple && length(operands(orig)) >= offset+4 - origops = collect(operands(orig)[1:end-1]) - shadowins = [ invert_pointer(gutils, origops[i], B) for i in (offset+3):length(origops) ] - shadowres = if width == 1 - newops = LLVM.Value[] - newvals = API.CValueType[] - for (i, v) in enumerate(origops) - if i >= offset + 3 - shadowin2 = shadowins[i-offset-3+1] - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active_iter), shadowin2]) - push!(newops, shadowin2) - push!(newvals, API.VT_Shadow) - else - push!(newops, new_from_original(gutils, origops[i])) - push!(newvals, API.VT_Primal) - end - end - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) - callconv!(cal, callconv(orig)) - cal - else - ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) - shadow = LLVM.UndefValue(ST) - for j in 1:width - newops = LLVM.Value[] - newvals = API.CValueType[] - for (i, v) in enumerate(origops) - if i >= offset + 3 - shadowin2 = extract_value!(B, shadowins[i-offset-3+1], j-1) - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active_iter), shadowin2]) - push!(newops, shadowin2) - push!(newvals, API.VT_Shadow) - else - push!(newops, new_from_original(gutils, origops[i])) - push!(newvals, API.VT_Primal) - end + if v && isiter == Base.iterate + T_jlvalue = LLVM.StructType(LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + + sret = generic_setup(orig, runtime_iterate_augfwd, AnyArray(2+Int(width)), gutils, #=start=#offset+2, B, false) + AT = LLVM.ArrayType(T_prjlvalue, 2+Int(width)) + + if unsafe_load(shadowR) != C_NULL + if width == 1 + gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1)]) + shadow = LLVM.load!(B, T_prjlvalue, gep) + else + ST = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))) + shadow = LLVM.UndefValue(ST) + for i in 1:width + gep = LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(i)]) + ld = LLVM.load!(B, T_prjlvalue, gep) + shadow = insert_value!(B, shadow, ld, i-1) end - cal = call_samefunc_with_inverted_bundles!(B, gutils, orig, newops, newvals, #=lookup=#false) - callconv!(cal, callconv(orig)) - shadow = insert_value!(B, shadow, cal, j-1) end - shadow + unsafe_store!(shadowR, shadow.ref) end - unsafe_store!(shadowR, shadowres.ref) + tape = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(1+width)])) + unsafe_store!(tapeR, tape.ref) + + if normalR != C_NULL + normal = LLVM.load!(B, T_prjlvalue, LLVM.inbounds_gep!(B, AT, sret, [LLVM.ConstantInt(0), LLVM.ConstantInt(0)])) + unsafe_store!(normalR, normal.ref) + else + # Delete the primal code + ni = new_from_original(gutils, orig) + erase_with_placeholder(gutils, ni, orig) + end + return false return false end @@ -1155,6 +1410,17 @@ function common_apply_iterate_augfwd(offset, B, orig, gutils, normalR, shadowR, end function common_apply_iterate_rev(offset, B, orig, gutils, tape) + needsShadowP = Ref{UInt8}(0) + needsPrimalP = Ref{UInt8}(0) + activep = API.EnzymeGradientUtilsGetReturnDiffeType(gutils, orig, needsPrimalP, needsShadowP, get_mode(gutils)) + + if (is_constant_value(gutils, orig) || needsShadowP[] == 0 ) && is_constant_inst(gutils, orig) + return nothing + end + + @assert tape !== C_NULL + width = get_width(gutils) + generic_setup(orig, runtime_iterate_rev, Nothing, gutils, #=start=#offset+2, B, true; tape) return nothing end diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index c6639eb8aa..1ee4f0d961 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -249,7 +249,7 @@ function common_jl_getfield_fwd(offset, B, orig, gutils, normalR, shadowR) return false end -function rt_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {T, symname, isconst} +function rt_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {NT, T, T2, Nargs, symname, isconst} res = if dptr isa Base.RefValue Base.getfield(dptr[], symname) else @@ -260,40 +260,57 @@ function rt_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs if length(dptrs) == 0 return Ref{RT}(make_zero(res)) else - return ( (Ref{RT}(make_zero(res)) for _ in 1:(1+length(dptrs)))..., ) + return NT(ntuple(Val(1+length(dptrs))) do i + Base.@_inline_meta + Ref{RT}(make_zero(res)) + end) end else if length(dptrs) == 0 return res else - return (res, (getfield(dv, symname) for dv in dptrs)...) + fval = NT((res, (ntuple(Val(length(dptrs))) do i + Base.@_inline_meta + dv = dptrs[i] + getfield(dv isa Base.RefValue ? dv[] : dv, symname) + end)...)) + return fval end end end -function idx_jl_getfield_aug(dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {T, symname, isconst} +function idx_jl_getfield_aug(::Val{NT}, dptr::T, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {NT, T, T2, Nargs, symname, isconst} res = if dptr isa Base.RefValue Base.getfield(dptr[], symname+1) else Base.getfield(dptr, symname+1) end RT = Core.Typeof(res) - if active_reg(RT) + actreg = active_reg(RT) + if actreg if length(dptrs) == 0 - return Ref{RT}(make_zero(res)) + return Ref{RT}(make_zero(res))::Any else - return ( (Ref{RT}(make_zero(res)) for _ in 1:(1+length(dptrs)))..., ) + return NT(ntuple(Val(1+length(dptrs))) do i + Base.@_inline_meta + Ref{RT}(make_zero(res)) + end) end else if length(dptrs) == 0 - return res + return res::Any else - return (res, (getfield(dv, symname) for dv in dptrs)...) + fval = NT((res, (ntuple(Val(length(dptrs))) do i + Base.@_inline_meta + dv = dptrs[i] + getfield(dv isa Base.RefValue ? dv[] : dv, symname+1) + end)...)) + return fval end end end -function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {T, symname, isconst} +function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {T, T2, Nargs, symname, isconst} cur = if dptr isa Base.RefValue getfield(dptr[], symname) else @@ -303,17 +320,65 @@ function rt_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, RT = Core.Typeof(cur) if active_reg(RT) && !isconst if length(dptrs) == 0 - setfield!(dptr, symname, recursive_add(cur, dret[])) + if dptr isa Base.RefValue + vload = dptr[] + dRT = Core.Typeof(vload) + dptr[] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do i + Base.@_inline_meta + prev = getfield(vload, i) + if fieldname(dRT, i) == symname + recursive_add(prev, dret[]) + else + prev + end + end) + else + setfield!(dptr, symname, recursive_add(cur, dret[])) + end else - setfield!(dptr, symname, recursive_add(cur, dret[1][])) + if dptr isa Base.RefValue + vload = dptr[] + dRT = Core.Typeof(vload) + dptr[] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do j + Base.@_inline_meta + prev = getfield(vload, j) + if fieldname(dRT, j) == symname + recursive_add(prev, dret[1][]) + else + prev + end + end) + else + setfield!(dptr, symname, recursive_add(cur, dret[1][])) + end for i in 1:length(dptrs) - setfield!(dptrs[i], symname, recursive_add(cur, dret[1+i][])) + if dptrs[i] isa Base.RefValue + vload = dptrs[i][] + dRT = Core.Typeof(vload) + dptrs[i][] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do j + Base.@_inline_meta + prev = getfield(vload, j) + if fieldname(dRT, j) == symname + recursive_add(prev, dret[1+i][]) + else + prev + end + end) + else + curi = if dptr isa Base.RefValue + Base.getfield(dptrs[i][], symname) + else + Base.getfield(dptrs[i], symname) + end + setfield!(dptrs[i], symname, recursive_add(curi, dret[1+i][])) + end end end end return nothing end -function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs...) where {T, symname, isconst} + +function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst}, dptrs::Vararg{T2, Nargs}) where {T, T2, Nargs, symname, isconst} cur = if dptr isa Base.RefValue Base.getfield(dptr[], symname+1) else @@ -323,11 +388,58 @@ function idx_jl_getfield_rev(dptr::T, dret, ::Type{Val{symname}}, ::Val{isconst} RT = Core.Typeof(cur) if active_reg(RT) && !isconst if length(dptrs) == 0 - setfield!(dptr, symname+1, recursive_add(cur, dret[])) + if dptr isa Base.RefValue + vload = dptr[] + dRT = Core.Typeof(vload) + dptr[] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do i + Base.@_inline_meta + prev = getfield(vload, i) + if i == symname+1 + recursive_add(prev, dret[]) + else + prev + end + end) + else + setfield!(dptr, symname+1, recursive_add(cur, dret[])) + end else - setfield!(dptr, symname+1, recursive_add(cur, dret[1][])) + if dptr isa Base.RefValue + vload = dptr[] + dRT = Core.Typeof(vload) + dptr[] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do j + Base.@_inline_meta + prev = getfield(vload, j) + if j == symname+1 + recursive_add(prev, dret[1][]) + else + prev + end + end) + else + setfield!(dptr, symname+1, recursive_add(cur, dret[1][])) + end for i in 1:length(dptrs) - setfield!(dptrs[i], symname+1, recursive_add(cur, dret[1+i][])) + if dptrs[i] isa Base.RefValue + vload = dptrs[i][] + dRT = Core.Typeof(vload) + dptrs[i][] = splatnew(dRT, ntuple(Val(fieldcount(dRT))) do j + Base.@_inline_meta + prev = getfield(vload, j) + if j == symname+1 + recursive_add(prev, dret[1+i][]) + else + prev + end + end) + else + curi = if dptr isa Base.RefValue + Base.getfield(dptrs[i][], symname+1) + else + Base.getfield(dptrs[i], symname+1) + end + setfield!(dptrs[i], symname+1, recursive_add(curi, dret[1+i][])) + end end end end @@ -362,7 +474,8 @@ function common_jl_getfield_augfwd(offset, B, orig, gutils, normalR, shadowR, ta inps = [new_from_original(gutils, ops[2])] end - vals = LLVM.Value[] + AA = Val(AnyArray(Int(width))) + vals = LLVM.Value[unsafe_to_llvm(AA)] push!(vals, inps[1]) sym = new_from_original(gutils, ops[3]) @@ -539,7 +652,8 @@ function jl_nthfield_augfwd(B, orig, gutils, normalR, shadowR, tapeR) inps = [new_from_original(gutils, ops[1])] end - vals = LLVM.Value[] + AA = Val(AnyArray(Int(width))) + vals = LLVM.Value[unsafe_to_llvm(AA)] push!(vals, inps[1]) sym = new_from_original(gutils, ops[2]) diff --git a/src/utils.jl b/src/utils.jl index a3268c6c94..916818181e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -8,7 +8,7 @@ @inline unsafe_to_pointer(val::Type{T}) where T = ccall(Base.@cfunction(x->x, Ptr{Cvoid}, (Ptr{Cvoid},)), Ptr{Cvoid}, (Any,), val) export unsafe_to_pointer -@inline is_concrete_tuple(x::T2) where T2 = (x <: Tuple) && !(x === Tuple) && !(x isa UnionAll) +@inline is_concrete_tuple(x::Type{T2}) where T2 = (T2 <: Tuple) && !(T2 === Tuple) && !(T2 isa UnionAll) export is_concrete_tuple const Tracked = 10 diff --git a/test/applyiter.jl b/test/applyiter.jl new file mode 100644 index 0000000000..2518e2d829 --- /dev/null +++ b/test/applyiter.jl @@ -0,0 +1,491 @@ +using Enzyme, Test + +concat() = () +concat(a) = a +concat(a, b) = (a..., b...) +concat(a, b, c...) = concat(concat(a, b), c...) + +metaconcat(x) = concat(x...) + +metaconcat2(x, y) = concat(x..., y...) + +midconcat(x, y) = (x, concat(y...)...) + +metaconcat3(x, y, z) = concat(x..., y..., z...) + +function metasumsq(f, args...) + res = 0.0 + x = f(args...) + for v in x + v = v::Float64 + res += v*v + end + return res +end + +function metasumsq2(f, args...) + res = 0.0 + x = f(args...) + for v in x + for v2 in v + v2 = v2::Float64 + res += v*v + end + end + return res +end + + +function metasumsq3(f, args...) + res = 0.0 + x = f(args...) + for v in x + v = v + res += v*v + end + return res +end + +function metasumsq4(f, args...) + res = 0.0 + x = f(args...) + for v in x + for v2 in v + v2 = v2 + res += v*v + end + end + return res +end + +function make_byref(out, fn, args...) + out[] = fn(args...) + nothing +end + +function tupapprox(a, b) + if a isa Tuple && b isa Tuple + if length(a) != length(b) + return false + end + for (aa, bb) in zip(a, b) + if !tupapprox(aa, bb) + return false + end + end + return true + end + if a isa Array && b isa Array + if size(a) != size(b) + return false + end + for i in length(a) + if !tupapprox(a[i], b[i]) + return false + end + end + return true + end + return a ≈ b +end + +@testset "Reverse Apply iterate" begin + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(0.0, 0.0), (0.0, 0.0)] + res = Enzyme.autodiff(Reverse, metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + + dx = [(0.0, 0.0), (0.0, 0.0)] + res = Enzyme.autodiff(ReverseWithPrimal, metasumsq, Active, Const(metaconcat), Duplicated(x, dx)) + @test res[2] ≈ 200.84999999999997 + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + + x = [[2.0, 3.0], [7.9, 11.2]] + dx = [[0.0, 0.0], [0.0, 0.0]] + + res = Enzyme.autodiff(Reverse, metasumsq2, Active, Const(metaconcat), Duplicated(x, dx)) + @test dx ≈ [[4.0, 6.0], [15.8, 22.4]] + + dx = [[0.0, 0.0], [0.0, 0.0]] + + res = Enzyme.autodiff(ReverseWithPrimal, metasumsq2, Active, Const(metaconcat), Duplicated(x, dx)) + + @test res[2] ≈ 200.84999999999997 + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) + + + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(0.0, 0.0), (0.0, 0.0)] + + y = [(13, 17), (25, 31)] + res = Enzyme.autodiff(Reverse, metasumsq3, Active, Const(metaconcat2), Duplicated(x, dx), Const(y)) + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + + + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(0.0, 0.0), (0.0, 0.0)] + y = [(13, 17), (25, 31)] + dy = [(0, 0), (0, 0)] + res = Enzyme.autodiff(Reverse, metasumsq3, Active, Const(metaconcat2), Duplicated(x, dx), Duplicated(y, dy)) + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + + + + x = [[2.0, 3.0], [7.9, 11.2]] + dx = [[0.0, 0.0], [0.0, 0.0]] + y = [[13, 17], [25, 31]] + res = Enzyme.autodiff(Reverse, metasumsq4, Active, Const(metaconcat2), Duplicated(x, dx), Const(y)) + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) + + + x = [[2.0, 3.0], [7.9, 11.2]] + dx = [[0.0, 0.0], [0.0, 0.0]] + y = [[13, 17], [25, 31]] + dy = [[0, 0], [0, 0]] + res = Enzyme.autodiff(Reverse, metasumsq4, Active, Const(metaconcat2), Duplicated(x, dx), Duplicated(y, dy)) + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) +end + +@testset "BatchReverse Apply iterate" begin + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(0.0, 0.0), (0.0, 0.0)] + dx2 = [(0.0, 0.0), (0.0, 0.0)] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + @test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)]) + + dx = [(0.0, 0.0), (0.0, 0.0)] + dx2 = [(0.0, 0.0), (0.0, 0.0)] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test out[] ≈ 200.84999999999997 + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + @test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)]) + + x = [[2.0, 3.0], [7.9, 11.2]] + dx = [[0.0, 0.0], [0.0, 0.0]] + dx2 = [[0.0, 0.0], [0.0, 0.0]] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + @test dx ≈ [[4.0, 6.0], [15.8, 22.4]] + @test dx2 ≈ [[3*4.0, 3*6.0], [3*15.8, 3*22.4]] + + dx = [[0.0, 0.0], [0.0, 0.0]] + dx2 = [[0.0, 0.0], [0.0, 0.0]] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq2), Const(metaconcat), BatchDuplicated(x, (dx, dx2))) + + @test out[] ≈ 200.84999999999997 + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) + @test tupapprox(dx2, [[3*4.0, 3*6.0], [3*15.8, 3*22.4]]) + + + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(0.0, 0.0), (0.0, 0.0)] + dx2 = [(0.0, 0.0), (0.0, 0.0)] + + y = [(13, 17), (25, 31)] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y)) + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + @test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)]) + + + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(0.0, 0.0), (0.0, 0.0)] + dx2 = [(0.0, 0.0), (0.0, 0.0)] + y = [(13, 17), (25, 31)] + dy = [(0, 0), (0, 0)] + dy2 = [(0, 0), (0, 0)] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicatedNoNeed(out, (dout, dout2)), Const(metasumsq3),Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2))) + @test tupapprox(dx, [(4.0, 6.0), (15.8, 22.4)]) + @test tupapprox(dx2, [(3*4.0, 3*6.0), (3*15.8, 3*22.4)]) + + + x = [[2.0, 3.0], [7.9, 11.2]] + dx = [[0.0, 0.0], [0.0, 0.0]] + dx2 = [[0.0, 0.0], [0.0, 0.0]] + y = [[13, 17], [25, 31]] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), Const(y)) + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) + @test tupapprox(dx2, [[3*4.0, 3*6.0], [3*15.8, 3*22.4]]) + + x = [[2.0, 3.0], [7.9, 11.2]] + dx = [[0.0, 0.0], [0.0, 0.0]] + dx2 = [[0.0, 0.0], [0.0, 0.0]] + y = [[13, 17], [25, 31]] + dy = [[0, 0], [0, 0]] + dy2 = [[0, 0], [0, 0]] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + Enzyme.autodiff(Reverse, make_byref, Const, BatchDuplicated(out, (dout, dout2)), Const(metasumsq4), Const(metaconcat2), BatchDuplicated(x, (dx, dx2)), BatchDuplicated(y, (dy, dy2))) + @test tupapprox(dx, [[4.0, 6.0], [15.8, 22.4]]) + @test tupapprox(dx2, [[3*4.0, 3*6.0], [3*15.8, 3*22.4]]) +end + +@testset "Forward Apply iterate" begin + x = [(2.0, 3.0), (7.9, 11.2)] + dx = [(13.7, 15.2), (100.02, 304.1)] + + dres, = Enzyme.autodiff(Forward, metaconcat, Duplicated(x, dx)) + @test length(dres) == 4 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(x, dx)) + @test length(res) == 4 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + @test length(dres) == 4 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + + a = [("a", "b"), ("c", "d")] + da = [("e", "f"), ("g", "h")] + + dres, = Enzyme.autodiff(Forward, metaconcat, Duplicated(a, da)) + @test length(dres) == 4 + @test dres[1] == "a" + @test dres[2] == "b" + @test dres[3] == "c" + @test dres[4] == "d" + + res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(a, da)) + @test length(res) == 4 + @test res[1] == "a" + @test res[2] == "b" + @test res[3] == "c" + @test res[4] == "d" + @test length(dres) == 4 + @test dres[1] == "a" + @test dres[2] == "b" + @test dres[3] == "c" + @test dres[4] == "d" + + + Enzyme.autodiff(Forward, metaconcat, Const(a)) + +@static if VERSION ≥ v"1.7-" + dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Duplicated(a, da)) + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" + + res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Duplicated(a, da)) + @test length(res) == 5 + @test res[1] ≈ 1.0 + @test res[2] == "a" + @test res[3] == "b" + @test res[4] == "c" + @test res[5] == "d" + + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" + + + dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Const(a)) + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" + + res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Const(a)) + @test length(res) == 5 + @test res[1] ≈ 1.0 + @test res[2] == "a" + @test res[3] == "b" + @test res[4] == "c" + @test res[5] == "d" + @test length(dres) == 5 + @test dres[1] ≈ 7.0 + @test dres[2] == "a" + @test dres[3] == "b" + @test dres[4] == "c" + @test dres[5] == "d" +end + + y = [(-92.0, -93.0), (-97.9, -911.2)] + dy = [(-913.7, -915.2), (-9100.02, -9304.1)] + + dres, = Enzyme.autodiff(Forward, metaconcat2, Duplicated(x, dx), Duplicated(y, dy)) + @test length(dres) == 8 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + @test dres[5] ≈ -913.7 + @test dres[6] ≈ -915.2 + @test dres[7] ≈ -9100.02 + @test dres[8] ≈ -9304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat2, Duplicated, Duplicated(x, dx), Duplicated(y, dy)) + @test length(res) == 8 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + @test res[5] ≈ -92.0 + @test res[6] ≈ -93.0 + @test res[7] ≈ -97.9 + @test res[8] ≈ -911.2 + @test length(dres) == 8 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + @test dres[5] ≈ -913.7 + @test dres[6] ≈ -915.2 + @test dres[7] ≈ -9100.02 + @test dres[8] ≈ -9304.1 + + + dres, = Enzyme.autodiff(Forward, metaconcat3, Duplicated(x, dx), Const(a), Duplicated(y, dy)) + @test length(dres) == 12 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + @test dres[5] == "a" + @test dres[6] == "b" + @test dres[7] == "c" + @test dres[8] == "d" + + @test dres[9] ≈ -913.7 + @test dres[10] ≈ -915.2 + @test dres[11] ≈ -9100.02 + @test dres[12] ≈ -9304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat3, Duplicated, Duplicated(x, dx), Const(a), Duplicated(y, dy)) + @test length(res) == 12 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + + @test res[5] == "a" + @test res[6] == "b" + @test res[7] == "c" + @test res[8] == "d" + + @test res[9] ≈ -92.0 + @test res[10] ≈ -93.0 + @test res[11] ≈ -97.9 + @test res[12] ≈ -911.2 + + @test length(dres) == 12 + @test dres[1] ≈ 13.7 + @test dres[2] ≈ 15.2 + @test dres[3] ≈ 100.02 + @test dres[4] ≈ 304.1 + + @test dres[5] == "a" + @test dres[6] == "b" + @test dres[7] == "c" + @test dres[8] == "d" + + @test dres[9] ≈ -913.7 + @test dres[10] ≈ -915.2 + @test dres[11] ≈ -9100.02 + @test dres[12] ≈ -9304.1 + + + dres, = Enzyme.autodiff(Forward, metaconcat, BatchDuplicated(x, (dx, dy))) + @test length(dres[1]) == 4 + @test dres[1][1] ≈ 13.7 + @test dres[1][2] ≈ 15.2 + @test dres[1][3] ≈ 100.02 + @test dres[1][4] ≈ 304.1 + @test length(dres[2]) == 4 + @test dres[2][1] ≈ -913.7 + @test dres[2][2] ≈ -915.2 + @test dres[2][3] ≈ -9100.02 + @test dres[2][4] ≈ -9304.1 + + res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, BatchDuplicated(x, (dx, dy))) + @test length(res) == 4 + @test res[1] ≈ 2.0 + @test res[2] ≈ 3.0 + @test res[3] ≈ 7.9 + @test res[4] ≈ 11.2 + @test length(dres[1]) == 4 + @test dres[1][1] ≈ 13.7 + @test dres[1][2] ≈ 15.2 + @test dres[1][3] ≈ 100.02 + @test dres[1][4] ≈ 304.1 + @test length(dres[2]) == 4 + @test dres[2][1] ≈ -913.7 + @test dres[2][2] ≈ -915.2 + @test dres[2][3] ≈ -9100.02 + @test dres[2][4] ≈ -9304.1 +end + +@testset "legacy reverse apply iterate" begin + function mktup(v) + tup = tuple(v...) + return tup[1][1] * tup[3][1] + end + + data = [[3.0], nothing, [2.0]] + ddata = [[0.0], nothing, [0.0]] + + Enzyme.autodiff(Reverse, mktup, Duplicated(data, ddata)) + @test ddata[1][1] ≈ 2.0 + @test ddata[3][1] ≈ 3.0 + + function mktup2(v) + tup = tuple(v...) + return (tup[1][1] * tup[3])::Float64 + end + + data = [[3.0], nothing, 2.0] + ddata = [[0.0], nothing, 0.0] + + @test_throws AssertionError Enzyme.autodiff(Reverse, mktup2, Duplicated(data, ddata)) + + function mktup3(v) + tup = tuple(v..., v...) + return tup[1][1] * tup[1][1] + end + + data = [[3.0]] + ddata = [[0.0]] + + Enzyme.autodiff(Reverse, mktup3, Duplicated(data, ddata)) + @test ddata[1][1] ≈ 6.0 +end diff --git a/test/runtests.jl b/test/runtests.jl index 0212ec0d83..e931666f90 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -137,6 +137,7 @@ end @assert Enzyme.Compiler.active_reg_inner(Symbol, (), nothing) == Enzyme.Compiler.AnyState @assert Enzyme.Compiler.active_reg_inner(String, (), nothing) == Enzyme.Compiler.AnyState @assert Enzyme.Compiler.active_reg_inner(Tuple{Any,Int64}, (), nothing) == Enzyme.Compiler.DupState + @assert Enzyme.Compiler.active_reg_inner(Tuple{S,Int64} where S, (), Base.get_world_counter()) == Enzyme.Compiler.DupState @assert Enzyme.Compiler.active_reg_inner(Union{Float64,Nothing}, (), nothing) == Enzyme.Compiler.DupState @assert Enzyme.Compiler.active_reg_inner(Union{Float64,Nothing}, (), nothing, #=justActive=#Val(false), #=unionSret=#Val(true)) == Enzyme.Compiler.ActiveState world = codegen_world_age(typeof(f0), Tuple{Float64}) @@ -1670,232 +1671,38 @@ end end -concat() = () -concat(a) = a -concat(a, b) = (a..., b...) -concat(a, b, c...) = concat(concat(a, b), c...) - -metaconcat(x) = concat(x...) - -metaconcat2(x, y) = concat(x..., y...) - -midconcat(x, y) = (x, concat(y...)...) - -metaconcat3(x, y, z) = concat(x..., y..., z...) - -@testset "Forward Apply iterate" begin - x = [(2.0, 3.0), (7.9, 11.2)] - dx = [(13.7, 15.2), (100.02, 304.1)] - - dres, = Enzyme.autodiff(Forward, metaconcat, Duplicated(x, dx)) - @test length(dres) == 4 - @test dres[1] ≈ 13.7 - @test dres[2] ≈ 15.2 - @test dres[3] ≈ 100.02 - @test dres[4] ≈ 304.1 - - res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(x, dx)) - @test length(res) == 4 - @test res[1] ≈ 2.0 - @test res[2] ≈ 3.0 - @test res[3] ≈ 7.9 - @test res[4] ≈ 11.2 - @test length(dres) == 4 - @test dres[1] ≈ 13.7 - @test dres[2] ≈ 15.2 - @test dres[3] ≈ 100.02 - @test dres[4] ≈ 304.1 - - - a = [("a", "b"), ("c", "d")] - da = [("e", "f"), ("g", "h")] - - dres, = Enzyme.autodiff(Forward, metaconcat, Duplicated(a, da)) - @test length(dres) == 4 - @test dres[1] == "a" - @test dres[2] == "b" - @test dres[3] == "c" - @test dres[4] == "d" - - res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, Duplicated(a, da)) - @test length(res) == 4 - @test res[1] == "a" - @test res[2] == "b" - @test res[3] == "c" - @test res[4] == "d" - @test length(dres) == 4 - @test dres[1] == "a" - @test dres[2] == "b" - @test dres[3] == "c" - @test dres[4] == "d" - - - Enzyme.autodiff(Forward, metaconcat, Const(a)) +function batchgf(out, args) + res = 0.0 + x = Base.inferencebarrier((args[1][1],)) + for v in x + v = v::Float64 + res += v + break + end + out[] = res + nothing +end -@static if VERSION ≥ v"1.7-" - dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Duplicated(a, da)) - @test length(dres) == 5 - @test dres[1] ≈ 7.0 - @test dres[2] == "a" - @test dres[3] == "b" - @test dres[4] == "c" - @test dres[5] == "d" - - res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Duplicated(a, da)) - @test length(res) == 5 - @test res[1] ≈ 1.0 - @test res[2] == "a" - @test res[3] == "b" - @test res[4] == "c" - @test res[5] == "d" - - @test length(dres) == 5 - @test dres[1] ≈ 7.0 - @test dres[2] == "a" - @test dres[3] == "b" - @test dres[4] == "c" - @test dres[5] == "d" - - - dres, = Enzyme.autodiff(Forward, midconcat, Duplicated(1.0, 7.0), Const(a)) - @test length(dres) == 5 - @test dres[1] ≈ 7.0 - @test dres[2] == "a" - @test dres[3] == "b" - @test dres[4] == "c" - @test dres[5] == "d" - - res, dres = Enzyme.autodiff(Forward, midconcat, Duplicated, Duplicated(1.0, 7.0), Const(a)) - @test length(res) == 5 - @test res[1] ≈ 1.0 - @test res[2] == "a" - @test res[3] == "b" - @test res[4] == "c" - @test res[5] == "d" - @test length(dres) == 5 - @test dres[1] ≈ 7.0 - @test dres[2] == "a" - @test dres[3] == "b" - @test dres[4] == "c" - @test dres[5] == "d" -end - - y = [(-92.0, -93.0), (-97.9, -911.2)] - dy = [(-913.7, -915.2), (-9100.02, -9304.1)] - - dres, = Enzyme.autodiff(Forward, metaconcat2, Duplicated(x, dx), Duplicated(y, dy)) - @test length(dres) == 8 - @test dres[1] ≈ 13.7 - @test dres[2] ≈ 15.2 - @test dres[3] ≈ 100.02 - @test dres[4] ≈ 304.1 - @test dres[5] ≈ -913.7 - @test dres[6] ≈ -915.2 - @test dres[7] ≈ -9100.02 - @test dres[8] ≈ -9304.1 - - res, dres = Enzyme.autodiff(Forward, metaconcat2, Duplicated, Duplicated(x, dx), Duplicated(y, dy)) - @test length(res) == 8 - @test res[1] ≈ 2.0 - @test res[2] ≈ 3.0 - @test res[3] ≈ 7.9 - @test res[4] ≈ 11.2 - @test res[5] ≈ -92.0 - @test res[6] ≈ -93.0 - @test res[7] ≈ -97.9 - @test res[8] ≈ -911.2 - @test length(dres) == 8 - @test dres[1] ≈ 13.7 - @test dres[2] ≈ 15.2 - @test dres[3] ≈ 100.02 - @test dres[4] ≈ 304.1 - @test dres[5] ≈ -913.7 - @test dres[6] ≈ -915.2 - @test dres[7] ≈ -9100.02 - @test dres[8] ≈ -9304.1 - - - dres, = Enzyme.autodiff(Forward, metaconcat3, Duplicated(x, dx), Const(a), Duplicated(y, dy)) - @test length(dres) == 12 - @test dres[1] ≈ 13.7 - @test dres[2] ≈ 15.2 - @test dres[3] ≈ 100.02 - @test dres[4] ≈ 304.1 - - @test dres[5] == "a" - @test dres[6] == "b" - @test dres[7] == "c" - @test dres[8] == "d" - - @test dres[9] ≈ -913.7 - @test dres[10] ≈ -915.2 - @test dres[11] ≈ -9100.02 - @test dres[12] ≈ -9304.1 - - res, dres = Enzyme.autodiff(Forward, metaconcat3, Duplicated, Duplicated(x, dx), Const(a), Duplicated(y, dy)) - @test length(res) == 12 - @test res[1] ≈ 2.0 - @test res[2] ≈ 3.0 - @test res[3] ≈ 7.9 - @test res[4] ≈ 11.2 - - @test res[5] == "a" - @test res[6] == "b" - @test res[7] == "c" - @test res[8] == "d" - - @test res[9] ≈ -92.0 - @test res[10] ≈ -93.0 - @test res[11] ≈ -97.9 - @test res[12] ≈ -911.2 - - @test length(dres) == 12 - @test dres[1] ≈ 13.7 - @test dres[2] ≈ 15.2 - @test dres[3] ≈ 100.02 - @test dres[4] ≈ 304.1 - - @test dres[5] == "a" - @test dres[6] == "b" - @test dres[7] == "c" - @test dres[8] == "d" - - @test dres[9] ≈ -913.7 - @test dres[10] ≈ -915.2 - @test dres[11] ≈ -9100.02 - @test dres[12] ≈ -9304.1 - - - dres, = Enzyme.autodiff(Forward, metaconcat, BatchDuplicated(x, (dx, dy))) - @test length(dres[1]) == 4 - @test dres[1][1] ≈ 13.7 - @test dres[1][2] ≈ 15.2 - @test dres[1][3] ≈ 100.02 - @test dres[1][4] ≈ 304.1 - @test length(dres[2]) == 4 - @test dres[2][1] ≈ -913.7 - @test dres[2][2] ≈ -915.2 - @test dres[2][3] ≈ -9100.02 - @test dres[2][4] ≈ -9304.1 - - res, dres = Enzyme.autodiff(Forward, metaconcat, Duplicated, BatchDuplicated(x, (dx, dy))) - @test length(res) == 4 - @test res[1] ≈ 2.0 - @test res[2] ≈ 3.0 - @test res[3] ≈ 7.9 - @test res[4] ≈ 11.2 - @test length(dres[1]) == 4 - @test dres[1][1] ≈ 13.7 - @test dres[1][2] ≈ 15.2 - @test dres[1][3] ≈ 100.02 - @test dres[1][4] ≈ 304.1 - @test length(dres[2]) == 4 - @test dres[2][1] ≈ -913.7 - @test dres[2][2] ≈ -915.2 - @test dres[2][3] ≈ -9100.02 - @test dres[2][4] ≈ -9304.1 +@testset "Batch Getfield" begin + x = [(2.0, 3.0)] + dx = [(0.0, 0.0)] + dx2 = [(0.0, 0.0)] + dx3 = [(0.0, 0.0)] + out = Ref(0.0) + dout = Ref(1.0) + dout2 = Ref(3.0) + dout3 = Ref(5.0) + Enzyme.autodiff(Reverse, batchgf, Const, BatchDuplicatedNoNeed(out, (dout, dout2, dout3)), BatchDuplicated(x, (dx, dx2, dx3))) + @test dx[1][1] ≈ 1.0 + @test dx[1][2] ≈ 0.0 + @test dx2[1][1] ≈ 3.0 + @test dx2[1][2] ≈ 0.0 + @test dx3[1][1] ≈ 5.0 + @test dx2[1][2] ≈ 0.0 end +include("applyiter.jl") + @testset "Dynamic Val Construction" begin dyn_f(::Val{D}) where D = prod(D) @@ -2566,41 +2373,6 @@ end Enzyme.API.runtimeActivity!(false) end -@testset "apply iterate" begin - function mktup(v) - tup = tuple(v...) - return tup[1][1] * tup[3][1] - end - - data = [[3.0], nothing, [2.0]] - ddata = [[0.0], nothing, [0.0]] - - Enzyme.autodiff(Reverse, mktup, Duplicated(data, ddata)) - @test ddata[1][1] ≈ 2.0 - @test ddata[3][1] ≈ 3.0 - - function mktup2(v) - tup = tuple(v...) - return (tup[1][1] * tup[3])::Float64 - end - - data = [[3.0], nothing, 2.0] - ddata = [[0.0], nothing, 0.0] - - @test_throws AssertionError Enzyme.autodiff(Reverse, mktup2, Duplicated(data, ddata)) - - function mktup3(v) - tup = tuple(v..., v...) - return tup[1][1] * tup[1][1] - end - - data = [[3.0]] - ddata = [[0.0]] - - Enzyme.autodiff(Reverse, mktup3, Duplicated(data, ddata)) - @test ddata[1][1] ≈ 6.0 -end - @testset "BLAS" begin x = [2.0, 3.0] dx = [0.2,0.3]