Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Special case for handling CAs
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 1, 2024
1 parent 92e957b commit 256f0cc
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 5 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Expand All @@ -26,6 +27,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
BatchedRoutinesCUDAExt = ["CUDA"]
BatchedRoutinesComponentArraysForwardDiffExt = ["ComponentArrays", "ForwardDiff"]
BatchedRoutinesCUDALinearSolveExt = ["CUDA", "LinearSolve"]
BatchedRoutinesFiniteDiffExt = ["FiniteDiff"]
BatchedRoutinesForwardDiffExt = ["ForwardDiff"]
Expand Down
12 changes: 12 additions & 0 deletions ext/BatchedRoutinesComponentArraysForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module BatchedRoutinesComponentArraysForwardDiffExt

using BatchedRoutines: BatchedRoutines
using ComponentArrays: ComponentArrays, ComponentArray
using ForwardDiff: ForwardDiff

@inline function BatchedRoutines._restructure(y, x::ComponentArray)
x_data = ComponentArrays.getdata(x)
return ComponentArray(reshape(y, size(x_data)), ComponentArrays.getaxes(x))

Check warning on line 9 in ext/BatchedRoutinesComponentArraysForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesComponentArraysForwardDiffExt.jl#L7-L9

Added lines #L7 - L9 were not covered by tests
end

end
8 changes: 4 additions & 4 deletions ext/BatchedRoutinesForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ end
u_part_next = Dual.(u[idxs_next], dev(Partials.(map(nt, 1:length(idxs_next)))))
end

u_duals = reshape(vcat(u_part_prev, u_part_duals, u_part_next), size(u))
u_duals = BatchedRoutines._restructure(vcat(u_part_prev, u_part_duals, u_part_next), u)
y_duals = f(u_duals)

gs === nothing && return ForwardDiff.partials(y_duals)
Expand All @@ -224,20 +224,20 @@ Base.@assume_effects :total BatchedRoutines._assert_type(::Type{<:AbstractArray{

function BatchedRoutines._jacobian_vector_product(ad::AutoForwardDiff, f::F, x, u) where {F}
Tag = ad.tag === nothing ? typeof(ForwardDiff.Tag(f, eltype(x))) : typeof(ad.tag)
x_dual = _construct_jvp_duals(Tag, x, u)
x_dual = BatchedRoutines._construct_jvp_duals(Tag, x, u)
y_dual = f(x_dual)
return ForwardDiff.partials.(y_dual, 1)
end

function BatchedRoutines._jacobian_vector_product(
ad::AutoForwardDiff, f::F, x, u, p) where {F}
Tag = ad.tag === nothing ? typeof(ForwardDiff.Tag(f, eltype(x))) : typeof(ad.tag)
x_dual = _construct_jvp_duals(Tag, x, u)
x_dual = BatchedRoutines._construct_jvp_duals(Tag, x, u)
y_dual = f(x_dual, p)
return ForwardDiff.partials.(y_dual, 1)
end

@inline function _construct_jvp_duals(::Type{Tag}, x, u) where {Tag}
@inline function BatchedRoutines._construct_jvp_duals(::Type{Tag}, x, u) where {Tag}
T = promote_type(eltype(x), eltype(u))
partials = ForwardDiff.Partials{1, T}.(tuple.(u))
return ForwardDiff.Dual{Tag, T, 1}.(x, reshape(partials, size(x)))
Expand Down
3 changes: 3 additions & 0 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ function _jacobian_vector_product end
function _vector_jacobian_product end
function _batched_jacobian end
function _batched_gradient end
function _construct_jvp_duals end

@inline _restructure(y, x) = reshape(y, size(x))

# Test Loaded AD Backend
_assert_loaded_backend(::AutoForwardDiff) = @assert _is_extension_loaded(Val(:ForwardDiff))
Expand Down
2 changes: 1 addition & 1 deletion src/operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ end
end

@inline function Base.:*(x::AbstractMatrix, op::UniformBlockDiagonalOperator)
return (reshape(x, :, 1, nbatches(x)) * op) |> (dropdims$(; dims=2))
return dropdims(reshape(x, :, 1, nbatches(x)) * op; dims=1)

Check warning on line 54 in src/operator.jl

View check run for this annotation

Codecov / codecov/patch

src/operator.jl#L53-L54

Added lines #L53 - L54 were not covered by tests
end

@inline function Base.:*(x::AbstractArray{T, 3}, op::UniformBlockDiagonalOperator) where {T}
Expand Down

0 comments on commit 256f0cc

Please sign in to comment.