Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
ForwardDiff GPU gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 15, 2024
1 parent 951f035 commit a57e936
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 6 deletions.
71 changes: 65 additions & 6 deletions ext/BatchedRoutinesForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,75 @@ end
return last(BatchedRoutines.__batched_value_and_jacobian(ad, f, u))
end

@inline function BatchedRoutines._batched_gradient(
# We don't use the ForwardDiff.gradient since it causes GPU compilation errors due to
# scalar indexing
@generated function BatchedRoutines._batched_gradient(
ad::AutoForwardDiff{CK}, f::F, u) where {F, CK}
tag = ad.tag === nothing ? ForwardDiff.Tag{F, eltype(u)}() : ad.tag
calls = [:(tag = ad.tag === nothing ? ForwardDiff.Tag{F, eltype(u)}() : ad.tag)]
if CK === nothing || CK 0
cfg = ForwardDiff.GradientConfig(
f, u, ForwardDiff.Chunk{batched_pickchunksize(vec(u))}(), tag)
push!(calls, :(ck = ForwardDiff.Chunk{ForwardDiff.pickchunksize(length(u))}()))
else
cfg = ForwardDiff.GradientConfig(f, u, ForwardDiff.Chunk{CK}(), tag)
push!(calls, :(ck = ForwardDiff.Chunk{CK}()))
end
push!(calls, :(return _forwarddiff_gradient(f, u, typeof(tag), ck)))
return Expr(:block, calls...)
end

function _forwarddiff_gradient(f::F, u::AbstractArray{T}, ::Type{Tag},
ck::ForwardDiff.Chunk{CK}) where {F, T, Tag, CK}
L = length(u)
nchunks, remainder = divrem(L, CK)

Dual = ForwardDiff.Dual{Tag, T, CK}
Partials = ForwardDiff.Partials{CK, T}

gs = similar(u)
for i in 1:nchunks
_forwarddiff_gradient!(gs, (i - 1) * CK + 1, ck, Tag, Dual, Partials, f, u)
end
return ForwardDiff.gradient(f, u, cfg)

if remainder > 0
Dual_rem = ForwardDiff.Dual{Tag, T, remainder}
Partials_rem = ForwardDiff.Partials{remainder, T}
_forwarddiff_gradient!(gs, nchunks * CK + 1, ForwardDiff.Chunk{remainder}(),
Tag, Dual_rem, Partials_rem, f, u)
end

return gs
end

@views function _forwarddiff_gradient!(
gs, idx::Int, ::ForwardDiff.Chunk{CK}, ::Type{Tag}, ::Type{Dual},
::Type{Partials}, f::F, u::AbstractArray{T}) where {CK, Tag, Dual, Partials, F, T}
N = length(u)
idxs = idx:min(idx + CK - 1, N)
idxs_prev = 1:(idx - 1)
idxs_next = (idx + CK):N

dev = get_device(u)

partials = dev(map(𝒾 -> Partials(ntuple(𝒿 -> ifelse(𝒾 == 𝒿, oneunit(T), zero(T)), CK)),
1:length(idxs)))
u_part_duals = Dual.(u[idxs], partials)

nt = Returns(ntuple(Returns(zero(T)), CK))
if length(idxs_prev) == 0
u_part_prev = similar(u_part_duals, 0)
else
u_part_prev = Dual.(u[idxs_prev], dev(Partials.(map(nt, 1:length(idxs_prev)))))
end

if length(idxs_next) == 0
u_part_next = similar(u_part_duals, 0)
else
u_part_next = Dual.(u[idxs_next], dev(Partials.(map(nt, 1:length(idxs_next)))))
end

u_duals = reshape(vcat(u_part_prev, u_part_duals, u_part_next), size(u))
y_duals = f(u_duals)

gs[idxs] .= ForwardDiff.partials(y_duals)
return
end

# helpers.jl
Expand Down
56 changes: 56 additions & 0 deletions test/autodiff_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,59 @@
@test Matrix(J_fdiff)Matrix(J_fwdiff) atol=1e-3
end
end

@testitem "Gradient" setup=[SharedTestSetup] begin
using FiniteDiff, ForwardDiff, ReverseDiff, Zygote

rng = get_stable_rng(1001)

@testset "$mode" for (mode, aType, device, ongpu) in MODES
simple_batched_function = function (X, p)
X_ = reshape(X, :, nbatches(X))
return sum(
abs2, sum(abs2, X_ .* p; dims=1) .- sum(abs, X_ .* p; dims=1) .+ p .^ 2)
end

X = randn(rng, 3, 2, 4) |> aType
p = randn(rng, 6) |> aType

gs_fdiff = batched_gradient(
AutoFiniteDiff(), simple_batched_function, Array(X), Array(p))
gs_fwdiff = batched_gradient(AutoForwardDiff(), simple_batched_function, X, p)
gs_rdiff = batched_gradient(
AutoReverseDiff(), simple_batched_function, Array(X), Array(p))
gs_zygote = batched_gradient(AutoZygote(), simple_batched_function, X, p)

@test Array(gs_fdiff)Array(gs_fwdiff) atol=1e-3
@test Array(gs_fwdiff)Array(gs_rdiff) atol=1e-3
@test Array(gs_rdiff)Array(gs_zygote) atol=1e-3

X = randn(rng, 2, 4) |> aType
p = randn(rng, 2) |> aType

gs_fdiff = batched_gradient(
AutoFiniteDiff(), simple_batched_function, Array(X), Array(p))
gs_fwdiff = batched_gradient(AutoForwardDiff(), simple_batched_function, X, p)
gs_rdiff = batched_gradient(
AutoReverseDiff(), simple_batched_function, Array(X), Array(p))
gs_zygote = batched_gradient(AutoZygote(), simple_batched_function, X, p)

@test Array(gs_fdiff)Array(gs_fwdiff) atol=1e-3
@test Array(gs_fwdiff)Array(gs_rdiff) atol=1e-3
@test Array(gs_rdiff)Array(gs_zygote) atol=1e-3

X = randn(rng, 3) |> aType
p = randn(rng, 3) |> aType

J_fdiff = batched_gradient(
AutoFiniteDiff(), simple_batched_function, Array(X), Array(p))
J_fwdiff = batched_gradient(AutoForwardDiff(), simple_batched_function, X, p)
J_rdiff = batched_gradient(
AutoReverseDiff(), simple_batched_function, Array(X), Array(p))
J_zygote = batched_gradient(AutoZygote(), simple_batched_function, X, p)

@test Array(J_fdiff)Array(J_fwdiff) atol=1e-3
@test Array(J_fwdiff)Array(J_rdiff) atol=1e-3
@test Array(J_rdiff)Array(J_zygote) atol=1e-3
end
end

0 comments on commit a57e936

Please sign in to comment.