Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WithPrimal and NoPrimal function #1898

Merged
merged 2 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Enzyme"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.13.3"
version = "0.13.4"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeCore"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.8.2"
version = "0.8.3"

[compat]
Adapt = "3, 4"
Expand Down
23 changes: 23 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,21 @@ const ReverseHolomorphicWithPrimal = ReverseMode{true,false,DefaultABI, true, fa
@inline set_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,rt,ABI,Holomorphic,ErrIfFuncWritten}()
@inline clear_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,false,ABI,Holomorphic,ErrIfFuncWritten}()

"""
WithPrimal(::Enzyme.Mode)

Modifies the mode to include the primal value.
"""
@inline WithPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{true,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}()

"""
NoPrimal(::Enzyme.Mode)

Modifies the mode to exclude the primal value.
"""
@inline NoPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{false,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}()


"""
struct ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI} <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity}

Expand All @@ -267,6 +282,10 @@ const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,Defau
@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,rt,Width,ModifiedBetween,ABI, ErrIfFuncWritten}()
@inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, ErrIfFuncWritten}()

@inline WithPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make these lower case like the other setters?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we also have some setters uppercase at the moment as well

@inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, WidthO, MB, ABI, ErrIfFuncWritten}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,WidthO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,ABI, ErrIfFuncWritten}()

I kind of like it uppercase like this since then WithPrimal(Forward) == ForwardWithPrimal and so on.

But yeah it is a shame it's not consistent (but it's already inconsnstent =/)

Copy link
Contributor Author

@ptiede ptiede Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with either since we are already inconsistent.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just going to go for this for now then. We can try to standardize more on the next bump

@inline NoPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}()


"""
struct Forward{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity}

Expand All @@ -286,6 +305,10 @@ const ForwardWithPrimal = ForwardMode{true, DefaultABI, false, false}()
@inline set_runtime_activity(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}, rt::Bool) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,rt}()
@inline clear_runtime_activity(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,false}()

@inline WithPrimal(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{true,ABI,ErrIfFuncWritten,RuntimeActivity}()
@inline NoPrimal(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}()


function autodiff end
function autodiff_deferred end
function autodiff_thunk end
Expand Down
6 changes: 5 additions & 1 deletion src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ import EnzymeCore:
set_abi,
set_runtime_activity,
clear_runtime_activity,
within_autodiff
within_autodiff,
WithPrimal,
NoPrimal
export Annotation,
Const,
Active,
Expand All @@ -63,6 +65,8 @@ export Annotation,
set_abi,
set_runtime_activity,
clear_runtime_activity,
WithPrimal,
NoPrimal,
within_autodiff

import EnzymeCore: BatchDuplicatedFunc
Expand Down
19 changes: 19 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4066,6 +4066,25 @@ end
@test res[2][6] ≈ 6.0
end

@testset "WithPrimal" begin
@test WithPrimal(Reverse) === ReverseWithPrimal
@test NoPrimal(Reverse) === Reverse
@test WithPrimal(ReverseWithPrimal) === ReverseWithPrimal
@test NoPrimal(ReverseWithPrimal) === Reverse

@test WithPrimal(set_runtime_activity(Reverse)) === set_runtime_activity(ReverseWithPrimal)

@test WithPrimal(Forward) === ForwardWithPrimal
@test NoPrimal(Forward) === Forward
@test WithPrimal(ForwardWithPrimal) === ForwardWithPrimal
@test NoPrimal(ForwardWithPrimal) === Forward

@test WithPrimal(ReverseSplitNoPrimal) === ReverseSplitWithPrimal
@test NoPrimal(ReverseSplitNoPrimal) === ReverseSplitNoPrimal
@test WithPrimal(ReverseSplitWithPrimal) === ReverseSplitWithPrimal
@test NoPrimal(ReverseSplitWithPrimal) === ReverseSplitNoPrimal
end

# TEST EXTENSIONS
using SpecialFunctions
@testset "SpecialFunctions ext" begin
Expand Down
Loading