Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Fix ForwardDiff over FiniteDiff and ReverseDiff
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 15, 2024
1 parent 9867ee1 commit 16513d2
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 17 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
Expand All @@ -71,4 +72,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Zygote"]
test = ["Aqua", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearSolve", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Zygote"]
3 changes: 2 additions & 1 deletion ext/BatchedRoutinesFiniteDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ end

# NOTE: This doesn't exploit batching
@inline function BatchedRoutines._batched_gradient(ad::AutoFiniteDiff, f::F, x) where {F}
return FiniteDiff.finite_difference_gradient(f, x, ad.fdjtype)
returntype = first(BatchedRoutines._resolve_gradient_type(f, f, x, Val(1)))
return FiniteDiff.finite_difference_gradient(f, x, ad.fdjtype, returntype)
end

# TODO: For the gradient call just use FiniteDiff over FiniteDiff
Expand Down
21 changes: 16 additions & 5 deletions ext/BatchedRoutinesReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,18 @@ Base.@assume_effects :total BatchedRoutines._assert_type(::Type{<:ReverseDiff.Tr
Base.@assume_effects :total BatchedRoutines._assert_type(::Type{<:AbstractArray{<:ReverseDiff.TrackedReal}})=false

function BatchedRoutines._batched_gradient(::AutoReverseDiff, f::F, u) where {F}
return ReverseDiff.gradient(f, u)
Base.issingletontype(f) && return ReverseDiff.gradient(f, u)

∂u = similar(u, first(BatchedRoutines._resolve_gradient_type(f, f, u, Val(1))))
fill!(∂u, false)

tape = ReverseDiff.InstructionTape()
u_tracked = ReverseDiff.TrackedArray(u, ∂u, tape)
y_tracked = f(u_tracked)
y_tracked.deriv = true
ReverseDiff.reverse_pass!(tape)

return ∂u
end

# Chain rules integration
Expand Down Expand Up @@ -49,10 +60,10 @@ ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian(
ad, f, x::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray)

ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian(
ad, f, x, p::ReverseDiff.TrackedArray)
ad, f, x::AbstractArray, p::ReverseDiff.TrackedArray)

ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian(
ad, f, x::ReverseDiff.TrackedArray, p)
ad, f, x::ReverseDiff.TrackedArray, p::AbstractArray)

function BatchedRoutines.batched_gradient(
ad, f::F, x::AbstractArray{<:ReverseDiff.TrackedReal}) where {F}
Expand Down Expand Up @@ -83,9 +94,9 @@ ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_gradient(
ad, f, x::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray)

ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_gradient(
ad, f, x, p::ReverseDiff.TrackedArray)
ad, f, x::AbstractArray, p::ReverseDiff.TrackedArray)

ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_gradient(
ad, f, x::ReverseDiff.TrackedArray, p)
ad, f, x::ReverseDiff.TrackedArray, p::AbstractArray)

end
2 changes: 1 addition & 1 deletion src/BatchedRoutines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import PrecompileTools: @recompile_invalidations
using ChainRulesCore: ChainRulesCore, HasReverseMode, NoTangent, RuleConfig
using ConcreteStructs: @concrete
using FastClosures: @closure
using FillArrays: Fill
using FillArrays: Fill, OneElement
using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNonZero,
mul!, pinv
using LuxDeviceUtils: LuxDeviceUtils, get_device
Expand Down
5 changes: 5 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ function CRC.rrule(::typeof(batched_gradient), ad, f::F, x, p) where {F}
throw(ArgumentError("`ForwardDiff.jl` needs to be loaded to compute the gradient \
of `batched_gradient`."))

if ad isa AutoForwardDiff && get_device(x) isa LuxDeviceUtils.AbstractLuxGPUDevice
@warn "`rrule` of `batched_gradient($(ad))` might fail on GPU. Consider using \
`AutoZygote` instead." maxlog=1
end

dx = batched_gradient(ad, f, x, p)
∇batched_gradient = @closure Δ -> begin
∂x = _jacobian_vector_product(
Expand Down
13 changes: 13 additions & 0 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,16 @@ function _maybe_remove_chunksize(ad::AutoAllForwardDiff{CK}, x) where {CK}
(CK === nothing || CK 0 || CK length(x)) && return ad
return parameterless_type(ad)()
end

# Figure out the type of the gradient
@inline function _resolve_gradient_type(f::F, g::G, x, ::Val{depth}) where {F, G, depth}
Base.issingletontype(f) && return (eltype(x), false)
return promote_type(eltype(x), eltype(g(x))), true
end
@inline function _resolve_gradient_type(
f::Union{Base.Fix1, Base.Fix2}, g::G, x, ::Val{depth}) where {G, depth}
depth 5 && return promote_type(eltype(x), eltype(f(x))), true
T, resolved = _resolve_gradient_type(f.f, g, x, Val(depth + 1))
resolved && return T, true
return promote_type(T, eltype(f.x)), false
end
29 changes: 20 additions & 9 deletions test/autodiff_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
AutoForwardDiff(; chunksize=2), simple_batched_function, X, p)

@test Matrix(J_fdiff)Matrix(J_fwdiff) atol=1e-3
@test Matrix(J_fwdiff)Matrix(J_fwdiff2)
@test Matrix(J_fwdiff) Matrix(J_fwdiff2)
end
end
end
Expand Down Expand Up @@ -67,22 +67,33 @@ end
Array(p))

@testset "backend: $(backend)" for backend in (
# AutoFiniteDiff(), # FIXME: FiniteDiff doesn't work well with ForwardDiff
AutoForwardDiff(), AutoForwardDiff(; chunksize=3),
# AutoReverseDiff(), # FIXME: ReverseDiff with ForwardDiff problematic
AutoZygote())
AutoFiniteDiff(), AutoForwardDiff(),
AutoForwardDiff(; chunksize=3), AutoReverseDiff(), AutoZygote())
(!(backend isa AutoZygote) && ongpu) && continue
atol = backend isa AutoFiniteDiff ? 1e-1 : 1e-3
rtol = backend isa AutoFiniteDiff ? 1e-1 : 1e-3

__f = (x, p) -> sum(
abs2, batched_gradient(backend, simple_batched_function, x, p))

gs_zyg = Zygote.gradient(__f, X, p)
gs_rdiff = ReverseDiff.gradient(__f, (Array(X), Array(p)))

@test Array(gs_fwddiff_x)Array(gs_zyg[1]) atol=1e-3
@test Array(gs_fwddiff_p)Array(gs_zyg[2]) atol=1e-3
@test Array(gs_fwddiff_x)Array(gs_rdiff[1]) atol=1e-3
@test Array(gs_fwddiff_p)Array(gs_rdiff[2]) atol=1e-3
@test Array(gs_fwddiff_x)Array(gs_zyg[1]) atol=atol rtol=rtol
@test Array(gs_fwddiff_p)Array(gs_zyg[2]) atol=atol rtol=rtol
@test Array(gs_fwddiff_x)Array(gs_rdiff[1]) atol=atol rtol=rtol
@test Array(gs_fwddiff_p)Array(gs_rdiff[2]) atol=atol rtol=rtol

__f1 = x -> sum(
abs2, batched_gradient(backend, simple_batched_function, x, p))
__f2 = x -> sum(abs2,
batched_gradient(backend, simple_batched_function, x, Array(p)))

gs_zyg_x = only(Zygote.gradient(__f1, X))
gs_rdiff_x = ReverseDiff.gradient(__f2, Array(X))

@test Array(gs_zyg_x)Array(gs_fwddiff_x) atol=atol rtol=rtol
@test Array(gs_rdiff_x)Array(gs_fwddiff_x) atol=atol rtol=rtol
end
end
end
Expand Down
2 changes: 2 additions & 0 deletions test/integration_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
@testitem "Linear Solve" setup=[SharedTestSetup] begin
end

0 comments on commit 16513d2

Please sign in to comment.