Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Add Batched Nonlinear Solvers for Forward AD rules
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 2, 2024
1 parent 54b894e commit 50e683e
Show file tree
Hide file tree
Showing 10 changed files with 204 additions and 25 deletions.
11 changes: 9 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"

[weakdeps]
Expand All @@ -27,8 +28,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
BatchedRoutinesCUDAExt = ["CUDA"]
BatchedRoutinesComponentArraysForwardDiffExt = ["ComponentArrays", "ForwardDiff"]
BatchedRoutinesCUDALinearSolveExt = ["CUDA", "LinearSolve"]
BatchedRoutinesComponentArraysForwardDiffExt = ["ComponentArrays", "ForwardDiff"]
BatchedRoutinesFiniteDiffExt = ["FiniteDiff"]
BatchedRoutinesForwardDiffExt = ["ForwardDiff"]
BatchedRoutinesLinearSolveExt = ["LinearSolve"]
Expand All @@ -42,6 +43,7 @@ Aqua = "0.8.4"
ArrayInterface = "7.8.1"
CUDA = "5.2.0"
ChainRulesCore = "1.23"
Chairmarks = "1.2"
ComponentArrays = "0.15.10"
ConcreteStructs = "0.2.3"
ExplicitImports = "1.4.0"
Expand All @@ -59,7 +61,9 @@ PrecompileTools = "1.2.0"
Random = "<0.0.1, 1"
ReTestItems = "1.23.1"
ReverseDiff = "1.15"
SciMLBase = "2.31"
SciMLOperators = "0.3.8"
SimpleNonlinearSolve = "1.7"
StableRNGs = "1.0.1"
Statistics = "1.11.1"
Test = "<0.0.1, 1"
Expand All @@ -68,6 +72,7 @@ julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
Expand All @@ -80,10 +85,12 @@ LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ComponentArrays", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearSolve", "Lux", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "StableRNGs", "Statistics", "Test", "Zygote"]
test = ["Aqua", "Chairmarks", "ComponentArrays", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearSolve", "Lux", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "SciMLBase", "SimpleNonlinearSolve", "StableRNGs", "Statistics", "Test", "Zygote"]
24 changes: 24 additions & 0 deletions ext/BatchedRoutinesForwardDiffExt/BatchedRoutinesForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module BatchedRoutinesForwardDiffExt

using ADTypes: AutoForwardDiff
using ArrayInterface: parameterless_type
using BatchedRoutines: BatchedRoutines, AbstractBatchedNonlinearAlgorithm,
UniformBlockDiagonalOperator, batched_jacobian, batched_mul,
batched_pickchunksize, _assert_type
using ChainRulesCore: ChainRulesCore
using FastClosures: @closure
using ForwardDiff: ForwardDiff, Dual
using LinearAlgebra: LinearAlgebra
using LuxDeviceUtils: LuxDeviceUtils, get_device
using SciMLBase: SciMLBase, NonlinearProblem

const CRC = ChainRulesCore

@inline BatchedRoutines._is_extension_loaded(::Val{:ForwardDiff}) = true

@inline BatchedRoutines.__can_forwarddiff_dual(::Type{T}) where {T} = ForwardDiff.can_dual(T)

include("jacobian.jl")
include("nonlinearsolve_ad.jl")

end
Original file line number Diff line number Diff line change
@@ -1,18 +1,3 @@
module BatchedRoutinesForwardDiffExt

using ADTypes: AutoForwardDiff
using ArrayInterface: parameterless_type
using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, batched_jacobian,
batched_mul, batched_pickchunksize, _assert_type
using ChainRulesCore: ChainRulesCore
using FastClosures: @closure
using ForwardDiff: ForwardDiff
using LuxDeviceUtils: LuxDeviceUtils, get_device

const CRC = ChainRulesCore

@inline BatchedRoutines._is_extension_loaded(::Val{:ForwardDiff}) = true

# api.jl
function BatchedRoutines.batched_pickchunksize(
X::AbstractArray{T, N}, n::Int=ForwardDiff.DEFAULT_CHUNK_THRESHOLD) where {T, N}
Expand Down Expand Up @@ -241,6 +226,4 @@ end
T = promote_type(eltype(x), eltype(u))
partials = ForwardDiff.Partials{1, T}.(tuple.(u))
return ForwardDiff.Dual{Tag, T, 1}.(x, reshape(partials, size(x)))
end

end
end
44 changes: 44 additions & 0 deletions ext/BatchedRoutinesForwardDiffExt/nonlinearsolve_ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
function SciMLBase.solve(
prob::NonlinearProblem{<:AbstractArray, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractBatchedNonlinearAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end

function __nlsolve_ad(prob::NonlinearProblem, alg, args...; kwargs...)
p = ForwardDiff.value.(prob.p)
u0 = ForwardDiff.value.(prob.u0)
newprob = NonlinearProblem(prob.f, u0, p; prob.kwargs...)

sol = SciMLBase.solve(newprob, alg, args...; kwargs...)

uu = sol.u
Jₚ = ForwardDiff.jacobian(Base.Fix1(prob.f, uu), p)
Jᵤ = if prob.f.jac === nothing
BatchedRoutines.batched_jacobian(AutoForwardDiff(), prob.f, uu, p)
else
BatchedRoutines._wrap_batched_operator(prob.f.jac(uu, p))

Check warning on line 25 in ext/BatchedRoutinesForwardDiffExt/nonlinearsolve_ad.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesForwardDiffExt/nonlinearsolve_ad.jl#L25

Added line #L25 was not covered by tests
end

Jᵤ_fact = LinearAlgebra.lu!(Jᵤ)

map_fn = @closure zp -> begin
Jₚᵢ, p = zp
LinearAlgebra.ldiv!(Jᵤ_fact, Jₚᵢ)
Jₚᵢ .*= -1
return map(Base.Fix2(*, ForwardDiff.partials(p)), Jₚᵢ)
end

return sol, sum(map_fn, zip(eachcol(Jₚ), prob.p))
end

@inline function __nlsolve_dual_soln(u::AbstractArray, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
_partials = reshape(partials, size(u))
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, _partials))
end
5 changes: 3 additions & 2 deletions ext/BatchedRoutinesLinearSolveExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, getdata
using ChainRulesCore: ChainRulesCore, NoTangent
using FastClosures: @closure
using LinearAlgebra: LinearAlgebra
using LinearSolve: LinearSolve, SciMLBase
using LinearSolve: LinearSolve
using SciMLBase: SciMLBase

const CRC = ChainRulesCore

Expand Down Expand Up @@ -113,7 +114,7 @@ function LinearSolve.solve!(cache::LinearSolve.LinearCache{<:UniformBlockDiagona
y = LinearAlgebra.ldiv!(
cache.u, LinearSolve.@get_cacheval(cache, :NormalCholeskyFactorization),
A' * cache.b)
return LinearSolve.SciMLBase.build_linear_solution(alg, y, nothing, cache)
return SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

# SVDFactorization
Expand Down
18 changes: 15 additions & 3 deletions src/BatchedRoutines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ module BatchedRoutines
import PrecompileTools: @recompile_invalidations

@recompile_invalidations begin
using ADTypes: AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoSparseForwardDiff,
AutoSparsePolyesterForwardDiff, AutoPolyesterForwardDiff, AutoZygote
using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
AutoReverseDiff, AutoSparseForwardDiff, AutoSparsePolyesterForwardDiff,
AutoPolyesterForwardDiff, AutoZygote
using Adapt: Adapt
using ArrayInterface: ArrayInterface, parameterless_type
using ChainRulesCore: ChainRulesCore, HasReverseMode, NoTangent, RuleConfig
Expand All @@ -14,6 +15,7 @@ import PrecompileTools: @recompile_invalidations
using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNonZero,
mul!, pinv
using LuxDeviceUtils: LuxDeviceUtils, get_device
using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem, ReturnCode
using SciMLOperators: SciMLOperators, AbstractSciMLOperator
end

Expand All @@ -40,23 +42,33 @@ const AutoAllForwardDiff{CK} = Union{<:AutoForwardDiff{CK}, <:AutoSparseForwardD
const BatchedVector{T} = AbstractMatrix{T}
const BatchedMatrix{T} = AbstractArray{T, 3}

abstract type AbstractBatchedNonlinearAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end

@inline _is_extension_loaded(::Val) = false

include("operator.jl")

include("api.jl")
include("helpers.jl")

include("operator.jl")
include("factorization.jl")

include("nlsolve/utils.jl")
include("nlsolve/batched_raphson.jl")

include("impl/batched_mul.jl")
include("impl/batched_gmres.jl")

include("chainrules.jl")

# Core
export AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoZygote
export batched_adjoint, batched_gradient, batched_jacobian, batched_pickchunksize,
batched_mul, batched_pinv, batched_transpose
export batchview, nbatches
export UniformBlockDiagonalOperator

# Nonlinear Solvers
export BatchedSimpleNewtonRaphson, BatchedSimpleGaussNewton

end
3 changes: 3 additions & 0 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,6 @@ end
CRC.@non_differentiable fill_like(::Any...)
CRC.@non_differentiable zeros_like(::Any...)
CRC.@non_differentiable ones_like(::Any...)

@inline _wrap_batched_operator(x::AbstractArray{T, 3}) where {T} = UniformBlockDiagonalOperator(x)
@inline _wrap_batched_operator(x::UniformBlockDiagonalOperator) = x

Check warning on line 172 in src/helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/helpers.jl#L171-L172

Added lines #L171 - L172 were not covered by tests
33 changes: 33 additions & 0 deletions src/nlsolve/batched_raphson.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
@kwdef @concrete struct BatchedSimpleNewtonRaphson <: AbstractBatchedNonlinearAlgorithm
autodiff = nothing
end

const BatchedSimpleGaussNewton = BatchedSimpleNewtonRaphson

function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
alg::BatchedSimpleNewtonRaphson, args...;
abstol=nothing, maxiters::Int=1000, kwargs...)
@assert !SciMLBase.isinplace(prob) "BatchedSimpleNewtonRaphson does not support inplace."

x = deepcopy(prob.u0)
fx = prob.f(x, prob.p)
@assert (ndims(x) == 2)&&(ndims(fx) == 2) "BatchedSimpleNewtonRaphson only supports matrices."

autodiff = __get_concrete_autodiff(prob, alg.autodiff)
abstol = __get_tolerance(abstol, x)

maximum(abs, fx) < abstol &&
return SciMLBase.build_solution(prob, alg, x, fx; retcode=ReturnCode.Success)

for _ in 1:maxiters
fx, J = __value_and_jacobian(prob, x, autodiff)
δx = J \ fx

maximum(abs, fx) < abstol &&
return SciMLBase.build_solution(prob, alg, x, fx; retcode=ReturnCode.Success)

@. x -= δx
end

return SciMLBase.build_solution(prob, alg, x, fx; retcode=ReturnCode.MaxIters)

Check warning on line 32 in src/nlsolve/batched_raphson.jl

View check run for this annotation

Codecov / codecov/patch

src/nlsolve/batched_raphson.jl#L32

Added line #L32 was not covered by tests
end
26 changes: 26 additions & 0 deletions src/nlsolve/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
@inline __get_concrete_autodiff(prob, ad::AbstractADType; kwargs...) = ad

Check warning on line 1 in src/nlsolve/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/nlsolve/utils.jl#L1

Added line #L1 was not covered by tests
@inline function __get_concrete_autodiff(prob, ::Nothing; kwargs...)
prob.f.jac !== nothing && return nothing
if _is_extension_loaded(Val(:ForwardDiff)) && __can_forwarddiff_dual(eltype(prob.u0))
return AutoForwardDiff()
elseif _is_extension_loaded(Val(:FiniteDiff))
return AutoFiniteDiff()

Check warning on line 7 in src/nlsolve/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/nlsolve/utils.jl#L6-L7

Added lines #L6 - L7 were not covered by tests
else
error("No AD backend loaded. Please load an AD backend first.")

Check warning on line 9 in src/nlsolve/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/nlsolve/utils.jl#L9

Added line #L9 was not covered by tests
end
end

function __can_forwarddiff_dual end

@inline function __value_and_jacobian(prob, x, autodiff)
if prob.f.jac === nothing
return prob.f(x, prob.p), batched_jacobian(autodiff, prob.f, x, prob.p)
else
return prob.f(x, prob.p), _wrap_batched_operator(prob.f.jac(x, prob.p))

Check warning on line 19 in src/nlsolve/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/nlsolve/utils.jl#L19

Added line #L19 was not covered by tests
end
end

@inline __get_tolerance(abstol, u0) = __get_tolerance(abstol, eltype(u0))
@inline function __get_tolerance(abstol, ::Type{T}) where {T}
return abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
end
46 changes: 46 additions & 0 deletions test/nlsolve_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
@testitem "Batched Nonlinear Solvers" setup=[SharedTestSetup] begin
using Chairmarks, ForwardDiff, SciMLBase, SimpleNonlinearSolve, Statistics, Zygote

testing_f(u, p) = u .^ 2 .+ u .^ 3 .- u .- p

u0 = rand(3, 128)
p = rand(1, 128)

prob = NonlinearProblem(testing_f, u0, p)

sol_nlsolve = solve(prob, SimpleNewtonRaphson())
sol_batched = solve(prob, BatchedSimpleNewtonRaphson())

@test abs.(sol_nlsolve.u) abs.(sol_batched.u)

nlsolve_timing = @be solve($prob, $SimpleNewtonRaphson())
batched_timing = @be solve($prob, $BatchedSimpleNewtonRaphson())

@info "SimpleNonlinearSolve Timing: $(median(nlsolve_timing))."
@info "BatchedSimpleNewtonRaphson Timing: $(median(batched_timing))."

∂p1 = ForwardDiff.gradient(p) do p
prob = NonlinearProblem(testing_f, u0, p)
return sum(abs2, solve(prob, SimpleNewtonRaphson()).u)
end

∂p2 = ForwardDiff.gradient(p) do p
prob = NonlinearProblem(testing_f, u0, p)
return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson()).u)
end

@test ∂p1 ∂p2

fwdiff_nlsolve_timing = @be ForwardDiff.gradient($p) do p
prob = NonlinearProblem(testing_f, u0, p)
return sum(abs2, solve(prob, SimpleNewtonRaphson()).u)
end

fwdiff_batched_timing = @be ForwardDiff.gradient($p) do p
prob = NonlinearProblem(testing_f, u0, p)
return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson()).u)
end

@info "ForwardDiff SimpleNonlinearSolve Timing: $(median(fwdiff_nlsolve_timing))."
@info "ForwardDiff BatchedNonlinearSolve Timing: $(median(fwdiff_batched_timing))."
end

0 comments on commit 50e683e

Please sign in to comment.