Skip to content

Commit

Permalink
Import needs_primal from Reactant.jl (#2021)
Browse files Browse the repository at this point in the history
* Implement `needs_primal`

* Test `needs_primal`

* Export `needs_primal`

* Bump versions

* Move `needs_primal` tests EnzymeCore

* Use tilde version specifier to also support v0.8.4

* Move `WithPrimal`, `NoPrimal` tests to EnzymeCore

---------

Co-authored-by: William Moses <gh@wsmoses.com>
  • Loading branch information
mofeing and wsmoses authored Oct 30, 2024
1 parent 3728b0c commit 26ca6fe
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ EnzymeStaticArraysExt = "StaticArrays"
BFloat16s = "0.2, 0.3, 0.4, 0.5"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.8.4"
EnzymeCore = "0.8.4, 0.8.5"
Enzyme_jll = "0.0.158"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1"
LLVM = "6.1, 7, 8, 9"
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeCore"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.8.4"
version = "0.8.5"

[compat]
Adapt = "3, 4"
Expand Down
14 changes: 14 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export MixedDuplicated, BatchMixedDuplicated
export DefaultABI, FFIABI, InlineABI, NonGenABI
export BatchDuplicatedFunc
export within_autodiff
export needs_primal

function batch_size end

Expand Down Expand Up @@ -351,6 +352,15 @@ Return a new mode which excludes the primal value.
"""
@inline NoPrimal(::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}) where {ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten} = ReverseMode{false,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}()

"""
needs_primal(::Mode)
needs_primal(::Type{Mode})
Returns `true` if the mode needs the primal value, otherwise `false`.
"""
@inline needs_primal(::ReverseMode{ReturnPrimal}) where {ReturnPrimal} = ReturnPrimal
@inline needs_primal(::Type{<:ReverseMode{ReturnPrimal}}) where {ReturnPrimal} = ReturnPrimal

"""
struct ReverseModeSplit{
ReturnPrimal,
Expand Down Expand Up @@ -424,6 +434,8 @@ Return a new instance of [`ReverseModeSplit`](@ref) mode where `Width` is set to
@inline WithPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}()
@inline NoPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}()

@inline needs_primal(::ReverseModeSplit{ReturnPrimal}) where {ReturnPrimal} = ReturnPrimal
@inline needs_primal(::Type{<:ReverseModeSplit{ReturnPrimal}}) where {ReturnPrimal} = ReturnPrimal

"""
struct ForwardMode{
Expand Down Expand Up @@ -480,6 +492,8 @@ const ForwardWithPrimal = ForwardMode{true, DefaultABI, false, false}()
@inline WithPrimal(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{true,ABI,ErrIfFuncWritten,RuntimeActivity}()
@inline NoPrimal(::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}) where {ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity} = ForwardMode{false,ABI,ErrIfFuncWritten,RuntimeActivity}()

@inline needs_primal(::ForwardMode{ReturnPrimal}) where {ReturnPrimal} = ReturnPrimal
@inline needs_primal(::Type{<:ForwardMode{ReturnPrimal}}) where {ReturnPrimal} = ReturnPrimal

function autodiff end
function autodiff_deferred end
Expand Down
28 changes: 28 additions & 0 deletions lib/EnzymeCore/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,34 @@ using Test
using EnzymeCore

@testset verbose = true "EnzymeCore" begin
@testset "WithPrimal" begin
@test WithPrimal(Reverse) === ReverseWithPrimal
@test NoPrimal(Reverse) === Reverse
@test WithPrimal(ReverseWithPrimal) === ReverseWithPrimal
@test NoPrimal(ReverseWithPrimal) === Reverse

@test WithPrimal(set_runtime_activity(Reverse)) === set_runtime_activity(ReverseWithPrimal)

@test WithPrimal(Forward) === ForwardWithPrimal
@test NoPrimal(Forward) === Forward
@test WithPrimal(ForwardWithPrimal) === ForwardWithPrimal
@test NoPrimal(ForwardWithPrimal) === Forward

@test WithPrimal(ReverseSplitNoPrimal) === ReverseSplitWithPrimal
@test NoPrimal(ReverseSplitNoPrimal) === ReverseSplitNoPrimal
@test WithPrimal(ReverseSplitWithPrimal) === ReverseSplitWithPrimal
@test NoPrimal(ReverseSplitWithPrimal) === ReverseSplitNoPrimal
end

@testset "needs_primal" begin
@test needs_primal(Reverse) === false
@test needs_primal(ReverseWithPrimal) === true
@test needs_primal(Forward) === false
@test needs_primal(ForwardWithPrimal) === true
@test needs_primal(ReverseSplitNoPrimal) === false
@test needs_primal(ReverseSplitWithPrimal) === true
end

@testset "Miscellaneous" begin
include("misc.jl")
end
Expand Down
19 changes: 0 additions & 19 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3615,25 +3615,6 @@ end
@test res[2][6] 6.0
end

@testset "WithPrimal" begin
@test WithPrimal(Reverse) === ReverseWithPrimal
@test NoPrimal(Reverse) === Reverse
@test WithPrimal(ReverseWithPrimal) === ReverseWithPrimal
@test NoPrimal(ReverseWithPrimal) === Reverse

@test WithPrimal(set_runtime_activity(Reverse)) === set_runtime_activity(ReverseWithPrimal)

@test WithPrimal(Forward) === ForwardWithPrimal
@test NoPrimal(Forward) === Forward
@test WithPrimal(ForwardWithPrimal) === ForwardWithPrimal
@test NoPrimal(ForwardWithPrimal) === Forward

@test WithPrimal(ReverseSplitNoPrimal) === ReverseSplitWithPrimal
@test NoPrimal(ReverseSplitNoPrimal) === ReverseSplitNoPrimal
@test WithPrimal(ReverseSplitWithPrimal) === ReverseSplitWithPrimal
@test NoPrimal(ReverseSplitWithPrimal) === ReverseSplitNoPrimal
end

# TEST EXTENSIONS
using SpecialFunctions
@testset "SpecialFunctions ext" begin
Expand Down

4 comments on commit 26ca6fe

@wsmoses
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir="lib/EnzymeTestUtils"

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/118417

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a EnzymeTestUtils-v0.2.2 -m "<description of version>" 26ca6fe82e64aaca34d85554b8d60de24360632d
git push origin EnzymeTestUtils-v0.2.2

Also, note the warning: Version 0.2.2 skips over 0.2.1
This can be safely ignored. However, if you want to fix this you can do so. Call register() again after making the fix. This will update the Pull request.

@wsmoses
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir="lib/EnzymeCore"

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/118418

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a EnzymeCore-v0.8.5 -m "<description of version>" 26ca6fe82e64aaca34d85554b8d60de24360632d
git push origin EnzymeCore-v0.8.5

Please sign in to comment.