From 3b2654ebdf14f7710e76314523d89504fcd538d0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 12 Mar 2024 19:58:15 -0400 Subject: [PATCH] Efficient gradient over gradient implementations --- README.md | 2 ++ ext/BatchedRoutinesFiniteDiffExt.jl | 6 ++--- ext/BatchedRoutinesForwardDiffExt.jl | 18 ++++++++++---- ext/BatchedRoutinesReverseDiffExt.jl | 24 ++----------------- ext/BatchedRoutinesZygoteExt.jl | 24 ++----------------- src/api.jl | 21 +++++++---------- src/chainrules.jl | 35 +++++++++++++++++++++------- src/helpers.jl | 11 +++++++++ 8 files changed, 70 insertions(+), 71 deletions(-) diff --git a/README.md b/README.md index 183bc06..4745245 100644 --- a/README.md +++ b/README.md @@ -12,3 +12,5 @@ [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) + +This is currently a WIP and is being tested internally for some projects. Once such testing is finish we will register this package in the General registry. diff --git a/ext/BatchedRoutinesFiniteDiffExt.jl b/ext/BatchedRoutinesFiniteDiffExt.jl index 1d6d45c..e164730 100644 --- a/ext/BatchedRoutinesFiniteDiffExt.jl +++ b/ext/BatchedRoutinesFiniteDiffExt.jl @@ -10,7 +10,7 @@ using FiniteDiff: FiniteDiff # api.jl ## Exposed API -@inline function BatchedRoutines.batched_jacobian( +@inline function BatchedRoutines._batched_jacobian( ad::AutoFiniteDiff, f::F, x::AbstractVector{T}) where {F, T} J = FiniteDiff.finite_difference_jacobian(f, x, ad.fdjtype) (_assert_type(f) && _assert_type(x) && Base.issingletontype(F)) && @@ -18,7 +18,7 @@ using FiniteDiff: FiniteDiff return UniformBlockDiagonalMatrix(J) end -@inline function BatchedRoutines.batched_jacobian( +@inline function BatchedRoutines._batched_jacobian( ad::AutoFiniteDiff, f::F, x::AbstractMatrix) where {F} f! = @closure (y, x_) -> copyto!(y, f(x_)) fx = f(x) @@ -31,7 +31,7 @@ end end # NOTE: This doesn't exploit batching -@inline function BatchedRoutines.batched_gradient(ad::AutoFiniteDiff, f::F, x) where {F} +@inline function BatchedRoutines._batched_gradient(ad::AutoFiniteDiff, f::F, x) where {F} return FiniteDiff.finite_difference_batched_gradient(f, x, ad.fdjtype) end diff --git a/ext/BatchedRoutinesForwardDiffExt.jl b/ext/BatchedRoutinesForwardDiffExt.jl index 207cc07..7d371be 100644 --- a/ext/BatchedRoutinesForwardDiffExt.jl +++ b/ext/BatchedRoutinesForwardDiffExt.jl @@ -123,7 +123,7 @@ end end ## Exposed API -@inline function BatchedRoutines.batched_jacobian( +@inline function BatchedRoutines._batched_jacobian( ad::AutoForwardDiff{CK}, f::F, u::AbstractVector{T}) where {CK, F, T} tag = ad.tag === nothing ? ForwardDiff.Tag{F, eltype(u)}() : ad.tag if CK === nothing || CK ≤ 0 @@ -138,13 +138,13 @@ end return UniformBlockDiagonalMatrix(J) end -@inline function BatchedRoutines.batched_jacobian( +@inline function BatchedRoutines._batched_jacobian( ad::AutoForwardDiff, f::F, u::AbstractMatrix) where {F} return last(BatchedRoutines.__batched_value_and_jacobian(ad, f, u)) end -@inline function BatchedRoutines.batched_gradient( - ad::AutoForwardDiff{CK}, f::F, u::AbstractMatrix) where {F, CK} +@inline function BatchedRoutines._batched_gradient( + ad::AutoForwardDiff{CK}, f::F, u) where {F, CK} tag = ad.tag === nothing ? ForwardDiff.Tag{F, eltype(u)}() : ad.tag if CK === nothing || CK ≤ 0 cfg = ForwardDiff.GradientConfig( @@ -167,4 +167,14 @@ function BatchedRoutines._jacobian_vector_product(ad::AutoForwardDiff, f::F, x, return ForwardDiff.partials.(y_dual, 1) end +function BatchedRoutines._jacobian_vector_product( + ad::AutoForwardDiff, f::F, x, u, p) where {F} + Tag = ad.tag === nothing ? typeof(ForwardDiff.Tag(f, eltype(x))) : typeof(ad.tag) + T = promote_type(eltype(x), eltype(u)) + partials = ForwardDiff.Partials{1, T}.(tuple.(u)) + u_dual = ForwardDiff.Dual{Tag, T, 1}.(u, partials) + y_dual = f(u_dual, p) + return ForwardDiff.partials.(y_dual, 1) +end + end diff --git a/ext/BatchedRoutinesReverseDiffExt.jl b/ext/BatchedRoutinesReverseDiffExt.jl index cae1bdf..fbfca73 100644 --- a/ext/BatchedRoutinesReverseDiffExt.jl +++ b/ext/BatchedRoutinesReverseDiffExt.jl @@ -14,8 +14,7 @@ const CRC = ChainRulesCore Base.@assume_effects :total BatchedRoutines._assert_type(::Type{<:ReverseDiff.TrackedArray})=false Base.@assume_effects :total BatchedRoutines._assert_type(::Type{<:AbstractArray{<:ReverseDiff.TrackedReal}})=false -function BatchedRoutines.batched_gradient( - ::AutoReverseDiff, f::F, u::AbstractMatrix) where {F} +function BatchedRoutines._batched_gradient(::AutoReverseDiff, f::F, u) where {F} return ReverseDiff.gradient(f, u) end @@ -37,7 +36,7 @@ function (pb_f::ReverseDiffPullbackFunction)(Δ) return pb_f.∂input end -function _value_and_pullback(f::F, x) where {F} +function BatchedRoutines._value_and_pullback(::AutoReverseDiff, f::F, x) where {F} tape = ReverseDiff.InstructionTape() ∂x = zero(x) x_tracked = ReverseDiff.TrackedArray(x, ∂x, tape) @@ -52,23 +51,4 @@ function _value_and_pullback(f::F, x) where {F} return y, ReverseDiffPullbackFunction(tape, ∂x, y_tracked) end -function CRC.rrule(::typeof(BatchedRoutines.batched_gradient), - ad::AutoReverseDiff, f::F, x::AbstractMatrix) where {F} - if BatchedRoutines._is_extension_loaded(Val(:ForwardDiff)) - dx = BatchedRoutines.batched_gradient(ad, f, x) - # Use Forward Over Reverse to compute the Hessian Vector Product - ∇batched_gradient = @closure Δ -> begin - ∂x = BatchedRoutines._jacobian_vector_product( - AutoForwardDiff(), @closure(x->BatchedRoutines.batched_gradient(ad, f, x)), - x, reshape(Δ, size(x))) - return NoTangent(), NoTangent(), NoTangent(), ∂x - end - return dx, ∇batched_gradient - end - - dx, pb_f = _value_and_pullback(Base.Fix1(ReverseDiff.gradient, f), x) - ∇batched_gradient = @closure Δ -> (NoTangent(), NoTangent(), NoTangent(), pb_f(Δ)) - return dx, ∇batched_gradient -end - end diff --git a/ext/BatchedRoutinesZygoteExt.jl b/ext/BatchedRoutinesZygoteExt.jl index dbcd57d..d37939a 100644 --- a/ext/BatchedRoutinesZygoteExt.jl +++ b/ext/BatchedRoutinesZygoteExt.jl @@ -10,30 +10,10 @@ const CRC = ChainRulesCore @inline BatchedRoutines._is_extension_loaded(::Val{:Zygote}) = true -function BatchedRoutines.batched_gradient(::AutoZygote, f::F, u::AbstractMatrix) where {F} +function BatchedRoutines._batched_gradient(::AutoZygote, f::F, u) where {F} return only(Zygote.gradient(f, u)) end -function CRC.rrule(::typeof(BatchedRoutines.batched_gradient), - ad::AutoZygote, f::F, x::AbstractMatrix) where {F} - if BatchedRoutines._is_extension_loaded(Val(:ForwardDiff)) - dx = BatchedRoutines.batched_gradient(ad, f, x) - # Use Forward Over Reverse to compute the Hessian Vector Product - ∇batched_gradient = @closure Δ -> begin - ∂x = BatchedRoutines._jacobian_vector_product( - AutoForwardDiff(), @closure(x->BatchedRoutines.batched_gradient(ad, f, x)), - x, reshape(Δ, size(x))) - return NoTangent(), NoTangent(), NoTangent(), ∂x - end - return dx, ∇batched_gradient - end - - dx, pb_f = Zygote.pullback(@closure(x->only(Zygote.gradient(f, x))), x) - ∇batched_gradient = @closure Δ -> begin - ∂x = only(pb_f(Δ)) # Else we have to do Zygote over Zygote - return NoTangent(), NoTangent(), NoTangent(), ∂x - end - return dx, ∇batched_gradient -end +BatchedRoutines._value_and_pullback(::AutoZygote, f::F, x) where {F} = Zygote.pullback(f, x) end diff --git a/src/api.jl b/src/api.jl index 7dbf151..3722fc7 100644 --- a/src/api.jl +++ b/src/api.jl @@ -10,7 +10,10 @@ Use the backend `ad` to compute the Jacobian of `f` at `x` in batched mode. Retu If the batches interact among themselves, then the Jacobian is not block diagonal and this function will not work as expected. """ -function batched_jacobian end +function batched_jacobian(ad, f::F, u::AbstractVecOrMat) where {F} + _assert_loaded_backend(ad) + return _batched_jacobian(ad, f, u) +end @inline function batched_jacobian(ad, f::F, u::AbstractArray) where {F} B = size(u, ndims(u)) @@ -29,19 +32,13 @@ Use the backend `ad` to compute the batched_gradient of `f` at `x`. For the forw different from calling the batched_gradient function in the backend. This exists to efficiently swap backends for computing the `batched_gradient` of the `batched_gradient`. """ -function batched_gradient end - -function batched_gradient(ad, f::F, u::AbstractVector) where {F} - return vec(batched_gradient(ad, f, reshape(u, 1, :))) +function batched_gradient(ad, f::F, u) where {F} + _assert_loaded_backend(ad) + return _batched_gradient(ad, f, u) end -function batched_gradient(ad, f::F, u::AbstractArray) where {F} - B = size(u, ndims(u)) - f_mat = @closure x -> reshape(f(reshape(x, size(u))), :, B) - return reshape(batched_gradient(ad, f_mat, reshape(u, :, B)), size(u)) -end - -@inline batched_gradient(ad, f::F, u, p) where {F} = batched_gradient(ad, Base.Fix2(f, p), u) +@inline batched_gradient(ad, f::F, u, p) where {F} = batched_gradient( + ad, Base.Fix2(f, p), u) """ batched_pickchunksize(X::AbstractArray, n::Int) diff --git a/src/chainrules.jl b/src/chainrules.jl index e198ca6..bacec05 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -3,7 +3,7 @@ function __batched_value_and_jacobian(ad, f::F, x) where {F} return f(x), J end -# Reverse over Forward: Just construct Hessian for now +# FIXME: Gradient of jacobians is really in-efficient here function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x::AbstractMatrix) where {F} N, B = size(x) J, H = __batched_value_and_jacobian( @@ -23,7 +23,6 @@ function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x, p) where {F} J, H = __batched_value_and_jacobian( ad, @closure(y->reshape(batched_jacobian(ad, f, y, p).data, :, B)), x) - # TODO: This can be written as a JVP p_size = size(p) _, Jₚ_ = __batched_value_and_jacobian( ad, @closure(p->reshape(batched_jacobian(ad, f, x, reshape(p, p_size)).data, :, B)), @@ -40,16 +39,36 @@ function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x, p) where {F} return UniformBlockDiagonalMatrix(reshape(J, :, N, B)), ∇batched_jacobian end -function CRC.rrule(::typeof(batched_gradient), ad, f::F, x::AbstractMatrix) where {F} - N, B = size(x) - dx, J = BatchedRoutines.__batched_value_and_jacobian( - ad, @closure(x->batched_gradient(ad, f, x)), x) +function CRC.rrule(::typeof(batched_gradient), ad, f::F, x) where {F} + BatchedRoutines._is_extension_loaded(Val(:ForwardDiff)) || + throw(ArgumentError("`ForwardDiff.jl` needs to be loaded to compute the gradient \ + of `batched_gradient`.")) - function ∇batched_gradient(Δ) - ∂x = reshape(batched_mul(reshape(Δ, 1, :, nbatches(Δ)), J.data), :, nbatches(Δ)) + dx = BatchedRoutines.batched_gradient(ad, f, x) + ∇batched_gradient = @closure Δ -> begin + ∂x = _jacobian_vector_product( + AutoForwardDiff(), @closure(x->BatchedRoutines.batched_gradient(ad, f, x)), + x, reshape(Δ, size(x))) return NoTangent(), NoTangent(), NoTangent(), ∂x end + return dx, ∇batched_gradient +end +function CRC.rrule(::typeof(batched_gradient), ad, f::F, x, p) where {F} + BatchedRoutines._is_extension_loaded(Val(:ForwardDiff)) || + throw(ArgumentError("`ForwardDiff.jl` needs to be loaded to compute the gradient \ + of `batched_gradient`.")) + + dx = BatchedRoutines.batched_gradient(ad, f, x, p) + ∇batched_gradient = @closure Δ -> begin + ∂x = _jacobian_vector_product(AutoForwardDiff(), + @closure(x->BatchedRoutines.batched_gradient(ad, Base.Fix2(f, p), x)), + x, reshape(Δ, size(x))) + ∂p = _jacobian_vector_product(AutoForwardDiff(), + @closure((x, p)->BatchedRoutines.batched_gradient(ad, Base.Fix1(f, x), p)), + x, reshape(Δ, size(x)), p) + return NoTangent(), NoTangent(), NoTangent(), ∂x, ∂p + end return dx, ∇batched_gradient end diff --git a/src/helpers.jl b/src/helpers.jl index 3a5f975..858dfc1 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -117,3 +117,14 @@ end # Useful for computing the gradient of a gradient function _jacobian_vector_product end function _vector_jacobian_product end +function _value_and_pullback end +function _batched_jacobian end +function _batched_gradient end + +# Test Loaded AD Backend +_assert_loaded_backend(::AutoForwardDiff) = @assert _is_extension_loaded(Val(:ForwardDiff)) +_assert_loaded_backend(::AutoReverseDiff) = @assert _is_extension_loaded(Val(:ReverseDiff)) +_assert_loaded_backend(::AutoFiniteDiff) = @assert _is_extension_loaded(Val(:FiniteDiff)) +_assert_loaded_backend(::AutoZygote) = @assert _is_extension_loaded(Val(:Zygote)) + +CRC.@non_differentiable _assert_loaded_backend(::Any...)