Skip to content

Commit

Permalink
Add JET test for MIRKN methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikQQY committed Jan 19, 2025
1 parent cb24722 commit e5d4028
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 85 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
BoundaryValueDiffEqODEInterfaceExt = "ODEInterface"

[compat]
ADTypes = "1"
ADTypes = "1.9"
Aqua = "0.8.9"
ArrayInterface = "7.18"
BoundaryValueDiffEqAscher = "1"
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqMIRKN/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BoundaryValueDiffEqMIRKN"
uuid = "9255f1d6-53bf-473e-b6bd-23f1ff009da4"
authors = ["Qingyu Qu <erikqqy123@gmail.com>"]
version = "1.2.0"
version = "1.3.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 2 additions & 1 deletion lib/BoundaryValueDiffEqMIRKN/src/BoundaryValueDiffEqMIRKN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ using SciMLBase: SciMLBase, AbstractDiffEqInterpolation, AbstractBVProblem,
using Setfield: @set!, @set
using SparseArrays: sparse
using SparseDiffTools: init_jacobian, sparse_jacobian, sparse_jacobian_cache,
sparse_jacobian!, matrix_colors, SymbolicsSparsityDetection
sparse_jacobian!, matrix_colors, SymbolicsSparsityDetection,
NoSparsityDetection

@reexport using ADTypes, BoundaryValueDiffEqCore, SciMLBase

Expand Down
40 changes: 27 additions & 13 deletions lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,34 @@ function SciMLBase.__init(prob::SecondOrderBVProblem, alg::AbstractMIRKN;

return MIRKNCache{iip, T}(
alg_order(alg), stage, M, size(X), f, bc, prob_, prob.problem_type,
prob.p, alg, TU, bcresid_prototype, mesh, mesh_dt, k_discrete,
y, y₀, residual, fᵢ_cache, fᵢ₂_cache, resid_size, kwargs)
prob.p, alg, TU, bcresid_prototype, mesh, mesh_dt, k_discrete, y,
y₀, residual, fᵢ_cache, fᵢ₂_cache, resid_size, (; dt, kwargs...))
end

function __split_mirkn_kwargs(; dt, kwargs...)
return ((dt), (; kwargs...))
end

function SciMLBase.solve!(cache::MIRKNCache{iip, T}) where {iip, T}
(; mesh, M, p, prob, kwargs) = cache
nlprob = __construct_nlproblem(cache, vec(cache.y₀))
(_), kwargs = __split_mirkn_kwargs(; cache.kwargs...)
info::ReturnCode.T = ReturnCode.Success

sol_nlprob, info = __perform_mirkn_iteration(cache; kwargs...)

solu = ArrayPartition.(
cache.y₀.u[1:length(cache.mesh)], cache.y₀.u[(length(cache.mesh) + 1):end])
odesol = SciMLBase.build_solution(
cache.prob, cache.alg, cache.mesh, solu; retcode = info)
return __build_solution(cache.prob, odesol, sol_nlprob)
end

function __perform_mirkn_iteration(cache::MIRKNCache; nlsolve_kwargs = (;), kwargs...)
nlprob::NonlinearProblem = __construct_nlproblem(cache, vec(cache.y₀))
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, cache.alg.nlsolve)
sol_nlprob = __solve(nlprob, nlsolve_alg; kwargs..., alias_u0 = true)
sol_nlprob = __solve(nlprob, nlsolve_alg; kwargs..., nlsolve_kwargs..., alias_u0 = true)
recursive_unflatten!(cache.y₀, sol_nlprob.u)
solu = ArrayPartition.(cache.y₀.u[1:length(mesh)], cache.y₀.u[(length(mesh) + 1):end])
return SciMLBase.build_solution(
prob, cache.alg, mesh, solu; retcode = sol_nlprob.retcode)

return sol_nlprob, sol_nlprob.retcode
end

function __construct_nlproblem(cache::MIRKNCache{iip}, y::AbstractVector) where {iip}
Expand All @@ -115,19 +130,18 @@ function __construct_nlproblem(cache::MIRKNCache{iip}, y::AbstractVector) where
sd = alg.jac_alg.diffmode isa AutoSparse ? SymbolicsSparsityDetection() :
NoSparsityDetection()
ad = alg.jac_alg.diffmode
lz = reduce(vcat, cache.y₀)
jac_cache = __sparse_jacobian_cache(Val(iip), ad, sd, lossₚ, lz, lz)
lz = similar(y)
jac_cache = __sparse_jacobian_cache(Val(iip), ad, sd, lossₚ, lz, y)
jac_prototype = init_jacobian(jac_cache)
jac = if iip
@closure (J, u, p) -> __mirkn_mpoint_jacobian!(J, u, ad, jac_cache, lossₚ, lz)
else
@closure (u, p) -> __mirkn_mpoint_jacobian(jac_prototype, u, ad, jac_cache, lossₚ)
end
resid_prototype = zero(lz)
_nlf = NonlinearFunction{iip}(
nlf = NonlinearFunction{iip}(
loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype)
nlprob::NonlinearProblem = NonlinearProblem(_nlf, lz, cache.p)
return nlprob
return __internal_nlsolve_problem(cache.prob, resid_prototype, lz, nlf, lz, cache.p)
end

function __mirkn_2point_jacobian!(J, x, diffmode, diffcache, loss_fn::L, resid) where {L}
Expand Down
158 changes: 89 additions & 69 deletions lib/BoundaryValueDiffEqMIRKN/test/mirkn_basic_tests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,92 @@
@testsetup module MIRKNConvergenceTests

using BoundaryValueDiffEqMIRKN

for order in (4, 6)
s = Symbol("MIRKN$(order)")
@eval mirkn_solver(::Val{$order}, args...; kwargs...) = $(s)(args...; kwargs...)
end

function f!(ddu, du, u, p, t)
ddu[1] = u[1]
end
function f(du, u, p, t)
return u[1]
end
function bc!(res, du, u, p, t)
res[1] = u(0.0)[1] - 1
res[2] = u(1.0)[1]
end
function bc(du, u, p, t)
return [u(0.0)[1] - 1, u(1.0)[1]]
end
function bc_indexing!(res, du, u, p, t)
res[1] = u[:, 1][1] - 1
res[2] = u[:, end][1]
end
function bc_indexing(du, u, p, t)
return [u[:, 1][1] - 1, u[:, end][1]]
end
function bc_a!(res, du, u, p)
res[1] = u[1] - 1
end
function bc_b!(res, du, u, p)
res[1] = u[1]
end
function bc_a(du, u, p)
return [u[1] - 1]
end
function bc_b(du, u, p)
return [u[1]]
end
analytical_solution = (u0, p, t) -> [
(exp(-t) - exp(t - 2)) / (1 - exp(-2)), (-exp(-t) - exp(t - 2)) / (1 - exp(-2))]
u0 = [1.0]
tspan = (0.0, 1.0)
testTol = 0.2
bvpf1 = DynamicalBVPFunction(f!, bc!, analytic = analytical_solution)
bvpf2 = DynamicalBVPFunction(f, bc, analytic = analytical_solution)
bvpf3 = DynamicalBVPFunction(f!, bc_indexing!, analytic = analytical_solution)
bvpf4 = DynamicalBVPFunction(f, bc_indexing, analytic = analytical_solution)
bvpf5 = DynamicalBVPFunction(f!, (bc_a!, bc_b!), analytic = analytical_solution,
bcresid_prototype = (zeros(1), zeros(1)), twopoint = Val(true))
bvpf6 = DynamicalBVPFunction(f, (bc_a, bc_b), analytic = analytical_solution,
bcresid_prototype = (zeros(1), zeros(1)), twopoint = Val(true))
probArr = [SecondOrderBVProblem(bvpf1, u0, tspan), SecondOrderBVProblem(bvpf2, u0, tspan),
SecondOrderBVProblem(bvpf3, u0, tspan), SecondOrderBVProblem(bvpf4, u0, tspan),
TwoPointSecondOrderBVProblem(bvpf5, u0, tspan),
TwoPointSecondOrderBVProblem(bvpf6, u0, tspan)]
dts = 1 .// 2 .^ (3:-1:1)

end

@testitem "Convergence on Linear" setup=[MIRKNConvergenceTests] begin
using LinearAlgebra, DiffEqDevTools

@testset "Problem: $i" for i in (1, 2, 3, 4, 5, 6)
prob = probArr[i]
@testset "MIRKN$order" for order in (4, 6)
sim = test_convergence(
dts, prob, mirkn_solver(Val(order)); abstol = 1e-8, reltol = 1e-8)
@test sim.𝒪est[:final]order atol=testTol
end
end
end

@testitem "JET tests" setup=[MIRKNConvergenceTests] begin
using JET

@testset "Problem: $i" for i in 1:6
prob = probArr[i]
@testset "MIRKN$order" for order in (4, 6)
solver = mirkn_solver(Val(order); nlsolve = NewtonRaphson(),
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2)))
@test_call target_modules=(BoundaryValueDiffEqMIRKN,) solve(
prob, solver; dt = 0.2)
end
end
end

@testitem "Example problem from paper" begin
using BoundaryValueDiffEqMIRKN

Expand Down Expand Up @@ -81,72 +170,3 @@
end
end
end

@testitem "Convergence on Linear" begin
using LinearAlgebra, DiffEqDevTools

for order in (4, 6)
s = Symbol("MIRKN$(order)")
@eval mirkn_solver(::Val{$order}, args...; kwargs...) = $(s)(args...; kwargs...)
end

function f!(ddu, du, u, p, t)
ddu[1] = u[1]
end
function f(du, u, p, t)
return u[1]
end
function bc!(res, du, u, p, t)
res[1] = u(0.0)[1] - 1
res[2] = u(1.0)[1]
end
function bc(du, u, p, t)
return [u(0.0)[1] - 1, u(1.0)[1]]
end
function bc_indexing!(res, du, u, p, t)
res[1] = u[:, 1][1] - 1
res[2] = u[:, end][1]
end
function bc_indexing(du, u, p, t)
return [u[:, 1][1] - 1, u[:, end][1]]
end
function bc_a!(res, du, u, p)
res[1] = u[1] - 1
end
function bc_b!(res, du, u, p)
res[1] = u[1]
end
function bc_a(du, u, p)
return [u[1] - 1]
end
function bc_b(du, u, p)
return [u[1]]
end
analytical_solution = (u0, p, t) -> [
(exp(-t) - exp(t - 2)) / (1 - exp(-2)), (-exp(-t) - exp(t - 2)) / (1 - exp(-2))]
u0 = [1.0]
tspan = (0.0, 1.0)
testTol = 0.2
bvpf1 = DynamicalBVPFunction(f!, bc!, analytic = analytical_solution)
bvpf2 = DynamicalBVPFunction(f, bc, analytic = analytical_solution)
bvpf3 = DynamicalBVPFunction(f!, bc_indexing!, analytic = analytical_solution)
bvpf4 = DynamicalBVPFunction(f, bc_indexing, analytic = analytical_solution)
bvpf5 = DynamicalBVPFunction(f!, (bc_a!, bc_b!), analytic = analytical_solution,
bcresid_prototype = (zeros(1), zeros(1)), twopoint = Val(true))
bvpf6 = DynamicalBVPFunction(f, (bc_a, bc_b), analytic = analytical_solution,
bcresid_prototype = (zeros(1), zeros(1)), twopoint = Val(true))
probArr = [
SecondOrderBVProblem(bvpf1, u0, tspan), SecondOrderBVProblem(bvpf2, u0, tspan),
SecondOrderBVProblem(bvpf3, u0, tspan), SecondOrderBVProblem(bvpf4, u0, tspan),
TwoPointSecondOrderBVProblem(bvpf5, u0, tspan),
TwoPointSecondOrderBVProblem(bvpf6, u0, tspan)]
dts = 1 .// 2 .^ (3:-1:1)
@testset "Problem: $i" for i in (1, 2, 3, 4, 5, 6)
prob = probArr[i]
@testset "MIRKN$order" for order in (4, 6)
sim = test_convergence(
dts, prob, mirkn_solver(Val(order)); abstol = 1e-8, reltol = 1e-8)
@test sim.𝒪est[:final]order atol=testTol
end
end
end

0 comments on commit e5d4028

Please sign in to comment.