diff --git a/Project.toml b/Project.toml index b06fe1817..a6d2ef892 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "2.38.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" @@ -51,6 +52,7 @@ SciMLBaseZygoteExt = "Zygote" [compat] ADTypes = "0.2.5,1.0.0" +Accessors = "0.1.36" ArrayInterface = "7.6" ChainRules = "1.58.0" ChainRulesCore = "1.18" diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index dad4ff058..595d0be35 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -22,6 +22,7 @@ import FunctionWrappersWrappers import RuntimeGeneratedFunctions import EnumX import ADTypes: AbstractADType +import Accessors: @set, @reset using Reexport using SciMLOperators diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index f8664bdff..1f2cdb387 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -372,83 +372,31 @@ function calculate_solution_errors!(sol::AbstractODESolution; fill_uanalytic = t end function build_solution(sol::ODESolution{T, N}, u_analytic, errors) where {T, N} - ODESolution{T, N}(sol.u, - u_analytic, - errors, - sol.t, - sol.k, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - sol.tslocation, - sol.stats, - sol.alg_choice, - sol.retcode, - sol.resid, - sol.original) + @reset sol.u_analytic = u_analytic + return @set sol.errors = errors end function solution_new_retcode(sol::ODESolution{T, N}, retcode) where {T, N} - ODESolution{T, N}(sol.u, - sol.u_analytic, - sol.errors, - sol.t, - sol.k, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - sol.tslocation, - sol.stats, - sol.alg_choice, - retcode, - sol.resid, - sol.original) + return @set sol.retcode = retcode end function solution_new_tslocation(sol::ODESolution{T, N}, tslocation) where {T, N} - ODESolution{T, N}(sol.u, - sol.u_analytic, - sol.errors, - sol.t, - sol.k, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - tslocation, - sol.stats, - sol.alg_choice, - sol.retcode, - sol.resid, - sol.original) + return @set sol.tslocation = tslocation end function solution_new_original_retcode( sol::ODESolution{T, N}, original, retcode, resid) where {T, N} - return ODESolution{T, N}( - sol.u, sol.u_analytic, sol.errors, sol.t, sol.k, sol.prob, sol.alg, - sol.interp, sol.dense, sol.tslocation, sol.stats, sol.alg_choice, retcode, resid, - original) + @reset sol.original = original + @reset sol.retcode = retcode + return @set sol.resid = resid end function solution_slice(sol::ODESolution{T, N}, I) where {T, N} - ODESolution{T, N}(sol.u[I], - sol.u_analytic === nothing ? nothing : sol.u_analytic[I], - sol.errors, - sol.t[I], - sol.dense ? sol.k[I] : sol.k, - sol.prob, - sol.alg, - sol.interp, - false, - sol.tslocation, - sol.stats, - sol.alg_choice, - sol.retcode, - sol.resid, - sol.original) + @reset sol.u = sol.u[I] + @reset sol.u_analytic = sol.u_analytic === nothing ? nothing : sol.u_analytic[I] + @reset sol.t = sol.t[I] + @reset sol.k = sol.dense ? sol.k[I] : sol.k + return @set sol.alg = false end function sensitivity_solution(sol::ODESolution, u, t) @@ -463,10 +411,7 @@ function sensitivity_solution(sol::ODESolution, u, t) end interp = enable_interpolation_sensitivitymode(sol.interp) - ODESolution{T, N}(u, sol.u_analytic, sol.errors, - t isa Vector ? t : collect(t), - sol.k, sol.prob, - sol.alg, interp, - sol.dense, sol.tslocation, - sol.stats, sol.alg_choice, sol.retcode, sol.resid, sol.original) + @reset sol.u = u + @reset sol.t = t isa Vector ? t : collect(t) + return @set sol.interp = interp end