Skip to content

Commit

Permalink
WIP: Add internal reverse-mode rules for ranges
Browse files Browse the repository at this point in the history
This is the second PR to fix EnzymeAD#274. It's separated as I think the forward mode one can just be merged no problem, and this one may take a little bit more time.

The crux of why this one is hard is because of how Julia deals with malformed ranges.

```
Basically dret.val = 182.0:156.0:26.0, the 26.0 is not the true value. Same as

julia> 10:1:1
10:1:9
```

Because of that behavior, the reverse `dret` does not actually have the information as to what its final point is, and its length is "incorrect" as it's changed by the constructor. In order to "fix" the reverse, we'd want to swap the `step` to negative and then use the same start/stop, but that information is already lost so it cannot be fixed within the rule. You can see the commented out code that would do the fixing if the information is there, and without that we cannot get a correctly sized reversed range for the rule.

But it's a bit puzzling to figure out how to remove that behavior. In Base Julia it seems to be done in the `function (:)(start::T, step::T, stop::T) where T<:IEEEFloat`, and as I showed in the issue, I can overload that function and the behavior goes away, but Enzyme's constructed range still has that truncation behavior, which means I missed spot or something.

namespace ConfigWidth

namespace

namespace needs_primal

namespace AugmentedReturn
  • Loading branch information
ChrisRackauckas authored and wsmoses committed Aug 26, 2024
1 parent 32b7aa2 commit 95b2972
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
25 changes: 25 additions & 0 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -922,3 +922,28 @@ function EnzymeRules.reverse(
)
return ()
end

function EnzymeRules.augmented_primal(config::EnzymeRules.ConfigWidth{1}, func::Const{Colon}, ::Type{<:Active},
start, step ,stop)

if EnzymeRules.needs_primal(config)
primal = func.val(start.val, step.val, stop.val)
else
primal = nothing
end
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
end

function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth{1}, func::Const{Colon}, dret, tape::Nothing,
start, step, stop)

#fixedreverse = if _dret.start > _dret.stop && _dret.step > 0
# _dret.stop:_dret.step:_dret.start
#else
# _dret
#end
dstart = start isa Const ? nothing : one(eltype(dret.val))
dstep = step isa Const ? nothing : one(eltype(dret.val))
dstop = stop isa Const ? nothing : zero(eltype(dret.val))
return (dstart, dstep, dstop)
end
5 changes: 5 additions & 0 deletions test/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,11 @@ end
((var"1"=75.0, var"2"=150.0),)
@test Enzyme.autodiff(Forward, f4, BatchDuplicated(0.12, (1.0, 2.0))) ==
((var"1"=0.0, var"2"=0.0),)

@test Enzyme.autodiff(Reverse, f1, Active, Active(0.25)) == ((78,),)
@test Enzyme.autodiff(Reverse, f2, Active, Active(0.25)) == ((1.0,),)
@test Enzyme.autodiff(Reverse, f3, Active, Active(0.25)) == ((12,),)
@test Enzyme.autodiff(Reverse, f4, Active, Active(0.25)) == ((0.0,),)
end

end # InternalRules

0 comments on commit 95b2972

Please sign in to comment.