From fdec97672c1e1e90ba23da4c9f087319f2b2a0af Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 15 Mar 2024 00:30:56 -0400 Subject: [PATCH] Remove extra code --- README.md | 2 +- ext/BatchedRoutinesReverseDiffExt.jl | 33 ----------- ext/BatchedRoutinesZygoteExt.jl | 2 - src/helpers.jl | 1 - test/autodiff_tests.jl | 86 ++++++++-------------------- 5 files changed, 25 insertions(+), 99 deletions(-) diff --git a/README.md b/README.md index 4745245..c3a96ea 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ [![CI](https://github.com/LuxDL/BatchedRoutines.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/BatchedRoutines.jl/actions/workflows/CI.yml) [![Build status](https://img.shields.io/buildkite/ba1f9622add5978c2d7b194563fd9327113c9c21e5734be20e/main.svg?label=gpu)](https://buildkite.com/julialang/lux-dot-jl) -[![codecov](https://codecov.io/gh/LuxDL/BatchedRoutines.jl/branch/main/graph/badge.svg?token=IMqBM1e3hz)](https://codecov.io/gh/LuxDL/BatchedRoutines.jl.jl) +[![codecov](https://codecov.io/gh/LuxDL/BatchedRoutines.jl/branch/main/graph/badge.svg?token=IMqBM1e3hz)](https://codecov.io/gh/LuxDL/BatchedRoutines.jl) [![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/BatchedRoutines)](https://pkgs.genieframework.com?packages=BatchedRoutines) [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) diff --git a/ext/BatchedRoutinesReverseDiffExt.jl b/ext/BatchedRoutinesReverseDiffExt.jl index fae9a5d..2059813 100644 --- a/ext/BatchedRoutinesReverseDiffExt.jl +++ b/ext/BatchedRoutinesReverseDiffExt.jl @@ -88,37 +88,4 @@ ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_gradient( ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_gradient( ad, f, x::ReverseDiff.TrackedArray, p) -@concrete struct ReverseDiffPullbackFunction <: Function - tape - ∂input - output -end - -function (pb_f::ReverseDiffPullbackFunction)(Δ) - if pb_f.output isa AbstractArray{<:ReverseDiff.TrackedReal} - @inbounds for (oᵢ, Δᵢ) in zip(vec(pb_f.output), vec(Δ)) - oᵢ.deriv = Δᵢ - end - else - vec(pb_f.output.deriv) .= vec(Δ) - end - ReverseDiff.reverse_pass!(pb_f.tape) - return pb_f.∂input -end - -function BatchedRoutines._value_and_pullback(::AutoReverseDiff, f::F, x) where {F} - tape = ReverseDiff.InstructionTape() - ∂x = zero(x) - x_tracked = ReverseDiff.TrackedArray(x, ∂x, tape) - y_tracked = f(x_tracked) - - if y_tracked isa AbstractArray{<:ReverseDiff.TrackedReal} - y = ReverseDiff.value.(y_tracked) - else - y = ReverseDiff.value(y_tracked) - end - - return y, ReverseDiffPullbackFunction(tape, ∂x, y_tracked) -end - end diff --git a/ext/BatchedRoutinesZygoteExt.jl b/ext/BatchedRoutinesZygoteExt.jl index d37939a..64d1e42 100644 --- a/ext/BatchedRoutinesZygoteExt.jl +++ b/ext/BatchedRoutinesZygoteExt.jl @@ -14,6 +14,4 @@ function BatchedRoutines._batched_gradient(::AutoZygote, f::F, u) where {F} return only(Zygote.gradient(f, u)) end -BatchedRoutines._value_and_pullback(::AutoZygote, f::F, x) where {F} = Zygote.pullback(f, x) - end diff --git a/src/helpers.jl b/src/helpers.jl index 858dfc1..354806d 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -117,7 +117,6 @@ end # Useful for computing the gradient of a gradient function _jacobian_vector_product end function _vector_jacobian_product end -function _value_and_pullback end function _batched_jacobian end function _batched_gradient end diff --git a/test/autodiff_tests.jl b/test/autodiff_tests.jl index b0b745c..8f5d93c 100644 --- a/test/autodiff_tests.jl +++ b/test/autodiff_tests.jl @@ -9,32 +9,16 @@ return sum(abs2, X_ .* p; dims=1) .- sum(abs, X_ .* p; dims=1) .+ p .^ 2 end - X = randn(rng, 3, 2, 4) |> aType - p = randn(rng, 6) |> aType + Xs = (aType(randn(rng, 3, 2, 4)), aType(randn(rng, 2, 4)), aType(randn(rng, 3))) + ps = (aType(randn(rng, 6)), aType(randn(rng, 2)), aType(randn(rng, 3))) - J_fdiff = batched_jacobian( - AutoFiniteDiff(), simple_batched_function, Array(X), Array(p)) - J_fwdiff = batched_jacobian(AutoForwardDiff(), simple_batched_function, X, p) + for (X, p) in zip(Xs, ps) + J_fdiff = batched_jacobian( + AutoFiniteDiff(), simple_batched_function, Array(X), Array(p)) + J_fwdiff = batched_jacobian(AutoForwardDiff(), simple_batched_function, X, p) - @test Matrix(J_fdiff)≈Matrix(J_fwdiff) atol=1e-3 - - X = randn(rng, 2, 4) |> aType - p = randn(rng, 2) |> aType - - J_fdiff = batched_jacobian( - AutoFiniteDiff(), simple_batched_function, Array(X), Array(p)) - J_fwdiff = batched_jacobian(AutoForwardDiff(), simple_batched_function, X, p) - - @test Matrix(J_fdiff)≈Matrix(J_fwdiff) atol=1e-3 - - X = randn(rng, 3) |> aType - p = randn(rng, 3) |> aType - - J_fdiff = batched_jacobian( - AutoFiniteDiff(), simple_batched_function, Array(X), Array(p)) - J_fwdiff = batched_jacobian(AutoForwardDiff(), simple_batched_function, X, p) - - @test Matrix(J_fdiff)≈Matrix(J_fwdiff) atol=1e-3 + @test Matrix(J_fdiff)≈Matrix(J_fwdiff) atol=1e-3 + end end end @@ -50,46 +34,24 @@ end abs2, sum(abs2, X_ .* p; dims=1) .- sum(abs, X_ .* p; dims=1) .+ p .^ 2) end - X = randn(rng, 3, 2, 4) |> aType - p = randn(rng, 6) |> aType + Xs = (aType(randn(rng, 3, 2, 4)), aType(randn(rng, 2, 4)), aType(randn(rng, 3))) + ps = (aType(randn(rng, 6)), aType(randn(rng, 2)), aType(randn(rng, 3))) - gs_fdiff = batched_gradient( - AutoFiniteDiff(), simple_batched_function, Array(X), Array(p)) - gs_fwdiff = batched_gradient(AutoForwardDiff(), simple_batched_function, X, p) - gs_rdiff = batched_gradient( - AutoReverseDiff(), simple_batched_function, Array(X), Array(p)) - gs_zygote = batched_gradient(AutoZygote(), simple_batched_function, X, p) + for (X, p) in zip(Xs, ps) + gs_fdiff = batched_gradient( + AutoFiniteDiff(), simple_batched_function, Array(X), Array(p)) + gs_fwdiff = batched_gradient(AutoForwardDiff(), simple_batched_function, X, p) + gs_rdiff = batched_gradient( + AutoReverseDiff(), simple_batched_function, Array(X), Array(p)) + gs_zygote = batched_gradient(AutoZygote(), simple_batched_function, X, p) - @test Array(gs_fdiff)≈Array(gs_fwdiff) atol=1e-3 - @test Array(gs_fwdiff)≈Array(gs_rdiff) atol=1e-3 - @test Array(gs_rdiff)≈Array(gs_zygote) atol=1e-3 - - X = randn(rng, 2, 4) |> aType - p = randn(rng, 2) |> aType - - gs_fdiff = batched_gradient( - AutoFiniteDiff(), simple_batched_function, Array(X), Array(p)) - gs_fwdiff = batched_gradient(AutoForwardDiff(), simple_batched_function, X, p) - gs_rdiff = batched_gradient( - AutoReverseDiff(), simple_batched_function, Array(X), Array(p)) - gs_zygote = batched_gradient(AutoZygote(), simple_batched_function, X, p) - - @test Array(gs_fdiff)≈Array(gs_fwdiff) atol=1e-3 - @test Array(gs_fwdiff)≈Array(gs_rdiff) atol=1e-3 - @test Array(gs_rdiff)≈Array(gs_zygote) atol=1e-3 - - X = randn(rng, 3) |> aType - p = randn(rng, 3) |> aType - - J_fdiff = batched_gradient( - AutoFiniteDiff(), simple_batched_function, Array(X), Array(p)) - J_fwdiff = batched_gradient(AutoForwardDiff(), simple_batched_function, X, p) - J_rdiff = batched_gradient( - AutoReverseDiff(), simple_batched_function, Array(X), Array(p)) - J_zygote = batched_gradient(AutoZygote(), simple_batched_function, X, p) + @test Array(gs_fdiff)≈Array(gs_fwdiff) atol=1e-3 + @test Array(gs_fwdiff)≈Array(gs_rdiff) atol=1e-3 + @test Array(gs_rdiff)≈Array(gs_zygote) atol=1e-3 + end - @test Array(J_fdiff)≈Array(J_fwdiff) atol=1e-3 - @test Array(J_fwdiff)≈Array(J_rdiff) atol=1e-3 - @test Array(J_rdiff)≈Array(J_zygote) atol=1e-3 + @testset "Gradient of Gradient" begin + + end end end