-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
floatrange
causes "unkown" binary operator
#274
Comments
throwing this one back to you |
So this is interesting. We basically have a custom floating point type, so teaching Enzyme about that is going to be fun. Probably best to wait for #177 |
As a bandaid, should Enzyme just define a rule over range construction? |
MWE now works: ```julia using Enzyme, OrdinaryDiffEq, StaticArrays Enzyme.EnzymeCore.EnzymeRules.inactive_type(::Type{SciMLBase.DEStats}) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_from_initdt!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.fixed_t_for_floatingpoint_error!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_accept!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_reject!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(DiffEqBase.fastpow), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_perform_step!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.check_error!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.log_step!), args...) = true function lorenz!(du, u, p, t) du[1] = 10.0(u[2] - u[1]) du[2] = u[1] * (28.0 - u[3]) - u[2] du[3] = u[1] * u[2] - (8 / 3) * u[3] end const _saveat = SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0] function f(y::Array{Float64}, u0::Array{Float64}) tspan = (0.0, 3.0) prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan) sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough()) y .= sol[1,:] return nothing end; u0 = [1.0; 0.0; 0.0] d_u0 = zeros(3) y = zeros(13) dy = zeros(13) Enzyme.autodiff(Reverse, f, Duplicated(y, dy), Duplicated(u0, d_u0)); ``` Core issues to finish this: 1. I shouldn't have to pull all of the logging out to a separate function, but there seems to be a bug in enzyme with int inactivity EnzymeAD/Enzyme.jl#1636 2. `saveat` has issues because it uses Julia ranges, which can have a floating point fix issue EnzymeAD/Enzyme.jl#274 3. adding the zero(u), zero(u) is required because Enzyme does not seem to support non-fully initialized types (@wsmoses is that known?) and segfaults when trying to use the uninitialized memory. So making the inner constructor not use undef is and easy fix to that. But that's not memory optimal. It would take a bit of a refactor to make it memory optimal, but it's no big deal and it's probably something that improves the package anyways.
I've been working from this MWE: using Enzyme
function f(x)
ts = Array(0.0:x:3.0)
sum(ts)
end
f(0.25)
Enzyme.autodiff(Forward, f, Duplicated(0.25, 1.0))
Enzyme.autodiff(Reverse, f, Active, Active(0.25))
|
Inside of the range code there is a using Enzyme, ReverseDiff, Tracker
import .EnzymeRules: forward, reverse, augmented_primal
using .EnzymeRules
function forward(func::Const{typeof(Base.truncbits)}, ::Type{<:Duplicated}, x::Duplicated, mask)
println("Using custom rule!")
maskval = if hasproperty(mask, :val)
mask.val
else
mask
end
ret = func.val(x.val, maskval)
@show x.dval
return Duplicated(ret, one(ret))
end
function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(Base.truncbits)}, ::Type{<:Union{Const,Active}},
x::Union{Const,Active}, mask)
println("In custom augmented primal rule.")
maskval = if hasproperty(mask, :val)
mask.val
else
mask
end
# Compute primal
if needs_primal(config)
primal = func.val(x.val, maskval)
else
primal = nothing
end
# Return an AugmentedReturn object with shadow = nothing
return AugmentedReturn(primal, nothing, nothing)
end
function reverse(config::ConfigWidth{1}, func::Const{typeof(Base.truncbits)}, ::Type{<:Union{Const,Active}}, dret::Union{Active,Const},
x::Union{Const,Active}, mask)
println("In custom reverse rule.")
return (one(x.val), nothing)
end
function reverse(config::ConfigWidth{1}, func::Const{typeof(Base.truncbits)}, ::Type{<:Active}, dret::Nothing,
x::Active, mask)
println("In custom reverse rule.")
return (one(x.val), nothing)
end
function reverse(config::ConfigWidth{1}, func::Const{typeof(Base.truncbits)}, ::Type{<:Const}, dret::Nothing,
x::Const, mask)
println("In custom reverse rule.")
return (nothing, nothing)
end
function reverse(config::ConfigWidth{1}, func::Const{typeof(Base.truncbits)}, ::Active, dret::Nothing,
x::Active, mask)
println("In custom reverse rule.")
return (one(x.val), nothing)
end
function reverse(config::ConfigWidth{1}, func::Const{typeof(Base.truncbits)}, ::Const, dret::Nothing,
x::Const, mask)
println("In custom reverse rule.")
return (nothing, nothing)
end
function f(x)
ts = Array(0.0:x:3.0)
sum(ts)
end
f(0.25)
Enzyme.autodiff(Forward, f, Duplicated(0.25, 1.0))
Enzyme.autodiff(Reverse, f, Active, Active(0.25))
To me that seems like a bug in the activity detection, possibly caused by some of the reinterprets back from UInt representations. I'll see if I can fix this with a rule targeted a bit higher. |
For reference, targeting it like this with other AD systems works well: Base.div(x::Tracker.TrackedReal, y::Tracker.TrackedReal, r::RoundingMode) = div(Tracker.value(x), Tracker.value(y), r)
_y, back = Tracker.forward(f, 0.25)
back(1)
|
In my journey here, one level up: function forward(func::Const{typeof(Base.steprangelen_hp)}, ::Type{<:Duplicated}, outtype::Const{Type{Float64}}, ref::Union{Const,Active}, step::Union{Const,Active}, nb, len, offset)
println("Using custom rule!")
ret = func.val(getval.((outtype, ref, step, nb, len, offset))...)
@show outtype, ref, step, nb, len, offset
start = ref isa Const ? zero(eltype(ret)) : one(eltype(ret))
dstep = step isa Const ? zero(eltype(ret)) : one(eltype(ret))
return Duplicated(ret, StepRangeLen(Base.TwicePrecision(start), Base.TwicePrecision(dstep), length(ret)))
end All of the values are Int though at this level, so again Enzyme keeps them const. I think I know the right target now though, it has to be the # Construct range for rational start=start_n/den, step=step_n/den
function floatrange(::Type{T}, start_n::Integer, step_n::Integer, len::Integer, den::Integer) where T
len = len + 0 # promote with Int
if len < 2 || step_n == 0
return steprangelen_hp(T, (start_n, den), (step_n, den), 0, len, oneunit(len))
end
# index of smallest-magnitude value
L = typeof(len)
imin = clamp(round(typeof(len), -start_n/step_n+1), oneunit(L), len)
# Compute smallest-magnitude element to 2x precision
ref_n = start_n+(imin-1)*step_n # this shouldn't overflow, so don't check
nb = nbitslen(T, len, imin)
@show steprangelen_hp(T, (ref_n, den), (step_n, den), nb, len, imin)
end so everything below that is already treated as const. |
Okay moving a level higher, I got a forward rule to work: using Enzyme, ReverseDiff, Tracker
import .EnzymeRules: forward, reverse, augmented_primal
using .EnzymeRules
getval(x) = hasproperty(x, :val) ? x.val : x
function forward(func::Const{Colon}, ::Type{<:Duplicated}, start::Union{Const, Active}, step::Union{Const, Active}, stop::Union{Const, Active})
ret = func.val(getval.((start, step, stop))...)
dstart = start isa Const ? zero(eltype(ret)) : one(eltype(ret))
dstep = step isa Const ? zero(eltype(ret)) : one(eltype(ret))
return Duplicated(ret, range(dstart, step=dstep, length=length(ret)))
end
function f1(x)
ts = Array(0.0:x:3.0)
sum(ts)
end
function f2(x)
ts = Array(0.0:.25:3.0)
sum(ts) + x
end
function f3(x)
ts = Array(x:.25:3.0)
sum(ts)
end
function f4(x)
ts = Array(0.0:.25:x)
sum(ts)
end
f1(0.25)
Enzyme.autodiff(Forward, f1, Duplicated(0.25, 1.0)) == (78,)
Enzyme.autodiff(Forward, f2, Duplicated(0.25, 1.0)) == (1.0,)
Enzyme.autodiff(Forward, f3, Duplicated(0.25, 1.0)) == (12,)
Enzyme.autodiff(Forward, f4, Duplicated(3.0, 1.0)) == (0,)
using ForwardDiff
ForwardDiff.derivative(f1, 0.25)
ForwardDiff.derivative(f2, 0.25)
ForwardDiff.derivative(f3, 0.25)
ForwardDiff.derivative(f4, 3.0) 🎉 |
For the reverse mode, I need to figure out how to make Enzyme run a custom piece of code. The underlying problem is that malformed ranges are lossy: julia> 10:1:1
10:1:9 The reason this comes up is because Enzyme does a very naive construction of the using Enzyme, ReverseDiff, Tracker
import .EnzymeRules: reverse, augmented_primal
using .EnzymeRules
function augmented_primal(config::ConfigWidth{1}, func::Const{Colon}, ::Type{<:Active},
start, step ,stop)
println("In custom augmented primal rule.")
# Compute primal
if needs_primal(config)
primal = func.val(start.val, step.val, stop.val)
else
primal = nothing
end
return AugmentedReturn(primal, nothing, nothing)
end
function reverse(config::ConfigWidth{1}, func::Const{Colon}, dret, tape::Nothing,
start, step, stop)
println("In custom reverse rule.")
_dret = @show dret.val
#fixedreverse = if _dret.start > _dret.stop && _dret.step > 0
# _dret.stop:_dret.step:_dret.start
#else
# _dret
#end
dstart = start isa Const ? nothing : one(eltype(dret.val))
dstep = step isa Const ? nothing : one(eltype(dret.val))
dstop = stop isa Const ? nothing : zero(eltype(dret.val))
return (dstart, dstep, dstop)
end
Enzyme.autodiff(Reverse, f1, Active, Active(0.25))
Enzyme.autodiff(Reverse, f2, Active, Active(0.25)) == ((1.0,),)
Enzyme.autodiff(Reverse, f3, Active, Active(0.25))
Enzyme.autodiff(Reverse, f4, Active, Active(0.25)) == ((0.0,),)
using ForwardDiff
ForwardDiff.derivative(f1, 0.25)
ForwardDiff.derivative(f2, 0.25)
ForwardDiff.derivative(f3, 0.25)
ForwardDiff.derivative(f4, 3.0) That's the set of test cases, for the first one you can see it: ulia> Enzyme.autodiff(Reverse, f1, Active, Active(0.25))
In custom augmented primal rule.
(ref, step, len, offset) = (Base.TwicePrecision{Float64}(0.0, 0.0), Base.TwicePrecision{Float64}(0.25, 0.0), 13, 1)
primal = 0.0:0.25:3.0
In custom reverse rule.
dret.val = 182.0:156.0:26.0
((1.0,),) Basically julia> 10:1:1
10:1:9 the 26 is simply I can patch Julia to not be lossy here. I can Revise in Base/twiceprecision.jl: function (:)(start::T, step::T, stop::T) where T<:IEEEFloat
step == 0 && throw(ArgumentError("range step cannot be zero"))
# see if the inputs have exact rational approximations (and if so,
# perform all computations in terms of the rationals)
step_n, step_d = rat(step)
if step_d != 0 && T(step_n/step_d) == step
start_n, start_d = rat(start)
stop_n, stop_d = rat(stop)
if start_d != 0 && stop_d != 0 &&
T(start_n/start_d) == start && T(stop_n/stop_d) == stop
den = lcm_unchecked(start_d, step_d) # use same denominator for start and step
m = maxintfloat(T, Int)
if den != 0 && abs(start*den) <= m && abs(step*den) <= m && # will round succeed?
rem(den, start_d) == 0 && rem(den, step_d) == 0 # check lcm overflow
start_n = round(Int, start*den)
step_n = round(Int, step*den)
len = max(0, Int(div(den*stop_n - stop_d*start_n + step_n*stop_d, step_n*stop_d)))
# Integer ops could overflow, so check that this makes sense
if isbetween(start, start + (len-1)*step, stop + step/2) &&
!isbetween(start, start + len*step, stop)
# Return a 2x precision range
return floatrange(T, start_n, step_n, len, den)
end
end
end
end
# Fallback, taking start and step literally
# n.b. we use Int as the default length type for IEEEFloats
lf = (stop-start)/step
if lf < 0
len = 0
elseif lf == 0
len = 1
else
len = round(Int, lf) + 1
stop′ = start + (len-1)*step
# if we've overshot the end, subtract one:
len -= (start < stop < stop′) + (start > stop > stop′)
end
steprangelen_hp(T, start, step, 0, len, 1)
end becomes: function (:)(start::T, step::T, stop::T) where T<:IEEEFloat
step == 0 && throw(ArgumentError("range step cannot be zero"))
# see if the inputs have exact rational approximations (and if so,
# perform all computations in terms of the rationals)
step_n, step_d = rat(step)
if step_d != 0 && T(step_n/step_d) == step
start_n, start_d = rat(start)
stop_n, stop_d = rat(stop)
if start_d != 0 && stop_d != 0 &&
T(start_n/start_d) == start && T(stop_n/stop_d) == stop
den = lcm_unchecked(start_d, step_d) # use same denominator for start and step
m = maxintfloat(T, Int)
if den != 0 && abs(start*den) <= m && abs(step*den) <= m && # will round succeed?
rem(den, start_d) == 0 && rem(den, step_d) == 0 # check lcm overflow
start_n = round(Int, start*den)
step_n = round(Int, step*den)
len = max(0, Int(div(den*stop_n - stop_d*start_n + step_n*stop_d, step_n*stop_d)))
# Integer ops could overflow, so check that this makes sense
if isbetween(start, start + (len-1)*step, stop + step/2) &&
!isbetween(start, start + len*step, stop)
# Return a 2x precision range
return floatrange(T, start_n, step_n, len, den)
end
end
end
end
# Fallback, taking start and step literally
# n.b. we use Int as the default length type for IEEEFloats
lf = (stop-start)/step
#if lf < 0
# len = 0
#elseif lf == 0
# len = 1
#else
len = round(Int, lf) + 1
stop′ = start + (len-1)*step
# if we've overshot the end, subtract one:
len -= (start < stop < stop′) + (start > stop > stop′)
#end
if len < 0
step = -step
len = -len + 2
end
steprangelen_hp(T, start, step, 0, len, 1)
end and with this patch: julia> 10.0:1.0:1.0
(ref, step, len, offset) = (Base.TwicePrecision{Float64}(10.0, 0.0), Base.TwicePrecision{Float64}(-1.0, 0.0), 10, 1)
10.0:-1.0:1.0 so bueno, this let's me retain the information. However, when I Revise this in, Enzyme does not seem to use it in its construction of In custom augmented primal rule.
In custom reverse rule.
dret.val = 182.0:156.0:26.0
((1.0,),)
In custom augmented primal rule.
In custom reverse rule.
dret.val = 182.0:156.0:26.0
true
In custom augmented primal rule.
In custom reverse rule.
dret.val = 156.0:132.0:24.0
((1.0,),)
In custom augmented primal rule.
In custom reverse rule.
dret.val = 6.0:2.0:4.0
true You can see it's still using some constructor that's forcing the malformed range to be lossy, and thus the rule cannot be written. So there's two paths here:
|
I'm not sure I understand/follow. Enzyme wouldn't construct the dval by calling this function, but creating a zero'd tuple and += the value from the uses. |
@ChrisRackauckas do you want to move this to a PR to make it easier to comment? |
bump @ChrisRackauckas can you move this to a PR? |
This is part 1 one solving EnzymeAD#274. It does the forward mode rules as those are simpler. A separate PR will do the WIP reverse mode rules as that seems to be a bit more complex.
This is part 1 one solving EnzymeAD#274. It does the forward mode rules as those are simpler. A separate PR will do the WIP reverse mode rules as that seems to be a bit more complex. Add missing `@test` don't forget the rule
This is the second PR to fix EnzymeAD#274. It's separated as I think the forward mode one can just be merged no problem, and this one may take a little bit more time. The crux of why this one is hard is because of how Julia deals with malformed ranges. ``` Basically dret.val = 182.0:156.0:26.0, the 26.0 is not the true value. Same as julia> 10:1:1 10:1:9 ``` Because of that behavior, the reverse `dret` does not actually have the information as to what its final point is, and its length is "incorrect" as it's changed by the constructor. In order to "fix" the reverse, we'd want to swap the `step` to negative and then use the same start/stop, but that information is already lost so it cannot be fixed within the rule. You can see the commented out code that would do the fixing if the information is there, and without that we cannot get a correctly sized reversed range for the rule. But it's a bit puzzling to figure out how to remove that behavior. In Base Julia it seems to be done in the `function (:)(start::T, step::T, stop::T) where T<:IEEEFloat`, and as I showed in the issue, I can overload that function and the behavior goes away, but Enzyme's constructed range still has that truncation behavior, which means I missed spot or something.
This is the second PR to fix EnzymeAD#274. It's separated as I think the forward mode one can just be merged no problem, and this one may take a little bit more time. The crux of why this one is hard is because of how Julia deals with malformed ranges. ``` Basically dret.val = 182.0:156.0:26.0, the 26.0 is not the true value. Same as julia> 10:1:1 10:1:9 ``` Because of that behavior, the reverse `dret` does not actually have the information as to what its final point is, and its length is "incorrect" as it's changed by the constructor. In order to "fix" the reverse, we'd want to swap the `step` to negative and then use the same start/stop, but that information is already lost so it cannot be fixed within the rule. You can see the commented out code that would do the fixing if the information is there, and without that we cannot get a correctly sized reversed range for the rule. But it's a bit puzzling to figure out how to remove that behavior. In Base Julia it seems to be done in the `function (:)(start::T, step::T, stop::T) where T<:IEEEFloat`, and as I showed in the issue, I can overload that function and the behavior goes away, but Enzyme's constructed range still has that truncation behavior, which means I missed spot or something.
This is the second PR to fix EnzymeAD#274. It's separated as I think the forward mode one can just be merged no problem, and this one may take a little bit more time. The crux of why this one is hard is because of how Julia deals with malformed ranges. ``` Basically dret.val = 182.0:156.0:26.0, the 26.0 is not the true value. Same as julia> 10:1:1 10:1:9 ``` Because of that behavior, the reverse `dret` does not actually have the information as to what its final point is, and its length is "incorrect" as it's changed by the constructor. In order to "fix" the reverse, we'd want to swap the `step` to negative and then use the same start/stop, but that information is already lost so it cannot be fixed within the rule. You can see the commented out code that would do the fixing if the information is there, and without that we cannot get a correctly sized reversed range for the rule. But it's a bit puzzling to figure out how to remove that behavior. In Base Julia it seems to be done in the `function (:)(start::T, step::T, stop::T) where T<:IEEEFloat`, and as I showed in the issue, I can overload that function and the behavior goes away, but Enzyme's constructed range still has that truncation behavior, which means I missed spot or something. namespace ConfigWidth namespace namespace needs_primal namespace AugmentedReturn
* Add internal forward-mode rules for ranges This is part 1 one solving #274. It does the forward mode rules as those are simpler. A separate PR will do the WIP reverse mode rules as that seems to be a bit more complex. Add missing `@test` don't forget the rule * namespace * Update internal_rules.jl * Update internal_rules.jl * Update src/internal_rules.jl * Update internal_rules.jl * Update internal_rules.jl --------- Co-authored-by: William Moses <gh@wsmoses.com>
This is the second PR to fix EnzymeAD#274. It's separated as I think the forward mode one can just be merged no problem, and this one may take a little bit more time. The crux of why this one is hard is because of how Julia deals with malformed ranges. ``` Basically dret.val = 182.0:156.0:26.0, the 26.0 is not the true value. Same as julia> 10:1:1 10:1:9 ``` Because of that behavior, the reverse `dret` does not actually have the information as to what its final point is, and its length is "incorrect" as it's changed by the constructor. In order to "fix" the reverse, we'd want to swap the `step` to negative and then use the same start/stop, but that information is already lost so it cannot be fixed within the rule. You can see the commented out code that would do the fixing if the information is there, and without that we cannot get a correctly sized reversed range for the rule. But it's a bit puzzling to figure out how to remove that behavior. In Base Julia it seems to be done in the `function (:)(start::T, step::T, stop::T) where T<:IEEEFloat`, and as I showed in the issue, I can overload that function and the behavior goes away, but Enzyme's constructed range still has that truncation behavior, which means I missed spot or something. namespace ConfigWidth namespace namespace needs_primal namespace AugmentedReturn
* WIP: Add internal reverse-mode rules for ranges This is the second PR to fix #274. It's separated as I think the forward mode one can just be merged no problem, and this one may take a little bit more time. The crux of why this one is hard is because of how Julia deals with malformed ranges. ``` Basically dret.val = 182.0:156.0:26.0, the 26.0 is not the true value. Same as julia> 10:1:1 10:1:9 ``` Because of that behavior, the reverse `dret` does not actually have the information as to what its final point is, and its length is "incorrect" as it's changed by the constructor. In order to "fix" the reverse, we'd want to swap the `step` to negative and then use the same start/stop, but that information is already lost so it cannot be fixed within the rule. You can see the commented out code that would do the fixing if the information is there, and without that we cannot get a correctly sized reversed range for the rule. But it's a bit puzzling to figure out how to remove that behavior. In Base Julia it seems to be done in the `function (:)(start::T, step::T, stop::T) where T<:IEEEFloat`, and as I showed in the issue, I can overload that function and the behavior goes away, but Enzyme's constructed range still has that truncation behavior, which means I missed spot or something. namespace ConfigWidth namespace namespace needs_primal namespace AugmentedReturn * Complete implementation * fix * fix --------- Co-authored-by: Billy Moses <wmoses@google.com>
cc: @boriskaus
The text was updated successfully, but these errors were encountered: