Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
ForwardDiff nested gradient fix
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 15, 2024
1 parent 7fa4ddd commit 9867ee1
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 29 deletions.
27 changes: 18 additions & 9 deletions ext/BatchedRoutinesForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@ end
if CK === nothing || CK 0
push!(calls, :(ck = ForwardDiff.Chunk{ForwardDiff.pickchunksize(length(u))}()))
else
push!(calls, :(ck = ForwardDiff.Chunk{CK}()))
push!(calls, quote
@assert CK length(u) "Chunk size must be ≤ the length of u"
ck = ForwardDiff.Chunk{CK}()
end)
end
push!(calls, :(return _forwarddiff_gradient(f, u, typeof(tag), ck)))
return Expr(:block, calls...)
Expand All @@ -165,22 +168,25 @@ function _forwarddiff_gradient(f::F, u::AbstractArray{T}, ::Type{Tag},
Dual = ForwardDiff.Dual{Tag, T, CK}
Partials = ForwardDiff.Partials{CK, T}

gs = similar(u)
for i in 1:nchunks
_forwarddiff_gradient!(gs, (i - 1) * CK + 1, ck, Tag, Dual, Partials, f, u)
gs_first = _forwarddiff_gradient!!(nothing, 1, ck, Tag, Dual, Partials, f, u)
gs_ = similar(u, eltype(gs_first), size(u))
gs = vec(gs_)
gs[1:CK] .= gs_first
for i in 2:nchunks
_forwarddiff_gradient!!(gs, (i - 1) * CK + 1, ck, Tag, Dual, Partials, f, u)
end

if remainder > 0
Dual_rem = ForwardDiff.Dual{Tag, T, remainder}
Partials_rem = ForwardDiff.Partials{remainder, T}
_forwarddiff_gradient!(gs, nchunks * CK + 1, ForwardDiff.Chunk{remainder}(),
_forwarddiff_gradient!!(gs, nchunks * CK + 1, ForwardDiff.Chunk{remainder}(),
Tag, Dual_rem, Partials_rem, f, u)
end

return gs
return gs_
end

@views function _forwarddiff_gradient!(
@views function _forwarddiff_gradient!!(
gs, idx::Int, ::ForwardDiff.Chunk{CK}, ::Type{Tag}, ::Type{Dual},
::Type{Partials}, f::F, u::AbstractArray{T}) where {CK, Tag, Dual, Partials, F, T}
N = length(u)
Expand Down Expand Up @@ -210,6 +216,7 @@ end
u_duals = reshape(vcat(u_part_prev, u_part_duals, u_part_next), size(u))
y_duals = f(u_duals)

gs === nothing && return ForwardDiff.partials(y_duals)
gs[idxs] .= ForwardDiff.partials(y_duals)
return
end
Expand All @@ -220,7 +227,8 @@ Base.@assume_effects :total BatchedRoutines._assert_type(::Type{<:AbstractArray{
function BatchedRoutines._jacobian_vector_product(ad::AutoForwardDiff, f::F, x, u) where {F}
Tag = ad.tag === nothing ? typeof(ForwardDiff.Tag(f, eltype(x))) : typeof(ad.tag)
T = promote_type(eltype(x), eltype(u))
partials = ForwardDiff.Partials{1, T}.(tuple.(u))
dev = get_device(x)
partials = ForwardDiff.Partials{1, T}.(tuple.(u)) |> dev
x_dual = ForwardDiff.Dual{Tag, T, 1}.(x, partials)
y_dual = f(x_dual)
return ForwardDiff.partials.(y_dual, 1)
Expand All @@ -230,7 +238,8 @@ function BatchedRoutines._jacobian_vector_product(
ad::AutoForwardDiff, f::F, x, u, p) where {F}
Tag = ad.tag === nothing ? typeof(ForwardDiff.Tag(f, eltype(x))) : typeof(ad.tag)
T = promote_type(eltype(x), eltype(u))
partials = ForwardDiff.Partials{1, T}.(tuple.(u))
dev = get_device(x)
partials = ForwardDiff.Partials{1, T}.(tuple.(u)) |> dev
x_dual = ForwardDiff.Dual{Tag, T, 1}.(x, partials)
y_dual = f(x_dual, p)
return ForwardDiff.partials.(y_dual, 1)
Expand Down
9 changes: 6 additions & 3 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,12 @@ function CRC.rrule(::typeof(batched_gradient), ad, f::F, x, p) where {F}
∂x = _jacobian_vector_product(
AutoForwardDiff(), @closure(x->batched_gradient(ad, Base.Fix2(f, p), x)),
x, reshape(Δ, size(x)))
∂p = _jacobian_vector_product(
AutoForwardDiff(), @closure((x, p)->batched_gradient(ad, Base.Fix1(f, x), p)),
x, reshape(Δ, size(x)), p)
∂p = _jacobian_vector_product(AutoForwardDiff(),
@closure((x, p)->batched_gradient(
_maybe_remove_chunksize(ad, p), Base.Fix1(f, x), p)),
x,
reshape(Δ, size(x)),
p)
return NoTangent(), NoTangent(), NoTangent(), ∂x, ∂p
end
return dx, ∇batched_gradient
Expand Down
7 changes: 7 additions & 0 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,10 @@ _assert_loaded_backend(::AutoFiniteDiff) = @assert _is_extension_loaded(Val(:Fin
_assert_loaded_backend(::AutoZygote) = @assert _is_extension_loaded(Val(:Zygote))

CRC.@non_differentiable _assert_loaded_backend(::Any...)

# Chunksize remove
_maybe_remove_chunksize(ad, x) = ad
function _maybe_remove_chunksize(ad::AutoAllForwardDiff{CK}, x) where {CK}
(CK === nothing || CK 0 || CK length(x)) && return ad
return parameterless_type(ad)()
end
32 changes: 15 additions & 17 deletions test/autodiff_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
J_fdiff = batched_jacobian(
AutoFiniteDiff(), simple_batched_function, Array(X), Array(p))
J_fwdiff = batched_jacobian(AutoForwardDiff(), simple_batched_function, X, p)
J_fwdiff2 = batched_jacobian(
AutoForwardDiff(; chunksize=2), simple_batched_function, X, p)

@test Matrix(J_fdiff)Matrix(J_fwdiff) atol=1e-3
@test Matrix(J_fwdiff)Matrix(J_fwdiff2)
end
end
end
Expand Down Expand Up @@ -63,23 +66,18 @@ end
AutoZygote(), simple_batched_function, Array(X), p)),
Array(p))

for backend in (
# AutoFiniteDiff(), # FIXME: FiniteDiff doesn't work well with ForwardDiff
# AutoForwardDiff(), # FIXME: The return type doesn't match
# AutoReverseDiff(), # FIXME: ReverseDiff with ForwardDiff problematic
AutoZygote(),)
arrType = backend isa AutoFiniteDiff || backend isa AutoReverseDiff ?
Array : identity

gs_zyg = Zygote.gradient(
(x, p) -> sum(
abs2, batched_gradient(backend, simple_batched_function, x, p)),
arrType(X),
arrType(p))
gs_rdiff = ReverseDiff.gradient(
(x, p) -> sum(
abs2, batched_gradient(backend, simple_batched_function, x, p)),
(Array(X), Array(p)))
@testset "backend: $(backend)" for backend in (
# AutoFiniteDiff(), # FIXME: FiniteDiff doesn't work well with ForwardDiff
AutoForwardDiff(), AutoForwardDiff(; chunksize=3),
# AutoReverseDiff(), # FIXME: ReverseDiff with ForwardDiff problematic
AutoZygote())
(!(backend isa AutoZygote) && ongpu) && continue

__f = (x, p) -> sum(
abs2, batched_gradient(backend, simple_batched_function, x, p))

gs_zyg = Zygote.gradient(__f, X, p)
gs_rdiff = ReverseDiff.gradient(__f, (Array(X), Array(p)))

@test Array(gs_fwddiff_x)Array(gs_zyg[1]) atol=1e-3
@test Array(gs_fwddiff_p)Array(gs_zyg[2]) atol=1e-3
Expand Down

0 comments on commit 9867ee1

Please sign in to comment.