Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Make Enzyme discrete adjoints work #2282

Closed
wants to merge 2 commits into from
Closed

Conversation

ChrisRackauckas
Copy link
Member

@ChrisRackauckas ChrisRackauckas commented Jul 13, 2024

MWE now works:

using Enzyme, OrdinaryDiffEq, StaticArrays

Enzyme.EnzymeCore.EnzymeRules.inactive_type(::Type{SciMLBase.DEStats}) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_from_initdt!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.fixed_t_for_floatingpoint_error!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_accept!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_reject!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(DiffEqBase.fastpow), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_perform_step!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.check_error!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.log_step!), args...) = true

function lorenz!(du, u, p, t)
    du[1] = 10.0(u[2] - u[1])
    du[2] = u[1] * (28.0 - u[3]) - u[2]
    du[3] = u[1] * u[2] - (8 / 3) * u[3]
end

const _saveat =  SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]

function f(y::Array{Float64}, u0::Array{Float64})
    tspan = (0.0, 3.0)
    prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
    sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough())
    y .= sol[1,:]
    return nothing
end;
u0 = [1.0; 0.0; 0.0]
d_u0 = zeros(3)
y  = zeros(13)
dy = zeros(13)

Enzyme.autodiff(Reverse, f,  Duplicated(y, dy), Duplicated(u0, d_u0));

Core issues to finish this:

  1. I shouldn't have to pull all of the logging out to a separate function, but there seems to be a bug in enzyme with int inactivity Operation on Int not deduced as const EnzymeAD/Enzyme.jl#1636
  2. saveat has issues because it uses Julia ranges, which can have a floating point fix issue floatrange causes "unkown" binary operator EnzymeAD/Enzyme.jl#274
  3. adding the zero(u), zero(u) is required because Enzyme does not seem to support non-fully initialized types (@wsmoses is that known?) and segfaults when trying to use the uninitialized memory. So making the inner constructor not use undef is and easy fix to that. But that's not memory optimal. It would take a bit of a refactor to make it memory optimal, but it's no big deal and it's probably something that improves the package anyways.
  4. Inactivating fastpow directly probably isn't a good idea. But the step controller should be inactive, we already do the dual/tracker dropping on that for the other AD methods. So we can move that inactivity to the right spot before merging.

MWE now works:

```julia
using Enzyme, OrdinaryDiffEq, StaticArrays

Enzyme.EnzymeCore.EnzymeRules.inactive_type(::Type{SciMLBase.DEStats}) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_from_initdt!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.fixed_t_for_floatingpoint_error!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_accept!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_reject!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(DiffEqBase.fastpow), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_perform_step!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.check_error!), args...) = true
Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.log_step!), args...) = true

function lorenz!(du, u, p, t)
    du[1] = 10.0(u[2] - u[1])
    du[2] = u[1] * (28.0 - u[3]) - u[2]
    du[3] = u[1] * u[2] - (8 / 3) * u[3]
end

const _saveat =  SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0]

function f(y::Array{Float64}, u0::Array{Float64})
    tspan = (0.0, 3.0)
    prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan)
    sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough())
    y .= sol[1,:]
    return nothing
end;
u0 = [1.0; 0.0; 0.0]
d_u0 = zeros(3)
y  = zeros(13)
dy = zeros(13)

Enzyme.autodiff(Reverse, f,  Duplicated(y, dy), Duplicated(u0, d_u0));
```

Core issues to finish this:

1. I shouldn't have to pull all of the logging out to a separate function, but there seems to be a bug in enzyme with int inactivity EnzymeAD/Enzyme.jl#1636
2. `saveat` has issues because it uses Julia ranges, which can have a floating point fix issue EnzymeAD/Enzyme.jl#274
3. adding the zero(u), zero(u) is required because Enzyme does not seem to support non-fully initialized types (@wsmoses is that known?) and segfaults when trying to use the uninitialized memory. So making the inner constructor not use undef is and easy fix to that. But that's not memory optimal. It would take a bit of a refactor to make it memory optimal, but it's no big deal and it's probably something that improves the package anyways.
@wsmoses
Copy link
Contributor

wsmoses commented Jul 14, 2024

@ChrisRackauckas I don't follow, what's the error with non-fully initialized types?

@wsmoses
Copy link
Contributor

wsmoses commented Jul 14, 2024

Presumably not known, so open an issue

@ChrisRackauckas
Copy link
Member Author

@ChrisRackauckas I don't follow, what's the error with non-fully initialized types?

Julia allows for a type to be partially defined by having an inner type constructor which only puts a value on a subset of arguments. The others are set to undef and need to be defined before access. It seems that Enzyme does not do well with that.

@@ -65,7 +65,7 @@ end
q = inv(qmax)
else
expo = 1 / (get_current_adaptive_order(alg, integrator.cache) + 1)
qtmp = DiffEqBase.fastpow(EEst, expo) / gamma
qtmp = ^(EEst, expo) / gamma
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't these have a @fastmath (and Float32)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was just seeing what's required in order to make Enzyme work on this.

ChrisRackauckas added a commit that referenced this pull request Aug 17, 2024
This was used as a performance optimization early on, dropping the construction of those two vectors since we already construct so much in the caches, we might as well reuse one of the cache pointers. And it's also built into some methods what cache pointer that must be. So the integrator is made with undef's and then during initialization phase the pointers are set.

However, this is unnecessary and adds some complexity. For one, it makes the constructor a bit of a mess. But for two, it gives Enzyme issues as demonstrated in #2282. A better solution is then to just, construct the type correctly.

To do this, we simply need to refactor the information of what vectors correspond to fsal first and last into a function that is per-cache, and use that function in the integrator construction. That's already done in this PR. All that's required to complete this PR is to ensure this refactor is done on every method.
@ChrisRackauckas
Copy link
Member Author

All of the pieces are in now, Enzyme now works on the explicit solvers!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants