diff --git a/Project.toml b/Project.toml index 98d008501..0d874b63b 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "2.39.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" @@ -51,6 +52,7 @@ SciMLBaseZygoteExt = "Zygote" [compat] ADTypes = "0.2.5,1.0.0" +Accessors = "0.1.36" ArrayInterface = "7.6" ChainRules = "1.58.0" ChainRulesCore = "1.18" @@ -76,7 +78,7 @@ PyCall = "1.96" PythonCall = "0.9.15" RCall = "0.14.0" RecipesBase = "1.3.4" -RecursiveArrayTools = "3.14.0" +RecursiveArrayTools = "3.22.0" Reexport = "1" RuntimeGeneratedFunctions = "0.5.12" SciMLOperators = "0.3.7" diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 2f7399f11..58e7bb309 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -58,15 +58,7 @@ end du, dprob end T = eltype(eltype(VA.u)) - if dprob.u0 === nothing - N = 2 - elseif dprob isa SciMLBase.BVProblem && !hasmethod(size, Tuple{typeof(dprob.u0)}) - __u0 = hasmethod(dprob.u0, Tuple{typeof(dprob.p), typeof(first(dprob.tspan))}) ? - dprob.u0(dprob.p, first(dprob.tspan)) : dprob.u0(first(dprob.tspan)) - N = length((size(__u0)..., length(du))) - else - N = length((size(dprob.u0)..., length(du))) - end + N = ndims(VA) Δ′ = ODESolution{T, N}(du, nothing, nothing, VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, VA.alg_choice, VA.retcode) diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index dad4ff058..595d0be35 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -22,6 +22,7 @@ import FunctionWrappersWrappers import RuntimeGeneratedFunctions import EnumX import ADTypes: AbstractADType +import Accessors: @set, @reset using Reexport using SciMLOperators diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index 9ca7c5ddf..c995a87d5 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -622,7 +622,8 @@ function check_error(integrator::DEIntegrator) @warn("dt($(integrator.dt)) <= dtmin($(opts.dtmin)) at t=$(integrator.t)$EEst. Aborting. There is either an error in your model specification or the true solution is unstable.") end return ReturnCode.DtLessThanMin - elseif !step_accepted && integrator.t isa AbstractFloat && abs(integrator.dt) <= abs(eps(integrator.t)) + elseif !step_accepted && integrator.t isa AbstractFloat && + abs(integrator.dt) <= abs(eps(integrator.t)) if verbose if isdefined(integrator, :EEst) EEst = ", and step error estimate = $(integrator.EEst)" @@ -634,7 +635,8 @@ function check_error(integrator::DEIntegrator) return ReturnCode.Unstable end end - if step_accepted && opts.unstable_check(integrator.dt, integrator.u, integrator.p, integrator.t) + if step_accepted && + opts.unstable_check(integrator.dt, integrator.u, integrator.p, integrator.t) if verbose @warn("Instability detected. Aborting") end diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 46abb98c1..5636c4993 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -125,6 +125,20 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, S, original::O end +function ConstructionBase.constructorof(::Type{O}) where {T, N, O <: ODESolution{T, N}} + ODESolution{T, N} +end + +function ConstructionBase.setproperties(sol::ODESolution, patch::NamedTuple) + u = get(patch, :u, sol.u) + N = u === nothing ? 2 : ndims(eltype(u)) + 1 + T = eltype(eltype(u)) + patch = merge(getproperties(sol), patch) + return ODESolution{T, N}(patch.u, patch.u_analytic, patch.errors, patch.t, patch.k, + patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats, + patch.alg_choice, patch.retcode, patch.resid, patch.original) +end + Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Symbol) if s === :destats Base.depwarn("`sol.destats` is deprecated. Use `sol.stats` instead.", "sol.destats") @@ -272,7 +286,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, prob.u0(prob.p, first(prob.tspan)) : prob.u0(first(prob.tspan)) N = length((size(__u0)..., length(u))) else - N = length((size(prob.u0)..., length(u))) + N = ndims(eltype(u)) + 1 end if prob.f isa Tuple @@ -372,75 +386,31 @@ function calculate_solution_errors!(sol::AbstractODESolution; fill_uanalytic = t end function build_solution(sol::ODESolution{T, N}, u_analytic, errors) where {T, N} - ODESolution{T, N}(sol.u, - u_analytic, - errors, - sol.t, - sol.k, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - sol.tslocation, - sol.stats, - sol.alg_choice, - sol.retcode, - sol.resid, - sol.original) + @reset sol.u_analytic = u_analytic + return @set sol.errors = errors end function solution_new_retcode(sol::ODESolution{T, N}, retcode) where {T, N} - ODESolution{T, N}(sol.u, - sol.u_analytic, - sol.errors, - sol.t, - sol.k, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - sol.tslocation, - sol.stats, - sol.alg_choice, - retcode, - sol.resid, - sol.original) + return @set sol.retcode = retcode end function solution_new_tslocation(sol::ODESolution{T, N}, tslocation) where {T, N} - ODESolution{T, N}(sol.u, - sol.u_analytic, - sol.errors, - sol.t, - sol.k, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - tslocation, - sol.stats, - sol.alg_choice, - sol.retcode, - sol.resid, - sol.original) + return @set sol.tslocation = tslocation +end + +function solution_new_original_retcode( + sol::ODESolution{T, N}, original, retcode, resid) where {T, N} + @reset sol.original = original + @reset sol.retcode = retcode + return @set sol.resid = resid end function solution_slice(sol::ODESolution{T, N}, I) where {T, N} - ODESolution{T, N}(sol.u[I], - sol.u_analytic === nothing ? nothing : sol.u_analytic[I], - sol.errors, - sol.t[I], - sol.dense ? sol.k[I] : sol.k, - sol.prob, - sol.alg, - sol.interp, - false, - sol.tslocation, - sol.stats, - sol.alg_choice, - sol.retcode, - sol.resid, - sol.original) + @reset sol.u = sol.u[I] + @reset sol.u_analytic = sol.u_analytic === nothing ? nothing : sol.u_analytic[I] + @reset sol.t = sol.t[I] + @reset sol.k = sol.dense ? sol.k[I] : sol.k + return @set sol.alg = false end function sensitivity_solution(sol::ODESolution, u, t) @@ -455,10 +425,7 @@ function sensitivity_solution(sol::ODESolution, u, t) end interp = enable_interpolation_sensitivitymode(sol.interp) - ODESolution{T, N}(u, sol.u_analytic, sol.errors, - t isa Vector ? t : collect(t), - sol.k, sol.prob, - sol.alg, interp, - sol.dense, sol.tslocation, - sol.stats, sol.alg_choice, sol.retcode, sol.resid, sol.original) + @reset sol.u = u + @reset sol.t = t isa Vector ? t : collect(t) + return @set sol.interp = interp end diff --git a/test/aqua.jl b/test/aqua.jl index 6dbee8706..b403790aa 100644 --- a/test/aqua.jl +++ b/test/aqua.jl @@ -2,6 +2,9 @@ using Test using SciMLBase using Aqua +# https://github.com/JuliaArrays/FillArrays.jl/pull/163 +@test_broken isempty(detect_ambiguities(SciMLBase)) + @testset "Aqua tests (performance)" begin # This tests that we don't accidentally run into # https://github.com/JuliaLang/julia/issues/29393 diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 11ce0ff96..7d599049a 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -15,6 +15,7 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -35,5 +36,6 @@ SciMLSensitivity = "7.11" SciMLStructures = "1.1" Sundials = "4.11" SymbolicIndexingInterface = "0.3" +SymbolicUtils = "<1.6" Unitful = "1.12" Zygote = "0.6" diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index ca4823b55..4952d729c 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -235,7 +235,7 @@ x_val = vcat.(getindex.((sol,), x_idx, :)...) y_val = sol[y_idx, :] obs_val = sol[x[1] + y] -# checking inference for mixed-type arrays will always fail +# don't check inference for weird cases of nested arrays/tuples for (sym, val, check_inference) in [ (x, x_val, true), (y, y_val, true), @@ -254,7 +254,7 @@ for (sym, val, check_inference) in [ ((x, x), [(i, i) for i in x_val], true), ((x, x_idx), [(i, i) for i in x_val], true), ((x, x[1] + y), [(i, j) for (i, j) in zip(x_val, obs_val)], true), - ((x, (x[1] + y, y)), [(i, (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], true), + ((x, (x[1] + y, y)), [(i, (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], false), ([x, [x[1] + y, y]], [[i, [k, j]] for (i, j, k) in zip(x_val, y_val, obs_val)], false), ((x, [x[1] + y, y], (x[1] + y, y_idx)), [(i, [k, j], (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], false), @@ -311,6 +311,7 @@ end pval = [1.0, 2.0, 3.0] pval_new = [4.0, 5.0, 6.0] +# don't check inference for nested tuples/arrays for (sym, oldval, newval, check_inference) in [ (p[1], pval[1], pval_new[1], true), (p, pval, pval_new, true), @@ -319,7 +320,7 @@ for (sym, oldval, newval, check_inference) in [ ((p[1], p[2:3]), (pval[1], pval[2:3]), (pval_new[1], pval_new[2:3]), true), ([p[1], p[2:3]], [pval[1], pval[2:3]], [pval_new[1], pval_new[2:3]], false), ((p[1], (p[2],), [p[3]]), (pval[1], (pval[2],), [pval[3]]), - (pval_new[1], (pval_new[2],), [pval_new[3]]), true), + (pval_new[1], (pval_new[2],), [pval_new[3]]), false), ([p[1], (p[2],), [p[3]]], [pval[1], (pval[2],), [pval[3]]], [pval_new[1], (pval_new[2],), [pval_new[3]]], false) ] diff --git a/test/runtests.jl b/test/runtests.jl index f5b3c2d22..0172f4a66 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,10 +1,6 @@ using Pkg using SafeTestsets using Test -using SciMLBase - -# https://github.com/JuliaArrays/FillArrays.jl/pull/163 -@test_broken isempty(detect_ambiguities(SciMLBase)) const GROUP = get(ENV, "GROUP", "All") const is_APPVEYOR = (Sys.iswindows() && haskey(ENV, "APPVEYOR"))