Skip to content

Commit

Permalink
Add WithPrimal and NoPrimal function (#1898)
Browse files Browse the repository at this point in the history
* Add WithPrimal and NoPrimal function

* version bumps
  • Loading branch information
ptiede authored Sep 26, 2024
1 parent 9495698 commit f91eabb
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Enzyme"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.13.3"
version = "0.13.4"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeCore"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.8.2"
version = "0.8.3"

[compat]
Adapt = "3, 4"
Expand Down
23 changes: 23 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,21 @@ const ReverseHolomorphicWithPrimal = ReverseMode{true,false,DefaultABI, true, fa
@inline set_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,rt,ABI,Holomorphic,ErrIfFuncWritten}()
@inline clear_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,false,ABI,Holomorphic,ErrIfFuncWritten}()

"""
WithPrimal(::Enzyme.Mode)
Modifies the mode to include the primal value.
"""
@inline WithPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{true,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}()

"""
NoPrimal(::Enzyme.Mode)
Modifies the mode to exclude the primal value.
"""
@inline NoPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{false,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}()


"""
struct ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI} <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity}
Expand All @@ -267,6 +282,10 @@ const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,Defau
@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,rt,Width,ModifiedBetween,ABI, ErrIfFuncWritten}()
@inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, ErrIfFuncWritten}()

@inline WithPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}()
@inline NoPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}()


"""
struct Forward{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity}
Expand All @@ -286,6 +305,10 @@ const ForwardWithPrimal = ForwardMode{true, DefaultABI, false, false}()
@inline set_runtime_activity(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}, rt::Bool) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,rt}()
@inline clear_runtime_activity(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,false}()

@inline WithPrimal(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{true,ABI,ErrIfFuncWritten,RuntimeActivity}()
@inline NoPrimal(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}()


function autodiff end
function autodiff_deferred end
function autodiff_thunk end
Expand Down
6 changes: 5 additions & 1 deletion src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ import EnzymeCore:
set_abi,
set_runtime_activity,
clear_runtime_activity,
within_autodiff
within_autodiff,
WithPrimal,
NoPrimal
export Annotation,
Const,
Active,
Expand All @@ -63,6 +65,8 @@ export Annotation,
set_abi,
set_runtime_activity,
clear_runtime_activity,
WithPrimal,
NoPrimal,
within_autodiff

import EnzymeCore: BatchDuplicatedFunc
Expand Down
19 changes: 19 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4066,6 +4066,25 @@ end
@test res[2][6] 6.0
end

@testset "WithPrimal" begin
@test WithPrimal(Reverse) === ReverseWithPrimal
@test NoPrimal(Reverse) === Reverse
@test WithPrimal(ReverseWithPrimal) === ReverseWithPrimal
@test NoPrimal(ReverseWithPrimal) === Reverse

@test WithPrimal(set_runtime_activity(Reverse)) === set_runtime_activity(ReverseWithPrimal)

@test WithPrimal(Forward) === ForwardWithPrimal
@test NoPrimal(Forward) === Forward
@test WithPrimal(ForwardWithPrimal) === ForwardWithPrimal
@test NoPrimal(ForwardWithPrimal) === Forward

@test WithPrimal(ReverseSplitNoPrimal) === ReverseSplitWithPrimal
@test NoPrimal(ReverseSplitNoPrimal) === ReverseSplitNoPrimal
@test WithPrimal(ReverseSplitWithPrimal) === ReverseSplitWithPrimal
@test NoPrimal(ReverseSplitWithPrimal) === ReverseSplitNoPrimal
end

# TEST EXTENSIONS
using SpecialFunctions
@testset "SpecialFunctions ext" begin
Expand Down

0 comments on commit f91eabb

Please sign in to comment.