From f91eabb764d6d0e0d24b6e929a2aa0ffc86aec9b Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Thu, 26 Sep 2024 17:16:26 -0400 Subject: [PATCH] Add WithPrimal and NoPrimal function (#1898) * Add WithPrimal and NoPrimal function * version bumps --- Project.toml | 2 +- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/EnzymeCore.jl | 23 +++++++++++++++++++++++ src/Enzyme.jl | 6 +++++- test/runtests.jl | 19 +++++++++++++++++++ 5 files changed, 49 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index f2b99062a0..4cffef367c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.3" +version = "0.13.4" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 37ddaf6457..3a871b930c 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.2" +version = "0.8.3" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index f51c742f5d..3231674de5 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -244,6 +244,21 @@ const ReverseHolomorphicWithPrimal = ReverseMode{true,false,DefaultABI, true, fa @inline set_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,rt,ABI,Holomorphic,ErrIfFuncWritten}() @inline clear_runtime_activity(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{ReturnPrimal,false,ABI,Holomorphic,ErrIfFuncWritten}() +""" + WithPrimal(::Enzyme.Mode) + +Modifies the mode to include the primal value. +""" +@inline WithPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{true,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}() + +""" + NoPrimal(::Enzyme.Mode) + +Modifies the mode to exclude the primal value. +""" +@inline NoPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{false,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}() + + """ struct ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI} <: Mode{ABI,ErrIfFuncWritten,RuntimeActivity} @@ -267,6 +282,10 @@ const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,Defau @inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,rt,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() @inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() +@inline WithPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() +@inline NoPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() + + """ struct Forward{ReturnPrimal, ABI, ErrIfFuncWritten,RuntimeActivity} <: Mode{ABI, ErrIfFuncWritten, RuntimeActivity} @@ -286,6 +305,10 @@ const ForwardWithPrimal = ForwardMode{true, DefaultABI, false, false}() @inline set_runtime_activity(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}, rt::Bool) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,rt}() @inline clear_runtime_activity(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,false}() +@inline WithPrimal(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{true,ABI,ErrIfFuncWritten,RuntimeActivity}() +@inline NoPrimal(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() + + function autodiff end function autodiff_deferred end function autodiff_thunk end diff --git a/src/Enzyme.jl b/src/Enzyme.jl index c99114e038..b49c3738f6 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -46,7 +46,9 @@ import EnzymeCore: set_abi, set_runtime_activity, clear_runtime_activity, - within_autodiff + within_autodiff, + WithPrimal, + NoPrimal export Annotation, Const, Active, @@ -63,6 +65,8 @@ export Annotation, set_abi, set_runtime_activity, clear_runtime_activity, + WithPrimal, + NoPrimal, within_autodiff import EnzymeCore: BatchDuplicatedFunc diff --git a/test/runtests.jl b/test/runtests.jl index 69e6d51cd5..d499febd77 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4066,6 +4066,25 @@ end @test res[2][6] ≈ 6.0 end +@testset "WithPrimal" begin + @test WithPrimal(Reverse) === ReverseWithPrimal + @test NoPrimal(Reverse) === Reverse + @test WithPrimal(ReverseWithPrimal) === ReverseWithPrimal + @test NoPrimal(ReverseWithPrimal) === Reverse + + @test WithPrimal(set_runtime_activity(Reverse)) === set_runtime_activity(ReverseWithPrimal) + + @test WithPrimal(Forward) === ForwardWithPrimal + @test NoPrimal(Forward) === Forward + @test WithPrimal(ForwardWithPrimal) === ForwardWithPrimal + @test NoPrimal(ForwardWithPrimal) === Forward + + @test WithPrimal(ReverseSplitNoPrimal) === ReverseSplitWithPrimal + @test NoPrimal(ReverseSplitNoPrimal) === ReverseSplitNoPrimal + @test WithPrimal(ReverseSplitWithPrimal) === ReverseSplitWithPrimal + @test NoPrimal(ReverseSplitWithPrimal) === ReverseSplitNoPrimal +end + # TEST EXTENSIONS using SpecialFunctions @testset "SpecialFunctions ext" begin