Skip to content

Commit

Permalink
Merge pull request #271 from ErikQQY/qqy/mirkn_test
Browse files Browse the repository at this point in the history
Add JET test for MIRKN methods
  • Loading branch information
ErikQQY authored Jan 20, 2025
2 parents cb24722 + 22664fc commit 8045331
Show file tree
Hide file tree
Showing 11 changed files with 130 additions and 93 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
BoundaryValueDiffEqODEInterfaceExt = "ODEInterface"

[compat]
ADTypes = "1"
ADTypes = "1.11"
Aqua = "0.8.9"
ArrayInterface = "7.18"
BoundaryValueDiffEqAscher = "1"
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqAscher/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

[compat]
ADTypes = "1.9"
ADTypes = "1.11"
Adapt = "4.1.1"
AlmostBlockDiagonals = "0.1.10"
ArrayInterface = "7.18"
Expand Down
4 changes: 2 additions & 2 deletions lib/BoundaryValueDiffEqCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BoundaryValueDiffEqCore"
uuid = "56b672f2-a5fe-4263-ab2d-da677488eb3a"
authors = ["Qingyu Qu <erikqqy123@gmail.com>"]
version = "1.5.0"
version = "1.6.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -24,7 +24,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

[compat]
ADTypes = "1.9"
ADTypes = "1.11"
Adapt = "4.1.1"
ArrayInterface = "7.18"
Aqua = "0.8.9"
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqCore/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ end
end

# Construct BVP Solution
function __build_solution(prob::BVProblem, odesol, nlsol)
function __build_solution(prob::AbstractBVProblem, odesol, nlsol)
retcode = ifelse(SciMLBase.successful_retcode(nlsol), odesol.retcode, nlsol.retcode)
return SciMLBase.solution_new_original_retcode(odesol, nlsol, retcode, nlsol.resid)
end
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqFIRK/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

[compat]
ADTypes = "1.9"
ADTypes = "1.11"
Adapt = "4.1.1"
Aqua = "0.8.7"
ArrayInterface = "7.18"
Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqMIRK/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

[compat]
ADTypes = "1.9"
ADTypes = "1.11"
Adapt = "4.1.1"
Aqua = "0.8.7"
ArrayInterface = "7.18"
Expand Down
4 changes: 2 additions & 2 deletions lib/BoundaryValueDiffEqMIRKN/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BoundaryValueDiffEqMIRKN"
uuid = "9255f1d6-53bf-473e-b6bd-23f1ff009da4"
authors = ["Qingyu Qu <erikqqy123@gmail.com>"]
version = "1.2.0"
version = "1.3.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -29,7 +29,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

[compat]
ADTypes = "1.9"
ADTypes = "1.11"
Adapt = "4.1.1"
Aqua = "0.8.7"
ArrayInterface = "7.18"
Expand Down
3 changes: 2 additions & 1 deletion lib/BoundaryValueDiffEqMIRKN/src/BoundaryValueDiffEqMIRKN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ using SciMLBase: SciMLBase, AbstractDiffEqInterpolation, AbstractBVProblem,
using Setfield: @set!, @set
using SparseArrays: sparse
using SparseDiffTools: init_jacobian, sparse_jacobian, sparse_jacobian_cache,
sparse_jacobian!, matrix_colors, SymbolicsSparsityDetection
sparse_jacobian!, matrix_colors, SymbolicsSparsityDetection,
NoSparsityDetection

@reexport using ADTypes, BoundaryValueDiffEqCore, SciMLBase

Expand Down
40 changes: 27 additions & 13 deletions lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,34 @@ function SciMLBase.__init(prob::SecondOrderBVProblem, alg::AbstractMIRKN;

return MIRKNCache{iip, T}(
alg_order(alg), stage, M, size(X), f, bc, prob_, prob.problem_type,
prob.p, alg, TU, bcresid_prototype, mesh, mesh_dt, k_discrete,
y, y₀, residual, fᵢ_cache, fᵢ₂_cache, resid_size, kwargs)
prob.p, alg, TU, bcresid_prototype, mesh, mesh_dt, k_discrete, y,
y₀, residual, fᵢ_cache, fᵢ₂_cache, resid_size, (; dt, kwargs...))
end

function __split_mirkn_kwargs(; dt, kwargs...)
return ((dt), (; kwargs...))
end

function SciMLBase.solve!(cache::MIRKNCache{iip, T}) where {iip, T}
(; mesh, M, p, prob, kwargs) = cache
nlprob = __construct_nlproblem(cache, vec(cache.y₀))
(_), kwargs = __split_mirkn_kwargs(; cache.kwargs...)
info::ReturnCode.T = ReturnCode.Success

sol_nlprob, info = __perform_mirkn_iteration(cache; kwargs...)

solu = ArrayPartition.(
cache.y₀.u[1:length(cache.mesh)], cache.y₀.u[(length(cache.mesh) + 1):end])
odesol = SciMLBase.build_solution(
cache.prob, cache.alg, cache.mesh, solu; retcode = info)
return __build_solution(cache.prob, odesol, sol_nlprob)
end

function __perform_mirkn_iteration(cache::MIRKNCache; nlsolve_kwargs = (;), kwargs...)
nlprob::NonlinearProblem = __construct_nlproblem(cache, vec(cache.y₀))
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, cache.alg.nlsolve)
sol_nlprob = __solve(nlprob, nlsolve_alg; kwargs..., alias_u0 = true)
sol_nlprob = __solve(nlprob, nlsolve_alg; kwargs..., nlsolve_kwargs..., alias_u0 = true)
recursive_unflatten!(cache.y₀, sol_nlprob.u)
solu = ArrayPartition.(cache.y₀.u[1:length(mesh)], cache.y₀.u[(length(mesh) + 1):end])
return SciMLBase.build_solution(
prob, cache.alg, mesh, solu; retcode = sol_nlprob.retcode)

return sol_nlprob, sol_nlprob.retcode
end

function __construct_nlproblem(cache::MIRKNCache{iip}, y::AbstractVector) where {iip}
Expand All @@ -115,19 +130,18 @@ function __construct_nlproblem(cache::MIRKNCache{iip}, y::AbstractVector) where
sd = alg.jac_alg.diffmode isa AutoSparse ? SymbolicsSparsityDetection() :
NoSparsityDetection()
ad = alg.jac_alg.diffmode
lz = reduce(vcat, cache.y₀)
jac_cache = __sparse_jacobian_cache(Val(iip), ad, sd, lossₚ, lz, lz)
lz = __similar(y)
jac_cache = __sparse_jacobian_cache(Val(iip), ad, sd, lossₚ, lz, y)
jac_prototype = init_jacobian(jac_cache)
jac = if iip
@closure (J, u, p) -> __mirkn_mpoint_jacobian!(J, u, ad, jac_cache, lossₚ, lz)
else
@closure (u, p) -> __mirkn_mpoint_jacobian(jac_prototype, u, ad, jac_cache, lossₚ)
end
resid_prototype = zero(lz)
_nlf = NonlinearFunction{iip}(
nlf = NonlinearFunction{iip}(
loss; jac = jac, resid_prototype = resid_prototype, jac_prototype = jac_prototype)
nlprob::NonlinearProblem = NonlinearProblem(_nlf, lz, cache.p)
return nlprob
return __internal_nlsolve_problem(cache.prob, resid_prototype, lz, nlf, lz, cache.p)
end

function __mirkn_2point_jacobian!(J, x, diffmode, diffcache, loss_fn::L, resid) where {L}
Expand Down
160 changes: 91 additions & 69 deletions lib/BoundaryValueDiffEqMIRKN/test/mirkn_basic_tests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,94 @@
@testsetup module MIRKNConvergenceTests

using BoundaryValueDiffEqMIRKN

for order in (4, 6)
s = Symbol("MIRKN$(order)")
@eval mirkn_solver(::Val{$order}, args...; kwargs...) = $(s)(args...; kwargs...)
end

function f!(ddu, du, u, p, t)
ddu[1] = u[1]
end
function f(du, u, p, t)
return u[1]
end
function bc!(res, du, u, p, t)
res[1] = u(0.0)[1] - 1
res[2] = u(1.0)[1]
end
function bc(du, u, p, t)
return [u(0.0)[1] - 1, u(1.0)[1]]
end
function bc_indexing!(res, du, u, p, t)
res[1] = u[:, 1][1] - 1
res[2] = u[:, end][1]
end
function bc_indexing(du, u, p, t)
return [u[:, 1][1] - 1, u[:, end][1]]
end
function bc_a!(res, du, u, p)
res[1] = u[1] - 1
end
function bc_b!(res, du, u, p)
res[1] = u[1]
end
function bc_a(du, u, p)
return [u[1] - 1]
end
function bc_b(du, u, p)
return [u[1]]
end
analytical_solution = (u0, p, t) -> [
(exp(-t) - exp(t - 2)) / (1 - exp(-2)), (-exp(-t) - exp(t - 2)) / (1 - exp(-2))]
u0 = [1.0]
tspan = (0.0, 1.0)
testTol = 0.2
bvpf1 = DynamicalBVPFunction(f!, bc!, analytic = analytical_solution)
bvpf2 = DynamicalBVPFunction(f, bc, analytic = analytical_solution)
bvpf3 = DynamicalBVPFunction(f!, bc_indexing!, analytic = analytical_solution)
bvpf4 = DynamicalBVPFunction(f, bc_indexing, analytic = analytical_solution)
bvpf5 = DynamicalBVPFunction(f!, (bc_a!, bc_b!), analytic = analytical_solution,
bcresid_prototype = (zeros(1), zeros(1)), twopoint = Val(true))
bvpf6 = DynamicalBVPFunction(f, (bc_a, bc_b), analytic = analytical_solution,
bcresid_prototype = (zeros(1), zeros(1)), twopoint = Val(true))
probArr = [SecondOrderBVProblem(bvpf1, u0, tspan), SecondOrderBVProblem(bvpf2, u0, tspan),
SecondOrderBVProblem(bvpf3, u0, tspan), SecondOrderBVProblem(bvpf4, u0, tspan),
TwoPointSecondOrderBVProblem(bvpf5, u0, tspan),
TwoPointSecondOrderBVProblem(bvpf6, u0, tspan)]
dts = 1 .// 2 .^ (3:-1:1)

export probArr, dts, testTol, mirkn_solver

end

@testitem "Convergence on Linear" setup=[MIRKNConvergenceTests] begin
using LinearAlgebra, DiffEqDevTools

@testset "Problem: $i" for i in (1, 2, 3, 4, 5, 6)
prob = probArr[i]
@testset "MIRKN$order" for order in (4, 6)
sim = test_convergence(
dts, prob, mirkn_solver(Val(order)); abstol = 1e-8, reltol = 1e-8)
@test sim.𝒪est[:final]order atol=testTol
end
end
end

@testitem "JET tests" setup=[MIRKNConvergenceTests] begin
using JET

@testset "Problem: $i" for i in 1:6
prob = probArr[i]
@testset "MIRKN$order" for order in (4, 6)
solver = mirkn_solver(Val(order); nlsolve = NewtonRaphson(),
jac_alg = BVPJacobianAlgorithm(AutoForwardDiff(; chunksize = 2)))
@test_call target_modules=(BoundaryValueDiffEqMIRKN,) solve(
prob, solver; dt = 0.2)
end
end
end

@testitem "Example problem from paper" begin
using BoundaryValueDiffEqMIRKN

Expand Down Expand Up @@ -81,72 +172,3 @@
end
end
end

@testitem "Convergence on Linear" begin
using LinearAlgebra, DiffEqDevTools

for order in (4, 6)
s = Symbol("MIRKN$(order)")
@eval mirkn_solver(::Val{$order}, args...; kwargs...) = $(s)(args...; kwargs...)
end

function f!(ddu, du, u, p, t)
ddu[1] = u[1]
end
function f(du, u, p, t)
return u[1]
end
function bc!(res, du, u, p, t)
res[1] = u(0.0)[1] - 1
res[2] = u(1.0)[1]
end
function bc(du, u, p, t)
return [u(0.0)[1] - 1, u(1.0)[1]]
end
function bc_indexing!(res, du, u, p, t)
res[1] = u[:, 1][1] - 1
res[2] = u[:, end][1]
end
function bc_indexing(du, u, p, t)
return [u[:, 1][1] - 1, u[:, end][1]]
end
function bc_a!(res, du, u, p)
res[1] = u[1] - 1
end
function bc_b!(res, du, u, p)
res[1] = u[1]
end
function bc_a(du, u, p)
return [u[1] - 1]
end
function bc_b(du, u, p)
return [u[1]]
end
analytical_solution = (u0, p, t) -> [
(exp(-t) - exp(t - 2)) / (1 - exp(-2)), (-exp(-t) - exp(t - 2)) / (1 - exp(-2))]
u0 = [1.0]
tspan = (0.0, 1.0)
testTol = 0.2
bvpf1 = DynamicalBVPFunction(f!, bc!, analytic = analytical_solution)
bvpf2 = DynamicalBVPFunction(f, bc, analytic = analytical_solution)
bvpf3 = DynamicalBVPFunction(f!, bc_indexing!, analytic = analytical_solution)
bvpf4 = DynamicalBVPFunction(f, bc_indexing, analytic = analytical_solution)
bvpf5 = DynamicalBVPFunction(f!, (bc_a!, bc_b!), analytic = analytical_solution,
bcresid_prototype = (zeros(1), zeros(1)), twopoint = Val(true))
bvpf6 = DynamicalBVPFunction(f, (bc_a, bc_b), analytic = analytical_solution,
bcresid_prototype = (zeros(1), zeros(1)), twopoint = Val(true))
probArr = [
SecondOrderBVProblem(bvpf1, u0, tspan), SecondOrderBVProblem(bvpf2, u0, tspan),
SecondOrderBVProblem(bvpf3, u0, tspan), SecondOrderBVProblem(bvpf4, u0, tspan),
TwoPointSecondOrderBVProblem(bvpf5, u0, tspan),
TwoPointSecondOrderBVProblem(bvpf6, u0, tspan)]
dts = 1 .// 2 .^ (3:-1:1)
@testset "Problem: $i" for i in (1, 2, 3, 4, 5, 6)
prob = probArr[i]
@testset "MIRKN$order" for order in (4, 6)
sim = test_convergence(
dts, prob, mirkn_solver(Val(order)); abstol = 1e-8, reltol = 1e-8)
@test sim.𝒪est[:final]order atol=testTol
end
end
end
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqShooting/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

[compat]
ADTypes = "1.9"
ADTypes = "1.11"
Adapt = "4.1.1"
Aqua = "0.8.7"
ArrayInterface = "7.18"
Expand Down

0 comments on commit 8045331

Please sign in to comment.