Skip to content

Commit

Permalink
Compile fewer methods
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Oct 14, 2024
1 parent 1a6fdd4 commit ba3559c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 16 deletions.
18 changes: 9 additions & 9 deletions src/integrators.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import DataStructures
import Base.Cartesian: @nexprs

"""
DistributedODEIntegrator <: AbstractODEIntegrator
Expand Down Expand Up @@ -226,14 +227,13 @@ reached_tstop(integrator, tstop, stop_at_tstop = integrator.dtchangeable) =
integrator.t == tstop || (!stop_at_tstop && is_past_t(integrator, tstop))


@inline unrolled_foreach(::Tuple{}, integrator) = nothing
@inline unrolled_foreach(callback, integrator) =
callback.condition(integrator.u, integrator.t, integrator) ? callback.affect!(integrator) : nothing
@inline unrolled_foreach(discrete_callbacks::Tuple{Any}, integrator) =
unrolled_foreach(first(discrete_callbacks), integrator)
@inline function unrolled_foreach(discrete_callbacks::Tuple, integrator)
unrolled_foreach(first(discrete_callbacks), integrator)
unrolled_foreach(Base.tail(discrete_callbacks), integrator)
@generated function unrolled_foreach(::Val{N}, callbacks, integrator) where {N}
return quote
@nexprs $N i -> begin
callback = callbacks[i]
callback.condition(integrator.u, integrator.t, integrator) ? callback.affect!(integrator) : nothing
end
end
end

function __step!(integrator)
Expand All @@ -257,7 +257,7 @@ function __step!(integrator)

# apply callbacks
discrete_callbacks = integrator.callback.discrete_callbacks
unrolled_foreach(discrete_callbacks, integrator)
unrolled_foreach(Val(length(discrete_callbacks)), discrete_callbacks, integrator)

# remove tstops that were just reached
while !isempty(tstops) && reached_tstop(integrator, first(tstops))
Expand Down
12 changes: 5 additions & 7 deletions src/solvers/imex_ark.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import NVTX
import Base.Cartesian: @nexprs

has_jac(T_imp!) =
hasfield(typeof(T_imp!), :Wfact) &&
Expand Down Expand Up @@ -68,7 +69,7 @@ function step_u!(integrator, cache::IMEXARKCache)
end
end

update_stage!(integrator, cache, ntuple(i -> i, Val(s)))
update_stage!(Val(s), integrator, cache)

t_final = t + dt

Expand All @@ -88,13 +89,10 @@ function step_u!(integrator, cache::IMEXARKCache)
return u
end


@inline update_stage!(integrator, cache, ::Tuple{}) = nothing
@inline update_stage!(integrator, cache, is::Tuple{Int}) = update_stage!(integrator, cache, first(is))
@inline function update_stage!(integrator, cache, is::Tuple)
update_stage!(integrator, cache, first(is))
update_stage!(integrator, cache, Base.tail(is))
@generated update_stage!(::Val{s}, integrator, cache::IMEXARKCache) where {s} = quote
@nexprs $s i -> update_stage!(integrator, cache, i)
end

@inline function update_stage!(integrator, cache::IMEXARKCache, i::Int)
(; u, p, t, dt, alg) = integrator
(; f) = integrator.sol.prob
Expand Down

0 comments on commit ba3559c

Please sign in to comment.