From 05b7a58b0935971d59ec1d3472f07a18abb9fd93 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Apr 2024 18:14:04 -0400 Subject: [PATCH 1/4] Add Batched Nonlinear Solvers for Forward AD rules --- Project.toml | 11 ++++- .../BatchedRoutinesForwardDiffExt.jl | 24 ++++++++++ .../jacobian.jl} | 17 ------- .../nonlinearsolve_ad.jl | 44 ++++++++++++++++++ ext/BatchedRoutinesLinearSolveExt.jl | 5 +- src/BatchedRoutines.jl | 18 ++++++-- src/helpers.jl | 3 ++ src/nlsolve/batched_raphson.jl | 33 +++++++++++++ src/nlsolve/utils.jl | 26 +++++++++++ test/nlsolve_tests.jl | 46 +++++++++++++++++++ 10 files changed, 203 insertions(+), 24 deletions(-) create mode 100644 ext/BatchedRoutinesForwardDiffExt/BatchedRoutinesForwardDiffExt.jl rename ext/{BatchedRoutinesForwardDiffExt.jl => BatchedRoutinesForwardDiffExt/jacobian.jl} (94%) create mode 100644 ext/BatchedRoutinesForwardDiffExt/nonlinearsolve_ad.jl create mode 100644 src/nlsolve/batched_raphson.jl create mode 100644 src/nlsolve/utils.jl create mode 100644 test/nlsolve_tests.jl diff --git a/Project.toml b/Project.toml index e2d43b8..2597f11 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" [weakdeps] @@ -27,8 +28,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] BatchedRoutinesCUDAExt = ["CUDA"] -BatchedRoutinesComponentArraysForwardDiffExt = ["ComponentArrays", "ForwardDiff"] BatchedRoutinesCUDALinearSolveExt = ["CUDA", "LinearSolve"] +BatchedRoutinesComponentArraysForwardDiffExt = ["ComponentArrays", "ForwardDiff"] BatchedRoutinesFiniteDiffExt = ["FiniteDiff"] BatchedRoutinesForwardDiffExt = ["ForwardDiff"] BatchedRoutinesLinearSolveExt = ["LinearSolve"] @@ -42,6 +43,7 @@ Aqua = "0.8.4" ArrayInterface = "7.8.1" CUDA = "5.2.0" ChainRulesCore = "1.23" +Chairmarks = "1.2" ComponentArrays = "0.15.10" ConcreteStructs = "0.2.3" ExplicitImports = "1.4.0" @@ -59,7 +61,9 @@ PrecompileTools = "1.2.0" Random = "<0.0.1, 1" ReTestItems = "1.23.1" ReverseDiff = "1.15" +SciMLBase = "2.31" SciMLOperators = "0.3.8" +SimpleNonlinearSolve = "1.7" StableRNGs = "1.0.1" Statistics = "1.11.1" Test = "<0.0.1, 1" @@ -68,6 +72,7 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" @@ -80,10 +85,12 @@ LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearSolve", "Lux", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "StableRNGs", "Statistics", "Test", "Zygote"] +test = ["Aqua", "Chairmarks", "ComponentArrays", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearSolve", "Lux", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "SciMLBase", "SimpleNonlinearSolve", "StableRNGs", "Statistics", "Test", "Zygote"] diff --git a/ext/BatchedRoutinesForwardDiffExt/BatchedRoutinesForwardDiffExt.jl b/ext/BatchedRoutinesForwardDiffExt/BatchedRoutinesForwardDiffExt.jl new file mode 100644 index 0000000..e7917dd --- /dev/null +++ b/ext/BatchedRoutinesForwardDiffExt/BatchedRoutinesForwardDiffExt.jl @@ -0,0 +1,24 @@ +module BatchedRoutinesForwardDiffExt + +using ADTypes: AutoForwardDiff +using ArrayInterface: parameterless_type +using BatchedRoutines: BatchedRoutines, AbstractBatchedNonlinearAlgorithm, + UniformBlockDiagonalOperator, batched_jacobian, batched_mul, + batched_pickchunksize, _assert_type +using ChainRulesCore: ChainRulesCore +using FastClosures: @closure +using ForwardDiff: ForwardDiff, Dual +using LinearAlgebra: LinearAlgebra +using LuxDeviceUtils: LuxDeviceUtils, get_device +using SciMLBase: SciMLBase, NonlinearProblem + +const CRC = ChainRulesCore + +@inline BatchedRoutines._is_extension_loaded(::Val{:ForwardDiff}) = true + +@inline BatchedRoutines.__can_forwarddiff_dual(::Type{T}) where {T} = ForwardDiff.can_dual(T) + +include("jacobian.jl") +include("nonlinearsolve_ad.jl") + +end diff --git a/ext/BatchedRoutinesForwardDiffExt.jl b/ext/BatchedRoutinesForwardDiffExt/jacobian.jl similarity index 94% rename from ext/BatchedRoutinesForwardDiffExt.jl rename to ext/BatchedRoutinesForwardDiffExt/jacobian.jl index b5b3d3c..b3574a8 100644 --- a/ext/BatchedRoutinesForwardDiffExt.jl +++ b/ext/BatchedRoutinesForwardDiffExt/jacobian.jl @@ -1,18 +1,3 @@ -module BatchedRoutinesForwardDiffExt - -using ADTypes: AutoForwardDiff -using ArrayInterface: parameterless_type -using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, batched_jacobian, - batched_mul, batched_pickchunksize, _assert_type -using ChainRulesCore: ChainRulesCore -using FastClosures: @closure -using ForwardDiff: ForwardDiff -using LuxDeviceUtils: LuxDeviceUtils, get_device - -const CRC = ChainRulesCore - -@inline BatchedRoutines._is_extension_loaded(::Val{:ForwardDiff}) = true - # api.jl function BatchedRoutines.batched_pickchunksize( X::AbstractArray{T, N}, n::Int=ForwardDiff.DEFAULT_CHUNK_THRESHOLD) where {T, N} @@ -242,5 +227,3 @@ end partials = ForwardDiff.Partials{1, T}.(tuple.(u)) return ForwardDiff.Dual{Tag, T, 1}.(x, reshape(partials, size(x))) end - -end diff --git a/ext/BatchedRoutinesForwardDiffExt/nonlinearsolve_ad.jl b/ext/BatchedRoutinesForwardDiffExt/nonlinearsolve_ad.jl new file mode 100644 index 0000000..75b0c09 --- /dev/null +++ b/ext/BatchedRoutinesForwardDiffExt/nonlinearsolve_ad.jl @@ -0,0 +1,44 @@ +function SciMLBase.solve( + prob::NonlinearProblem{<:AbstractArray, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, + alg::AbstractBatchedNonlinearAlgorithm, + args...; + kwargs...) where {T, V, P, iip} + sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) + dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) +end + +function __nlsolve_ad(prob::NonlinearProblem, alg, args...; kwargs...) + p = ForwardDiff.value.(prob.p) + u0 = ForwardDiff.value.(prob.u0) + newprob = NonlinearProblem(prob.f, u0, p; prob.kwargs...) + + sol = SciMLBase.solve(newprob, alg, args...; kwargs...) + + uu = sol.u + Jₚ = ForwardDiff.jacobian(Base.Fix1(prob.f, uu), p) + Jᵤ = if prob.f.jac === nothing + BatchedRoutines.batched_jacobian(AutoForwardDiff(), prob.f, uu, p) + else + BatchedRoutines._wrap_batched_operator(prob.f.jac(uu, p)) + end + + Jᵤ_fact = LinearAlgebra.lu!(Jᵤ) + + map_fn = @closure zp -> begin + Jₚᵢ, p = zp + LinearAlgebra.ldiv!(Jᵤ_fact, Jₚᵢ) + Jₚᵢ .*= -1 + return map(Base.Fix2(*, ForwardDiff.partials(p)), Jₚᵢ) + end + + return sol, sum(map_fn, zip(eachcol(Jₚ), prob.p)) +end + +@inline function __nlsolve_dual_soln(u::AbstractArray, partials, + ::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P} + _partials = reshape(partials, size(u)) + return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, _partials)) +end diff --git a/ext/BatchedRoutinesLinearSolveExt.jl b/ext/BatchedRoutinesLinearSolveExt.jl index 42b2808..16b6912 100644 --- a/ext/BatchedRoutinesLinearSolveExt.jl +++ b/ext/BatchedRoutinesLinearSolveExt.jl @@ -5,7 +5,8 @@ using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, getdata using ChainRulesCore: ChainRulesCore, NoTangent using FastClosures: @closure using LinearAlgebra: LinearAlgebra -using LinearSolve: LinearSolve, SciMLBase +using LinearSolve: LinearSolve +using SciMLBase: SciMLBase const CRC = ChainRulesCore @@ -113,7 +114,7 @@ function LinearSolve.solve!(cache::LinearSolve.LinearCache{<:UniformBlockDiagona y = LinearAlgebra.ldiv!( cache.u, LinearSolve.@get_cacheval(cache, :NormalCholeskyFactorization), A' * cache.b) - return LinearSolve.SciMLBase.build_linear_solution(alg, y, nothing, cache) + return SciMLBase.build_linear_solution(alg, y, nothing, cache) end # SVDFactorization diff --git a/src/BatchedRoutines.jl b/src/BatchedRoutines.jl index 692b71c..304b477 100644 --- a/src/BatchedRoutines.jl +++ b/src/BatchedRoutines.jl @@ -3,8 +3,9 @@ module BatchedRoutines import PrecompileTools: @recompile_invalidations @recompile_invalidations begin - using ADTypes: AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoSparseForwardDiff, - AutoSparsePolyesterForwardDiff, AutoPolyesterForwardDiff, AutoZygote + using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff, + AutoReverseDiff, AutoSparseForwardDiff, AutoSparsePolyesterForwardDiff, + AutoPolyesterForwardDiff, AutoZygote using Adapt: Adapt using ArrayInterface: ArrayInterface, parameterless_type using ChainRulesCore: ChainRulesCore, HasReverseMode, NoTangent, RuleConfig @@ -14,6 +15,7 @@ import PrecompileTools: @recompile_invalidations using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNonZero, mul!, pinv using LuxDeviceUtils: LuxDeviceUtils, get_device + using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem, ReturnCode using SciMLOperators: SciMLOperators, AbstractSciMLOperator end @@ -40,23 +42,33 @@ const AutoAllForwardDiff{CK} = Union{<:AutoForwardDiff{CK}, <:AutoSparseForwardD const BatchedVector{T} = AbstractMatrix{T} const BatchedMatrix{T} = AbstractArray{T, 3} +abstract type AbstractBatchedNonlinearAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end + @inline _is_extension_loaded(::Val) = false +include("operator.jl") + include("api.jl") include("helpers.jl") -include("operator.jl") include("factorization.jl") +include("nlsolve/utils.jl") +include("nlsolve/batched_raphson.jl") + include("impl/batched_mul.jl") include("impl/batched_gmres.jl") include("chainrules.jl") +# Core export AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoZygote export batched_adjoint, batched_gradient, batched_jacobian, batched_pickchunksize, batched_mul, batched_pinv, batched_transpose export batchview, nbatches export UniformBlockDiagonalOperator +# Nonlinear Solvers +export BatchedSimpleNewtonRaphson, BatchedSimpleGaussNewton + end diff --git a/src/helpers.jl b/src/helpers.jl index fd071af..25a52e2 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -167,3 +167,6 @@ end CRC.@non_differentiable fill_like(::Any...) CRC.@non_differentiable zeros_like(::Any...) CRC.@non_differentiable ones_like(::Any...) + +@inline _wrap_batched_operator(x::AbstractArray{T, 3}) where {T} = UniformBlockDiagonalOperator(x) +@inline _wrap_batched_operator(x::UniformBlockDiagonalOperator) = x diff --git a/src/nlsolve/batched_raphson.jl b/src/nlsolve/batched_raphson.jl new file mode 100644 index 0000000..d38052c --- /dev/null +++ b/src/nlsolve/batched_raphson.jl @@ -0,0 +1,33 @@ +@kwdef @concrete struct BatchedSimpleNewtonRaphson <: AbstractBatchedNonlinearAlgorithm + autodiff = nothing +end + +const BatchedSimpleGaussNewton = BatchedSimpleNewtonRaphson + +function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, + alg::BatchedSimpleNewtonRaphson, args...; + abstol=nothing, maxiters::Int=1000, kwargs...) + @assert !SciMLBase.isinplace(prob) "BatchedSimpleNewtonRaphson does not support inplace." + + x = deepcopy(prob.u0) + fx = prob.f(x, prob.p) + @assert (ndims(x) == 2)&&(ndims(fx) == 2) "BatchedSimpleNewtonRaphson only supports matrices." + + autodiff = __get_concrete_autodiff(prob, alg.autodiff) + abstol = __get_tolerance(abstol, x) + + maximum(abs, fx) < abstol && + return SciMLBase.build_solution(prob, alg, x, fx; retcode=ReturnCode.Success) + + for _ in 1:maxiters + fx, J = __value_and_jacobian(prob, x, autodiff) + δx = J \ fx + + maximum(abs, fx) < abstol && + return SciMLBase.build_solution(prob, alg, x, fx; retcode=ReturnCode.Success) + + @. x -= δx + end + + return SciMLBase.build_solution(prob, alg, x, fx; retcode=ReturnCode.MaxIters) +end diff --git a/src/nlsolve/utils.jl b/src/nlsolve/utils.jl new file mode 100644 index 0000000..6b80bbe --- /dev/null +++ b/src/nlsolve/utils.jl @@ -0,0 +1,26 @@ +@inline __get_concrete_autodiff(prob, ad::AbstractADType; kwargs...) = ad +@inline function __get_concrete_autodiff(prob, ::Nothing; kwargs...) + prob.f.jac !== nothing && return nothing + if _is_extension_loaded(Val(:ForwardDiff)) && __can_forwarddiff_dual(eltype(prob.u0)) + return AutoForwardDiff() + elseif _is_extension_loaded(Val(:FiniteDiff)) + return AutoFiniteDiff() + else + error("No AD backend loaded. Please load an AD backend first.") + end +end + +function __can_forwarddiff_dual end + +@inline function __value_and_jacobian(prob, x, autodiff) + if prob.f.jac === nothing + return prob.f(x, prob.p), batched_jacobian(autodiff, prob.f, x, prob.p) + else + return prob.f(x, prob.p), _wrap_batched_operator(prob.f.jac(x, prob.p)) + end +end + +@inline __get_tolerance(abstol, u0) = __get_tolerance(abstol, eltype(u0)) +@inline function __get_tolerance(abstol, ::Type{T}) where {T} + return abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol +end diff --git a/test/nlsolve_tests.jl b/test/nlsolve_tests.jl new file mode 100644 index 0000000..2916431 --- /dev/null +++ b/test/nlsolve_tests.jl @@ -0,0 +1,46 @@ +@testitem "Batched Nonlinear Solvers" setup=[SharedTestSetup] begin + using Chairmarks, ForwardDiff, SciMLBase, SimpleNonlinearSolve, Statistics, Zygote + + testing_f(u, p) = u .^ 2 .+ u .^ 3 .- u .- p + + u0 = rand(3, 128) + p = rand(1, 128) + + prob = NonlinearProblem(testing_f, u0, p) + + sol_nlsolve = solve(prob, SimpleNewtonRaphson()) + sol_batched = solve(prob, BatchedSimpleNewtonRaphson()) + + @test abs.(sol_nlsolve.u) ≈ abs.(sol_batched.u) + + nlsolve_timing = @be solve($prob, $SimpleNewtonRaphson()) + batched_timing = @be solve($prob, $BatchedSimpleNewtonRaphson()) + + @info "SimpleNonlinearSolve Timing: $(median(nlsolve_timing))." + @info "BatchedSimpleNewtonRaphson Timing: $(median(batched_timing))." + + ∂p1 = ForwardDiff.gradient(p) do p + prob = NonlinearProblem(testing_f, u0, p) + return sum(abs2, solve(prob, SimpleNewtonRaphson()).u) + end + + ∂p2 = ForwardDiff.gradient(p) do p + prob = NonlinearProblem(testing_f, u0, p) + return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson()).u) + end + + @test ∂p1 ≈ ∂p2 + + fwdiff_nlsolve_timing = @be ForwardDiff.gradient($p) do p + prob = NonlinearProblem(testing_f, u0, p) + return sum(abs2, solve(prob, SimpleNewtonRaphson()).u) + end + + fwdiff_batched_timing = @be ForwardDiff.gradient($p) do p + prob = NonlinearProblem(testing_f, u0, p) + return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson()).u) + end + + @info "ForwardDiff SimpleNonlinearSolve Timing: $(median(fwdiff_nlsolve_timing))." + @info "ForwardDiff BatchedNonlinearSolve Timing: $(median(fwdiff_batched_timing))." +end From b6f5491bf551518ad824fac553125f5f78bbaa4d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 3 Apr 2024 13:56:38 -0400 Subject: [PATCH 2/4] Add extensions for batched adjoint for Nonlinear Systems --- Project.toml | 9 +- .../BatchedRoutinesSciMLSensitivityExt.jl | 13 +++ .../steadystateadjoint.jl | 108 ++++++++++++++++++ src/BatchedRoutines.jl | 3 +- src/api.jl | 16 ++- src/nlsolve/utils.jl | 3 + src/operator.jl | 6 +- test/nlsolve_tests.jl | 37 +++++- 8 files changed, 179 insertions(+), 16 deletions(-) create mode 100644 ext/BatchedRoutinesSciMLSensitivityExt/BatchedRoutinesSciMLSensitivityExt.jl create mode 100644 ext/BatchedRoutinesSciMLSensitivityExt/steadystateadjoint.jl diff --git a/Project.toml b/Project.toml index 2597f11..00d086c 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -34,6 +35,7 @@ BatchedRoutinesFiniteDiffExt = ["FiniteDiff"] BatchedRoutinesForwardDiffExt = ["ForwardDiff"] BatchedRoutinesLinearSolveExt = ["LinearSolve"] BatchedRoutinesReverseDiffExt = ["ReverseDiff"] +BatchedRoutinesSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity", "Zygote"] BatchedRoutinesZygoteExt = ["Zygote"] [compat] @@ -58,7 +60,7 @@ LuxCUDA = "0.3.2" LuxDeviceUtils = "0.1.17" LuxTestUtils = "0.1.15" PrecompileTools = "1.2.0" -Random = "<0.0.1, 1" +Random = "1.10" ReTestItems = "1.23.1" ReverseDiff = "1.15" SciMLBase = "2.31" @@ -66,7 +68,7 @@ SciMLOperators = "0.3.8" SimpleNonlinearSolve = "1.7" StableRNGs = "1.0.1" Statistics = "1.11.1" -Test = "<0.0.1, 1" +Test = "1.10" Zygote = "0.6.69" julia = "1.10" @@ -86,6 +88,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -93,4 +96,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "Chairmarks", "ComponentArrays", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearSolve", "Lux", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "SciMLBase", "SimpleNonlinearSolve", "StableRNGs", "Statistics", "Test", "Zygote"] +test = ["Aqua", "Chairmarks", "ComponentArrays", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearSolve", "Lux", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "SciMLBase", "SciMLSensitivity", "SimpleNonlinearSolve", "StableRNGs", "Statistics", "Test", "Zygote"] diff --git a/ext/BatchedRoutinesSciMLSensitivityExt/BatchedRoutinesSciMLSensitivityExt.jl b/ext/BatchedRoutinesSciMLSensitivityExt/BatchedRoutinesSciMLSensitivityExt.jl new file mode 100644 index 0000000..3919062 --- /dev/null +++ b/ext/BatchedRoutinesSciMLSensitivityExt/BatchedRoutinesSciMLSensitivityExt.jl @@ -0,0 +1,13 @@ +module BatchedRoutinesSciMLSensitivityExt + +using ADTypes: AutoForwardDiff, AutoFiniteDiff +using BatchedRoutines: BatchedRoutines, BatchedNonlinearSolution +using FastClosures: @closure +using LinearSolve: LinearSolve +using SciMLBase: SciMLBase, NonlinearProblem, NonlinearSolution +using SciMLSensitivity: SciMLSensitivity, SteadyStateAdjoint, ZygoteVJP +using Zygote: Zygote + +include("steadystateadjoint.jl") + +end \ No newline at end of file diff --git a/ext/BatchedRoutinesSciMLSensitivityExt/steadystateadjoint.jl b/ext/BatchedRoutinesSciMLSensitivityExt/steadystateadjoint.jl new file mode 100644 index 0000000..a97f1ae --- /dev/null +++ b/ext/BatchedRoutinesSciMLSensitivityExt/steadystateadjoint.jl @@ -0,0 +1,108 @@ +import SciMLSensitivity: SteadyStateAdjointProblem, SteadyStateAdjointSensitivityFunction + +function SteadyStateAdjointProblem( + sol::BatchedNonlinearSolution, sensealg::SteadyStateAdjoint, alg, + dgdu::DG1=nothing, dgdp::DG2=nothing, g::G=nothing; kwargs...) where {DG1, DG2, G} + @assert sol.prob isa NonlinearProblem + (; f, p, u0) = sol.prob + f = SciMLBase.ODEFunction(f) + + @assert !SciMLBase.isinplace(sol.prob) "Adjoint for Batched Problems does not support \ + inplace problems." + @assert ndims(u0)==2 "u0 must be a matrix." + @assert dgdu!==nothing "`dgdu` must be specified. Automatic differentiation is not \ + currently implemented for this part." + @assert sensealg.autojacvec isa ZygoteVJP + + dgdu === nothing && + dgdp === nothing && + g === nothing && + error("Either `dgdu`, `dgdp`, or `g` must be specified.") + + needs_jac = ifelse(SciMLBase.has_adjoint(f), + false, + ifelse(sensealg.linsolve === nothing, size(u0, 1) ≤ 50, + SciMLSensitivity.__needs_concrete_A(sensealg.linsolve))) + + p === SciMLBase.NullParameters() && + error("Your model does not have parameters, and thus it is impossible to calculate \ + the derivative of the solution with respect to the parameters. Your model \ + must have parameters to use parameter sensitivity calculations!") + + # sense = SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp, + # f, f.colorvec, false) # Dont allocate the Jacobian yet in diffcache + # @show sense.vjp + y = sol.u + + if needs_jac + if SciMLBase.has_jac(f) + J = BatchedRoutines._wrap_batched_operator(f.jac(y, p, nothing)) + else + uf = SciMLBase.UJacobianWrapper(f, nothing, p) + if SciMLSensitivity.alg_autodiff(sensealg) + J = BatchedRoutines.batched_jacobian(AutoFiniteDiff(), uf, y) + else + J = BatchedRoutines.batched_jacobian(AutoForwardDiff(), uf, y) + end + end + end + + if dgdp === nothing && g === nothing + dgdu_val = similar(u0, length(u0)) + dgdp_val = nothing + else + dgdu_val, dgdp_val = similar(u0, length(u0)), similar(u0, length(p)) + end + + if dgdu !== nothing + dgdu(dgdu_val, y, p, nothing, nothing) + else + # TODO: Implement this part + error("Not implemented yet") + # if g !== nothing + # if dgdp_val !== nothing + # gradient!(vec(dgdu_val), diffcache.g[1], y, sensealg, + # diffcache.g_grad_config[1]) + # else + # gradient!(vec(dgdu_val), diffcache.g, y, sensealg, diffcache.g_grad_config) + # end + # end + end + + if !needs_jac # Construct an operator and use Jacobian-Free Linear Solve + error("Todo Jacobian Free Linear Solve") + # usize = size(y) + # __f = y -> vec(f(reshape(y, usize), p, nothing)) + # operator = VecJac(__f, vec(y); + # autodiff = get_autodiff_from_vjp(sensealg.autojacvec)) + # linear_problem = LinearProblem(operator, vec(dgdu_val); u0 = vec(λ)) + # solve(linear_problem, linsolve; alias_A = true, sensealg.linsolve_kwargs...) # u is vec(λ) + else + linear_problem = SciMLBase.LinearProblem(J', dgdu_val) + linsol = SciMLBase.solve( + linear_problem, sensealg.linsolve; alias_A=true, sensealg.linsolve_kwargs...) + λ = linsol.u + end + + _, pb_f = Zygote.pullback(@closure(p->vec(f(y, p, nothing))), p) + ∂p = only(pb_f(λ)) + ∂p === nothing && + !sensealg.autojacvec.allow_nothing && + throw(SciMLSensitivity.ZygoteVJPNothingError()) + + if g !== nothing || dgdp !== nothing + error("Not implemented yet") + # compute del g/del p + # if dgdp !== nothing + # dgdp(dgdp_val, y, p, nothing, nothing) + # else + # @unpack g_grad_config = diffcache + # gradient!(dgdp_val, diffcache.g[2], p, sensealg, g_grad_config[2]) + # end + # recursive_sub!(dgdp_val, vjp) + # return dgdp_val + else + SciMLSensitivity.recursive_neg!(∂p) + return ∂p + end +end diff --git a/src/BatchedRoutines.jl b/src/BatchedRoutines.jl index 304b477..39a9060 100644 --- a/src/BatchedRoutines.jl +++ b/src/BatchedRoutines.jl @@ -15,7 +15,8 @@ import PrecompileTools: @recompile_invalidations using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNonZero, mul!, pinv using LuxDeviceUtils: LuxDeviceUtils, get_device - using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem, ReturnCode + using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem, + NonlinearSolution, ReturnCode using SciMLOperators: SciMLOperators, AbstractSciMLOperator end diff --git a/src/api.jl b/src/api.jl index c68c3dd..9f73fa8 100644 --- a/src/api.jl +++ b/src/api.jl @@ -62,16 +62,24 @@ batched_mul!(C, A, B, α=true, β=false) = _batched_mul!(C, A, B, α, β) Transpose the first two dimensions of `X`. """ -batched_transpose(X::BatchedMatrix) = PermutedDimsArray(X, (2, 1, 3)) -batched_transpose(X::AbstractMatrix) = reshape(X, 1, size(X, 1), size(X, 2)) +function batched_transpose(X::BatchedMatrix; copy::Val{C}=Val(false)) where {C} + C || return PermutedDimsArray(X, (2, 1, 3)) + return permutedims(X, (2, 1, 3)) +end +function batched_transpose(X::AbstractMatrix; copy::Val{C}=Val(false)) where {C} + return reshape(X, 1, size(X, 1), size(X, 2)) +end """ batched_adjoint(X::AbstractArray{T, 3}) where {T} Adjoint the first two dimensions of `X`. """ -batched_adjoint(X::BatchedMatrix{<:Real}) = batched_transpose(X) -batched_adjoint(X::BatchedMatrix) = mapfoldl(adjoint, _cat3, batchview(X)) +batched_adjoint(X::BatchedMatrix{<:Real}; copy::Val{C}=Val(false)) where {C} = batched_transpose( + X; copy) +function batched_adjoint(X::BatchedMatrix; copy::Val{C}=Val(false)) where {C} + return mapfoldl(adjoint, _cat3, batchview(X)) +end """ nbatches(A::AbstractArray) diff --git a/src/nlsolve/utils.jl b/src/nlsolve/utils.jl index 6b80bbe..a0ad495 100644 --- a/src/nlsolve/utils.jl +++ b/src/nlsolve/utils.jl @@ -24,3 +24,6 @@ end @inline function __get_tolerance(abstol, ::Type{T}) where {T} return abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol end + +const BatchedNonlinearSolution{T, N, uType, R, P, O, uType2, S, Tr} = NonlinearSolution{ + T, N, uType, R, P, <:AbstractBatchedNonlinearAlgorithm, O, uType2, S, Tr} diff --git a/src/operator.jl b/src/operator.jl index 80e6469..8a52fa8 100644 --- a/src/operator.jl +++ b/src/operator.jl @@ -23,8 +23,8 @@ end for f in ( :batched_transpose, :batched_adjoint, :batched_inv, :batched_pinv, :batched_reshape) - @eval function $(f)(op::UniformBlockDiagonalOperator, args...) - return UniformBlockDiagonalOperator($(f)(op.data, args...)) + @eval function $(f)(op::UniformBlockDiagonalOperator, args...; kwargs...) + return UniformBlockDiagonalOperator($(f)(op.data, args...; kwargs...)) end end @@ -60,7 +60,7 @@ end for f in (:transpose, :adjoint) batched_f = Symbol("batched_", f) - @eval (Base.$(f))(op::UniformBlockDiagonalOperator) = $(batched_f)(op) + @eval (Base.$(f))(op::UniformBlockDiagonalOperator) = $(batched_f)(op; copy=Val(true)) end @inline function Base.size(op::UniformBlockDiagonalOperator) diff --git a/test/nlsolve_tests.jl b/test/nlsolve_tests.jl index 2916431..fd9eaf1 100644 --- a/test/nlsolve_tests.jl +++ b/test/nlsolve_tests.jl @@ -1,5 +1,6 @@ @testitem "Batched Nonlinear Solvers" setup=[SharedTestSetup] begin - using Chairmarks, ForwardDiff, SciMLBase, SimpleNonlinearSolve, Statistics, Zygote + using Chairmarks, ForwardDiff, SciMLBase, SciMLSensitivity, SimpleNonlinearSolve, + Statistics, Zygote testing_f(u, p) = u .^ 2 .+ u .^ 3 .- u .- p @@ -16,8 +17,8 @@ nlsolve_timing = @be solve($prob, $SimpleNewtonRaphson()) batched_timing = @be solve($prob, $BatchedSimpleNewtonRaphson()) - @info "SimpleNonlinearSolve Timing: $(median(nlsolve_timing))." - @info "BatchedSimpleNewtonRaphson Timing: $(median(batched_timing))." + println("SimpleNonlinearSolve Timing: $(median(nlsolve_timing)).") + println("BatchedSimpleNewtonRaphson Timing: $(median(batched_timing)).") ∂p1 = ForwardDiff.gradient(p) do p prob = NonlinearProblem(testing_f, u0, p) @@ -41,6 +42,32 @@ return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson()).u) end - @info "ForwardDiff SimpleNonlinearSolve Timing: $(median(fwdiff_nlsolve_timing))." - @info "ForwardDiff BatchedNonlinearSolve Timing: $(median(fwdiff_batched_timing))." + println("ForwardDiff SimpleNonlinearSolve Timing: $(median(fwdiff_nlsolve_timing)).") + println("ForwardDiff BatchedNonlinearSolve Timing: $(median(fwdiff_batched_timing)).") + + ∂p3 = only(Zygote.gradient(p) do p + prob = NonlinearProblem(testing_f, u0, p) + return sum(abs2, solve(prob, SimpleNewtonRaphson()).u) + end) + + ∂p4 = only(Zygote.gradient(p) do p + prob = NonlinearProblem(testing_f, u0, p) + return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson()).u) + end) + + @test ∂p3 ≈ ∂p4 + @test ∂p1 ≈ ∂p4 + + zygote_nlsolve_timing = @be Zygote.gradient($p) do p + prob = NonlinearProblem(testing_f, u0, p) + return sum(abs2, solve(prob, SimpleNewtonRaphson()).u) + end + + zygote_batched_timing = @be Zygote.gradient($p) do p + prob = NonlinearProblem(testing_f, u0, p) + return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson()).u) + end + + println("Zygote SimpleNonlinearSolve Timing: $(median(zygote_nlsolve_timing)).") + println("Zygote BatchedNonlinearSolve Timing: $(median(zygote_batched_timing)).") end From 3022cf9a95f63dbe4a59ea5ec25fa7c9ccc5c28e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 3 Apr 2024 14:27:03 -0400 Subject: [PATCH 3/4] Add the jacobian free version --- Project.toml | 1 + .../BatchedRoutinesSciMLSensitivityExt.jl | 2 +- .../steadystateadjoint.jl | 42 ++++++------------- test/nlsolve_tests.jl | 15 +++++-- 4 files changed, 26 insertions(+), 34 deletions(-) diff --git a/Project.toml b/Project.toml index 00d086c..3439429 100644 --- a/Project.toml +++ b/Project.toml @@ -65,6 +65,7 @@ ReTestItems = "1.23.1" ReverseDiff = "1.15" SciMLBase = "2.31" SciMLOperators = "0.3.8" +SciMLSensitivity = "7.56" SimpleNonlinearSolve = "1.7" StableRNGs = "1.0.1" Statistics = "1.11.1" diff --git a/ext/BatchedRoutinesSciMLSensitivityExt/BatchedRoutinesSciMLSensitivityExt.jl b/ext/BatchedRoutinesSciMLSensitivityExt/BatchedRoutinesSciMLSensitivityExt.jl index 3919062..95ace1d 100644 --- a/ext/BatchedRoutinesSciMLSensitivityExt/BatchedRoutinesSciMLSensitivityExt.jl +++ b/ext/BatchedRoutinesSciMLSensitivityExt/BatchedRoutinesSciMLSensitivityExt.jl @@ -10,4 +10,4 @@ using Zygote: Zygote include("steadystateadjoint.jl") -end \ No newline at end of file +end diff --git a/ext/BatchedRoutinesSciMLSensitivityExt/steadystateadjoint.jl b/ext/BatchedRoutinesSciMLSensitivityExt/steadystateadjoint.jl index a97f1ae..6df0ef9 100644 --- a/ext/BatchedRoutinesSciMLSensitivityExt/steadystateadjoint.jl +++ b/ext/BatchedRoutinesSciMLSensitivityExt/steadystateadjoint.jl @@ -29,9 +29,6 @@ function SteadyStateAdjointProblem( the derivative of the solution with respect to the parameters. Your model \ must have parameters to use parameter sensitivity calculations!") - # sense = SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp, - # f, f.colorvec, false) # Dont allocate the Jacobian yet in diffcache - # @show sense.vjp y = sol.u if needs_jac @@ -57,32 +54,28 @@ function SteadyStateAdjointProblem( if dgdu !== nothing dgdu(dgdu_val, y, p, nothing, nothing) else - # TODO: Implement this part error("Not implemented yet") - # if g !== nothing - # if dgdp_val !== nothing - # gradient!(vec(dgdu_val), diffcache.g[1], y, sensealg, - # diffcache.g_grad_config[1]) - # else - # gradient!(vec(dgdu_val), diffcache.g, y, sensealg, diffcache.g_grad_config) - # end - # end end if !needs_jac # Construct an operator and use Jacobian-Free Linear Solve - error("Todo Jacobian Free Linear Solve") - # usize = size(y) - # __f = y -> vec(f(reshape(y, usize), p, nothing)) - # operator = VecJac(__f, vec(y); - # autodiff = get_autodiff_from_vjp(sensealg.autojacvec)) - # linear_problem = LinearProblem(operator, vec(dgdu_val); u0 = vec(λ)) - # solve(linear_problem, linsolve; alias_A = true, sensealg.linsolve_kwargs...) # u is vec(λ) + linsolve = if sensealg.linsolve === nothing + LinearSolve.SimpleGMRES(; blocksize=size(u0, 1)) + else + sensealg.linsolve + end + usize = size(y) + __f = @closure y -> vec(f(reshape(y, usize), p, nothing)) + operator = SciMLSensitivity.VecJac(__f, vec(y); + autodiff=SciMLSensitivity.get_autodiff_from_vjp(sensealg.autojacvec)) + linear_problem = SciMLBase.LinearProblem(operator, dgdu_val) + linsol = SciMLBase.solve( + linear_problem, linsolve; alias_A=true, sensealg.linsolve_kwargs...) else linear_problem = SciMLBase.LinearProblem(J', dgdu_val) linsol = SciMLBase.solve( linear_problem, sensealg.linsolve; alias_A=true, sensealg.linsolve_kwargs...) - λ = linsol.u end + λ = linsol.u _, pb_f = Zygote.pullback(@closure(p->vec(f(y, p, nothing))), p) ∂p = only(pb_f(λ)) @@ -92,15 +85,6 @@ function SteadyStateAdjointProblem( if g !== nothing || dgdp !== nothing error("Not implemented yet") - # compute del g/del p - # if dgdp !== nothing - # dgdp(dgdp_val, y, p, nothing, nothing) - # else - # @unpack g_grad_config = diffcache - # gradient!(dgdp_val, diffcache.g[2], p, sensealg, g_grad_config[2]) - # end - # recursive_sub!(dgdp_val, vjp) - # return dgdp_val else SciMLSensitivity.recursive_neg!(∂p) return ∂p diff --git a/test/nlsolve_tests.jl b/test/nlsolve_tests.jl index fd9eaf1..110ab61 100644 --- a/test/nlsolve_tests.jl +++ b/test/nlsolve_tests.jl @@ -1,6 +1,6 @@ @testitem "Batched Nonlinear Solvers" setup=[SharedTestSetup] begin - using Chairmarks, ForwardDiff, SciMLBase, SciMLSensitivity, SimpleNonlinearSolve, - Statistics, Zygote + using Chairmarks, ForwardDiff, LinearSolve, SciMLBase, SciMLSensitivity, + SimpleNonlinearSolve, Statistics, Zygote testing_f(u, p) = u .^ 2 .+ u .^ 3 .- u .- p @@ -55,8 +55,15 @@ return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson()).u) end) - @test ∂p3 ≈ ∂p4 - @test ∂p1 ≈ ∂p4 + ∂p5 = only(Zygote.gradient(p) do p + prob = NonlinearProblem(testing_f, u0, p) + sensealg = SteadyStateAdjoint(; linsolve=KrylovJL_GMRES()) + return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson(); sensealg).u) + end) + + @test ∂p1≈∂p3 atol=1e-5 + @test ∂p3≈∂p4 atol=1e-5 + @test ∂p4≈∂p5 atol=1e-5 zygote_nlsolve_timing = @be Zygote.gradient($p) do p prob = NonlinearProblem(testing_f, u0, p) From a37f69464492f1d7bc2fb22ea093b86698fd3e0d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 18 Apr 2024 14:18:03 -0400 Subject: [PATCH 4/4] Make nested AD work for non-compiled ReverseDiff --- ext/BatchedRoutinesReverseDiffExt.jl | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/ext/BatchedRoutinesReverseDiffExt.jl b/ext/BatchedRoutinesReverseDiffExt.jl index 6c0678c..de2f982 100644 --- a/ext/BatchedRoutinesReverseDiffExt.jl +++ b/ext/BatchedRoutinesReverseDiffExt.jl @@ -2,7 +2,8 @@ module BatchedRoutinesReverseDiffExt using ADTypes: AutoReverseDiff, AutoForwardDiff using ArrayInterface: ArrayInterface -using BatchedRoutines: BatchedRoutines, batched_pickchunksize, _assert_type +using BatchedRoutines: BatchedRoutines, batched_pickchunksize, _assert_type, + UniformBlockDiagonalOperator, getdata using ChainRulesCore: ChainRulesCore, NoTangent using ConcreteStructs: @concrete using FastClosures: @closure @@ -30,6 +31,21 @@ function BatchedRoutines._batched_gradient(::AutoReverseDiff, f::F, u) where {F} return ∂u end +# ReverseDiff compatible `UniformBlockDiagonalOperator` +@inline function ReverseDiff.track( + op::UniformBlockDiagonalOperator, tp::ReverseDiff.InstructionTape) + return UniformBlockDiagonalOperator(ReverseDiff.track(getdata(op), tp)) +end + +@inline function ReverseDiff.deriv(x::UniformBlockDiagonalOperator) + return UniformBlockDiagonalOperator(ReverseDiff.deriv(getdata(x))) +end + +@inline function ReverseDiff.value!( + op::UniformBlockDiagonalOperator, val::UniformBlockDiagonalOperator) + ReverseDiff.value!(getdata(op), getdata(val)) +end + # Chain rules integration function BatchedRoutines.batched_jacobian( ad, f::F, x::AbstractMatrix{<:ReverseDiff.TrackedReal}) where {F}