Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Add extensions for batched adjoint for Nonlinear Systems
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 3, 2024
1 parent 05b7a58 commit b6f5491
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 16 deletions.
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
Expand All @@ -34,6 +35,7 @@ BatchedRoutinesFiniteDiffExt = ["FiniteDiff"]
BatchedRoutinesForwardDiffExt = ["ForwardDiff"]
BatchedRoutinesLinearSolveExt = ["LinearSolve"]
BatchedRoutinesReverseDiffExt = ["ReverseDiff"]
BatchedRoutinesSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity", "Zygote"]
BatchedRoutinesZygoteExt = ["Zygote"]

[compat]
Expand All @@ -58,15 +60,15 @@ LuxCUDA = "0.3.2"
LuxDeviceUtils = "0.1.17"
LuxTestUtils = "0.1.15"
PrecompileTools = "1.2.0"
Random = "<0.0.1, 1"
Random = "1.10"
ReTestItems = "1.23.1"
ReverseDiff = "1.15"
SciMLBase = "2.31"
SciMLOperators = "0.3.8"
SimpleNonlinearSolve = "1.7"
StableRNGs = "1.0.1"
Statistics = "1.11.1"
Test = "<0.0.1, 1"
Test = "1.10"
Zygote = "0.6.69"
julia = "1.10"

Expand All @@ -86,11 +88,12 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Chairmarks", "ComponentArrays", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearSolve", "Lux", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "SciMLBase", "SimpleNonlinearSolve", "StableRNGs", "Statistics", "Test", "Zygote"]
test = ["Aqua", "Chairmarks", "ComponentArrays", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearSolve", "Lux", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "SciMLBase", "SciMLSensitivity", "SimpleNonlinearSolve", "StableRNGs", "Statistics", "Test", "Zygote"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module BatchedRoutinesSciMLSensitivityExt

using ADTypes: AutoForwardDiff, AutoFiniteDiff
using BatchedRoutines: BatchedRoutines, BatchedNonlinearSolution
using FastClosures: @closure
using LinearSolve: LinearSolve
using SciMLBase: SciMLBase, NonlinearProblem, NonlinearSolution
using SciMLSensitivity: SciMLSensitivity, SteadyStateAdjoint, ZygoteVJP
using Zygote: Zygote

include("steadystateadjoint.jl")

end
108 changes: 108 additions & 0 deletions ext/BatchedRoutinesSciMLSensitivityExt/steadystateadjoint.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import SciMLSensitivity: SteadyStateAdjointProblem, SteadyStateAdjointSensitivityFunction

function SteadyStateAdjointProblem(
sol::BatchedNonlinearSolution, sensealg::SteadyStateAdjoint, alg,
dgdu::DG1=nothing, dgdp::DG2=nothing, g::G=nothing; kwargs...) where {DG1, DG2, G}
@assert sol.prob isa NonlinearProblem
(; f, p, u0) = sol.prob
f = SciMLBase.ODEFunction(f)

@assert !SciMLBase.isinplace(sol.prob) "Adjoint for Batched Problems does not support \
inplace problems."
@assert ndims(u0)==2 "u0 must be a matrix."
@assert dgdu!==nothing "`dgdu` must be specified. Automatic differentiation is not \
currently implemented for this part."
@assert sensealg.autojacvec isa ZygoteVJP

dgdu === nothing &&
dgdp === nothing &&
g === nothing &&
error("Either `dgdu`, `dgdp`, or `g` must be specified.")

needs_jac = ifelse(SciMLBase.has_adjoint(f),
false,
ifelse(sensealg.linsolve === nothing, size(u0, 1) 50,
SciMLSensitivity.__needs_concrete_A(sensealg.linsolve)))

p === SciMLBase.NullParameters() &&
error("Your model does not have parameters, and thus it is impossible to calculate \
the derivative of the solution with respect to the parameters. Your model \
must have parameters to use parameter sensitivity calculations!")

# sense = SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp,
# f, f.colorvec, false) # Dont allocate the Jacobian yet in diffcache
# @show sense.vjp
y = sol.u

if needs_jac
if SciMLBase.has_jac(f)
J = BatchedRoutines._wrap_batched_operator(f.jac(y, p, nothing))
else
uf = SciMLBase.UJacobianWrapper(f, nothing, p)
if SciMLSensitivity.alg_autodiff(sensealg)
J = BatchedRoutines.batched_jacobian(AutoFiniteDiff(), uf, y)
else
J = BatchedRoutines.batched_jacobian(AutoForwardDiff(), uf, y)
end
end
end

if dgdp === nothing && g === nothing
dgdu_val = similar(u0, length(u0))
dgdp_val = nothing
else
dgdu_val, dgdp_val = similar(u0, length(u0)), similar(u0, length(p))
end

if dgdu !== nothing
dgdu(dgdu_val, y, p, nothing, nothing)
else
# TODO: Implement this part
error("Not implemented yet")
# if g !== nothing
# if dgdp_val !== nothing
# gradient!(vec(dgdu_val), diffcache.g[1], y, sensealg,
# diffcache.g_grad_config[1])
# else
# gradient!(vec(dgdu_val), diffcache.g, y, sensealg, diffcache.g_grad_config)
# end
# end
end

if !needs_jac # Construct an operator and use Jacobian-Free Linear Solve
error("Todo Jacobian Free Linear Solve")
# usize = size(y)
# __f = y -> vec(f(reshape(y, usize), p, nothing))
# operator = VecJac(__f, vec(y);
# autodiff = get_autodiff_from_vjp(sensealg.autojacvec))
# linear_problem = LinearProblem(operator, vec(dgdu_val); u0 = vec(λ))
# solve(linear_problem, linsolve; alias_A = true, sensealg.linsolve_kwargs...) # u is vec(λ)
else
linear_problem = SciMLBase.LinearProblem(J', dgdu_val)
linsol = SciMLBase.solve(
linear_problem, sensealg.linsolve; alias_A=true, sensealg.linsolve_kwargs...)
λ = linsol.u
end

_, pb_f = Zygote.pullback(@closure(p->vec(f(y, p, nothing))), p)
∂p = only(pb_f(λ))
∂p === nothing &&
!sensealg.autojacvec.allow_nothing &&
throw(SciMLSensitivity.ZygoteVJPNothingError())

if g !== nothing || dgdp !== nothing
error("Not implemented yet")
# compute del g/del p
# if dgdp !== nothing
# dgdp(dgdp_val, y, p, nothing, nothing)
# else
# @unpack g_grad_config = diffcache
# gradient!(dgdp_val, diffcache.g[2], p, sensealg, g_grad_config[2])
# end
# recursive_sub!(dgdp_val, vjp)
# return dgdp_val
else
SciMLSensitivity.recursive_neg!(∂p)
return ∂p
end
end
3 changes: 2 additions & 1 deletion src/BatchedRoutines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import PrecompileTools: @recompile_invalidations
using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNonZero,
mul!, pinv
using LuxDeviceUtils: LuxDeviceUtils, get_device
using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem, ReturnCode
using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem,
NonlinearSolution, ReturnCode
using SciMLOperators: SciMLOperators, AbstractSciMLOperator
end

Expand Down
16 changes: 12 additions & 4 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,24 @@ batched_mul!(C, A, B, α=true, β=false) = _batched_mul!(C, A, B, α, β)
Transpose the first two dimensions of `X`.
"""
batched_transpose(X::BatchedMatrix) = PermutedDimsArray(X, (2, 1, 3))
batched_transpose(X::AbstractMatrix) = reshape(X, 1, size(X, 1), size(X, 2))
function batched_transpose(X::BatchedMatrix; copy::Val{C}=Val(false)) where {C}
C || return PermutedDimsArray(X, (2, 1, 3))
return permutedims(X, (2, 1, 3))
end
function batched_transpose(X::AbstractMatrix; copy::Val{C}=Val(false)) where {C}
return reshape(X, 1, size(X, 1), size(X, 2))
end

"""
batched_adjoint(X::AbstractArray{T, 3}) where {T}
Adjoint the first two dimensions of `X`.
"""
batched_adjoint(X::BatchedMatrix{<:Real}) = batched_transpose(X)
batched_adjoint(X::BatchedMatrix) = mapfoldl(adjoint, _cat3, batchview(X))
batched_adjoint(X::BatchedMatrix{<:Real}; copy::Val{C}=Val(false)) where {C} = batched_transpose(
X; copy)
function batched_adjoint(X::BatchedMatrix; copy::Val{C}=Val(false)) where {C}
return mapfoldl(adjoint, _cat3, batchview(X))
end

"""
nbatches(A::AbstractArray)
Expand Down
3 changes: 3 additions & 0 deletions src/nlsolve/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ end
@inline function __get_tolerance(abstol, ::Type{T}) where {T}
return abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
end

const BatchedNonlinearSolution{T, N, uType, R, P, O, uType2, S, Tr} = NonlinearSolution{
T, N, uType, R, P, <:AbstractBatchedNonlinearAlgorithm, O, uType2, S, Tr}
6 changes: 3 additions & 3 deletions src/operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ end

for f in (
:batched_transpose, :batched_adjoint, :batched_inv, :batched_pinv, :batched_reshape)
@eval function $(f)(op::UniformBlockDiagonalOperator, args...)
return UniformBlockDiagonalOperator($(f)(op.data, args...))
@eval function $(f)(op::UniformBlockDiagonalOperator, args...; kwargs...)
return UniformBlockDiagonalOperator($(f)(op.data, args...; kwargs...))
end
end

Expand Down Expand Up @@ -60,7 +60,7 @@ end

for f in (:transpose, :adjoint)
batched_f = Symbol("batched_", f)
@eval (Base.$(f))(op::UniformBlockDiagonalOperator) = $(batched_f)(op)
@eval (Base.$(f))(op::UniformBlockDiagonalOperator) = $(batched_f)(op; copy=Val(true))
end

@inline function Base.size(op::UniformBlockDiagonalOperator)
Expand Down
37 changes: 32 additions & 5 deletions test/nlsolve_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@testitem "Batched Nonlinear Solvers" setup=[SharedTestSetup] begin
using Chairmarks, ForwardDiff, SciMLBase, SimpleNonlinearSolve, Statistics, Zygote
using Chairmarks, ForwardDiff, SciMLBase, SciMLSensitivity, SimpleNonlinearSolve,
Statistics, Zygote

testing_f(u, p) = u .^ 2 .+ u .^ 3 .- u .- p

Expand All @@ -16,8 +17,8 @@
nlsolve_timing = @be solve($prob, $SimpleNewtonRaphson())
batched_timing = @be solve($prob, $BatchedSimpleNewtonRaphson())

@info "SimpleNonlinearSolve Timing: $(median(nlsolve_timing))."
@info "BatchedSimpleNewtonRaphson Timing: $(median(batched_timing))."
println("SimpleNonlinearSolve Timing: $(median(nlsolve_timing)).")
println("BatchedSimpleNewtonRaphson Timing: $(median(batched_timing)).")

∂p1 = ForwardDiff.gradient(p) do p
prob = NonlinearProblem(testing_f, u0, p)
Expand All @@ -41,6 +42,32 @@
return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson()).u)
end

@info "ForwardDiff SimpleNonlinearSolve Timing: $(median(fwdiff_nlsolve_timing))."
@info "ForwardDiff BatchedNonlinearSolve Timing: $(median(fwdiff_batched_timing))."
println("ForwardDiff SimpleNonlinearSolve Timing: $(median(fwdiff_nlsolve_timing)).")
println("ForwardDiff BatchedNonlinearSolve Timing: $(median(fwdiff_batched_timing)).")

∂p3 = only(Zygote.gradient(p) do p
prob = NonlinearProblem(testing_f, u0, p)
return sum(abs2, solve(prob, SimpleNewtonRaphson()).u)
end)

∂p4 = only(Zygote.gradient(p) do p
prob = NonlinearProblem(testing_f, u0, p)
return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson()).u)
end)

@test ∂p3 ∂p4
@test ∂p1 ∂p4

zygote_nlsolve_timing = @be Zygote.gradient($p) do p
prob = NonlinearProblem(testing_f, u0, p)
return sum(abs2, solve(prob, SimpleNewtonRaphson()).u)
end

zygote_batched_timing = @be Zygote.gradient($p) do p
prob = NonlinearProblem(testing_f, u0, p)
return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson()).u)
end

println("Zygote SimpleNonlinearSolve Timing: $(median(zygote_nlsolve_timing)).")
println("Zygote BatchedNonlinearSolve Timing: $(median(zygote_batched_timing)).")
end

0 comments on commit b6f5491

Please sign in to comment.