Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Efficient gradient over gradient implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 13, 2024
1 parent b770666 commit 3b2654e
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 71 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@

[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle)

This is currently a WIP and is being tested internally for some projects. Once such testing is finish we will register this package in the General registry.
6 changes: 3 additions & 3 deletions ext/BatchedRoutinesFiniteDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ using FiniteDiff: FiniteDiff

# api.jl
## Exposed API
@inline function BatchedRoutines.batched_jacobian(
@inline function BatchedRoutines._batched_jacobian(
ad::AutoFiniteDiff, f::F, x::AbstractVector{T}) where {F, T}
J = FiniteDiff.finite_difference_jacobian(f, x, ad.fdjtype)
(_assert_type(f) && _assert_type(x) && Base.issingletontype(F)) &&
(return UniformBlockDiagonalMatrix(J::parameterless_type(x){T, 2}))
return UniformBlockDiagonalMatrix(J)
end

@inline function BatchedRoutines.batched_jacobian(
@inline function BatchedRoutines._batched_jacobian(
ad::AutoFiniteDiff, f::F, x::AbstractMatrix) where {F}
f! = @closure (y, x_) -> copyto!(y, f(x_))
fx = f(x)
Expand All @@ -31,7 +31,7 @@ end
end

# NOTE: This doesn't exploit batching
@inline function BatchedRoutines.batched_gradient(ad::AutoFiniteDiff, f::F, x) where {F}
@inline function BatchedRoutines._batched_gradient(ad::AutoFiniteDiff, f::F, x) where {F}
return FiniteDiff.finite_difference_batched_gradient(f, x, ad.fdjtype)
end

Expand Down
18 changes: 14 additions & 4 deletions ext/BatchedRoutinesForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ end
end

## Exposed API
@inline function BatchedRoutines.batched_jacobian(
@inline function BatchedRoutines._batched_jacobian(
ad::AutoForwardDiff{CK}, f::F, u::AbstractVector{T}) where {CK, F, T}
tag = ad.tag === nothing ? ForwardDiff.Tag{F, eltype(u)}() : ad.tag
if CK === nothing || CK 0
Expand All @@ -138,13 +138,13 @@ end
return UniformBlockDiagonalMatrix(J)
end

@inline function BatchedRoutines.batched_jacobian(
@inline function BatchedRoutines._batched_jacobian(
ad::AutoForwardDiff, f::F, u::AbstractMatrix) where {F}
return last(BatchedRoutines.__batched_value_and_jacobian(ad, f, u))
end

@inline function BatchedRoutines.batched_gradient(
ad::AutoForwardDiff{CK}, f::F, u::AbstractMatrix) where {F, CK}
@inline function BatchedRoutines._batched_gradient(
ad::AutoForwardDiff{CK}, f::F, u) where {F, CK}
tag = ad.tag === nothing ? ForwardDiff.Tag{F, eltype(u)}() : ad.tag
if CK === nothing || CK 0
cfg = ForwardDiff.GradientConfig(
Expand All @@ -167,4 +167,14 @@ function BatchedRoutines._jacobian_vector_product(ad::AutoForwardDiff, f::F, x,
return ForwardDiff.partials.(y_dual, 1)
end

function BatchedRoutines._jacobian_vector_product(
ad::AutoForwardDiff, f::F, x, u, p) where {F}
Tag = ad.tag === nothing ? typeof(ForwardDiff.Tag(f, eltype(x))) : typeof(ad.tag)
T = promote_type(eltype(x), eltype(u))
partials = ForwardDiff.Partials{1, T}.(tuple.(u))
u_dual = ForwardDiff.Dual{Tag, T, 1}.(u, partials)
y_dual = f(u_dual, p)
return ForwardDiff.partials.(y_dual, 1)
end

end
24 changes: 2 additions & 22 deletions ext/BatchedRoutinesReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ const CRC = ChainRulesCore
Base.@assume_effects :total BatchedRoutines._assert_type(::Type{<:ReverseDiff.TrackedArray})=false
Base.@assume_effects :total BatchedRoutines._assert_type(::Type{<:AbstractArray{<:ReverseDiff.TrackedReal}})=false

function BatchedRoutines.batched_gradient(
::AutoReverseDiff, f::F, u::AbstractMatrix) where {F}
function BatchedRoutines._batched_gradient(::AutoReverseDiff, f::F, u) where {F}
return ReverseDiff.gradient(f, u)
end

Expand All @@ -37,7 +36,7 @@ function (pb_f::ReverseDiffPullbackFunction)(Δ)
return pb_f.∂input
end

function _value_and_pullback(f::F, x) where {F}
function BatchedRoutines._value_and_pullback(::AutoReverseDiff, f::F, x) where {F}
tape = ReverseDiff.InstructionTape()
∂x = zero(x)
x_tracked = ReverseDiff.TrackedArray(x, ∂x, tape)
Expand All @@ -52,23 +51,4 @@ function _value_and_pullback(f::F, x) where {F}
return y, ReverseDiffPullbackFunction(tape, ∂x, y_tracked)
end

function CRC.rrule(::typeof(BatchedRoutines.batched_gradient),
ad::AutoReverseDiff, f::F, x::AbstractMatrix) where {F}
if BatchedRoutines._is_extension_loaded(Val(:ForwardDiff))
dx = BatchedRoutines.batched_gradient(ad, f, x)
# Use Forward Over Reverse to compute the Hessian Vector Product
∇batched_gradient = @closure Δ -> begin
∂x = BatchedRoutines._jacobian_vector_product(
AutoForwardDiff(), @closure(x->BatchedRoutines.batched_gradient(ad, f, x)),
x, reshape(Δ, size(x)))
return NoTangent(), NoTangent(), NoTangent(), ∂x
end
return dx, ∇batched_gradient
end

dx, pb_f = _value_and_pullback(Base.Fix1(ReverseDiff.gradient, f), x)
∇batched_gradient = @closure Δ -> (NoTangent(), NoTangent(), NoTangent(), pb_f(Δ))
return dx, ∇batched_gradient
end

end
24 changes: 2 additions & 22 deletions ext/BatchedRoutinesZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,10 @@ const CRC = ChainRulesCore

@inline BatchedRoutines._is_extension_loaded(::Val{:Zygote}) = true

function BatchedRoutines.batched_gradient(::AutoZygote, f::F, u::AbstractMatrix) where {F}
function BatchedRoutines._batched_gradient(::AutoZygote, f::F, u) where {F}
return only(Zygote.gradient(f, u))
end

function CRC.rrule(::typeof(BatchedRoutines.batched_gradient),
ad::AutoZygote, f::F, x::AbstractMatrix) where {F}
if BatchedRoutines._is_extension_loaded(Val(:ForwardDiff))
dx = BatchedRoutines.batched_gradient(ad, f, x)
# Use Forward Over Reverse to compute the Hessian Vector Product
∇batched_gradient = @closure Δ -> begin
∂x = BatchedRoutines._jacobian_vector_product(
AutoForwardDiff(), @closure(x->BatchedRoutines.batched_gradient(ad, f, x)),
x, reshape(Δ, size(x)))
return NoTangent(), NoTangent(), NoTangent(), ∂x
end
return dx, ∇batched_gradient
end

dx, pb_f = Zygote.pullback(@closure(x->only(Zygote.gradient(f, x))), x)
∇batched_gradient = @closure Δ -> begin
∂x = only(pb_f(Δ)) # Else we have to do Zygote over Zygote
return NoTangent(), NoTangent(), NoTangent(), ∂x
end
return dx, ∇batched_gradient
end
BatchedRoutines._value_and_pullback(::AutoZygote, f::F, x) where {F} = Zygote.pullback(f, x)

end
21 changes: 9 additions & 12 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ Use the backend `ad` to compute the Jacobian of `f` at `x` in batched mode. Retu
If the batches interact among themselves, then the Jacobian is not block diagonal and
this function will not work as expected.
"""
function batched_jacobian end
function batched_jacobian(ad, f::F, u::AbstractVecOrMat) where {F}
_assert_loaded_backend(ad)
return _batched_jacobian(ad, f, u)
end

@inline function batched_jacobian(ad, f::F, u::AbstractArray) where {F}
B = size(u, ndims(u))
Expand All @@ -29,19 +32,13 @@ Use the backend `ad` to compute the batched_gradient of `f` at `x`. For the forw
different from calling the batched_gradient function in the backend. This exists to efficiently swap
backends for computing the `batched_gradient` of the `batched_gradient`.
"""
function batched_gradient end

function batched_gradient(ad, f::F, u::AbstractVector) where {F}
return vec(batched_gradient(ad, f, reshape(u, 1, :)))
function batched_gradient(ad, f::F, u) where {F}
_assert_loaded_backend(ad)
return _batched_gradient(ad, f, u)
end

function batched_gradient(ad, f::F, u::AbstractArray) where {F}
B = size(u, ndims(u))
f_mat = @closure x -> reshape(f(reshape(x, size(u))), :, B)
return reshape(batched_gradient(ad, f_mat, reshape(u, :, B)), size(u))
end

@inline batched_gradient(ad, f::F, u, p) where {F} = batched_gradient(ad, Base.Fix2(f, p), u)
@inline batched_gradient(ad, f::F, u, p) where {F} = batched_gradient(
ad, Base.Fix2(f, p), u)

"""
batched_pickchunksize(X::AbstractArray, n::Int)
Expand Down
35 changes: 27 additions & 8 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ function __batched_value_and_jacobian(ad, f::F, x) where {F}
return f(x), J
end

# Reverse over Forward: Just construct Hessian for now
# FIXME: Gradient of jacobians is really in-efficient here
function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x::AbstractMatrix) where {F}
N, B = size(x)
J, H = __batched_value_and_jacobian(
Expand All @@ -23,7 +23,6 @@ function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x, p) where {F}
J, H = __batched_value_and_jacobian(
ad, @closure(y->reshape(batched_jacobian(ad, f, y, p).data, :, B)), x)

# TODO: This can be written as a JVP
p_size = size(p)
_, Jₚ_ = __batched_value_and_jacobian(
ad, @closure(p->reshape(batched_jacobian(ad, f, x, reshape(p, p_size)).data, :, B)),
Expand All @@ -40,16 +39,36 @@ function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x, p) where {F}
return UniformBlockDiagonalMatrix(reshape(J, :, N, B)), ∇batched_jacobian
end

function CRC.rrule(::typeof(batched_gradient), ad, f::F, x::AbstractMatrix) where {F}
N, B = size(x)
dx, J = BatchedRoutines.__batched_value_and_jacobian(
ad, @closure(x->batched_gradient(ad, f, x)), x)
function CRC.rrule(::typeof(batched_gradient), ad, f::F, x) where {F}
BatchedRoutines._is_extension_loaded(Val(:ForwardDiff)) ||
throw(ArgumentError("`ForwardDiff.jl` needs to be loaded to compute the gradient \
of `batched_gradient`."))

function ∇batched_gradient(Δ)
∂x = reshape(batched_mul(reshape(Δ, 1, :, nbatches(Δ)), J.data), :, nbatches(Δ))
dx = BatchedRoutines.batched_gradient(ad, f, x)
∇batched_gradient = @closure Δ -> begin
∂x = _jacobian_vector_product(
AutoForwardDiff(), @closure(x->BatchedRoutines.batched_gradient(ad, f, x)),
x, reshape(Δ, size(x)))
return NoTangent(), NoTangent(), NoTangent(), ∂x
end
return dx, ∇batched_gradient
end

function CRC.rrule(::typeof(batched_gradient), ad, f::F, x, p) where {F}
BatchedRoutines._is_extension_loaded(Val(:ForwardDiff)) ||
throw(ArgumentError("`ForwardDiff.jl` needs to be loaded to compute the gradient \
of `batched_gradient`."))

dx = BatchedRoutines.batched_gradient(ad, f, x, p)
∇batched_gradient = @closure Δ -> begin
∂x = _jacobian_vector_product(AutoForwardDiff(),
@closure(x->BatchedRoutines.batched_gradient(ad, Base.Fix2(f, p), x)),
x, reshape(Δ, size(x)))
∂p = _jacobian_vector_product(AutoForwardDiff(),
@closure((x, p)->BatchedRoutines.batched_gradient(ad, Base.Fix1(f, x), p)),
x, reshape(Δ, size(x)), p)
return NoTangent(), NoTangent(), NoTangent(), ∂x, ∂p
end
return dx, ∇batched_gradient
end

Expand Down
11 changes: 11 additions & 0 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,14 @@ end
# Useful for computing the gradient of a gradient
function _jacobian_vector_product end
function _vector_jacobian_product end
function _value_and_pullback end
function _batched_jacobian end
function _batched_gradient end

# Test Loaded AD Backend
_assert_loaded_backend(::AutoForwardDiff) = @assert _is_extension_loaded(Val(:ForwardDiff))
_assert_loaded_backend(::AutoReverseDiff) = @assert _is_extension_loaded(Val(:ReverseDiff))
_assert_loaded_backend(::AutoFiniteDiff) = @assert _is_extension_loaded(Val(:FiniteDiff))
_assert_loaded_backend(::AutoZygote) = @assert _is_extension_loaded(Val(:Zygote))

CRC.@non_differentiable _assert_loaded_backend(::Any...)

0 comments on commit 3b2654e

Please sign in to comment.