Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Make Enzyme discrete adjoints work #2282

Closed
wants to merge 2 commits into from
Closed

Commits on Jul 13, 2024

  1. WIP: Make Enzyme discrete adjoints work

    MWE now works:
    
    ```julia
    using Enzyme, OrdinaryDiffEq, StaticArrays
    
    Enzyme.EnzymeCore.EnzymeRules.inactive_type(::Type{SciMLBase.DEStats}) = true
    Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf!), args...) = true
    Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_from_initdt!), args...) = true
    Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.fixed_t_for_floatingpoint_error!), args...) = true
    Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_accept!), args...) = true
    Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_reject!), args...) = true
    Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(DiffEqBase.fastpow), args...) = true
    Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_perform_step!), args...) = true
    Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.check_error!), args...) = true
    Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.log_step!), args...) = true
    
    function lorenz!(du, u, p, t)
        du[1] = 10.0(u[2] - u[1])
        du[2] = u[1] * (28.0 - u[3]) - u[2]
        du[3] = u[1] * u[2] - (8 / 3) * u[3]
    end
    
    const _saveat =  SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]
    
    function f(y::Array{Float64}, u0::Array{Float64})
        tspan = (0.0, 3.0)
        prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
        sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough())
        y .= sol[1,:]
        return nothing
    end;
    u0 = [1.0; 0.0; 0.0]
    d_u0 = zeros(3)
    y  = zeros(13)
    dy = zeros(13)
    
    Enzyme.autodiff(Reverse, f,  Duplicated(y, dy), Duplicated(u0, d_u0));
    ```
    
    Core issues to finish this:
    
    1. I shouldn't have to pull all of the logging out to a separate function, but there seems to be a bug in enzyme with int inactivity EnzymeAD/Enzyme.jl#1636
    2. `saveat` has issues because it uses Julia ranges, which can have a floating point fix issue EnzymeAD/Enzyme.jl#274
    3. adding the zero(u), zero(u) is required because Enzyme does not seem to support non-fully initialized types (@wsmoses is that known?) and segfaults when trying to use the uninitialized memory. So making the inner constructor not use undef is and easy fix to that. But that's not memory optimal. It would take a bit of a refactor to make it memory optimal, but it's no big deal and it's probably something that improves the package anyways.
    ChrisRackauckas committed Jul 13, 2024
    Configuration menu
    Copy the full SHA
    7dba466 View commit details
    Browse the repository at this point in the history

Commits on Jul 14, 2024

  1. Fastpow removal

    ChrisRackauckas committed Jul 14, 2024
    Configuration menu
    Copy the full SHA
    079f118 View commit details
    Browse the repository at this point in the history