Skip to content

Commit

Permalink
fix: fix u0_constructor for DDEProblem/SDDEProblem
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 25, 2024
1 parent c8ac522 commit 93c1e8f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
8 changes: 8 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
check_length = true,
eval_expression = false,
eval_module = @__MODULE__,
u0_constructor = identity,
kwargs...) where {iip}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DDEProblem`")
Expand All @@ -892,6 +893,9 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
h(p, t) = h_oop(p, t)
h(p::MTKParameters, t) = h_oop(p..., t)
u0 = h(p, tspan[1])
if u0 !== nothing
u0 = u0_constructor(u0)
end

cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
kwargs = filter_kwargs(kwargs)
Expand All @@ -914,6 +918,7 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
sparsenoise = nothing,
eval_expression = false,
eval_module = @__MODULE__,
u0_constructor = identity,
kwargs...) where {iip}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `SDDEProblem`")
Expand All @@ -929,6 +934,9 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
h(p::MTKParameters, t) = h_oop(p..., t)
h(out, p::MTKParameters, t) = h_iip(out, p..., t)
u0 = h(p, tspan[1])
if u0 !== nothing
u0 = u0_constructor(u0)
end

cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
kwargs = filter_kwargs(kwargs)
Expand Down
10 changes: 9 additions & 1 deletion test/dde.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ModelingToolkit, DelayDiffEq, Test
using ModelingToolkit, DelayDiffEq, StaticArrays, Test
using SymbolicIndexingInterface: is_markovian
using ModelingToolkit: t_nounits as t, D_nounits as D

Expand Down Expand Up @@ -89,6 +89,10 @@ eqs = [D(x(t)) ~ a * x(t) + b * x(t - τ) + c + (α * x(t) + γ) * η]
prob_mtk = SDDEProblem(sys, [x(t) => 1.0 + t], tspan; constant_lags = (τ,));
@test_nowarn sol_mtk = solve(prob_mtk, RKMil())

prob_sa = SDDEProblem(
sys, [x(t) => 1.0 + t], tspan; constant_lags = (τ,), u0_constructor = SVector{1})
@test prob_sa.u0 isa SVector{4, Float64}

@parameters x(..) a

function oscillator(; name, k = 1.0, τ = 0.01)
Expand Down Expand Up @@ -126,6 +130,10 @@ obsfn = ModelingToolkit.build_explicit_observed_function(
@test_nowarn sol[[sys.osc1.delx, sys.osc2.delx]]
@test sol[sys.osc1.delx] sol(sol.t .- 0.01; idxs = sys.osc1.x).u

prob_sa = DDEProblem(sys, [], (0.0, 10.0); constant_lags = [sys.osc1.τ, sys.osc2.τ],
u0_constructor = SVector{4})
@test prob_sa.u0 isa SVector{4, Float64}

@testset "DDE observed with array variables" begin
@component function valve(; name)
@parameters begin
Expand Down

0 comments on commit 93c1e8f

Please sign in to comment.