From ad86689a8a6d7fd895f8427b3a7b977602d7828e Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 10 Oct 2024 12:40:14 -0500 Subject: [PATCH] Multi arg fwd gradient (#1952) * Multi arg fwd gradient * multi arg deriv * fix * fix * Update Enzyme.jl * cleanup * fix * Update Enzyme.jl * Update Enzyme.jl --- src/Enzyme.jl | 322 +++++++++++++++-------- test/runtests.jl | 469 +--------------------------------- test/sugar.jl | 646 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 869 insertions(+), 568 deletions(-) create mode 100644 test/sugar.jl diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 2e7789d660..598dc872e9 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1794,16 +1794,30 @@ end @inline tupleconcat(x, y) = (x..., y...) @inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...) -function create_shadows(::Nothing, x) - return (onehot(x),) -end - -function create_shadows(::Val{1}, x) - return (onehot(x),) -end - -function create_shadows(::Val{chunk}, x) where {chunk} - return (chunkedonehot(x, Val(chunk)),) +@generated function create_shadows(chunk::ChunkTy, x::X, vargs::Vararg{Any,N}) where {ChunkTy, X, N} + args = Union{Symbol,Expr}[:x] + tys = Type[X] + for i in 1:N + push!(args, :(vargs[$i])) + push!(tys, vargs[i]) + end + + exprs = Union{Symbol,Expr}[] + for (arg, ty) in zip(args, tys) + if ty <: Enzyme.Const + push!(exprs, :(nothing)) + elseif ty <: AbstractFloat + push!(exprs, :(nothing)) + elseif ChunkTy == Nothing || ChunkTy == Val{1} + push!(exprs, :(onehot($arg))) + else + push!(exprs, :(chunkedonehot($arg, chunk))) + end + end + return quote + Base.@_inline_meta + ($(exprs...),) + end end struct TupleArray{T,Shape,Length,N} <: AbstractArray{T,N} @@ -1890,7 +1904,7 @@ gradient(ForwardWithPrimal, f, [2.0, 3.0]; chunk=Val(1)) (derivs = ([3.0, 2.0],), val = 6.0) ``` -For functions which return an AbstractArray or scalar, this function will return an AbstracttArray +For functions which return an AbstractArray or scalar, this function will return an AbstractArray whose shape is `(size(output)..., size(input)...)`. No guarantees are presently made about the type of the AbstractArray returned by this function (which may or may not be the same as the input AbstractArray if provided). @@ -1905,119 +1919,227 @@ grad = gradient(Forward, f, [2.0, 3.0, 4.0]) # output ([3.0 2.0 0.0; 0.0 1.0 1.0],) ``` + +This function supports multiple arguments and computes the gradient with respect to each + +```jldoctest gradfwd2 +mul(x, y) = x[1]*y[2] + x[2]*y[1] + +gradient(Forward, mul, [2.0, 3.0], [2.7, 3.1]) + +# output + +([3.1, 2.7], [3.0, 2.0]) +``` + +This includes the ability to mark some arguments as `Const` if its derivative is not needed, returning nothing in the corresponding derivative map. + +```jldoctest gradfwd2 +gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1])) + +# output + +([3.1, 2.7], nothing) +``` """ -@inline function gradient( +@generated function gradient( fm::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}, - f, - x; + f::F, + x::ty_0, + args::Vararg{Any,N}; chunk::CS = nothing, - shadows = create_shadows(chunk, x), -) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,CS} - if length(shadows[1]) == 0 - return if ReturnPrimal - (; derivs = (x,), val = f(x.val)) + shadows::ST = create_shadows(chunk, x, args...), +) where {F, ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity,CS,ST, ty_0, N} + + syms = Union{Symbol,Expr}[:x] + shads = Union{Symbol,Expr}[:(shadows[1])] + tys = Type[ty_0] + for i in 1:N + push!(syms, :(args[$i])) + push!(tys, args[i]) + push!(shads, :(shadows[1+$i])) + end + fval = if F <: Annotation + :(f.val) + else + :f + end + + vals = Union{Symbol,Expr}[] + consts = Union{Symbol,Expr}[] + for (arg, ty) in zip(syms, tys) + if ty <: Const + push!(vals, :($arg.val)) + push!(consts, arg) else - (x,) + push!(vals, arg) + push!(consts, :(Const($arg))) end end - if chunk == Val(0) - throw(ErrorException("Cannot differentiate with a batch size of 0")) + + if CS == Val{0} + return quote + Base.@_inline_meta + throw(ErrorException("Cannot differentiate with a batch size of 0")) + end end - gradtup = if chunk == nothing - resp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1])) + exprs = Union{Symbol,Expr}[] + primal = nothing + derivatives = Union{Symbol,Expr}[] - res = values(resp[1]) - dres = if x isa AbstractFloat - res[1] - else - res + primmode = :(fm) + for (i, (arg, ty)) in enumerate(zip(syms, tys)) + if ty <: Const + push!(derivatives, :(nothing)) + continue end - if ReturnPrimal - ((dres,), resp[2]) - else - (dres,) - end - elseif chunk == Val(1) - if ReturnPrimal - rp = autodiff(fm, f, Duplicated, Duplicated(x, shadows[1][1])) - dres1 = rp[1] - fm2 = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() #=ReturnPrimal=# - res = ntuple(length(shadows[1]) - 1) do i - autodiff(fm2, f, Duplicated, Duplicated(x, shadows[1][i+1]))[1] + argnum = length(ST.parameters[i].parameters) + + argderivative = if ty <: AbstractFloat + dargs = Union{Symbol,Expr}[] + for (j, arg2) in enumerate(syms) + if i == j + push!(dargs, :(Duplicated($arg, one($arg)))) + else + push!(dargs, consts[j]) + end end - gres = if x isa AbstractFloat - dres1[1] - else - (dres1, res...) + + resp = Symbol("resp_$i") + push!(exprs, quote + $resp = autodiff($primmode, f, Duplicated, $(dargs...)) + end) + if ReturnPrimal && primal == nothing + primal = :($resp[2]) + primmode = NoPrimal(fm()) end - ((gres,), rp[2]) - else - res = ntuple(length(shadows[1])) do i - autodiff(fm, f, Duplicated, Duplicated(x, shadows[1][i]))[1] + + :($resp[1]) + elseif argnum == 0 + vals[i] + elseif CS == Nothing + dargs = Union{Symbol,Expr}[] + for (j, arg2) in enumerate(syms) + if i == j + push!(dargs, :(BatchDuplicated($arg, $(shads[i])))) + else + push!(dargs, consts[j]) + end end - (if x isa AbstractFloat - res[1] - else - res - end,) - end - else - if ReturnPrimal - rp = autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1][1])) - dres1 = values(rp[1]) - gres = if x isa AbstractFloat - dres1[1] - else - fm2 = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() #=ReturnPrimal=# - tmp = ntuple(length(shadows[1]) - 1) do i - values( - autodiff( - fm2, - f, - BatchDuplicated, - BatchDuplicated(x, shadows[1][i+1]), - )[1], - ) + + df = :f + if F <: Enzyme.Duplicated + zeros = Expr[] + for i in 1:argnum + push!(zeros, :(f.dval)) end - tupleconcat(dres1, tmp...) + df = :(BatchDuplicated(f.val, ($(zeros...),) )) + end + + resp = Symbol("resp_$i") + push!(exprs, quote + $resp = autodiff($primmode, $df, BatchDuplicated, $(dargs...)) + end) + if ReturnPrimal && primal == nothing + primal = :($resp[2]) + primmode = NoPrimal(fm()) end - ((gres,), rp[2]) + + :(values($resp[1])) + elseif CS == Val{1} + subderivatives = Union{Symbol,Expr}[] + for an in 1:argnum + dargs = Union{Symbol,Expr}[] + for (j, arg2) in enumerate(syms) + if i == j + push!(dargs, :(Duplicated($arg, $(shads[i])[$an]))) + else + push!(dargs, consts[j]) + end + end + + resp = Symbol("resp_$i"*"_"*string(an)) + push!(exprs, quote + $resp = autodiff($primmode, f, Duplicated, $(dargs...)) + end) + if ReturnPrimal && primal == nothing + primal = :($resp[2]) + primmode = NoPrimal(fm()) + end + + push!(subderivatives, :(values($resp[1]))) + end + :(($(subderivatives...),)) else - tmp = ntuple(length(shadows[1])) do i - values(autodiff(fm, f, BatchDuplicated, BatchDuplicated(x, shadows[1][i]))[1]) + subderivatives = Union{Symbol,Expr}[] + for an in 1:argnum + dargs = Union{Symbol,Expr}[] + for (j, arg2) in enumerate(syms) + if i == j + push!(dargs, :(BatchDuplicated($arg, $(shads[i])[$an]))) + else + push!(dargs, consts[j]) + end + end + + resp = Symbol("resp_$i"*"_"*string(an)) + push!(exprs, quote + $resp = autodiff($primmode, f, BatchDuplicated, $(dargs...)) + end) + if ReturnPrimal && primal == nothing + primal = :($resp[2]) + primmode = NoPrimal(fm()) + end + + push!(subderivatives, :(values($resp[1]))) end - res = tupleconcat(tmp...) - (if x isa AbstractFloat - res[1] + :(tupleconcat($(subderivatives...))) + end + + deriv = if ty <: AbstractFloat + argderivative + else + tmp = Symbol("tmp_$i") + push!(exprs, :($tmp = $argderivative)) + if ty <: AbstractArray + if argnum > 0 + quote + if $tmp[1] isa AbstractArray + inshape = size($(vals[1])) + outshape = size($tmp[1]) + # st : outshape x total inputs + tupstack($tmp, outshape, inshape) + else + TupleArray($tmp, size($arg)) + end + end + else + :(TupleArray($tmp, size($arg))) + end else - res - end,) + tmp + end end + push!(derivatives, deriv) end - cols = if ReturnPrimal - gradtup[1][1] - else - gradtup[1] - end - res = if x isa AbstractFloat - cols - elseif length(cols) > 0 && cols[1] isa AbstractArray && x isa AbstractArray - inshape = size(x) - outshape = size(cols[1]) - # st : outshape x total inputs - tupstack(cols, outshape, inshape) - elseif x isa AbstractArray - TupleArray(cols, size(x)) - else - cols + # We weirdly asked for no derivatives + if ReturnPrimal && primal == nothing + primal = :($fval($(vals...))) end - if ReturnPrimal - (; derivs = (res,), val = gradtup[2]) + + result = if ReturnPrimal + :((; derivs = ($(derivatives...),), val = $primal)) else - (res,) + :(($(derivatives...),)) + end + + return quote + Base.@_inline_meta + $(exprs...) + $result end end diff --git a/test/runtests.jl b/test/runtests.jl index 902b9e4f65..c3856aabf1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,15 +16,6 @@ using InlineStrings using Enzyme_jll @info "Testing against" Enzyme_jll.libEnzyme -# symbol is \simeq -# this is basically a more flexible version of ≈ -(≃)(a, b) = (≈)(a, b) -(≃)(a::Tuple, b::Tuple) = all(xy -> xy[1] ≃ xy[2], zip(a,b)) -function (≃)(a::AbstractArray{<:Tuple}, b::AbstractArray{<:Tuple}) - size(a) == size(b) || return false - all(xy -> xy[1] ≃ xy[2], zip(a,b)) -end - function isapproxfn(fn, args...; kwargs...) isapprox(args...; kwargs...) end @@ -2938,465 +2929,7 @@ end @test dx ≈ [-1.0, 43.74, 0] end - -# these are used in gradient and jacobian tests -struct InpStruct - i1::Float64 - i2::Float64 - i3::Float64 -end -struct OutStruct - i1::Float64 - i2::Float64 - i3::Float64 -end - -for A ∈ (:InpStruct, :OutStruct) - @eval (≃)(a::$A, b::$A) = (a.i1 ≃ b.i1) && (a.i2 ≃ b.i2) && (a.i3 ≃ b.i3) - @eval function (≃)(a::AbstractArray{<:$A}, b::AbstractArray{<:$A}) - size(a) == size(b) || return false - all(xy -> xy[1] ≃ xy[2], zip(a, b)) - end -end - - -#NOTE: this is needed because of problems with hvcat on 1.10 and something inexplicable on 1.6 -# suffice it to say it's not good that this is required, please remove when possible -mkarray(sz, args...) = reshape(vcat(args...), sz) - -@testset "Gradient and Jacobian Outputs" begin - - scalar = 3.0 - - # ∂ scalar / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, x -> x^2, scalar)[1] ≈ 6.0 - @test Enzyme.gradient(Enzyme.Reverse, x -> x^2, scalar)[1] ≈ 6.0 - @test Enzyme.jacobian(Enzyme.Forward, x -> x^2, scalar)[1] ≈ 6.0 - @test Enzyme.jacobian(Enzyme.Reverse, x -> x^2, scalar)[1] ≈ 6.0 - @test Enzyme.gradient(Enzyme.Forward, x -> 2*x, scalar)[1] ≈ 2.0 - @test Enzyme.gradient(Enzyme.Reverse, x -> 2*x, scalar)[1] ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Forward, x -> 2*x, scalar)[1] ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Reverse, x -> 2*x, scalar)[1] ≈ 2.0 - - # ∂ vector / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] - - @test Enzyme.jacobian(Enzyme.Forward, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] - @test Enzyme.jacobian(Enzyme.Reverse, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] - - - # ∂ tuple / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, x -> (2*x, x^2), scalar)[1] ≃ (2.0, 6.0) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (2*x, x^2), scalar)[1] ≈ [2.0, 6.0] - - @test Enzyme.jacobian(Enzyme.Forward, x -> (2*x, x^2), scalar)[1] ≃ (2.0, 6.0) - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (2*x, x^2), scalar)[1] ≃ (2.0, 6.0) - - mkarray1 = x -> mkarray((2,2),2*x,sin(x),x^2,exp(x)) - - # ∂ matrix / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] - @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] - - @test Enzyme.jacobian(Enzyme.Forward, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] - @test Enzyme.jacobian(Enzyme.Reverse, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] - - # ∂ struct / ∂ scalar - @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar)[1] == OutStruct(1.0,2*scalar,3*scalar^2) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar)[1] == (OutStruct(1.0,2.0,3.0),) - @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar)[1] == OutStruct(1.0,2*scalar,3*scalar^2) - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar)[1] == (OutStruct(1.0,2.0,3.0),) - - - - vector = [2.7, 3.1] - - # ∂ scalar / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], vector)[1] ≈ [vector[2],vector[1]] - @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], vector)[1] ≈ [vector[2], vector[1]] - @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], vector)[1] ≈ [vector[2], vector[1]] - @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], vector)[1] ≈ [vector[2], vector[1]] - - - # ∂ vector / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ - [vector[2] vector[1]; -sin(vector[1]) 1.0] - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ - [vector[2] vector[1]; -sin(vector[1]) 1.0] - @test Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ - [vector[2] vector[1]; -sin(vector[1]) 1.0] - @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ - [vector[2] vector[1]; -sin(vector[1]) 1.0] - - # ∂ tuple / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≃ - [(vector[2], -sin(vector[1])), (vector[1], 1.0)] - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≈ - ([vector[2], -sin(vector[1])], [vector[1], 1.0]) - @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≃ - [(vector[2], -sin(vector[1])), (vector[1], 1.0)] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] - - mkarray2 = x -> mkarray((2,2), x[1]*x[2], exp(x[2]), cos(x[1])+x[2], x[1]) - - # ∂ matrix / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, mkarray2, vector)[1] ≈ - mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) - @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, vector)[1] - @test Enzyme.jacobian(Enzyme.Forward, mkarray2, vector)[1] ≈ - mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) - @test Enzyme.jacobian(Enzyme.Reverse, mkarray2, vector)[1] ≈ - mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) - - # ∂ struct / ∂ vector - @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector)[1] ≃ - [OutStruct(vector[2], -sin(vector[1]), 0.0), OutStruct(vector[1], 1.0, exp(vector[2]))] - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) - - @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector)[1] ≃ - [OutStruct(vector[2], -sin(vector[1]), 0.0), OutStruct(vector[1], 1.0, exp(vector[2]))] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) - - - - tuplev = (2.7, 3.1) - - # ∂ scalar / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) - @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) - @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) - @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) - - # ∂ vector / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≃ - ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≈ - [tuplev[2] tuplev[1]; -sin(tuplev[1]) 1.0] - @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≃ - [(tuplev[2], tuplev[1]), (-sin(tuplev[1]), 1.0)] - - # ∂ tuple / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≃ - ((vector[2], -sin(vector[1])), (vector[1], 1.0)) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) - @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≃ - ((tuplev[2], -sin(tuplev[1])), (tuplev[1], 1.0)) - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ - [tuplev[2] tuplev[1]; -sin(tuplev[1]) 1.0] - - # ∂ matrix / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, mkarray2, tuplev)[1] ≃ - ([tuplev[2] -sin(tuplev[1]); 0.0 1.0], [tuplev[1] 1.0; exp(tuplev[2]) 0.0]) - @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, tuplev)[1] - @test_broken Enzyme.jacobian(Enzyme.Forward, mkarray2, tuplev)[1] ≈ - [tuplev[2] -sin(tuplev[1]); 0.0 1.0;;; tuplev[1] 1.0; exp(tuplev[2]) 0.0] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> mkarray2, tuplev)[1] ≈ - [tuplev[2] -sin(tuplev[1]); 0.0 1.0;;; tuplev[1] 1.0; exp(tuplev[2]) 0.0] - - # ∂ struct / ∂ tuple - @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev)[1] ≃ - (OutStruct(tuplev[2], -sin(tuplev[1]), 0.0), OutStruct(tuplev[1], 1.0, exp(tuplev[2]))) - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) - - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev)[1] ≃ - [OutStruct(tuplev[2], -sin(tuplev[1]), 0.0), OutStruct(tuplev[1], 1.0, exp(tuplev[2]))] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) - - - - matrix = [2.7 3.1; 4.7 5.6] - - # ∂ scalar / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] - @test Enzyme.gradient(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] - @test Enzyme.jacobian(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] - @test Enzyme.jacobian(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] - - # ∂ vector / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] ≈ - mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) - @test_broken Enzyme.gradient(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] - # again we can't use array construction syntax because of 1.6 - @test Enzyme.jacobian(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] ≈ - mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) - @test Enzyme.jacobian(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] ≈ - mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) - - # ∂ tuple / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix)[1] ≃ - [(matrix[1,2],0.0) (matrix[1,1],0.0); (0.0,matrix[2,2]) (0.0,matrix[2,1])] - @test_broken Enzyme.gradient(Enzyme.Reverse, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) - @test Enzyme.jacobian(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix)[1] ≃ - [(matrix[1,2],0.0) (matrix[1,1],0.0); (0.0,matrix[2,2]) (0.0,matrix[2,1])] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix)[1] - - mkarray3 = x -> mkarray((2,2), x[1,1]*x[1,2], exp(x[1,1])+x[2,2], x[2,1]*x[2,2], sin(x[1,2])+x[2,1]) - - # ∂ matrix / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, mkarray3, matrix)[1] ≈ - mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, - matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) - @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray3, matrix)[1] - # array construction syntax broken on 1.6 - @test Enzyme.jacobian(Enzyme.Forward, mkarray3, matrix)[1] ≈ - mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, - matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) - @test Enzyme.jacobian(Enzyme.Reverse, mkarray3, matrix)[1] ≈ - mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, - matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) - - # ∂ tuple / ∂ matrix - @test Enzyme.gradient(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] ≃ - [OutStruct(matrix[1,2],0.0, exp(matrix[1,1])) OutStruct(matrix[1,1],0.0,0.0); OutStruct(0.0,matrix[2,2],0.0) OutStruct(0.0,matrix[2,1], 1.0)] - @test_broken Enzyme.gradient(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] - @test Enzyme.jacobian(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] ≃ - [OutStruct(matrix[1,2],0.0, exp(matrix[1,1])) OutStruct(matrix[1,1],0.0,0.0); OutStruct(0.0,matrix[2,2],0.0) OutStruct(0.0,matrix[2,1], 1.0)] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] - - - istruct = InpStruct(2.7, 3.1, 4.7) - - # ∂ scalar / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct)[1] - @test Enzyme.gradient(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct)[1] ≃ InpStruct(istruct.i2, istruct.i1, 1.0) - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct)[1] - @test Enzyme.jacobian(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct)[1] ≃ InpStruct(istruct.i2, istruct.i1, 1.0) - - # ∂ vector / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] - @test Enzyme.jacobian(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] ≃ [InpStruct(istruct.i2, istruct.i1, 0.0), InpStruct(1.0, 0.0, -sin(istruct.i3))] - - # ∂ tuple / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] - - mkarray4 = x -> mkarray((2,2), x.i1*x.i2, exp(x.i2), cos(x.i3)+x.i1, x.i1) - - # ∂ matrix / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct)[1] - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct)[1] - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct)[1] - @test Enzyme.jacobian(Enzyme.Reverse, mkarray4, istruct)[1] ≃ - [InpStruct(istruct.i2, istruct.i1, 0.0) InpStruct(1.0, 0.0, -sin(istruct.i3)); - InpStruct(0.0, exp(istruct.i2), 0.0) InpStruct(1.0, 0.0, 0.0)] - - # ∂ struct / ∂ struct - @test_broken Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] - @test_broken Enzyme.gradient(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] - @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] - @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] -end - -@testset "Simple Jacobian" begin - @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0)[1] ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0)[1] ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0])[1] ≈ [4.0, 6.0] - - @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, chunk=Val(1))[1] ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, chunk=Val(1))[1] ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], chunk=Val(1))[1] ≈ [4.0, 6.0] - - @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, chunk=Val(2))[1] ≈ 2.0 - @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, chunk=Val(2))[1] ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], chunk=Val(2))[1] ≈ [4.0, 6.0] - - @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, n_outs=Val((2,)))[1] ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, n_outs=Val((2,)), chunk=Val(1))[1] ≈ [1.0, 2.0] - @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, n_outs=Val((2,)), chunk=Val(2))[1] ≈ [1.0, 2.0] - - x = float.(reshape(1:6, 2, 3)) - - fillabs2(x) = [sum(abs2, x), 10*sum(abs2, x), 100*sum(abs2, x), 1000*sum(abs2, x)] - - jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x)[1] - - @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] - @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] - @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] - @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - - jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, chunk=Val(1))[1] - - @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] - @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] - @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] - @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - - jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, chunk=Val(2))[1] - - @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] - @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] - @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] - @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - - - jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, n_outs=Val((4,)), chunk=Val(1))[1] - - @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] - @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] - @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] - @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - - jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, n_outs=Val((4,)), chunk=Val(2))[1] - - @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] - @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] - @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] - @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] - - fillinpabs2(x) = [(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 10*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 100*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 1000*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3)] - - x2 = InpStruct(1.0, 2.0, 3.0) - - jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, n_outs=Val((4,)), chunk=Val(1))[1] - - @test jac[1] == InpStruct(2.0, 4.0, 6.0) - @test jac[2] == InpStruct(20.0, 40.0, 60.0) - @test jac[3] == InpStruct(200.0, 400.0, 600.0) - @test jac[4] == InpStruct(2000.0, 4000.0, 6000.0) - - jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, n_outs=Val((4,)), chunk=Val(2))[1] - - @test jac[1] == InpStruct(2.0, 4.0, 6.0) - @test jac[2] == InpStruct(20.0, 40.0, 60.0) - @test jac[3] == InpStruct(200.0, 400.0, 600.0) - @test jac[4] == InpStruct(2000.0, 4000.0, 6000.0) - - filloutabs2(x) = OutStruct(sum(abs2, x), 10*sum(abs2, x), 100*sum(abs2, x)) - - jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x)[1] - - @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) - @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) - - @test jac[1, 2] == OutStruct(6.0, 60.0, 600.0) - @test jac[2, 2] == OutStruct(8.0, 80.0, 800.0) - - @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) - @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) - - jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, chunk=Val(1))[1] - - @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) - @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) - - @test jac[1, 2] == OutStruct(6.0, 60.0, 600.0) - @test jac[2, 2] == OutStruct(8.0, 80.0, 800.0) - - @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) - @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) - - jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, chunk=Val(2))[1] - - @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) - @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) - - @test jac[1, 2] == OutStruct(6.0, 60.0, 600.0) - @test jac[2, 2] == OutStruct(8.0, 80.0, 800.0) - - @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) - @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) -end - - -@testset "Jacobian" begin - function inout(v) - [v[2], v[1]*v[1], v[1]*v[1]*v[1]] - end - - jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], n_outs=Val((3,)), chunk=Val(1))[1] - @test size(jac) == (3, 2) - @test jac ≈ [ 0.0 1.0; - 4.0 0.0; - 12.0 0.0] - - jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], chunk=Val(1))[1] - @test size(jac) == (3, 2) - @test jac ≈ [ 0.0 1.0; - 4.0 0.0; - 12.0 0.0] - - @test jac == Enzyme.jacobian(Forward, inout, [2.0, 3.0])[1] - - jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], n_outs=Val((3,)), chunk=Val(2))[1] - @test size(jac) == (3, 2) - @test jac ≈ [ 0.0 1.0; - 4.0 0.0; - 12.0 0.0] - - jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], chunk=Val(2))[1] - @test size(jac) == (3, 2) - @test jac ≈ [ 0.0 1.0; - 4.0 0.0; - 12.0 0.0] - - function f_test_1(A, x) - utmp = A*x[2:end] .+ x[1] - return utmp - end - - function f_test_2(A, x) - utmp = Vector{Float64}(undef, length(x)-1) - utmp .= A*x[2:end] .+ x[1] - return utmp - end - - function f_test_3!(u, A, x) - utmp .= A*x[2:end] .+ x[1] - end - - J_r_1(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_1(A, θ), x, n_outs=Val((5,)))[1] - J_r_2(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_2(A, θ), x, n_outs=Val((5,)))[1] - J_r_3(u, A, x) = Enzyme.jacobian(Reverse, θ -> f_test_3!(u, A, θ), x, n_outs=Val((5,)))[1] - - J_f_1(A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_1(A, θ)), x)[1] - J_f_2(A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_2(A, θ)), x)[1] - J_f_3(u, A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_3!(u, A, θ)), x)[1] - - x = ones(6) - A = Matrix{Float64}(LinearAlgebra.I, 5, 5) - u = Vector{Float64}(undef, 5) - - @test J_r_1(A, x) == [ - 1.0 1.0 0.0 0.0 0.0 0.0; - 1.0 0.0 1.0 0.0 0.0 0.0; - 1.0 0.0 0.0 1.0 0.0 0.0; - 1.0 0.0 0.0 0.0 1.0 0.0; - 1.0 0.0 0.0 0.0 0.0 1.0; - ] - - @test J_r_2(A, x) == [ - 1.0 1.0 0.0 0.0 0.0 0.0; - 1.0 0.0 1.0 0.0 0.0 0.0; - 1.0 0.0 0.0 1.0 0.0 0.0; - 1.0 0.0 0.0 0.0 1.0 0.0; - 1.0 0.0 0.0 0.0 0.0 1.0; - ] - - @test J_f_1(A, x) == [ - 1.0 1.0 0.0 0.0 0.0 0.0; - 1.0 0.0 1.0 0.0 0.0 0.0; - 1.0 0.0 0.0 1.0 0.0 0.0; - 1.0 0.0 0.0 0.0 1.0 0.0; - 1.0 0.0 0.0 0.0 0.0 1.0; - ] - @test J_f_2(A, x) == [ - 1.0 1.0 0.0 0.0 0.0 0.0; - 1.0 0.0 1.0 0.0 0.0 0.0; - 1.0 0.0 0.0 1.0 0.0 0.0; - 1.0 0.0 0.0 0.0 1.0 0.0; - 1.0 0.0 0.0 0.0 0.0 1.0; - ] - - # @show J_r_3(u, A, x) - # @show J_f_3(u, A, x) -end +include("sugar.jl") @testset "Forward on Reverse" begin diff --git a/test/sugar.jl b/test/sugar.jl new file mode 100644 index 0000000000..c558fd813e --- /dev/null +++ b/test/sugar.jl @@ -0,0 +1,646 @@ +using Enzyme, Test + + +mul_scalar(x, y) = x[1]*y[2] + x[2]*y[1] +mul_vector(x, y) = [x[1]*y[2], x[2]*y[1]] + +@testset "Forward Multi-Arg Gradient" begin + res = gradient(Forward, mul_scalar, [2.0, 3.0], [2.7, 3.1]) + @test res[1] ≈ [3.1, 2.7] + @test res[2] ≈ [3.0, 2.0] + + res = gradient(Forward, mul_scalar, [2.0, 3.0], [2.7, 3.1]; chunk=Val(1)) + @test res[1] ≈ [3.1, 2.7] + @test res[2] ≈ [3.0, 2.0] + + res = gradient(Forward, mul_scalar, [2.0, 3.0], [2.7, 3.1]; chunk=Val(2)) + @test res[1] ≈ [3.1, 2.7] + @test res[2] ≈ [3.0, 2.0] + + res = gradient(ForwardWithPrimal, mul_scalar, [2.0, 3.0], [2.7, 3.1]) + @test res.val ≈ mul_scalar([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1, 2.7] + @test res.derivs[2] ≈ [3.0, 2.0] + + res = gradient(ForwardWithPrimal, mul_scalar, [2.0, 3.0], [2.7, 3.1]; chunk=Val(1)) + @test res.val ≈ mul_scalar([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1, 2.7] + @test res.derivs[2] ≈ [3.0, 2.0] + + res = gradient(ForwardWithPrimal, mul_scalar, [2.0, 3.0], [2.7, 3.1]; chunk=Val(2)) + @test res.val ≈ mul_scalar([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1, 2.7] + @test res.derivs[2] ≈ [3.0, 2.0] + + + + res = gradient(Forward, mul_scalar, Const([2.0, 3.0]), [2.7, 3.1]) + @test res[1] == nothing + @test res[2] ≈ [3.0, 2.0] + + res = gradient(Forward, mul_scalar, Const([2.0, 3.0]), [2.7, 3.1]; chunk=Val(1)) + @test res[1] == nothing + @test res[2] ≈ [3.0, 2.0] + + res = gradient(Forward, mul_scalar, Const([2.0, 3.0]), [2.7, 3.1]; chunk=Val(2)) + @test res[1] == nothing + @test res[2] ≈ [3.0, 2.0] + + res = gradient(ForwardWithPrimal, mul_scalar, Const([2.0, 3.0]), [2.7, 3.1]) + @test res.val ≈ mul_scalar([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] == nothing + @test res.derivs[2] ≈ [3.0, 2.0] + + res = gradient(ForwardWithPrimal, mul_scalar, Const([2.0, 3.0]), [2.7, 3.1]; chunk=Val(1)) + @test res.val ≈ mul_scalar([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] == nothing + @test res.derivs[2] ≈ [3.0, 2.0] + + res = gradient(ForwardWithPrimal, mul_scalar, Const([2.0, 3.0]), [2.7, 3.1]; chunk=Val(2)) + @test res.val ≈ mul_scalar([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] == nothing + @test res.derivs[2] ≈ [3.0, 2.0] + + + res = gradient(Forward, mul_scalar, [2.0, 3.0], Const([2.7, 3.1])) + @test res[1] ≈ [3.1, 2.7] + @test res[2] == nothing + + res = gradient(Forward, mul_scalar, [2.0, 3.0], Const([2.7, 3.1]); chunk=Val(1)) + @test res[1] ≈ [3.1, 2.7] + @test res[2] == nothing + + res = gradient(Forward, mul_scalar, [2.0, 3.0], Const([2.7, 3.1]); chunk=Val(2)) + @test res[1] ≈ [3.1, 2.7] + @test res[2] == nothing + + res = gradient(ForwardWithPrimal, mul_scalar, [2.0, 3.0], Const([2.7, 3.1])) + @test res.val ≈ mul_scalar([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1, 2.7] + @test res.derivs[2] == nothing + + res = gradient(ForwardWithPrimal, mul_scalar, [2.0, 3.0], Const([2.7, 3.1]); chunk=Val(1)) + @test res.val ≈ mul_scalar([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1, 2.7] + @test res.derivs[2] == nothing + + res = gradient(ForwardWithPrimal, mul_scalar, [2.0, 3.0], Const([2.7, 3.1]); chunk=Val(2)) + @test res.val ≈ mul_scalar([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1, 2.7] + @test res.derivs[2] == nothing + + + + res = gradient(Forward, mul_vector, [2.0, 3.0], [2.7, 3.1]) + @test res[1] ≈ [3.1 0.0; 0.0 2.7] + @test res[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(Forward, mul_vector, [2.0, 3.0], [2.7, 3.1]; chunk=Val(1)) + @test res[1] ≈ [3.1 0.0; 0.0 2.7] + @test res[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(Forward, mul_vector, [2.0, 3.0], [2.7, 3.1]; chunk=Val(2)) + @test res[1] ≈ [3.1 0.0; 0.0 2.7] + @test res[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(ForwardWithPrimal, mul_vector, [2.0, 3.0], [2.7, 3.1]) + @test res.val ≈ mul_vector([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1 0.0; 0.0 2.7] + @test res.derivs[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(ForwardWithPrimal, mul_vector, [2.0, 3.0], [2.7, 3.1]; chunk=Val(1)) + @test res.val ≈ mul_vector([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1 0.0; 0.0 2.7] + @test res.derivs[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(ForwardWithPrimal, mul_vector, [2.0, 3.0], [2.7, 3.1]; chunk=Val(2)) + @test res.val ≈ mul_vector([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1 0.0; 0.0 2.7] + @test res.derivs[2] ≈ [0.0 2.0; 3.0 0.0] + + + + res = gradient(Forward, mul_vector, Const([2.0, 3.0]), [2.7, 3.1]) + @test res[1] == nothing + @test res[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(Forward, mul_vector, Const([2.0, 3.0]), [2.7, 3.1]; chunk=Val(1)) + @test res[1] == nothing + @test res[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(Forward, mul_vector, Const([2.0, 3.0]), [2.7, 3.1]; chunk=Val(2)) + @test res[1] == nothing + @test res[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(ForwardWithPrimal, mul_vector, Const([2.0, 3.0]), [2.7, 3.1]) + @test res.val ≈ mul_vector([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] == nothing + @test res.derivs[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(ForwardWithPrimal, mul_vector, Const([2.0, 3.0]), [2.7, 3.1]; chunk=Val(1)) + @test res.val ≈ mul_vector([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] == nothing + @test res.derivs[2] ≈ [0.0 2.0; 3.0 0.0] + + res = gradient(ForwardWithPrimal, mul_vector, Const([2.0, 3.0]), [2.7, 3.1]; chunk=Val(2)) + @test res.val ≈ mul_vector([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] == nothing + @test res.derivs[2] ≈ [0.0 2.0; 3.0 0.0] + + + res = gradient(Forward, mul_vector, [2.0, 3.0], Const([2.7, 3.1])) + @test res[1] ≈ [3.1 0.0; 0.0 2.7] + @test res[2] == nothing + + res = gradient(Forward, mul_vector, [2.0, 3.0], Const([2.7, 3.1]); chunk=Val(1)) + @test res[1] ≈ [3.1 0.0; 0.0 2.7] + @test res[2] == nothing + + res = gradient(Forward, mul_vector, [2.0, 3.0], Const([2.7, 3.1]); chunk=Val(2)) + @test res[1] ≈ [3.1 0.0; 0.0 2.7] + @test res[2] == nothing + + res = gradient(ForwardWithPrimal, mul_vector, [2.0, 3.0], Const([2.7, 3.1])) + @test res.val ≈ mul_vector([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1 0.0; 0.0 2.7] + @test res.derivs[2] == nothing + + res = gradient(ForwardWithPrimal, mul_vector, [2.0, 3.0], Const([2.7, 3.1]); chunk=Val(1)) + @test res.val ≈ mul_vector([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1 0.0; 0.0 2.7] + @test res.derivs[2] == nothing + + res = gradient(ForwardWithPrimal, mul_vector, [2.0, 3.0], Const([2.7, 3.1]); chunk=Val(2)) + @test res.val ≈ mul_vector([2.0, 3.0], [2.7, 3.1]) + @test res.derivs[1] ≈ [3.1 0.0; 0.0 2.7] + @test res.derivs[2] == nothing + +end + +# these are used in gradient and jacobian tests +struct InpStruct + i1::Float64 + i2::Float64 + i3::Float64 +end +struct OutStruct + i1::Float64 + i2::Float64 + i3::Float64 +end + +# symbol is \simeq +# this is basically a more flexible version of ≈ +(≃)(a, b) = (≈)(a, b) +(≃)(a::Tuple, b::Tuple) = all(xy -> xy[1] ≃ xy[2], zip(a,b)) +function (≃)(a::AbstractArray{<:Tuple}, b::AbstractArray{<:Tuple}) + size(a) == size(b) || return false + all(xy -> xy[1] ≃ xy[2], zip(a,b)) +end + +for A ∈ (:InpStruct, :OutStruct) + @eval (≃)(a::$A, b::$A) = (a.i1 ≃ b.i1) && (a.i2 ≃ b.i2) && (a.i3 ≃ b.i3) + @eval function (≃)(a::AbstractArray{<:$A}, b::AbstractArray{<:$A}) + size(a) == size(b) || return false + all(xy -> xy[1] ≃ xy[2], zip(a, b)) + end +end + + +#NOTE: this is needed because of problems with hvcat on 1.10 and something inexplicable on 1.6 +# suffice it to say it's not good that this is required, please remove when possible +mkarray(sz, args...) = reshape(vcat(args...), sz) + +@testset "Gradient and Jacobian Outputs" begin + + scalar = 3.0 + + # ∂ scalar / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, x -> x^2, scalar)[1] ≈ 6.0 + @test Enzyme.gradient(Enzyme.Reverse, x -> x^2, scalar)[1] ≈ 6.0 + @test Enzyme.jacobian(Enzyme.Forward, x -> x^2, scalar)[1] ≈ 6.0 + @test Enzyme.jacobian(Enzyme.Reverse, x -> x^2, scalar)[1] ≈ 6.0 + @test Enzyme.gradient(Enzyme.Forward, x -> 2*x, scalar)[1] ≈ 2.0 + @test Enzyme.gradient(Enzyme.Reverse, x -> 2*x, scalar)[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x -> 2*x, scalar)[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Reverse, x -> 2*x, scalar)[1] ≈ 2.0 + + # ∂ vector / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Forward, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] + @test Enzyme.jacobian(Enzyme.Reverse, x -> [2*x, x^2], scalar)[1] ≈ [2.0, 6.0] + + + # ∂ tuple / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, x -> (2*x, x^2), scalar)[1] ≃ (2.0, 6.0) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (2*x, x^2), scalar)[1] ≈ [2.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Forward, x -> (2*x, x^2), scalar)[1] ≃ (2.0, 6.0) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (2*x, x^2), scalar)[1] ≃ (2.0, 6.0) + + mkarray1 = x -> mkarray((2,2),2*x,sin(x),x^2,exp(x)) + + # ∂ matrix / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] + + @test Enzyme.jacobian(Enzyme.Forward, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] + @test Enzyme.jacobian(Enzyme.Reverse, mkarray1, scalar)[1] ≈ [2.0 6.0; cos(scalar) exp(scalar)] + + # ∂ struct / ∂ scalar + @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar)[1] == OutStruct(1.0,2*scalar,3*scalar^2) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar)[1] == (OutStruct(1.0,2.0,3.0),) + @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x, x^2, x^3), scalar)[1] == OutStruct(1.0,2*scalar,3*scalar^2) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> InpStruct(x, x^2, x^3), scalar)[1] == (OutStruct(1.0,2.0,3.0),) + + + + vector = [2.7, 3.1] + + # ∂ scalar / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], vector)[1] ≈ [vector[2],vector[1]] + @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], vector)[1] ≈ [vector[2], vector[1]] + @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], vector)[1] ≈ [vector[2], vector[1]] + @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], vector)[1] ≈ [vector[2], vector[1]] + + + # ∂ vector / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ + [vector[2] vector[1]; -sin(vector[1]) 1.0] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ + [vector[2] vector[1]; -sin(vector[1]) 1.0] + @test Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ + [vector[2] vector[1]; -sin(vector[1]) 1.0] + @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], vector)[1] ≈ + [vector[2] vector[1]; -sin(vector[1]) 1.0] + + # ∂ tuple / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≃ + [(vector[2], -sin(vector[1])), (vector[1], 1.0)] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≈ + ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≃ + [(vector[2], -sin(vector[1])), (vector[1], 1.0)] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] + + mkarray2 = x -> mkarray((2,2), x[1]*x[2], exp(x[2]), cos(x[1])+x[2], x[1]) + + # ∂ matrix / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, mkarray2, vector)[1] ≈ + mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, vector)[1] + @test Enzyme.jacobian(Enzyme.Forward, mkarray2, vector)[1] ≈ + mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) + @test Enzyme.jacobian(Enzyme.Reverse, mkarray2, vector)[1] ≈ + mkarray((2,2,2), vector[2], 0.0, -sin(vector[1]), 1.0, vector[1], exp(vector[2]), 1.0, 0.0) + + # ∂ struct / ∂ vector + @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector)[1] ≃ + [OutStruct(vector[2], -sin(vector[1]), 0.0), OutStruct(vector[1], 1.0, exp(vector[2]))] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + + @test Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), vector)[1] ≃ + [OutStruct(vector[2], -sin(vector[1]), 0.0), OutStruct(vector[1], 1.0, exp(vector[2]))] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), vector)[1] ≈ ([vector[2], -sin(vector[1])], [vector[1], 1.0]) + + + + tuplev = (2.7, 3.1) + + # ∂ scalar / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) + @test Enzyme.gradient(Enzyme.Reverse, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) + @test Enzyme.jacobian(Enzyme.Forward, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) + @test Enzyme.jacobian(Enzyme.Reverse, x -> x[1] * x[2], tuplev)[1] ≃ (tuplev[2],tuplev[1]) + + # ∂ vector / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≃ + ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≈ + [tuplev[2] tuplev[1]; -sin(tuplev[1]) 1.0] + @test Enzyme.jacobian(Enzyme.Reverse, x -> [x[1] * x[2], cos(x[1]) + x[2]], tuplev)[1] ≃ + [(tuplev[2], tuplev[1]), (-sin(tuplev[1]), 1.0)] + + # ∂ tuple / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≃ + ((vector[2], -sin(vector[1])), (vector[1], 1.0)) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + @test Enzyme.jacobian(Enzyme.Forward, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≃ + ((tuplev[2], -sin(tuplev[1])), (tuplev[1], 1.0)) + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ + [tuplev[2] tuplev[1]; -sin(tuplev[1]) 1.0] + + # ∂ matrix / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, mkarray2, tuplev)[1] ≃ + ([tuplev[2] -sin(tuplev[1]); 0.0 1.0], [tuplev[1] 1.0; exp(tuplev[2]) 0.0]) + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray2, tuplev)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, mkarray2, tuplev)[1] ≈ + [tuplev[2] -sin(tuplev[1]); 0.0 1.0;;; tuplev[1] 1.0; exp(tuplev[2]) 0.0] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> mkarray2, tuplev)[1] ≈ + [tuplev[2] -sin(tuplev[1]); 0.0 1.0;;; tuplev[1] 1.0; exp(tuplev[2]) 0.0] + + # ∂ struct / ∂ tuple + @test Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev)[1] ≃ + (OutStruct(tuplev[2], -sin(tuplev[1]), 0.0), OutStruct(tuplev[1], 1.0, exp(tuplev[2]))) + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x[1] * x[2], cos(x[1]) + x[2], exp(x[2])), tuplev)[1] ≃ + [OutStruct(tuplev[2], -sin(tuplev[1]), 0.0), OutStruct(tuplev[1], 1.0, exp(tuplev[2]))] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x[1] * x[2], cos(x[1]) + x[2]), tuplev)[1] ≈ ([tuplev[2], -sin(tuplev[1])], [tuplev[1], 1.0]) + + + + matrix = [2.7 3.1; 4.7 5.6] + + # ∂ scalar / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + @test Enzyme.gradient(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + @test Enzyme.jacobian(Enzyme.Forward, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + @test Enzyme.jacobian(Enzyme.Reverse, x->x[1,1]*x[1,2]+x[2,1]*x[2,2], matrix)[1] ≈ [matrix[1,2] matrix[1,1]; matrix[2,2] matrix[2,1]] + + # ∂ vector / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] ≈ + mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) + @test_broken Enzyme.gradient(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] + # again we can't use array construction syntax because of 1.6 + @test Enzyme.jacobian(Enzyme.Forward, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] ≈ + mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) + @test Enzyme.jacobian(Enzyme.Reverse, x->[x[1,1]*x[1,2],x[2,1]*x[2,2]], matrix)[1] ≈ + mkarray((2,2,2), matrix[1,2], 0.0, 0.0, matrix[2,2], matrix[1,1], 0.0, 0.0, matrix[2,1]) + + # ∂ tuple / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix)[1] ≃ + [(matrix[1,2],0.0) (matrix[1,1],0.0); (0.0,matrix[2,2]) (0.0,matrix[2,1])] + @test_broken Enzyme.gradient(Enzyme.Reverse, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix) + @test Enzyme.jacobian(Enzyme.Forward, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix)[1] ≃ + [(matrix[1,2],0.0) (matrix[1,1],0.0); (0.0,matrix[2,2]) (0.0,matrix[2,1])] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x->(x[1,1]*x[1,2],x[2,1]*x[2,2]), matrix)[1] + + mkarray3 = x -> mkarray((2,2), x[1,1]*x[1,2], exp(x[1,1])+x[2,2], x[2,1]*x[2,2], sin(x[1,2])+x[2,1]) + + # ∂ matrix / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, mkarray3, matrix)[1] ≈ + mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, + matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) + @test_broken Enzyme.gradient(Enzyme.Reverse, mkarray3, matrix)[1] + # array construction syntax broken on 1.6 + @test Enzyme.jacobian(Enzyme.Forward, mkarray3, matrix)[1] ≈ + mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, + matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) + @test Enzyme.jacobian(Enzyme.Reverse, mkarray3, matrix)[1] ≈ + mkarray((2,2,2,2), matrix[1,2],exp(matrix[1,1]),0.0,0.0,0.0,0.0,matrix[2,2],1.0, + matrix[1,1],0.0,0.0,cos(matrix[1,2]),0.0,1.0,matrix[2,1],0.0) + + # ∂ tuple / ∂ matrix + @test Enzyme.gradient(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] ≃ + [OutStruct(matrix[1,2],0.0, exp(matrix[1,1])) OutStruct(matrix[1,1],0.0,0.0); OutStruct(0.0,matrix[2,2],0.0) OutStruct(0.0,matrix[2,1], 1.0)] + @test_broken Enzyme.gradient(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] + @test Enzyme.jacobian(Enzyme.Forward, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] ≃ + [OutStruct(matrix[1,2],0.0, exp(matrix[1,1])) OutStruct(matrix[1,1],0.0,0.0); OutStruct(0.0,matrix[2,2],0.0) OutStruct(0.0,matrix[2,1], 1.0)] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x->OutStruct(x[1,1]*x[1,2],x[2,1]*x[2,2], exp(x[1,1])+x[2,2]), matrix)[1] + + + istruct = InpStruct(2.7, 3.1, 4.7) + + # ∂ scalar / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct)[1] + @test Enzyme.gradient(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct)[1] ≃ InpStruct(istruct.i2, istruct.i1, 1.0) + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> x.i1 * x.i2 + x.i3, istruct)[1] + @test Enzyme.jacobian(Enzyme.Reverse, x -> x.i1 * x.i2 + x.i3, istruct)[1] ≃ InpStruct(istruct.i2, istruct.i1, 1.0) + + # ∂ vector / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] + @test Enzyme.jacobian(Enzyme.Reverse, x -> [x.i1 * x.i2, cos(x.i3) + x.i1], istruct)[1] ≃ [InpStruct(istruct.i2, istruct.i1, 0.0), InpStruct(1.0, 0.0, -sin(istruct.i3))] + + # ∂ tuple / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> (x.i1 * x.i2, cos(x.i3) + x.i1), istruct)[1] + + mkarray4 = x -> mkarray((2,2), x.i1*x.i2, exp(x.i2), cos(x.i3)+x.i1, x.i1) + + # ∂ matrix / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct)[1] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> [x.i1 * x.i2 cos(x.i3) + x.i1; exp(x.i2) x.i1], istruct)[1] + @test Enzyme.jacobian(Enzyme.Reverse, mkarray4, istruct)[1] ≃ + [InpStruct(istruct.i2, istruct.i1, 0.0) InpStruct(1.0, 0.0, -sin(istruct.i3)); + InpStruct(0.0, exp(istruct.i2), 0.0) InpStruct(1.0, 0.0, 0.0)] + + # ∂ struct / ∂ struct + @test_broken Enzyme.gradient(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] + @test_broken Enzyme.gradient(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Forward, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] + @test_broken Enzyme.jacobian(Enzyme.Reverse, x -> OutStruct(x.i1 * x.i2, cos(x.i3) + x.i1, exp(x.i2)), istruct)[1] +end + +@testset "Simple Jacobian" begin + @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0)[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0)[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0])[1] ≈ [4.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, chunk=Val(1))[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, chunk=Val(1))[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], chunk=Val(1))[1] ≈ [4.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Forward, x->2*x, 3.0, chunk=Val(2))[1] ≈ 2.0 + @test Enzyme.jacobian(Enzyme.Forward, x->[x, 2*x], 3.0, chunk=Val(2))[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Forward, x->sum(abs2, x), [2.0, 3.0], chunk=Val(2))[1] ≈ [4.0, 6.0] + + @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, n_outs=Val((2,)))[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, n_outs=Val((2,)), chunk=Val(1))[1] ≈ [1.0, 2.0] + @test Enzyme.jacobian(Enzyme.Reverse, x->[x, 2*x], 3.0, n_outs=Val((2,)), chunk=Val(2))[1] ≈ [1.0, 2.0] + + x = float.(reshape(1:6, 2, 3)) + + fillabs2(x) = [sum(abs2, x), 10*sum(abs2, x), 100*sum(abs2, x), 1000*sum(abs2, x)] + + jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x)[1] + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, chunk=Val(1))[1] + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + jac = Enzyme.jacobian(Enzyme.Forward, fillabs2, x, chunk=Val(2))[1] + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + + jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, n_outs=Val((4,)), chunk=Val(1))[1] + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + jac = Enzyme.jacobian(Enzyme.Reverse, fillabs2, x, n_outs=Val((4,)), chunk=Val(2))[1] + + @test jac[1, :, :] ≈ [2.0 6.0 10.0; 4.0 8.0 12.0] + @test jac[2, :, :] ≈ [20.0 60.0 100.0; 40.0 80.0 120.0] + @test jac[3, :, :] ≈ [200.0 600.0 1000.0; 400.0 800.0 1200.0] + @test jac[4, :, :] ≈ [2000.0 6000.0 10000.0; 4000.0 8000.0 12000.0] + + fillinpabs2(x) = [(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 10*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 100*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3), 1000*(x.i1*x.i1+x.i2*x.i2+x.i3*x.i3)] + + x2 = InpStruct(1.0, 2.0, 3.0) + + jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, n_outs=Val((4,)), chunk=Val(1))[1] + + @test jac[1] == InpStruct(2.0, 4.0, 6.0) + @test jac[2] == InpStruct(20.0, 40.0, 60.0) + @test jac[3] == InpStruct(200.0, 400.0, 600.0) + @test jac[4] == InpStruct(2000.0, 4000.0, 6000.0) + + jac = Enzyme.jacobian(Enzyme.Reverse, fillinpabs2, x2, n_outs=Val((4,)), chunk=Val(2))[1] + + @test jac[1] == InpStruct(2.0, 4.0, 6.0) + @test jac[2] == InpStruct(20.0, 40.0, 60.0) + @test jac[3] == InpStruct(200.0, 400.0, 600.0) + @test jac[4] == InpStruct(2000.0, 4000.0, 6000.0) + + filloutabs2(x) = OutStruct(sum(abs2, x), 10*sum(abs2, x), 100*sum(abs2, x)) + + jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x)[1] + + @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) + @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) + + @test jac[1, 2] == OutStruct(6.0, 60.0, 600.0) + @test jac[2, 2] == OutStruct(8.0, 80.0, 800.0) + + @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) + @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) + + jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, chunk=Val(1))[1] + + @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) + @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) + + @test jac[1, 2] == OutStruct(6.0, 60.0, 600.0) + @test jac[2, 2] == OutStruct(8.0, 80.0, 800.0) + + @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) + @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) + + jac = Enzyme.jacobian(Enzyme.Forward, filloutabs2, x, chunk=Val(2))[1] + + @test jac[1, 1] == OutStruct(2.0, 20.0, 200.0) + @test jac[2, 1] == OutStruct(4.0, 40.0, 400.0) + + @test jac[1, 2] == OutStruct(6.0, 60.0, 600.0) + @test jac[2, 2] == OutStruct(8.0, 80.0, 800.0) + + @test jac[1, 3] == OutStruct(10.0, 100.0, 1000.0) + @test jac[2, 3] == OutStruct(12.0, 120.0, 1200.0) +end + + +@testset "Jacobian" begin + function inout(v) + [v[2], v[1]*v[1], v[1]*v[1]*v[1]] + end + + jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], n_outs=Val((3,)), chunk=Val(1))[1] + @test size(jac) == (3, 2) + @test jac ≈ [ 0.0 1.0; + 4.0 0.0; + 12.0 0.0] + + jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], chunk=Val(1))[1] + @test size(jac) == (3, 2) + @test jac ≈ [ 0.0 1.0; + 4.0 0.0; + 12.0 0.0] + + @test jac == Enzyme.jacobian(Forward, inout, [2.0, 3.0])[1] + + jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], n_outs=Val((3,)), chunk=Val(2))[1] + @test size(jac) == (3, 2) + @test jac ≈ [ 0.0 1.0; + 4.0 0.0; + 12.0 0.0] + + jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], chunk=Val(2))[1] + @test size(jac) == (3, 2) + @test jac ≈ [ 0.0 1.0; + 4.0 0.0; + 12.0 0.0] + + function f_test_1(A, x) + utmp = A*x[2:end] .+ x[1] + return utmp + end + + function f_test_2(A, x) + utmp = Vector{Float64}(undef, length(x)-1) + utmp .= A*x[2:end] .+ x[1] + return utmp + end + + function f_test_3!(u, A, x) + utmp .= A*x[2:end] .+ x[1] + end + + J_r_1(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_1(A, θ), x, n_outs=Val((5,)))[1] + J_r_2(A, x) = Enzyme.jacobian(Reverse, θ -> f_test_2(A, θ), x, n_outs=Val((5,)))[1] + J_r_3(u, A, x) = Enzyme.jacobian(Reverse, θ -> f_test_3!(u, A, θ), x, n_outs=Val((5,)))[1] + + J_f_1(A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_1(A, θ)), x)[1] + J_f_2(A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_2(A, θ)), x)[1] + J_f_3(u, A, x) = Enzyme.jacobian(Forward, Const(θ -> f_test_3!(u, A, θ)), x)[1] + + x = ones(6) + A = Matrix{Float64}(LinearAlgebra.I, 5, 5) + u = Vector{Float64}(undef, 5) + + @test J_r_1(A, x) == [ + 1.0 1.0 0.0 0.0 0.0 0.0; + 1.0 0.0 1.0 0.0 0.0 0.0; + 1.0 0.0 0.0 1.0 0.0 0.0; + 1.0 0.0 0.0 0.0 1.0 0.0; + 1.0 0.0 0.0 0.0 0.0 1.0; + ] + + @test J_r_2(A, x) == [ + 1.0 1.0 0.0 0.0 0.0 0.0; + 1.0 0.0 1.0 0.0 0.0 0.0; + 1.0 0.0 0.0 1.0 0.0 0.0; + 1.0 0.0 0.0 0.0 1.0 0.0; + 1.0 0.0 0.0 0.0 0.0 1.0; + ] + + @test J_f_1(A, x) == [ + 1.0 1.0 0.0 0.0 0.0 0.0; + 1.0 0.0 1.0 0.0 0.0 0.0; + 1.0 0.0 0.0 1.0 0.0 0.0; + 1.0 0.0 0.0 0.0 1.0 0.0; + 1.0 0.0 0.0 0.0 0.0 1.0; + ] + @test J_f_2(A, x) == [ + 1.0 1.0 0.0 0.0 0.0 0.0; + 1.0 0.0 1.0 0.0 0.0 0.0; + 1.0 0.0 0.0 1.0 0.0 0.0; + 1.0 0.0 0.0 0.0 1.0 0.0; + 1.0 0.0 0.0 0.0 0.0 1.0; + ] + + # @show J_r_3(u, A, x) + # @show J_f_3(u, A, x) +end