From a37f69464492f1d7bc2fb22ea093b86698fd3e0d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 18 Apr 2024 14:18:03 -0400 Subject: [PATCH] Make nested AD work for non-compiled ReverseDiff --- ext/BatchedRoutinesReverseDiffExt.jl | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/ext/BatchedRoutinesReverseDiffExt.jl b/ext/BatchedRoutinesReverseDiffExt.jl index 6c0678c..de2f982 100644 --- a/ext/BatchedRoutinesReverseDiffExt.jl +++ b/ext/BatchedRoutinesReverseDiffExt.jl @@ -2,7 +2,8 @@ module BatchedRoutinesReverseDiffExt using ADTypes: AutoReverseDiff, AutoForwardDiff using ArrayInterface: ArrayInterface -using BatchedRoutines: BatchedRoutines, batched_pickchunksize, _assert_type +using BatchedRoutines: BatchedRoutines, batched_pickchunksize, _assert_type, + UniformBlockDiagonalOperator, getdata using ChainRulesCore: ChainRulesCore, NoTangent using ConcreteStructs: @concrete using FastClosures: @closure @@ -30,6 +31,21 @@ function BatchedRoutines._batched_gradient(::AutoReverseDiff, f::F, u) where {F} return ∂u end +# ReverseDiff compatible `UniformBlockDiagonalOperator` +@inline function ReverseDiff.track( + op::UniformBlockDiagonalOperator, tp::ReverseDiff.InstructionTape) + return UniformBlockDiagonalOperator(ReverseDiff.track(getdata(op), tp)) +end + +@inline function ReverseDiff.deriv(x::UniformBlockDiagonalOperator) + return UniformBlockDiagonalOperator(ReverseDiff.deriv(getdata(x))) +end + +@inline function ReverseDiff.value!( + op::UniformBlockDiagonalOperator, val::UniformBlockDiagonalOperator) + ReverseDiff.value!(getdata(op), getdata(val)) +end + # Chain rules integration function BatchedRoutines.batched_jacobian( ad, f::F, x::AbstractMatrix{<:ReverseDiff.TrackedReal}) where {F}