From 9d317392406d8239b4bae87091cb8016f59025d0 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 21 May 2024 17:25:52 +0530 Subject: [PATCH 1/8] refactor: move type-pirated function from BoundaryValueDiffEq here --- src/solutions/ode_solutions.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 46abb98c1..f8664bdff 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -425,6 +425,14 @@ function solution_new_tslocation(sol::ODESolution{T, N}, tslocation) where {T, N sol.original) end +function solution_new_original_retcode( + sol::ODESolution{T, N}, original, retcode, resid) where {T, N} + return ODESolution{T, N}( + sol.u, sol.u_analytic, sol.errors, sol.t, sol.k, sol.prob, sol.alg, + sol.interp, sol.dense, sol.tslocation, sol.stats, sol.alg_choice, retcode, resid, + original) +end + function solution_slice(sol::ODESolution{T, N}, I) where {T, N} ODESolution{T, N}(sol.u[I], sol.u_analytic === nothing ? nothing : sol.u_analytic[I], From 6f62b15b5fb10d12f3bd1a14ea1bfda1604e9ebd Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 22 May 2024 16:36:56 +0530 Subject: [PATCH 2/8] refactor: use Accessors.jl when modifying ODESolution fields --- Project.toml | 2 + src/SciMLBase.jl | 1 + src/solutions/ode_solutions.jl | 89 ++++++++-------------------------- 3 files changed, 22 insertions(+), 70 deletions(-) diff --git a/Project.toml b/Project.toml index 98d008501..d771170dd 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "2.39.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" @@ -51,6 +52,7 @@ SciMLBaseZygoteExt = "Zygote" [compat] ADTypes = "0.2.5,1.0.0" +Accessors = "0.1.36" ArrayInterface = "7.6" ChainRules = "1.58.0" ChainRulesCore = "1.18" diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index dad4ff058..595d0be35 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -22,6 +22,7 @@ import FunctionWrappersWrappers import RuntimeGeneratedFunctions import EnumX import ADTypes: AbstractADType +import Accessors: @set, @reset using Reexport using SciMLOperators diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index f8664bdff..7c22f9b4a 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -125,6 +125,10 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, S, original::O end +function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: ODESolution{T, N}} + ODESolution{T, N} +end + Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Symbol) if s === :destats Base.depwarn("`sol.destats` is deprecated. Use `sol.stats` instead.", "sol.destats") @@ -372,83 +376,31 @@ function calculate_solution_errors!(sol::AbstractODESolution; fill_uanalytic = t end function build_solution(sol::ODESolution{T, N}, u_analytic, errors) where {T, N} - ODESolution{T, N}(sol.u, - u_analytic, - errors, - sol.t, - sol.k, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - sol.tslocation, - sol.stats, - sol.alg_choice, - sol.retcode, - sol.resid, - sol.original) + @reset sol.u_analytic = u_analytic + return @set sol.errors = errors end function solution_new_retcode(sol::ODESolution{T, N}, retcode) where {T, N} - ODESolution{T, N}(sol.u, - sol.u_analytic, - sol.errors, - sol.t, - sol.k, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - sol.tslocation, - sol.stats, - sol.alg_choice, - retcode, - sol.resid, - sol.original) + return @set sol.retcode = retcode end function solution_new_tslocation(sol::ODESolution{T, N}, tslocation) where {T, N} - ODESolution{T, N}(sol.u, - sol.u_analytic, - sol.errors, - sol.t, - sol.k, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - tslocation, - sol.stats, - sol.alg_choice, - sol.retcode, - sol.resid, - sol.original) + return @set sol.tslocation = tslocation end function solution_new_original_retcode( sol::ODESolution{T, N}, original, retcode, resid) where {T, N} - return ODESolution{T, N}( - sol.u, sol.u_analytic, sol.errors, sol.t, sol.k, sol.prob, sol.alg, - sol.interp, sol.dense, sol.tslocation, sol.stats, sol.alg_choice, retcode, resid, - original) + @reset sol.original = original + @reset sol.retcode = retcode + return @set sol.resid = resid end function solution_slice(sol::ODESolution{T, N}, I) where {T, N} - ODESolution{T, N}(sol.u[I], - sol.u_analytic === nothing ? nothing : sol.u_analytic[I], - sol.errors, - sol.t[I], - sol.dense ? sol.k[I] : sol.k, - sol.prob, - sol.alg, - sol.interp, - false, - sol.tslocation, - sol.stats, - sol.alg_choice, - sol.retcode, - sol.resid, - sol.original) + @reset sol.u = sol.u[I] + @reset sol.u_analytic = sol.u_analytic === nothing ? nothing : sol.u_analytic[I] + @reset sol.t = sol.t[I] + @reset sol.k = sol.dense ? sol.k[I] : sol.k + return @set sol.alg = false end function sensitivity_solution(sol::ODESolution, u, t) @@ -463,10 +415,7 @@ function sensitivity_solution(sol::ODESolution, u, t) end interp = enable_interpolation_sensitivitymode(sol.interp) - ODESolution{T, N}(u, sol.u_analytic, sol.errors, - t isa Vector ? t : collect(t), - sol.k, sol.prob, - sol.alg, interp, - sol.dense, sol.tslocation, - sol.stats, sol.alg_choice, sol.retcode, sol.resid, sol.original) + @reset sol.u = u + @reset sol.t = t isa Vector ? t : collect(t) + return @set sol.interp = interp end From da4b58a2ae006f96392def98ff4fa008a5fe4f36 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sun, 26 May 2024 21:25:26 +0530 Subject: [PATCH 3/8] refactor: format --- src/integrator_interface.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index 9ca7c5ddf..c995a87d5 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -622,7 +622,8 @@ function check_error(integrator::DEIntegrator) @warn("dt($(integrator.dt)) <= dtmin($(opts.dtmin)) at t=$(integrator.t)$EEst. Aborting. There is either an error in your model specification or the true solution is unstable.") end return ReturnCode.DtLessThanMin - elseif !step_accepted && integrator.t isa AbstractFloat && abs(integrator.dt) <= abs(eps(integrator.t)) + elseif !step_accepted && integrator.t isa AbstractFloat && + abs(integrator.dt) <= abs(eps(integrator.t)) if verbose if isdefined(integrator, :EEst) EEst = ", and step error estimate = $(integrator.EEst)" @@ -634,7 +635,8 @@ function check_error(integrator::DEIntegrator) return ReturnCode.Unstable end end - if step_accepted && opts.unstable_check(integrator.dt, integrator.u, integrator.p, integrator.t) + if step_accepted && + opts.unstable_check(integrator.dt, integrator.u, integrator.p, integrator.t) if verbose @warn("Instability detected. Aborting") end From b5051d46f791d11be07cea034afe568202b7a5f9 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sun, 26 May 2024 22:08:21 +0530 Subject: [PATCH 4/8] ci: upper-bound SymbolicUtils compat to <1.6 --- test/downstream/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 11ce0ff96..7d599049a 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -15,6 +15,7 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -35,5 +36,6 @@ SciMLSensitivity = "7.11" SciMLStructures = "1.1" Sundials = "4.11" SymbolicIndexingInterface = "0.3" +SymbolicUtils = "<1.6" Unitful = "1.12" Zygote = "0.6" From 0c02ba81d30bd8adb742165814031ee7c873bca0 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 29 May 2024 12:25:48 +0530 Subject: [PATCH 5/8] test: do not import SciMLBase in `runtests.jl` This avoids loading `ADTypes.jl` before downstream tests are precompiled --- test/aqua.jl | 3 +++ test/runtests.jl | 4 ---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/test/aqua.jl b/test/aqua.jl index 6dbee8706..b403790aa 100644 --- a/test/aqua.jl +++ b/test/aqua.jl @@ -2,6 +2,9 @@ using Test using SciMLBase using Aqua +# https://github.com/JuliaArrays/FillArrays.jl/pull/163 +@test_broken isempty(detect_ambiguities(SciMLBase)) + @testset "Aqua tests (performance)" begin # This tests that we don't accidentally run into # https://github.com/JuliaLang/julia/issues/29393 diff --git a/test/runtests.jl b/test/runtests.jl index f5b3c2d22..0172f4a66 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,10 +1,6 @@ using Pkg using SafeTestsets using Test -using SciMLBase - -# https://github.com/JuliaArrays/FillArrays.jl/pull/163 -@test_broken isempty(detect_ambiguities(SciMLBase)) const GROUP = get(ENV, "GROUP", "All") const is_APPVEYOR = (Sys.iswindows() && haskey(ENV, "APPVEYOR")) From 0c72de24f730ab139b9cd130f46190e2aa07f1e2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 3 Jun 2024 13:50:00 +0530 Subject: [PATCH 6/8] test: don't check inference for symbolic indexing with nested arrays/tuples --- test/downstream/symbol_indexing.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index ca4823b55..4952d729c 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -235,7 +235,7 @@ x_val = vcat.(getindex.((sol,), x_idx, :)...) y_val = sol[y_idx, :] obs_val = sol[x[1] + y] -# checking inference for mixed-type arrays will always fail +# don't check inference for weird cases of nested arrays/tuples for (sym, val, check_inference) in [ (x, x_val, true), (y, y_val, true), @@ -254,7 +254,7 @@ for (sym, val, check_inference) in [ ((x, x), [(i, i) for i in x_val], true), ((x, x_idx), [(i, i) for i in x_val], true), ((x, x[1] + y), [(i, j) for (i, j) in zip(x_val, obs_val)], true), - ((x, (x[1] + y, y)), [(i, (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], true), + ((x, (x[1] + y, y)), [(i, (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], false), ([x, [x[1] + y, y]], [[i, [k, j]] for (i, j, k) in zip(x_val, y_val, obs_val)], false), ((x, [x[1] + y, y], (x[1] + y, y_idx)), [(i, [k, j], (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], false), @@ -311,6 +311,7 @@ end pval = [1.0, 2.0, 3.0] pval_new = [4.0, 5.0, 6.0] +# don't check inference for nested tuples/arrays for (sym, oldval, newval, check_inference) in [ (p[1], pval[1], pval_new[1], true), (p, pval, pval_new, true), @@ -319,7 +320,7 @@ for (sym, oldval, newval, check_inference) in [ ((p[1], p[2:3]), (pval[1], pval[2:3]), (pval_new[1], pval_new[2:3]), true), ([p[1], p[2:3]], [pval[1], pval[2:3]], [pval_new[1], pval_new[2:3]], false), ((p[1], (p[2],), [p[3]]), (pval[1], (pval[2],), [pval[3]]), - (pval_new[1], (pval_new[2],), [pval_new[3]]), true), + (pval_new[1], (pval_new[2],), [pval_new[3]]), false), ([p[1], (p[2],), [p[3]]], [pval[1], (pval[2],), [pval[3]]], [pval_new[1], (pval_new[2],), [pval_new[3]]], false) ] From 55d81af15d274d094a4c50f3e6ae82e6321d5d13 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 3 Jun 2024 17:00:33 +0530 Subject: [PATCH 7/8] fix: fix incorrect dimensionality of `ODESolution` in `build_function` and `@set` --- ext/SciMLBaseZygoteExt.jl | 10 +--------- src/solutions/ode_solutions.jl | 12 +++++++++++- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 2f7399f11..58e7bb309 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -58,15 +58,7 @@ end du, dprob end T = eltype(eltype(VA.u)) - if dprob.u0 === nothing - N = 2 - elseif dprob isa SciMLBase.BVProblem && !hasmethod(size, Tuple{typeof(dprob.u0)}) - __u0 = hasmethod(dprob.u0, Tuple{typeof(dprob.p), typeof(first(dprob.tspan))}) ? - dprob.u0(dprob.p, first(dprob.tspan)) : dprob.u0(first(dprob.tspan)) - N = length((size(__u0)..., length(du))) - else - N = length((size(dprob.u0)..., length(du))) - end + N = ndims(VA) Δ′ = ODESolution{T, N}(du, nothing, nothing, VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, VA.alg_choice, VA.retcode) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 7c22f9b4a..5636c4993 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -129,6 +129,16 @@ function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: ODESolution ODESolution{T, N} end +function ConstructionBase.setproperties(sol::ODESolution, patch::NamedTuple) + u = get(patch, :u, sol.u) + N = u === nothing ? 2 : ndims(eltype(u)) + 1 + T = eltype(eltype(u)) + patch = merge(getproperties(sol), patch) + return ODESolution{T, N}(patch.u, patch.u_analytic, patch.errors, patch.t, patch.k, + patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats, + patch.alg_choice, patch.retcode, patch.resid, patch.original) +end + Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Symbol) if s === :destats Base.depwarn("`sol.destats` is deprecated. Use `sol.stats` instead.", "sol.destats") @@ -276,7 +286,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, prob.u0(prob.p, first(prob.tspan)) : prob.u0(first(prob.tspan)) N = length((size(__u0)..., length(u))) else - N = length((size(prob.u0)..., length(u))) + N = ndims(eltype(u)) + 1 end if prob.f isa Tuple From 17a165c654dd9bc48ccc1e4514c38daec8912bab Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 4 Jun 2024 14:05:12 +0530 Subject: [PATCH 8/8] build: bump RecursiveArrayTools compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d771170dd..0d874b63b 100644 --- a/Project.toml +++ b/Project.toml @@ -78,7 +78,7 @@ PyCall = "1.96" PythonCall = "0.9.15" RCall = "0.14.0" RecipesBase = "1.3.4" -RecursiveArrayTools = "3.14.0" +RecursiveArrayTools = "3.22.0" Reexport = "1" RuntimeGeneratedFunctions = "0.5.12" SciMLOperators = "0.3.7"