Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inference failure in callbacks #318

Merged
merged 2 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ClimaTimeSteppers"
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
authors = ["Climate Modeling Alliance"]
version = "0.7.37"
version = "0.7.38"

[deps]
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
Expand Down
2 changes: 1 addition & 1 deletion perf/jet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,5 @@ end
JET.@test_opt CTS.step_u!(integrator, integrator.cache)

CTS.__step!(integrator) # compile first, and make sure it runs
JET.@test_opt broken = true CTS.__step!(integrator)
JET.@test_opt CTS.__step!(integrator)
end
19 changes: 12 additions & 7 deletions src/integrators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,17 @@ is_past_t(integrator, t) = tdir(integrator) * (t - integrator.t) < zero(integrat
reached_tstop(integrator, tstop, stop_at_tstop = integrator.dtchangeable) =
integrator.t == tstop || (!stop_at_tstop && is_past_t(integrator, tstop))


@inline unrolled_foreach(::Tuple{}, integrator) = nothing
@inline unrolled_foreach(callback, integrator) =
callback.condition(integrator.u, integrator.t, integrator) ? callback.affect!(integrator) : nothing
@inline unrolled_foreach(discrete_callbacks::Tuple{Any}, integrator) =
unrolled_foreach(first(discrete_callbacks), integrator)
@inline function unrolled_foreach(discrete_callbacks::Tuple, integrator)
unrolled_foreach(first(discrete_callbacks), integrator)
unrolled_foreach(Base.tail(discrete_callbacks), integrator)
end

function __step!(integrator)
(; _dt, dtchangeable, tstops) = integrator

Expand All @@ -246,13 +257,7 @@ function __step!(integrator)

# apply callbacks
discrete_callbacks = integrator.callback.discrete_callbacks
for (ncb, callback) in enumerate(discrete_callbacks)
if callback.condition(integrator.u, integrator.t, integrator)::Bool
NVTX.@range "Callback $ncb of $(length(discrete_callbacks))" color = colorant"yellow" begin
callback.affect!(integrator)
end
end
end
unrolled_foreach(discrete_callbacks, integrator)

# remove tstops that were just reached
while !isempty(tstops) && reached_tstop(integrator, first(tstops))
Expand Down
Loading