Skip to content

Commit

Permalink
Add internal forward-mode rules for ranges
Browse files Browse the repository at this point in the history
This is part 1 one solving EnzymeAD#274. It does the forward mode rules as those are simpler. A separate PR will do the WIP reverse mode rules as that seems to be a bit more complex.

Add missing `@test`

don't forget the rule
  • Loading branch information
ChrisRackauckas committed Jul 21, 2024
1 parent 1e15769 commit 9e976d4
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -815,3 +815,17 @@ function EnzymeRules.forward(func::Const{typeof(ldiv!)},
end
end
end

# Ranges
# Float64 ranges in Julia use bitwise `&` with higher precision
# to correct for numerical error, thus we put rules over the
# operations as this is not directly differentiable

getval(x) = hasproperty(x, :val) ? x.val : x
function forward(func::Const{Colon}, ::Type{<:Duplicated}, start::Union{Const, Active}, step::Union{Const, Active}, stop::Union{Const, Active})
ret = func.val(getval.((start, step, stop))...)
dstart = start isa Const ? zero(eltype(ret)) : one(eltype(ret))
dstep = step isa Const ? zero(eltype(ret)) : one(eltype(ret))

return Duplicated(ret, range(dstart, step=dstep, length=length(ret)))
end
23 changes: 23 additions & 0 deletions test/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -618,4 +618,27 @@ end
@test autodiff(Enzyme.Reverse, x -> rand(MyDistribution(x)), Active, Active(1.0)) == ((1.0,),)
end

@testset "Ranges" begin
function f1(x)
ts = Array(0.0:x:3.0)
sum(ts)
end
function f2(x)
ts = Array(0.0:.25:3.0)
sum(ts) + x
end
function f3(x)
ts = Array(x:.25:3.0)
sum(ts)
end
function f4(x)
ts = Array(0.0:.25:x)
sum(ts)
end
@test Enzyme.autodiff(Forward, f1, Duplicated(0.25, 1.0)) == (78,)
@test Enzyme.autodiff(Forward, f2, Duplicated(0.25, 1.0)) == (1.0,)
@test Enzyme.autodiff(Forward, f3, Duplicated(0.25, 1.0)) == (12,)
@test Enzyme.autodiff(Forward, f4, Duplicated(3.0, 1.0)) == (0,)
end

end # InternalRules

0 comments on commit 9e976d4

Please sign in to comment.