diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 94650d43a8..8d9e0b5381 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -879,6 +879,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [], check_length = true, eval_expression = false, eval_module = @__MODULE__, + u0_constructor = identity, kwargs...) where {iip} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DDEProblem`") @@ -892,6 +893,9 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [], h(p, t) = h_oop(p, t) h(p::MTKParameters, t) = h_oop(p..., t) u0 = h(p, tspan[1]) + if u0 !== nothing + u0 = u0_constructor(u0) + end cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...) kwargs = filter_kwargs(kwargs) @@ -914,6 +918,7 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [], sparsenoise = nothing, eval_expression = false, eval_module = @__MODULE__, + u0_constructor = identity, kwargs...) where {iip} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `SDDEProblem`") @@ -929,6 +934,9 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [], h(p::MTKParameters, t) = h_oop(p..., t) h(out, p::MTKParameters, t) = h_iip(out, p..., t) u0 = h(p, tspan[1]) + if u0 !== nothing + u0 = u0_constructor(u0) + end cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...) kwargs = filter_kwargs(kwargs) diff --git a/test/dde.jl b/test/dde.jl index f5e72ee1bb..ec076fba53 100644 --- a/test/dde.jl +++ b/test/dde.jl @@ -1,4 +1,4 @@ -using ModelingToolkit, DelayDiffEq, Test +using ModelingToolkit, DelayDiffEq, StaticArrays, Test using SymbolicIndexingInterface: is_markovian using ModelingToolkit: t_nounits as t, D_nounits as D @@ -89,6 +89,10 @@ eqs = [D(x(t)) ~ a * x(t) + b * x(t - τ) + c + (α * x(t) + γ) * η] prob_mtk = SDDEProblem(sys, [x(t) => 1.0 + t], tspan; constant_lags = (τ,)); @test_nowarn sol_mtk = solve(prob_mtk, RKMil()) +prob_sa = SDDEProblem( + sys, [x(t) => 1.0 + t], tspan; constant_lags = (τ,), u0_constructor = SVector{1}) +@test prob_sa.u0 isa SVector{4, Float64} + @parameters x(..) a function oscillator(; name, k = 1.0, τ = 0.01) @@ -126,6 +130,10 @@ obsfn = ModelingToolkit.build_explicit_observed_function( @test_nowarn sol[[sys.osc1.delx, sys.osc2.delx]] @test sol[sys.osc1.delx] ≈ sol(sol.t .- 0.01; idxs = sys.osc1.x).u +prob_sa = DDEProblem(sys, [], (0.0, 10.0); constant_lags = [sys.osc1.τ, sys.osc2.τ], + u0_constructor = SVector{4}) +@test prob_sa.u0 isa SVector{4, Float64} + @testset "DDE observed with array variables" begin @component function valve(; name) @parameters begin