This repository has been archived by the owner on Nov 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Batched Nonlinear Solvers for Forward AD rules
- Loading branch information
Showing
10 changed files
with
203 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
24 changes: 24 additions & 0 deletions
24
ext/BatchedRoutinesForwardDiffExt/BatchedRoutinesForwardDiffExt.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
module BatchedRoutinesForwardDiffExt | ||
|
||
using ADTypes: AutoForwardDiff | ||
using ArrayInterface: parameterless_type | ||
using BatchedRoutines: BatchedRoutines, AbstractBatchedNonlinearAlgorithm, | ||
UniformBlockDiagonalOperator, batched_jacobian, batched_mul, | ||
batched_pickchunksize, _assert_type | ||
using ChainRulesCore: ChainRulesCore | ||
using FastClosures: @closure | ||
using ForwardDiff: ForwardDiff, Dual | ||
using LinearAlgebra: LinearAlgebra | ||
using LuxDeviceUtils: LuxDeviceUtils, get_device | ||
using SciMLBase: SciMLBase, NonlinearProblem | ||
|
||
const CRC = ChainRulesCore | ||
|
||
@inline BatchedRoutines._is_extension_loaded(::Val{:ForwardDiff}) = true | ||
|
||
@inline BatchedRoutines.__can_forwarddiff_dual(::Type{T}) where {T} = ForwardDiff.can_dual(T) | ||
|
||
include("jacobian.jl") | ||
include("nonlinearsolve_ad.jl") | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
function SciMLBase.solve( | ||
prob::NonlinearProblem{<:AbstractArray, iip, | ||
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, | ||
alg::AbstractBatchedNonlinearAlgorithm, | ||
args...; | ||
kwargs...) where {T, V, P, iip} | ||
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) | ||
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) | ||
return SciMLBase.build_solution( | ||
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) | ||
end | ||
|
||
function __nlsolve_ad(prob::NonlinearProblem, alg, args...; kwargs...) | ||
p = ForwardDiff.value.(prob.p) | ||
u0 = ForwardDiff.value.(prob.u0) | ||
newprob = NonlinearProblem(prob.f, u0, p; prob.kwargs...) | ||
|
||
sol = SciMLBase.solve(newprob, alg, args...; kwargs...) | ||
|
||
uu = sol.u | ||
Jₚ = ForwardDiff.jacobian(Base.Fix1(prob.f, uu), p) | ||
Jᵤ = if prob.f.jac === nothing | ||
BatchedRoutines.batched_jacobian(AutoForwardDiff(), prob.f, uu, p) | ||
else | ||
BatchedRoutines._wrap_batched_operator(prob.f.jac(uu, p)) | ||
end | ||
|
||
Jᵤ_fact = LinearAlgebra.lu!(Jᵤ) | ||
|
||
map_fn = @closure zp -> begin | ||
Jₚᵢ, p = zp | ||
LinearAlgebra.ldiv!(Jᵤ_fact, Jₚᵢ) | ||
Jₚᵢ .*= -1 | ||
return map(Base.Fix2(*, ForwardDiff.partials(p)), Jₚᵢ) | ||
end | ||
|
||
return sol, sum(map_fn, zip(eachcol(Jₚ), prob.p)) | ||
end | ||
|
||
@inline function __nlsolve_dual_soln(u::AbstractArray, partials, | ||
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P} | ||
_partials = reshape(partials, size(u)) | ||
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, _partials)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
@kwdef @concrete struct BatchedSimpleNewtonRaphson <: AbstractBatchedNonlinearAlgorithm | ||
autodiff = nothing | ||
end | ||
|
||
const BatchedSimpleGaussNewton = BatchedSimpleNewtonRaphson | ||
|
||
function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, | ||
alg::BatchedSimpleNewtonRaphson, args...; | ||
abstol=nothing, maxiters::Int=1000, kwargs...) | ||
@assert !SciMLBase.isinplace(prob) "BatchedSimpleNewtonRaphson does not support inplace." | ||
|
||
x = deepcopy(prob.u0) | ||
fx = prob.f(x, prob.p) | ||
@assert (ndims(x) == 2)&&(ndims(fx) == 2) "BatchedSimpleNewtonRaphson only supports matrices." | ||
|
||
autodiff = __get_concrete_autodiff(prob, alg.autodiff) | ||
abstol = __get_tolerance(abstol, x) | ||
|
||
maximum(abs, fx) < abstol && | ||
return SciMLBase.build_solution(prob, alg, x, fx; retcode=ReturnCode.Success) | ||
|
||
for _ in 1:maxiters | ||
fx, J = __value_and_jacobian(prob, x, autodiff) | ||
δx = J \ fx | ||
|
||
maximum(abs, fx) < abstol && | ||
return SciMLBase.build_solution(prob, alg, x, fx; retcode=ReturnCode.Success) | ||
|
||
@. x -= δx | ||
end | ||
|
||
return SciMLBase.build_solution(prob, alg, x, fx; retcode=ReturnCode.MaxIters) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
@inline __get_concrete_autodiff(prob, ad::AbstractADType; kwargs...) = ad | ||
@inline function __get_concrete_autodiff(prob, ::Nothing; kwargs...) | ||
prob.f.jac !== nothing && return nothing | ||
if _is_extension_loaded(Val(:ForwardDiff)) && __can_forwarddiff_dual(eltype(prob.u0)) | ||
return AutoForwardDiff() | ||
elseif _is_extension_loaded(Val(:FiniteDiff)) | ||
return AutoFiniteDiff() | ||
else | ||
error("No AD backend loaded. Please load an AD backend first.") | ||
end | ||
end | ||
|
||
function __can_forwarddiff_dual end | ||
|
||
@inline function __value_and_jacobian(prob, x, autodiff) | ||
if prob.f.jac === nothing | ||
return prob.f(x, prob.p), batched_jacobian(autodiff, prob.f, x, prob.p) | ||
else | ||
return prob.f(x, prob.p), _wrap_batched_operator(prob.f.jac(x, prob.p)) | ||
end | ||
end | ||
|
||
@inline __get_tolerance(abstol, u0) = __get_tolerance(abstol, eltype(u0)) | ||
@inline function __get_tolerance(abstol, ::Type{T}) where {T} | ||
return abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
@testitem "Batched Nonlinear Solvers" setup=[SharedTestSetup] begin | ||
using Chairmarks, ForwardDiff, SciMLBase, SimpleNonlinearSolve, Statistics, Zygote | ||
|
||
testing_f(u, p) = u .^ 2 .+ u .^ 3 .- u .- p | ||
|
||
u0 = rand(3, 128) | ||
p = rand(1, 128) | ||
|
||
prob = NonlinearProblem(testing_f, u0, p) | ||
|
||
sol_nlsolve = solve(prob, SimpleNewtonRaphson()) | ||
sol_batched = solve(prob, BatchedSimpleNewtonRaphson()) | ||
|
||
@test abs.(sol_nlsolve.u) ≈ abs.(sol_batched.u) | ||
|
||
nlsolve_timing = @be solve($prob, $SimpleNewtonRaphson()) | ||
batched_timing = @be solve($prob, $BatchedSimpleNewtonRaphson()) | ||
|
||
@info "SimpleNonlinearSolve Timing: $(median(nlsolve_timing))." | ||
@info "BatchedSimpleNewtonRaphson Timing: $(median(batched_timing))." | ||
|
||
∂p1 = ForwardDiff.gradient(p) do p | ||
prob = NonlinearProblem(testing_f, u0, p) | ||
return sum(abs2, solve(prob, SimpleNewtonRaphson()).u) | ||
end | ||
|
||
∂p2 = ForwardDiff.gradient(p) do p | ||
prob = NonlinearProblem(testing_f, u0, p) | ||
return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson()).u) | ||
end | ||
|
||
@test ∂p1 ≈ ∂p2 | ||
|
||
fwdiff_nlsolve_timing = @be ForwardDiff.gradient($p) do p | ||
prob = NonlinearProblem(testing_f, u0, p) | ||
return sum(abs2, solve(prob, SimpleNewtonRaphson()).u) | ||
end | ||
|
||
fwdiff_batched_timing = @be ForwardDiff.gradient($p) do p | ||
prob = NonlinearProblem(testing_f, u0, p) | ||
return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson()).u) | ||
end | ||
|
||
@info "ForwardDiff SimpleNonlinearSolve Timing: $(median(fwdiff_nlsolve_timing))." | ||
@info "ForwardDiff BatchedNonlinearSolve Timing: $(median(fwdiff_batched_timing))." | ||
end |