Skip to content

Commit

Permalink
WIP: Make Enzyme discrete adjoints work
Browse files Browse the repository at this point in the history
MWE now works:

```julia
using Enzyme, OrdinaryDiffEq, StaticArrays

Enzyme.EnzymeCore.EnzymeRules.inactive_type(::Type{SciMLBase.DEStats}) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_from_initdt!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.fixed_t_for_floatingpoint_error!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_accept!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_reject!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(DiffEqBase.fastpow), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_perform_step!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.check_error!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.log_step!), args...) = true

function lorenz!(du, u, p, t)
    du[1] = 10.0(u[2] - u[1])
    du[2] = u[1] * (28.0 - u[3]) - u[2]
    du[3] = u[1] * u[2] - (8 / 3) * u[3]
end

const _saveat =  SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]

function f(y::Array{Float64}, u0::Array{Float64})
    tspan = (0.0, 3.0)
    prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
    sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough())
    y .= sol[1,:]
    return nothing
end;
u0 = [1.0; 0.0; 0.0]
d_u0 = zeros(3)
y  = zeros(13)
dy = zeros(13)

Enzyme.autodiff(Reverse, f,  Duplicated(y, dy), Duplicated(u0, d_u0));
```

Core issues to finish this:

1. I shouldn't have to pull all of the logging out to a separate function, but there seems to be a bug in enzyme with int inactivity EnzymeAD/Enzyme.jl#1636
2. `saveat` has issues because it uses Julia ranges, which can have a floating point fix issue EnzymeAD/Enzyme.jl#274
3. adding the zero(u), zero(u) is required because Enzyme does not seem to support non-fully initialized types (@wsmoses is that known?) and segfaults when trying to use the uninitialized memory. So making the inner constructor not use undef is and easy fix to that. But that's not memory optimal. It would take a bit of a refactor to make it memory optimal, but it's no big deal and it's probably something that improves the package anyways.
  • Loading branch information
ChrisRackauckas committed Jul 13, 2024
1 parent 6ece080 commit 7dba466
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 36 deletions.
6 changes: 5 additions & 1 deletion src/integrators/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,11 @@ function DiffEqBase.auto_dt_reset!(integrator::ODEIntegrator)
integrator.opts.internalnorm, integrator.sol.prob,
integrator)
integrator.dtpropose = integrator.dt
integrator.stats.nf += 2
increment_nf_from_initdt!(integrator.stats)
end

function increment_nf_from_initdt!(stats)
stats.nf += 2
end

function DiffEqBase.set_t!(integrator::ODEIntegrator, t::Real)
Expand Down
71 changes: 39 additions & 32 deletions src/integrators/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,57 +231,33 @@ function _loopfooter!(integrator)
(integrator.opts.force_dtmin &&
abs(integrator.dt) <= timedepentdtmin(integrator))
if integrator.accept_step # Accept
integrator.stats.naccept += 1
increment_accept!(integrator.stats)
integrator.last_stepfail = false
dtnew = DiffEqBase.value(step_accept_controller!(integrator,
integrator.alg,
q)) *
oneunit(integrator.dt)
integrator.tprev = integrator.t
integrator.t = if has_tstop(integrator)
tstop = integrator.tdir * first_tstop(integrator)
if abs(ttmp - tstop) <
100eps(float(max(integrator.t, tstop) / oneunit(integrator.t))) *
oneunit(integrator.t)
tstop
else
ttmp
end
else
ttmp
end
integrator.t = fixed_t_for_floatingpoint_error!(integrator, ttmp)
calc_dt_propose!(integrator, dtnew)
handle_callbacks!(integrator)
else # Reject
increment_reject!(integrator.stats)
integrator.stats.nreject += 1
end
elseif !integrator.opts.adaptive #Not adaptive
integrator.stats.naccept += 1
increment_accept!(integrator.stats)
integrator.tprev = integrator.t
integrator.t = if has_tstop(integrator)
tstop = integrator.tdir * first_tstop(integrator)
if abs(ttmp - tstop) <
100eps(float(integrator.t / oneunit(integrator.t))) * oneunit(integrator.t)
tstop
else
ttmp
end
else
ttmp
end
integrator.t = fixed_t_for_floatingpoint_error!(integrator, ttmp)
integrator.last_stepfail = false
integrator.accept_step = true
integrator.dtpropose = integrator.dt
handle_callbacks!(integrator)
end
if integrator.opts.progress && integrator.iter % integrator.opts.progress_steps == 0
t1, t2 = integrator.sol.prob.tspan
@logmsg(LogLevel(-1),
integrator.opts.progress_name,
_id=integrator.opts.progress_id,
message=integrator.opts.progress_message(integrator.dt, integrator.u,
integrator.p, integrator.t),
progress=(integrator.t - t1) / (t2 - t1))
log_step!(integrator.opts.progress_name, integrator.opts.progress_id,
integrator.opts.progress_message, integrator.dt, integrator.u,
integrator.p, integrator.t, integrator.sol.prob.tspan)
end

# Take value because if t is dual then maxeig can be dual
Expand All @@ -295,6 +271,37 @@ function _loopfooter!(integrator)
nothing
end

function increment_accept!(stats)
stats.naccept += 1
end

function increment_reject!(stats)
stats.nreject += 1
end

function log_step!(progress_name, progress_id, progress_message, dt, u, p, t, tspan)
t1, t2 = tspan
@logmsg(LogLevel(-1),progress_name,
_id=progress_id,
message=progress_message(dt, u, p, t),
progress=(t - t1) / (t2 - t1))
end

function fixed_t_for_floatingpoint_error!(integrator, ttmp)
if has_tstop(integrator)
tstop = integrator.tdir * first_tstop(integrator)
if abs(ttmp - tstop) <
100eps(float(max(integrator.t, tstop) / oneunit(integrator.t))) *
oneunit(integrator.t)
tstop
else
ttmp
end
else
ttmp
end
end

# Use a generated function to call apply_callback! in a type-stable way
@generated function apply_ith_callback!(integrator,
time, upcrossing, event_idx, cb_idx,
Expand Down
2 changes: 1 addition & 1 deletion src/integrators/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ mutable struct ODEIntegrator{algType <: Union{OrdinaryDiffEqAlgorithm, DAEAlgori
do_error_check,
event_last_time, vector_event_last_time, last_event_error,
accept_step, isout, reeval_fsal, u_modified, reinitialize, isdae,
opts, stats, initializealg, differential_vars) # Leave off fsalfirst and last
opts, stats, initializealg, differential_vars, zero(u), zero(u))
end
end

Expand Down
12 changes: 10 additions & 2 deletions src/perform_step/low_order_rk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -797,10 +797,14 @@ function initialize!(integrator, cache::Tsit5Cache)
integrator.k[6] = cache.k6
integrator.k[7] = cache.k7
integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t) # Pre-start fsal
integrator.stats.nf += 1
increment_nf!(integrator.stats)
return nothing
end

function increment_nf!(stats)
stats.nf += 1
end

@muladd function perform_step!(integrator, cache::Tsit5Cache, repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
T = constvalue(recursive_unitless_bottom_eltype(u))
Expand Down Expand Up @@ -832,7 +836,7 @@ end
stage_limiter!(u, integrator, p, t + dt)
step_limiter!(u, integrator, p, t + dt)
f(k7, u, p, t + dt)
integrator.stats.nf += 6
increment_nf_perform_step!(integrator.stats)
if integrator.alg isa CompositeAlgorithm
g7 = u
g6 = tmp
Expand All @@ -853,6 +857,10 @@ end
return nothing
end

function increment_nf_perform_step!(stats)
stats.nf += 6
end

function initialize!(integrator, cache::DP5ConstantCache)
integrator.kshortsize = 4
integrator.k = typeof(integrator.k)(undef, integrator.kshortsize)
Expand Down

0 comments on commit 7dba466

Please sign in to comment.