diff --git a/Project.toml b/Project.toml index 0ed8688..da3fc2c 100644 --- a/Project.toml +++ b/Project.toml @@ -60,6 +60,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" @@ -71,4 +72,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Zygote"] +test = ["Aqua", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearSolve", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Zygote"] diff --git a/ext/BatchedRoutinesFiniteDiffExt.jl b/ext/BatchedRoutinesFiniteDiffExt.jl index ba3edff..6b0fe49 100644 --- a/ext/BatchedRoutinesFiniteDiffExt.jl +++ b/ext/BatchedRoutinesFiniteDiffExt.jl @@ -32,7 +32,8 @@ end # NOTE: This doesn't exploit batching @inline function BatchedRoutines._batched_gradient(ad::AutoFiniteDiff, f::F, x) where {F} - return FiniteDiff.finite_difference_gradient(f, x, ad.fdjtype) + returntype = first(BatchedRoutines._resolve_gradient_type(f, f, x, Val(1))) + return FiniteDiff.finite_difference_gradient(f, x, ad.fdjtype, returntype) end # TODO: For the gradient call just use FiniteDiff over FiniteDiff diff --git a/ext/BatchedRoutinesReverseDiffExt.jl b/ext/BatchedRoutinesReverseDiffExt.jl index 2059813..6c0678c 100644 --- a/ext/BatchedRoutinesReverseDiffExt.jl +++ b/ext/BatchedRoutinesReverseDiffExt.jl @@ -16,7 +16,18 @@ Base.@assume_effects :total BatchedRoutines._assert_type(::Type{<:ReverseDiff.Tr Base.@assume_effects :total BatchedRoutines._assert_type(::Type{<:AbstractArray{<:ReverseDiff.TrackedReal}})=false function BatchedRoutines._batched_gradient(::AutoReverseDiff, f::F, u) where {F} - return ReverseDiff.gradient(f, u) + Base.issingletontype(f) && return ReverseDiff.gradient(f, u) + + ∂u = similar(u, first(BatchedRoutines._resolve_gradient_type(f, f, u, Val(1)))) + fill!(∂u, false) + + tape = ReverseDiff.InstructionTape() + u_tracked = ReverseDiff.TrackedArray(u, ∂u, tape) + y_tracked = f(u_tracked) + y_tracked.deriv = true + ReverseDiff.reverse_pass!(tape) + + return ∂u end # Chain rules integration @@ -49,10 +60,10 @@ ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian( ad, f, x::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray) ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian( - ad, f, x, p::ReverseDiff.TrackedArray) + ad, f, x::AbstractArray, p::ReverseDiff.TrackedArray) ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian( - ad, f, x::ReverseDiff.TrackedArray, p) + ad, f, x::ReverseDiff.TrackedArray, p::AbstractArray) function BatchedRoutines.batched_gradient( ad, f::F, x::AbstractArray{<:ReverseDiff.TrackedReal}) where {F} @@ -83,9 +94,9 @@ ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_gradient( ad, f, x::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray) ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_gradient( - ad, f, x, p::ReverseDiff.TrackedArray) + ad, f, x::AbstractArray, p::ReverseDiff.TrackedArray) ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_gradient( - ad, f, x::ReverseDiff.TrackedArray, p) + ad, f, x::ReverseDiff.TrackedArray, p::AbstractArray) end diff --git a/src/BatchedRoutines.jl b/src/BatchedRoutines.jl index 44603af..3e10f60 100644 --- a/src/BatchedRoutines.jl +++ b/src/BatchedRoutines.jl @@ -10,7 +10,7 @@ import PrecompileTools: @recompile_invalidations using ChainRulesCore: ChainRulesCore, HasReverseMode, NoTangent, RuleConfig using ConcreteStructs: @concrete using FastClosures: @closure - using FillArrays: Fill + using FillArrays: Fill, OneElement using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNonZero, mul!, pinv using LuxDeviceUtils: LuxDeviceUtils, get_device diff --git a/src/chainrules.jl b/src/chainrules.jl index f6ea0e4..3cb20f9 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -70,6 +70,11 @@ function CRC.rrule(::typeof(batched_gradient), ad, f::F, x, p) where {F} throw(ArgumentError("`ForwardDiff.jl` needs to be loaded to compute the gradient \ of `batched_gradient`.")) + if ad isa AutoForwardDiff && get_device(x) isa LuxDeviceUtils.AbstractLuxGPUDevice + @warn "`rrule` of `batched_gradient($(ad))` might fail on GPU. Consider using \ + `AutoZygote` instead." maxlog=1 + end + dx = batched_gradient(ad, f, x, p) ∇batched_gradient = @closure Δ -> begin ∂x = _jacobian_vector_product( diff --git a/src/helpers.jl b/src/helpers.jl index 4999a77..c7b2217 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -134,3 +134,16 @@ function _maybe_remove_chunksize(ad::AutoAllForwardDiff{CK}, x) where {CK} (CK === nothing || CK ≤ 0 || CK ≤ length(x)) && return ad return parameterless_type(ad)() end + +# Figure out the type of the gradient +@inline function _resolve_gradient_type(f::F, g::G, x, ::Val{depth}) where {F, G, depth} + Base.issingletontype(f) && return (eltype(x), false) + return promote_type(eltype(x), eltype(g(x))), true +end +@inline function _resolve_gradient_type( + f::Union{Base.Fix1, Base.Fix2}, g::G, x, ::Val{depth}) where {G, depth} + depth ≥ 5 && return promote_type(eltype(x), eltype(f(x))), true + T, resolved = _resolve_gradient_type(f.f, g, x, Val(depth + 1)) + resolved && return T, true + return promote_type(T, eltype(f.x)), false +end diff --git a/test/autodiff_tests.jl b/test/autodiff_tests.jl index cec62df..2a13ce0 100644 --- a/test/autodiff_tests.jl +++ b/test/autodiff_tests.jl @@ -20,7 +20,7 @@ AutoForwardDiff(; chunksize=2), simple_batched_function, X, p) @test Matrix(J_fdiff)≈Matrix(J_fwdiff) atol=1e-3 - @test Matrix(J_fwdiff)≈Matrix(J_fwdiff2) + @test Matrix(J_fwdiff) ≈ Matrix(J_fwdiff2) end end end @@ -67,11 +67,11 @@ end Array(p)) @testset "backend: $(backend)" for backend in ( - # AutoFiniteDiff(), # FIXME: FiniteDiff doesn't work well with ForwardDiff - AutoForwardDiff(), AutoForwardDiff(; chunksize=3), - # AutoReverseDiff(), # FIXME: ReverseDiff with ForwardDiff problematic - AutoZygote()) + AutoFiniteDiff(), AutoForwardDiff(), + AutoForwardDiff(; chunksize=3), AutoReverseDiff(), AutoZygote()) (!(backend isa AutoZygote) && ongpu) && continue + atol = backend isa AutoFiniteDiff ? 1e-1 : 1e-3 + rtol = backend isa AutoFiniteDiff ? 1e-1 : 1e-3 __f = (x, p) -> sum( abs2, batched_gradient(backend, simple_batched_function, x, p)) @@ -79,10 +79,21 @@ end gs_zyg = Zygote.gradient(__f, X, p) gs_rdiff = ReverseDiff.gradient(__f, (Array(X), Array(p))) - @test Array(gs_fwddiff_x)≈Array(gs_zyg[1]) atol=1e-3 - @test Array(gs_fwddiff_p)≈Array(gs_zyg[2]) atol=1e-3 - @test Array(gs_fwddiff_x)≈Array(gs_rdiff[1]) atol=1e-3 - @test Array(gs_fwddiff_p)≈Array(gs_rdiff[2]) atol=1e-3 + @test Array(gs_fwddiff_x)≈Array(gs_zyg[1]) atol=atol rtol=rtol + @test Array(gs_fwddiff_p)≈Array(gs_zyg[2]) atol=atol rtol=rtol + @test Array(gs_fwddiff_x)≈Array(gs_rdiff[1]) atol=atol rtol=rtol + @test Array(gs_fwddiff_p)≈Array(gs_rdiff[2]) atol=atol rtol=rtol + + __f1 = x -> sum( + abs2, batched_gradient(backend, simple_batched_function, x, p)) + __f2 = x -> sum(abs2, + batched_gradient(backend, simple_batched_function, x, Array(p))) + + gs_zyg_x = only(Zygote.gradient(__f1, X)) + gs_rdiff_x = ReverseDiff.gradient(__f2, Array(X)) + + @test Array(gs_zyg_x)≈Array(gs_fwddiff_x) atol=atol rtol=rtol + @test Array(gs_rdiff_x)≈Array(gs_fwddiff_x) atol=atol rtol=rtol end end end diff --git a/test/integration_tests.jl b/test/integration_tests.jl new file mode 100644 index 0000000..57c7bf1 --- /dev/null +++ b/test/integration_tests.jl @@ -0,0 +1,2 @@ +@testitem "Linear Solve" setup=[SharedTestSetup] begin +end