Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP: Make Enzyme discrete adjoints work
MWE now works: ```julia using Enzyme, OrdinaryDiffEq, StaticArrays Enzyme.EnzymeCore.EnzymeRules.inactive_type(::Type{SciMLBase.DEStats}) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_from_initdt!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.fixed_t_for_floatingpoint_error!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_accept!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_reject!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(DiffEqBase.fastpow), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_perform_step!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.check_error!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.log_step!), args...) = true function lorenz!(du, u, p, t) du[1] = 10.0(u[2] - u[1]) du[2] = u[1] * (28.0 - u[3]) - u[2] du[3] = u[1] * u[2] - (8 / 3) * u[3] end const _saveat = SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0] function f(y::Array{Float64}, u0::Array{Float64}) tspan = (0.0, 3.0) prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan) sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough()) y .= sol[1,:] return nothing end; u0 = [1.0; 0.0; 0.0] d_u0 = zeros(3) y = zeros(13) dy = zeros(13) Enzyme.autodiff(Reverse, f, Duplicated(y, dy), Duplicated(u0, d_u0)); ``` Core issues to finish this: 1. I shouldn't have to pull all of the logging out to a separate function, but there seems to be a bug in enzyme with int inactivity EnzymeAD/Enzyme.jl#1636 2. `saveat` has issues because it uses Julia ranges, which can have a floating point fix issue EnzymeAD/Enzyme.jl#274 3. adding the zero(u), zero(u) is required because Enzyme does not seem to support non-fully initialized types (@wsmoses is that known?) and segfaults when trying to use the uninitialized memory. So making the inner constructor not use undef is and easy fix to that. But that's not memory optimal. It would take a bit of a refactor to make it memory optimal, but it's no big deal and it's probably something that improves the package anyways.
- Loading branch information