Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Add a pullback for reversediff
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 12, 2024
1 parent dec1acd commit b770666
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 24 deletions.
59 changes: 36 additions & 23 deletions ext/BatchedRoutinesReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module BatchedRoutinesReverseDiffExt
using ADTypes: AutoReverseDiff, AutoForwardDiff
using BatchedRoutines: BatchedRoutines, batched_pickchunksize, _assert_type
using ChainRulesCore: ChainRulesCore, NoTangent
using ConcreteStructs: @concrete
using FastClosures: @closure
using ReverseDiff: ReverseDiff

Expand All @@ -18,43 +19,55 @@ function BatchedRoutines.batched_gradient(
return ReverseDiff.gradient(f, u)
end

function CRC.rrule(::typeof(BatchedRoutines.batched_gradient),
ad::AutoReverseDiff, f::F, x::AbstractMatrix) where {F}
if BatchedRoutines._is_extension_loaded(Val(:ForwardDiff))
dx = BatchedRoutines.batched_gradient(ad, f, x)
# Use Forward Over Reverse to compute the Hessian Vector Product
∇batched_gradient = @closure Δ -> begin
∂x = BatchedRoutines._jacobian_vector_product(
AutoForwardDiff(), @closure(x->BatchedRoutines.batched_gradient(ad, f, x)),
x, reshape(Δ, size(x)))
return NoTangent(), NoTangent(), NoTangent(), ∂x
@concrete struct ReverseDiffPullbackFunction <: Function
tape
∂input
output
end

function (pb_f::ReverseDiffPullbackFunction)(Δ)
if pb_f.output isa AbstractArray{<:ReverseDiff.TrackedReal}
@inbounds for (oᵢ, Δᵢ) in zip(vec(pb_f.output), vec))
oᵢ.deriv = Δᵢ
end
return dx, ∇batched_gradient
else
vec(pb_f.output.deriv) .= vec(Δ)
end
ReverseDiff.reverse_pass!(pb_f.tape)
return pb_f.∂input
end

function _value_and_pullback(f::F, x) where {F}
tape = ReverseDiff.InstructionTape()
∂x = zero(x)
x_tracked = ReverseDiff.TrackedArray(x, ∂x, tape)
y_tracked = ReverseDiff.gradient(f, x_tracked)
y_tracked = f(x_tracked)

if y_tracked isa AbstractArray{<:ReverseDiff.TrackedReal}
dx = ReverseDiff.value.(y_tracked)
y = ReverseDiff.value.(y_tracked)
else
dx = ReverseDiff.value(y_tracked)
y = ReverseDiff.value(y_tracked)
end

∇batched_gradient = @closure Δ -> begin
if y_tracked isa AbstractArray{<:ReverseDiff.TrackedReal}
@inbounds for (oᵢ, Δᵢ) in zip(vec(y_tracked), vec(Δ))
oᵢ.deriv = Δᵢ
end
else
vec(y_tracked.deriv) .= vec(Δ)
return y, ReverseDiffPullbackFunction(tape, ∂x, y_tracked)
end

function CRC.rrule(::typeof(BatchedRoutines.batched_gradient),
ad::AutoReverseDiff, f::F, x::AbstractMatrix) where {F}
if BatchedRoutines._is_extension_loaded(Val(:ForwardDiff))
dx = BatchedRoutines.batched_gradient(ad, f, x)
# Use Forward Over Reverse to compute the Hessian Vector Product
∇batched_gradient = @closure Δ -> begin
∂x = BatchedRoutines._jacobian_vector_product(
AutoForwardDiff(), @closure(x->BatchedRoutines.batched_gradient(ad, f, x)),
x, reshape(Δ, size(x)))
return NoTangent(), NoTangent(), NoTangent(), ∂x
end
ReverseDiff.reverse_pass!(tape)
return NoTangent(), NoTangent(), NoTangent(), ∂x
return dx, ∇batched_gradient
end

dx, pb_f = _value_and_pullback(Base.Fix1(ReverseDiff.gradient, f), x)
∇batched_gradient = @closure Δ -> (NoTangent(), NoTangent(), NoTangent(), pb_f(Δ))
return dx, ∇batched_gradient
end

Expand Down
2 changes: 1 addition & 1 deletion src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x, p) where {F}
J, H = __batched_value_and_jacobian(
ad, @closure(y->reshape(batched_jacobian(ad, f, y, p).data, :, B)), x)

# TODO: For `CPU` arrays we can do ReverseDiff over ForwardDiff
# TODO: This can be written as a JVP
p_size = size(p)
_, Jₚ_ = __batched_value_and_jacobian(
ad, @closure(p->reshape(batched_jacobian(ad, f, x, reshape(p, p_size)).data, :, B)),
Expand Down

0 comments on commit b770666

Please sign in to comment.