From 480833c727e8a4e007a845e25f70d0bdd911944c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 2 Jan 2022 03:17:52 -0500 Subject: [PATCH] second derivative tests --- test/layers/basic.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index b362f55f16..0b9f340142 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -269,3 +269,18 @@ import Flux: activations @test_throws DimensionMismatch m(OneHotVector(3, 1000)) end end + +@testset "second derivatives" begin + m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2)) + @test Zygote.hessian_dual(sum∘m1, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m1, [1,2,3]) + + # NNlib's softmax gradient writes in-place + m2 = Chain(Dense(3,4,tanh), Dense(4,2), softmax) + @test_broken Zygote.hessian_dual(sum∘m2, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m2, [1,2,3]) + + # https://github.com/FluxML/NNlib.jl/issues/362 + m3 = Chain(Conv((3,), 2 => 3, relu), Dense(2,2)) + x3 = cat(Float32[1 2; 3 4; 5 6; 7 8]; dims=3) + @test_broken Zygote.hessian_dual(sum∘m3, x3) ≈ Zygote.hessian_reverse(sum∘m3, x3) +end +