Skip to content

Commit

Permalink
Add broken test for inference failure in callbacks loop
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Sep 19, 2024
1 parent cdcab7f commit 420f606
Showing 1 changed file with 31 additions and 6 deletions.
37 changes: 31 additions & 6 deletions perf/jet.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ArgParse, JET, Test, BenchmarkTools, DiffEqBase, ClimaTimeSteppers
# using Revise; include("perf/jet.jl")
using ArgParse, JET, Test, BenchmarkTools, SciMLBase, ClimaTimeSteppers
import ClimaTimeSteppers as CTS
function parse_commandline()
s = ArgParse.ArgParseSettings()
Expand All @@ -15,10 +16,29 @@ end
cts = joinpath(dirname(@__DIR__));
include(joinpath(cts, "test", "problems.jl"))
config_integrators(itc::IntegratorTestCase) = config_integrators(itc.prob)

struct Foo end
foo!(integrator) = nothing
(::Foo)(integrator) = foo!(integrator)
struct Bar end
bar!(integrator) = nothing
(::Bar)(integrator) = bar!(integrator)

function discrete_cb(cb!, n)
cond = if n == 1
(u, t, integrator) -> isnothing(cb!(integrator))
else
(u, t, integrator) -> isnothing(cb!(integrator)) || rand() 0.5
end
SciMLBase.DiscreteCallback(cond, cb!;)
end
function config_integrators(problem)
algorithm = CTS.IMEXAlgorithm(ARS343(), NewtonsMethod(; max_iters = 2))
dt = 0.01
integrator = DiffEqBase.init(problem, algorithm; dt)
discrete_callbacks = (discrete_cb(Foo(), 0), discrete_cb(Bar(), 0), discrete_cb(Foo(), 1), discrete_cb(Bar(), 1))
callback = SciMLBase.CallbackSet((), discrete_callbacks)

integrator = SciMLBase.init(problem, algorithm; dt, callback)
integrator.cache = CTS.init_cache(problem, algorithm)
return (; integrator)
end
Expand All @@ -33,7 +53,12 @@ else
end
(; integrator) = config_integrators(prob)

CTS.step_u!(integrator, integrator.cache) # compile first, and make sure it runs
step_allocs = @allocated CTS.step_u!(integrator, integrator.cache)
@show step_allocs
JET.@test_opt CTS.step_u!(integrator, integrator.cache)
@testset "JET / allocations" begin
CTS.step_u!(integrator, integrator.cache) # compile first, and make sure it runs
step_allocs = @allocated CTS.step_u!(integrator, integrator.cache)
@show step_allocs
JET.@test_opt CTS.step_u!(integrator, integrator.cache)

CTS.__step!(integrator) # compile first, and make sure it runs
JET.@test_opt broken = true CTS.__step!(integrator)
end

0 comments on commit 420f606

Please sign in to comment.