Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Add nested AD tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 15, 2024
1 parent fdec976 commit 7fa4ddd
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion test/autodiff_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,42 @@ end
end

@testset "Gradient of Gradient" begin

for (X, p) in zip(Xs, ps)
gs_fwddiff_x = ForwardDiff.gradient(
x -> sum(abs2,
batched_gradient(
AutoZygote(), simple_batched_function, x, Array(p))),
Array(X))
gs_fwddiff_p = ForwardDiff.gradient(
p -> sum(abs2,
batched_gradient(
AutoZygote(), simple_batched_function, Array(X), p)),
Array(p))

for backend in (
# AutoFiniteDiff(), # FIXME: FiniteDiff doesn't work well with ForwardDiff
# AutoForwardDiff(), # FIXME: The return type doesn't match
# AutoReverseDiff(), # FIXME: ReverseDiff with ForwardDiff problematic
AutoZygote(),)
arrType = backend isa AutoFiniteDiff || backend isa AutoReverseDiff ?
Array : identity

gs_zyg = Zygote.gradient(
(x, p) -> sum(
abs2, batched_gradient(backend, simple_batched_function, x, p)),
arrType(X),
arrType(p))
gs_rdiff = ReverseDiff.gradient(
(x, p) -> sum(
abs2, batched_gradient(backend, simple_batched_function, x, p)),
(Array(X), Array(p)))

@test Array(gs_fwddiff_x)Array(gs_zyg[1]) atol=1e-3
@test Array(gs_fwddiff_p)Array(gs_zyg[2]) atol=1e-3
@test Array(gs_fwddiff_x)Array(gs_rdiff[1]) atol=1e-3
@test Array(gs_fwddiff_p)Array(gs_rdiff[2]) atol=1e-3
end
end
end
end
end

0 comments on commit 7fa4ddd

Please sign in to comment.