Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Make nested AD work for non-compiled ReverseDiff
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 18, 2024
1 parent 3022cf9 commit a37f694
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion ext/BatchedRoutinesReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module BatchedRoutinesReverseDiffExt

using ADTypes: AutoReverseDiff, AutoForwardDiff
using ArrayInterface: ArrayInterface
using BatchedRoutines: BatchedRoutines, batched_pickchunksize, _assert_type
using BatchedRoutines: BatchedRoutines, batched_pickchunksize, _assert_type,
UniformBlockDiagonalOperator, getdata
using ChainRulesCore: ChainRulesCore, NoTangent
using ConcreteStructs: @concrete
using FastClosures: @closure
Expand Down Expand Up @@ -30,6 +31,21 @@ function BatchedRoutines._batched_gradient(::AutoReverseDiff, f::F, u) where {F}
return ∂u
end

# ReverseDiff compatible `UniformBlockDiagonalOperator`
@inline function ReverseDiff.track(

Check warning on line 35 in ext/BatchedRoutinesReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesReverseDiffExt.jl#L35

Added line #L35 was not covered by tests
op::UniformBlockDiagonalOperator, tp::ReverseDiff.InstructionTape)
return UniformBlockDiagonalOperator(ReverseDiff.track(getdata(op), tp))

Check warning on line 37 in ext/BatchedRoutinesReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesReverseDiffExt.jl#L37

Added line #L37 was not covered by tests
end

@inline function ReverseDiff.deriv(x::UniformBlockDiagonalOperator)
return UniformBlockDiagonalOperator(ReverseDiff.deriv(getdata(x)))

Check warning on line 41 in ext/BatchedRoutinesReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesReverseDiffExt.jl#L40-L41

Added lines #L40 - L41 were not covered by tests
end

@inline function ReverseDiff.value!(

Check warning on line 44 in ext/BatchedRoutinesReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesReverseDiffExt.jl#L44

Added line #L44 was not covered by tests
op::UniformBlockDiagonalOperator, val::UniformBlockDiagonalOperator)
ReverseDiff.value!(getdata(op), getdata(val))

Check warning on line 46 in ext/BatchedRoutinesReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesReverseDiffExt.jl#L46

Added line #L46 was not covered by tests
end

# Chain rules integration
function BatchedRoutines.batched_jacobian(
ad, f::F, x::AbstractMatrix{<:ReverseDiff.TrackedReal}) where {F}
Expand Down

0 comments on commit a37f694

Please sign in to comment.