From 7dba4664382923b21a2b8ee23248f893385e00d4 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 13 Jul 2024 16:47:54 +0200 Subject: [PATCH 1/2] WIP: Make Enzyme discrete adjoints work MWE now works: ```julia using Enzyme, OrdinaryDiffEq, StaticArrays Enzyme.EnzymeCore.EnzymeRules.inactive_type(::Type{SciMLBase.DEStats}) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_from_initdt!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.fixed_t_for_floatingpoint_error!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_accept!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_reject!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(DiffEqBase.fastpow), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_perform_step!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.check_error!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.log_step!), args...) = true function lorenz!(du, u, p, t) du[1] = 10.0(u[2] - u[1]) du[2] = u[1] * (28.0 - u[3]) - u[2] du[3] = u[1] * u[2] - (8 / 3) * u[3] end const _saveat = SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0] function f(y::Array{Float64}, u0::Array{Float64}) tspan = (0.0, 3.0) prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan) sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough()) y .= sol[1,:] return nothing end; u0 = [1.0; 0.0; 0.0] d_u0 = zeros(3) y = zeros(13) dy = zeros(13) Enzyme.autodiff(Reverse, f, Duplicated(y, dy), Duplicated(u0, d_u0)); ``` Core issues to finish this: 1. I shouldn't have to pull all of the logging out to a separate function, but there seems to be a bug in enzyme with int inactivity https://github.com/EnzymeAD/Enzyme.jl/issues/1636 2. `saveat` has issues because it uses Julia ranges, which can have a floating point fix issue https://github.com/EnzymeAD/Enzyme.jl/issues/274 3. adding the zero(u), zero(u) is required because Enzyme does not seem to support non-fully initialized types (@wsmoses is that known?) and segfaults when trying to use the uninitialized memory. So making the inner constructor not use undef is and easy fix to that. But that's not memory optimal. It would take a bit of a refactor to make it memory optimal, but it's no big deal and it's probably something that improves the package anyways. --- src/integrators/integrator_interface.jl | 6 +- src/integrators/integrator_utils.jl | 71 ++++++++++--------- src/integrators/type.jl | 2 +- src/perform_step/low_order_rk_perform_step.jl | 12 +++- 4 files changed, 55 insertions(+), 36 deletions(-) diff --git a/src/integrators/integrator_interface.jl b/src/integrators/integrator_interface.jl index 5bf13dc578..1242d8a0c9 100644 --- a/src/integrators/integrator_interface.jl +++ b/src/integrators/integrator_interface.jl @@ -480,7 +480,11 @@ function DiffEqBase.auto_dt_reset!(integrator::ODEIntegrator) integrator.opts.internalnorm, integrator.sol.prob, integrator) integrator.dtpropose = integrator.dt - integrator.stats.nf += 2 + increment_nf_from_initdt!(integrator.stats) +end + +function increment_nf_from_initdt!(stats) + stats.nf += 2 end function DiffEqBase.set_t!(integrator::ODEIntegrator, t::Real) diff --git a/src/integrators/integrator_utils.jl b/src/integrators/integrator_utils.jl index a3f45f456a..7abe83e6c9 100644 --- a/src/integrators/integrator_utils.jl +++ b/src/integrators/integrator_utils.jl @@ -231,57 +231,33 @@ function _loopfooter!(integrator) (integrator.opts.force_dtmin && abs(integrator.dt) <= timedepentdtmin(integrator)) if integrator.accept_step # Accept - integrator.stats.naccept += 1 + increment_accept!(integrator.stats) integrator.last_stepfail = false dtnew = DiffEqBase.value(step_accept_controller!(integrator, integrator.alg, q)) * oneunit(integrator.dt) integrator.tprev = integrator.t - integrator.t = if has_tstop(integrator) - tstop = integrator.tdir * first_tstop(integrator) - if abs(ttmp - tstop) < - 100eps(float(max(integrator.t, tstop) / oneunit(integrator.t))) * - oneunit(integrator.t) - tstop - else - ttmp - end - else - ttmp - end + integrator.t = fixed_t_for_floatingpoint_error!(integrator, ttmp) calc_dt_propose!(integrator, dtnew) handle_callbacks!(integrator) else # Reject + increment_reject!(integrator.stats) integrator.stats.nreject += 1 end elseif !integrator.opts.adaptive #Not adaptive - integrator.stats.naccept += 1 + increment_accept!(integrator.stats) integrator.tprev = integrator.t - integrator.t = if has_tstop(integrator) - tstop = integrator.tdir * first_tstop(integrator) - if abs(ttmp - tstop) < - 100eps(float(integrator.t / oneunit(integrator.t))) * oneunit(integrator.t) - tstop - else - ttmp - end - else - ttmp - end + integrator.t = fixed_t_for_floatingpoint_error!(integrator, ttmp) integrator.last_stepfail = false integrator.accept_step = true integrator.dtpropose = integrator.dt handle_callbacks!(integrator) end if integrator.opts.progress && integrator.iter % integrator.opts.progress_steps == 0 - t1, t2 = integrator.sol.prob.tspan - @logmsg(LogLevel(-1), - integrator.opts.progress_name, - _id=integrator.opts.progress_id, - message=integrator.opts.progress_message(integrator.dt, integrator.u, - integrator.p, integrator.t), - progress=(integrator.t - t1) / (t2 - t1)) + log_step!(integrator.opts.progress_name, integrator.opts.progress_id, + integrator.opts.progress_message, integrator.dt, integrator.u, + integrator.p, integrator.t, integrator.sol.prob.tspan) end # Take value because if t is dual then maxeig can be dual @@ -295,6 +271,37 @@ function _loopfooter!(integrator) nothing end +function increment_accept!(stats) + stats.naccept += 1 +end + +function increment_reject!(stats) + stats.nreject += 1 +end + +function log_step!(progress_name, progress_id, progress_message, dt, u, p, t, tspan) + t1, t2 = tspan + @logmsg(LogLevel(-1),progress_name, + _id=progress_id, + message=progress_message(dt, u, p, t), + progress=(t - t1) / (t2 - t1)) +end + +function fixed_t_for_floatingpoint_error!(integrator, ttmp) + if has_tstop(integrator) + tstop = integrator.tdir * first_tstop(integrator) + if abs(ttmp - tstop) < + 100eps(float(max(integrator.t, tstop) / oneunit(integrator.t))) * + oneunit(integrator.t) + tstop + else + ttmp + end + else + ttmp + end +end + # Use a generated function to call apply_callback! in a type-stable way @generated function apply_ith_callback!(integrator, time, upcrossing, event_idx, cb_idx, diff --git a/src/integrators/type.jl b/src/integrators/type.jl index a361302f9e..a3474ee785 100644 --- a/src/integrators/type.jl +++ b/src/integrators/type.jl @@ -177,7 +177,7 @@ mutable struct ODEIntegrator{algType <: Union{OrdinaryDiffEqAlgorithm, DAEAlgori do_error_check, event_last_time, vector_event_last_time, last_event_error, accept_step, isout, reeval_fsal, u_modified, reinitialize, isdae, - opts, stats, initializealg, differential_vars) # Leave off fsalfirst and last + opts, stats, initializealg, differential_vars, zero(u), zero(u)) end end diff --git a/src/perform_step/low_order_rk_perform_step.jl b/src/perform_step/low_order_rk_perform_step.jl index 8b8a0f63dd..f15fec13b9 100644 --- a/src/perform_step/low_order_rk_perform_step.jl +++ b/src/perform_step/low_order_rk_perform_step.jl @@ -797,10 +797,14 @@ function initialize!(integrator, cache::Tsit5Cache) integrator.k[6] = cache.k6 integrator.k[7] = cache.k7 integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t) # Pre-start fsal - integrator.stats.nf += 1 + increment_nf!(integrator.stats) return nothing end +function increment_nf!(stats) + stats.nf += 1 +end + @muladd function perform_step!(integrator, cache::Tsit5Cache, repeat_step = false) @unpack t, dt, uprev, u, f, p = integrator T = constvalue(recursive_unitless_bottom_eltype(u)) @@ -832,7 +836,7 @@ end stage_limiter!(u, integrator, p, t + dt) step_limiter!(u, integrator, p, t + dt) f(k7, u, p, t + dt) - integrator.stats.nf += 6 + increment_nf_perform_step!(integrator.stats) if integrator.alg isa CompositeAlgorithm g7 = u g6 = tmp @@ -853,6 +857,10 @@ end return nothing end +function increment_nf_perform_step!(stats) + stats.nf += 6 +end + function initialize!(integrator, cache::DP5ConstantCache) integrator.kshortsize = 4 integrator.k = typeof(integrator.k)(undef, integrator.kshortsize) From 079f118e892a694e3d915f1dd0f766b511bb1cc0 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 14 Jul 2024 18:56:12 +0200 Subject: [PATCH 2/2] Fastpow removal --- src/integrators/controllers.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/integrators/controllers.jl b/src/integrators/controllers.jl index 416a6ea99d..a61328c65b 100644 --- a/src/integrators/controllers.jl +++ b/src/integrators/controllers.jl @@ -65,7 +65,7 @@ end q = inv(qmax) else expo = 1 / (get_current_adaptive_order(alg, integrator.cache) + 1) - qtmp = DiffEqBase.fastpow(EEst, expo) / gamma + qtmp = ^(EEst, expo) / gamma @fastmath q = DiffEqBase.value(max(inv(qmax), min(inv(qmin), qtmp))) # TODO: Shouldn't this be in `step_accept_controller!` as for the PI controller? integrator.qold = DiffEqBase.value(integrator.dt) / q @@ -138,8 +138,8 @@ end if iszero(EEst) q = inv(qmax) else - q11 = DiffEqBase.fastpow(EEst, float(beta1)) - q = q11 / DiffEqBase.fastpow(qold, float(beta2)) + q11 = ^(EEst, float(beta1)) + q = q11 / ^(qold, float(beta2)) integrator.q11 = q11 @fastmath q = max(inv(qmax), min(inv(qmin), q / gamma)) end @@ -412,7 +412,7 @@ end fac = min(gamma, (1 + 2 * maxiters) * gamma / (iter + 2 * maxiters)) end expo = 1 / (get_current_adaptive_order(alg, integrator.cache) + 1) - qtmp = DiffEqBase.fastpow(EEst, expo) / fac + qtmp = ^(EEst, expo) / fac @fastmath q = DiffEqBase.value(max(inv(qmax), min(inv(qmin), qtmp))) integrator.qold = q end @@ -426,7 +426,7 @@ function step_accept_controller!(integrator, controller::PredictiveController, a if integrator.success_iter > 0 expo = 1 / (get_current_adaptive_order(alg, integrator.cache) + 1) qgus = (integrator.dtacc / integrator.dt) * - DiffEqBase.fastpow((EEst^2) / integrator.erracc, expo) + ^((EEst^2) / integrator.erracc, expo) qgus = max(inv(qmax), min(inv(qmin), qgus / gamma)) qacc = max(q, qgus) else