Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Cleanup ReverseDiff support
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 13, 2024
1 parent 1b4d36b commit 2c5d820
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
31 changes: 31 additions & 0 deletions ext/BatchedRoutinesReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module BatchedRoutinesReverseDiffExt

using ADTypes: AutoReverseDiff, AutoForwardDiff
using ArrayInterface: ArrayInterface
using BatchedRoutines: BatchedRoutines, batched_pickchunksize, _assert_type
using ChainRulesCore: ChainRulesCore, NoTangent
using ConcreteStructs: @concrete
Expand All @@ -18,6 +19,36 @@ function BatchedRoutines._batched_gradient(::AutoReverseDiff, f::F, u) where {F}
return ReverseDiff.gradient(f, u)
end

# Chain rules integration
function BatchedRoutines.batched_jacobian(
ad, f::F, x::AbstractMatrix{<:ReverseDiff.TrackedReal}) where {F}
return BatchedRoutines.batched_jacobian(ad, f, ArrayInterface.aos_to_soa(x))
end

function BatchedRoutines.batched_jacobian(
ad, f::F, x::AbstractArray{<:ReverseDiff.TrackedReal},
p::AbstractArray{<:ReverseDiff.TrackedReal}) where {F}
return BatchedRoutines.batched_jacobian(
ad, f, ArrayInterface.aos_to_soa(x), ArrayInterface.aos_to_soa(p))
end

function BatchedRoutines.batched_jacobian(
ad, f::F, x::AbstractArray, p::AbstractArray{<:ReverseDiff.TrackedReal}) where {F}
return BatchedRoutines.batched_jacobian(ad, f, x, ArrayInterface.aos_to_soa(p))
end

ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian(
ad, f, x::ReverseDiff.TrackedArray)

ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian(
ad, f, x::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray)

ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian(
ad, f, x, p::ReverseDiff.TrackedArray)

ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian(
ad, f, x::ReverseDiff.TrackedArray, p)

# TODO: Fix the gradient call over ReverseDiff

@concrete struct ReverseDiffPullbackFunction <: Function
Expand Down
25 changes: 24 additions & 1 deletion src/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ function batched_transpose(X::UniformBlockDiagonalMatrix)
return UniformBlockDiagonalMatrix(batched_transpose(X.data))
end

# To support ReverseDiff
Base.IndexStyle(::Type{<:UniformBlockDiagonalMatrix}) = IndexLinear()

Base.transpose(A::UniformBlockDiagonalMatrix) = batched_transpose(A)

function batched_adjoint(X::UniformBlockDiagonalMatrix)
Expand Down Expand Up @@ -87,6 +90,10 @@ Base.@propagate_inbounds function Base.getindex(
return A.data[i_, j_, k]
end

Base.@propagate_inbounds function Base.getindex(A::UniformBlockDiagonalMatrix, idx::Int)
return getindex(A, mod1(idx, size(A, 1)), (idx - 1) ÷ size(A, 1) + 1)
end

Base.@propagate_inbounds function Base.setindex!(
A::UniformBlockDiagonalMatrix, v, i::Int, j::Int)
i_, j_, k = _block_indices(A, i, j)
Expand All @@ -97,6 +104,11 @@ Base.@propagate_inbounds function Base.setindex!(
return v
end

Base.@propagate_inbounds function Base.setindex!(A::UniformBlockDiagonalMatrix, v, idx::Int)
@show size(A)
return setindex!(A, v, mod1(idx, size(A, 1)), (idx - 1) ÷ size(A, 1) + 1)
end

function _block_indices(A::UniformBlockDiagonalMatrix, i::Int, j::Int)
all((0, 0) .< (i, j) .<= size(A)) || throw(BoundsError(A, (i, j)))

Expand Down Expand Up @@ -247,6 +259,11 @@ function Base.:*(X::AbstractArray{T, 3}, Y::UniformBlockDiagonalMatrix) where {T
return UniformBlockDiagonalMatrix(batched_mul(X, Y.data))
end

function Base.:*(X::AbstractArray{T, 2}, Y::UniformBlockDiagonalMatrix) where {T}
C = reshape(X, 1, :, nbatches(X)) * Y
return dropdims(C.data; dims=1)
end

# LinearAlgebra
abstract type AbstractBatchedMatrixFactorization end

Expand All @@ -266,6 +283,12 @@ function LinearAlgebra.:\(A::AbstractBatchedMatrixFactorization, b::AbstractVect
return X
end

function LinearAlgebra.:\(A::AbstractBatchedMatrixFactorization, b::AbstractMatrix)
X = similar(b, promote_type(eltype(A), eltype(b)), size(A, 1))
LinearAlgebra.ldiv!(X, A, vec(b))
return reshape(X, :, nbatches(b))
end

struct GenericBatchedFactorization{A, F} <: AbstractBatchedMatrixFactorization
alg::A
fact::Vector{F}
Expand Down Expand Up @@ -327,7 +350,7 @@ end
function LinearAlgebra.ldiv!(A::GenericBatchedFactorization, b::AbstractMatrix)
@assert nbatches(A) == nbatches(b)
for i in 1:nbatches(A)
ldiv!(batchview(A, i), batchview(b, i))
LinearAlgebra.ldiv!(batchview(A, i), batchview(b, i))
end
return b
end
Expand Down

0 comments on commit 2c5d820

Please sign in to comment.