From 2c5d820f8e52cef18abff07a580dcb977da88cf0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 13 Mar 2024 18:20:33 -0400 Subject: [PATCH] Cleanup ReverseDiff support --- ext/BatchedRoutinesReverseDiffExt.jl | 31 ++++++++++++++++++++++++++++ src/matrix.jl | 25 +++++++++++++++++++++- 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/ext/BatchedRoutinesReverseDiffExt.jl b/ext/BatchedRoutinesReverseDiffExt.jl index ce9b3da..ad8de30 100644 --- a/ext/BatchedRoutinesReverseDiffExt.jl +++ b/ext/BatchedRoutinesReverseDiffExt.jl @@ -1,6 +1,7 @@ module BatchedRoutinesReverseDiffExt using ADTypes: AutoReverseDiff, AutoForwardDiff +using ArrayInterface: ArrayInterface using BatchedRoutines: BatchedRoutines, batched_pickchunksize, _assert_type using ChainRulesCore: ChainRulesCore, NoTangent using ConcreteStructs: @concrete @@ -18,6 +19,36 @@ function BatchedRoutines._batched_gradient(::AutoReverseDiff, f::F, u) where {F} return ReverseDiff.gradient(f, u) end +# Chain rules integration +function BatchedRoutines.batched_jacobian( + ad, f::F, x::AbstractMatrix{<:ReverseDiff.TrackedReal}) where {F} + return BatchedRoutines.batched_jacobian(ad, f, ArrayInterface.aos_to_soa(x)) +end + +function BatchedRoutines.batched_jacobian( + ad, f::F, x::AbstractArray{<:ReverseDiff.TrackedReal}, + p::AbstractArray{<:ReverseDiff.TrackedReal}) where {F} + return BatchedRoutines.batched_jacobian( + ad, f, ArrayInterface.aos_to_soa(x), ArrayInterface.aos_to_soa(p)) +end + +function BatchedRoutines.batched_jacobian( + ad, f::F, x::AbstractArray, p::AbstractArray{<:ReverseDiff.TrackedReal}) where {F} + return BatchedRoutines.batched_jacobian(ad, f, x, ArrayInterface.aos_to_soa(p)) +end + +ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian( + ad, f, x::ReverseDiff.TrackedArray) + +ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian( + ad, f, x::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray) + +ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian( + ad, f, x, p::ReverseDiff.TrackedArray) + +ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian( + ad, f, x::ReverseDiff.TrackedArray, p) + # TODO: Fix the gradient call over ReverseDiff @concrete struct ReverseDiffPullbackFunction <: Function diff --git a/src/matrix.jl b/src/matrix.jl index 511959b..edb0618 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -18,6 +18,9 @@ function batched_transpose(X::UniformBlockDiagonalMatrix) return UniformBlockDiagonalMatrix(batched_transpose(X.data)) end +# To support ReverseDiff +Base.IndexStyle(::Type{<:UniformBlockDiagonalMatrix}) = IndexLinear() + Base.transpose(A::UniformBlockDiagonalMatrix) = batched_transpose(A) function batched_adjoint(X::UniformBlockDiagonalMatrix) @@ -87,6 +90,10 @@ Base.@propagate_inbounds function Base.getindex( return A.data[i_, j_, k] end +Base.@propagate_inbounds function Base.getindex(A::UniformBlockDiagonalMatrix, idx::Int) + return getindex(A, mod1(idx, size(A, 1)), (idx - 1) ÷ size(A, 1) + 1) +end + Base.@propagate_inbounds function Base.setindex!( A::UniformBlockDiagonalMatrix, v, i::Int, j::Int) i_, j_, k = _block_indices(A, i, j) @@ -97,6 +104,11 @@ Base.@propagate_inbounds function Base.setindex!( return v end +Base.@propagate_inbounds function Base.setindex!(A::UniformBlockDiagonalMatrix, v, idx::Int) + @show size(A) + return setindex!(A, v, mod1(idx, size(A, 1)), (idx - 1) ÷ size(A, 1) + 1) +end + function _block_indices(A::UniformBlockDiagonalMatrix, i::Int, j::Int) all((0, 0) .< (i, j) .<= size(A)) || throw(BoundsError(A, (i, j))) @@ -247,6 +259,11 @@ function Base.:*(X::AbstractArray{T, 3}, Y::UniformBlockDiagonalMatrix) where {T return UniformBlockDiagonalMatrix(batched_mul(X, Y.data)) end +function Base.:*(X::AbstractArray{T, 2}, Y::UniformBlockDiagonalMatrix) where {T} + C = reshape(X, 1, :, nbatches(X)) * Y + return dropdims(C.data; dims=1) +end + # LinearAlgebra abstract type AbstractBatchedMatrixFactorization end @@ -266,6 +283,12 @@ function LinearAlgebra.:\(A::AbstractBatchedMatrixFactorization, b::AbstractVect return X end +function LinearAlgebra.:\(A::AbstractBatchedMatrixFactorization, b::AbstractMatrix) + X = similar(b, promote_type(eltype(A), eltype(b)), size(A, 1)) + LinearAlgebra.ldiv!(X, A, vec(b)) + return reshape(X, :, nbatches(b)) +end + struct GenericBatchedFactorization{A, F} <: AbstractBatchedMatrixFactorization alg::A fact::Vector{F} @@ -327,7 +350,7 @@ end function LinearAlgebra.ldiv!(A::GenericBatchedFactorization, b::AbstractMatrix) @assert nbatches(A) == nbatches(b) for i in 1:nbatches(A) - ldiv!(batchview(A, i), batchview(b, i)) + LinearAlgebra.ldiv!(batchview(A, i), batchview(b, i)) end return b end