Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up generate_initializesystem() #3051

Merged
merged 12 commits into from
Oct 5, 2024
140 changes: 56 additions & 84 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,109 +5,81 @@ Generate `NonlinearSystem` which initializes an ODE problem from specified initi
"""
function generate_initializesystem(sys::ODESystem;
u0map = Dict(),
name = nameof(sys),
guesses = Dict(), check_defguess = false,
default_dd_value = 0.0,
algebraic_only = false,
initialization_eqs = [],
check_units = true,
kwargs...)
sts, eqs = unknowns(sys), equations(sys)
guesses = Dict(),
default_dd_guess = 0.0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the old name documented anywhere?

algebraic_only = false,
check_units = true, check_defguess = false,
name = nameof(sys), kwargs...)
vars = unique([unknowns(sys); getfield.((observed(sys)), :lhs)])
vars_set = Set(vars) # for efficient in-lookup

eqs = equations(sys)
idxs_diff = isdiffeq.(eqs)
idxs_alge = .!idxs_diff
num_alge = sum(idxs_alge)

# Start the equations list with algebraic equations
eqs_ics = eqs[idxs_alge]
u0 = Vector{Pair}(undef, 0)

# prepare map for dummy derivative substitution
eqs_diff = eqs[idxs_diff]
diffmap = Dict(getfield.(eqs_diff, :lhs) .=> getfield.(eqs_diff, :rhs))
observed_diffmap = Dict(Differential(get_iv(sys)).(getfield.((observed(sys)), :lhs)) .=>
Differential(get_iv(sys)).(getfield.((observed(sys)), :rhs)))
full_diffmap = merge(diffmap, observed_diffmap)
D = Differential(get_iv(sys))
diffmap = merge(
Dict(eq.lhs => eq.rhs for eq in eqs_diff),
Dict(D(eq.lhs) => D(eq.rhs) for eq in observed(sys))
)

full_states = unique([sts; getfield.((observed(sys)), :lhs)])
set_full_states = Set(full_states)
# 1) process dummy derivatives and u0map into initialization system
eqs_ics = eqs[idxs_alge] # start equation list with algebraic equations
defs = copy(defaults(sys)) # copy so we don't modify sys.defaults
guesses = merge(get_guesses(sys), todict(guesses))
schedule = getfield(sys, :schedule)

if schedule !== nothing
guessmap = [x[1] => get(guesses, x[1], default_dd_value)
for x in schedule.dummy_sub]
dd_guess = Dict(filter(x -> !isnothing(x[1]), guessmap))
if u0map === nothing || isempty(u0map)
filtered_u0 = u0map
else
filtered_u0 = Pair[]
for x in u0map
y = get(schedule.dummy_sub, x[1], x[1])
y = ModelingToolkit.fixpoint_sub(y, full_diffmap)

if y ∈ set_full_states
# defer initialization until defaults are merged below
push!(filtered_u0, y => x[2])
elseif y isa Symbolics.Arr
# scalarize array # TODO: don't scalarize arrays
_y = collect(y)
for i in eachindex(_y)
push!(filtered_u0, _y[i] => x[2][i])
end
elseif y isa Symbolics.BasicSymbolic
# y is a derivative expression expanded
# add to the initialization equations
push!(eqs_ics, y ~ x[2])
else
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
end
if !isnothing(schedule)
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
# set dummy derivatives to default_dd_guess unless specified
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
end
for (y, x) in u0map
y = get(schedule.dummy_sub, y, y)
y = fixpoint_sub(y, diffmap)
if y ∈ vars_set
# variables specified in u0 overrides defaults
push!(defs, y => x)
elseif y isa Symbolics.Arr
# TODO: don't scalarize arrays
merge!(defs, Dict(scalarize(y .=> x)))
elseif y isa Symbolics.BasicSymbolic
# y is a derivative expression expanded; add it to the initialization equations
push!(eqs_ics, y ~ x)
else
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
end
filtered_u0 = todict(filtered_u0)
end
else
dd_guess = Dict()
filtered_u0 = todict(u0map)
end

defs = merge(defaults(sys), filtered_u0)

for st in full_states
if st ∈ keys(defs)
def = defs[st]

if def isa Equation
st ∉ keys(guesses) && check_defguess &&
error("Invalid setup: unknown $(st) has an initial condition equation with no guess.")
push!(eqs_ics, def)
push!(u0, st => guesses[st])
else
push!(eqs_ics, st ~ def)
push!(u0, st => def)
end
elseif st ∈ keys(guesses)
push!(u0, st => guesses[st])
# 2) process other variables
for var in vars
if var ∈ keys(defs)
push!(eqs_ics, var ~ defs[var])
elseif var ∈ keys(guesses)
push!(defs, var => guesses[var])
elseif check_defguess
error("Invalid setup: unknown $(st) has no default value or initial guess")
error("Invalid setup: variable $(var) has no default value or initial guess")
end
end

# 3) process explicitly provided initialization equations
if !algebraic_only
for eq in [get_initialization_eqs(sys); initialization_eqs]
_eq = ModelingToolkit.fixpoint_sub(eq, full_diffmap)
push!(eqs_ics, _eq)
initialization_eqs = [get_initialization_eqs(sys); initialization_eqs]
for eq in initialization_eqs
eq = fixpoint_sub(eq, diffmap) # expand dummy derivatives
push!(eqs_ics, eq)
end
end

pars = [parameters(sys); get_iv(sys)]
nleqs = [eqs_ics; observed(sys)]

sys_nl = NonlinearSystem(nleqs,
full_states,
pars;
defaults = merge(ModelingToolkit.defaults(sys), todict(u0), dd_guess),
parameter_dependencies = parameter_dependencies(sys),
pars = [parameters(sys); get_iv(sys)] # include independent variable as pseudo-parameter
eqs_ics = [eqs_ics; observed(sys)]
return NonlinearSystem(
eqs_ics, vars, pars;
defaults = defs, parameter_dependencies = parameter_dependencies(sys),
checks = check_units,
name,
kwargs...)

return sys_nl
name, kwargs...
)
end
24 changes: 11 additions & 13 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,25 +126,23 @@ function NonlinearSystem(eqs, unknowns, ps;
throw(ArgumentError("NonlinearSystem does not accept `continuous_events`, you provided $continuous_events"))
discrete_events === nothing || isempty(discrete_events) ||
throw(ArgumentError("NonlinearSystem does not accept `discrete_events`, you provided $discrete_events"))

name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
# Move things over, but do not touch array expressions
#
# # we cannot scalarize in the loop because `eqs` itself might require
# scalarization
eqs = [x.lhs isa Union{Symbolic, Number} ? 0 ~ x.rhs - x.lhs : x
for x in scalarize(eqs)]

if !(isempty(default_u0) && isempty(default_p))
length(unique(nameof.(systems))) == length(systems) ||
throw(ArgumentError("System names must be unique."))
(isempty(default_u0) && isempty(default_p)) ||
Base.depwarn(
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
:NonlinearSystem, force = true)

# Accept a single (scalar/vector) equation, but make array for consistent internal handling
if !(eqs isa AbstractArray)
eqs = [eqs]
end
sysnames = nameof.(systems)
if length(unique(sysnames)) != length(sysnames)
throw(ArgumentError("System names must be unique."))
end

# Copy equations to canonical form, but do not touch array expressions
eqs = [wrap(eq.lhs) isa Symbolics.Arr ? eq : 0 ~ eq.rhs - eq.lhs for eq in eqs]

jac = RefValue{Any}(EMPTY_JAC)
defaults = todict(defaults)
defaults = Dict{Any, Any}(value(k) => value(v)
Expand Down
9 changes: 9 additions & 0 deletions test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -567,3 +567,12 @@ oprob_2nd_order_2 = ODEProblem(sys_2nd_order, u0_2nd_order_2, tspan, ps)
sol = solve(oprob_2nd_order_2, Rosenbrock23()) # retcode: Success
@test sol[Y][1] == 2.0
@test sol[D(Y)][1] == 0.5

@testset "Vector in initial conditions" begin
@variables x(t)[1:5] y(t)[1:5]
@named sys = ODESystem([D(x) ~ x, D(y) ~ y], t; initialization_eqs = [y ~ -x])
sys = structural_simplify(sys)
prob = ODEProblem(sys, [sys.x => ones(5)], (0.0, 1.0), [])
sol = solve(prob, Tsit5(), reltol = 1e-4)
@test all(sol(1.0, idxs = sys.x) .≈ +exp(1)) && all(sol(1.0, idxs = sys.y) .≈ -exp(1))
end
14 changes: 14 additions & 0 deletions test/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,20 @@ end
@test_nowarn solve(prob)
end

@testset "System of linear equations with vector variable" begin
# 1st example in https://en.wikipedia.org/w/index.php?title=System_of_linear_equations&oldid=1247697953
@variables x[1:3]
A = [3 2 -1
2 -2 4
-1 1/2 -1]
b = [1, -2, 0]
@named sys = NonlinearSystem(A * x ~ b, [x], [])
sys = structural_simplify(sys)
prob = NonlinearProblem(sys, unknowns(sys) .=> 0.0)
sol = solve(prob)
@test all(sol[x] .≈ A \ b)
end

@testset "resid_prototype when system has no unknowns and an equation" begin
@variables x
@parameters p
Expand Down
1 change: 1 addition & 0 deletions test/reduction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ A = reshape(1:(N^2), N, N)
eqs = xs ~ A * xs
@named sys′ = NonlinearSystem(eqs, [xs], [])
sys = structural_simplify(sys′)
@test length(equations(sys)) == 3 && length(observed(sys)) == 2

# issue 958
@parameters k₁ k₂ k₋₁ E₀
Expand Down
Loading