diff --git a/src/remake.jl b/src/remake.jl index cc826636d..a5c50b967 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -97,26 +97,29 @@ function remake(prob::ODEProblem; f = missing, tspan = prob.tspan end - u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults) + newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults) iip = isinplace(prob) if f === missing + initializeprob, initializeprobmap = remake_initializeprob(prob.f.sys, prob.f, u0 === missing ? newu0 : u0, tspan[1], p === missing ? newp : p) if specialization(prob.f) === FunctionWrapperSpecialize ptspan = promote_tspan(tspan) if iip _f = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_iip( unwrapped_f(prob.f.f), - (u0, u0, p, - ptspan[1]))) + (newu0, newu0, newp, + ptspan[1])); initializeprob, initializeprobmap) else _f = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_oop( unwrapped_f(prob.f.f), - (u0, p, - ptspan[1]))) + (newu0, newp, + ptspan[1])); initializeprob, initializeprobmap) end else _f = prob.f + @reset _f.initializeprob = initializeprob + @reset _f.initializeprobmap = initializeprobmap end elseif f isa AbstractODEFunction _f = f @@ -124,22 +127,37 @@ function remake(prob::ODEProblem; f = missing, ptspan = promote_tspan(tspan) if iip _f = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_iip(f, - (u0, u0, p, + (newu0, newu0, newp, ptspan[1]))) else _f = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_oop(f, - (u0, p, ptspan[1]))) + (newu0, newp, ptspan[1]))) end else _f = ODEFunction{isinplace(prob), specialization(prob.f)}(f) end if kwargs === missing - ODEProblem{isinplace(prob)}(_f, u0, tspan, p, prob.problem_type; prob.kwargs..., + ODEProblem{isinplace(prob)}(_f, newu0, tspan, newp, prob.problem_type; prob.kwargs..., _kwargs...) else - ODEProblem{isinplace(prob)}(_f, u0, tspan, p, prob.problem_type; kwargs...) + ODEProblem{isinplace(prob)}(_f, newu0, tspan, newp, prob.problem_type; kwargs...) + end +end + +""" + remake_initializeprob(sys, scimlfn, u0, t0, p) + +Re-create the initialization problem present in the function `scimlfn`, using the +associated system `sys`, and the new values of `u0`, initial time `t0` and `p`. By +default, returns `nothing, nothing` if `scimlfn` does not have an initialization +problem, and `scimlfn.initializeprob, scimlfn.initializeprobmap` if it does. +""" +function remake_initializeprob(sys, scimlfn, u0, t0, p) + if !has_initializeprob(scimlfn) + return nothing, nothing end + return scimlfn.initializeprob, scimlfn.initializeprobmap end """