diff --git a/src/dense/generic_dense.jl b/src/dense/generic_dense.jl index 9481726529..5c87221c26 100644 --- a/src/dense/generic_dense.jl +++ b/src/dense/generic_dense.jl @@ -46,37 +46,37 @@ end if length(integrator.cache.caches) == 2 if cache_current == 1 _ode_addsteps!(integrator.k, integrator.tprev, integrator.uprev, - integrator.u, - integrator.dt, f, integrator.p, - cache.caches[1], - always_calc_begin, allow_calc_end, force_calc_end) + integrator.u, + integrator.dt, f, integrator.p, + cache.caches[1], + always_calc_begin, allow_calc_end, force_calc_end) else @assert cache_current == 2 _ode_addsteps!(integrator.k, integrator.tprev, integrator.uprev, - integrator.u, - integrator.dt, f, integrator.p, - cache.caches[2], - always_calc_begin, allow_calc_end, force_calc_end) + integrator.u, + integrator.dt, f, integrator.p, + cache.caches[2], + always_calc_begin, allow_calc_end, force_calc_end) end else if cache_current == 1 _ode_addsteps!(integrator.k, integrator.tprev, integrator.uprev, - integrator.u, - integrator.dt, f, integrator.p, - cache.caches[1], - always_calc_begin, allow_calc_end, force_calc_end) + integrator.u, + integrator.dt, f, integrator.p, + cache.caches[1], + always_calc_begin, allow_calc_end, force_calc_end) elseif cache_current == 2 _ode_addsteps!(integrator.k, integrator.tprev, integrator.uprev, - integrator.u, - integrator.dt, f, integrator.p, - cache.caches[2], - always_calc_begin, allow_calc_end, force_calc_end) + integrator.u, + integrator.dt, f, integrator.p, + cache.caches[2], + always_calc_begin, allow_calc_end, force_calc_end) else _ode_addsteps!(integrator.k, integrator.tprev, integrator.uprev, - integrator.u, - integrator.dt, f, integrator.p, - cache.caches[cache_current], - always_calc_begin, allow_calc_end, force_calc_end) + integrator.u, + integrator.dt, f, integrator.p, + cache.caches[cache_current], + always_calc_begin, allow_calc_end, force_calc_end) end end end diff --git a/src/perform_step/explicit_rk_perform_step.jl b/src/perform_step/explicit_rk_perform_step.jl index c2acb4e07a..d198a3778e 100644 --- a/src/perform_step/explicit_rk_perform_step.jl +++ b/src/perform_step/explicit_rk_perform_step.jl @@ -77,12 +77,24 @@ function initialize!(integrator, cache::ExplicitRKCache) integrator.destats.nf += 1 end -@muladd function perform_step!(integrator, cache::ExplicitRKCache, repeat_step = false) - @unpack t, dt, uprev, u, f, p = integrator - alg = unwrap_alg(integrator, nothing) - @unpack A, c, α, αEEst, stages = cache.tab - @unpack kk, utilde, tmp, atmp = cache +@generated function accumulate_explicit_stages!(out, A, uprev, kk, dt, ::Val{s}) where {s} + s <= 1 && error("$s must be > 1") + # Note that `A` is transposed + if s == 2 + return :(@muladd @.. broadcast=false out=uprev + dt * (A[1, $s] * kk[1])) + else + expr = :(@muladd @.. broadcast=false out=uprev + + dt * (A[1, $s] * kk[1] + A[2, $s] * kk[2])) + acc = expr.args[end].args[end].args[end].args[end].args[end].args + for i in 3:(s - 1) + push!(acc, :(A[$i, $s] * kk[$i])) + end + return expr + end +end +@muladd function compute_stages!(f::F, A, c, utilde, u, tmp, uprev, kk, p, t, dt, + stages::Integer) where {F} # Middle for i in 2:(stages - 1) @.. broadcast=false utilde=zero(kk[1][1]) @@ -91,7 +103,6 @@ end end @.. broadcast=false tmp=uprev + dt * utilde f(kk[i], tmp, p, t + c[i] * dt) - integrator.destats.nf += 1 end #Last @@ -101,7 +112,47 @@ end end @.. broadcast=false u=uprev + dt * utilde f(kk[end], u, p, t + c[end] * dt) #fsallast is tmp even if not fsal - integrator.destats.nf += 1 + return nothing +end + +@generated function compute_stages!(f::F, A, c, u, tmp, uprev, kk, p, t, dt, + ::Val{s}) where {F, s} + quote + Base.@nexprs $(s - 2) i′->begin + i = i′ + 1 + accumulate_explicit_stages!(tmp, A, uprev, kk, dt, Val(i)) + f(kk[i], tmp, p, t + c[i] * dt) + end + accumulate_explicit_stages!(u, A, uprev, kk, dt, Val(s)) + f(kk[s], u, p, t + c[end] * dt) + end +end + +function runtime_split_stages!(f::F, A, c, utilde, u, tmp, uprev, kk, p, t, dt, + stages::Integer) where {F} + Base.@nif 16 (s->(s == stages)) (s->compute_stages!(f, A, c, u, tmp, uprev, kk, p, t, + dt, Val(s))) (s->compute_stages!(f, + A, + c, + utilde, + u, + tmp, + uprev, + kk, + p, + t, + dt, + stages)) +end + +@muladd function perform_step!(integrator, cache::ExplicitRKCache, repeat_step = false) + @unpack t, dt, uprev, u, f, p = integrator + alg = unwrap_alg(integrator, nothing) + @unpack A, c, α, αEEst, stages = cache.tab + @unpack kk, utilde, tmp, atmp = cache + + runtime_split_stages!(f, A, c, utilde, u, tmp, uprev, kk, p, t, dt, stages) + integrator.destats.nf += stages - 1 #Accumulate if !isfsal(alg.tableau)