diff --git a/src/callbacks.jl b/src/callbacks.jl index e5a09045c..a4333d2cd 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -116,45 +116,42 @@ function get_condition(integrator::DEIntegrator, callback, abst) end end -# Use Recursion to find the first callback for type-stability - -# Base Case: Only one callback -function find_first_continuous_callback(integrator, callback::AbstractContinuousCallback) - (find_callback_time(integrator, callback, 1)..., 1, 1) -end - -# Starting Case: Compute on the first callback -function find_first_continuous_callback(integrator, callback::AbstractContinuousCallback, - args...) - find_first_continuous_callback(integrator, - find_callback_time(integrator, callback, 1)..., 1, 1, - args...) +# Use a generated function for type stability even when many callbacks are given +@inline function find_first_continuous_callback(integrator, + callbacks::Vararg{ + AbstractContinuousCallback, + N}) where {N} + find_first_continuous_callback(integrator, tuple(callbacks...)) end - -function find_first_continuous_callback(integrator, tmin::Number, upcrossing::Number, - event_occurred::Bool, event_idx::Int, idx::Int, - counter::Int, - callback2) - counter += 1 # counter is idx for callback2. - tmin2, upcrossing2, event_occurred2, event_idx2 = find_callback_time(integrator, - callback2, counter) - - if event_occurred2 && (tmin2 < tmin || !event_occurred) - return tmin2, upcrossing2, true, event_idx2, counter, counter - else - return tmin, upcrossing, event_occurred, event_idx, idx, counter +@generated function find_first_continuous_callback(integrator, + callbacks::NTuple{N, + AbstractContinuousCallback + }) where {N} + ex = quote + tmin, upcrossing, event_occurred, event_idx = find_callback_time(integrator, + callbacks[1], 1) + identified_idx = 1 end -end - -function find_first_continuous_callback(integrator, tmin::Number, upcrossing::Number, - event_occurred::Bool, event_idx::Int, idx::Int, - counter::Int, callback2, args...) - find_first_continuous_callback(integrator, - find_first_continuous_callback(integrator, tmin, - upcrossing, - event_occurred, - event_idx, idx, counter, - callback2)..., args...) + for i in 2:N + ex = quote + $ex + tmin2, upcrossing2, event_occurred2, event_idx2 = find_callback_time(integrator, + callbacks[$i], + $i) + if event_occurred2 && (tmin2 < tmin || !event_occurred) + tmin = tmin2 + upcrossing = upcrossing2 + event_occurred = true + event_idx = event_idx2 + identified_idx = $i + end + end + end + ex = quote + $ex + return tmin, upcrossing, event_occurred, event_idx, identified_idx, $N + end + ex end @inline function determine_event_occurance(integrator, callback::VectorContinuousCallback, diff --git a/test/callbacks.jl b/test/callbacks.jl index 9260e980e..252af55bd 100644 --- a/test/callbacks.jl +++ b/test/callbacks.jl @@ -50,3 +50,55 @@ cbs5 = CallbackSet(cbs1, cbs2) @test length(cbs5.discrete_callbacks) == 1 @test length(cbs5.continuous_callbacks) == 2 + +# For the purposes of this test, create a empty integrator type and +# override find_callback_time, since we don't actually care about testing +# the find callback time aspect, just the inference failure +struct EmptyIntegrator + u::Vector{Float64} +end +function DiffEqBase.find_callback_time(integrator::EmptyIntegrator, + callback::ContinuousCallback, counter) + 1.0 + counter, 0.9 + counter, true, counter +end +function DiffEqBase.find_callback_time(integrator::EmptyIntegrator, + callback::VectorContinuousCallback, counter) + 1.0 + counter, 0.9 + counter, true, counter +end +find_first_integrator = EmptyIntegrator([1.0, 2.0]) +vector_affect! = function (integrator, idx) + integrator.u = integrator.u + idx +end + +cond_1(u, t, integrator) = t - 1.0 +cond_2(u, t, integrator) = t - 1.1 +cond_3(u, t, integrator) = t - 1.2 +cond_4(u, t, integrator) = t - 1.3 +cond_5(u, t, integrator) = t - 1.4 +cond_6(u, t, integrator) = t - 1.5 +cond_7(u, t, integrator) = t - 1.6 +cond_8(u, t, integrator) = t - 1.7 +cond_9(u, t, integrator) = t - 1.8 +cond_10(u, t, integrator) = t - 1.9 +# Setup a lot of callbacks so the recursive inference failure happens +callbacks = (ContinuousCallback(cond_1, affect!), + ContinuousCallback(cond_2, affect!), + ContinuousCallback(cond_3, affect!), + ContinuousCallback(cond_4, affect!), + ContinuousCallback(cond_5, affect!), + ContinuousCallback(cond_6, affect!), + ContinuousCallback(cond_7, affect!), + ContinuousCallback(cond_8, affect!), + ContinuousCallback(cond_9, affect!), + ContinuousCallback(cond_10, affect!), + VectorContinuousCallback(cond_1, vector_affect!, 2), + VectorContinuousCallback(cond_2, vector_affect!, 2), + VectorContinuousCallback(cond_3, vector_affect!, 2), + VectorContinuousCallback(cond_4, vector_affect!, 2), + VectorContinuousCallback(cond_5, vector_affect!, 2), + VectorContinuousCallback(cond_6, vector_affect!, 2)); +function test_find_first_callback(callbacks, int) + @timed(DiffEqBase.find_first_continuous_callback(int, callbacks...)) +end +test_find_first_callback(callbacks, find_first_integrator); +@test test_find_first_callback(callbacks, find_first_integrator).bytes == 0