Skip to content

Commit

Permalink
Merge pull request #920 from meson800/find_first_callback_gen_func
Browse files Browse the repository at this point in the history
Switch find_first_continuous_callback to use a generated implementation.
  • Loading branch information
ChrisRackauckas authored Aug 23, 2023
2 parents a2ac2da + 677636b commit 1799fc3
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 37 deletions.
71 changes: 34 additions & 37 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,45 +116,42 @@ function get_condition(integrator::DEIntegrator, callback, abst)
end
end

# Use Recursion to find the first callback for type-stability

# Base Case: Only one callback
function find_first_continuous_callback(integrator, callback::AbstractContinuousCallback)
(find_callback_time(integrator, callback, 1)..., 1, 1)
end

# Starting Case: Compute on the first callback
function find_first_continuous_callback(integrator, callback::AbstractContinuousCallback,
args...)
find_first_continuous_callback(integrator,
find_callback_time(integrator, callback, 1)..., 1, 1,
args...)
# Use a generated function for type stability even when many callbacks are given
@inline function find_first_continuous_callback(integrator,
callbacks::Vararg{
AbstractContinuousCallback,
N}) where {N}
find_first_continuous_callback(integrator, tuple(callbacks...))
end

function find_first_continuous_callback(integrator, tmin::Number, upcrossing::Number,
event_occurred::Bool, event_idx::Int, idx::Int,
counter::Int,
callback2)
counter += 1 # counter is idx for callback2.
tmin2, upcrossing2, event_occurred2, event_idx2 = find_callback_time(integrator,
callback2, counter)

if event_occurred2 && (tmin2 < tmin || !event_occurred)
return tmin2, upcrossing2, true, event_idx2, counter, counter
else
return tmin, upcrossing, event_occurred, event_idx, idx, counter
@generated function find_first_continuous_callback(integrator,
callbacks::NTuple{N,
AbstractContinuousCallback
}) where {N}
ex = quote
tmin, upcrossing, event_occurred, event_idx = find_callback_time(integrator,
callbacks[1], 1)
identified_idx = 1
end
end

function find_first_continuous_callback(integrator, tmin::Number, upcrossing::Number,
event_occurred::Bool, event_idx::Int, idx::Int,
counter::Int, callback2, args...)
find_first_continuous_callback(integrator,
find_first_continuous_callback(integrator, tmin,
upcrossing,
event_occurred,
event_idx, idx, counter,
callback2)..., args...)
for i in 2:N
ex = quote
$ex
tmin2, upcrossing2, event_occurred2, event_idx2 = find_callback_time(integrator,
callbacks[$i],
$i)
if event_occurred2 && (tmin2 < tmin || !event_occurred)
tmin = tmin2
upcrossing = upcrossing2
event_occurred = true
event_idx = event_idx2
identified_idx = $i
end
end
end
ex = quote
$ex
return tmin, upcrossing, event_occurred, event_idx, identified_idx, $N
end
ex
end

@inline function determine_event_occurance(integrator, callback::VectorContinuousCallback,
Expand Down
52 changes: 52 additions & 0 deletions test/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,55 @@ cbs5 = CallbackSet(cbs1, cbs2)

@test length(cbs5.discrete_callbacks) == 1
@test length(cbs5.continuous_callbacks) == 2

# For the purposes of this test, create a empty integrator type and
# override find_callback_time, since we don't actually care about testing
# the find callback time aspect, just the inference failure
struct EmptyIntegrator
u::Vector{Float64}
end
function DiffEqBase.find_callback_time(integrator::EmptyIntegrator,
callback::ContinuousCallback, counter)
1.0 + counter, 0.9 + counter, true, counter
end
function DiffEqBase.find_callback_time(integrator::EmptyIntegrator,
callback::VectorContinuousCallback, counter)
1.0 + counter, 0.9 + counter, true, counter
end
find_first_integrator = EmptyIntegrator([1.0, 2.0])
vector_affect! = function (integrator, idx)
integrator.u = integrator.u + idx
end

cond_1(u, t, integrator) = t - 1.0
cond_2(u, t, integrator) = t - 1.1
cond_3(u, t, integrator) = t - 1.2
cond_4(u, t, integrator) = t - 1.3
cond_5(u, t, integrator) = t - 1.4
cond_6(u, t, integrator) = t - 1.5
cond_7(u, t, integrator) = t - 1.6
cond_8(u, t, integrator) = t - 1.7
cond_9(u, t, integrator) = t - 1.8
cond_10(u, t, integrator) = t - 1.9
# Setup a lot of callbacks so the recursive inference failure happens
callbacks = (ContinuousCallback(cond_1, affect!),
ContinuousCallback(cond_2, affect!),
ContinuousCallback(cond_3, affect!),
ContinuousCallback(cond_4, affect!),
ContinuousCallback(cond_5, affect!),
ContinuousCallback(cond_6, affect!),
ContinuousCallback(cond_7, affect!),
ContinuousCallback(cond_8, affect!),
ContinuousCallback(cond_9, affect!),
ContinuousCallback(cond_10, affect!),
VectorContinuousCallback(cond_1, vector_affect!, 2),
VectorContinuousCallback(cond_2, vector_affect!, 2),
VectorContinuousCallback(cond_3, vector_affect!, 2),
VectorContinuousCallback(cond_4, vector_affect!, 2),
VectorContinuousCallback(cond_5, vector_affect!, 2),
VectorContinuousCallback(cond_6, vector_affect!, 2));
function test_find_first_callback(callbacks, int)
@timed(DiffEqBase.find_first_continuous_callback(int, callbacks...))
end
test_find_first_callback(callbacks, find_first_integrator);
@test test_find_first_callback(callbacks, find_first_integrator).bytes == 0

0 comments on commit 1799fc3

Please sign in to comment.