From 7fa4ddd8c2421ca28c91dad92bb657ec15834523 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 15 Mar 2024 00:48:39 -0400 Subject: [PATCH] Add nested AD tests --- test/autodiff_tests.jl | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/test/autodiff_tests.jl b/test/autodiff_tests.jl index 8f5d93c..22f2530 100644 --- a/test/autodiff_tests.jl +++ b/test/autodiff_tests.jl @@ -51,7 +51,42 @@ end end @testset "Gradient of Gradient" begin - + for (X, p) in zip(Xs, ps) + gs_fwddiff_x = ForwardDiff.gradient( + x -> sum(abs2, + batched_gradient( + AutoZygote(), simple_batched_function, x, Array(p))), + Array(X)) + gs_fwddiff_p = ForwardDiff.gradient( + p -> sum(abs2, + batched_gradient( + AutoZygote(), simple_batched_function, Array(X), p)), + Array(p)) + + for backend in ( + # AutoFiniteDiff(), # FIXME: FiniteDiff doesn't work well with ForwardDiff + # AutoForwardDiff(), # FIXME: The return type doesn't match + # AutoReverseDiff(), # FIXME: ReverseDiff with ForwardDiff problematic + AutoZygote(),) + arrType = backend isa AutoFiniteDiff || backend isa AutoReverseDiff ? + Array : identity + + gs_zyg = Zygote.gradient( + (x, p) -> sum( + abs2, batched_gradient(backend, simple_batched_function, x, p)), + arrType(X), + arrType(p)) + gs_rdiff = ReverseDiff.gradient( + (x, p) -> sum( + abs2, batched_gradient(backend, simple_batched_function, x, p)), + (Array(X), Array(p))) + + @test Array(gs_fwddiff_x)≈Array(gs_zyg[1]) atol=1e-3 + @test Array(gs_fwddiff_p)≈Array(gs_zyg[2]) atol=1e-3 + @test Array(gs_fwddiff_x)≈Array(gs_rdiff[1]) atol=1e-3 + @test Array(gs_fwddiff_p)≈Array(gs_rdiff[2]) atol=1e-3 + end + end end end end