Skip to content

Commit

Permalink
feat: turn parameter dependencies into initialization equations
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Sep 26, 2024
1 parent 1a03b5a commit e715e28
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 4 deletions.
9 changes: 9 additions & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,15 @@ function has_observed_with_lhs(sys, sym)
end
end

function has_parameter_dependency_with_lhs(sys, sym)
has_parameter_dependencies(sys) || return false
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
return any(isequal(sym), ic.dependent_pars)
else
return any(isequal(sym), [eq.lhs for eq in parameter_dependencies(sys)])
end
end

function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym)
if is_variable(sys, sym) || is_independent_variable(sys, sym)
push!(ts_idxs, ContinuousTimeseries())
Expand Down
12 changes: 11 additions & 1 deletion src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -860,13 +860,23 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
solvablepars = [p
for p in parameters(sys)
if is_parameter_solvable(p, parammap, defs, guesses)]

pvarmap = if parammap === nothing || parammap == SciMLBase.NullParameters() || !(eltype(parammap) <: Pair) && isempty(parammap)
defs
else
merge(defs, todict(parammap))
end
setparobserved = filter(keys(pvarmap)) do var
has_parameter_dependency_with_lhs(sys, var)
end
else
solvablepars = ()
setparobserved = ()
end
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
if sys isa ODESystem && build_initializeprob &&
(((implicit_dae || !isempty(missingvars) || !isempty(solvablepars) ||
!isempty(setobserved)) &&
!isempty(setobserved) || !isempty(setparobserved)) &&
ModelingToolkit.get_tearing_state(sys) !== nothing) ||
!isempty(initialization_equations(sys))) && t !== nothing
if eltype(u0map) <: Number
Expand Down
25 changes: 22 additions & 3 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,19 +146,34 @@ function generate_initializesystem(sys::ODESystem;
end
end
end

# parameter dependencies become equations, their LHS become unknowns
for eq in parameter_dependencies(sys)
varp = tovar(eq.lhs)
paramsubs[eq.lhs] = varp
push!(eqs_ics, eq)
guessval = get(guesses, eq.lhs, eq.rhs)
push!(u0, varp => guessval)
end

# handle values provided for dependent parameters
for (k, v) in merge(defaults(sys), pmap)
if has_parameter_dependency_with_lhs(sys, k)
push!(eqs_ics, paramsubs[k] ~ v)
end
end
pars = vcat(
[get_iv(sys)],
[p for p in parameters(sys) if !haskey(paramsubs, p)]
)
nleqs = [eqs_ics; observed(sys)]
nleqs = Symbolics.substitute.(nleqs, (paramsubs,))
unks = [full_states; collect(values(paramsubs))]

u0 = Dict(k => substitute(v, paramsubs) for (k, v) in u0)
sys_nl = NonlinearSystem(nleqs,
unks,
pars;
defaults = merge(ModelingToolkit.defaults(sys), todict(u0), dd_guess, pmap),
parameter_dependencies = parameter_dependencies(sys),
checks = check_units,
name,
kwargs...)
Expand Down Expand Up @@ -205,8 +220,12 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
solvablepars = [par
for par in parameters(sys)
if is_parameter_solvable(par, p, defs, guesses)]
pvarmap = merge(defs, p)
setparobserved = filter(keys(pvarmap)) do var
has_parameter_dependency_with_lhs(sys, var)
end
if (((!isempty(missingvars) || !isempty(solvablepars) ||
!isempty(setobserved)) &&
!isempty(setobserved) || !isempty(setparobserved)) &&
ModelingToolkit.get_tearing_state(sys) !== nothing) ||
!isempty(initialization_equations(sys)))
initprob = InitializationProblem(sys, t0, u0, p)
Expand Down

0 comments on commit e715e28

Please sign in to comment.