From e3b7d7a05f23a890eefcc949c037a385fec840d1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 22 May 2024 14:06:57 +0530 Subject: [PATCH] fix: remake initialization problem during DAE initialization --- Project.toml | 2 ++ src/OrdinaryDiffEq.jl | 2 ++ src/initialize_dae.jl | 11 ++++++++++- 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c09c7acd72..bcfff85422 100644 --- a/Project.toml +++ b/Project.toml @@ -42,6 +42,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" [compat] @@ -84,6 +85,7 @@ SparseArrays = "1.9" SparseDiffTools = "2.3" StaticArrayInterface = "1.2" StaticArrays = "1.0" +SymbolicIndexingInterface = "0.3.16" TruncatedStacktraces = "1.2" julia = "1.10" diff --git a/src/OrdinaryDiffEq.jl b/src/OrdinaryDiffEq.jl index 3a0506f30d..025ec50c46 100644 --- a/src/OrdinaryDiffEq.jl +++ b/src/OrdinaryDiffEq.jl @@ -64,6 +64,8 @@ using ExponentialUtilities using NonlinearSolve +using SymbolicIndexingInterface + # Required by temporary fix in not in-place methods with 12+ broadcasts # `MVector` is used by Nordsieck forms import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA diff --git a/src/initialize_dae.jl b/src/initialize_dae.jl index 9404d9531f..5d38c3945d 100644 --- a/src/initialize_dae.jl +++ b/src/initialize_dae.jl @@ -134,7 +134,16 @@ end function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem}, alg::OverrideInit, isinplace::Union{Val{true}, Val{false}}) initializeprob = prob.f.initializeprob - + if initializeprob.f.sys !== nothing && prob.f.sys !== nothing + initu0vars = variable_symbols(initializeprob) + initu0order = variable_index.((initializeprob,), initu0vars) + # Variable symbols are not guaranteed to be in order + invpermute!(initu0vars, initu0order) + initu0 = getu(prob.f.initializeprob, initu0vars)(prob) + initp = remake_buffer(initializeprob, parameter_values(initializeprob), + Dict(sym => getu(prob, sym)(prob) for sym in parameter_symbols(initializeprob))) + initializeprob = remake(initializeprob; u0 = initu0, p = initp) + end # If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit # Since then it's the case of not a DAE but has initializeprob # In which case, it should be differentiable