diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 0199d7077c..d9b2193ce8 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -922,3 +922,28 @@ function EnzymeRules.reverse( ) return () end + +function EnzymeRules.augmented_primal(config::EnzymeRules.ConfigWidth{1}, func::Const{Colon}, ::Type{<:Active}, + start, step ,stop) + + if EnzymeRules.needs_primal(config) + primal = func.val(start.val, step.val, stop.val) + else + primal = nothing + end + return EnzymeRules.AugmentedReturn(primal, nothing, nothing) +end + +function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth{1}, func::Const{Colon}, dret, tape::Nothing, + start, step, stop) + + #fixedreverse = if _dret.start > _dret.stop && _dret.step > 0 + # _dret.stop:_dret.step:_dret.start + #else + # _dret + #end + dstart = start isa Const ? nothing : one(eltype(dret.val)) + dstep = step isa Const ? nothing : one(eltype(dret.val)) + dstop = stop isa Const ? nothing : zero(eltype(dret.val)) + return (dstart, dstep, dstop) +end diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 659d5dee98..b2d851677a 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -668,6 +668,11 @@ end ((var"1"=75.0, var"2"=150.0),) @test Enzyme.autodiff(Forward, f4, BatchDuplicated(0.12, (1.0, 2.0))) == ((var"1"=0.0, var"2"=0.0),) + + @test Enzyme.autodiff(Reverse, f1, Active, Active(0.25)) == ((78,),) + @test Enzyme.autodiff(Reverse, f2, Active, Active(0.25)) == ((1.0,),) + @test Enzyme.autodiff(Reverse, f3, Active, Active(0.25)) == ((12,),) + @test Enzyme.autodiff(Reverse, f4, Active, Active(0.25)) == ((0.0,),) end end # InternalRules