From 26ca6fe82e64aaca34d85554b8d60de24360632d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Wed, 30 Oct 2024 22:17:48 +0100 Subject: [PATCH] Import `needs_primal` from Reactant.jl (#2021) * Implement `needs_primal` * Test `needs_primal` * Export `needs_primal` * Bump versions * Move `needs_primal` tests EnzymeCore * Use tilde version specifier to also support v0.8.4 * Move `WithPrimal`, `NoPrimal` tests to EnzymeCore --------- Co-authored-by: William Moses --- Project.toml | 2 +- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/EnzymeCore.jl | 14 ++++++++++++++ lib/EnzymeCore/test/runtests.jl | 28 ++++++++++++++++++++++++++++ test/runtests.jl | 19 ------------------- 5 files changed, 44 insertions(+), 21 deletions(-) diff --git a/Project.toml b/Project.toml index fd1866fa56..ed748bdc77 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ EnzymeStaticArraysExt = "StaticArrays" BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.8.4" +EnzymeCore = "0.8.4, 0.8.5" Enzyme_jll = "0.0.158" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 2e45d2c2f6..18c3bbad00 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.4" +version = "0.8.5" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 394cd00a5f..c751aaac38 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -7,6 +7,7 @@ export MixedDuplicated, BatchMixedDuplicated export DefaultABI, FFIABI, InlineABI, NonGenABI export BatchDuplicatedFunc export within_autodiff +export needs_primal function batch_size end @@ -351,6 +352,15 @@ Return a new mode which excludes the primal value. """ @inline NoPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{false,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}() +""" + needs_primal(::Mode) + needs_primal(::Type{Mode}) + +Returns `true` if the mode needs the primal value, otherwise `false`. +""" +@inline needs_primal(::ReverseMode{ReturnPrimal}) where {ReturnPrimal} = ReturnPrimal +@inline needs_primal(::Type{<:ReverseMode{ReturnPrimal}}) where {ReturnPrimal} = ReturnPrimal + """ struct ReverseModeSplit{ ReturnPrimal, @@ -424,6 +434,8 @@ Return a new instance of [`ReverseModeSplit`](@ref) mode where `Width` is set to @inline WithPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() @inline NoPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() +@inline needs_primal(::ReverseModeSplit{ReturnPrimal}) where {ReturnPrimal} = ReturnPrimal +@inline needs_primal(::Type{<:ReverseModeSplit{ReturnPrimal}}) where {ReturnPrimal} = ReturnPrimal """ struct ForwardMode{ @@ -480,6 +492,8 @@ const ForwardWithPrimal = ForwardMode{true, DefaultABI, false, false}() @inline WithPrimal(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{true,ABI,ErrIfFuncWritten,RuntimeActivity}() @inline NoPrimal(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}() +@inline needs_primal(::ForwardMode{ReturnPrimal}) where {ReturnPrimal} = ReturnPrimal +@inline needs_primal(::Type{<:ForwardMode{ReturnPrimal}}) where {ReturnPrimal} = ReturnPrimal function autodiff end function autodiff_deferred end diff --git a/lib/EnzymeCore/test/runtests.jl b/lib/EnzymeCore/test/runtests.jl index 114e7d7157..61f0e7af5c 100644 --- a/lib/EnzymeCore/test/runtests.jl +++ b/lib/EnzymeCore/test/runtests.jl @@ -2,6 +2,34 @@ using Test using EnzymeCore @testset verbose = true "EnzymeCore" begin + @testset "WithPrimal" begin + @test WithPrimal(Reverse) === ReverseWithPrimal + @test NoPrimal(Reverse) === Reverse + @test WithPrimal(ReverseWithPrimal) === ReverseWithPrimal + @test NoPrimal(ReverseWithPrimal) === Reverse + + @test WithPrimal(set_runtime_activity(Reverse)) === set_runtime_activity(ReverseWithPrimal) + + @test WithPrimal(Forward) === ForwardWithPrimal + @test NoPrimal(Forward) === Forward + @test WithPrimal(ForwardWithPrimal) === ForwardWithPrimal + @test NoPrimal(ForwardWithPrimal) === Forward + + @test WithPrimal(ReverseSplitNoPrimal) === ReverseSplitWithPrimal + @test NoPrimal(ReverseSplitNoPrimal) === ReverseSplitNoPrimal + @test WithPrimal(ReverseSplitWithPrimal) === ReverseSplitWithPrimal + @test NoPrimal(ReverseSplitWithPrimal) === ReverseSplitNoPrimal + end + + @testset "needs_primal" begin + @test needs_primal(Reverse) === false + @test needs_primal(ReverseWithPrimal) === true + @test needs_primal(Forward) === false + @test needs_primal(ForwardWithPrimal) === true + @test needs_primal(ReverseSplitNoPrimal) === false + @test needs_primal(ReverseSplitWithPrimal) === true + end + @testset "Miscellaneous" begin include("misc.jl") end diff --git a/test/runtests.jl b/test/runtests.jl index 777cb4a63d..d28461d26a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3615,25 +3615,6 @@ end @test res[2][6] ≈ 6.0 end -@testset "WithPrimal" begin - @test WithPrimal(Reverse) === ReverseWithPrimal - @test NoPrimal(Reverse) === Reverse - @test WithPrimal(ReverseWithPrimal) === ReverseWithPrimal - @test NoPrimal(ReverseWithPrimal) === Reverse - - @test WithPrimal(set_runtime_activity(Reverse)) === set_runtime_activity(ReverseWithPrimal) - - @test WithPrimal(Forward) === ForwardWithPrimal - @test NoPrimal(Forward) === Forward - @test WithPrimal(ForwardWithPrimal) === ForwardWithPrimal - @test NoPrimal(ForwardWithPrimal) === Forward - - @test WithPrimal(ReverseSplitNoPrimal) === ReverseSplitWithPrimal - @test NoPrimal(ReverseSplitNoPrimal) === ReverseSplitNoPrimal - @test WithPrimal(ReverseSplitWithPrimal) === ReverseSplitWithPrimal - @test NoPrimal(ReverseSplitWithPrimal) === ReverseSplitNoPrimal -end - # TEST EXTENSIONS using SpecialFunctions @testset "SpecialFunctions ext" begin