From ba3559ca03252ab3b78e95cb7390267141b644b3 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Mon, 14 Oct 2024 09:39:08 -0400 Subject: [PATCH] Compile fewer methods --- src/integrators.jl | 18 +++++++++--------- src/solvers/imex_ark.jl | 12 +++++------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/src/integrators.jl b/src/integrators.jl index 09345304..ce17e66a 100644 --- a/src/integrators.jl +++ b/src/integrators.jl @@ -1,4 +1,5 @@ import DataStructures +import Base.Cartesian: @nexprs """ DistributedODEIntegrator <: AbstractODEIntegrator @@ -226,14 +227,13 @@ reached_tstop(integrator, tstop, stop_at_tstop = integrator.dtchangeable) = integrator.t == tstop || (!stop_at_tstop && is_past_t(integrator, tstop)) -@inline unrolled_foreach(::Tuple{}, integrator) = nothing -@inline unrolled_foreach(callback, integrator) = - callback.condition(integrator.u, integrator.t, integrator) ? callback.affect!(integrator) : nothing -@inline unrolled_foreach(discrete_callbacks::Tuple{Any}, integrator) = - unrolled_foreach(first(discrete_callbacks), integrator) -@inline function unrolled_foreach(discrete_callbacks::Tuple, integrator) - unrolled_foreach(first(discrete_callbacks), integrator) - unrolled_foreach(Base.tail(discrete_callbacks), integrator) +@generated function unrolled_foreach(::Val{N}, callbacks, integrator) where {N} + return quote + @nexprs $N i -> begin + callback = callbacks[i] + callback.condition(integrator.u, integrator.t, integrator) ? callback.affect!(integrator) : nothing + end + end end function __step!(integrator) @@ -257,7 +257,7 @@ function __step!(integrator) # apply callbacks discrete_callbacks = integrator.callback.discrete_callbacks - unrolled_foreach(discrete_callbacks, integrator) + unrolled_foreach(Val(length(discrete_callbacks)), discrete_callbacks, integrator) # remove tstops that were just reached while !isempty(tstops) && reached_tstop(integrator, first(tstops)) diff --git a/src/solvers/imex_ark.jl b/src/solvers/imex_ark.jl index 4c2d24f9..2bda6001 100644 --- a/src/solvers/imex_ark.jl +++ b/src/solvers/imex_ark.jl @@ -1,4 +1,5 @@ import NVTX +import Base.Cartesian: @nexprs has_jac(T_imp!) = hasfield(typeof(T_imp!), :Wfact) && @@ -68,7 +69,7 @@ function step_u!(integrator, cache::IMEXARKCache) end end - update_stage!(integrator, cache, ntuple(i -> i, Val(s))) + update_stage!(Val(s), integrator, cache) t_final = t + dt @@ -88,13 +89,10 @@ function step_u!(integrator, cache::IMEXARKCache) return u end - -@inline update_stage!(integrator, cache, ::Tuple{}) = nothing -@inline update_stage!(integrator, cache, is::Tuple{Int}) = update_stage!(integrator, cache, first(is)) -@inline function update_stage!(integrator, cache, is::Tuple) - update_stage!(integrator, cache, first(is)) - update_stage!(integrator, cache, Base.tail(is)) +@generated update_stage!(::Val{s}, integrator, cache::IMEXARKCache) where {s} = quote + @nexprs $s i -> update_stage!(integrator, cache, i) end + @inline function update_stage!(integrator, cache::IMEXARKCache, i::Int) (; u, p, t, dt, alg) = integrator (; f) = integrator.sol.prob