Skip to content

Commit

Permalink
Switch find_first_continuous_callback to use a generated implementation.
Browse files Browse the repository at this point in the history
As mentioned in SciML/DifferentialEquations.jl#971, the current
recursive method for identifying the first continuous callback can cause
the compiler to give up on type inference, especially when there are
many callbacks. The fallback then allocates.

This switches this function to using a generated function (along with an
inline function that takes splatted tuples). Because this generated
function explicitly unrolls the tuple, there are no type inference
problems.

I added a test that allocates using the old implementation (about 19kb
allocations!) but does not with the new system.
  • Loading branch information
meson800 committed Aug 22, 2023
1 parent a2ac2da commit 677636b
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 37 deletions.
71 changes: 34 additions & 37 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,45 +116,42 @@ function get_condition(integrator::DEIntegrator, callback, abst)
end
end

# Use Recursion to find the first callback for type-stability

# Base Case: Only one callback
function find_first_continuous_callback(integrator, callback::AbstractContinuousCallback)
(find_callback_time(integrator, callback, 1)..., 1, 1)
end

# Starting Case: Compute on the first callback
function find_first_continuous_callback(integrator, callback::AbstractContinuousCallback,
args...)
find_first_continuous_callback(integrator,
find_callback_time(integrator, callback, 1)..., 1, 1,
args...)
# Use a generated function for type stability even when many callbacks are given
@inline function find_first_continuous_callback(integrator,
callbacks::Vararg{
AbstractContinuousCallback,
N}) where {N}
find_first_continuous_callback(integrator, tuple(callbacks...))
end

function find_first_continuous_callback(integrator, tmin::Number, upcrossing::Number,
event_occurred::Bool, event_idx::Int, idx::Int,
counter::Int,
callback2)
counter += 1 # counter is idx for callback2.
tmin2, upcrossing2, event_occurred2, event_idx2 = find_callback_time(integrator,
callback2, counter)

if event_occurred2 && (tmin2 < tmin || !event_occurred)
return tmin2, upcrossing2, true, event_idx2, counter, counter
else
return tmin, upcrossing, event_occurred, event_idx, idx, counter
@generated function find_first_continuous_callback(integrator,
callbacks::NTuple{N,
AbstractContinuousCallback
}) where {N}
ex = quote
tmin, upcrossing, event_occurred, event_idx = find_callback_time(integrator,
callbacks[1], 1)
identified_idx = 1
end
end

function find_first_continuous_callback(integrator, tmin::Number, upcrossing::Number,
event_occurred::Bool, event_idx::Int, idx::Int,
counter::Int, callback2, args...)
find_first_continuous_callback(integrator,
find_first_continuous_callback(integrator, tmin,
upcrossing,
event_occurred,
event_idx, idx, counter,
callback2)..., args...)
for i in 2:N
ex = quote
$ex

Check warning on line 137 in src/callbacks.jl

View check run for this annotation

Codecov / codecov/patch

src/callbacks.jl#L137

Added line #L137 was not covered by tests
tmin2, upcrossing2, event_occurred2, event_idx2 = find_callback_time(integrator,
callbacks[$i],
$i)
if event_occurred2 && (tmin2 < tmin || !event_occurred)
tmin = tmin2
upcrossing = upcrossing2
event_occurred = true
event_idx = event_idx2
identified_idx = $i
end
end
end
ex = quote
$ex

Check warning on line 151 in src/callbacks.jl

View check run for this annotation

Codecov / codecov/patch

src/callbacks.jl#L151

Added line #L151 was not covered by tests
return tmin, upcrossing, event_occurred, event_idx, identified_idx, $N
end
ex
end

@inline function determine_event_occurance(integrator, callback::VectorContinuousCallback,
Expand Down
52 changes: 52 additions & 0 deletions test/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,55 @@ cbs5 = CallbackSet(cbs1, cbs2)

@test length(cbs5.discrete_callbacks) == 1
@test length(cbs5.continuous_callbacks) == 2

# For the purposes of this test, create a empty integrator type and
# override find_callback_time, since we don't actually care about testing
# the find callback time aspect, just the inference failure
struct EmptyIntegrator
u::Vector{Float64}
end
function DiffEqBase.find_callback_time(integrator::EmptyIntegrator,
callback::ContinuousCallback, counter)
1.0 + counter, 0.9 + counter, true, counter
end
function DiffEqBase.find_callback_time(integrator::EmptyIntegrator,
callback::VectorContinuousCallback, counter)
1.0 + counter, 0.9 + counter, true, counter
end
find_first_integrator = EmptyIntegrator([1.0, 2.0])
vector_affect! = function (integrator, idx)
integrator.u = integrator.u + idx
end

cond_1(u, t, integrator) = t - 1.0
cond_2(u, t, integrator) = t - 1.1
cond_3(u, t, integrator) = t - 1.2
cond_4(u, t, integrator) = t - 1.3
cond_5(u, t, integrator) = t - 1.4
cond_6(u, t, integrator) = t - 1.5
cond_7(u, t, integrator) = t - 1.6
cond_8(u, t, integrator) = t - 1.7
cond_9(u, t, integrator) = t - 1.8
cond_10(u, t, integrator) = t - 1.9
# Setup a lot of callbacks so the recursive inference failure happens
callbacks = (ContinuousCallback(cond_1, affect!),
ContinuousCallback(cond_2, affect!),
ContinuousCallback(cond_3, affect!),
ContinuousCallback(cond_4, affect!),
ContinuousCallback(cond_5, affect!),
ContinuousCallback(cond_6, affect!),
ContinuousCallback(cond_7, affect!),
ContinuousCallback(cond_8, affect!),
ContinuousCallback(cond_9, affect!),
ContinuousCallback(cond_10, affect!),
VectorContinuousCallback(cond_1, vector_affect!, 2),
VectorContinuousCallback(cond_2, vector_affect!, 2),
VectorContinuousCallback(cond_3, vector_affect!, 2),
VectorContinuousCallback(cond_4, vector_affect!, 2),
VectorContinuousCallback(cond_5, vector_affect!, 2),
VectorContinuousCallback(cond_6, vector_affect!, 2));
function test_find_first_callback(callbacks, int)
@timed(DiffEqBase.find_first_continuous_callback(int, callbacks...))
end
test_find_first_callback(callbacks, find_first_integrator);
@test test_find_first_callback(callbacks, find_first_integrator).bytes == 0

0 comments on commit 677636b

Please sign in to comment.