From 263afd2e3e34a40b5f9c7306bb9f9a86637c08a5 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 31 Oct 2024 12:22:40 +0000 Subject: [PATCH 1/5] Add integration tests for Bijectors --- .github/workflows/CI.yml | 5 +- test/integration/Bijectors/Project.toml | 9 + test/integration/Bijectors/runtests.jl | 222 ++++++++++++++++++ .../{ => DynamicExpressions}/Project.toml | 0 .../runtests.jl} | 0 5 files changed, 234 insertions(+), 2 deletions(-) create mode 100644 test/integration/Bijectors/Project.toml create mode 100644 test/integration/Bijectors/runtests.jl rename test/integration/{ => DynamicExpressions}/Project.toml (100%) rename test/integration/{DynamicExpressions.jl => DynamicExpressions/runtests.jl} (100%) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 643cb2c043..40445b4df3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -264,6 +264,7 @@ jobs: - ubuntu-latest test: - DynamicExpressions + - Bijectors steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -273,8 +274,8 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - name: "Run tests" run: | - julia --color=yes --project=test/integration -e 'using Pkg; Pkg.develop([PackageSpec(; path) for path in (".", "lib/EnzymeCore")]); Pkg.instantiate()' - julia --color=yes --project=test/integration --threads=auto --check-bounds=yes test/integration/${{ matrix.test }}.jl + julia --color=yes --project=test/integration/${{ matrix.test }} -e 'using Pkg; Pkg.develop([PackageSpec(; path) for path in (".", "lib/EnzymeCore")]); Pkg.instantiate()' + julia --color=yes --project=test/integration/${{ matrix.test }} --threads=auto --check-bounds=yes test/integration/${{ matrix.test }}/runtests.jl shell: bash docs: name: Documentation diff --git a/test/integration/Bijectors/Project.toml b/test/integration/Bijectors/Project.toml new file mode 100644 index 0000000000..2b8c2f46c2 --- /dev/null +++ b/test/integration/Bijectors/Project.toml @@ -0,0 +1,9 @@ +[deps] +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" + +[compat] +Bijectors = "=0.13.16" +FiniteDifferences = "0.12.32" +StableRNGs = "1.0.2" diff --git a/test/integration/Bijectors/runtests.jl b/test/integration/Bijectors/runtests.jl new file mode 100644 index 0000000000..8bcb293166 --- /dev/null +++ b/test/integration/Bijectors/runtests.jl @@ -0,0 +1,222 @@ +module BijectorsIntegrationTests + +using Bijectors: Bijectors +using Enzyme: Enzyme +using FiniteDifferences: FiniteDifferences +using LinearAlgebra: LinearAlgebra +using Random: randn +using StableRNGs: StableRNG +using Test: @test, @test_broken, @testset + +rng = StableRNG(23) + +""" +Enum type for choosing Enzyme autodiff modes. +""" +@enum ModeSelector Neither Forward Reverse Both + +""" +Type for specifying a test case for `Enzyme.gradient`. + +The test will check the accuracy of the gradient of `func` at `value` against `finitediff`, +with both forward and reverse mode autodiff. `name` is for diagnostic printing. +`runtime_activity`, `broken`, `skip` are for specifying whether to use +`Enzyme.set_runtime_activity` or not, whether the test is broken, and whether the test is so +broken we can't even run `@test_broken` on it (because it crashes Julia). All of them take +values `Neither`, `Forward`, `Reverse` or `Both`, to specify which mode to apply the setting +to. `splat` is for specifying whether to call the function as `func(value)` or as +`func(value...)`. + +Default values are `name=nothing`, `runtime_activity=Neither`, `broken=Neither`, +`skip=Neither`, and `splat=false`. +""" +struct TestCase + func::Function + value + name::Union{String, Nothing} + runtime_activity::ModeSelector + broken::ModeSelector + skip::ModeSelector + splat::Bool +end + +# Default values for most arguments. +function TestCase( + f, value; + name=nothing, runtime_activity=Neither, broken=Neither, skip=Neither, splat=false +) + return TestCase(f, value, name, runtime_activity, broken, skip, splat) +end + +""" +Test Enzyme.gradient, both Forward and Reverse mode, against FiniteDifferences.grad. +""" +function test_grad(case::TestCase; rtol=1e-6, atol=1e-6) + @nospecialize + f = case.func + # We'll call the function as f(x...), so wrap in a singleton tuple if need be. + x = case.splat ? case.value : (case.value,) + finitediff = collect( + FiniteDifferences.grad(FiniteDifferences.central_fdm(4, 1), f, x...)[1] + ) + + f_mode = if (case.runtime_activity === Both || case.runtime_activity === Forward) + Enzyme.set_runtime_activity(Enzyme.Forward) + else + Enzyme.Forward + end + r_mode = if (case.runtime_activity === Both || case.runtime_activity === Reverse) + Enzyme.set_runtime_activity(Enzyme.Reverse) + else + Enzyme.Reverse + end + + if !(case.skip === Forward) && !(case.skip === Both) + if case.broken === Both || case.broken === Forward + @test_broken( + collect(Enzyme.gradient(f_mode, Enzyme.Const(f), x...)[1]) ≈ finitediff, + rtol = rtol, + atol = atol, + ) + else + @test( + collect(Enzyme.gradient(f_mode, Enzyme.Const(f), x...)[1]) ≈ finitediff, + rtol = rtol, + atol = atol, + ) + end + end + + if !(case.skip === Reverse) && !(case.skip === Both) + if case.broken === Both || case.broken === Reverse + @test_broken( + collect(Enzyme.gradient(r_mode, Enzyme.Const(f), x...)[1]) ≈ finitediff, + rtol = rtol, + atol = atol, + ) + else + @test( + collect(Enzyme.gradient(r_mode, Enzyme.Const(f), x...)[1]) ≈ finitediff, + rtol = rtol, + atol = atol, + ) + end + end + return nothing +end + +""" +A helper function that returns a TestCase that evaluates sum(bijector(inverse(bijector)(x))) +""" +function sum_b_binv_test_case( + bijector, dim; runtime_activity=Neither, name=nothing, broken=Neither, skip=Neither +) + if name === nothing + name = string(bijector) + end + b_inv = Bijectors.inverse(bijector) + return TestCase( + x -> sum(bijector(b_inv(x))), + randn(rng, dim); + runtime_activity=runtime_activity, name=name, broken=broken, skip=skip + ) +end + +@testset "Bijectors integration tests" begin + test_cases = TestCase[ + sum_b_binv_test_case(Bijectors.VecCorrBijector(), 3), + sum_b_binv_test_case(Bijectors.VecCorrBijector(), 0), + sum_b_binv_test_case(Bijectors.CorrBijector(), (3, 3)), + # TODO(mhauru) Skip Reverse because of + # https://github.com/EnzymeAD/Enzyme.jl/issues/2033 + sum_b_binv_test_case(Bijectors.CorrBijector(), (0, 0); skip=Reverse), + sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:L), 3), + sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:L), 0), + sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:U), 3), + sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:U), 0), + sum_b_binv_test_case(Bijectors.Coupling(Bijectors.Shift, Bijectors.PartitionMask(3, [1], [2])), 3), + sum_b_binv_test_case(Bijectors.InvertibleBatchNorm(3), (3, 3)), + sum_b_binv_test_case(Bijectors.LeakyReLU(0.2), 3), + sum_b_binv_test_case(Bijectors.Logit(0.1, 0.3), 3), + # TODO(mhauru) Skip Reverse because of + # https://github.com/EnzymeAD/Enzyme.jl/issues/2034 + sum_b_binv_test_case(Bijectors.PDBijector(), (3, 3); skip=Reverse), + sum_b_binv_test_case(Bijectors.PDVecBijector(), 3), + sum_b_binv_test_case( + Bijectors.Permute([ + 0 1 0; + 1 0 0; + 0 0 1 + ]), + (3, 3), + ), + # TODO(mhauru) Both modes broken because of + # https://github.com/EnzymeAD/Enzyme.jl/issues/2035 + sum_b_binv_test_case(Bijectors.PlanarLayer(3), (3, 3); broken=Both), + # TODO(mhauru) Skip Reverse because of + # https://github.com/EnzymeAD/Enzyme.jl/issues/2036 + sum_b_binv_test_case(Bijectors.RadialLayer(3), 3; skip=Reverse), + sum_b_binv_test_case(Bijectors.Reshape((2, 3), (3, 2)), (2, 3)), + sum_b_binv_test_case(Bijectors.Scale(0.2), 3), + sum_b_binv_test_case(Bijectors.Shift(-0.4), 3), + sum_b_binv_test_case(Bijectors.SignFlip(), 3), + sum_b_binv_test_case(Bijectors.SimplexBijector(), 3), + sum_b_binv_test_case(Bijectors.TruncatedBijector(-0.2, 0.5), 3), + + # Below, some test cases that don't fit the sum_b_binv_test_case mold. + + TestCase( + function (x) + b = Bijectors.RationalQuadraticSpline([-0.2, 0.1, 0.5], [-0.3, 0.3, 0.9], [1.0, 0.2, 1.0]) + binv = Bijectors.inverse(b) + return sum(binv(b(x))) + end, + randn(rng); + name="RationalQuadraticSpline on scalar", + ), + + TestCase( + function (x) + b = Bijectors.OrderedBijector() + binv = Bijectors.inverse(b) + return sum(binv(b(x))) + end, + randn(rng, 7); + name="OrderedBijector", + ), + + # TODO(mhauru) Skip Reverse because of + # https://github.com/EnzymeAD/Enzyme.jl/issues/2029 + TestCase( + function (x) + layer = Bijectors.PlanarLayer(x[1:2], x[3:4], x[5:5]) + flow = Bijectors.transformed(Bijectors.MvNormal(zeros(2), LinearAlgebra.I), layer) + x = x[6:7] + return Bijectors.logpdf(flow.dist, x) - Bijectors.logabsdetjac(flow.transform, x) + end, + randn(rng, 7); + name="PlanarLayer7", + skip=Reverse, + ), + + # TODO(mhauru) Reverse mode broken because of + # https://github.com/EnzymeAD/Enzyme.jl/issues/2030 + TestCase( + function (x) + layer = Bijectors.PlanarLayer(x[1:2], x[3:4], x[5:5]) + flow = Bijectors.transformed(Bijectors.MvNormal(zeros(2), LinearAlgebra.I), layer) + x = reshape(x[6:end], 2, :) + return sum(Bijectors.logpdf(flow.dist, x) - Bijectors.logabsdetjac(flow.transform, x)) + end, + randn(rng, 11); + name="PlanarLayer11", + broken=Reverse, + ), + ] + + @testset "$(case.name)" for case in test_cases + test_grad(case) + end +end + +end diff --git a/test/integration/Project.toml b/test/integration/DynamicExpressions/Project.toml similarity index 100% rename from test/integration/Project.toml rename to test/integration/DynamicExpressions/Project.toml diff --git a/test/integration/DynamicExpressions.jl b/test/integration/DynamicExpressions/runtests.jl similarity index 100% rename from test/integration/DynamicExpressions.jl rename to test/integration/DynamicExpressions/runtests.jl From 459d2d07857c893edf1b2cdcacf277344d381d84 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 1 Nov 2024 17:30:52 +0000 Subject: [PATCH 2/5] Remove unnecessary collect calls --- test/integration/Bijectors/runtests.jl | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/test/integration/Bijectors/runtests.jl b/test/integration/Bijectors/runtests.jl index 8bcb293166..395ec8d3f6 100644 --- a/test/integration/Bijectors/runtests.jl +++ b/test/integration/Bijectors/runtests.jl @@ -56,9 +56,8 @@ function test_grad(case::TestCase; rtol=1e-6, atol=1e-6) f = case.func # We'll call the function as f(x...), so wrap in a singleton tuple if need be. x = case.splat ? case.value : (case.value,) - finitediff = collect( - FiniteDifferences.grad(FiniteDifferences.central_fdm(4, 1), f, x...)[1] - ) + finitediff = FiniteDifferences.grad(FiniteDifferences.central_fdm(4, 1), f, x...)[1] + f_mode = if (case.runtime_activity === Both || case.runtime_activity === Forward) Enzyme.set_runtime_activity(Enzyme.Forward) @@ -74,13 +73,13 @@ function test_grad(case::TestCase; rtol=1e-6, atol=1e-6) if !(case.skip === Forward) && !(case.skip === Both) if case.broken === Both || case.broken === Forward @test_broken( - collect(Enzyme.gradient(f_mode, Enzyme.Const(f), x...)[1]) ≈ finitediff, + Enzyme.gradient(f_mode, Enzyme.Const(f), x...)[1] ≈ finitediff, rtol = rtol, atol = atol, ) else @test( - collect(Enzyme.gradient(f_mode, Enzyme.Const(f), x...)[1]) ≈ finitediff, + Enzyme.gradient(f_mode, Enzyme.Const(f), x...)[1] ≈ finitediff, rtol = rtol, atol = atol, ) @@ -90,13 +89,13 @@ function test_grad(case::TestCase; rtol=1e-6, atol=1e-6) if !(case.skip === Reverse) && !(case.skip === Both) if case.broken === Both || case.broken === Reverse @test_broken( - collect(Enzyme.gradient(r_mode, Enzyme.Const(f), x...)[1]) ≈ finitediff, + Enzyme.gradient(r_mode, Enzyme.Const(f), x...)[1] ≈ finitediff, rtol = rtol, atol = atol, ) else @test( - collect(Enzyme.gradient(r_mode, Enzyme.Const(f), x...)[1]) ≈ finitediff, + Enzyme.gradient(r_mode, Enzyme.Const(f), x...)[1] ≈ finitediff, rtol = rtol, atol = atol, ) From 733ac292a64cd5448bdad64b90469f9d0058d267 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 1 Nov 2024 17:58:10 +0000 Subject: [PATCH 3/5] Mark one more broken test --- test/integration/Bijectors/runtests.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/integration/Bijectors/runtests.jl b/test/integration/Bijectors/runtests.jl index 395ec8d3f6..0141aeaf87 100644 --- a/test/integration/Bijectors/runtests.jl +++ b/test/integration/Bijectors/runtests.jl @@ -58,7 +58,6 @@ function test_grad(case::TestCase; rtol=1e-6, atol=1e-6) x = case.splat ? case.value : (case.value,) finitediff = FiniteDifferences.grad(FiniteDifferences.central_fdm(4, 1), f, x...)[1] - f_mode = if (case.runtime_activity === Both || case.runtime_activity === Forward) Enzyme.set_runtime_activity(Enzyme.Forward) else @@ -125,7 +124,9 @@ end test_cases = TestCase[ sum_b_binv_test_case(Bijectors.VecCorrBijector(), 3), sum_b_binv_test_case(Bijectors.VecCorrBijector(), 0), - sum_b_binv_test_case(Bijectors.CorrBijector(), (3, 3)), + # TODO(mhauru) Skip Reverse because of + # https://github.com/EnzymeAD/Enzyme.jl/issues/2041 + sum_b_binv_test_case(Bijectors.CorrBijector(), (3, 3); skip=Reverse), # TODO(mhauru) Skip Reverse because of # https://github.com/EnzymeAD/Enzyme.jl/issues/2033 sum_b_binv_test_case(Bijectors.CorrBijector(), (0, 0); skip=Reverse), From 96c1e2328f11e0d711eac33e3b705d22a37ccae1 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 5 Nov 2024 10:28:21 +0000 Subject: [PATCH 4/5] Update to which Bijectors tests to run --- test/integration/Bijectors/runtests.jl | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/test/integration/Bijectors/runtests.jl b/test/integration/Bijectors/runtests.jl index 0141aeaf87..47d5be6c12 100644 --- a/test/integration/Bijectors/runtests.jl +++ b/test/integration/Bijectors/runtests.jl @@ -124,9 +124,7 @@ end test_cases = TestCase[ sum_b_binv_test_case(Bijectors.VecCorrBijector(), 3), sum_b_binv_test_case(Bijectors.VecCorrBijector(), 0), - # TODO(mhauru) Skip Reverse because of - # https://github.com/EnzymeAD/Enzyme.jl/issues/2041 - sum_b_binv_test_case(Bijectors.CorrBijector(), (3, 3); skip=Reverse), + sum_b_binv_test_case(Bijectors.CorrBijector(), (3, 3)), # TODO(mhauru) Skip Reverse because of # https://github.com/EnzymeAD/Enzyme.jl/issues/2033 sum_b_binv_test_case(Bijectors.CorrBijector(), (0, 0); skip=Reverse), @@ -138,9 +136,7 @@ end sum_b_binv_test_case(Bijectors.InvertibleBatchNorm(3), (3, 3)), sum_b_binv_test_case(Bijectors.LeakyReLU(0.2), 3), sum_b_binv_test_case(Bijectors.Logit(0.1, 0.3), 3), - # TODO(mhauru) Skip Reverse because of - # https://github.com/EnzymeAD/Enzyme.jl/issues/2034 - sum_b_binv_test_case(Bijectors.PDBijector(), (3, 3); skip=Reverse), + sum_b_binv_test_case(Bijectors.PDBijector(), (3, 3)), sum_b_binv_test_case(Bijectors.PDVecBijector(), 3), sum_b_binv_test_case( Bijectors.Permute([ @@ -153,9 +149,7 @@ end # TODO(mhauru) Both modes broken because of # https://github.com/EnzymeAD/Enzyme.jl/issues/2035 sum_b_binv_test_case(Bijectors.PlanarLayer(3), (3, 3); broken=Both), - # TODO(mhauru) Skip Reverse because of - # https://github.com/EnzymeAD/Enzyme.jl/issues/2036 - sum_b_binv_test_case(Bijectors.RadialLayer(3), 3; skip=Reverse), + sum_b_binv_test_case(Bijectors.RadialLayer(3), 3), sum_b_binv_test_case(Bijectors.Reshape((2, 3), (3, 2)), (2, 3)), sum_b_binv_test_case(Bijectors.Scale(0.2), 3), sum_b_binv_test_case(Bijectors.Shift(-0.4), 3), @@ -185,8 +179,6 @@ end name="OrderedBijector", ), - # TODO(mhauru) Skip Reverse because of - # https://github.com/EnzymeAD/Enzyme.jl/issues/2029 TestCase( function (x) layer = Bijectors.PlanarLayer(x[1:2], x[3:4], x[5:5]) @@ -196,11 +188,8 @@ end end, randn(rng, 7); name="PlanarLayer7", - skip=Reverse, ), - # TODO(mhauru) Reverse mode broken because of - # https://github.com/EnzymeAD/Enzyme.jl/issues/2030 TestCase( function (x) layer = Bijectors.PlanarLayer(x[1:2], x[3:4], x[5:5]) @@ -210,7 +199,6 @@ end end, randn(rng, 11); name="PlanarLayer11", - broken=Reverse, ), ] From 3595ad09f37ff5838d6b3538f1852c8eef22f354 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 6 Nov 2024 17:38:47 +0000 Subject: [PATCH 5/5] Bring in another fixed test --- test/integration/Bijectors/runtests.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/integration/Bijectors/runtests.jl b/test/integration/Bijectors/runtests.jl index 47d5be6c12..23e6136561 100644 --- a/test/integration/Bijectors/runtests.jl +++ b/test/integration/Bijectors/runtests.jl @@ -125,9 +125,7 @@ end sum_b_binv_test_case(Bijectors.VecCorrBijector(), 3), sum_b_binv_test_case(Bijectors.VecCorrBijector(), 0), sum_b_binv_test_case(Bijectors.CorrBijector(), (3, 3)), - # TODO(mhauru) Skip Reverse because of - # https://github.com/EnzymeAD/Enzyme.jl/issues/2033 - sum_b_binv_test_case(Bijectors.CorrBijector(), (0, 0); skip=Reverse), + sum_b_binv_test_case(Bijectors.CorrBijector(), (0, 0)), sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:L), 3), sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:L), 0), sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:U), 3),