diff --git a/src/internal_rules.jl b/src/internal_rules.jl index d874dd5380..03e5b6ee3f 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -815,3 +815,17 @@ function EnzymeRules.forward(func::Const{typeof(ldiv!)}, end end end + +# Ranges +# Float64 ranges in Julia use bitwise `&` with higher precision +# to correct for numerical error, thus we put rules over the +# operations as this is not directly differentiable + +getval(x) = hasproperty(x, :val) ? x.val : x +function forward(func::Const{Colon}, ::Type{<:Duplicated}, start::Union{Const, Active}, step::Union{Const, Active}, stop::Union{Const, Active}) + ret = func.val(getval.((start, step, stop))...) + dstart = start isa Const ? zero(eltype(ret)) : one(eltype(ret)) + dstep = step isa Const ? zero(eltype(ret)) : one(eltype(ret)) + + return Duplicated(ret, range(dstart, step=dstep, length=length(ret))) +end \ No newline at end of file diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 7cc5c07321..25c7ab2838 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -618,4 +618,27 @@ end @test autodiff(Enzyme.Reverse, x -> rand(MyDistribution(x)), Active, Active(1.0)) == ((1.0,),) end +@testset "Ranges" begin + function f1(x) + ts = Array(0.0:x:3.0) + sum(ts) + end + function f2(x) + ts = Array(0.0:.25:3.0) + sum(ts) + x + end + function f3(x) + ts = Array(x:.25:3.0) + sum(ts) + end + function f4(x) + ts = Array(0.0:.25:x) + sum(ts) + end + @test Enzyme.autodiff(Forward, f1, Duplicated(0.25, 1.0)) == (78,) + @test Enzyme.autodiff(Forward, f2, Duplicated(0.25, 1.0)) == (1.0,) + @test Enzyme.autodiff(Forward, f3, Duplicated(0.25, 1.0)) == (12,) + @test Enzyme.autodiff(Forward, f4, Duplicated(3.0, 1.0)) == (0,) +end + end # InternalRules