From 05b7a58b0935971d59ec1d3472f07a18abb9fd93 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Apr 2024 18:14:04 -0400 Subject: [PATCH] Add Batched Nonlinear Solvers for Forward AD rules --- Project.toml | 11 ++++- .../BatchedRoutinesForwardDiffExt.jl | 24 ++++++++++ .../jacobian.jl} | 17 ------- .../nonlinearsolve_ad.jl | 44 ++++++++++++++++++ ext/BatchedRoutinesLinearSolveExt.jl | 5 +- src/BatchedRoutines.jl | 18 ++++++-- src/helpers.jl | 3 ++ src/nlsolve/batched_raphson.jl | 33 +++++++++++++ src/nlsolve/utils.jl | 26 +++++++++++ test/nlsolve_tests.jl | 46 +++++++++++++++++++ 10 files changed, 203 insertions(+), 24 deletions(-) create mode 100644 ext/BatchedRoutinesForwardDiffExt/BatchedRoutinesForwardDiffExt.jl rename ext/{BatchedRoutinesForwardDiffExt.jl => BatchedRoutinesForwardDiffExt/jacobian.jl} (94%) create mode 100644 ext/BatchedRoutinesForwardDiffExt/nonlinearsolve_ad.jl create mode 100644 src/nlsolve/batched_raphson.jl create mode 100644 src/nlsolve/utils.jl create mode 100644 test/nlsolve_tests.jl diff --git a/Project.toml b/Project.toml index e2d43b8..2597f11 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" [weakdeps] @@ -27,8 +28,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] BatchedRoutinesCUDAExt = ["CUDA"] -BatchedRoutinesComponentArraysForwardDiffExt = ["ComponentArrays", "ForwardDiff"] BatchedRoutinesCUDALinearSolveExt = ["CUDA", "LinearSolve"] +BatchedRoutinesComponentArraysForwardDiffExt = ["ComponentArrays", "ForwardDiff"] BatchedRoutinesFiniteDiffExt = ["FiniteDiff"] BatchedRoutinesForwardDiffExt = ["ForwardDiff"] BatchedRoutinesLinearSolveExt = ["LinearSolve"] @@ -42,6 +43,7 @@ Aqua = "0.8.4" ArrayInterface = "7.8.1" CUDA = "5.2.0" ChainRulesCore = "1.23" +Chairmarks = "1.2" ComponentArrays = "0.15.10" ConcreteStructs = "0.2.3" ExplicitImports = "1.4.0" @@ -59,7 +61,9 @@ PrecompileTools = "1.2.0" Random = "<0.0.1, 1" ReTestItems = "1.23.1" ReverseDiff = "1.15" +SciMLBase = "2.31" SciMLOperators = "0.3.8" +SimpleNonlinearSolve = "1.7" StableRNGs = "1.0.1" Statistics = "1.11.1" Test = "<0.0.1, 1" @@ -68,6 +72,7 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" @@ -80,10 +85,12 @@ LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearSolve", "Lux", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "StableRNGs", "Statistics", "Test", "Zygote"] +test = ["Aqua", "Chairmarks", "ComponentArrays", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearSolve", "Lux", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "SciMLBase", "SimpleNonlinearSolve", "StableRNGs", "Statistics", "Test", "Zygote"] diff --git a/ext/BatchedRoutinesForwardDiffExt/BatchedRoutinesForwardDiffExt.jl b/ext/BatchedRoutinesForwardDiffExt/BatchedRoutinesForwardDiffExt.jl new file mode 100644 index 0000000..e7917dd --- /dev/null +++ b/ext/BatchedRoutinesForwardDiffExt/BatchedRoutinesForwardDiffExt.jl @@ -0,0 +1,24 @@ +module BatchedRoutinesForwardDiffExt + +using ADTypes: AutoForwardDiff +using ArrayInterface: parameterless_type +using BatchedRoutines: BatchedRoutines, AbstractBatchedNonlinearAlgorithm, + UniformBlockDiagonalOperator, batched_jacobian, batched_mul, + batched_pickchunksize, _assert_type +using ChainRulesCore: ChainRulesCore +using FastClosures: @closure +using ForwardDiff: ForwardDiff, Dual +using LinearAlgebra: LinearAlgebra +using LuxDeviceUtils: LuxDeviceUtils, get_device +using SciMLBase: SciMLBase, NonlinearProblem + +const CRC = ChainRulesCore + +@inline BatchedRoutines._is_extension_loaded(::Val{:ForwardDiff}) = true + +@inline BatchedRoutines.__can_forwarddiff_dual(::Type{T}) where {T} = ForwardDiff.can_dual(T) + +include("jacobian.jl") +include("nonlinearsolve_ad.jl") + +end diff --git a/ext/BatchedRoutinesForwardDiffExt.jl b/ext/BatchedRoutinesForwardDiffExt/jacobian.jl similarity index 94% rename from ext/BatchedRoutinesForwardDiffExt.jl rename to ext/BatchedRoutinesForwardDiffExt/jacobian.jl index b5b3d3c..b3574a8 100644 --- a/ext/BatchedRoutinesForwardDiffExt.jl +++ b/ext/BatchedRoutinesForwardDiffExt/jacobian.jl @@ -1,18 +1,3 @@ -module BatchedRoutinesForwardDiffExt - -using ADTypes: AutoForwardDiff -using ArrayInterface: parameterless_type -using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, batched_jacobian, - batched_mul, batched_pickchunksize, _assert_type -using ChainRulesCore: ChainRulesCore -using FastClosures: @closure -using ForwardDiff: ForwardDiff -using LuxDeviceUtils: LuxDeviceUtils, get_device - -const CRC = ChainRulesCore - -@inline BatchedRoutines._is_extension_loaded(::Val{:ForwardDiff}) = true - # api.jl function BatchedRoutines.batched_pickchunksize( X::AbstractArray{T, N}, n::Int=ForwardDiff.DEFAULT_CHUNK_THRESHOLD) where {T, N} @@ -242,5 +227,3 @@ end partials = ForwardDiff.Partials{1, T}.(tuple.(u)) return ForwardDiff.Dual{Tag, T, 1}.(x, reshape(partials, size(x))) end - -end diff --git a/ext/BatchedRoutinesForwardDiffExt/nonlinearsolve_ad.jl b/ext/BatchedRoutinesForwardDiffExt/nonlinearsolve_ad.jl new file mode 100644 index 0000000..75b0c09 --- /dev/null +++ b/ext/BatchedRoutinesForwardDiffExt/nonlinearsolve_ad.jl @@ -0,0 +1,44 @@ +function SciMLBase.solve( + prob::NonlinearProblem{<:AbstractArray, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, + alg::AbstractBatchedNonlinearAlgorithm, + args...; + kwargs...) where {T, V, P, iip} + sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) + dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) +end + +function __nlsolve_ad(prob::NonlinearProblem, alg, args...; kwargs...) + p = ForwardDiff.value.(prob.p) + u0 = ForwardDiff.value.(prob.u0) + newprob = NonlinearProblem(prob.f, u0, p; prob.kwargs...) + + sol = SciMLBase.solve(newprob, alg, args...; kwargs...) + + uu = sol.u + Jₚ = ForwardDiff.jacobian(Base.Fix1(prob.f, uu), p) + Jᵤ = if prob.f.jac === nothing + BatchedRoutines.batched_jacobian(AutoForwardDiff(), prob.f, uu, p) + else + BatchedRoutines._wrap_batched_operator(prob.f.jac(uu, p)) + end + + Jᵤ_fact = LinearAlgebra.lu!(Jᵤ) + + map_fn = @closure zp -> begin + Jₚᵢ, p = zp + LinearAlgebra.ldiv!(Jᵤ_fact, Jₚᵢ) + Jₚᵢ .*= -1 + return map(Base.Fix2(*, ForwardDiff.partials(p)), Jₚᵢ) + end + + return sol, sum(map_fn, zip(eachcol(Jₚ), prob.p)) +end + +@inline function __nlsolve_dual_soln(u::AbstractArray, partials, + ::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P} + _partials = reshape(partials, size(u)) + return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, _partials)) +end diff --git a/ext/BatchedRoutinesLinearSolveExt.jl b/ext/BatchedRoutinesLinearSolveExt.jl index 42b2808..16b6912 100644 --- a/ext/BatchedRoutinesLinearSolveExt.jl +++ b/ext/BatchedRoutinesLinearSolveExt.jl @@ -5,7 +5,8 @@ using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, getdata using ChainRulesCore: ChainRulesCore, NoTangent using FastClosures: @closure using LinearAlgebra: LinearAlgebra -using LinearSolve: LinearSolve, SciMLBase +using LinearSolve: LinearSolve +using SciMLBase: SciMLBase const CRC = ChainRulesCore @@ -113,7 +114,7 @@ function LinearSolve.solve!(cache::LinearSolve.LinearCache{<:UniformBlockDiagona y = LinearAlgebra.ldiv!( cache.u, LinearSolve.@get_cacheval(cache, :NormalCholeskyFactorization), A' * cache.b) - return LinearSolve.SciMLBase.build_linear_solution(alg, y, nothing, cache) + return SciMLBase.build_linear_solution(alg, y, nothing, cache) end # SVDFactorization diff --git a/src/BatchedRoutines.jl b/src/BatchedRoutines.jl index 692b71c..304b477 100644 --- a/src/BatchedRoutines.jl +++ b/src/BatchedRoutines.jl @@ -3,8 +3,9 @@ module BatchedRoutines import PrecompileTools: @recompile_invalidations @recompile_invalidations begin - using ADTypes: AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoSparseForwardDiff, - AutoSparsePolyesterForwardDiff, AutoPolyesterForwardDiff, AutoZygote + using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff, + AutoReverseDiff, AutoSparseForwardDiff, AutoSparsePolyesterForwardDiff, + AutoPolyesterForwardDiff, AutoZygote using Adapt: Adapt using ArrayInterface: ArrayInterface, parameterless_type using ChainRulesCore: ChainRulesCore, HasReverseMode, NoTangent, RuleConfig @@ -14,6 +15,7 @@ import PrecompileTools: @recompile_invalidations using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNonZero, mul!, pinv using LuxDeviceUtils: LuxDeviceUtils, get_device + using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem, ReturnCode using SciMLOperators: SciMLOperators, AbstractSciMLOperator end @@ -40,23 +42,33 @@ const AutoAllForwardDiff{CK} = Union{<:AutoForwardDiff{CK}, <:AutoSparseForwardD const BatchedVector{T} = AbstractMatrix{T} const BatchedMatrix{T} = AbstractArray{T, 3} +abstract type AbstractBatchedNonlinearAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end + @inline _is_extension_loaded(::Val) = false +include("operator.jl") + include("api.jl") include("helpers.jl") -include("operator.jl") include("factorization.jl") +include("nlsolve/utils.jl") +include("nlsolve/batched_raphson.jl") + include("impl/batched_mul.jl") include("impl/batched_gmres.jl") include("chainrules.jl") +# Core export AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoZygote export batched_adjoint, batched_gradient, batched_jacobian, batched_pickchunksize, batched_mul, batched_pinv, batched_transpose export batchview, nbatches export UniformBlockDiagonalOperator +# Nonlinear Solvers +export BatchedSimpleNewtonRaphson, BatchedSimpleGaussNewton + end diff --git a/src/helpers.jl b/src/helpers.jl index fd071af..25a52e2 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -167,3 +167,6 @@ end CRC.@non_differentiable fill_like(::Any...) CRC.@non_differentiable zeros_like(::Any...) CRC.@non_differentiable ones_like(::Any...) + +@inline _wrap_batched_operator(x::AbstractArray{T, 3}) where {T} = UniformBlockDiagonalOperator(x) +@inline _wrap_batched_operator(x::UniformBlockDiagonalOperator) = x diff --git a/src/nlsolve/batched_raphson.jl b/src/nlsolve/batched_raphson.jl new file mode 100644 index 0000000..d38052c --- /dev/null +++ b/src/nlsolve/batched_raphson.jl @@ -0,0 +1,33 @@ +@kwdef @concrete struct BatchedSimpleNewtonRaphson <: AbstractBatchedNonlinearAlgorithm + autodiff = nothing +end + +const BatchedSimpleGaussNewton = BatchedSimpleNewtonRaphson + +function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, + alg::BatchedSimpleNewtonRaphson, args...; + abstol=nothing, maxiters::Int=1000, kwargs...) + @assert !SciMLBase.isinplace(prob) "BatchedSimpleNewtonRaphson does not support inplace." + + x = deepcopy(prob.u0) + fx = prob.f(x, prob.p) + @assert (ndims(x) == 2)&&(ndims(fx) == 2) "BatchedSimpleNewtonRaphson only supports matrices." + + autodiff = __get_concrete_autodiff(prob, alg.autodiff) + abstol = __get_tolerance(abstol, x) + + maximum(abs, fx) < abstol && + return SciMLBase.build_solution(prob, alg, x, fx; retcode=ReturnCode.Success) + + for _ in 1:maxiters + fx, J = __value_and_jacobian(prob, x, autodiff) + δx = J \ fx + + maximum(abs, fx) < abstol && + return SciMLBase.build_solution(prob, alg, x, fx; retcode=ReturnCode.Success) + + @. x -= δx + end + + return SciMLBase.build_solution(prob, alg, x, fx; retcode=ReturnCode.MaxIters) +end diff --git a/src/nlsolve/utils.jl b/src/nlsolve/utils.jl new file mode 100644 index 0000000..6b80bbe --- /dev/null +++ b/src/nlsolve/utils.jl @@ -0,0 +1,26 @@ +@inline __get_concrete_autodiff(prob, ad::AbstractADType; kwargs...) = ad +@inline function __get_concrete_autodiff(prob, ::Nothing; kwargs...) + prob.f.jac !== nothing && return nothing + if _is_extension_loaded(Val(:ForwardDiff)) && __can_forwarddiff_dual(eltype(prob.u0)) + return AutoForwardDiff() + elseif _is_extension_loaded(Val(:FiniteDiff)) + return AutoFiniteDiff() + else + error("No AD backend loaded. Please load an AD backend first.") + end +end + +function __can_forwarddiff_dual end + +@inline function __value_and_jacobian(prob, x, autodiff) + if prob.f.jac === nothing + return prob.f(x, prob.p), batched_jacobian(autodiff, prob.f, x, prob.p) + else + return prob.f(x, prob.p), _wrap_batched_operator(prob.f.jac(x, prob.p)) + end +end + +@inline __get_tolerance(abstol, u0) = __get_tolerance(abstol, eltype(u0)) +@inline function __get_tolerance(abstol, ::Type{T}) where {T} + return abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol +end diff --git a/test/nlsolve_tests.jl b/test/nlsolve_tests.jl new file mode 100644 index 0000000..2916431 --- /dev/null +++ b/test/nlsolve_tests.jl @@ -0,0 +1,46 @@ +@testitem "Batched Nonlinear Solvers" setup=[SharedTestSetup] begin + using Chairmarks, ForwardDiff, SciMLBase, SimpleNonlinearSolve, Statistics, Zygote + + testing_f(u, p) = u .^ 2 .+ u .^ 3 .- u .- p + + u0 = rand(3, 128) + p = rand(1, 128) + + prob = NonlinearProblem(testing_f, u0, p) + + sol_nlsolve = solve(prob, SimpleNewtonRaphson()) + sol_batched = solve(prob, BatchedSimpleNewtonRaphson()) + + @test abs.(sol_nlsolve.u) ≈ abs.(sol_batched.u) + + nlsolve_timing = @be solve($prob, $SimpleNewtonRaphson()) + batched_timing = @be solve($prob, $BatchedSimpleNewtonRaphson()) + + @info "SimpleNonlinearSolve Timing: $(median(nlsolve_timing))." + @info "BatchedSimpleNewtonRaphson Timing: $(median(batched_timing))." + + ∂p1 = ForwardDiff.gradient(p) do p + prob = NonlinearProblem(testing_f, u0, p) + return sum(abs2, solve(prob, SimpleNewtonRaphson()).u) + end + + ∂p2 = ForwardDiff.gradient(p) do p + prob = NonlinearProblem(testing_f, u0, p) + return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson()).u) + end + + @test ∂p1 ≈ ∂p2 + + fwdiff_nlsolve_timing = @be ForwardDiff.gradient($p) do p + prob = NonlinearProblem(testing_f, u0, p) + return sum(abs2, solve(prob, SimpleNewtonRaphson()).u) + end + + fwdiff_batched_timing = @be ForwardDiff.gradient($p) do p + prob = NonlinearProblem(testing_f, u0, p) + return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson()).u) + end + + @info "ForwardDiff SimpleNonlinearSolve Timing: $(median(fwdiff_nlsolve_timing))." + @info "ForwardDiff BatchedNonlinearSolve Timing: $(median(fwdiff_batched_timing))." +end