Skip to content

Commit

Permalink
WIP: Refactor ODEIntegrator to not allow undef fsal states
Browse files Browse the repository at this point in the history
This was used as a performance optimization early on, dropping the construction of those two vectors since we already construct so much in the caches, we might as well reuse one of the cache pointers. And it's also built into some methods what cache pointer that must be. So the integrator is made with undef's and then during initialization phase the pointers are set.

However, this is unnecessary and adds some complexity. For one, it makes the constructor a bit of a mess. But for two, it gives Enzyme issues as demonstrated in #2282. A better solution is then to just, construct the type correctly.

To do this, we simply need to refactor the information of what vectors correspond to fsal first and last into a function that is per-cache, and use that function in the integrator construction. That's already done in this PR. All that's required to complete this PR is to ensure this refactor is done on every method.
  • Loading branch information
ChrisRackauckas committed Aug 17, 2024
1 parent d88f255 commit a57ef07
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 70 deletions.
12 changes: 5 additions & 7 deletions lib/OrdinaryDiffEqCore/src/caches/basic_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@ abstract type OrdinaryDiffEqMutableCache <: OrdinaryDiffEqCache end
struct ODEEmptyCache <: OrdinaryDiffEqConstantCache end
struct ODEChunkCache{CS} <: OrdinaryDiffEqConstantCache end

# Don't worry about the potential alloc on a constant cache
get_fsalfirstlast(cache::OrdinaryDiffEqConstantCache) = zero(cache.u), zero(cache.u)

mutable struct CompositeCache{T, F} <: OrdinaryDiffEqCache
caches::T
choice_function::F
current::Int
end

TruncatedStacktraces.@truncate_stacktrace CompositeCache 1
get_fsalfirstlast(cache::CompositeCache) = get_fsalfirstlast(cache.caches[1])

mutable struct DefaultCache{T1, T2, T3, T4, T5, T6, A, F} <: OrdinaryDiffEqCache
args::A
Expand All @@ -28,12 +31,7 @@ mutable struct DefaultCache{T1, T2, T3, T4, T5, T6, A, F} <: OrdinaryDiffEqCache
end
end

TruncatedStacktraces.@truncate_stacktrace DefaultCache 1

if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
Base.Experimental.silence!(CompositeCache)
Base.Experimental.silence!(DefaultCache)
end
get_fsalfirstlast(cache::DefaultCache) = get_fsalfirstlast(cache.cache1)

function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
Expand Down
50 changes: 0 additions & 50 deletions lib/OrdinaryDiffEqCore/src/integrators/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,54 +136,4 @@ mutable struct ODEIntegrator{algType <: Union{OrdinaryDiffEqAlgorithm, DAEAlgori
differential_vars::DV
fsalfirst::FSALType
fsallast::FSALType

function ODEIntegrator{algType, IIP, uType, duType, tType, pType, eigenType, EEstT,
tTypeNoUnits, tdirType, ksEltype, SolType,
F, CacheType, O, FSALType, EventErrorType, CallbackCacheType,
InitializeAlgType, DV}(sol, u, du, k, t, dt, f, p, uprev, uprev2,
duprev, tprev,
alg, dtcache, dtchangeable, dtpropose, tdir,
eigen_est, EEst, qold, q11, erracc, dtacc,
success_iter,
iter, saveiter, saveiter_dense, cache,
callback_cache,
kshortsize, force_stepfail, last_stepfail,
just_hit_tstop,
do_error_check,
event_last_time, vector_event_last_time,
last_event_error,
accept_step, isout, reeval_fsal, u_modified,
reinitialize, isdae,
opts, stats,
initializealg, differential_vars) where {algType, IIP, uType,
duType, tType, pType,
eigenType, EEstT,
tTypeNoUnits, tdirType,
ksEltype, SolType, F,
CacheType, O,
FSALType,
EventErrorType,
CallbackCacheType,
InitializeAlgType, DV}
new{algType, IIP, uType, duType, tType, pType, eigenType, EEstT, tTypeNoUnits,
tdirType, ksEltype, SolType,
F, CacheType, O, FSALType, EventErrorType,
CallbackCacheType, InitializeAlgType, DV
}(sol, u, du, k, t, dt, f, p, uprev, uprev2, duprev, tprev,
alg, dtcache, dtchangeable, dtpropose, tdir,
eigen_est, EEst, qold, q11, erracc, dtacc, success_iter,
iter, saveiter, saveiter_dense, cache, callback_cache,
kshortsize, force_stepfail, last_stepfail, just_hit_tstop,
do_error_check,
event_last_time, vector_event_last_time, last_event_error,
accept_step, isout, reeval_fsal, u_modified, reinitialize, isdae,
opts, stats, initializealg, differential_vars) # Leave off fsalfirst and last
end
end

if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
Base.Experimental.silence!(ODEIntegrator)
end
# When this is changed, DelayDiffEq.jl must be changed as well!

TruncatedStacktraces.@truncate_stacktrace ODEIntegrator 2 1 3 4
4 changes: 3 additions & 1 deletion lib/OrdinaryDiffEqCore/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ function DiffEqBase.__init(
reinitiailize = true
saveiter = 0 # Starts at 0 so first save is at 1
saveiter_dense = 0
faslfirst, fsallast = get_fsalfirstlast(cache)

integrator = ODEIntegrator{typeof(_alg), isinplace(prob), uType, typeof(du),
tType, typeof(p),
Expand All @@ -494,7 +495,8 @@ function DiffEqBase.__init(
last_event_error, accept_step,
isout, reeval_fsal,
u_modified, reinitiailize, isdae,
opts, stats, initializealg, differential_vars)
opts, stats, initializealg, differential_vars,
faslfirst, fsallast)

if initialize_integrator
if isdae || SciMLBase.has_initializeprob(prob.f)
Expand Down
5 changes: 1 addition & 4 deletions lib/OrdinaryDiffEqTsit5/src/tsit_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
step_limiter!::StepLimiter
thread::Thread
end
if isdefined(Base, :Experimental) && isdefined(Base.Experimental, :silence!)
Base.Experimental.silence!(Tsit5Cache)
end

function alg_cache(alg::Tsit5, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
Expand All @@ -39,7 +36,7 @@ function alg_cache(alg::Tsit5, u, rate_prototype, ::Type{uEltypeNoUnits},
alg.stage_limiter!, alg.step_limiter!, alg.thread)
end

TruncatedStacktraces.@truncate_stacktrace Tsit5Cache 1
get_fsalfirstlast(cache::Tsit5Cache) = (cache.k1, cache.k7)

function alg_cache(alg::Tsit5, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
Expand Down
2 changes: 0 additions & 2 deletions lib/OrdinaryDiffEqTsit5/src/tsit_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@ end

function initialize!(integrator, cache::Tsit5Cache)
integrator.kshortsize = 7
integrator.fsalfirst = cache.k1
integrator.fsallast = cache.k7 # setup pointers
resize!(integrator.k, integrator.kshortsize)
# Setup k pointers
integrator.k[1] = cache.k1
Expand Down
11 changes: 7 additions & 4 deletions lib/OrdinaryDiffEqVerner/src/verner_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
lazy::Bool
end

TruncatedStacktraces.@truncate_stacktrace Vern6Cache 1
get_fsalfirstlast(cache::Vern6Cache) = (cache.k1, cache.k9)

function alg_cache(alg::Vern6, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
Expand Down Expand Up @@ -86,7 +86,8 @@ end
lazy::Bool
end

TruncatedStacktraces.@truncate_stacktrace Vern7Cache 1
# fake values since non-FSAL method
get_fsalfirstlast(cache::Vern7Cache) = (cache.k1, cache.k2)

function alg_cache(alg::Vern7, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
Expand Down Expand Up @@ -151,7 +152,8 @@ end
lazy::Bool
end

TruncatedStacktraces.@truncate_stacktrace Vern8Cache 1
# fake values since non-FSAL method
get_fsalfirstlast(cache::Vern8Cache) = (cache.k1, cache.k2)

function alg_cache(alg::Vern8, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
Expand Down Expand Up @@ -224,7 +226,8 @@ end
lazy::Bool
end

TruncatedStacktraces.@truncate_stacktrace Vern9Cache 1
# fake values since non-FSAL method
get_fsalfirstlast(cache::Vern9Cache) = (cache.k1, cache.k2)

function alg_cache(alg::Vern9, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
Expand Down
2 changes: 0 additions & 2 deletions lib/OrdinaryDiffEqVerner/src/verner_rk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ end
function initialize!(integrator, cache::Vern6Cache)
alg = unwrap_alg(integrator, false)
cache.lazy ? (integrator.kshortsize = 9) : (integrator.kshortsize = 12)
integrator.fsalfirst = cache.k1
integrator.fsallast = cache.k9
@unpack k = integrator
resize!(k, integrator.kshortsize)
k[1] = cache.k1
Expand Down

0 comments on commit a57ef07

Please sign in to comment.